summaryrefslogtreecommitdiffstats
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorxiubuzhe <xiubuzhe@sina.com>2023-10-08 20:59:00 +0800
committerxiubuzhe <xiubuzhe@sina.com>2023-10-08 20:59:00 +0800
commit1dac2263372df2b85db5d029a45721fa158a5c9d (patch)
tree0365f9c57df04178a726d7584ca6a6b955a7ce6a /lib/sqlalchemy
parentb494be364bb39e1de128ada7dc576a729d99907e (diff)
downloadsunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.gz
sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.bz2
sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.zip
first add files
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py158
-rwxr-xr-xlib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.sobin0 -> 53952 bytes
-rw-r--r--lib/sqlalchemy/connectors/__init__.py10
-rw-r--r--lib/sqlalchemy/connectors/mxodbc.py166
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py193
-rwxr-xr-xlib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.sobin0 -> 60640 bytes
-rwxr-xr-xlib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.sobin0 -> 92632 bytes
-rw-r--r--lib/sqlalchemy/databases/__init__.py38
-rw-r--r--lib/sqlalchemy/dialects/__init__.py72
-rw-r--r--lib/sqlalchemy/dialects/firebird/__init__.py41
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py989
-rw-r--r--lib/sqlalchemy/dialects/firebird/fdb.py112
-rw-r--r--lib/sqlalchemy/dialects/firebird/kinterbasdb.py202
-rw-r--r--lib/sqlalchemy/dialects/mssql/__init__.py85
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py3545
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py232
-rw-r--r--lib/sqlalchemy/dialects/mssql/json.py125
-rw-r--r--lib/sqlalchemy/dialects/mssql/mxodbc.py150
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py116
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py138
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py673
-rw-r--r--lib/sqlalchemy/dialects/mysql/__init__.py103
-rw-r--r--lib/sqlalchemy/dialects/mysql/aiomysql.py317
-rw-r--r--lib/sqlalchemy/dialects/mysql/asyncmy.py328
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py3306
-rw-r--r--lib/sqlalchemy/dialects/mysql/cymysql.py82
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py175
-rw-r--r--lib/sqlalchemy/dialects/mysql/enumerated.py263
-rw-r--r--lib/sqlalchemy/dialects/mysql/expression.py130
-rw-r--r--lib/sqlalchemy/dialects/mysql/json.py84
-rw-r--r--lib/sqlalchemy/dialects/mysql/mariadb.py25
-rw-r--r--lib/sqlalchemy/dialects/mysql/mariadbconnector.py240
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py240
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py331
-rw-r--r--lib/sqlalchemy/dialects/mysql/oursql.py273
-rw-r--r--lib/sqlalchemy/dialects/mysql/provision.py78
-rw-r--r--lib/sqlalchemy/dialects/mysql/pymysql.py98
-rw-r--r--lib/sqlalchemy/dialects/mysql/pyodbc.py136
-rw-r--r--lib/sqlalchemy/dialects/mysql/reflection.py558
-rw-r--r--lib/sqlalchemy/dialects/mysql/reserved_words.py564
-rw-r--r--lib/sqlalchemy/dialects/mysql/types.py773
-rw-r--r--lib/sqlalchemy/dialects/oracle/__init__.py58
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2522
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py1424
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py160
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py117
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py413
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py1112
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py4651
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py274
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py277
-rw-r--r--lib/sqlalchemy/dialects/postgresql/hstore.py455
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py327
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py594
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py124
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py1088
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py60
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pygresql.py278
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pypostgresql.py126
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py138
-rw-r--r--lib/sqlalchemy/dialects/sqlite/__init__.py58
-rw-r--r--lib/sqlalchemy/dialects/sqlite/aiosqlite.py335
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py2556
-rw-r--r--lib/sqlalchemy/dialects/sqlite/dml.py200
-rw-r--r--lib/sqlalchemy/dialects/sqlite/json.py84
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py142
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlcipher.py164
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py613
-rw-r--r--lib/sqlalchemy/dialects/sybase/__init__.py67
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py1100
-rw-r--r--lib/sqlalchemy/dialects/sybase/mxodbc.py34
-rw-r--r--lib/sqlalchemy/dialects/sybase/pyodbc.py89
-rw-r--r--lib/sqlalchemy/dialects/sybase/pysybase.py106
-rw-r--r--lib/sqlalchemy/engine/__init__.py62
-rw-r--r--lib/sqlalchemy/engine/base.py3450
-rw-r--r--lib/sqlalchemy/engine/characteristics.py56
-rw-r--r--lib/sqlalchemy/engine/create.py743
-rw-r--r--lib/sqlalchemy/engine/cursor.py1942
-rw-r--r--lib/sqlalchemy/engine/default.py1936
-rw-r--r--lib/sqlalchemy/engine/events.py835
-rw-r--r--lib/sqlalchemy/engine/interfaces.py1719
-rw-r--r--lib/sqlalchemy/engine/mock.py118
-rw-r--r--lib/sqlalchemy/engine/reflection.py1160
-rw-r--r--lib/sqlalchemy/engine/result.py1857
-rw-r--r--lib/sqlalchemy/engine/row.py621
-rw-r--r--lib/sqlalchemy/engine/strategies.py17
-rw-r--r--lib/sqlalchemy/engine/url.py806
-rw-r--r--lib/sqlalchemy/engine/util.py253
-rw-r--r--lib/sqlalchemy/event/__init__.py17
-rw-r--r--lib/sqlalchemy/event/api.py219
-rw-r--r--lib/sqlalchemy/event/attr.py468
-rw-r--r--lib/sqlalchemy/event/base.py345
-rw-r--r--lib/sqlalchemy/event/legacy.py185
-rw-r--r--lib/sqlalchemy/event/registry.py297
-rw-r--r--lib/sqlalchemy/events.py14
-rw-r--r--lib/sqlalchemy/exc.py733
-rw-r--r--lib/sqlalchemy/ext/__init__.py11
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py1627
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py22
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py89
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py828
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/exc.py21
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py671
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py107
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py759
-rw-r--r--lib/sqlalchemy/ext/automap.py1234
-rw-r--r--lib/sqlalchemy/ext/baked.py648
-rw-r--r--lib/sqlalchemy/ext/compiler.py613
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py64
-rw-r--r--lib/sqlalchemy/ext/declarative/extensions.py463
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py256
-rw-r--r--lib/sqlalchemy/ext/hybrid.py1206
-rw-r--r--lib/sqlalchemy/ext/indexable.py352
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py416
-rw-r--r--lib/sqlalchemy/ext/mutable.py958
-rw-r--r--lib/sqlalchemy/ext/mypy/__init__.py0
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py299
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py516
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py556
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py253
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py284
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py305
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py388
-rw-r--r--lib/sqlalchemy/ext/serializer.py177
-rw-r--r--lib/sqlalchemy/future/__init__.py18
-rw-r--r--lib/sqlalchemy/future/engine.py413
-rw-r--r--lib/sqlalchemy/future/orm/__init__.py10
-rw-r--r--lib/sqlalchemy/inspection.py93
-rw-r--r--lib/sqlalchemy/log.py241
-rw-r--r--lib/sqlalchemy/orm/__init__.py344
-rw-r--r--lib/sqlalchemy/orm/attributes.py2331
-rw-r--r--lib/sqlalchemy/orm/base.py572
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py441
-rw-r--r--lib/sqlalchemy/orm/collections.py1706
-rw-r--r--lib/sqlalchemy/orm/context.py3136
-rw-r--r--lib/sqlalchemy/orm/decl_api.py1062
-rw-r--r--lib/sqlalchemy/orm/decl_base.py1210
-rw-r--r--lib/sqlalchemy/orm/dependency.py1290
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py745
-rw-r--r--lib/sqlalchemy/orm/dynamic.py491
-rw-r--r--lib/sqlalchemy/orm/evaluator.py241
-rw-r--r--lib/sqlalchemy/orm/events.py2876
-rw-r--r--lib/sqlalchemy/orm/exc.py204
-rw-r--r--lib/sqlalchemy/orm/identity.py254
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py652
-rw-r--r--lib/sqlalchemy/orm/interfaces.py978
-rw-r--r--lib/sqlalchemy/orm/loading.py1465
-rw-r--r--lib/sqlalchemy/orm/mapper.py3658
-rw-r--r--lib/sqlalchemy/orm/path_registry.py519
-rw-r--r--lib/sqlalchemy/orm/persistence.py2517
-rw-r--r--lib/sqlalchemy/orm/properties.py430
-rw-r--r--lib/sqlalchemy/orm/query.py3508
-rw-r--r--lib/sqlalchemy/orm/relationships.py3684
-rw-r--r--lib/sqlalchemy/orm/scoping.py228
-rw-r--r--lib/sqlalchemy/orm/session.py4386
-rw-r--r--lib/sqlalchemy/orm/state.py1025
-rw-r--r--lib/sqlalchemy/orm/strategies.py3141
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py2008
-rw-r--r--lib/sqlalchemy/orm/sync.py167
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py784
-rw-r--r--lib/sqlalchemy/orm/util.py2149
-rw-r--r--lib/sqlalchemy/pool/__init__.py56
-rw-r--r--lib/sqlalchemy/pool/base.py1121
-rw-r--r--lib/sqlalchemy/pool/dbapi_proxy.py147
-rw-r--r--lib/sqlalchemy/pool/events.py284
-rw-r--r--lib/sqlalchemy/pool/impl.py514
-rw-r--r--lib/sqlalchemy/processors.py176
-rw-r--r--lib/sqlalchemy/schema.py59
-rw-r--r--lib/sqlalchemy/sql/__init__.py150
-rw-r--r--lib/sqlalchemy/sql/annotation.py364
-rw-r--r--lib/sqlalchemy/sql/base.py1702
-rw-r--r--lib/sqlalchemy/sql/coercions.py1096
-rw-r--r--lib/sqlalchemy/sql/compiler.py5525
-rw-r--r--lib/sqlalchemy/sql/crud.py1091
-rw-r--r--lib/sqlalchemy/sql/ddl.py1341
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py360
-rw-r--r--lib/sqlalchemy/sql/dml.py1514
-rw-r--r--lib/sqlalchemy/sql/elements.py5415
-rw-r--r--lib/sqlalchemy/sql/events.py331
-rw-r--r--lib/sqlalchemy/sql/expression.py278
-rw-r--r--lib/sqlalchemy/sql/functions.py1575
-rw-r--r--lib/sqlalchemy/sql/lambdas.py1314
-rw-r--r--lib/sqlalchemy/sql/naming.py210
-rw-r--r--lib/sqlalchemy/sql/operators.py1688
-rw-r--r--lib/sqlalchemy/sql/roles.py239
-rw-r--r--lib/sqlalchemy/sql/schema.py5268
-rw-r--r--lib/sqlalchemy/sql/selectable.py6946
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py3351
-rw-r--r--lib/sqlalchemy/sql/traversals.py1559
-rw-r--r--lib/sqlalchemy/sql/type_api.py1974
-rw-r--r--lib/sqlalchemy/sql/util.py1120
-rw-r--r--lib/sqlalchemy/sql/visitors.py852
-rw-r--r--lib/sqlalchemy/testing/__init__.py86
-rw-r--r--lib/sqlalchemy/testing/assertions.py845
-rw-r--r--lib/sqlalchemy/testing/assertsql.py457
-rw-r--r--lib/sqlalchemy/testing/asyncio.py128
-rw-r--r--lib/sqlalchemy/testing/config.py209
-rw-r--r--lib/sqlalchemy/testing/engines.py465
-rw-r--r--lib/sqlalchemy/testing/entities.py111
-rw-r--r--lib/sqlalchemy/testing/exclusions.py465
-rw-r--r--lib/sqlalchemy/testing/fixtures.py870
-rw-r--r--lib/sqlalchemy/testing/mock.py32
-rw-r--r--lib/sqlalchemy/testing/pickleable.py151
-rw-r--r--lib/sqlalchemy/testing/plugin/__init__.py0
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py54
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py789
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py820
-rw-r--r--lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py112
-rw-r--r--lib/sqlalchemy/testing/profiling.py335
-rw-r--r--lib/sqlalchemy/testing/provision.py416
-rw-r--r--lib/sqlalchemy/testing/requirements.py1518
-rw-r--r--lib/sqlalchemy/testing/schema.py218
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py13
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py204
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py381
-rw-r--r--lib/sqlalchemy/testing/suite/test_deprecations.py145
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py361
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py367
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1738
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py426
-rw-r--r--lib/sqlalchemy/testing/suite/test_rowcount.py165
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py1783
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py282
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py1508
-rw-r--r--lib/sqlalchemy/testing/suite/test_unicode_ddl.py206
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py60
-rw-r--r--lib/sqlalchemy/testing/util.py458
-rw-r--r--lib/sqlalchemy/testing/warnings.py82
-rw-r--r--lib/sqlalchemy/types.py119
-rw-r--r--lib/sqlalchemy/util/__init__.py175
-rw-r--r--lib/sqlalchemy/util/_collections.py1089
-rw-r--r--lib/sqlalchemy/util/_compat_py3k.py67
-rw-r--r--lib/sqlalchemy/util/_concurrency_py3k.py194
-rw-r--r--lib/sqlalchemy/util/_preloaded.py68
-rw-r--r--lib/sqlalchemy/util/compat.py632
-rw-r--r--lib/sqlalchemy/util/concurrency.py73
-rw-r--r--lib/sqlalchemy/util/deprecations.py417
-rw-r--r--lib/sqlalchemy/util/langhelpers.py1945
-rw-r--r--lib/sqlalchemy/util/queue.py291
-rw-r--r--lib/sqlalchemy/util/topological.py100
241 files changed, 183942 insertions, 0 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
new file mode 100644
index 0000000..3cae9f5
--- /dev/null
+++ b/lib/sqlalchemy/__init__.py
@@ -0,0 +1,158 @@
+# sqlalchemy/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import util as _util
+from .engine import create_engine
+from .engine import create_mock_engine
+from .engine import engine_from_config
+from .inspection import inspect
+from .schema import BLANK_SCHEMA
+from .schema import CheckConstraint
+from .schema import Column
+from .schema import ColumnDefault
+from .schema import Computed
+from .schema import Constraint
+from .schema import DDL
+from .schema import DefaultClause
+from .schema import FetchedValue
+from .schema import ForeignKey
+from .schema import ForeignKeyConstraint
+from .schema import Identity
+from .schema import Index
+from .schema import MetaData
+from .schema import PrimaryKeyConstraint
+from .schema import Sequence
+from .schema import Table
+from .schema import ThreadLocalMetaData
+from .schema import UniqueConstraint
+from .sql import alias
+from .sql import all_
+from .sql import and_
+from .sql import any_
+from .sql import asc
+from .sql import between
+from .sql import bindparam
+from .sql import case
+from .sql import cast
+from .sql import collate
+from .sql import column
+from .sql import delete
+from .sql import desc
+from .sql import distinct
+from .sql import except_
+from .sql import except_all
+from .sql import exists
+from .sql import extract
+from .sql import false
+from .sql import func
+from .sql import funcfilter
+from .sql import insert
+from .sql import intersect
+from .sql import intersect_all
+from .sql import join
+from .sql import LABEL_STYLE_DEFAULT
+from .sql import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .sql import LABEL_STYLE_NONE
+from .sql import LABEL_STYLE_TABLENAME_PLUS_COL
+from .sql import lambda_stmt
+from .sql import lateral
+from .sql import literal
+from .sql import literal_column
+from .sql import modifier
+from .sql import not_
+from .sql import null
+from .sql import nulls_first
+from .sql import nulls_last
+from .sql import nullsfirst
+from .sql import nullslast
+from .sql import or_
+from .sql import outerjoin
+from .sql import outparam
+from .sql import over
+from .sql import select
+from .sql import subquery
+from .sql import table
+from .sql import tablesample
+from .sql import text
+from .sql import true
+from .sql import tuple_
+from .sql import type_coerce
+from .sql import union
+from .sql import union_all
+from .sql import update
+from .sql import values
+from .sql import within_group
+from .types import ARRAY
+from .types import BIGINT
+from .types import BigInteger
+from .types import BINARY
+from .types import BLOB
+from .types import BOOLEAN
+from .types import Boolean
+from .types import CHAR
+from .types import CLOB
+from .types import DATE
+from .types import Date
+from .types import DATETIME
+from .types import DateTime
+from .types import DECIMAL
+from .types import Enum
+from .types import FLOAT
+from .types import Float
+from .types import INT
+from .types import INTEGER
+from .types import Integer
+from .types import Interval
+from .types import JSON
+from .types import LargeBinary
+from .types import NCHAR
+from .types import NUMERIC
+from .types import Numeric
+from .types import NVARCHAR
+from .types import PickleType
+from .types import REAL
+from .types import SMALLINT
+from .types import SmallInteger
+from .types import String
+from .types import TEXT
+from .types import Text
+from .types import TIME
+from .types import Time
+from .types import TIMESTAMP
+from .types import TupleType
+from .types import TypeDecorator
+from .types import Unicode
+from .types import UnicodeText
+from .types import VARBINARY
+from .types import VARCHAR
+
+
+__version__ = "1.4.40"
+
+
+def __go(lcls):
+ global __all__
+
+ from . import events
+ from . import util as _sa_util
+
+ import inspect as _inspect
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ _sa_util.preloaded.import_prefix("sqlalchemy")
+
+ from . import exc
+
+ exc._version_token = "".join(__version__.split(".")[0:2])
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..58e90e7
--- /dev/null
+++ b/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py
new file mode 100644
index 0000000..e738086
--- /dev/null
+++ b/lib/sqlalchemy/connectors/__init__.py
@@ -0,0 +1,10 @@
+# connectors/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+class Connector(object):
+ pass
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py
new file mode 100644
index 0000000..89b3484
--- /dev/null
+++ b/lib/sqlalchemy/connectors/mxodbc.py
@@ -0,0 +1,166 @@
+# connectors/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Provide a SQLALchemy connector for the eGenix mxODBC commercial
+Python adapter for ODBC. This is not a free product, but eGenix
+provides SQLAlchemy with a license for use in continuous integration
+testing.
+
+This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
+and 2008, using the SQL Server Native driver. However, it is
+possible for this to be used on other database platforms.
+
+For more info on mxODBC, see https://www.egenix.com/
+
+.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mssql.
+
+"""
+
+import re
+import sys
+import warnings
+
+from . import Connector
+from ..util import warn_deprecated
+
+
+class MxODBCConnector(Connector):
+ driver = "mxodbc"
+
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ supports_native_decimal = True
+
+ @classmethod
+ def dbapi(cls):
+ # this classmethod will normally be replaced by an instance
+ # attribute of the same name, so this is normally only called once.
+ cls._load_mx_exceptions()
+ platform = sys.platform
+ if platform == "win32":
+ from mx.ODBC import Windows as Module
+ # this can be the string "linux2", and possibly others
+ elif "linux" in platform:
+ from mx.ODBC import unixODBC as Module
+ elif platform == "darwin":
+ from mx.ODBC import iODBC as Module
+ else:
+ raise ImportError("Unrecognized platform for mxODBC import")
+
+ warn_deprecated(
+ "The mxODBC DBAPI is deprecated and will be removed"
+ "in a future version. Please use one of the supported DBAPIs to"
+ "connect to mssql.",
+ version="1.4",
+ )
+ return Module
+
+ @classmethod
+ def _load_mx_exceptions(cls):
+ """Import mxODBC exception classes into the module namespace,
+ as if they had been imported normally. This is done here
+ to avoid requiring all SQLAlchemy users to install mxODBC.
+ """
+ global InterfaceError, ProgrammingError
+ from mx.ODBC import InterfaceError
+ from mx.ODBC import ProgrammingError
+
+ def on_connect(self):
+ def connect(conn):
+ conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
+ conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
+ conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
+ conn.errorhandler = self._error_handler()
+
+ return connect
+
+ def _error_handler(self):
+ """Return a handler that adjusts mxODBC's raised Warnings to
+ emit Python standard warnings.
+ """
+ from mx.ODBC.Error import Warning as MxOdbcWarning
+
+ def error_handler(connection, cursor, errorclass, errorvalue):
+ if issubclass(errorclass, MxOdbcWarning):
+ errorclass.__bases__ = (Warning,)
+ warnings.warn(
+ message=str(errorvalue), category=errorclass, stacklevel=2
+ )
+ else:
+ raise errorclass(errorvalue)
+
+ return error_handler
+
+ def create_connect_args(self, url):
+ r"""Return a tuple of \*args, \**kwargs for creating a connection.
+
+ The mxODBC 3.x connection constructor looks like this:
+
+ connect(dsn, user='', password='',
+ clear_auto_commit=1, errorhandler=None)
+
+ This method translates the values in the provided URI
+ into args and kwargs needed to instantiate an mxODBC Connection.
+
+ The arg 'errorhandler' is not used by SQLAlchemy and will
+ not be populated.
+
+ """
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+ args = opts.pop("host")
+ opts.pop("port", None)
+ opts.pop("database", None)
+ return (args,), opts
+
+ def is_disconnect(self, e, connection, cursor):
+ # TODO: eGenix recommends checking connection.closed here
+ # Does that detect dropped connections ?
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return "connection already closed" in str(e)
+ elif isinstance(e, self.dbapi.Error):
+ return "[08S01]" in str(e)
+ else:
+ return False
+
+ def _get_server_version_info(self, connection):
+ # eGenix suggests using conn.dbms_version instead
+ # of what we're doing here
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile(r"[.\-]")
+ # 18 == pyodbc.SQL_DBMS_VER
+ for n in r.split(dbapi_con.getinfo(18)[1]):
+ try:
+ version.append(int(n))
+ except ValueError:
+ version.append(n)
+ return tuple(version)
+
+ def _get_direct(self, context):
+ if context:
+ native_odbc_execute = context.execution_options.get(
+ "native_odbc_execute", "auto"
+ )
+ # default to direct=True in all cases, is more generally
+ # compatible especially with SQL Server
+ return False if native_odbc_execute is True else True
+ else:
+ return True
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ cursor.executemany(
+ statement, parameters, direct=self._get_direct(context)
+ )
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ cursor.execute(statement, parameters, direct=self._get_direct(context))
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
new file mode 100644
index 0000000..9bb67b5
--- /dev/null
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -0,0 +1,193 @@
+# connectors/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from . import Connector
+from .. import util
+
+
+class PyODBCConnector(Connector):
+ driver = "pyodbc"
+
+ # this is no longer False for pyodbc in general
+ supports_sane_rowcount_returning = True
+ supports_sane_multi_rowcount = False
+
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ supports_native_decimal = True
+ default_paramstyle = "named"
+
+ use_setinputsizes = False
+
+ # for non-DSN connections, this *may* be used to
+ # hold the desired driver name
+ pyodbc_driver_name = None
+
+ def __init__(
+ self, supports_unicode_binds=None, use_setinputsizes=False, **kw
+ ):
+ super(PyODBCConnector, self).__init__(**kw)
+ if supports_unicode_binds is not None:
+ self.supports_unicode_binds = supports_unicode_binds
+ self.use_setinputsizes = use_setinputsizes
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pyodbc")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+
+ keys = opts
+
+ query = url.query
+
+ connect_args = {}
+ for param in ("ansi", "unicode_results", "autocommit"):
+ if param in keys:
+ connect_args[param] = util.asbool(keys.pop(param))
+
+ if "odbc_connect" in keys:
+ connectors = [util.unquote_plus(keys.pop("odbc_connect"))]
+ else:
+
+ def check_quote(token):
+ if ";" in str(token) or str(token).startswith("{"):
+ token = "{%s}" % token.replace("}", "}}")
+ return token
+
+ keys = dict((k, check_quote(v)) for k, v in keys.items())
+
+ dsn_connection = "dsn" in keys or (
+ "host" in keys and "database" not in keys
+ )
+ if dsn_connection:
+ connectors = [
+ "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", ""))
+ ]
+ else:
+ port = ""
+ if "port" in keys and "port" not in query:
+ port = ",%d" % int(keys.pop("port"))
+
+ connectors = []
+ driver = keys.pop("driver", self.pyodbc_driver_name)
+ if driver is None and keys:
+ # note if keys is empty, this is a totally blank URL
+ util.warn(
+ "No driver name specified; "
+ "this is expected by PyODBC when using "
+ "DSN-less connections"
+ )
+ else:
+ connectors.append("DRIVER={%s}" % driver)
+
+ connectors.extend(
+ [
+ "Server=%s%s" % (keys.pop("host", ""), port),
+ "Database=%s" % keys.pop("database", ""),
+ ]
+ )
+
+ user = keys.pop("user", None)
+ if user:
+ connectors.append("UID=%s" % user)
+ pwd = keys.pop("password", "")
+ if pwd:
+ connectors.append("PWD=%s" % pwd)
+ else:
+ authentication = keys.pop("authentication", None)
+ if authentication:
+ connectors.append("Authentication=%s" % authentication)
+ else:
+ connectors.append("Trusted_Connection=Yes")
+
+ # if set to 'Yes', the ODBC layer will try to automagically
+ # convert textual data from your database encoding to your
+ # client encoding. This should obviously be set to 'No' if
+ # you query a cp1253 encoded database from a latin1 client...
+ if "odbc_autotranslate" in keys:
+ connectors.append(
+ "AutoTranslate=%s" % keys.pop("odbc_autotranslate")
+ )
+
+ connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()])
+
+ return [[";".join(connectors)], connect_args]
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return "The cursor's connection has been closed." in str(
+ e
+ ) or "Attempt to use a closed connection." in str(e)
+ else:
+ return False
+
+ def _dbapi_version(self):
+ if not self.dbapi:
+ return ()
+ return self._parse_dbapi_version(self.dbapi.version)
+
+ def _parse_dbapi_version(self, vers):
+ m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers)
+ if not m:
+ return ()
+ vers = tuple([int(x) for x in m.group(1).split(".")])
+ if m.group(2):
+ vers += (m.group(2),)
+ return vers
+
+ def _get_server_version_info(self, connection, allow_chars=True):
+ # NOTE: this function is not reliable, particularly when
+ # freetds is in use. Implement database-specific server version
+ # queries.
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile(r"[.\-]")
+ for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+ try:
+ version.append(int(n))
+ except ValueError:
+ if allow_chars:
+ version.append(n)
+ return tuple(version)
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ # the rules for these types seems a little strange, as you can pass
+ # non-tuples as well as tuples, however it seems to assume "0"
+ # for the subsequent values if you don't pass a tuple which fails
+ # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
+ # that ticket #5649 is targeting.
+
+ # NOTE: as of #6058, this won't be called if the use_setinputsizes flag
+ # is False, or if no types were specified in list_of_tuples
+
+ cursor.setinputsizes(
+ [
+ (dbtype, None, None)
+ if not isinstance(dbtype, tuple)
+ else dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ # adjust for ConnectionFairy being present
+ # allows attribute set e.g. "connection.autocommit = True"
+ # to work properly
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(PyODBCConnector, self).set_isolation_level(connection, level)
diff --git a/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..f2b7b00
--- /dev/null
+++ b/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..0d851bd
--- /dev/null
+++ b/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py
new file mode 100644
index 0000000..fa83229
--- /dev/null
+++ b/lib/sqlalchemy/databases/__init__.py
@@ -0,0 +1,38 @@
+# databases/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Include imports from the sqlalchemy.dialects package for backwards
+compatibility with pre 0.6 versions.
+
+"""
+from ..dialects.firebird import base as firebird
+from ..dialects.mssql import base as mssql
+from ..dialects.mysql import base as mysql
+from ..dialects.oracle import base as oracle
+from ..dialects.postgresql import base as postgresql
+from ..dialects.sqlite import base as sqlite
+from ..dialects.sybase import base as sybase
+from ..util import warn_deprecated_20
+
+postgres = postgresql
+
+
+__all__ = (
+ "firebird",
+ "mssql",
+ "mysql",
+ "postgresql",
+ "sqlite",
+ "oracle",
+ "sybase",
+)
+
+
+warn_deprecated_20(
+ "The `database` package is deprecated and will be removed in v2.0 "
+ "of sqlalchemy. Use the `dialects` package instead."
+)
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py
new file mode 100644
index 0000000..84a9ad8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/__init__.py
@@ -0,0 +1,72 @@
+# dialects/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+__all__ = (
+ "firebird",
+ "mssql",
+ "mysql",
+ "oracle",
+ "postgresql",
+ "sqlite",
+ "sybase",
+)
+
+
+from .. import util
+
+
+def _auto_fn(name):
+ """default dialect importer.
+
+ plugs into the :class:`.PluginLoader`
+ as a first-hit system.
+
+ """
+ if "." in name:
+ dialect, driver = name.split(".")
+ else:
+ dialect = name
+ driver = "base"
+
+ try:
+ if dialect == "firebird":
+ try:
+ module = __import__("sqlalchemy_firebird")
+ except ImportError:
+ module = __import__("sqlalchemy.dialects.firebird").dialects
+ module = getattr(module, dialect)
+ elif dialect == "sybase":
+ try:
+ module = __import__("sqlalchemy_sybase")
+ except ImportError:
+ module = __import__("sqlalchemy.dialects.sybase").dialects
+ module = getattr(module, dialect)
+ elif dialect == "mariadb":
+ # it's "OK" for us to hardcode here since _auto_fn is already
+ # hardcoded. if mysql / mariadb etc were third party dialects
+ # they would just publish all the entrypoints, which would actually
+ # look much nicer.
+ module = __import__(
+ "sqlalchemy.dialects.mysql.mariadb"
+ ).dialects.mysql.mariadb
+ return module.loader(driver)
+ else:
+ module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
+ module = getattr(module, dialect)
+ except ImportError:
+ return None
+
+ if hasattr(module, driver):
+ module = getattr(module, driver)
+ return lambda: module.dialect
+ else:
+ return None
+
+
+registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
+
+plugins = util.PluginLoader("sqlalchemy.plugins")
diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py
new file mode 100644
index 0000000..a34eecf
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/__init__.py
@@ -0,0 +1,41 @@
+# firebird/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from sqlalchemy.dialects.firebird.base import BIGINT
+from sqlalchemy.dialects.firebird.base import BLOB
+from sqlalchemy.dialects.firebird.base import CHAR
+from sqlalchemy.dialects.firebird.base import DATE
+from sqlalchemy.dialects.firebird.base import FLOAT
+from sqlalchemy.dialects.firebird.base import NUMERIC
+from sqlalchemy.dialects.firebird.base import SMALLINT
+from sqlalchemy.dialects.firebird.base import TEXT
+from sqlalchemy.dialects.firebird.base import TIME
+from sqlalchemy.dialects.firebird.base import TIMESTAMP
+from sqlalchemy.dialects.firebird.base import VARCHAR
+from . import base # noqa
+from . import fdb # noqa
+from . import kinterbasdb # noqa
+
+
+base.dialect = dialect = fdb.dialect
+
+__all__ = (
+ "SMALLINT",
+ "BIGINT",
+ "FLOAT",
+ "FLOAT",
+ "DATE",
+ "TIME",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "TIMESTAMP",
+ "VARCHAR",
+ "CHAR",
+ "BLOB",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
new file mode 100644
index 0000000..e2698b1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -0,0 +1,989 @@
+# firebird/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: firebird
+ :name: Firebird
+
+.. note::
+
+ The Firebird dialect within SQLAlchemy **is not currently supported**.
+ It is not tested within continuous integration and is likely to have
+ many issues and caveats not currently handled. Consider using the
+ `external dialect <https://github.com/pauldex/sqlalchemy-firebird>`_
+ instead.
+
+.. deprecated:: 1.4 The internal Firebird dialect is deprecated and will be
+ removed in a future version. Use the external dialect.
+
+Firebird Dialects
+-----------------
+
+Firebird offers two distinct dialects_ (not to be confused with a
+SQLAlchemy ``Dialect``):
+
+dialect 1
+ This is the old syntax and behaviour, inherited from Interbase pre-6.0.
+
+dialect 3
+ This is the newer and supported syntax, introduced in Interbase 6.0.
+
+The SQLAlchemy Firebird dialect detects these versions and
+adjusts its representation of SQL accordingly. However,
+support for dialect 1 is not well tested and probably has
+incompatibilities.
+
+Locking Behavior
+----------------
+
+Firebird locks tables aggressively. For this reason, a DROP TABLE may
+hang until other transactions are released. SQLAlchemy does its best
+to release transactions as quickly as possible. The most common cause
+of hanging transactions is a non-fully consumed result set, i.e.::
+
+ result = engine.execute(text("select * from table"))
+ row = result.fetchone()
+ return
+
+Where above, the ``CursorResult`` has not been fully consumed. The
+connection will be returned to the pool and the transactional state
+rolled back once the Python garbage collector reclaims the objects
+which hold onto the connection, which often occurs asynchronously.
+The above use case can be alleviated by calling ``first()`` on the
+``CursorResult`` which will fetch the first row and immediately close
+all remaining cursor/connection resources.
+
+RETURNING support
+-----------------
+
+Firebird 2.0 supports returning a result set from inserts, and 2.1
+extends that to deletes and updates. This is generically exposed by
+the SQLAlchemy ``returning()`` method, such as::
+
+ # INSERT..RETURNING
+ result = table.insert().returning(table.c.col1, table.c.col2).\
+ values(name='foo')
+ print(result.fetchall())
+
+ # UPDATE..RETURNING
+ raises = empl.update().returning(empl.c.id, empl.c.salary).\
+ where(empl.c.sales>100).\
+ values(dict(salary=empl.c.salary * 1.1))
+ print(raises.fetchall())
+
+
+.. _dialects: https://mc-computing.com/Databases/Firebird/SQL_Dialect.html
+"""
+
+import datetime
+
+from sqlalchemy import exc
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import expression
+from sqlalchemy.types import BIGINT
+from sqlalchemy.types import BLOB
+from sqlalchemy.types import DATE
+from sqlalchemy.types import FLOAT
+from sqlalchemy.types import INTEGER
+from sqlalchemy.types import Integer
+from sqlalchemy.types import NUMERIC
+from sqlalchemy.types import SMALLINT
+from sqlalchemy.types import TEXT
+from sqlalchemy.types import TIME
+from sqlalchemy.types import TIMESTAMP
+
+
+RESERVED_WORDS = set(
+ [
+ "active",
+ "add",
+ "admin",
+ "after",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "ascending",
+ "at",
+ "auto",
+ "avg",
+ "before",
+ "begin",
+ "between",
+ "bigint",
+ "bit_length",
+ "blob",
+ "both",
+ "by",
+ "case",
+ "cast",
+ "char",
+ "character",
+ "character_length",
+ "char_length",
+ "check",
+ "close",
+ "collate",
+ "column",
+ "commit",
+ "committed",
+ "computed",
+ "conditional",
+ "connect",
+ "constraint",
+ "containing",
+ "count",
+ "create",
+ "cross",
+ "cstring",
+ "current",
+ "current_connection",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_transaction",
+ "current_user",
+ "cursor",
+ "database",
+ "date",
+ "day",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "desc",
+ "descending",
+ "disconnect",
+ "distinct",
+ "do",
+ "domain",
+ "double",
+ "drop",
+ "else",
+ "end",
+ "entry_point",
+ "escape",
+ "exception",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "extract",
+ "fetch",
+ "file",
+ "filter",
+ "float",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "function",
+ "gdscode",
+ "generator",
+ "gen_id",
+ "global",
+ "grant",
+ "group",
+ "having",
+ "hour",
+ "if",
+ "in",
+ "inactive",
+ "index",
+ "inner",
+ "input_type",
+ "insensitive",
+ "insert",
+ "int",
+ "integer",
+ "into",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "leading",
+ "left",
+ "length",
+ "level",
+ "like",
+ "long",
+ "lower",
+ "manual",
+ "max",
+ "maximum_segment",
+ "merge",
+ "min",
+ "minute",
+ "module_name",
+ "month",
+ "names",
+ "national",
+ "natural",
+ "nchar",
+ "no",
+ "not",
+ "null",
+ "numeric",
+ "octet_length",
+ "of",
+ "on",
+ "only",
+ "open",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "output_type",
+ "overflow",
+ "page",
+ "pages",
+ "page_size",
+ "parameter",
+ "password",
+ "plan",
+ "position",
+ "post_event",
+ "precision",
+ "primary",
+ "privileges",
+ "procedure",
+ "protected",
+ "rdb$db_key",
+ "read",
+ "real",
+ "record_version",
+ "recreate",
+ "recursive",
+ "references",
+ "release",
+ "reserv",
+ "reserving",
+ "retain",
+ "returning_values",
+ "returns",
+ "revoke",
+ "right",
+ "rollback",
+ "rows",
+ "row_count",
+ "savepoint",
+ "schema",
+ "second",
+ "segment",
+ "select",
+ "sensitive",
+ "set",
+ "shadow",
+ "shared",
+ "singular",
+ "size",
+ "smallint",
+ "snapshot",
+ "some",
+ "sort",
+ "sqlcode",
+ "stability",
+ "start",
+ "starting",
+ "starts",
+ "statistics",
+ "sub_type",
+ "sum",
+ "suspend",
+ "table",
+ "then",
+ "time",
+ "timestamp",
+ "to",
+ "trailing",
+ "transaction",
+ "trigger",
+ "trim",
+ "uncommitted",
+ "union",
+ "unique",
+ "update",
+ "upper",
+ "user",
+ "using",
+ "value",
+ "values",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "when",
+ "where",
+ "while",
+ "with",
+ "work",
+ "write",
+ "year",
+ ]
+)
+
+
+class _StringType(sqltypes.String):
+ """Base for Firebird string types."""
+
+ def __init__(self, charset=None, **kw):
+ self.charset = charset
+ super(_StringType, self).__init__(**kw)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+ """Firebird VARCHAR type"""
+
+ __visit_name__ = "VARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ super(VARCHAR, self).__init__(length=length, **kwargs)
+
+
+class CHAR(_StringType, sqltypes.CHAR):
+ """Firebird CHAR type"""
+
+ __visit_name__ = "CHAR"
+
+ def __init__(self, length=None, **kwargs):
+ super(CHAR, self).__init__(length=length, **kwargs)
+
+
+class _FBDateTime(sqltypes.DateTime):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+
+colspecs = {sqltypes.DateTime: _FBDateTime}
+
+ischema_names = {
+ "SHORT": SMALLINT,
+ "LONG": INTEGER,
+ "QUAD": FLOAT,
+ "FLOAT": FLOAT,
+ "DATE": DATE,
+ "TIME": TIME,
+ "TEXT": TEXT,
+ "INT64": BIGINT,
+ "DOUBLE": FLOAT,
+ "TIMESTAMP": TIMESTAMP,
+ "VARYING": VARCHAR,
+ "CSTRING": CHAR,
+ "BLOB": BLOB,
+}
+
+
+# TODO: date conversion types (should be implemented as _FBDateTime,
+# _FBDate, etc. as bind/result functionality is required)
+
+
+class FBTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
+
+ def visit_TEXT(self, type_, **kw):
+ return "BLOB SUB_TYPE 1"
+
+ def visit_BLOB(self, type_, **kw):
+ return "BLOB SUB_TYPE 0"
+
+ def _extend_string(self, type_, basic):
+ charset = getattr(type_, "charset", None)
+ if charset is None:
+ return basic
+ else:
+ return "%s CHARACTER SET %s" % (basic, charset)
+
+ def visit_CHAR(self, type_, **kw):
+ basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
+ return self._extend_string(type_, basic)
+
+ def visit_VARCHAR(self, type_, **kw):
+ if not type_.length:
+ raise exc.CompileError(
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+ basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
+ return self._extend_string(type_, basic)
+
+
+class FBCompiler(sql.compiler.SQLCompiler):
+ """Firebird specific idiosyncrasies"""
+
+ ansi_bind_rules = True
+
+ # def visit_contains_op_binary(self, binary, operator, **kw):
+ # cant use CONTAINING b.c. it's case insensitive.
+
+ # def visit_not_contains_op_binary(self, binary, operator, **kw):
+ # cant use NOT CONTAINING b.c. it's case insensitive.
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_startswith_op_binary(self, binary, operator, **kw):
+ return "%s STARTING WITH %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ )
+
+ def visit_not_startswith_op_binary(self, binary, operator, **kw):
+ return "%s NOT STARTING WITH %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ )
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return "mod(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if self.dialect._version_two:
+ return super(FBCompiler, self).visit_alias(
+ alias, asfrom=asfrom, **kwargs
+ )
+ else:
+ # Override to not use the AS keyword which FB 1.5 does not like
+ if asfrom:
+ alias_name = (
+ isinstance(alias.name, expression._truncated_label)
+ and self._truncated_identifier("alias", alias.name)
+ or alias.name
+ )
+
+ return (
+ self.process(alias.element, asfrom=asfrom, **kwargs)
+ + " "
+ + self.preparer.format_alias(alias, alias_name)
+ )
+ else:
+ return self.process(alias.element, **kwargs)
+
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0])
+ start = self.process(func.clauses.clauses[1])
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2])
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+
+ def visit_length_func(self, function, **kw):
+ if self.dialect._version_two:
+ return "char_length" + self.function_argspec(function)
+ else:
+ return "strlen" + self.function_argspec(function)
+
+ visit_char_length_func = visit_length_func
+
+ def function_argspec(self, func, **kw):
+ # TODO: this probably will need to be
+ # narrowed to a fixed list, some no-arg functions
+ # may require parens - see similar example in the oracle
+ # dialect
+ if func.clauses is not None and len(func.clauses):
+ return self.process(func.clause_expr, **kw)
+ else:
+ return ""
+
+ def default_from(self):
+ return " FROM rdb$database"
+
+ def visit_sequence(self, seq, **kw):
+ return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
+
+ def get_select_precolumns(self, select, **kw):
+ """Called when building a ``SELECT`` statement, position is just
+ before column list Firebird puts the limit and offset right
+ after the ``SELECT``...
+ """
+
+ result = ""
+ if select._limit_clause is not None:
+ result += "FIRST %s " % self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ result += "SKIP %s " % self.process(select._offset_clause, **kw)
+ result += super(FBCompiler, self).get_select_precolumns(select, **kw)
+ return result
+
+ def limit_clause(self, select, **kw):
+ """Already taken care of in the `get_select_precolumns` method."""
+
+ return ""
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = [
+ self._label_returning_column(stmt, c)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+
+class FBDDLCompiler(sql.compiler.DDLCompiler):
+ """Firebird syntactic idiosyncrasies"""
+
+ def visit_create_sequence(self, create):
+ """Generate a ``CREATE GENERATOR`` statement for the sequence."""
+
+ # no syntax for these
+ # https://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
+ if create.element.start is not None:
+ raise NotImplementedError(
+ "Firebird SEQUENCE doesn't support START WITH"
+ )
+ if create.element.increment is not None:
+ raise NotImplementedError(
+ "Firebird SEQUENCE doesn't support INCREMENT BY"
+ )
+
+ if self.dialect._version_two:
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
+ else:
+ return "CREATE GENERATOR %s" % self.preparer.format_sequence(
+ create.element
+ )
+
+ def visit_drop_sequence(self, drop):
+ """Generate a ``DROP GENERATOR`` statement for the sequence."""
+
+ if self.dialect._version_two:
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(
+ drop.element
+ )
+ else:
+ return "DROP GENERATOR %s" % self.preparer.format_sequence(
+ drop.element
+ )
+
+ def visit_computed_column(self, generated):
+ if generated.persisted is not None:
+ raise exc.CompileError(
+ "Firebird computed columns do not support a persistence "
+ "method setting; set the 'persisted' flag to None for "
+ "Firebird support."
+ )
+ return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+
+
+class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
+ """Install Firebird specific reserved words."""
+
+ reserved_words = RESERVED_WORDS
+ illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
+ ["_"]
+ )
+
+ def __init__(self, dialect):
+ super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
+
+
+class FBExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ """Get the next value from the sequence using ``gen_id()``."""
+
+ return self._execute_scalar(
+ "SELECT gen_id(%s, 1) FROM rdb$database"
+ % self.identifier_preparer.format_sequence(seq),
+ type_,
+ )
+
+
+class FBDialect(default.DefaultDialect):
+ """Firebird dialect"""
+
+ name = "firebird"
+ supports_statement_cache = True
+
+ max_identifier_length = 31
+
+ supports_sequences = True
+ sequences_optional = False
+ supports_default_values = True
+ postfetch_lastrowid = False
+
+ supports_native_boolean = False
+
+ requires_name_normalize = True
+ supports_empty_insert = False
+
+ statement_compiler = FBCompiler
+ ddl_compiler = FBDDLCompiler
+ preparer = FBIdentifierPreparer
+ type_compiler = FBTypeCompiler
+ execution_ctx_cls = FBExecutionContext
+
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ construct_arguments = []
+
+ # defaults to dialect ver. 3,
+ # will be autodetected off upon
+ # first connect
+ _version_two = True
+
+ def __init__(self, *args, **kwargs):
+ util.warn_deprecated(
+ "The firebird dialect is deprecated and will be removed "
+ "in a future version. This dialect is superseded by the external "
+ "dialect https://github.com/pauldex/sqlalchemy-firebird.",
+ version="1.4",
+ )
+ super(FBDialect, self).__init__(*args, **kwargs)
+
+ def initialize(self, connection):
+ super(FBDialect, self).initialize(connection)
+ self._version_two = (
+ "firebird" in self.server_version_info
+ and self.server_version_info >= (2,)
+ ) or (
+ "interbase" in self.server_version_info
+ and self.server_version_info >= (6,)
+ )
+
+ if not self._version_two:
+ # TODO: whatever other pre < 2.0 stuff goes here
+ self.ischema_names = ischema_names.copy()
+ self.ischema_names["TIMESTAMP"] = sqltypes.DATE
+ self.colspecs = {sqltypes.DateTime: sqltypes.DATE}
+
+ self.implicit_returning = self._version_two and self.__dict__.get(
+ "implicit_returning", True
+ )
+
+ def has_table(self, connection, table_name, schema=None):
+ """Return ``True`` if the given table exists, ignoring
+ the `schema`."""
+ self._ensure_has_table_connection(connection)
+
+ tblqry = """
+ SELECT 1 AS has_table FROM rdb$database
+ WHERE EXISTS (SELECT rdb$relation_name
+ FROM rdb$relations
+ WHERE rdb$relation_name=?)
+ """
+ c = connection.exec_driver_sql(
+ tblqry, [self.denormalize_name(table_name)]
+ )
+ return c.first() is not None
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ """Return ``True`` if the given sequence (generator) exists."""
+
+ genqry = """
+ SELECT 1 AS has_sequence FROM rdb$database
+ WHERE EXISTS (SELECT rdb$generator_name
+ FROM rdb$generators
+ WHERE rdb$generator_name=?)
+ """
+ c = connection.exec_driver_sql(
+ genqry, [self.denormalize_name(sequence_name)]
+ )
+ return c.first() is not None
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ # there are two queries commonly mentioned for this.
+ # this one, using view_blr, is at the Firebird FAQ among other places:
+ # https://www.firebirdfaq.org/faq174/
+ s = """
+ select rdb$relation_name
+ from rdb$relations
+ where rdb$view_blr is null
+ and (rdb$system_flag is null or rdb$system_flag = 0);
+ """
+
+ # the other query is this one. It's not clear if there's really
+ # any difference between these two. This link:
+ # https://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
+ # states them as interchangeable. Some discussion at [ticket:2898]
+ # SELECT DISTINCT rdb$relation_name
+ # FROM rdb$relation_fields
+ # WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
+
+ return [
+ self.normalize_name(row[0])
+ for row in connection.exec_driver_sql(s)
+ ]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ # see https://www.firebirdfaq.org/faq174/
+ s = """
+ select rdb$relation_name
+ from rdb$relations
+ where rdb$view_blr is not null
+ and (rdb$system_flag is null or rdb$system_flag = 0);
+ """
+ return [
+ self.normalize_name(row[0])
+ for row in connection.exec_driver_sql(s)
+ ]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ qry = """
+ SELECT rdb$view_source AS view_source
+ FROM rdb$relations
+ WHERE rdb$relation_name=?
+ """
+ rp = connection.exec_driver_sql(
+ qry, [self.denormalize_name(view_name)]
+ )
+ row = rp.first()
+ if row:
+ return row["view_source"]
+ else:
+ return None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ # Query to extract the PK/FK constrained fields of the given table
+ keyqry = """
+ SELECT se.rdb$field_name AS fname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ """
+ tablename = self.denormalize_name(table_name)
+ # get primary key fields
+ c = connection.exec_driver_sql(keyqry, ["PRIMARY KEY", tablename])
+ pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()]
+ return {"constrained_columns": pkfields, "name": None}
+
+ @reflection.cache
+ def get_column_sequence(
+ self, connection, table_name, column_name, schema=None, **kw
+ ):
+ tablename = self.denormalize_name(table_name)
+ colname = self.denormalize_name(column_name)
+ # Heuristic-query to determine the generator associated to a PK field
+ genqry = """
+ SELECT trigdep.rdb$depended_on_name AS fgenerator
+ FROM rdb$dependencies tabdep
+ JOIN rdb$dependencies trigdep
+ ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
+ AND trigdep.rdb$depended_on_type=14
+ AND trigdep.rdb$dependent_type=2
+ JOIN rdb$triggers trig ON
+ trig.rdb$trigger_name=tabdep.rdb$dependent_name
+ WHERE tabdep.rdb$depended_on_name=?
+ AND tabdep.rdb$depended_on_type=0
+ AND trig.rdb$trigger_type=1
+ AND tabdep.rdb$field_name=?
+ AND (SELECT count(*)
+ FROM rdb$dependencies trigdep2
+ WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
+ """
+ genr = connection.exec_driver_sql(genqry, [tablename, colname]).first()
+ if genr is not None:
+ return dict(name=self.normalize_name(genr["fgenerator"]))
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of all the fields of the given table
+ tblqry = """
+ SELECT r.rdb$field_name AS fname,
+ r.rdb$null_flag AS null_flag,
+ t.rdb$type_name AS ftype,
+ f.rdb$field_sub_type AS stype,
+ f.rdb$field_length/
+ COALESCE(cs.rdb$bytes_per_character,1) AS flen,
+ f.rdb$field_precision AS fprec,
+ f.rdb$field_scale AS fscale,
+ COALESCE(r.rdb$default_source,
+ f.rdb$default_source) AS fdefault
+ FROM rdb$relation_fields r
+ JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
+ JOIN rdb$types t
+ ON t.rdb$type=f.rdb$field_type AND
+ t.rdb$field_name='RDB$FIELD_TYPE'
+ LEFT JOIN rdb$character_sets cs ON
+ f.rdb$character_set_id=cs.rdb$character_set_id
+ WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
+ ORDER BY r.rdb$field_position
+ """
+ # get the PK, used to determine the eventual associated sequence
+ pk_constraint = self.get_pk_constraint(connection, table_name)
+ pkey_cols = pk_constraint["constrained_columns"]
+
+ tablename = self.denormalize_name(table_name)
+ # get all of the fields for this table
+ c = connection.exec_driver_sql(tblqry, [tablename])
+ cols = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ name = self.normalize_name(row["fname"])
+ orig_colname = row["fname"]
+
+ # get the data type
+ colspec = row["ftype"].rstrip()
+ coltype = self.ischema_names.get(colspec)
+ if coltype is None:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (colspec, name)
+ )
+ coltype = sqltypes.NULLTYPE
+ elif issubclass(coltype, Integer) and row["fprec"] != 0:
+ coltype = NUMERIC(
+ precision=row["fprec"], scale=row["fscale"] * -1
+ )
+ elif colspec in ("VARYING", "CSTRING"):
+ coltype = coltype(row["flen"])
+ elif colspec == "TEXT":
+ coltype = TEXT(row["flen"])
+ elif colspec == "BLOB":
+ if row["stype"] == 1:
+ coltype = TEXT()
+ else:
+ coltype = BLOB()
+ else:
+ coltype = coltype()
+
+ # does it have a default value?
+ defvalue = None
+ if row["fdefault"] is not None:
+ # the value comes down as "DEFAULT 'value'": there may be
+ # more than one whitespace around the "DEFAULT" keyword
+ # and it may also be lower case
+ # (see also https://tracker.firebirdsql.org/browse/CORE-356)
+ defexpr = row["fdefault"].lstrip()
+ assert defexpr[:8].rstrip().upper() == "DEFAULT", (
+ "Unrecognized default value: %s" % defexpr
+ )
+ defvalue = defexpr[8:].strip()
+ if defvalue == "NULL":
+ # Redundant
+ defvalue = None
+ col_d = {
+ "name": name,
+ "type": coltype,
+ "nullable": not bool(row["null_flag"]),
+ "default": defvalue,
+ "autoincrement": "auto",
+ }
+
+ if orig_colname.lower() == orig_colname:
+ col_d["quote"] = True
+
+ # if the PK is a single field, try to see if its linked to
+ # a sequence thru a trigger
+ if len(pkey_cols) == 1 and name == pkey_cols[0]:
+ seq_d = self.get_column_sequence(connection, tablename, name)
+ if seq_d is not None:
+ col_d["sequence"] = seq_d
+
+ cols.append(col_d)
+ return cols
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of each UK/FK of the given table
+ fkqry = """
+ SELECT rc.rdb$constraint_name AS cname,
+ cse.rdb$field_name AS fname,
+ ix2.rdb$relation_name AS targetrname,
+ se.rdb$field_name AS targetfname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
+ JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
+ JOIN rdb$index_segments cse ON
+ cse.rdb$index_name=ix1.rdb$index_name
+ JOIN rdb$index_segments se
+ ON se.rdb$index_name=ix2.rdb$index_name
+ AND se.rdb$field_position=cse.rdb$field_position
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ ORDER BY se.rdb$index_name, se.rdb$field_position
+ """
+ tablename = self.denormalize_name(table_name)
+
+ c = connection.exec_driver_sql(fkqry, ["FOREIGN KEY", tablename])
+ fks = util.defaultdict(
+ lambda: {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ }
+ )
+
+ for row in c:
+ cname = self.normalize_name(row["cname"])
+ fk = fks[cname]
+ if not fk["name"]:
+ fk["name"] = cname
+ fk["referred_table"] = self.normalize_name(row["targetrname"])
+ fk["constrained_columns"].append(self.normalize_name(row["fname"]))
+ fk["referred_columns"].append(
+ self.normalize_name(row["targetfname"])
+ )
+ return list(fks.values())
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ qry = """
+ SELECT ix.rdb$index_name AS index_name,
+ ix.rdb$unique_flag AS unique_flag,
+ ic.rdb$field_name AS field_name
+ FROM rdb$indices ix
+ JOIN rdb$index_segments ic
+ ON ix.rdb$index_name=ic.rdb$index_name
+ LEFT OUTER JOIN rdb$relation_constraints
+ ON rdb$relation_constraints.rdb$index_name =
+ ic.rdb$index_name
+ WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
+ AND rdb$relation_constraints.rdb$constraint_type IS NULL
+ ORDER BY index_name, ic.rdb$field_position
+ """
+ c = connection.exec_driver_sql(
+ qry, [self.denormalize_name(table_name)]
+ )
+
+ indexes = util.defaultdict(dict)
+ for row in c:
+ indexrec = indexes[row["index_name"]]
+ if "name" not in indexrec:
+ indexrec["name"] = self.normalize_name(row["index_name"])
+ indexrec["column_names"] = []
+ indexrec["unique"] = bool(row["unique_flag"])
+
+ indexrec["column_names"].append(
+ self.normalize_name(row["field_name"])
+ )
+
+ return list(indexes.values())
diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py
new file mode 100644
index 0000000..38f4432
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/fdb.py
@@ -0,0 +1,112 @@
+# firebird/fdb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: firebird+fdb
+ :name: fdb
+ :dbapi: pyodbc
+ :connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
+ :url: https://pypi.org/project/fdb/
+
+ fdb is a kinterbasdb compatible DBAPI for Firebird.
+
+ .. versionchanged:: 0.9 - The fdb dialect is now the default dialect
+ under the ``firebird://`` URL space, as ``fdb`` is now the official
+ Python driver for Firebird.
+
+Arguments
+----------
+
+The ``fdb`` dialect is based on the
+:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not
+accept every argument that Kinterbasdb does.
+
+* ``enable_rowcount`` - True by default, setting this to False disables
+ the usage of "cursor.rowcount" with the
+ Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
+ after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
+ CursorResult will return -1 for result.rowcount. The rationale here is
+ that Kinterbasdb requires a second round trip to the database when
+ .rowcount is called - since SQLA's resultproxy automatically closes
+ the cursor after a non-result-returning statement, rowcount must be
+ called, if at all, before the result object is returned. Additionally,
+ cursor.rowcount may not return correct results with older versions
+ of Firebird, and setting this flag to False will also cause the
+ SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
+ per-execution basis using the ``enable_rowcount`` option with
+ :meth:`_engine.Connection.execution_options`::
+
+ conn = engine.connect().execution_options(enable_rowcount=True)
+ r = conn.execute(stmt)
+ print(r.rowcount)
+
+* ``retaining`` - False by default. Setting this to True will pass the
+ ``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
+ methods of the DBAPI connection, which can improve performance in some
+ situations, but apparently with significant caveats.
+ Please read the fdb and/or kinterbasdb DBAPI documentation in order to
+ understand the implications of this flag.
+
+ .. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
+ In 0.8 it defaulted to ``True``.
+
+ .. seealso::
+
+ https://pythonhosted.org/fdb/usage-guide.html#retaining-transactions
+ - information on the "retaining" flag.
+
+""" # noqa
+
+from .kinterbasdb import FBDialect_kinterbasdb
+from ... import util
+
+
+class FBDialect_fdb(FBDialect_kinterbasdb):
+ supports_statement_cache = True
+
+ def __init__(self, enable_rowcount=True, retaining=False, **kwargs):
+ super(FBDialect_fdb, self).__init__(
+ enable_rowcount=enable_rowcount, retaining=retaining, **kwargs
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("fdb")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "type_conv", int)
+
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ """Get the version of the Firebird server used by a connection.
+
+ Returns a tuple of (`major`, `minor`, `build`), three integers
+ representing the version of the attached server.
+ """
+
+ # This is the simpler approach (the other uses the services api),
+ # that for backward compatibility reasons returns a string like
+ # LI-V6.3.3.12981 Firebird 2.0
+ # where the first version is a fake one resembling the old
+ # Interbase signature.
+
+ isc_info_firebird_version = 103
+ fbconn = connection.connection
+
+ version = fbconn.db_info(isc_info_firebird_version)
+
+ return self._parse_version_info(version)
+
+
+dialect = FBDialect_fdb
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
new file mode 100644
index 0000000..b999404
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
@@ -0,0 +1,202 @@
+# firebird/kinterbasdb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: firebird+kinterbasdb
+ :name: kinterbasdb
+ :dbapi: kinterbasdb
+ :connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
+ :url: https://firebirdsql.org/index.php?op=devel&sub=python
+
+Arguments
+----------
+
+The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
+arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect.
+In addition, it also accepts the following:
+
+* ``type_conv`` - select the kind of mapping done on the types: by default
+ SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
+ the linked documents below for further information.
+
+* ``concurrency_level`` - set the backend policy with regards to threading
+ issues: by default SQLAlchemy uses policy 1. See the linked documents
+ below for further information.
+
+.. seealso::
+
+ https://sourceforge.net/projects/kinterbasdb
+
+ https://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
+
+ https://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
+
+""" # noqa
+
+import decimal
+from re import match
+
+from .base import FBDialect
+from .base import FBExecutionContext
+from ... import types as sqltypes
+from ... import util
+
+
+class _kinterbasdb_numeric(object):
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return str(value)
+ else:
+ return value
+
+ return process
+
+
+class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
+ pass
+
+
+class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
+ pass
+
+
+class FBExecutionContext_kinterbasdb(FBExecutionContext):
+ @property
+ def rowcount(self):
+ if self.execution_options.get(
+ "enable_rowcount", self.dialect.enable_rowcount
+ ):
+ return self.cursor.rowcount
+ else:
+ return -1
+
+
+class FBDialect_kinterbasdb(FBDialect):
+ driver = "kinterbasdb"
+ supports_statement_cache = True
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = FBExecutionContext_kinterbasdb
+
+ supports_native_decimal = True
+
+ colspecs = util.update_copy(
+ FBDialect.colspecs,
+ {
+ sqltypes.Numeric: _FBNumeric_kinterbasdb,
+ sqltypes.Float: _FBFloat_kinterbasdb,
+ },
+ )
+
+ def __init__(
+ self,
+ type_conv=200,
+ concurrency_level=1,
+ enable_rowcount=True,
+ retaining=False,
+ **kwargs
+ ):
+ super(FBDialect_kinterbasdb, self).__init__(**kwargs)
+ self.enable_rowcount = enable_rowcount
+ self.type_conv = type_conv
+ self.concurrency_level = concurrency_level
+ self.retaining = retaining
+ if enable_rowcount:
+ self.supports_sane_rowcount = True
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("kinterbasdb")
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ # kinterbase does not accept a None, but wants an empty list
+ # when there are no arguments.
+ cursor.execute(statement, parameters or [])
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback(self.retaining)
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit(self.retaining)
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "type_conv", int)
+
+ type_conv = opts.pop("type_conv", self.type_conv)
+ concurrency_level = opts.pop(
+ "concurrency_level", self.concurrency_level
+ )
+
+ if self.dbapi is not None:
+ initialized = getattr(self.dbapi, "initialized", None)
+ if initialized is None:
+ # CVS rev 1.96 changed the name of the attribute:
+ # https://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
+ # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
+ initialized = getattr(self.dbapi, "_initialized", False)
+ if not initialized:
+ self.dbapi.init(
+ type_conv=type_conv, concurrency_level=concurrency_level
+ )
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ """Get the version of the Firebird server used by a connection.
+
+ Returns a tuple of (`major`, `minor`, `build`), three integers
+ representing the version of the attached server.
+ """
+
+ # This is the simpler approach (the other uses the services api),
+ # that for backward compatibility reasons returns a string like
+ # LI-V6.3.3.12981 Firebird 2.0
+ # where the first version is a fake one resembling the old
+ # Interbase signature.
+
+ fbconn = connection.connection
+ version = fbconn.server_version
+
+ return self._parse_version_info(version)
+
+ def _parse_version_info(self, version):
+ m = match(
+ r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version
+ )
+ if not m:
+ raise AssertionError(
+ "Could not determine version from string '%s'" % version
+ )
+
+ if m.group(5) != None:
+ return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"])
+ else:
+ return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"])
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ msg = str(e)
+ return (
+ "Error writing data to the connection" in msg
+ or "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ or "connection shutdown" in msg
+ )
+ else:
+ return False
+
+
+dialect = FBDialect_kinterbasdb
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py
new file mode 100644
index 0000000..cae0168
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/__init__.py
@@ -0,0 +1,85 @@
+# mssql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import mxodbc # noqa
+from . import pymssql # noqa
+from . import pyodbc # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DATETIME2
+from .base import DATETIMEOFFSET
+from .base import DECIMAL
+from .base import FLOAT
+from .base import IMAGE
+from .base import INTEGER
+from .base import JSON
+from .base import MONEY
+from .base import NCHAR
+from .base import NTEXT
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import REAL
+from .base import ROWVERSION
+from .base import SMALLDATETIME
+from .base import SMALLINT
+from .base import SMALLMONEY
+from .base import SQL_VARIANT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TINYINT
+from .base import try_cast
+from .base import UNIQUEIDENTIFIER
+from .base import VARBINARY
+from .base import VARCHAR
+from .base import XML
+
+
+base.dialect = dialect = pyodbc.dialect
+
+
+__all__ = (
+ "JSON",
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "TINYINT",
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "NCHAR",
+ "TEXT",
+ "NTEXT",
+ "DECIMAL",
+ "NUMERIC",
+ "FLOAT",
+ "DATETIME",
+ "DATETIME2",
+ "DATETIMEOFFSET",
+ "DATE",
+ "TIME",
+ "SMALLDATETIME",
+ "BINARY",
+ "VARBINARY",
+ "BIT",
+ "REAL",
+ "IMAGE",
+ "TIMESTAMP",
+ "ROWVERSION",
+ "MONEY",
+ "SMALLMONEY",
+ "UNIQUEIDENTIFIER",
+ "SQL_VARIANT",
+ "XML",
+ "dialect",
+ "try_cast",
+)
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
new file mode 100644
index 0000000..ee6ce87
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -0,0 +1,3545 @@
+# mssql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: mssql
+ :name: Microsoft SQL Server
+ :full_support: 2017
+ :normal_support: 2012+
+ :best_effort: 2005+
+
+.. _mssql_external_dialects:
+
+External Dialects
+-----------------
+
+In addition to the above DBAPI layers with native SQLAlchemy support, there
+are third-party dialects for other DBAPI layers that are compatible
+with SQL Server. See the "External Dialects" list on the
+:ref:`dialect_toplevel` page.
+
+.. _mssql_identity:
+
+Auto Increment Behavior / IDENTITY Columns
+------------------------------------------
+
+SQL Server provides so-called "auto incrementing" behavior using the
+``IDENTITY`` construct, which can be placed on any single integer column in a
+table. SQLAlchemy considers ``IDENTITY`` within its default "autoincrement"
+behavior for an integer primary key column, described at
+:paramref:`_schema.Column.autoincrement`. This means that by default,
+the first integer primary key column in a :class:`_schema.Table` will be
+considered to be the identity column - unless it is associated with a
+:class:`.Sequence` - and will generate DDL as such::
+
+ from sqlalchemy import Table, MetaData, Column, Integer
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ m.create_all(engine)
+
+The above example will generate DDL as:
+
+.. sourcecode:: sql
+
+ CREATE TABLE t (
+ id INTEGER NOT NULL IDENTITY,
+ x INTEGER NULL,
+ PRIMARY KEY (id)
+ )
+
+For the case where this default generation of ``IDENTITY`` is not desired,
+specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag,
+on the first integer primary key column::
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True, autoincrement=False),
+ Column('x', Integer))
+ m.create_all(engine)
+
+To add the ``IDENTITY`` keyword to a non-primary key column, specify
+``True`` for the :paramref:`_schema.Column.autoincrement` flag on the desired
+:class:`_schema.Column` object, and ensure that
+:paramref:`_schema.Column.autoincrement`
+is set to ``False`` on any integer primary key column::
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True, autoincrement=False),
+ Column('x', Integer, autoincrement=True))
+ m.create_all(engine)
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the start and increment
+ parameters of an IDENTITY. These replace
+ the use of the :class:`.Sequence` object in order to specify these values.
+
+.. deprecated:: 1.4
+
+ The ``mssql_identity_start`` and ``mssql_identity_increment`` parameters
+ to :class:`_schema.Column` are deprecated and should we replaced by
+ an :class:`_schema.Identity` object. Specifying both ways of configuring
+ an IDENTITY will result in a compile error.
+ These options are also no longer returned as part of the
+ ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`.
+ Use the information in the ``identity`` key instead.
+
+.. deprecated:: 1.3
+
+ The use of :class:`.Sequence` to specify IDENTITY characteristics is
+ deprecated and will be removed in a future release. Please use
+ the :class:`_schema.Identity` object parameters
+ :paramref:`_schema.Identity.start` and
+ :paramref:`_schema.Identity.increment`.
+
+.. versionchanged:: 1.4 Removed the ability to use a :class:`.Sequence`
+ object to modify IDENTITY characteristics. :class:`.Sequence` objects
+ now only manipulate true T-SQL SEQUENCE types.
+
+.. note::
+
+ There can only be one IDENTITY column on the table. When using
+ ``autoincrement=True`` to enable the IDENTITY keyword, SQLAlchemy does not
+ guard against multiple columns specifying the option simultaneously. The
+ SQL Server database will instead reject the ``CREATE TABLE`` statement.
+
+.. note::
+
+ An INSERT statement which attempts to provide a value for a column that is
+ marked with IDENTITY will be rejected by SQL Server. In order for the
+ value to be accepted, a session-level option "SET IDENTITY_INSERT" must be
+ enabled. The SQLAlchemy SQL Server dialect will perform this operation
+ automatically when using a core :class:`_expression.Insert`
+ construct; if the
+ execution specifies a value for the IDENTITY column, the "IDENTITY_INSERT"
+ option will be enabled for the span of that statement's invocation.However,
+ this scenario is not high performing and should not be relied upon for
+ normal use. If a table doesn't actually require IDENTITY behavior in its
+ integer primary key column, the keyword should be disabled when creating
+ the table by ensuring that ``autoincrement=False`` is set.
+
+Controlling "Start" and "Increment"
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Specific control over the "start" and "increment" values for
+the ``IDENTITY`` generator are provided using the
+:paramref:`_schema.Identity.start` and :paramref:`_schema.Identity.increment`
+parameters passed to the :class:`_schema.Identity` object::
+
+ from sqlalchemy import Table, Integer, Column, Identity
+
+ test = Table(
+ 'test', metadata,
+ Column(
+ 'id',
+ Integer,
+ primary_key=True,
+ Identity(start=100, increment=10)
+ ),
+ Column('name', String(20))
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE test (
+ id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+ name VARCHAR(20) NULL,
+ )
+
+.. note::
+
+ The :class:`_schema.Identity` object supports many other parameter in
+ addition to ``start`` and ``increment``. These are not supported by
+ SQL Server and will be ignored when generating the CREATE TABLE ddl.
+
+.. versionchanged:: 1.3.19 The :class:`_schema.Identity` object is
+ now used to affect the
+ ``IDENTITY`` generator for a :class:`_schema.Column` under SQL Server.
+ Previously, the :class:`.Sequence` object was used. As SQL Server now
+ supports real sequences as a separate construct, :class:`.Sequence` will be
+ functional in the normal way starting from SQLAlchemy version 1.4.
+
+
+Using IDENTITY with Non-Integer numeric types
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+SQL Server also allows ``IDENTITY`` to be used with ``NUMERIC`` columns. To
+implement this pattern smoothly in SQLAlchemy, the primary datatype of the
+column should remain as ``Integer``, however the underlying implementation
+type deployed to the SQL Server database can be specified as ``Numeric`` using
+:meth:`.TypeEngine.with_variant`::
+
+ from sqlalchemy import Column
+ from sqlalchemy import Integer
+ from sqlalchemy import Numeric
+ from sqlalchemy import String
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class TestTable(Base):
+ __tablename__ = "test"
+ id = Column(
+ Integer().with_variant(Numeric(10, 0), "mssql"),
+ primary_key=True,
+ autoincrement=True,
+ )
+ name = Column(String)
+
+In the above example, ``Integer().with_variant()`` provides clear usage
+information that accurately describes the intent of the code. The general
+restriction that ``autoincrement`` only applies to ``Integer`` is established
+at the metadata level and not at the per-dialect level.
+
+When using the above pattern, the primary key identifier that comes back from
+the insertion of a row, which is also the value that would be assigned to an
+ORM object such as ``TestTable`` above, will be an instance of ``Decimal()``
+and not ``int`` when using SQL Server. The numeric return type of the
+:class:`_types.Numeric` type can be changed to return floats by passing False
+to :paramref:`_types.Numeric.asdecimal`. To normalize the return type of the
+above ``Numeric(10, 0)`` to return Python ints (which also support "long"
+integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
+
+ from sqlalchemy import TypeDecorator
+
+ class NumericAsInteger(TypeDecorator):
+ '''normalize floating point return values into ints'''
+
+ impl = Numeric(10, 0, asdecimal=False)
+ cache_ok = True
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = int(value)
+ return value
+
+ class TestTable(Base):
+ __tablename__ = "test"
+ id = Column(
+ Integer().with_variant(NumericAsInteger, "mssql"),
+ primary_key=True,
+ autoincrement=True,
+ )
+ name = Column(String)
+
+
+INSERT behavior
+^^^^^^^^^^^^^^^^
+
+Handling of the ``IDENTITY`` column at INSERT time involves two key
+techniques. The most common is being able to fetch the "last inserted value"
+for a given ``IDENTITY`` column, a process which SQLAlchemy performs
+implicitly in many cases, most importantly within the ORM.
+
+The process for fetching this value has several variants:
+
+* In the vast majority of cases, RETURNING is used in conjunction with INSERT
+ statements on SQL Server in order to get newly generated primary key values:
+
+ .. sourcecode:: sql
+
+ INSERT INTO t (x) OUTPUT inserted.id VALUES (?)
+
+* When RETURNING is not available or has been disabled via
+ ``implicit_returning=False``, either the ``scope_identity()`` function or
+ the ``@@identity`` variable is used; behavior varies by backend:
+
+ * when using PyODBC, the phrase ``; select scope_identity()`` will be
+ appended to the end of the INSERT statement; a second result set will be
+ fetched in order to receive the value. Given a table as::
+
+ t = Table('t', m, Column('id', Integer, primary_key=True),
+ Column('x', Integer),
+ implicit_returning=False)
+
+ an INSERT will look like:
+
+ .. sourcecode:: sql
+
+ INSERT INTO t (x) VALUES (?); select scope_identity()
+
+ * Other dialects such as pymssql will call upon
+ ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT
+ statement. If the flag ``use_scope_identity=False`` is passed to
+ :func:`_sa.create_engine`,
+ the statement ``SELECT @@identity AS lastrowid``
+ is used instead.
+
+A table that contains an ``IDENTITY`` column will prohibit an INSERT statement
+that refers to the identity column explicitly. The SQLAlchemy dialect will
+detect when an INSERT construct, created using a core
+:func:`_expression.insert`
+construct (not a plain string SQL), refers to the identity column, and
+in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert
+statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the
+execution. Given this example::
+
+ m = MetaData()
+ t = Table('t', m, Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ m.create_all(engine)
+
+ with engine.begin() as conn:
+ conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
+
+The above column will be created with IDENTITY, however the INSERT statement
+we emit is specifying explicit values. In the echo output we can see
+how SQLAlchemy handles this:
+
+.. sourcecode:: sql
+
+ CREATE TABLE t (
+ id INTEGER NOT NULL IDENTITY(1,1),
+ x INTEGER NULL,
+ PRIMARY KEY (id)
+ )
+
+ COMMIT
+ SET IDENTITY_INSERT t ON
+ INSERT INTO t (id, x) VALUES (?, ?)
+ ((1, 1), (2, 2))
+ SET IDENTITY_INSERT t OFF
+ COMMIT
+
+
+
+This is an auxiliary use case suitable for testing and bulk insert scenarios.
+
+SEQUENCE support
+----------------
+
+The :class:`.Sequence` object now creates "real" sequences, i.e.,
+``CREATE SEQUENCE``. To provide compatibility with other dialects,
+:class:`.Sequence` defaults to a start value of 1, even though the
+T-SQL defaults is -9223372036854775808.
+
+.. versionadded:: 1.4.0
+
+MAX on VARCHAR / NVARCHAR
+-------------------------
+
+SQL Server supports the special string "MAX" within the
+:class:`_types.VARCHAR` and :class:`_types.NVARCHAR` datatypes,
+to indicate "maximum length possible". The dialect currently handles this as
+a length of "None" in the base type, rather than supplying a
+dialect-specific version of these types, so that a base type
+specified such as ``VARCHAR(None)`` can assume "unlengthed" behavior on
+more than one backend without using dialect-specific types.
+
+To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None::
+
+ my_table = Table(
+ 'my_table', metadata,
+ Column('my_data', VARCHAR(None)),
+ Column('my_n_data', NVARCHAR(None))
+ )
+
+
+Collation Support
+-----------------
+
+Character collations are supported by the base string types,
+specified by the string argument "collation"::
+
+ from sqlalchemy import VARCHAR
+ Column('login', VARCHAR(32, collation='Latin1_General_CI_AS'))
+
+When such a column is associated with a :class:`_schema.Table`, the
+CREATE TABLE statement for this column will yield::
+
+ login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
+
+LIMIT/OFFSET Support
+--------------------
+
+MSSQL has added support for LIMIT / OFFSET as of SQL Server 2012, via the
+"OFFSET n ROWS" and "FETCH NEXT n ROWS" clauses. SQLAlchemy supports these
+syntaxes automatically if SQL Server 2012 or greater is detected.
+
+.. versionchanged:: 1.4 support added for SQL Server "OFFSET n ROWS" and
+ "FETCH NEXT n ROWS" syntax.
+
+For statements that specify only LIMIT and no OFFSET, all versions of SQL
+Server support the TOP keyword. This syntax is used for all SQL Server
+versions when no OFFSET clause is present. A statement such as::
+
+ select(some_table).limit(5)
+
+will render similarly to::
+
+ SELECT TOP 5 col1, col2.. FROM table
+
+For versions of SQL Server prior to SQL Server 2012, a statement that uses
+LIMIT and OFFSET, or just OFFSET alone, will be rendered using the
+``ROW_NUMBER()`` window function. A statement such as::
+
+ select(some_table).order_by(some_table.c.col3).limit(5).offset(10)
+
+will render similarly to::
+
+ SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2,
+ ROW_NUMBER() OVER (ORDER BY col3) AS
+ mssql_rn FROM table WHERE t.x = :x_1) AS
+ anon_1 WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1
+
+Note that when using LIMIT and/or OFFSET, whether using the older
+or newer SQL Server syntaxes, the statement must have an ORDER BY as well,
+else a :class:`.CompileError` is raised.
+
+.. _mssql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+All SQL Server dialects support setting of transaction isolation level
+both via a dialect-specific parameter
+:paramref:`_sa.create_engine.isolation_level`
+accepted by :func:`_sa.create_engine`,
+as well as the :paramref:`.Connection.execution_options.isolation_level`
+argument as passed to
+:meth:`_engine.Connection.execution_options`.
+This feature works by issuing the
+command ``SET TRANSACTION ISOLATION LEVEL <level>`` for
+each new connection.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@ms_2008",
+ isolation_level="REPEATABLE READ"
+ )
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="READ COMMITTED"
+ )
+
+Valid values for ``isolation_level`` include:
+
+* ``AUTOCOMMIT`` - pyodbc / pymssql-specific
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``SNAPSHOT`` - specific to SQL Server
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+Nullability
+-----------
+MSSQL has support for three levels of column nullability. The default
+nullability allows nulls and is explicit in the CREATE TABLE
+construct::
+
+ name VARCHAR(20) NULL
+
+If ``nullable=None`` is specified then no specification is made. In
+other words the database's configured default is used. This will
+render::
+
+ name VARCHAR(20)
+
+If ``nullable`` is ``True`` or ``False`` then the column will be
+``NULL`` or ``NOT NULL`` respectively.
+
+Date / Time Handling
+--------------------
+DATE and TIME are supported. Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
+
+.. _mssql_large_type_deprecation:
+
+Large Text/Binary Type Deprecation
+----------------------------------
+
+Per
+`SQL Server 2012/2014 Documentation <https://technet.microsoft.com/en-us/library/ms187993.aspx>`_,
+the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL
+Server in a future release. SQLAlchemy normally relates these types to the
+:class:`.UnicodeText`, :class:`_expression.TextClause` and
+:class:`.LargeBinary` datatypes.
+
+In order to accommodate this change, a new flag ``deprecate_large_types``
+is added to the dialect, which will be automatically set based on detection
+of the server version in use, if not otherwise set by the user. The
+behavior of this flag is as follows:
+
+* When this flag is ``True``, the :class:`.UnicodeText`,
+ :class:`_expression.TextClause` and
+ :class:`.LargeBinary` datatypes, when used to render DDL, will render the
+ types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``,
+ respectively. This is a new behavior as of the addition of this flag.
+
+* When this flag is ``False``, the :class:`.UnicodeText`,
+ :class:`_expression.TextClause` and
+ :class:`.LargeBinary` datatypes, when used to render DDL, will render the
+ types ``NTEXT``, ``TEXT``, and ``IMAGE``,
+ respectively. This is the long-standing behavior of these types.
+
+* The flag begins with the value ``None``, before a database connection is
+ established. If the dialect is used to render DDL without the flag being
+ set, it is interpreted the same as ``False``.
+
+* On first connection, the dialect detects if SQL Server version 2012 or
+ greater is in use; if the flag is still at ``None``, it sets it to ``True``
+ or ``False`` based on whether 2012 or greater is detected.
+
+* The flag can be set to either ``True`` or ``False`` when the dialect
+ is created, typically via :func:`_sa.create_engine`::
+
+ eng = create_engine("mssql+pymssql://user:pass@host/db",
+ deprecate_large_types=True)
+
+* Complete control over whether the "old" or "new" types are rendered is
+ available in all SQLAlchemy versions by using the UPPERCASE type objects
+ instead: :class:`_types.NVARCHAR`, :class:`_types.VARCHAR`,
+ :class:`_types.VARBINARY`, :class:`_types.TEXT`, :class:`_mssql.NTEXT`,
+ :class:`_mssql.IMAGE`
+ will always remain fixed and always output exactly that
+ type.
+
+.. versionadded:: 1.0.0
+
+.. _multipart_schema_names:
+
+Multipart Schema Names
+----------------------
+
+SQL Server schemas sometimes require multiple parts to their "schema"
+qualifier, that is, including the database name and owner name as separate
+tokens, such as ``mydatabase.dbo.some_table``. These multipart names can be set
+at once using the :paramref:`_schema.Table.schema` argument of
+:class:`_schema.Table`::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="mydatabase.dbo"
+ )
+
+When performing operations such as table or component reflection, a schema
+argument that contains a dot will be split into separate
+"database" and "owner" components in order to correctly query the SQL
+Server information schema tables, as these two values are stored separately.
+Additionally, when rendering the schema name for DDL or SQL, the two
+components will be quoted separately for case sensitive names and other
+special characters. Given an argument as below::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="MyDataBase.dbo"
+ )
+
+The above schema would be rendered as ``[MyDataBase].dbo``, and also in
+reflection, would be reflected using "dbo" as the owner and "MyDataBase"
+as the database name.
+
+To control how the schema name is broken into database / owner,
+specify brackets (which in SQL Server are quoting characters) in the name.
+Below, the "owner" will be considered as ``MyDataBase.dbo`` and the
+"database" will be None::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="[MyDataBase.dbo]"
+ )
+
+To individually specify both database and owner name with special characters
+or embedded dots, use two sets of brackets::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="[MyDataBase.Period].[MyOwner.Dot]"
+ )
+
+
+.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as
+ identifier delimiters splitting the schema into separate database
+ and owner tokens, to allow dots within either name itself.
+
+.. _legacy_schema_rendering:
+
+Legacy Schema Mode
+------------------
+
+Very old versions of the MSSQL dialect introduced the behavior such that a
+schema-qualified table would be auto-aliased when used in a
+SELECT statement; given a table::
+
+ account_table = Table(
+ 'account', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('info', String(100)),
+ schema="customer_schema"
+ )
+
+this legacy mode of rendering would assume that "customer_schema.account"
+would not be accepted by all parts of the SQL statement, as illustrated
+below::
+
+ >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True)
+ >>> print(account_table.select().compile(eng))
+ SELECT account_1.id, account_1.info
+ FROM customer_schema.account AS account_1
+
+This mode of behavior is now off by default, as it appears to have served
+no purpose; however in the case that legacy applications rely upon it,
+it is available using the ``legacy_schema_aliasing`` argument to
+:func:`_sa.create_engine` as illustrated above.
+
+.. versionchanged:: 1.1 the ``legacy_schema_aliasing`` flag introduced
+ in version 1.0.5 to allow disabling of legacy mode for schemas now
+ defaults to False.
+
+.. deprecated:: 1.4
+
+ The ``legacy_schema_aliasing`` flag is now
+ deprecated and will be removed in a future release.
+
+.. _mssql_indexes:
+
+Clustered Index Support
+-----------------------
+
+The MSSQL dialect supports clustered indexes (and primary keys) via the
+``mssql_clustered`` option. This option is available to :class:`.Index`,
+:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`.
+
+To generate a clustered index::
+
+ Index("my_index", table.c.x, mssql_clustered=True)
+
+which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``.
+
+To generate a clustered primary key use::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x", "y", mssql_clustered=True))
+
+which will render the table, for example, as::
+
+ CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
+ PRIMARY KEY CLUSTERED (x, y))
+
+Similarly, we can generate a clustered unique constraint using::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x"),
+ UniqueConstraint("y", mssql_clustered=True),
+ )
+
+To explicitly request a non-clustered primary key (for example, when
+a separate clustered index is desired), use::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x", "y", mssql_clustered=False))
+
+which will render the table, for example, as::
+
+ CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
+ PRIMARY KEY NONCLUSTERED (x, y))
+
+.. versionchanged:: 1.1 the ``mssql_clustered`` option now defaults
+ to None, rather than False. ``mssql_clustered=False`` now explicitly
+ renders the NONCLUSTERED clause, whereas None omits the CLUSTERED
+ clause entirely, allowing SQL Server defaults to take effect.
+
+
+MSSQL-Specific Index Options
+-----------------------------
+
+In addition to clustering, the MSSQL dialect supports other special options
+for :class:`.Index`.
+
+INCLUDE
+^^^^^^^
+
+The ``mssql_include`` option renders INCLUDE(colname) for the given string
+names::
+
+ Index("my_index", table.c.x, mssql_include=['y'])
+
+would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
+
+.. _mssql_index_where:
+
+Filtered Indexes
+^^^^^^^^^^^^^^^^
+
+The ``mssql_where`` option renders WHERE(condition) for the given string
+names::
+
+ Index("my_index", table.c.x, mssql_where=table.c.x > 10)
+
+would render the index as ``CREATE INDEX my_index ON table (x) WHERE x > 10``.
+
+.. versionadded:: 1.3.4
+
+Index ordering
+^^^^^^^^^^^^^^
+
+Index ordering is available via functional expressions, such as::
+
+ Index("my_index", table.c.x.desc())
+
+would render the index as ``CREATE INDEX my_index ON table (x DESC)``
+
+.. seealso::
+
+ :ref:`schema_indexes_functional`
+
+Compatibility Levels
+--------------------
+MSSQL supports the notion of setting compatibility levels at the
+database level. This allows, for instance, to run a database that
+is compatible with SQL2000 while running on a SQL2005 database
+server. ``server_version_info`` will always return the database
+server version information (in this case SQL2005) and not the
+compatibility level information. Because of this, if running under
+a backwards compatibility mode SQLAlchemy may attempt to use T-SQL
+statements that are unable to be parsed by the database server.
+
+Triggers
+--------
+
+SQLAlchemy by default uses OUTPUT INSERTED to get at newly
+generated primary key values via IDENTITY columns or other
+server side defaults. MS-SQL does not
+allow the usage of OUTPUT INSERTED on tables that have triggers.
+To disable the usage of OUTPUT INSERTED on a per-table basis,
+specify ``implicit_returning=False`` for each :class:`_schema.Table`
+which has triggers::
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ # ...,
+ implicit_returning=False
+ )
+
+Declarative form::
+
+ class MyClass(Base):
+ # ...
+ __table_args__ = {'implicit_returning':False}
+
+
+This option can also be specified engine-wide using the
+``implicit_returning=False`` argument on :func:`_sa.create_engine`.
+
+.. _mssql_rowcount_versioning:
+
+Rowcount Support / ORM Versioning
+---------------------------------
+
+The SQL Server drivers may have limited ability to return the number
+of rows updated from an UPDATE or DELETE statement.
+
+As of this writing, the PyODBC driver is not able to return a rowcount when
+OUTPUT INSERTED is used. This impacts the SQLAlchemy ORM's versioning feature
+in many cases where server-side value generators are in use in that while the
+versioning operations can succeed, the ORM cannot always check that an UPDATE
+or DELETE statement matched the number of rows expected, which is how it
+verifies that the version identifier matched. When this condition occurs, a
+warning will be emitted but the operation will proceed.
+
+The use of OUTPUT INSERTED can be disabled by setting the
+:paramref:`_schema.Table.implicit_returning` flag to ``False`` on a particular
+:class:`_schema.Table`, which in declarative looks like::
+
+ class MyTable(Base):
+ __tablename__ = 'mytable'
+ id = Column(Integer, primary_key=True)
+ stuff = Column(String(10))
+ timestamp = Column(TIMESTAMP(), default=text('DEFAULT'))
+ __mapper_args__ = {
+ 'version_id_col': timestamp,
+ 'version_id_generator': False,
+ }
+ __table_args__ = {
+ 'implicit_returning': False
+ }
+
+Enabling Snapshot Isolation
+---------------------------
+
+SQL Server has a default transaction
+isolation mode that locks entire tables, and causes even mildly concurrent
+applications to have long held locks and frequent deadlocks.
+Enabling snapshot isolation for the database as a whole is recommended
+for modern levels of concurrency support. This is accomplished via the
+following ALTER DATABASE commands executed at the SQL prompt::
+
+ ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON
+
+ ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON
+
+Background on SQL Server snapshot isolation is available at
+https://msdn.microsoft.com/en-us/library/ms175095.aspx.
+
+""" # noqa
+
+import codecs
+import datetime
+import operator
+import re
+
+from . import information_schema as ischema
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from ... import exc
+from ... import Identity
+from ... import schema as sa_schema
+from ... import Sequence
+from ... import sql
+from ... import text
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import expression
+from ...sql import func
+from ...sql import quoted_name
+from ...sql import roles
+from ...sql import util as sql_util
+from ...types import BIGINT
+from ...types import BINARY
+from ...types import CHAR
+from ...types import DATE
+from ...types import DATETIME
+from ...types import DECIMAL
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NCHAR
+from ...types import NUMERIC
+from ...types import NVARCHAR
+from ...types import SMALLINT
+from ...types import TEXT
+from ...types import VARCHAR
+from ...util import compat
+from ...util import update_wrapper
+from ...util.langhelpers import public_factory
+
+
+# https://sqlserverbuilds.blogspot.com/
+MS_2017_VERSION = (14,)
+MS_2016_VERSION = (13,)
+MS_2014_VERSION = (12,)
+MS_2012_VERSION = (11,)
+MS_2008_VERSION = (10,)
+MS_2005_VERSION = (9,)
+MS_2000_VERSION = (8,)
+
+RESERVED_WORDS = set(
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "authorization",
+ "backup",
+ "begin",
+ "between",
+ "break",
+ "browse",
+ "bulk",
+ "by",
+ "cascade",
+ "case",
+ "check",
+ "checkpoint",
+ "close",
+ "clustered",
+ "coalesce",
+ "collate",
+ "column",
+ "commit",
+ "compute",
+ "constraint",
+ "contains",
+ "containstable",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "dbcc",
+ "deallocate",
+ "declare",
+ "default",
+ "delete",
+ "deny",
+ "desc",
+ "disk",
+ "distinct",
+ "distributed",
+ "double",
+ "drop",
+ "dump",
+ "else",
+ "end",
+ "errlvl",
+ "escape",
+ "except",
+ "exec",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "fetch",
+ "file",
+ "fillfactor",
+ "for",
+ "foreign",
+ "freetext",
+ "freetexttable",
+ "from",
+ "full",
+ "function",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identity",
+ "identity_insert",
+ "identitycol",
+ "if",
+ "in",
+ "index",
+ "inner",
+ "insert",
+ "intersect",
+ "into",
+ "is",
+ "join",
+ "key",
+ "kill",
+ "left",
+ "like",
+ "lineno",
+ "load",
+ "merge",
+ "national",
+ "nocheck",
+ "nonclustered",
+ "not",
+ "null",
+ "nullif",
+ "of",
+ "off",
+ "offsets",
+ "on",
+ "open",
+ "opendatasource",
+ "openquery",
+ "openrowset",
+ "openxml",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "percent",
+ "pivot",
+ "plan",
+ "precision",
+ "primary",
+ "print",
+ "proc",
+ "procedure",
+ "public",
+ "raiserror",
+ "read",
+ "readtext",
+ "reconfigure",
+ "references",
+ "replication",
+ "restore",
+ "restrict",
+ "return",
+ "revert",
+ "revoke",
+ "right",
+ "rollback",
+ "rowcount",
+ "rowguidcol",
+ "rule",
+ "save",
+ "schema",
+ "securityaudit",
+ "select",
+ "session_user",
+ "set",
+ "setuser",
+ "shutdown",
+ "some",
+ "statistics",
+ "system_user",
+ "table",
+ "tablesample",
+ "textsize",
+ "then",
+ "to",
+ "top",
+ "tran",
+ "transaction",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "union",
+ "unique",
+ "unpivot",
+ "update",
+ "updatetext",
+ "use",
+ "user",
+ "values",
+ "varying",
+ "view",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "with",
+ "writetext",
+ ]
+)
+
+
+class REAL(sqltypes.REAL):
+ __visit_name__ = "REAL"
+
+ def __init__(self, **kw):
+ # REAL is a synonym for FLOAT(24) on SQL server.
+ # it is only accepted as the word "REAL" in DDL, the numeric
+ # precision value is not allowed to be present
+ kw.setdefault("precision", 24)
+ super(REAL, self).__init__(**kw)
+
+
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = "TINYINT"
+
+
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings. MSDate/TIME check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+
+class _MSDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ elif isinstance(value, util.string_types):
+ m = self._reg.match(value)
+ if not m:
+ raise ValueError(
+ "could not parse %r as a date value" % (value,)
+ )
+ return datetime.date(*[int(x or 0) for x in m.groups()])
+ else:
+ return value
+
+ return process
+
+
+class TIME(sqltypes.TIME):
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+ super(TIME, self).__init__()
+
+ __zero_date = datetime.date(1900, 1, 1)
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ value = datetime.datetime.combine(
+ self.__zero_date, value.time()
+ )
+ elif isinstance(value, datetime.time):
+ """issue #5339
+ per: https://github.com/mkleehammer/pyodbc/wiki/Tips-and-Tricks-by-Database-Platform#time-columns
+ pass TIME value as string
+ """ # noqa
+ value = str(value)
+ return value
+
+ return process
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.time()
+ elif isinstance(value, util.string_types):
+ m = self._reg.match(value)
+ if not m:
+ raise ValueError(
+ "could not parse %r as a time value" % (value,)
+ )
+ return datetime.time(*[int(x or 0) for x in m.groups()])
+ else:
+ return value
+
+ return process
+
+
+_MSTime = TIME
+
+
+class _BASETIMEIMPL(TIME):
+ __visit_name__ = "_BASETIMEIMPL"
+
+
+class _DateTimeBase(object):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+
+class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
+ pass
+
+
+class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "SMALLDATETIME"
+
+
+class DATETIME2(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "DATETIME2"
+
+ def __init__(self, precision=None, **kw):
+ super(DATETIME2, self).__init__(**kw)
+ self.precision = precision
+
+
+class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "DATETIMEOFFSET"
+
+ def __init__(self, precision=None, **kw):
+ super(DATETIMEOFFSET, self).__init__(**kw)
+ self.precision = precision
+
+
+class _UnicodeLiteral(object):
+ def literal_processor(self, dialect):
+ def process(value):
+
+ value = value.replace("'", "''")
+
+ if dialect.identifier_preparer._double_percents:
+ value = value.replace("%", "%%")
+
+ return "N'%s'" % value
+
+ return process
+
+
+class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode):
+ pass
+
+
+class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText):
+ pass
+
+
+class TIMESTAMP(sqltypes._Binary):
+ """Implement the SQL Server TIMESTAMP type.
+
+ Note this is **completely different** than the SQL Standard
+ TIMESTAMP type, which is not supported by SQL Server. It
+ is a read-only datatype that does not support INSERT of values.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :class:`_mssql.ROWVERSION`
+
+ """
+
+ __visit_name__ = "TIMESTAMP"
+
+ # expected by _Binary to be present
+ length = None
+
+ def __init__(self, convert_int=False):
+ """Construct a TIMESTAMP or ROWVERSION type.
+
+ :param convert_int: if True, binary integer values will
+ be converted to integers on read.
+
+ .. versionadded:: 1.2
+
+ """
+ self.convert_int = convert_int
+
+ def result_processor(self, dialect, coltype):
+ super_ = super(TIMESTAMP, self).result_processor(dialect, coltype)
+ if self.convert_int:
+
+ def process(value):
+ value = super_(value)
+ if value is not None:
+ # https://stackoverflow.com/a/30403242/34549
+ value = int(codecs.encode(value, "hex"), 16)
+ return value
+
+ return process
+ else:
+ return super_
+
+
+class ROWVERSION(TIMESTAMP):
+ """Implement the SQL Server ROWVERSION type.
+
+ The ROWVERSION datatype is a SQL Server synonym for the TIMESTAMP
+ datatype, however current SQL Server documentation suggests using
+ ROWVERSION for new datatypes going forward.
+
+ The ROWVERSION datatype does **not** reflect (e.g. introspect) from the
+ database as itself; the returned datatype will be
+ :class:`_mssql.TIMESTAMP`.
+
+ This is a read-only datatype that does not support INSERT of values.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :class:`_mssql.TIMESTAMP`
+
+ """
+
+ __visit_name__ = "ROWVERSION"
+
+
+class NTEXT(sqltypes.UnicodeText):
+
+ """MSSQL NTEXT type, for variable-length unicode text up to 2^30
+ characters."""
+
+ __visit_name__ = "NTEXT"
+
+
+class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
+ """The MSSQL VARBINARY type.
+
+ This type adds additional features to the core :class:`_types.VARBINARY`
+ type, including "deprecate_large_types" mode where
+ either ``VARBINARY(max)`` or IMAGE is rendered, as well as the SQL
+ Server ``FILESTREAM`` option.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`mssql_large_type_deprecation`
+
+ """
+
+ __visit_name__ = "VARBINARY"
+
+ def __init__(self, length=None, filestream=False):
+ """
+ Construct a VARBINARY type.
+
+ :param length: optional, a length for the column for use in
+ DDL statements, for those binary types that accept a length,
+ such as the MySQL BLOB type.
+
+ :param filestream=False: if True, renders the ``FILESTREAM`` keyword
+ in the table definition. In this case ``length`` must be ``None``
+ or ``'max'``.
+
+ .. versionadded:: 1.4.31
+
+ """
+
+ self.filestream = filestream
+ if self.filestream and length not in (None, "max"):
+ raise ValueError(
+ "length must be None or 'max' when setting filestream"
+ )
+ super(VARBINARY, self).__init__(length=length)
+
+
+class IMAGE(sqltypes.LargeBinary):
+ __visit_name__ = "IMAGE"
+
+
+class XML(sqltypes.Text):
+ """MSSQL XML type.
+
+ This is a placeholder type for reflection purposes that does not include
+ any Python-side datatype support. It also does not currently support
+ additional arguments, such as "CONTENT", "DOCUMENT",
+ "xml_schema_collection".
+
+ .. versionadded:: 1.1.11
+
+ """
+
+ __visit_name__ = "XML"
+
+
+class BIT(sqltypes.Boolean):
+ """MSSQL BIT type.
+
+ Both pyodbc and pymssql return values from BIT columns as
+ Python <class 'bool'> so just subclass Boolean.
+
+ """
+
+ __visit_name__ = "BIT"
+
+
+class MONEY(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+
+class SMALLMONEY(sqltypes.TypeEngine):
+ __visit_name__ = "SMALLMONEY"
+
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+
+class SQL_VARIANT(sqltypes.TypeEngine):
+ __visit_name__ = "SQL_VARIANT"
+
+
+class TryCast(sql.elements.Cast):
+ """Represent a SQL Server TRY_CAST expression."""
+
+ __visit_name__ = "try_cast"
+
+ stringify_dialect = "mssql"
+ inherit_cache = True
+
+ def __init__(self, *arg, **kw):
+ """Create a TRY_CAST expression.
+
+ :class:`.TryCast` is a subclass of SQLAlchemy's :class:`.Cast`
+ construct, and works in the same way, except that the SQL expression
+ rendered is "TRY_CAST" rather than "CAST"::
+
+ from sqlalchemy import select
+ from sqlalchemy import Numeric
+ from sqlalchemy.dialects.mssql import try_cast
+
+ stmt = select(
+ try_cast(product_table.c.unit_price, Numeric(10, 4))
+ )
+
+ The above would render::
+
+ SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4))
+ FROM product_table
+
+ .. versionadded:: 1.3.7
+
+ """
+ super(TryCast, self).__init__(*arg, **kw)
+
+
+try_cast = public_factory(TryCast, ".dialects.mssql.try_cast")
+
+# old names.
+MSDateTime = _MSDateTime
+MSDate = _MSDate
+MSReal = REAL
+MSTinyInteger = TINYINT
+MSTime = TIME
+MSSmallDateTime = SMALLDATETIME
+MSDateTime2 = DATETIME2
+MSDateTimeOffset = DATETIMEOFFSET
+MSText = TEXT
+MSNText = NTEXT
+MSString = VARCHAR
+MSNVarchar = NVARCHAR
+MSChar = CHAR
+MSNChar = NCHAR
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSImage = IMAGE
+MSBit = BIT
+MSMoney = MONEY
+MSSmallMoney = SMALLMONEY
+MSUniqueIdentifier = UNIQUEIDENTIFIER
+MSVariant = SQL_VARIANT
+
+ischema_names = {
+ "int": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "varchar": VARCHAR,
+ "nvarchar": NVARCHAR,
+ "char": CHAR,
+ "nchar": NCHAR,
+ "text": TEXT,
+ "ntext": NTEXT,
+ "decimal": DECIMAL,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "datetime": DATETIME,
+ "datetime2": DATETIME2,
+ "datetimeoffset": DATETIMEOFFSET,
+ "date": DATE,
+ "time": TIME,
+ "smalldatetime": SMALLDATETIME,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "bit": BIT,
+ "real": REAL,
+ "image": IMAGE,
+ "xml": XML,
+ "timestamp": TIMESTAMP,
+ "money": MONEY,
+ "smallmoney": SMALLMONEY,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+ "sql_variant": SQL_VARIANT,
+}
+
+
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend(self, spec, type_, length=None):
+ """Extend a string-type declaration with standard SQL
+ COLLATE annotations.
+
+ """
+
+ if getattr(type_, "collation", None):
+ collation = "COLLATE %s" % type_.collation
+ else:
+ collation = None
+
+ if not length:
+ length = type_.length
+
+ if length:
+ spec = spec + "(%s)" % length
+
+ return " ".join([c for c in (spec, collation) if c is not None])
+
+ def visit_FLOAT(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is None:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {"precision": precision}
+
+ def visit_TINYINT(self, type_, **kw):
+ return "TINYINT"
+
+ def visit_TIME(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "TIME(%s)" % precision
+ else:
+ return "TIME"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP"
+
+ def visit_ROWVERSION(self, type_, **kw):
+ return "ROWVERSION"
+
+ def visit_datetime(self, type_, **kw):
+ if type_.timezone:
+ return self.visit_DATETIMEOFFSET(type_, **kw)
+ else:
+ return self.visit_DATETIME(type_, **kw)
+
+ def visit_DATETIMEOFFSET(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "DATETIMEOFFSET(%s)" % type_.precision
+ else:
+ return "DATETIMEOFFSET"
+
+ def visit_DATETIME2(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "DATETIME2(%s)" % precision
+ else:
+ return "DATETIME2"
+
+ def visit_SMALLDATETIME(self, type_, **kw):
+ return "SMALLDATETIME"
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_, **kw)
+
+ def visit_text(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_VARCHAR(type_, **kw)
+ else:
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_NVARCHAR(type_, **kw)
+ else:
+ return self.visit_NTEXT(type_, **kw)
+
+ def visit_NTEXT(self, type_, **kw):
+ return self._extend("NTEXT", type_)
+
+ def visit_TEXT(self, type_, **kw):
+ return self._extend("TEXT", type_)
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._extend("VARCHAR", type_, length=type_.length or "max")
+
+ def visit_CHAR(self, type_, **kw):
+ return self._extend("CHAR", type_)
+
+ def visit_NCHAR(self, type_, **kw):
+ return self._extend("NCHAR", type_)
+
+ def visit_NVARCHAR(self, type_, **kw):
+ return self._extend("NVARCHAR", type_, length=type_.length or "max")
+
+ def visit_date(self, type_, **kw):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_, **kw)
+ else:
+ return self.visit_DATE(type_, **kw)
+
+ def visit__BASETIMEIMPL(self, type_, **kw):
+ return self.visit_time(type_, **kw)
+
+ def visit_time(self, type_, **kw):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_, **kw)
+ else:
+ return self.visit_TIME(type_, **kw)
+
+ def visit_large_binary(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_VARBINARY(type_, **kw)
+ else:
+ return self.visit_IMAGE(type_, **kw)
+
+ def visit_IMAGE(self, type_, **kw):
+ return "IMAGE"
+
+ def visit_XML(self, type_, **kw):
+ return "XML"
+
+ def visit_VARBINARY(self, type_, **kw):
+ text = self._extend("VARBINARY", type_, length=type_.length or "max")
+ if getattr(type_, "filestream", False):
+ text += " FILESTREAM"
+ return text
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BIT(type_)
+
+ def visit_BIT(self, type_, **kw):
+ return "BIT"
+
+ def visit_JSON(self, type_, **kw):
+ # this is a bit of a break with SQLAlchemy's convention of
+ # "UPPERCASE name goes to UPPERCASE type name with no modification"
+ return self._extend("NVARCHAR", type_, length="max")
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_, **kw):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
+ return "UNIQUEIDENTIFIER"
+
+ def visit_SQL_VARIANT(self, type_, **kw):
+ return "SQL_VARIANT"
+
+
+class MSExecutionContext(default.DefaultExecutionContext):
+ _enable_identity_insert = False
+ _select_lastrowid = False
+ _lastrowid = None
+ _rowcount = None
+
+ def _opt_encode(self, statement):
+
+ if not self.dialect.supports_unicode_statements:
+ encoded = self.dialect._encoder(statement)[0]
+ else:
+ encoded = statement
+
+ if self.compiled and self.compiled.schema_translate_map:
+
+ rst = self.compiled.preparer._render_schema_translates
+ encoded = rst(encoded, self.compiled.schema_translate_map)
+
+ return encoded
+
+ def pre_exec(self):
+ """Activate IDENTITY_INSERT if needed."""
+
+ if self.isinsert:
+ tbl = self.compiled.compile_state.dml_table
+ id_column = tbl._autoincrement_column
+ insert_has_identity = (id_column is not None) and (
+ not isinstance(id_column.default, Sequence)
+ )
+
+ if insert_has_identity:
+ compile_state = self.compiled.dml_compile_state
+ self._enable_identity_insert = (
+ id_column.key in self.compiled_parameters[0]
+ ) or (
+ compile_state._dict_parameters
+ and (id_column.key in compile_state._insert_col_keys)
+ )
+
+ else:
+ self._enable_identity_insert = False
+
+ self._select_lastrowid = (
+ not self.compiled.inline
+ and insert_has_identity
+ and not self.compiled.returning
+ and not self._enable_identity_insert
+ and not self.executemany
+ )
+
+ if self._enable_identity_insert:
+ self.root_connection._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s ON"
+ % self.identifier_preparer.format_table(tbl)
+ ),
+ (),
+ self,
+ )
+
+ def post_exec(self):
+ """Disable IDENTITY_INSERT if enabled."""
+
+ conn = self.root_connection
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ self._rowcount = self.cursor.rowcount
+
+ if self._select_lastrowid:
+ if self.dialect.use_scope_identity:
+ conn._cursor_execute(
+ self.cursor,
+ "SELECT scope_identity() AS lastrowid",
+ (),
+ self,
+ )
+ else:
+ conn._cursor_execute(
+ self.cursor, "SELECT @@identity AS lastrowid", (), self
+ )
+ # fetchall() ensures the cursor is consumed without closing it
+ row = self.cursor.fetchall()[0]
+ self._lastrowid = int(row[0])
+
+ elif (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and self.compiled.returning:
+ self.cursor_fetch_strategy = (
+ _cursor.FullyBufferedCursorFetchStrategy(
+ self.cursor,
+ self.cursor.description,
+ self.cursor.fetchall(),
+ )
+ )
+
+ if self._enable_identity_insert:
+ conn._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.identifier_preparer.format_table(
+ self.compiled.compile_state.dml_table
+ )
+ ),
+ (),
+ self,
+ )
+
+ def get_lastrowid(self):
+ return self._lastrowid
+
+ @property
+ def rowcount(self):
+ if self._rowcount is not None:
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
+ def handle_dbapi_exception(self, e):
+ if self._enable_identity_insert:
+ try:
+ self.cursor.execute(
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.identifier_preparer.format_table(
+ self.compiled.compile_state.dml_table
+ )
+ )
+ )
+ except Exception:
+ pass
+
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "SELECT NEXT VALUE FOR %s"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+ def get_insert_default(self, column):
+ if (
+ isinstance(column, sa_schema.Column)
+ and column is column.table._autoincrement_column
+ and isinstance(column.default, sa_schema.Sequence)
+ and column.default.optional
+ ):
+ return None
+ return super(MSExecutionContext, self).get_insert_default(column)
+
+
+class MSSQLCompiler(compiler.SQLCompiler):
+ returning_precedes_values = True
+
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
+ "doy": "dayofyear",
+ "dow": "weekday",
+ "milliseconds": "millisecond",
+ "microseconds": "microsecond",
+ },
+ )
+
+ def __init__(self, *args, **kwargs):
+ self.tablealiases = {}
+ super(MSSQLCompiler, self).__init__(*args, **kwargs)
+
+ def _with_legacy_schema_aliasing(fn):
+ def decorate(self, *arg, **kw):
+ if self.dialect.legacy_schema_aliasing:
+ return fn(self, *arg, **kw)
+ else:
+ super_ = getattr(super(MSSQLCompiler, self), fn.__name__)
+ return super_(*arg, **kw)
+
+ return decorate
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_current_date_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def visit_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_concat_op_binary(self, binary, operator, **kw):
+ return "%s + %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ return "CONTAINS (%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def get_select_precolumns(self, select, **kw):
+ """MS-SQL puts TOP, it's version of LIMIT here"""
+
+ s = super(MSSQLCompiler, self).get_select_precolumns(select, **kw)
+
+ if select._has_row_limiting_clause and self._use_top(select):
+ # ODBC drivers and possibly others
+ # don't support bind params in the SELECT clause on SQL Server.
+ # so have to use literal here.
+ kw["literal_execute"] = True
+ s += "TOP %s " % self.process(
+ self._get_limit_or_fetch(select), **kw
+ )
+ if select._fetch_clause is not None:
+ if select._fetch_clause_options["percent"]:
+ s += "PERCENT "
+ if select._fetch_clause_options["with_ties"]:
+ s += "WITH TIES "
+
+ return s
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def get_crud_hint_text(self, table, text):
+ return text
+
+ def _get_limit_or_fetch(self, select):
+ if select._fetch_clause is None:
+ return select._limit_clause
+ else:
+ return select._fetch_clause
+
+ def _use_top(self, select):
+ return (select._offset_clause is None) and (
+ select._simple_int_clause(select._limit_clause)
+ or (
+ # limit can use TOP with is by itself. fetch only uses TOP
+ # when it needs to because of PERCENT and/or WITH TIES
+ select._simple_int_clause(select._fetch_clause)
+ and (
+ select._fetch_clause_options["percent"]
+ or select._fetch_clause_options["with_ties"]
+ )
+ )
+ )
+
+ def fetch_clause(self, cs, **kwargs):
+ return ""
+
+ def limit_clause(self, cs, **kwargs):
+ return ""
+
+ def _check_can_use_fetch_limit(self, select):
+ # to use ROW_NUMBER(), an ORDER BY is required.
+ # OFFSET are FETCH are options of the ORDER BY clause
+ if not select._order_by_clause.clauses:
+ raise exc.CompileError(
+ "MSSQL requires an order_by when "
+ "using an OFFSET or a non-simple "
+ "LIMIT clause"
+ )
+
+ if select._fetch_clause_options is not None and (
+ select._fetch_clause_options["percent"]
+ or select._fetch_clause_options["with_ties"]
+ ):
+ raise exc.CompileError(
+ "MSSQL needs TOP to use PERCENT and/or WITH TIES. "
+ "Only simple fetch without offset can be used."
+ )
+
+ def _row_limit_clause(self, select, **kw):
+ """MSSQL 2012 supports OFFSET/FETCH operators
+ Use it instead subquery with row_number
+
+ """
+
+ if self.dialect._supports_offset_fetch and not self._use_top(select):
+ self._check_can_use_fetch_limit(select)
+
+ text = ""
+
+ if select._offset_clause is not None:
+ offset_str = self.process(select._offset_clause, **kw)
+ else:
+ offset_str = "0"
+ text += "\n OFFSET %s ROWS" % offset_str
+
+ limit = self._get_limit_or_fetch(select)
+
+ if limit is not None:
+ text += "\n FETCH FIRST %s ROWS ONLY" % self.process(
+ limit, **kw
+ )
+ return text
+ else:
+ return ""
+
+ def visit_try_cast(self, element, **kw):
+ return "TRY_CAST (%s AS %s)" % (
+ self.process(element.clause, **kw),
+ self.process(element.typeclause, **kw),
+ )
+
+ def translate_select_structure(self, select_stmt, **kwargs):
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``row_number()`` criterion.
+ MSSQL 2012 and above are excluded
+
+ """
+ select = select_stmt
+
+ if (
+ select._has_row_limiting_clause
+ and not self.dialect._supports_offset_fetch
+ and not self._use_top(select)
+ and not getattr(select, "_mssql_visit", None)
+ ):
+ self._check_can_use_fetch_limit(select)
+
+ _order_by_clauses = [
+ sql_util.unwrap_label_reference(elem)
+ for elem in select._order_by_clause.clauses
+ ]
+
+ limit_clause = self._get_limit_or_fetch(select)
+ offset_clause = select._offset_clause
+
+ select = select._generate()
+ select._mssql_visit = True
+ select = (
+ select.add_columns(
+ sql.func.ROW_NUMBER()
+ .over(order_by=_order_by_clauses)
+ .label("mssql_rn")
+ )
+ .order_by(None)
+ .alias()
+ )
+
+ mssql_rn = sql.column("mssql_rn")
+ limitselect = sql.select(
+ *[c for c in select.c if c.key != "mssql_rn"]
+ )
+ if offset_clause is not None:
+ limitselect = limitselect.where(mssql_rn > offset_clause)
+ if limit_clause is not None:
+ limitselect = limitselect.where(
+ mssql_rn <= (limit_clause + offset_clause)
+ )
+ else:
+ limitselect = limitselect.where(mssql_rn <= (limit_clause))
+ return limitselect
+ else:
+ return select
+
+ @_with_legacy_schema_aliasing
+ def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs):
+ if mssql_aliased is table or iscrud:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ # alias schema-qualified tables
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=table, **kwargs)
+ else:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ @_with_legacy_schema_aliasing
+ def visit_alias(self, alias, **kw):
+ # translate for schema-qualified table aliases
+ kw["mssql_aliased"] = alias.element
+ return super(MSSQLCompiler, self).visit_alias(alias, **kw)
+
+ @_with_legacy_schema_aliasing
+ def visit_column(self, column, add_to_result_map=None, **kw):
+ if (
+ column.table is not None
+ and (not self.isupdate and not self.isdelete)
+ or self.is_subquery()
+ ):
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ converted = elements._corresponding_column_or_error(t, column)
+ if add_to_result_map is not None:
+ add_to_result_map(
+ column.name,
+ column.name,
+ (column, column.name, column.key),
+ column.type,
+ )
+
+ return super(MSSQLCompiler, self).visit_column(converted, **kw)
+
+ return super(MSSQLCompiler, self).visit_column(
+ column, add_to_result_map=add_to_result_map, **kw
+ )
+
+ def _schema_aliased_table(self, table):
+ if getattr(table, "schema", None) is not None:
+ if table not in self.tablealiases:
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_extract(self, extract, **kw):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_binary(self, binary, **kwargs):
+ """Move bind parameters to the right-hand side of an operator, where
+ possible.
+
+ """
+ if (
+ isinstance(binary.left, expression.BindParameter)
+ and binary.operator == operator.eq
+ and not isinstance(binary.right, expression.BindParameter)
+ ):
+ return self.process(
+ expression.BinaryExpression(
+ binary.right, binary.left, binary.operator
+ ),
+ **kwargs
+ )
+ return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+ def returning_clause(self, stmt, returning_cols):
+ # SQL server returning clause requires that the columns refer to
+ # the virtual table names "inserted" or "deleted". Here, we make
+ # a simple alias of our table with that name, and then adapt the
+ # columns we have from the list of RETURNING columns to that new name
+ # so that they render as "inserted.<colname>" / "deleted.<colname>".
+
+ if self.isinsert or self.isupdate:
+ target = stmt.table.alias("inserted")
+ else:
+ target = stmt.table.alias("deleted")
+
+ adapter = sql_util.ClauseAdapter(target)
+
+ # adapter.traverse() takes a column from our target table and returns
+ # the one that is linked to the "inserted" / "deleted" tables. So in
+ # order to retrieve these values back from the result (e.g. like
+ # row[column]), tell the compiler to also add the original unadapted
+ # column to the result map. Before #4877, these were (unknowingly)
+ # falling back using string name matching in the result set which
+ # necessarily used an expensive KeyError in order to match.
+
+ columns = [
+ self._label_returning_column(
+ stmt,
+ adapter.traverse(c),
+ {"result_map_targets": (c,)},
+ )
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "OUTPUT " + ", ".join(columns)
+
+ def get_cte_preamble(self, recursive):
+ # SQL Server finds it too inconvenient to accept
+ # an entirely optional, SQL standard specified,
+ # "RECURSIVE" word with their "WITH",
+ # so here we go
+ return "WITH"
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(MSSQLCompiler, self).label_select_column(
+ select, column, asfrom
+ )
+
+ def for_update_clause(self, select, **kw):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
+ # SQLAlchemy doesn't use
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if (
+ self.is_subquery()
+ and not select._limit
+ and (
+ select._offset is None
+ or not self.dialect._supports_offset_fetch
+ )
+ ):
+ # avoid processing the order by clause if we won't end up
+ # using it, because we don't want all the bind params tacked
+ # onto the positional list if that is what the dbapi requires
+ return ""
+
+ order_by = self.process(select._order_by_clause, **kw)
+
+ if order_by:
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the UPDATE..FROM clause specific to MSSQL.
+
+ In MSSQL, if the UPDATE statement involves an alias of the table to
+ be updated, then the table itself must be added to the FROM list as
+ well. Otherwise, it is optional. Here, we add it regardless.
+
+ """
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. FROM clause specific to MSSQL.
+
+ Yes, it has the FROM keyword twice.
+
+ """
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 WHERE 1!=1"
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "NOT EXISTS (SELECT %s INTERSECT SELECT %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "EXISTS (SELECT %s INTERSECT SELECT %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _render_json_extract_from_binary(self, binary, operator, **kw):
+ # note we are intentionally calling upon the process() calls in the
+ # order in which they appear in the SQL String as this is used
+ # by positional parameter rendering
+
+ if binary.type._type_affinity is sqltypes.JSON:
+ return "JSON_QUERY(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ # as with other dialects, start with an explicit test for NULL
+ case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ if binary.type._type_affinity is sqltypes.Integer:
+ type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ elif binary.type._type_affinity is sqltypes.Numeric:
+ type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ "FLOAT"
+ if isinstance(binary.type, sqltypes.Float)
+ else "NUMERIC(%s, %s)"
+ % (binary.type.precision, binary.type.scale),
+ )
+ elif binary.type._type_affinity is sqltypes.Boolean:
+ # the NULL handling is particularly weird with boolean, so
+ # explicitly return numeric (BIT) constants
+ type_expression = (
+ "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL"
+ )
+ elif binary.type._type_affinity is sqltypes.String:
+ # TODO: does this comment (from mysql) apply to here, too?
+ # this fails with a JSON value that's a four byte unicode
+ # string. SQLite has the same problem at the moment
+ type_expression = "ELSE JSON_VALUE(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ else:
+ # other affinity....this is not expected right now
+ type_expression = "ELSE JSON_QUERY(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ return case_expression + " " + type_expression + " END"
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_sequence(self, seq, **kw):
+ return "NEXT VALUE FOR %s" % self.preparer.format_sequence(seq)
+
+
+class MSSQLStrictCompiler(MSSQLCompiler):
+
+ """A subclass of MSSQLCompiler which disables the usage of bind
+ parameters where not allowed natively by MS-SQL.
+
+ A dialect may use this compiler on a platform where native
+ binds are used.
+
+ """
+
+ ansi_bind_rules = True
+
+ def visit_in_op_binary(self, binary, operator, **kw):
+ kw["literal_execute"] = True
+ return "%s IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_not_in_op_binary(self, binary, operator, **kw):
+ kw["literal_execute"] = True
+ return "%s NOT IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def render_literal_value(self, value, type_):
+ """
+ For date and datetime values, convert to a string
+ format acceptable to MSSQL. That seems to be the
+ so-called ODBC canonical date format which looks
+ like this:
+
+ yyyy-mm-dd hh:mi:ss.mmm(24h)
+
+ For other data types, call the base class implementation.
+ """
+ # datetime and date are both subclasses of datetime.date
+ if issubclass(type(value), datetime.date):
+ # SQL Server wants single quotes around the date string.
+ return "'" + str(value) + "'"
+ else:
+ return super(MSSQLStrictCompiler, self).render_literal_value(
+ value, type_
+ )
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column)
+
+ # type is not accepted in a computed column
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+ else:
+ colspec += " " + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+
+ if column.nullable is not None:
+ if (
+ not column.nullable
+ or column.primary_key
+ or isinstance(column.default, sa_schema.Sequence)
+ or column.autoincrement is True
+ or column.identity
+ ):
+ colspec += " NOT NULL"
+ elif column.computed is None:
+ # don't specify "NULL" for computed columns
+ colspec += " NULL"
+
+ if column.table is None:
+ raise exc.CompileError(
+ "mssql requires Table-bound columns "
+ "in order to generate DDL"
+ )
+
+ d_opt = column.dialect_options["mssql"]
+ start = d_opt["identity_start"]
+ increment = d_opt["identity_increment"]
+ if start is not None or increment is not None:
+ if column.identity:
+ raise exc.CompileError(
+ "Cannot specify options 'mssql_identity_start' and/or "
+ "'mssql_identity_increment' while also using the "
+ "'Identity' construct."
+ )
+ util.warn_deprecated(
+ "The dialect options 'mssql_identity_start' and "
+ "'mssql_identity_increment' are deprecated. "
+ "Use the 'Identity' object instead.",
+ "1.4",
+ )
+
+ if column.identity:
+ colspec += self.process(column.identity, **kwargs)
+ elif (
+ column is column.table._autoincrement_column
+ or column.autoincrement is True
+ ) and (
+ not isinstance(column.default, Sequence) or column.default.optional
+ ):
+ colspec += self.process(Identity(start=start, increment=increment))
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_create_index(self, create, include_schema=False):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ # handle clustering option
+ clustered = index.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(index.table),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+
+ # handle other included columns
+ if index.dialect_options["mssql"]["include"]:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in index.dialect_options["mssql"]["include"]
+ ]
+
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
+
+ whereclause = index.dialect_options["mssql"]["where"]
+
+ if whereclause is not None:
+ whereclause = coercions.expect(
+ roles.DDLExpressionRole, whereclause
+ )
+
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def visit_drop_index(self, drop):
+ return "\nDROP INDEX %s ON %s" % (
+ self._prepared_index_name(drop.element, include_schema=False),
+ self.preparer.format_table(drop.element.table),
+ )
+
+ def visit_primary_key_constraint(self, constraint):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
+ text += "PRIMARY KEY "
+
+ clustered = constraint.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_unique_constraint(self, constraint):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "UNIQUE "
+
+ clustered = constraint.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_computed_column(self, generated):
+ text = "AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ # explicitly check for True|False since None means server default
+ if generated.persisted is True:
+ text += " PERSISTED"
+ return text
+
+ def visit_create_sequence(self, create, **kw):
+ prefix = None
+ if create.element.data_type is not None:
+ data_type = create.element.data_type
+ prefix = " AS %s" % self.type_compiler.process(data_type)
+ return super(MSDDLCompiler, self).visit_create_sequence(
+ create, prefix=prefix, **kw
+ )
+
+ def visit_identity_column(self, identity, **kw):
+ text = " IDENTITY"
+ if identity.start is not None or identity.increment is not None:
+ start = 1 if identity.start is None else identity.start
+ increment = 1 if identity.increment is None else identity.increment
+ text += "(%s,%s)" % (start, increment)
+ return text
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer, self).__init__(
+ dialect,
+ initial_quote="[",
+ final_quote="]",
+ quote_case_sensitive_collations=False,
+ )
+
+ def _escape_identifier(self, value):
+ return value.replace("]", "]]")
+
+ def _unescape_identifier(self, value):
+ return value.replace("]]", "]")
+
+ def quote_schema(self, schema, force=None):
+ """Prepare a quoted table and schema name."""
+
+ # need to re-implement the deprecation warning entirely
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote_schema.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ version="1.3",
+ )
+
+ dbname, owner = _schema_elements(schema)
+ if dbname:
+ result = "%s.%s" % (self.quote(dbname), self.quote(owner))
+ elif owner:
+ result = self.quote(owner)
+ else:
+ result = ""
+ return result
+
+
+def _db_plus_owner_listing(fn):
+ def wrap(dialect, connection, schema=None, **kw):
+ dbname, owner = _owner_plus_db(dialect, schema)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
+ return update_wrapper(wrap, fn)
+
+
+def _db_plus_owner(fn):
+ def wrap(dialect, connection, tablename, schema=None, **kw):
+ dbname, owner = _owner_plus_db(dialect, schema)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ tablename,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
+ return update_wrapper(wrap, fn)
+
+
+def _switch_db(dbname, connection, fn, *arg, **kw):
+ if dbname:
+ current_db = connection.exec_driver_sql("select db_name()").scalar()
+ if current_db != dbname:
+ connection.exec_driver_sql(
+ "use %s" % connection.dialect.identifier_preparer.quote(dbname)
+ )
+ try:
+ return fn(*arg, **kw)
+ finally:
+ if dbname and current_db != dbname:
+ connection.exec_driver_sql(
+ "use %s"
+ % connection.dialect.identifier_preparer.quote(current_db)
+ )
+
+
+def _owner_plus_db(dialect, schema):
+ if not schema:
+ return None, dialect.default_schema_name
+ elif "." in schema:
+ return _schema_elements(schema)
+ else:
+ return None, schema
+
+
+_memoized_schema = util.LRUCache()
+
+
+def _schema_elements(schema):
+ if isinstance(schema, quoted_name) and schema.quote:
+ return None, schema
+
+ if schema in _memoized_schema:
+ return _memoized_schema[schema]
+
+ # tests for this function are in:
+ # test/dialect/mssql/test_reflection.py ->
+ # OwnerPlusDBTest.test_owner_database_pairs
+ # test/dialect/mssql/test_compiler.py -> test_force_schema_*
+ # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_*
+ #
+
+ if schema.startswith("__[SCHEMA_"):
+ return None, schema
+
+ push = []
+ symbol = ""
+ bracket = False
+ has_brackets = False
+ for token in re.split(r"(\[|\]|\.)", schema):
+ if not token:
+ continue
+ if token == "[":
+ bracket = True
+ has_brackets = True
+ elif token == "]":
+ bracket = False
+ elif not bracket and token == ".":
+ if has_brackets:
+ push.append("[%s]" % symbol)
+ else:
+ push.append(symbol)
+ symbol = ""
+ has_brackets = False
+ else:
+ symbol += token
+ if symbol:
+ push.append(symbol)
+ if len(push) > 1:
+ dbname, owner = ".".join(push[0:-1]), push[-1]
+
+ # test for internal brackets
+ if re.match(r".*\].*\[.*", dbname[1:-1]):
+ dbname = quoted_name(dbname, quote=False)
+ else:
+ dbname = dbname.lstrip("[").rstrip("]")
+
+ elif len(push):
+ dbname, owner = None, push[0]
+ else:
+ dbname, owner = None, None
+
+ _memoized_schema[schema] = dbname, owner
+ return dbname, owner
+
+
+class MSDialect(default.DefaultDialect):
+ # will assume it's at least mssql2005
+ name = "mssql"
+ supports_statement_cache = True
+ supports_default_values = True
+ supports_empty_insert = False
+ execution_ctx_cls = MSExecutionContext
+ use_scope_identity = True
+ max_identifier_length = 128
+ schema_name = "dbo"
+
+ implicit_returning = True
+ full_returning = True
+
+ colspecs = {
+ sqltypes.DateTime: _MSDateTime,
+ sqltypes.Date: _MSDate,
+ sqltypes.JSON: JSON,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+ sqltypes.Time: _BASETIMEIMPL,
+ sqltypes.Unicode: _MSUnicode,
+ sqltypes.UnicodeText: _MSUnicodeText,
+ DATETIMEOFFSET: DATETIMEOFFSET,
+ DATETIME2: DATETIME2,
+ SMALLDATETIME: SMALLDATETIME,
+ DATETIME: DATETIME,
+ }
+
+ engine_config_types = default.DefaultDialect.engine_config_types.union(
+ {"legacy_schema_aliasing": util.asbool}
+ )
+
+ ischema_names = ischema_names
+
+ supports_sequences = True
+ sequences_optional = True
+ # T-SQL's actual default is -9223372036854775808
+ default_sequence_base = 1
+
+ supports_native_boolean = False
+ non_native_boolean_check_constraint = False
+ supports_unicode_binds = True
+ postfetch_lastrowid = True
+ _supports_offset_fetch = False
+ _supports_nvarchar_max = False
+
+ legacy_schema_aliasing = False
+
+ server_version_info = ()
+
+ statement_compiler = MSSQLCompiler
+ ddl_compiler = MSDDLCompiler
+ type_compiler = MSTypeCompiler
+ preparer = MSIdentifierPreparer
+
+ construct_arguments = [
+ (sa_schema.PrimaryKeyConstraint, {"clustered": None}),
+ (sa_schema.UniqueConstraint, {"clustered": None}),
+ (sa_schema.Index, {"clustered": None, "include": None, "where": None}),
+ (
+ sa_schema.Column,
+ {"identity_start": None, "identity_increment": None},
+ ),
+ ]
+
+ def __init__(
+ self,
+ query_timeout=None,
+ use_scope_identity=True,
+ schema_name="dbo",
+ isolation_level=None,
+ deprecate_large_types=None,
+ json_serializer=None,
+ json_deserializer=None,
+ legacy_schema_aliasing=None,
+ ignore_no_transaction_on_rollback=False,
+ **opts
+ ):
+ self.query_timeout = int(query_timeout or 0)
+ self.schema_name = schema_name
+
+ self.use_scope_identity = use_scope_identity
+ self.deprecate_large_types = deprecate_large_types
+ self.ignore_no_transaction_on_rollback = (
+ ignore_no_transaction_on_rollback
+ )
+
+ if legacy_schema_aliasing is not None:
+ util.warn_deprecated(
+ "The legacy_schema_aliasing parameter is "
+ "deprecated and will be removed in a future release.",
+ "1.4",
+ )
+ self.legacy_schema_aliasing = legacy_schema_aliasing
+
+ super(MSDialect, self).__init__(**opts)
+
+ self.isolation_level = isolation_level
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+
+ def do_savepoint(self, connection, name):
+ # give the DBAPI a push
+ connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
+ super(MSDialect, self).do_savepoint(connection, name)
+
+ def do_release_savepoint(self, connection, name):
+ # SQL Server does not support RELEASE SAVEPOINT
+ pass
+
+ def do_rollback(self, dbapi_connection):
+ try:
+ super(MSDialect, self).do_rollback(dbapi_connection)
+ except self.dbapi.ProgrammingError as e:
+ if self.ignore_no_transaction_on_rollback and re.match(
+ r".*\b111214\b", str(e)
+ ):
+ util.warn(
+ "ProgrammingError 111214 "
+ "'No corresponding transaction found.' "
+ "has been suppressed via "
+ "ignore_no_transaction_on_rollback=True"
+ )
+ else:
+ raise
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "SNAPSHOT",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.close()
+ if level == "SNAPSHOT":
+ connection.commit()
+
+ def get_isolation_level(self, dbapi_connection):
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(
+ "SELECT name FROM sys.system_views WHERE name IN "
+ "('dm_exec_sessions', 'dm_pdw_nodes_exec_sessions')"
+ )
+ row = cursor.fetchone()
+ if not row:
+ raise NotImplementedError(
+ "Can't fetch isolation level on this particular "
+ "SQL Server version."
+ )
+
+ view_name = "sys.{}".format(row[0])
+ cursor.execute(
+ """
+ SELECT CASE transaction_isolation_level
+ WHEN 0 THEN NULL
+ WHEN 1 THEN 'READ UNCOMMITTED'
+ WHEN 2 THEN 'READ COMMITTED'
+ WHEN 3 THEN 'REPEATABLE READ'
+ WHEN 4 THEN 'SERIALIZABLE'
+ WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL
+ FROM {}
+ where session_id = @@SPID
+ """.format(
+ view_name
+ )
+ )
+ row = cursor.fetchone()
+ assert row is not None
+ val = row[0]
+ finally:
+ cursor.close()
+ return val.upper()
+
+ def initialize(self, connection):
+ super(MSDialect, self).initialize(connection)
+ self._setup_version_attributes()
+ self._setup_supports_nvarchar_max(connection)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ def _setup_version_attributes(self):
+ if self.server_version_info[0] not in list(range(8, 17)):
+ util.warn(
+ "Unrecognized server version info '%s'. Some SQL Server "
+ "features may not function properly."
+ % ".".join(str(x) for x in self.server_version_info)
+ )
+
+ if self.server_version_info >= MS_2008_VERSION:
+ self.supports_multivalues_insert = True
+ if self.deprecate_large_types is None:
+ self.deprecate_large_types = (
+ self.server_version_info >= MS_2012_VERSION
+ )
+
+ self._supports_offset_fetch = (
+ self.server_version_info and self.server_version_info[0] >= 11
+ )
+
+ def _setup_supports_nvarchar_max(self, connection):
+ try:
+ connection.scalar(
+ sql.text("SELECT CAST('test max support' AS NVARCHAR(max))")
+ )
+ except exc.DBAPIError:
+ self._supports_nvarchar_max = False
+ else:
+ self._supports_nvarchar_max = True
+
+ def _get_default_schema_name(self, connection):
+ query = sql.text("SELECT schema_name()")
+ default_schema_name = connection.scalar(query)
+ if default_schema_name is not None:
+ # guard against the case where the default_schema_name is being
+ # fed back into a table reflection function.
+ return quoted_name(default_schema_name, quote=True)
+ else:
+ return self.schema_name
+
+ @_db_plus_owner
+ def has_table(self, connection, tablename, dbname, owner, schema):
+ self._ensure_has_table_connection(connection)
+ if tablename.startswith("#"): # temporary table
+ tables = ischema.mssql_temp_table_columns
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
+ )
+ )
+
+ # #7168: fetch all (not just first match) in case some other #temp
+ # table with the same name happens to appear first
+ table_names = connection.execute(s).scalars().fetchall()
+ # #6910: verify it's not a temp table from another session
+ for table_name in table_names:
+ if bool(
+ connection.scalar(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ )
+ ):
+ return True
+ else:
+ return False
+ else:
+ tables = ischema.tables
+
+ s = sql.select(tables.c.table_name).where(
+ sql.and_(
+ tables.c.table_type == "BASE TABLE",
+ tables.c.table_name == tablename,
+ )
+ )
+
+ if owner:
+ s = s.where(tables.c.table_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ @_db_plus_owner
+ def has_sequence(self, connection, sequencename, dbname, owner, schema):
+ sequences = ischema.sequences
+
+ s = sql.select(sequences.c.sequence_name).where(
+ sequences.c.sequence_name == sequencename
+ )
+
+ if owner:
+ s = s.where(sequences.c.sequence_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_sequence_names(self, connection, dbname, owner, schema, **kw):
+ sequences = ischema.sequences
+
+ s = sql.select(sequences.c.sequence_name)
+ if owner:
+ s = s.where(sequences.c.sequence_schema == owner)
+
+ c = connection.execute(s)
+
+ return [row[0] for row in c]
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = sql.select(ischema.schemata.c.schema_name).order_by(
+ ischema.schemata.c.schema_name
+ )
+ schema_names = [r[0] for r in connection.execute(s)]
+ return schema_names
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_table_names(self, connection, dbname, owner, schema, **kw):
+ tables = ischema.tables
+ s = (
+ sql.select(tables.c.table_name)
+ .where(
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "BASE TABLE",
+ )
+ )
+ .order_by(tables.c.table_name)
+ )
+ table_names = [r[0] for r in connection.execute(s)]
+ return table_names
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_view_names(self, connection, dbname, owner, schema, **kw):
+ tables = ischema.tables
+ s = (
+ sql.select(tables.c.table_name)
+ .where(
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "VIEW",
+ )
+ )
+ .order_by(tables.c.table_name)
+ )
+ view_names = [r[0] for r in connection.execute(s)]
+ return view_names
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
+ filter_definition = (
+ "ind.filter_definition"
+ if self.server_version_info >= MS_2008_VERSION
+ else "NULL as filter_definition"
+ )
+ rp = connection.execution_options(future_result=True).execute(
+ sql.text(
+ "select ind.index_id, ind.is_unique, ind.name, "
+ "%s "
+ "from sys.indexes as ind join sys.tables as tab on "
+ "ind.object_id=tab.object_id "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name = :tabname "
+ "and sch.name=:schname "
+ "and ind.is_primary_key=0 and ind.type != 0"
+ % filter_definition
+ )
+ .bindparams(
+ sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ .columns(name=sqltypes.Unicode())
+ )
+ indexes = {}
+ for row in rp.mappings():
+ indexes[row["index_id"]] = {
+ "name": row["name"],
+ "unique": row["is_unique"] == 1,
+ "column_names": [],
+ "include_columns": [],
+ }
+
+ if row["filter_definition"] is not None:
+ indexes[row["index_id"]].setdefault("dialect_options", {})[
+ "mssql_where"
+ ] = row["filter_definition"]
+
+ rp = connection.execution_options(future_result=True).execute(
+ sql.text(
+ "select ind_col.index_id, ind_col.object_id, col.name, "
+ "ind_col.is_included_column "
+ "from sys.columns as col "
+ "join sys.tables as tab on tab.object_id=col.object_id "
+ "join sys.index_columns as ind_col on "
+ "(ind_col.column_id=col.column_id and "
+ "ind_col.object_id=tab.object_id) "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name=:tabname "
+ "and sch.name=:schname"
+ )
+ .bindparams(
+ sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ .columns(name=sqltypes.Unicode())
+ )
+ for row in rp.mappings():
+ if row["index_id"] in indexes:
+ if row["is_included_column"]:
+ indexes[row["index_id"]]["include_columns"].append(
+ row["name"]
+ )
+ else:
+ indexes[row["index_id"]]["column_names"].append(
+ row["name"]
+ )
+ for index_info in indexes.values():
+ # NOTE: "root level" include_columns is legacy, now part of
+ # dialect_options (issue #7382)
+ index_info.setdefault("dialect_options", {})[
+ "mssql_include"
+ ] = index_info["include_columns"]
+
+ return list(indexes.values())
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_view_definition(
+ self, connection, viewname, dbname, owner, schema, **kw
+ ):
+ rp = connection.execute(
+ sql.text(
+ "select definition from sys.sql_modules as mod, "
+ "sys.views as views, "
+ "sys.schemas as sch"
+ " where "
+ "mod.object_id=views.object_id and "
+ "views.schema_id=sch.schema_id and "
+ "views.name=:viewname and sch.name=:schname"
+ ).bindparams(
+ sql.bindparam("viewname", viewname, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ )
+
+ if rp:
+ view_def = rp.scalar()
+ return view_def
+
+ def _temp_table_name_like_pattern(self, tablename):
+ # LIKE uses '%' to match zero or more characters and '_' to match any
+ # single character. We want to match literal underscores, so T-SQL
+ # requires that we enclose them in square brackets.
+ return tablename + (
+ ("[_][_][_]%") if not tablename.startswith("##") else ""
+ )
+
+ def _get_internal_temp_table_name(self, connection, tablename):
+ # it's likely that schema is always "dbo", but since we can
+ # get it here, let's get it.
+ # see https://stackoverflow.com/questions/8311959/
+ # specifying-schema-for-temporary-tables
+
+ try:
+ return connection.execute(
+ sql.text(
+ "select table_schema, table_name "
+ "from tempdb.information_schema.tables "
+ "where table_name like :p1"
+ ),
+ {"p1": self._temp_table_name_like_pattern(tablename)},
+ ).one()
+ except exc.MultipleResultsFound as me:
+ util.raise_(
+ exc.UnreflectableTableError(
+ "Found more than one temporary table named '%s' in tempdb "
+ "at this time. Cannot reliably resolve that name to its "
+ "internal table name." % tablename
+ ),
+ replace_context=me,
+ )
+ except exc.NoResultFound as ne:
+ util.raise_(
+ exc.NoSuchTableError(
+ "Unable to find a temporary table named '%s' in tempdb."
+ % tablename
+ ),
+ replace_context=ne,
+ )
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+ is_temp_table = tablename.startswith("#")
+ if is_temp_table:
+ owner, tablename = self._get_internal_temp_table_name(
+ connection, tablename
+ )
+
+ columns = ischema.mssql_temp_table_columns
+ else:
+ columns = ischema.columns
+
+ computed_cols = ischema.computed_columns
+ identity_cols = ischema.identity_columns
+ if owner:
+ whereclause = sql.and_(
+ columns.c.table_name == tablename,
+ columns.c.table_schema == owner,
+ )
+ full_name = columns.c.table_schema + "." + columns.c.table_name
+ else:
+ whereclause = columns.c.table_name == tablename
+ full_name = columns.c.table_name
+
+ join = columns.join(
+ computed_cols,
+ onclause=sql.and_(
+ computed_cols.c.object_id == func.object_id(full_name),
+ computed_cols.c.name
+ == columns.c.column_name.collate("DATABASE_DEFAULT"),
+ ),
+ isouter=True,
+ ).join(
+ identity_cols,
+ onclause=sql.and_(
+ identity_cols.c.object_id == func.object_id(full_name),
+ identity_cols.c.name
+ == columns.c.column_name.collate("DATABASE_DEFAULT"),
+ ),
+ isouter=True,
+ )
+
+ if self._supports_nvarchar_max:
+ computed_definition = computed_cols.c.definition
+ else:
+ # tds_version 4.2 does not support NVARCHAR(MAX)
+ computed_definition = sql.cast(
+ computed_cols.c.definition, NVARCHAR(4000)
+ )
+
+ s = (
+ sql.select(
+ columns,
+ computed_definition,
+ computed_cols.c.is_persisted,
+ identity_cols.c.is_identity,
+ identity_cols.c.seed_value,
+ identity_cols.c.increment_value,
+ )
+ .where(whereclause)
+ .select_from(join)
+ .order_by(columns.c.ordinal_position)
+ )
+
+ c = connection.execution_options(future_result=True).execute(s)
+
+ cols = []
+ for row in c.mappings():
+ name = row[columns.c.column_name]
+ type_ = row[columns.c.data_type]
+ nullable = row[columns.c.is_nullable] == "YES"
+ charlen = row[columns.c.character_maximum_length]
+ numericprec = row[columns.c.numeric_precision]
+ numericscale = row[columns.c.numeric_scale]
+ default = row[columns.c.column_default]
+ collation = row[columns.c.collation_name]
+ definition = row[computed_definition]
+ is_persisted = row[computed_cols.c.is_persisted]
+ is_identity = row[identity_cols.c.is_identity]
+ identity_start = row[identity_cols.c.seed_value]
+ identity_increment = row[identity_cols.c.increment_value]
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+ if coltype in (
+ MSString,
+ MSChar,
+ MSNVarchar,
+ MSNChar,
+ MSText,
+ MSNText,
+ MSBinary,
+ MSVarBinary,
+ sqltypes.LargeBinary,
+ ):
+ if charlen == -1:
+ charlen = None
+ kwargs["length"] = charlen
+ if collation:
+ kwargs["collation"] = collation
+
+ if coltype is None:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (type_, name)
+ )
+ coltype = sqltypes.NULLTYPE
+ else:
+ if issubclass(coltype, sqltypes.Numeric):
+ kwargs["precision"] = numericprec
+
+ if not issubclass(coltype, sqltypes.Float):
+ kwargs["scale"] = numericscale
+
+ coltype = coltype(**kwargs)
+ cdict = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": is_identity is not None,
+ }
+
+ if definition is not None and is_persisted is not None:
+ cdict["computed"] = {
+ "sqltext": definition,
+ "persisted": is_persisted,
+ }
+
+ if is_identity is not None:
+ # identity_start and identity_increment are Decimal or None
+ if identity_start is None or identity_increment is None:
+ cdict["identity"] = {}
+ else:
+ if isinstance(coltype, sqltypes.BigInteger):
+ start = compat.long_type(identity_start)
+ increment = compat.long_type(identity_increment)
+ elif isinstance(coltype, sqltypes.Integer):
+ start = int(identity_start)
+ increment = int(identity_increment)
+ else:
+ start = identity_start
+ increment = identity_increment
+
+ cdict["identity"] = {
+ "start": start,
+ "increment": increment,
+ }
+
+ cols.append(cdict)
+
+ return cols
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_pk_constraint(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
+ pkeys = []
+ TC = ischema.constraints
+ C = ischema.key_constraints.alias("C")
+
+ # Primary key constraints
+ s = (
+ sql.select(
+ C.c.column_name, TC.c.constraint_type, C.c.constraint_name
+ )
+ .where(
+ sql.and_(
+ TC.c.constraint_name == C.c.constraint_name,
+ TC.c.table_schema == C.c.table_schema,
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ ),
+ )
+ .order_by(TC.c.constraint_name, C.c.ordinal_position)
+ )
+ c = connection.execution_options(future_result=True).execute(s)
+ constraint_name = None
+ for row in c.mappings():
+ if "PRIMARY" in row[TC.c.constraint_type.name]:
+ pkeys.append(row["COLUMN_NAME"])
+ if constraint_name is None:
+ constraint_name = row[C.c.constraint_name.name]
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_foreign_keys(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
+ # Foreign key constraints
+ s = (
+ text(
+ """\
+WITH fk_info AS (
+ SELECT
+ ischema_ref_con.constraint_schema,
+ ischema_ref_con.constraint_name,
+ ischema_key_col.ordinal_position,
+ ischema_key_col.table_schema,
+ ischema_key_col.table_name,
+ ischema_ref_con.unique_constraint_schema,
+ ischema_ref_con.unique_constraint_name,
+ ischema_ref_con.match_option,
+ ischema_ref_con.update_rule,
+ ischema_ref_con.delete_rule,
+ ischema_key_col.column_name AS constrained_column
+ FROM
+ INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con
+ INNER JOIN
+ INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON
+ ischema_key_col.table_schema = ischema_ref_con.constraint_schema
+ AND ischema_key_col.constraint_name =
+ ischema_ref_con.constraint_name
+ WHERE ischema_key_col.table_name = :tablename
+ AND ischema_key_col.table_schema = :owner
+),
+constraint_info AS (
+ SELECT
+ ischema_key_col.constraint_schema,
+ ischema_key_col.constraint_name,
+ ischema_key_col.ordinal_position,
+ ischema_key_col.table_schema,
+ ischema_key_col.table_name,
+ ischema_key_col.column_name
+ FROM
+ INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col
+),
+index_info AS (
+ SELECT
+ sys.schemas.name AS index_schema,
+ sys.indexes.name AS index_name,
+ sys.index_columns.key_ordinal AS ordinal_position,
+ sys.schemas.name AS table_schema,
+ sys.objects.name AS table_name,
+ sys.columns.name AS column_name
+ FROM
+ sys.indexes
+ INNER JOIN
+ sys.objects ON
+ sys.objects.object_id = sys.indexes.object_id
+ INNER JOIN
+ sys.schemas ON
+ sys.schemas.schema_id = sys.objects.schema_id
+ INNER JOIN
+ sys.index_columns ON
+ sys.index_columns.object_id = sys.objects.object_id
+ AND sys.index_columns.index_id = sys.indexes.index_id
+ INNER JOIN
+ sys.columns ON
+ sys.columns.object_id = sys.indexes.object_id
+ AND sys.columns.column_id = sys.index_columns.column_id
+)
+ SELECT
+ fk_info.constraint_schema,
+ fk_info.constraint_name,
+ fk_info.ordinal_position,
+ fk_info.constrained_column,
+ constraint_info.table_schema AS referred_table_schema,
+ constraint_info.table_name AS referred_table_name,
+ constraint_info.column_name AS referred_column,
+ fk_info.match_option,
+ fk_info.update_rule,
+ fk_info.delete_rule
+ FROM
+ fk_info INNER JOIN constraint_info ON
+ constraint_info.constraint_schema =
+ fk_info.unique_constraint_schema
+ AND constraint_info.constraint_name =
+ fk_info.unique_constraint_name
+ AND constraint_info.ordinal_position = fk_info.ordinal_position
+ UNION
+ SELECT
+ fk_info.constraint_schema,
+ fk_info.constraint_name,
+ fk_info.ordinal_position,
+ fk_info.constrained_column,
+ index_info.table_schema AS referred_table_schema,
+ index_info.table_name AS referred_table_name,
+ index_info.column_name AS referred_column,
+ fk_info.match_option,
+ fk_info.update_rule,
+ fk_info.delete_rule
+ FROM
+ fk_info INNER JOIN index_info ON
+ index_info.index_schema = fk_info.unique_constraint_schema
+ AND index_info.index_name = fk_info.unique_constraint_name
+ AND index_info.ordinal_position = fk_info.ordinal_position
+
+ ORDER BY fk_info.constraint_schema, fk_info.constraint_name,
+ fk_info.ordinal_position
+"""
+ )
+ .bindparams(
+ sql.bindparam("tablename", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("owner", owner, ischema.CoerceUnicode()),
+ )
+ .columns(
+ constraint_schema=sqltypes.Unicode(),
+ constraint_name=sqltypes.Unicode(),
+ table_schema=sqltypes.Unicode(),
+ table_name=sqltypes.Unicode(),
+ constrained_column=sqltypes.Unicode(),
+ referred_table_schema=sqltypes.Unicode(),
+ referred_table_name=sqltypes.Unicode(),
+ referred_column=sqltypes.Unicode(),
+ )
+ )
+
+ # group rows by constraint ID, to handle multi-column FKs
+ fkeys = []
+
+ def fkey_rec():
+ return {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ "options": {},
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for r in connection.execute(s).fetchall():
+ (
+ _, # constraint schema
+ rfknm,
+ _, # ordinal position
+ scol,
+ rschema,
+ rtbl,
+ rcol,
+ # TODO: we support match=<keyword> for foreign keys so
+ # we can support this also, PG has match=FULL for example
+ # but this seems to not be a valid value for SQL Server
+ _, # match rule
+ fkuprule,
+ fkdelrule,
+ ) = r
+
+ rec = fkeys[rfknm]
+ rec["name"] = rfknm
+
+ if fkuprule != "NO ACTION":
+ rec["options"]["onupdate"] = fkuprule
+
+ if fkdelrule != "NO ACTION":
+ rec["options"]["ondelete"] = fkdelrule
+
+ if not rec["referred_table"]:
+ rec["referred_table"] = rtbl
+ if schema is not None or owner != rschema:
+ if dbname:
+ rschema = dbname + "." + rschema
+ rec["referred_schema"] = rschema
+
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
+
+ local_cols.append(scol)
+ remote_cols.append(rcol)
+
+ return list(fkeys.values())
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
new file mode 100644
index 0000000..df91493
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -0,0 +1,232 @@
+# mssql/information_schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import cast
+from ... import Column
+from ... import MetaData
+from ... import Table
+from ... import util
+from ...ext.compiler import compiles
+from ...sql import expression
+from ...types import Boolean
+from ...types import Integer
+from ...types import Numeric
+from ...types import String
+from ...types import TypeDecorator
+from ...types import Unicode
+
+
+ischema = MetaData()
+
+
+class CoerceUnicode(TypeDecorator):
+ impl = Unicode
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect):
+ if util.py2k and isinstance(value, util.binary_type):
+ value = value.decode(dialect.encoding)
+ return value
+
+ def bind_expression(self, bindvalue):
+ return _cast_on_2005(bindvalue)
+
+
+class _cast_on_2005(expression.ColumnElement):
+ def __init__(self, bindvalue):
+ self.bindvalue = bindvalue
+
+
+@compiles(_cast_on_2005)
+def _compile(element, compiler, **kw):
+ from . import base
+
+ if (
+ compiler.dialect.server_version_info is None
+ or compiler.dialect.server_version_info < base.MS_2005_VERSION
+ ):
+ return compiler.process(element.bindvalue, **kw)
+ else:
+ return compiler.process(cast(element.bindvalue, Unicode), **kw)
+
+
+schemata = Table(
+ "SCHEMATA",
+ ischema,
+ Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
+ Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
+ Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
+ schema="INFORMATION_SCHEMA",
+)
+
+tables = Table(
+ "TABLES",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("TABLE_TYPE", CoerceUnicode, key="table_type"),
+ schema="INFORMATION_SCHEMA",
+)
+
+columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+mssql_temp_table_columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="tempdb.INFORMATION_SCHEMA",
+)
+
+constraints = Table(
+ "TABLE_CONSTRAINTS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"),
+ schema="INFORMATION_SCHEMA",
+)
+
+column_constraints = Table(
+ "CONSTRAINT_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+key_constraints = Table(
+ "KEY_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ schema="INFORMATION_SCHEMA",
+)
+
+ref_constraints = Table(
+ "REFERENTIAL_CONSTRAINTS",
+ ischema,
+ Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ # TODO: is CATLOG misspelled ?
+ Column(
+ "UNIQUE_CONSTRAINT_CATLOG",
+ CoerceUnicode,
+ key="unique_constraint_catalog",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_SCHEMA",
+ CoerceUnicode,
+ key="unique_constraint_schema",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
+ ),
+ Column("MATCH_OPTION", String, key="match_option"),
+ Column("UPDATE_RULE", String, key="update_rule"),
+ Column("DELETE_RULE", String, key="delete_rule"),
+ schema="INFORMATION_SCHEMA",
+)
+
+views = Table(
+ "VIEWS",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
+ Column("CHECK_OPTION", String, key="check_option"),
+ Column("IS_UPDATABLE", String, key="is_updatable"),
+ schema="INFORMATION_SCHEMA",
+)
+
+computed_columns = Table(
+ "computed_columns",
+ ischema,
+ Column("object_id", Integer),
+ Column("name", CoerceUnicode),
+ Column("is_computed", Boolean),
+ Column("is_persisted", Boolean),
+ Column("definition", CoerceUnicode),
+ schema="sys",
+)
+
+sequences = Table(
+ "SEQUENCES",
+ ischema,
+ Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"),
+ Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"),
+ Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+
+class IdentitySqlVariant(TypeDecorator):
+ r"""This type casts sql_variant columns in the identity_columns view
+ to numeric. This is required because:
+
+ * pyodbc does not support sql_variant
+ * pymssql under python 2 return the byte representation of the number,
+ int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
+ correct value as string.
+ """
+ impl = Unicode
+ cache_ok = True
+
+ def column_expression(self, colexpr):
+ return cast(colexpr, Numeric)
+
+
+identity_columns = Table(
+ "identity_columns",
+ ischema,
+ Column("object_id", Integer),
+ Column("name", CoerceUnicode),
+ Column("is_identity", Boolean),
+ Column("seed_value", IdentitySqlVariant),
+ Column("increment_value", IdentitySqlVariant),
+ Column("last_value", IdentitySqlVariant),
+ Column("is_not_for_replication", Boolean),
+ schema="sys",
+)
diff --git a/lib/sqlalchemy/dialects/mssql/json.py b/lib/sqlalchemy/dialects/mssql/json.py
new file mode 100644
index 0000000..d515731
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/json.py
@@ -0,0 +1,125 @@
+from ... import types as sqltypes
+
+# technically, all the dialect-specific datatypes that don't have any special
+# behaviors would be private with names like _MSJson. However, we haven't been
+# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType
+# / JSONPathType in their json.py files, so keep consistent with that
+# sub-convention for now. A future change can update them all to be
+# package-private at once.
+
+
+class JSON(sqltypes.JSON):
+ """MSSQL JSON type.
+
+ MSSQL supports JSON-formatted data as of SQL Server 2016.
+
+ The :class:`_mssql.JSON` datatype at the DDL level will represent the
+ datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison
+ functions as well as Python coercion behavior.
+
+ :class:`_mssql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a SQL Server backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`_mssql.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_VALUE``
+ or ``JSON_QUERY`` functions at the database level.
+
+ The SQL Server :class:`_mssql.JSON` type necessarily makes use of the
+ ``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements
+ of a JSON object. These two functions have a major restriction in that
+ they are **mutually exclusive** based on the type of object to be returned.
+ The ``JSON_QUERY`` function **only** returns a JSON dictionary or list,
+ but not an individual string, numeric, or boolean element; the
+ ``JSON_VALUE`` function **only** returns an individual string, numeric,
+ or boolean element. **both functions either return NULL or raise
+ an error if they are not used against the correct expected value**.
+
+ To handle this awkward requirement, indexed access rules are as follows:
+
+ 1. When extracting a sub element from a JSON that is itself a JSON
+ dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
+ should be used::
+
+ stmt = select(
+ data_table.c.data["some key"].as_json()
+ ).where(
+ data_table.c.data["some key"].as_json() == {"sub": "structure"}
+ )
+
+ 2. When extracting a sub element from a JSON that is a plain boolean,
+ string, integer, or float, use the appropriate method among
+ :meth:`_types.JSON.Comparator.as_boolean`,
+ :meth:`_types.JSON.Comparator.as_string`,
+ :meth:`_types.JSON.Comparator.as_integer`,
+ :meth:`_types.JSON.Comparator.as_float`::
+
+ stmt = select(
+ data_table.c.data["some key"].as_string()
+ ).where(
+ data_table.c.data["some key"].as_string() == "some string"
+ )
+
+ .. versionadded:: 1.4
+
+
+ """
+
+ # note there was a result processor here that was looking for "number",
+ # but none of the tests seem to exercise it.
+
+
+# Note: these objects currently match exactly those of MySQL, however since
+# these are not generalizable to all JSON implementations, remain separately
+# implemented for each dialect.
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py
new file mode 100644
index 0000000..95c32d4
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py
@@ -0,0 +1,150 @@
+# mssql/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: mssql+mxodbc
+ :name: mxODBC
+ :dbapi: mxodbc
+ :connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
+ :url: https://www.egenix.com/
+
+.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mssql.
+
+Execution Modes
+---------------
+
+mxODBC features two styles of statement execution, using the
+``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
+an extension to the DBAPI specification). The former makes use of a particular
+API call specific to the SQL Server Native Client ODBC driver known
+SQLDescribeParam, while the latter does not.
+
+mxODBC apparently only makes repeated use of a single prepared statement
+when SQLDescribeParam is used. The advantage to prepared statement reuse is
+one of performance. The disadvantage is that SQLDescribeParam has a limited
+set of scenarios in which bind parameters are understood, including that they
+cannot be placed within the argument lists of function calls, anywhere outside
+the FROM, or even within subqueries within the FROM clause - making the usage
+of bind parameters within SELECT statements impossible for all but the most
+simplistic statements.
+
+For this reason, the mxODBC dialect uses the "native" mode by default only for
+INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
+all other statements.
+
+This behavior can be controlled via
+:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
+``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
+value of ``True`` will unconditionally use native bind parameters and a value
+of ``False`` will unconditionally use string-escaped parameters.
+
+"""
+
+
+from .base import _MSDate
+from .base import _MSDateTime
+from .base import _MSTime
+from .base import MSDialect
+from .base import VARBINARY
+from .pyodbc import _MSNumeric_pyodbc
+from .pyodbc import MSExecutionContext_pyodbc
+from ... import types as sqltypes
+from ...connectors.mxodbc import MxODBCConnector
+
+
+class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
+ """Include pyodbc's numeric processor."""
+
+
+class _MSDate_mxodbc(_MSDate):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is not None:
+ return "%s-%s-%s" % (value.year, value.month, value.day)
+ else:
+ return None
+
+ return process
+
+
+class _MSTime_mxodbc(_MSTime):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is not None:
+ return "%s:%s:%s" % (value.hour, value.minute, value.second)
+ else:
+ return None
+
+ return process
+
+
+class _VARBINARY_mxodbc(VARBINARY):
+
+ """
+ mxODBC Support for VARBINARY column types.
+
+ This handles the special case for null VARBINARY values,
+ which maps None values to the mx.ODBC.Manager.BinaryNull symbol.
+ """
+
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ # should pull from mx.ODBC.Manager.BinaryNull
+ return dialect.dbapi.BinaryNull
+
+ return process
+
+
+class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
+ """
+ The pyodbc execution context is useful for enabling
+ SELECT SCOPE_IDENTITY in cases where OUTPUT clause
+ does not work (tables with insert triggers).
+ """
+
+ # todo - investigate whether the pyodbc execution context
+ # is really only being used in cases where OUTPUT
+ # won't work.
+
+
+class MSDialect_mxodbc(MxODBCConnector, MSDialect):
+
+ # this is only needed if "native ODBC" mode is used,
+ # which is now disabled by default.
+ # statement_compiler = MSSQLStrictCompiler
+ supports_statement_cache = True
+
+ execution_ctx_cls = MSExecutionContext_mxodbc
+
+ # flag used by _MSNumeric_mxodbc
+ _need_decimal_fix = True
+
+ colspecs = {
+ sqltypes.Numeric: _MSNumeric_mxodbc,
+ sqltypes.DateTime: _MSDateTime,
+ sqltypes.Date: _MSDate_mxodbc,
+ sqltypes.Time: _MSTime_mxodbc,
+ VARBINARY: _VARBINARY_mxodbc,
+ sqltypes.LargeBinary: _VARBINARY_mxodbc,
+ }
+
+ def __init__(self, description_encoding=None, **params):
+ super(MSDialect_mxodbc, self).__init__(**params)
+ self.description_encoding = description_encoding
+
+
+dialect = MSDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
new file mode 100644
index 0000000..56f3305
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -0,0 +1,116 @@
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from ... import create_engine
+from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
+from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
+from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
+from ...testing.provision import log
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("mssql")
+def _mssql_create_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ conn.exec_driver_sql("create database %s" % ident)
+ conn.exec_driver_sql(
+ "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
+ )
+ conn.exec_driver_sql(
+ "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
+ )
+ conn.exec_driver_sql("use %s" % ident)
+ conn.exec_driver_sql("create schema test_schema")
+ conn.exec_driver_sql("create schema test_schema_2")
+
+
+@drop_db.for_db("mssql")
+def _mssql_drop_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ _mssql_drop_ignore(conn, ident)
+
+
+def _mssql_drop_ignore(conn, ident):
+ try:
+ # typically when this happens, we can't KILL the session anyway,
+ # so let the cleanup process drop the DBs
+ # for row in conn.exec_driver_sql(
+ # "select session_id from sys.dm_exec_sessions "
+ # "where database_id=db_id('%s')" % ident):
+ # log.info("killing SQL server session %s", row['session_id'])
+ # conn.exec_driver_sql("kill %s" % row['session_id'])
+ conn.exec_driver_sql("drop database %s" % ident)
+ log.info("Reaped db: %s", ident)
+ return True
+ except exc.DatabaseError as err:
+ log.warning("couldn't drop db: %s", err)
+ return False
+
+
+@run_reap_dbs.for_db("mssql")
+def _reap_mssql_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+ eng = create_engine(url)
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+
+ to_reap = conn.exec_driver_sql(
+ "select d.name from sys.databases as d where name "
+ "like 'TEST_%' and not exists (select session_id "
+ "from sys.dm_exec_sessions "
+ "where database_id=d.database_id)"
+ )
+ all_names = {dbname.lower() for (dbname,) in to_reap}
+ to_drop = set()
+ for name in all_names:
+ if name in idents:
+ to_drop.add(name)
+
+ dropped = total = 0
+ for total, dbname in enumerate(to_drop, 1):
+ if _mssql_drop_ignore(conn, dbname):
+ dropped += 1
+ log.info(
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+ return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+ return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ inspector = inspect(conn)
+ for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+ for tname in inspector.get_table_names(schema=schema):
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fk["name"]
+ )
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
new file mode 100644
index 0000000..84c5fed
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -0,0 +1,138 @@
+# mssql/pymssql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: mssql+pymssql
+ :name: pymssql
+ :dbapi: pymssql
+ :connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8
+
+pymssql is a Python module that provides a Python DBAPI interface around
+`FreeTDS <https://www.freetds.org/>`_.
+
+.. note::
+
+ pymssql is currently not included in SQLAlchemy's continuous integration
+ (CI) testing.
+
+Modern versions of this driver worked very well with SQL Server and FreeTDS
+from Linux and were highly recommended. However, pymssql is currently
+unmaintained and has fallen behind the progress of the Microsoft ODBC driver in
+its support for newer features of SQL Server. The latest official release of
+pymssql at the time of this document is version 2.1.4 (August, 2018) and it
+lacks support for:
+
+1. table-valued parameters (TVPs),
+2. ``datetimeoffset`` columns using timezone-aware ``datetime`` objects
+ (values are sent and retrieved as strings), and
+3. encrypted connections (e.g., to Azure SQL), when pymssql is installed from
+ the pre-built wheels. Support for encrypted connections requires building
+ pymssql from source, which can be a nuisance, especially under Windows.
+
+The above features are all supported by mssql+pyodbc when using Microsoft's
+ODBC Driver for SQL Server (msodbcsql), which is now available for Windows,
+(several flavors of) Linux, and macOS.
+
+
+""" # noqa
+import re
+
+from .base import MSDialect
+from .base import MSIdentifierPreparer
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+
+class _MSNumeric_pymssql(sqltypes.Numeric):
+ def result_processor(self, dialect, type_):
+ if not self.asdecimal:
+ return processors.to_float
+ else:
+ return sqltypes.Numeric.result_processor(self, dialect, type_)
+
+
+class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer_pymssql, self).__init__(dialect)
+ # pymssql has the very unusual behavior that it uses pyformat
+ # yet does not require that percent signs be doubled
+ self._double_percents = False
+
+
+class MSDialect_pymssql(MSDialect):
+ supports_statement_cache = True
+ supports_native_decimal = True
+ driver = "pymssql"
+
+ preparer = MSIdentifierPreparer_pymssql
+
+ colspecs = util.update_copy(
+ MSDialect.colspecs,
+ {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
+ )
+
+ @classmethod
+ def dbapi(cls):
+ module = __import__("pymssql")
+ # pymmsql < 2.1.1 doesn't have a Binary method. we use string
+ client_ver = tuple(int(x) for x in module.__version__.split("."))
+ if client_ver < (2, 1, 1):
+ # TODO: monkeypatching here is less than ideal
+ module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
+
+ if client_ver < (1,):
+ util.warn(
+ "The pymssql dialect expects at least "
+ "the 1.0 series of the pymssql DBAPI."
+ )
+ return module
+
+ def _get_server_version_info(self, connection):
+ vers = connection.exec_driver_sql("select @@version").scalar()
+ m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3, 4))
+ else:
+ return None
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+ port = opts.pop("port", None)
+ if port and "host" in opts:
+ opts["host"] = "%s:%s" % (opts["host"], port)
+ return [[], opts]
+
+ def is_disconnect(self, e, connection, cursor):
+ for msg in (
+ "Adaptive Server connection timed out",
+ "Net-Lib error during Connection reset by peer",
+ "message 20003", # connection timeout
+ "Error 10054",
+ "Not connected to any MS SQL server",
+ "Connection is closed",
+ "message 20006", # Write to the server failed
+ "message 20017", # Unexpected EOF from the server
+ "message 20047", # DBPROCESS is dead or not enabled
+ ):
+ if msg in str(e):
+ return True
+ else:
+ return False
+
+ def set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit(True)
+ else:
+ connection.autocommit(False)
+ super(MSDialect_pymssql, self).set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MSDialect_pymssql
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
new file mode 100644
index 0000000..edb76f2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -0,0 +1,673 @@
+# mssql/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mssql+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
+ :url: https://pypi.org/project/pyodbc/
+
+Connecting to PyODBC
+--------------------
+
+The URL here is to be translated to PyODBC connection strings, as
+detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
+
+DSN Connections
+^^^^^^^^^^^^^^^
+
+A DSN connection in ODBC means that a pre-existing ODBC datasource is
+configured on the client machine. The application then specifies the name
+of this datasource, which encompasses details such as the specific ODBC driver
+in use as well as the network address of the database. Assuming a datasource
+is configured on the client, a basic DSN-based connection looks like::
+
+ engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
+
+Which above, will pass the following connection string to PyODBC::
+
+ DSN=some_dsn;UID=scott;PWD=tiger
+
+If the username and password are omitted, the DSN form will also add
+the ``Trusted_Connection=yes`` directive to the ODBC string.
+
+Hostname Connections
+^^^^^^^^^^^^^^^^^^^^
+
+Hostname-based connections are also supported by pyodbc. These are often
+easier to use than a DSN and have the additional advantage that the specific
+database name to connect towards may be specified locally in the URL, rather
+than it being fixed as part of a datasource configuration.
+
+When using a hostname connection, the driver name must also be specified in the
+query parameters of the URL. As these names usually have spaces in them, the
+name must be URL encoded which means using plus signs for spaces::
+
+ engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
+
+Other keywords interpreted by the Pyodbc dialect to be passed to
+``pyodbc.connect()`` in both the DSN and hostname cases include:
+``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``,
+``authentication``.
+Note that in order for the dialect to recognize these keywords
+(including the ``driver`` keyword above) they must be all lowercase.
+Multiple additional keyword arguments must be separated by an
+ampersand (``&``), not a semicolon::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@myhost:49242/databasename"
+ "?driver=ODBC+Driver+17+for+SQL+Server"
+ "&authentication=ActiveDirectoryIntegrated"
+ )
+
+The equivalent URL can be constructed using :class:`_sa.engine.URL`::
+
+ from sqlalchemy.engine import URL
+ connection_url = URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="myhost",
+ port=49242,
+ database="databasename",
+ query={
+ "driver": "ODBC Driver 17 for SQL Server",
+ "authentication": "ActiveDirectoryIntegrated",
+ },
+ )
+
+
+Pass through exact Pyodbc string
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+A PyODBC connection string can also be sent in pyodbc's format directly, as
+specified in `the PyODBC documentation
+<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_,
+using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
+can help make this easier::
+
+ from sqlalchemy.engine import URL
+ connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
+ connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
+
+ engine = create_engine(connection_url)
+
+.. _mssql_pyodbc_access_tokens:
+
+Connecting to databases with access tokens
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Some database servers are set up to only accept access tokens for login. For
+example, SQL Server allows the use of Azure Active Directory tokens to connect
+to databases. This requires creating a credential object using the
+``azure-identity`` library. More information about the authentication step can be
+found in `Microsoft's documentation
+<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_.
+
+After getting an engine, the credentials need to be sent to ``pyodbc.connect``
+each time a connection is requested. One way to do this is to set up an event
+listener on the engine that adds the credential token to the dialect's connect
+call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For
+SQL Server in particular, this is passed as an ODBC connection attribute with
+a data structure `described by Microsoft
+<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_.
+
+The following code snippet will create an engine that connects to an Azure SQL
+database using Azure credentials::
+
+ import struct
+ from sqlalchemy import create_engine, event
+ from sqlalchemy.engine.url import URL
+ from azure import identity
+
+ SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
+ TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
+
+ connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
+
+ engine = create_engine(connection_string)
+
+ azure_credentials = identity.DefaultAzureCredential()
+
+ @event.listens_for(engine, "do_connect")
+ def provide_token(dialect, conn_rec, cargs, cparams):
+ # remove the "Trusted_Connection" parameter that SQLAlchemy adds
+ cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
+
+ # create token credential
+ raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
+ token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
+
+ # apply it to keyword arguments
+ cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
+
+.. tip::
+
+ The ``Trusted_Connection`` token is currently added by the SQLAlchemy
+ pyodbc dialect when no username or password is present. This needs
+ to be removed per Microsoft's
+ `documentation for Azure access tokens
+ <https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_,
+ stating that a connection string when using an access token must not contain
+ ``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters.
+
+.. _azure_synapse_ignore_no_transaction_on_rollback:
+
+Avoiding transaction-related exceptions on Azure Synapse Analytics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Azure Synapse Analytics has a significant difference in its transaction
+handling compared to plain SQL Server; in some cases an error within a Synapse
+transaction can cause it to be arbitrarily terminated on the server side, which
+then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to
+fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()``
+to pass silently if no transaction is present as the driver does not expect
+this condition. The symptom of this failure is an exception with a message
+resembling 'No corresponding transaction found. (111214)' when attempting to
+emit a ``.rollback()`` after an operation had a failure of some kind.
+
+This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
+the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
+
+ engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
+
+Using the above parameter, the dialect will catch ``ProgrammingError``
+exceptions raised during ``connection.rollback()`` and emit a warning
+if the error message contains code ``111214``, however will not raise
+an exception.
+
+.. versionadded:: 1.4.40 Added the
+ ``ignore_no_transaction_on_rollback=True`` parameter.
+
+Enable autocommit for Azure SQL Data Warehouse (DW) connections
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Azure SQL Data Warehouse does not support transactions,
+and that can cause problems with SQLAlchemy's "autobegin" (and implicit
+commit/rollback) behavior. We can avoid these problems by enabling autocommit
+at both the pyodbc and engine levels::
+
+ connection_url = sa.engine.URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="dw.azure.example.com",
+ database="mydb",
+ query={
+ "driver": "ODBC Driver 17 for SQL Server",
+ "autocommit": "True",
+ },
+ )
+
+ engine = create_engine(connection_url).execution_options(
+ isolation_level="AUTOCOMMIT"
+ )
+
+Avoiding sending large string parameters as TEXT/NTEXT
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+By default, for historical reasons, Microsoft's ODBC drivers for SQL Server
+send long string parameters (greater than 4000 SBCS characters or 2000 Unicode
+characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many
+years and are starting to cause compatibility issues with newer versions of
+SQL_Server/Azure. For example, see `this
+issue <https://github.com/mkleehammer/pyodbc/issues/835>`_.
+
+Starting with ODBC Driver 18 for SQL Server we can override the legacy
+behavior and pass long strings as varchar(max)/nvarchar(max) using the
+``LongAsMax=Yes`` connection string parameter::
+
+ connection_url = sa.engine.URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="mssqlserver.example.com",
+ database="mydb",
+ query={
+ "driver": "ODBC Driver 18 for SQL Server",
+ "LongAsMax": "Yes",
+ },
+ )
+
+
+Pyodbc Pooling / connection close behavior
+------------------------------------------
+
+PyODBC uses internal `pooling
+<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by
+default, which means connections will be longer lived than they are within
+SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often
+preferable to disable this behavior. This behavior can only be disabled
+globally at the PyODBC module level, **before** any connections are made::
+
+ import pyodbc
+
+ pyodbc.pooling = False
+
+ # don't use the engine before pooling is set to False
+ engine = create_engine("mssql+pyodbc://user:pass@dsn")
+
+If this variable is left at its default value of ``True``, **the application
+will continue to maintain active database connections**, even when the
+SQLAlchemy engine itself fully discards a connection or if the engine is
+disposed.
+
+.. seealso::
+
+ `pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ -
+ in the PyODBC documentation.
+
+Driver / Unicode Support
+-------------------------
+
+PyODBC works best with Microsoft ODBC drivers, particularly in the area
+of Unicode support on both Python 2 and Python 3.
+
+Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not**
+recommended; there have been historically many Unicode-related issues
+in this area, including before Microsoft offered ODBC drivers for Linux
+and OSX. Now that Microsoft offers drivers for all platforms, for
+PyODBC support these are recommended. FreeTDS remains relevant for
+non-ODBC drivers such as pymssql where it works very well.
+
+
+Rowcount Support
+----------------
+
+Pyodbc only has partial support for rowcount. See the notes at
+:ref:`mssql_rowcount_versioning` for important notes when using ORM
+versioning.
+
+.. _mssql_pyodbc_fastexecutemany:
+
+Fast Executemany Mode
+---------------------
+
+The Pyodbc driver has added support for a "fast executemany" mode of execution
+which greatly reduces round trips for a DBAPI ``executemany()`` call when using
+Microsoft ODBC drivers, for **limited size batches that fit in memory**. The
+feature is enabled by setting the flag ``.fast_executemany`` on the DBAPI
+cursor when an executemany call is to be used. The SQLAlchemy pyodbc SQL
+Server dialect supports setting this flag automatically when the
+``.fast_executemany`` flag is passed to
+:func:`_sa.create_engine` ; note that the ODBC driver must be the Microsoft
+driver in order to use this flag::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server",
+ fast_executemany=True)
+
+.. warning:: The pyodbc fast_executemany mode **buffers all rows in memory** and is
+ not compatible with very large batches of data. A future version of SQLAlchemy
+ may support this flag as a per-execution option instead.
+
+.. versionadded:: 1.3
+
+.. seealso::
+
+ `fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_
+ - on github
+
+.. _mssql_pyodbc_setinputsizes:
+
+Setinputsizes Support
+-----------------------
+
+The pyodbc ``cursor.setinputsizes()`` method can be used if necessary. To
+enable this hook, pass ``use_setinputsizes=True`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("mssql+pyodbc://...", use_setinputsizes=True)
+
+The behavior of the hook can then be customized, as may be necessary
+particularly if fast_executemany is in use, via the
+:meth:`.DialectEvents.do_setinputsizes` hook. See that method for usage
+examples.
+
+.. versionchanged:: 1.4.1 The pyodbc dialects will not use setinputsizes
+ unless ``use_setinputsizes=True`` is passed.
+
+""" # noqa
+
+
+import datetime
+import decimal
+import re
+import struct
+
+from .base import BINARY
+from .base import DATETIMEOFFSET
+from .base import MSDialect
+from .base import MSExecutionContext
+from .base import VARBINARY
+from ... import exc
+from ... import types as sqltypes
+from ... import util
+from ...connectors.pyodbc import PyODBCConnector
+
+
+class _ms_numeric_pyodbc(object):
+
+ """Turns Decimals with adjusted() < 0 or > 7 into strings.
+
+ The routines here are needed for older pyodbc versions
+ as well as current mxODBC versions.
+
+ """
+
+ def bind_processor(self, dialect):
+
+ super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect)
+
+ if not dialect._need_decimal_fix:
+ return super_process
+
+ def process(value):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
+ adjusted = value.adjusted()
+ if adjusted < 0:
+ return self._small_dec_to_string(value)
+ elif adjusted > 7:
+ return self._large_dec_to_string(value)
+
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+
+ return process
+
+ # these routines needed for older versions of pyodbc.
+ # as of 2.1.8 this logic is integrated.
+
+ def _small_dec_to_string(self, value):
+ return "%s0.%s%s" % (
+ (value < 0 and "-" or ""),
+ "0" * (abs(value.adjusted()) - 1),
+ "".join([str(nint) for nint in value.as_tuple()[1]]),
+ )
+
+ def _large_dec_to_string(self, value):
+ _int = value.as_tuple()[1]
+ if "E" in str(value):
+ result = "%s%s%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int]),
+ "0" * (value.adjusted() - (len(_int) - 1)),
+ )
+ else:
+ if (len(_int) - 1) > value.adjusted():
+ result = "%s%s.%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ "".join([str(s) for s in _int][value.adjusted() + 1 :]),
+ )
+ else:
+ result = "%s%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ )
+ return result
+
+
+class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
+ pass
+
+
+class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
+ pass
+
+
+class _ms_binary_pyodbc(object):
+ """Wraps binary values in dialect-specific Binary wrapper.
+ If the value is null, return a pyodbc-specific BinaryNull
+ object to prevent pyODBC [and FreeTDS] from defaulting binary
+ NULL types to SQLWCHAR and causing implicit conversion errors.
+ """
+
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ # pyodbc-specific
+ return dialect.dbapi.BinaryNull
+
+ return process
+
+
+class _ODBCDateTimeBindProcessor(object):
+ """Add bind processors to handle datetimeoffset behaviors"""
+
+ has_tz = False
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.string_types):
+ # if a string was passed directly, allow it through
+ return value
+ elif not value.tzinfo or (not self.timezone and not self.has_tz):
+ # for DateTime(timezone=False)
+ return value
+ else:
+ # for DATETIMEOFFSET or DateTime(timezone=True)
+ #
+ # Convert to string format required by T-SQL
+ dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z")
+ # offset needs a colon, e.g., -0700 -> -07:00
+ # "UTC offset in the form (+-)HHMM[SS[.ffffff]]"
+ # backend currently rejects seconds / fractional seconds
+ dto_string = re.sub(
+ r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string
+ )
+ return dto_string
+
+ return process
+
+
+class _ODBCDateTime(_ODBCDateTimeBindProcessor, sqltypes.DateTime):
+ pass
+
+
+class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET):
+ has_tz = True
+
+
+class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY):
+ pass
+
+
+class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
+ pass
+
+
+class MSExecutionContext_pyodbc(MSExecutionContext):
+ _embedded_scope_identity = False
+
+ def pre_exec(self):
+ """where appropriate, issue "select scope_identity()" in the same
+ statement.
+
+ Background on why "scope_identity()" is preferable to "@@identity":
+ https://msdn.microsoft.com/en-us/library/ms190315.aspx
+
+ Background on why we attempt to embed "scope_identity()" into the same
+ statement as the INSERT:
+ https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
+
+ """
+
+ super(MSExecutionContext_pyodbc, self).pre_exec()
+
+ # don't embed the scope_identity select into an
+ # "INSERT .. DEFAULT VALUES"
+ if (
+ self._select_lastrowid
+ and self.dialect.use_scope_identity
+ and len(self.parameters[0])
+ ):
+ self._embedded_scope_identity = True
+
+ self.statement += "; select scope_identity()"
+
+ def post_exec(self):
+ if self._embedded_scope_identity:
+ # Fetch the last inserted id from the manipulated statement
+ # We may have to skip over a number of result sets with
+ # no data (due to triggers, etc.)
+ while True:
+ try:
+ # fetchall() ensures the cursor is consumed
+ # without closing it (FreeTDS particularly)
+ row = self.cursor.fetchall()[0]
+ break
+ except self.dialect.dbapi.Error:
+ # no way around this - nextset() consumes the previous set
+ # so we need to just keep flipping
+ self.cursor.nextset()
+
+ self._lastrowid = int(row[0])
+ else:
+ super(MSExecutionContext_pyodbc, self).post_exec()
+
+
+class MSDialect_pyodbc(PyODBCConnector, MSDialect):
+ supports_statement_cache = True
+
+ # mssql still has problems with this on Linux
+ supports_sane_rowcount_returning = False
+
+ execution_ctx_cls = MSExecutionContext_pyodbc
+
+ colspecs = util.update_copy(
+ MSDialect.colspecs,
+ {
+ sqltypes.Numeric: _MSNumeric_pyodbc,
+ sqltypes.Float: _MSFloat_pyodbc,
+ BINARY: _BINARY_pyodbc,
+ # support DateTime(timezone=True)
+ sqltypes.DateTime: _ODBCDateTime,
+ DATETIMEOFFSET: _ODBCDATETIMEOFFSET,
+ # SQL Server dialect has a VARBINARY that is just to support
+ # "deprecate_large_types" w/ VARBINARY(max), but also we must
+ # handle the usual SQL standard VARBINARY
+ VARBINARY: _VARBINARY_pyodbc,
+ sqltypes.VARBINARY: _VARBINARY_pyodbc,
+ sqltypes.LargeBinary: _VARBINARY_pyodbc,
+ },
+ )
+
+ def __init__(
+ self, description_encoding=None, fast_executemany=False, **params
+ ):
+ if "description_encoding" in params:
+ self.description_encoding = params.pop("description_encoding")
+ super(MSDialect_pyodbc, self).__init__(**params)
+ self.use_scope_identity = (
+ self.use_scope_identity
+ and self.dbapi
+ and hasattr(self.dbapi.Cursor, "nextset")
+ )
+ self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
+ 2,
+ 1,
+ 8,
+ )
+ self.fast_executemany = fast_executemany
+
+ def _get_server_version_info(self, connection):
+ try:
+ # "Version of the instance of SQL Server, in the form
+ # of 'major.minor.build.revision'"
+ raw = connection.exec_driver_sql(
+ "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
+ ).scalar()
+ except exc.DBAPIError:
+ # SQL Server docs indicate this function isn't present prior to
+ # 2008. Before we had the VARCHAR cast above, pyodbc would also
+ # fail on this query.
+ return super(MSDialect_pyodbc, self)._get_server_version_info(
+ connection, allow_chars=False
+ )
+ else:
+ version = []
+ r = re.compile(r"[.\-]")
+ for n in r.split(raw):
+ try:
+ version.append(int(n))
+ except ValueError:
+ pass
+ return tuple(version)
+
+ def on_connect(self):
+ super_ = super(MSDialect_pyodbc, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ self._setup_timestampoffset_type(conn)
+
+ return on_connect
+
+ def _setup_timestampoffset_type(self, connection):
+ # output converter function for datetimeoffset
+ def _handle_datetimeoffset(dto_value):
+ tup = struct.unpack("<6hI2h", dto_value)
+ return datetime.datetime(
+ tup[0],
+ tup[1],
+ tup[2],
+ tup[3],
+ tup[4],
+ tup[5],
+ tup[6] // 1000,
+ util.timezone(
+ datetime.timedelta(hours=tup[7], minutes=tup[8])
+ ),
+ )
+
+ odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h
+ connection.add_output_converter(
+ odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset
+ )
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if self.fast_executemany:
+ cursor.fast_executemany = True
+ super(MSDialect_pyodbc, self).do_executemany(
+ cursor, statement, parameters, context=context
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ code = e.args[0]
+ if code in {
+ "08S01",
+ "01000",
+ "01002",
+ "08003",
+ "08007",
+ "08S02",
+ "08001",
+ "HYT00",
+ "HY010",
+ "10054",
+ }:
+ return True
+ return super(MSDialect_pyodbc, self).is_disconnect(
+ e, connection, cursor
+ )
+
+
+dialect = MSDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
new file mode 100644
index 0000000..04c83d1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/__init__.py
@@ -0,0 +1,103 @@
+# mysql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import cymysql # noqa
+from . import mariadbconnector # noqa
+from . import mysqlconnector # noqa
+from . import mysqldb # noqa
+from . import oursql # noqa
+from . import pymysql # noqa
+from . import pyodbc # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import BLOB
+from .base import BOOLEAN
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DECIMAL
+from .base import DOUBLE
+from .base import ENUM
+from .base import FLOAT
+from .base import INTEGER
+from .base import JSON
+from .base import LONGBLOB
+from .base import LONGTEXT
+from .base import MEDIUMBLOB
+from .base import MEDIUMINT
+from .base import MEDIUMTEXT
+from .base import NCHAR
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import REAL
+from .base import SET
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TINYBLOB
+from .base import TINYINT
+from .base import TINYTEXT
+from .base import VARBINARY
+from .base import VARCHAR
+from .base import YEAR
+from .dml import Insert
+from .dml import insert
+from .expression import match
+from ...util import compat
+
+if compat.py3k:
+ from . import aiomysql # noqa
+ from . import asyncmy # noqa
+
+# default dialect
+base.dialect = dialect = mysqldb.dialect
+
+__all__ = (
+ "BIGINT",
+ "BINARY",
+ "BIT",
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DOUBLE",
+ "ENUM",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "INTEGER",
+ "JSON",
+ "LONGBLOB",
+ "LONGTEXT",
+ "MEDIUMBLOB",
+ "MEDIUMINT",
+ "MEDIUMTEXT",
+ "NCHAR",
+ "NVARCHAR",
+ "NUMERIC",
+ "SET",
+ "SMALLINT",
+ "REAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TINYBLOB",
+ "TINYINT",
+ "TINYTEXT",
+ "VARBINARY",
+ "VARCHAR",
+ "YEAR",
+ "dialect",
+ "insert",
+ "Insert",
+ "match",
+)
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
new file mode 100644
index 0000000..975467c
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py
@@ -0,0 +1,317 @@
+# mysql/aiomysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+aiomysql
+ :name: aiomysql
+ :dbapi: aiomysql
+ :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/aio-libs/aiomysql
+
+.. warning:: The aiomysql dialect is not currently tested as part of
+ SQLAlchemy’s continuous integration. As of September, 2021 the driver
+ appears to be unmaintained and no longer functions for Python version 3.10,
+ and additionally depends on a significantly outdated version of PyMySQL.
+ Please refer to the :ref:`asyncmy` dialect for current MySQL/MariaDB asyncio
+ functionality.
+
+The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
+
+Using a special asyncio mediation layer, the aiomysql dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiomysql_cursor:
+ server_side = False
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "await_",
+ "_cursor",
+ "_rows",
+ )
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ # see https://github.com/aio-libs/aiomysql/issues/543
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ return self.await_(self._execute_async(operation, parameters))
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._executemany_async(operation, seq_of_parameters)
+ )
+
+ async def _execute_async(self, operation, parameters):
+ async with self._adapt_connection._execute_mutex:
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if not self.server_side:
+ # aiomysql has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(self, operation, seq_of_parameters):
+ async with self._adapt_connection._execute_mutex:
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
+ __slots__ = ()
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.aiomysql.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiomysql_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection", "_execute_mutex")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+ self._execute_mutex = asyncio.Lock()
+
+ def ping(self, reconnect):
+ return self.await_(self._connection.ping(reconnect))
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiomysql_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiomysql_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiomysql_dbapi:
+ def __init__(self, aiomysql, pymysql):
+ self.aiomysql = aiomysql
+ self.pymysql = pymysql
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.aiomysql, name))
+
+ for name in (
+ "NUMBER",
+ "STRING",
+ "DATETIME",
+ "BINARY",
+ "TIMESTAMP",
+ "Binary",
+ ):
+ setattr(self, name, getattr(self.pymysql, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_aiomysql_connection(
+ self,
+ await_fallback(self.aiomysql.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_aiomysql_connection(
+ self,
+ await_only(self.aiomysql.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_aiomysql(MySQLDialect_pymysql):
+ driver = "aiomysql"
+ supports_statement_cache = True
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_aiomysql_ss_cursor
+
+ is_async = True
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiomysql_dbapi(
+ __import__("aiomysql"), __import__("pymysql")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ return super(MySQLDialect_aiomysql, self).create_connect_args(
+ url, _translate_args=dict(username="user", database="db")
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_aiomysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return "not connected" in str_e
+
+ def _found_rows_client_flag(self):
+ from pymysql.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = MySQLDialect_aiomysql
diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py
new file mode 100644
index 0000000..521918a
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py
@@ -0,0 +1,328 @@
+# mysql/asyncmy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+asyncmy
+ :name: asyncmy
+ :dbapi: asyncmy
+ :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/long2ice/asyncmy
+
+.. note:: The asyncmy dialect as of September, 2021 was added to provide
+ MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database
+ driver has become unmaintained, however asyncmy is itself very new.
+
+Using a special asyncio mediation layer, the asyncmy dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import asynccontextmanager
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_asyncmy_cursor:
+ server_side = False
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "await_",
+ "_cursor",
+ "_rows",
+ )
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ return self.await_(self._execute_async(operation, parameters))
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._executemany_async(operation, seq_of_parameters)
+ )
+
+ async def _execute_async(self, operation, parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if not self.server_side:
+ # asyncmy has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(self, operation, seq_of_parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
+ __slots__ = ()
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.asyncmy.cursors.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_asyncmy_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection", "_execute_mutex")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+ self._execute_mutex = asyncio.Lock()
+
+ @asynccontextmanager
+ async def _mutex_and_adapt_errors(self):
+ async with self._execute_mutex:
+ try:
+ yield
+ except AttributeError:
+ raise self.dbapi.InternalError(
+ "network operation failed due to asyncmy attribute error"
+ )
+
+ def ping(self, reconnect):
+ assert not reconnect
+ return self.await_(self._do_ping())
+
+ async def _do_ping(self):
+ async with self._mutex_and_adapt_errors():
+ return await self._connection.ping(False)
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_asyncmy_ss_cursor(self)
+ else:
+ return AsyncAdapt_asyncmy_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+def _Binary(x):
+ """Return x as a binary type."""
+ return bytes(x)
+
+
+class AsyncAdapt_asyncmy_dbapi:
+ def __init__(self, asyncmy):
+ self.asyncmy = asyncmy
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.asyncmy.errors, name))
+
+ STRING = util.symbol("STRING")
+ NUMBER = util.symbol("NUMBER")
+ BINARY = util.symbol("BINARY")
+ DATETIME = util.symbol("DATETIME")
+ TIMESTAMP = util.symbol("TIMESTAMP")
+ Binary = staticmethod(_Binary)
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_asyncmy_connection(
+ self,
+ await_fallback(self.asyncmy.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_asyncmy_connection(
+ self,
+ await_only(self.asyncmy.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_asyncmy(MySQLDialect_pymysql):
+ driver = "asyncmy"
+ supports_statement_cache = True
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_asyncmy_ss_cursor
+
+ is_async = True
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ return super(MySQLDialect_asyncmy, self).create_connect_args(
+ url, _translate_args=dict(username="user", database="db")
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_asyncmy, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return (
+ "not connected" in str_e or "network operation failed" in str_e
+ )
+
+ def _found_rows_client_flag(self):
+ from asyncmy.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = MySQLDialect_asyncmy
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
new file mode 100644
index 0000000..111c63b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -0,0 +1,3306 @@
+# mysql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: mysql
+ :name: MySQL / MariaDB
+ :full_support: 5.6, 5.7, 8.0 / 10.4, 10.5
+ :normal_support: 5.6+ / 10+
+ :best_effort: 5.0.2+ / 5.0.2+
+
+Supported Versions and Features
+-------------------------------
+
+SQLAlchemy supports MySQL starting with version 5.0.2 through modern releases,
+as well as all modern versions of MariaDB. See the official MySQL
+documentation for detailed information about features supported in any given
+server release.
+
+.. versionchanged:: 1.4 minimum MySQL version supported is now 5.0.2.
+
+MariaDB Support
+~~~~~~~~~~~~~~~
+
+The MariaDB variant of MySQL retains fundamental compatibility with MySQL's
+protocols however the development of these two products continues to diverge.
+Within the realm of SQLAlchemy, the two databases have a small number of
+syntactical and behavioral differences that SQLAlchemy accommodates automatically.
+To connect to a MariaDB database, no changes to the database URL are required::
+
+
+ engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4")
+
+Upon first connect, the SQLAlchemy dialect employs a
+server version detection scheme that determines if the
+backing database reports as MariaDB. Based on this flag, the dialect
+can make different choices in those of areas where its behavior
+must be different.
+
+.. _mysql_mariadb_only_mode:
+
+MariaDB-Only Mode
+~~~~~~~~~~~~~~~~~
+
+The dialect also supports an **optional** "MariaDB-only" mode of connection, which may be
+useful for the case where an application makes use of MariaDB-specific features
+and is not compatible with a MySQL database. To use this mode of operation,
+replace the "mysql" token in the above URL with "mariadb"::
+
+ engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4")
+
+The above engine, upon first connect, will raise an error if the server version
+detection detects that the backing database is not MariaDB.
+
+When using an engine with ``"mariadb"`` as the dialect name, **all mysql-specific options
+that include the name "mysql" in them are now named with "mariadb"**. This means
+options like ``mysql_engine`` should be named ``mariadb_engine``, etc. Both
+"mysql" and "mariadb" options can be used simultaneously for applications that
+use URLs with both "mysql" and "mariadb" dialects::
+
+ my_table = Table(
+ "mytable",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("textdata", String(50)),
+ mariadb_engine="InnoDB",
+ mysql_engine="InnoDB",
+ )
+
+ Index(
+ "textdata_ix",
+ my_table.c.textdata,
+ mysql_prefix="FULLTEXT",
+ mariadb_prefix="FULLTEXT",
+ )
+
+Similar behavior will occur when the above structures are reflected, i.e. the
+"mariadb" prefix will be present in the option names when the database URL
+is based on the "mariadb" name.
+
+.. versionadded:: 1.4 Added "mariadb" dialect name supporting "MariaDB-only mode"
+ for the MySQL dialect.
+
+.. _mysql_connection_timeouts:
+
+Connection Timeouts and Disconnects
+-----------------------------------
+
+MySQL / MariaDB feature an automatic connection close behavior, for connections that
+have been idle for a fixed period of time, defaulting to eight hours.
+To circumvent having this issue, use
+the :paramref:`_sa.create_engine.pool_recycle` option which ensures that
+a connection will be discarded and replaced with a new one if it has been
+present in the pool for a fixed number of seconds::
+
+ engine = create_engine('mysql+mysqldb://...', pool_recycle=3600)
+
+For more comprehensive disconnect detection of pooled connections, including
+accommodation of server restarts and network issues, a pre-ping approach may
+be employed. See :ref:`pool_disconnects` for current approaches.
+
+.. seealso::
+
+ :ref:`pool_disconnects` - Background on several techniques for dealing
+ with timed out connections as well as database restarts.
+
+.. _mysql_storage_engines:
+
+CREATE TABLE arguments including Storage Engines
+------------------------------------------------
+
+Both MySQL's and MariaDB's CREATE TABLE syntax includes a wide array of special options,
+including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``,
+``INSERT_METHOD``, and many more.
+To accommodate the rendering of these arguments, specify the form
+``mysql_argument_name="value"``. For example, to specify a table with
+``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE``
+of ``1024``::
+
+ Table('mytable', metadata,
+ Column('data', String(32)),
+ mysql_engine='InnoDB',
+ mysql_charset='utf8mb4',
+ mysql_key_block_size="1024"
+ )
+
+When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against
+the "mariadb" prefix must be included as well. The values can of course
+vary independently so that different settings on MySQL vs. MariaDB may
+be maintained::
+
+ # support both "mysql" and "mariadb-only" engine URLs
+
+ Table('mytable', metadata,
+ Column('data', String(32)),
+
+ mysql_engine='InnoDB',
+ mariadb_engine='InnoDB',
+
+ mysql_charset='utf8mb4',
+ mariadb_charset='utf8',
+
+ mysql_key_block_size="1024"
+ mariadb_key_block_size="1024"
+
+ )
+
+The MySQL / MariaDB dialects will normally transfer any keyword specified as
+``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the
+``CREATE TABLE`` statement. A handful of these names will render with a space
+instead of an underscore; to support this, the MySQL dialect has awareness of
+these particular names, which include ``DATA DIRECTORY``
+(e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g.
+``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g.
+``mysql_index_directory``).
+
+The most common argument is ``mysql_engine``, which refers to the storage
+engine for the table. Historically, MySQL server installations would default
+to ``MyISAM`` for this value, although newer versions may be defaulting
+to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support
+of transactions and foreign keys.
+
+A :class:`_schema.Table`
+that is created in a MySQL / MariaDB database with a storage engine
+of ``MyISAM`` will be essentially non-transactional, meaning any
+INSERT/UPDATE/DELETE statement referring to this table will be invoked as
+autocommit. It also will have no support for foreign key constraints; while
+the ``CREATE TABLE`` statement accepts foreign key options, when using the
+``MyISAM`` storage engine these arguments are discarded. Reflecting such a
+table will also produce no foreign key constraint information.
+
+For fully atomic transactions as well as support for foreign key
+constraints, all participating ``CREATE TABLE`` statements must specify a
+transactional engine, which in the vast majority of cases is ``InnoDB``.
+
+
+Case Sensitivity and Table Reflection
+-------------------------------------
+
+Both MySQL and MariaDB have inconsistent support for case-sensitive identifier
+names, basing support on specific details of the underlying
+operating system. However, it has been observed that no matter
+what case sensitivity behavior is present, the names of tables in
+foreign key declarations are *always* received from the database
+as all-lower case, making it impossible to accurately reflect a
+schema where inter-related tables use mixed-case identifier names.
+
+Therefore it is strongly advised that table names be declared as
+all lower case both within SQLAlchemy as well as on the MySQL / MariaDB
+database itself, especially if database reflection features are
+to be used.
+
+.. _mysql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+All MySQL / MariaDB dialects support setting of transaction isolation level both via a
+dialect-specific parameter :paramref:`_sa.create_engine.isolation_level`
+accepted
+by :func:`_sa.create_engine`, as well as the
+:paramref:`.Connection.execution_options.isolation_level` argument as passed to
+:meth:`_engine.Connection.execution_options`.
+This feature works by issuing the
+command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` for each new
+connection. For the special AUTOCOMMIT isolation level, DBAPI-specific
+techniques are used.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "mysql://scott:tiger@localhost/test",
+ isolation_level="READ UNCOMMITTED"
+ )
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="READ COMMITTED"
+ )
+
+Valid values for ``isolation_level`` include:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+The special ``AUTOCOMMIT`` value makes use of the various "autocommit"
+attributes provided by specific DBAPIs, and is currently supported by
+MySQLdb, MySQL-Client, MySQL-Connector Python, and PyMySQL. Using it,
+the database connection will return true for the value of
+``SELECT @@autocommit;``.
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+AUTO_INCREMENT Behavior
+-----------------------
+
+When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on
+the first :class:`.Integer` primary key column which is not marked as a
+foreign key::
+
+ >>> t = Table('mytable', metadata,
+ ... Column('mytable_id', Integer, primary_key=True)
+ ... )
+ >>> t.create()
+ CREATE TABLE mytable (
+ id INTEGER NOT NULL AUTO_INCREMENT,
+ PRIMARY KEY (id)
+ )
+
+You can disable this behavior by passing ``False`` to the
+:paramref:`_schema.Column.autoincrement` argument of :class:`_schema.Column`.
+This flag
+can also be used to enable auto-increment on a secondary column in a
+multi-column key for some storage engines::
+
+ Table('mytable', metadata,
+ Column('gid', Integer, primary_key=True, autoincrement=False),
+ Column('id', Integer, primary_key=True)
+ )
+
+.. _mysql_ss_cursors:
+
+Server Side Cursors
+-------------------
+
+Server-side cursor support is available for the mysqlclient, PyMySQL,
+mariadbconnector dialects and may also be available in others. This makes use
+of either the "buffered=True/False" flag if available or by using a class such
+as ``MySQLdb.cursors.SSCursor`` or ``pymysql.cursors.SSCursor`` internally.
+
+
+Server side cursors are enabled on a per-statement basis by using the
+:paramref:`.Connection.execution_options.stream_results` connection execution
+option::
+
+ with engine.connect() as conn:
+ result = conn.execution_options(stream_results=True).execute(text("select * from table"))
+
+Note that some kinds of SQL statements may not be supported with
+server side cursors; generally, only SQL statements that return rows should be
+used with this option.
+
+.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated
+ and will be removed in a future release. Please use the
+ :paramref:`_engine.Connection.stream_results` execution option for
+ unbuffered cursor support.
+
+.. seealso::
+
+ :ref:`engine_stream_results`
+
+.. _mysql_unicode:
+
+Unicode
+-------
+
+Charset Selection
+~~~~~~~~~~~~~~~~~
+
+Most MySQL / MariaDB DBAPIs offer the option to set the client character set for
+a connection. This is typically delivered using the ``charset`` parameter
+in the URL, such as::
+
+ e = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4")
+
+This charset is the **client character set** for the connection. Some
+MySQL DBAPIs will default this to a value such as ``latin1``, and some
+will make use of the ``default-character-set`` setting in the ``my.cnf``
+file as well. Documentation for the DBAPI in use should be consulted
+for specific behavior.
+
+The encoding used for Unicode has traditionally been ``'utf8'``. However, for
+MySQL versions 5.5.3 and MariaDB 5.5 on forward, a new MySQL-specific encoding
+``'utf8mb4'`` has been introduced, and as of MySQL 8.0 a warning is emitted by
+the server if plain ``utf8`` is specified within any server-side directives,
+replaced with ``utf8mb3``. The rationale for this new encoding is due to the
+fact that MySQL's legacy utf-8 encoding only supports codepoints up to three
+bytes instead of four. Therefore, when communicating with a MySQL or MariaDB
+database that includes codepoints more than three bytes in size, this new
+charset is preferred, if supported by both the database as well as the client
+DBAPI, as in::
+
+ e = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4")
+
+All modern DBAPIs should support the ``utf8mb4`` charset.
+
+In order to use ``utf8mb4`` encoding for a schema that was created with legacy
+``utf8``, changes to the MySQL/MariaDB schema and/or server configuration may be
+required.
+
+.. seealso::
+
+ `The utf8mb4 Character Set \
+ <https://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html>`_ - \
+ in the MySQL documentation
+
+.. _mysql_binary_introducer:
+
+Dealing with Binary Data Warnings and Unicode
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now
+emit a warning when attempting to pass binary data to the database, while a
+character set encoding is also in place, when the binary data itself is not
+valid for that encoding::
+
+ default.py:509: Warning: (1300, "Invalid utf8mb4 character string:
+ 'F9876A'")
+ cursor.execute(statement, parameters)
+
+This warning is due to the fact that the MySQL client library is attempting to
+interpret the binary string as a unicode object even if a datatype such
+as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires
+a binary "character set introducer" be present before any non-NULL value
+that renders like this::
+
+ INSERT INTO table (data) VALUES (_binary %s)
+
+These character set introducers are provided by the DBAPI driver, assuming the
+use of mysqlclient or PyMySQL (both of which are recommended). Add the query
+string parameter ``binary_prefix=true`` to the URL to repair this warning::
+
+ # mysqlclient
+ engine = create_engine(
+ "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true")
+
+ # PyMySQL
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true")
+
+
+The ``binary_prefix`` flag may or may not be supported by other MySQL drivers.
+
+SQLAlchemy itself cannot render this ``_binary`` prefix reliably, as it does
+not work with the NULL value, which is valid to be sent as a bound parameter.
+As the MySQL driver renders parameters directly into the SQL string, it's the
+most efficient place for this additional keyword to be passed.
+
+.. seealso::
+
+ `Character set introducers <https://dev.mysql.com/doc/refman/5.7/en/charset-introducer.html>`_ - on the MySQL website
+
+
+ANSI Quoting Style
+------------------
+
+MySQL / MariaDB feature two varieties of identifier "quoting style", one using
+backticks and the other using quotes, e.g. ```some_identifier``` vs.
+``"some_identifier"``. All MySQL dialects detect which version
+is in use by checking the value of :ref:`sql_mode<mysql_sql_mode>` when a connection is first
+established with a particular :class:`_engine.Engine`.
+This quoting style comes
+into play when rendering table and column names as well as when reflecting
+existing database structures. The detection is entirely automatic and
+no special configuration is needed to use either quoting style.
+
+
+.. _mysql_sql_mode:
+
+Changing the sql_mode
+---------------------
+
+MySQL supports operating in multiple
+`Server SQL Modes <https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html>`_ for
+both Servers and Clients. To change the ``sql_mode`` for a given application, a
+developer can leverage SQLAlchemy's Events system.
+
+In the following example, the event system is used to set the ``sql_mode`` on
+the ``first_connect`` and ``connect`` events::
+
+ from sqlalchemy import create_engine, event
+
+ eng = create_engine("mysql://scott:tiger@localhost/test", echo='debug')
+
+ # `insert=True` will ensure this is the very first listener to run
+ @event.listens_for(eng, "connect", insert=True)
+ def connect(dbapi_connection, connection_record):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'")
+
+ conn = eng.connect()
+
+In the example illustrated above, the "connect" event will invoke the "SET"
+statement on the connection at the moment a particular DBAPI connection is
+first created for a given Pool, before the connection is made available to the
+connection pool. Additionally, because the function was registered with
+``insert=True``, it will be prepended to the internal list of registered
+functions.
+
+
+MySQL / MariaDB SQL Extensions
+------------------------------
+
+Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic
+function and operator support::
+
+ table.select(table.c.password==func.md5('plaintext'))
+ table.select(table.c.username.op('regexp')('^[a-d]'))
+
+And of course any valid SQL statement can be executed as a string as well.
+
+Some limited direct support for MySQL / MariaDB extensions to SQL is currently
+available.
+
+* INSERT..ON DUPLICATE KEY UPDATE: See
+ :ref:`mysql_insert_on_duplicate_key_update`
+
+* SELECT pragma, use :meth:`_expression.Select.prefix_with` and
+ :meth:`_query.Query.prefix_with`::
+
+ select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT'])
+
+* UPDATE with LIMIT::
+
+ update(..., mysql_limit=10, mariadb_limit=10)
+
+* optimizer hints, use :meth:`_expression.Select.prefix_with` and
+ :meth:`_query.Query.prefix_with`::
+
+ select(...).prefix_with("/*+ NO_RANGE_OPTIMIZATION(t4 PRIMARY) */")
+
+* index hints, use :meth:`_expression.Select.with_hint` and
+ :meth:`_query.Query.with_hint`::
+
+ select(...).with_hint(some_table, "USE INDEX xyz")
+
+* MATCH operator support::
+
+ from sqlalchemy.dialects.mysql import match
+ select(...).where(match(col1, col2, against="some expr").in_boolean_mode())
+
+ .. seealso::
+
+ :class:`_mysql.match`
+
+.. _mysql_insert_on_duplicate_key_update:
+
+INSERT...ON DUPLICATE KEY UPDATE (Upsert)
+------------------------------------------
+
+MySQL / MariaDB allow "upserts" (update or insert)
+of rows into a table via the ``ON DUPLICATE KEY UPDATE`` clause of the
+``INSERT`` statement. A candidate row will only be inserted if that row does
+not match an existing primary or unique key in the table; otherwise, an UPDATE
+will be performed. The statement allows for separate specification of the
+values to INSERT versus the values for UPDATE.
+
+SQLAlchemy provides ``ON DUPLICATE KEY UPDATE`` support via the MySQL-specific
+:func:`.mysql.insert()` function, which provides
+the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.mysql import insert
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... data=insert_stmt.inserted.data,
+ ... status='U'
+ ... )
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = VALUES(data), status = %s
+
+
+Unlike PostgreSQL's "ON CONFLICT" phrase, the "ON DUPLICATE KEY UPDATE"
+phrase will always match on any primary key or unique key, and will always
+perform an UPDATE if there's a match; there are no options for it to raise
+an error or to skip performing an UPDATE.
+
+``ON DUPLICATE KEY UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are normally specified using
+keyword arguments passed to the
+:meth:`_mysql.Insert.on_duplicate_key_update`
+given column key values (usually the name of the column, unless it
+specifies :paramref:`_schema.Column.key`
+) as keys and literal or SQL expressions
+as values:
+
+.. sourcecode:: pycon+sql
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... data="some data",
+ ... updated_at=func.current_timestamp(),
+ ... )
+
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP
+
+In a manner similar to that of :meth:`.UpdateBase.values`, other parameter
+forms are accepted, including a single dictionary:
+
+.. sourcecode:: pycon+sql
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... {"data": "some data", "updated_at": func.current_timestamp()},
+ ... )
+
+as well as a list of 2-tuples, which will automatically provide
+a parameter-ordered UPDATE statement in a manner similar to that described
+at :ref:`tutorial_parameter_ordered_updates`. Unlike the :class:`_expression.Update`
+object,
+no special flag is needed to specify the intent since the argument form is
+this context is unambiguous:
+
+.. sourcecode:: pycon+sql
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... [
+ ... ("data", "some data"),
+ ... ("updated_at", func.current_timestamp()),
+ ... ]
+ ... )
+
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP
+
+.. versionchanged:: 1.3 support for parameter-ordered UPDATE clause within
+ MySQL ON DUPLICATE KEY UPDATE
+
+.. warning::
+
+ The :meth:`_mysql.Insert.on_duplicate_key_update`
+ method does **not** take into
+ account Python-side default UPDATE values or generation functions, e.g.
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON DUPLICATE KEY style of UPDATE,
+ unless they are manually specified explicitly in the parameters.
+
+
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`_mysql.Insert.inserted` is available as an attribute on
+the :class:`_mysql.Insert` object; this object is a
+:class:`_expression.ColumnCollection` which contains all columns of the target
+table:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh')
+
+ >>> do_update_stmt = stmt.on_duplicate_key_update(
+ ... data="updated value",
+ ... author=stmt.inserted.author
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (%s, %s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, author = VALUES(author)
+
+When rendered, the "inserted" namespace will produce the expression
+``VALUES(<columnname>)``.
+
+.. versionadded:: 1.2 Added support for MySQL ON DUPLICATE KEY UPDATE clause
+
+
+
+rowcount Support
+----------------
+
+SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the
+usual definition of "number of rows matched by an UPDATE or DELETE" statement.
+This is in contradiction to the default setting on most MySQL DBAPI drivers,
+which is "number of rows actually modified/deleted". For this reason, the
+SQLAlchemy MySQL dialects always add the ``constants.CLIENT.FOUND_ROWS``
+flag, or whatever is equivalent for the target dialect, upon connection.
+This setting is currently hardcoded.
+
+.. seealso::
+
+ :attr:`_engine.CursorResult.rowcount`
+
+
+.. _mysql_indexes:
+
+MySQL / MariaDB- Specific Index Options
+-----------------------------------------
+
+MySQL and MariaDB-specific extensions to the :class:`.Index` construct are available.
+
+Index Length
+~~~~~~~~~~~~~
+
+MySQL and MariaDB both provide an option to create index entries with a certain length, where
+"length" refers to the number of characters or bytes in each value which will
+become part of the index. SQLAlchemy provides this feature via the
+``mysql_length`` and/or ``mariadb_length`` parameters::
+
+ Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10)
+
+ Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4,
+ 'b': 9})
+
+ Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4,
+ 'b': 9})
+
+Prefix lengths are given in characters for nonbinary string types and in bytes
+for binary string types. The value passed to the keyword argument *must* be
+either an integer (and, thus, specify the same prefix length value for all
+columns of the index) or a dict in which keys are column names and values are
+prefix length values for corresponding columns. MySQL and MariaDB only allow a
+length for a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY,
+VARBINARY and BLOB.
+
+Index Prefixes
+~~~~~~~~~~~~~~
+
+MySQL storage engines permit you to specify an index prefix when creating
+an index. SQLAlchemy provides this feature via the
+``mysql_prefix`` parameter on :class:`.Index`::
+
+ Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL
+storage engine.
+
+.. versionadded:: 1.1.5
+
+.. seealso::
+
+ `CREATE INDEX <https://dev.mysql.com/doc/refman/5.0/en/create-index.html>`_ - MySQL documentation
+
+Index Types
+~~~~~~~~~~~~~
+
+Some MySQL storage engines permit you to specify an index type when creating
+an index or primary key constraint. SQLAlchemy provides this feature via the
+``mysql_using`` parameter on :class:`.Index`::
+
+ Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash')
+
+As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`::
+
+ PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index
+type for your MySQL storage engine.
+
+More information can be found at:
+
+https://dev.mysql.com/doc/refman/5.0/en/create-index.html
+
+https://dev.mysql.com/doc/refman/5.0/en/create-table.html
+
+Index Parsers
+~~~~~~~~~~~~~
+
+CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This
+is available using the keyword argument ``mysql_with_parser``::
+
+ Index(
+ 'my_index', my_table.c.data,
+ mysql_prefix='FULLTEXT', mysql_with_parser="ngram",
+ mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram",
+ )
+
+.. versionadded:: 1.3
+
+
+.. _mysql_foreign_keys:
+
+MySQL / MariaDB Foreign Keys
+-----------------------------
+
+MySQL and MariaDB's behavior regarding foreign keys has some important caveats.
+
+Foreign Key Arguments to Avoid
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Neither MySQL nor MariaDB support the foreign key arguments "DEFERRABLE", "INITIALLY",
+or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with
+:class:`_schema.ForeignKeyConstraint` or :class:`_schema.ForeignKey`
+will have the effect of
+these keywords being rendered in a DDL expression, which will then raise an
+error on MySQL or MariaDB. In order to use these keywords on a foreign key while having
+them ignored on a MySQL / MariaDB backend, use a custom compile rule::
+
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.schema import ForeignKeyConstraint
+
+ @compiles(ForeignKeyConstraint, "mysql", "mariadb")
+ def process(element, compiler, **kw):
+ element.deferrable = element.initially = None
+ return compiler.visit_foreign_key_constraint(element, **kw)
+
+The "MATCH" keyword is in fact more insidious, and is explicitly disallowed
+by SQLAlchemy in conjunction with the MySQL or MariaDB backends. This argument is
+silently ignored by MySQL / MariaDB, but in addition has the effect of ON UPDATE and ON
+DELETE options also being ignored by the backend. Therefore MATCH should
+never be used with the MySQL / MariaDB backends; as is the case with DEFERRABLE and
+INITIALLY, custom compilation rules can be used to correct a
+ForeignKeyConstraint at DDL definition time.
+
+Reflection of Foreign Key Constraints
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Not all MySQL / MariaDB storage engines support foreign keys. When using the
+very common ``MyISAM`` MySQL storage engine, the information loaded by table
+reflection will not include foreign keys. For these tables, you may supply a
+:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time::
+
+ Table('mytable', metadata,
+ ForeignKeyConstraint(['other_id'], ['othertable.other_id']),
+ autoload_with=engine
+ )
+
+.. seealso::
+
+ :ref:`mysql_storage_engines`
+
+.. _mysql_unique_constraints:
+
+MySQL / MariaDB Unique Constraints and Reflection
+----------------------------------------------------
+
+SQLAlchemy supports both the :class:`.Index` construct with the
+flag ``unique=True``, indicating a UNIQUE index, as well as the
+:class:`.UniqueConstraint` construct, representing a UNIQUE constraint.
+Both objects/syntaxes are supported by MySQL / MariaDB when emitting DDL to create
+these constraints. However, MySQL / MariaDB does not have a unique constraint
+construct that is separate from a unique index; that is, the "UNIQUE"
+constraint on MySQL / MariaDB is equivalent to creating a "UNIQUE INDEX".
+
+When reflecting these constructs, the
+:meth:`_reflection.Inspector.get_indexes`
+and the :meth:`_reflection.Inspector.get_unique_constraints`
+methods will **both**
+return an entry for a UNIQUE index in MySQL / MariaDB. However, when performing
+full table reflection using ``Table(..., autoload_with=engine)``,
+the :class:`.UniqueConstraint` construct is
+**not** part of the fully reflected :class:`_schema.Table` construct under any
+circumstances; this construct is always represented by a :class:`.Index`
+with the ``unique=True`` setting present in the :attr:`_schema.Table.indexes`
+collection.
+
+
+TIMESTAMP / DATETIME issues
+---------------------------
+
+.. _mysql_timestamp_onupdate:
+
+Rendering ON UPDATE CURRENT TIMESTAMP for MySQL / MariaDB's explicit_defaults_for_timestamp
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL / MariaDB have historically expanded the DDL for the :class:`_types.TIMESTAMP`
+datatype into the phrase "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE
+CURRENT_TIMESTAMP", which includes non-standard SQL that automatically updates
+the column with the current timestamp when an UPDATE occurs, eliminating the
+usual need to use a trigger in such a case where server-side update changes are
+desired.
+
+MySQL 5.6 introduced a new flag `explicit_defaults_for_timestamp
+<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html
+#sysvar_explicit_defaults_for_timestamp>`_ which disables the above behavior,
+and in MySQL 8 this flag defaults to true, meaning in order to get a MySQL
+"on update timestamp" without changing this flag, the above DDL must be
+rendered explicitly. Additionally, the same DDL is valid for use of the
+``DATETIME`` datatype as well.
+
+SQLAlchemy's MySQL dialect does not yet have an option to generate
+MySQL's "ON UPDATE CURRENT_TIMESTAMP" clause, noting that this is not a general
+purpose "ON UPDATE" as there is no such syntax in standard SQL. SQLAlchemy's
+:paramref:`_schema.Column.server_onupdate` parameter is currently not related
+to this special MySQL behavior.
+
+To generate this DDL, make use of the :paramref:`_schema.Column.server_default`
+parameter and pass a textual clause that also includes the ON UPDATE clause::
+
+ from sqlalchemy import Table, MetaData, Column, Integer, String, TIMESTAMP
+ from sqlalchemy import text
+
+ metadata = MetaData()
+
+ mytable = Table(
+ "mytable",
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column(
+ 'last_updated',
+ TIMESTAMP,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
+ )
+ )
+
+The same instructions apply to use of the :class:`_types.DateTime` and
+:class:`_types.DATETIME` datatypes::
+
+ from sqlalchemy import DateTime
+
+ mytable = Table(
+ "mytable",
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column(
+ 'last_updated',
+ DateTime,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
+ )
+ )
+
+
+Even though the :paramref:`_schema.Column.server_onupdate` feature does not
+generate this DDL, it still may be desirable to signal to the ORM that this
+updated value should be fetched. This syntax looks like the following::
+
+ from sqlalchemy.schema import FetchedValue
+
+ class MyClass(Base):
+ __tablename__ = 'mytable'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(String(50))
+ last_updated = Column(
+ TIMESTAMP,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
+ server_onupdate=FetchedValue()
+ )
+
+
+.. _mysql_timestamp_null:
+
+TIMESTAMP Columns and NULL
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL historically enforces that a column which specifies the
+TIMESTAMP datatype implicitly includes a default value of
+CURRENT_TIMESTAMP, even though this is not stated, and additionally
+sets the column as NOT NULL, the opposite behavior vs. that of all
+other datatypes::
+
+ mysql> CREATE TABLE ts_test (
+ -> a INTEGER,
+ -> b INTEGER NOT NULL,
+ -> c TIMESTAMP,
+ -> d TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ -> e TIMESTAMP NULL);
+ Query OK, 0 rows affected (0.03 sec)
+
+ mysql> SHOW CREATE TABLE ts_test;
+ +---------+-----------------------------------------------------
+ | Table | Create Table
+ +---------+-----------------------------------------------------
+ | ts_test | CREATE TABLE `ts_test` (
+ `a` int(11) DEFAULT NULL,
+ `b` int(11) NOT NULL,
+ `c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ `d` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `e` timestamp NULL DEFAULT NULL
+ ) ENGINE=MyISAM DEFAULT CHARSET=latin1
+
+Above, we see that an INTEGER column defaults to NULL, unless it is specified
+with NOT NULL. But when the column is of type TIMESTAMP, an implicit
+default of CURRENT_TIMESTAMP is generated which also coerces the column
+to be a NOT NULL, even though we did not specify it as such.
+
+This behavior of MySQL can be changed on the MySQL side using the
+`explicit_defaults_for_timestamp
+<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html
+#sysvar_explicit_defaults_for_timestamp>`_ configuration flag introduced in
+MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like
+any other datatype on the MySQL side with regards to defaults and nullability.
+
+However, to accommodate the vast majority of MySQL databases that do not
+specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with
+any TIMESTAMP column that does not specify ``nullable=False``. In order to
+accommodate newer databases that specify ``explicit_defaults_for_timestamp``,
+SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify
+``nullable=False``. The following example illustrates::
+
+ from sqlalchemy import MetaData, Integer, Table, Column, text
+ from sqlalchemy.dialects.mysql import TIMESTAMP
+
+ m = MetaData()
+ t = Table('ts_test', m,
+ Column('a', Integer),
+ Column('b', Integer, nullable=False),
+ Column('c', TIMESTAMP),
+ Column('d', TIMESTAMP, nullable=False)
+ )
+
+
+ from sqlalchemy import create_engine
+ e = create_engine("mysql://scott:tiger@localhost/test", echo=True)
+ m.create_all(e)
+
+output::
+
+ CREATE TABLE ts_test (
+ a INTEGER,
+ b INTEGER NOT NULL,
+ c TIMESTAMP NULL,
+ d TIMESTAMP NOT NULL
+ )
+
+.. versionchanged:: 1.0.0 - SQLAlchemy now renders NULL or NOT NULL in all
+ cases for TIMESTAMP columns, to accommodate
+ ``explicit_defaults_for_timestamp``. Prior to this version, it will
+ not render "NOT NULL" for a TIMESTAMP column that is ``nullable=False``.
+
+""" # noqa
+
+from array import array as _array
+from collections import defaultdict
+from itertools import compress
+import re
+
+from sqlalchemy import literal_column
+from sqlalchemy import text
+from sqlalchemy.sql import visitors
+from . import reflection as _reflection
+from .enumerated import ENUM
+from .enumerated import SET
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from .reserved_words import RESERVED_WORDS_MARIADB
+from .reserved_words import RESERVED_WORDS_MYSQL
+from .types import _FloatType
+from .types import _IntegerType
+from .types import _MatchType
+from .types import _NumericType
+from .types import _StringType
+from .types import BIGINT
+from .types import BIT
+from .types import CHAR
+from .types import DATETIME
+from .types import DECIMAL
+from .types import DOUBLE
+from .types import FLOAT
+from .types import INTEGER
+from .types import LONGBLOB
+from .types import LONGTEXT
+from .types import MEDIUMBLOB
+from .types import MEDIUMINT
+from .types import MEDIUMTEXT
+from .types import NCHAR
+from .types import NUMERIC
+from .types import NVARCHAR
+from .types import REAL
+from .types import SMALLINT
+from .types import TEXT
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TINYBLOB
+from .types import TINYINT
+from .types import TINYTEXT
+from .types import VARCHAR
+from .types import YEAR
+from ... import exc
+from ... import log
+from ... import schema as sa_schema
+from ... import sql
+from ... import types as sqltypes
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import functions
+from ...sql import operators
+from ...sql import roles
+from ...sql import util as sql_util
+from ...sql.sqltypes import Unicode
+from ...types import BINARY
+from ...types import BLOB
+from ...types import BOOLEAN
+from ...types import DATE
+from ...types import VARBINARY
+from ...util import topological
+
+AUTOCOMMIT_RE = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)",
+ re.I | re.UNICODE,
+)
+SET_RE = re.compile(
+ r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
+)
+
+
+# old names
+MSTime = TIME
+MSSet = SET
+MSEnum = ENUM
+MSLongBlob = LONGBLOB
+MSMediumBlob = MEDIUMBLOB
+MSTinyBlob = TINYBLOB
+MSBlob = BLOB
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSNChar = NCHAR
+MSNVarChar = NVARCHAR
+MSChar = CHAR
+MSString = VARCHAR
+MSLongText = LONGTEXT
+MSMediumText = MEDIUMTEXT
+MSTinyText = TINYTEXT
+MSText = TEXT
+MSYear = YEAR
+MSTimeStamp = TIMESTAMP
+MSBit = BIT
+MSSmallInteger = SMALLINT
+MSTinyInteger = TINYINT
+MSMediumInteger = MEDIUMINT
+MSBigInteger = BIGINT
+MSNumeric = NUMERIC
+MSDecimal = DECIMAL
+MSDouble = DOUBLE
+MSReal = REAL
+MSFloat = FLOAT
+MSInteger = INTEGER
+
+colspecs = {
+ _IntegerType: _IntegerType,
+ _NumericType: _NumericType,
+ _FloatType: _FloatType,
+ sqltypes.Numeric: NUMERIC,
+ sqltypes.Float: FLOAT,
+ sqltypes.Time: TIME,
+ sqltypes.Enum: ENUM,
+ sqltypes.MatchType: _MatchType,
+ sqltypes.JSON: JSON,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+}
+
+# Everything 3.23 through 5.1 excepting OpenGIS types.
+ischema_names = {
+ "bigint": BIGINT,
+ "binary": BINARY,
+ "bit": BIT,
+ "blob": BLOB,
+ "boolean": BOOLEAN,
+ "char": CHAR,
+ "date": DATE,
+ "datetime": DATETIME,
+ "decimal": DECIMAL,
+ "double": DOUBLE,
+ "enum": ENUM,
+ "fixed": DECIMAL,
+ "float": FLOAT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "json": JSON,
+ "longblob": LONGBLOB,
+ "longtext": LONGTEXT,
+ "mediumblob": MEDIUMBLOB,
+ "mediumint": MEDIUMINT,
+ "mediumtext": MEDIUMTEXT,
+ "nchar": NCHAR,
+ "nvarchar": NVARCHAR,
+ "numeric": NUMERIC,
+ "set": SET,
+ "smallint": SMALLINT,
+ "text": TEXT,
+ "time": TIME,
+ "timestamp": TIMESTAMP,
+ "tinyblob": TINYBLOB,
+ "tinyint": TINYINT,
+ "tinytext": TINYTEXT,
+ "varbinary": VARBINARY,
+ "varchar": VARCHAR,
+ "year": YEAR,
+}
+
+
+class MySQLExecutionContext(default.DefaultExecutionContext):
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_RE.match(statement)
+
+ def create_server_side_cursor(self):
+ if self.dialect.supports_server_side_cursors:
+ return self._dbapi_connection.cursor(self.dialect._sscursor)
+ else:
+ raise NotImplementedError()
+
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "select nextval(%s)"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+
+class MySQLCompiler(compiler.SQLCompiler):
+
+ render_table_with_column_in_update_from = True
+ """Overridden from base SQLCompiler value"""
+
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update({"milliseconds": "millisecond"})
+
+ def default_from(self):
+ """Called when a ``SELECT`` statement has no froms,
+ and no ``FROM`` clause is to be appended.
+
+ """
+ if self.stack:
+ stmt = self.stack[-1]["selectable"]
+ if stmt._where_criteria:
+ return " FROM DUAL"
+
+ return ""
+
+ def visit_random_func(self, fn, **kw):
+ return "rand%s" % self.function_argspec(fn)
+
+ def visit_sequence(self, seq, **kw):
+ return "nextval(%s)" % self.preparer.format_sequence(seq)
+
+ def visit_sysdate_func(self, fn, **kw):
+ return "SYSDATE()"
+
+ def _render_json_extract_from_binary(self, binary, operator, **kw):
+ # note we are intentionally calling upon the process() calls in the
+ # order in which they appear in the SQL String as this is used
+ # by positional parameter rendering
+
+ if binary.type._type_affinity is sqltypes.JSON:
+ return "JSON_EXTRACT(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ # for non-JSON, MySQL doesn't handle JSON null at all so it has to
+ # be explicit
+ case_expression = "CASE JSON_EXTRACT(%s, %s) WHEN 'null' THEN NULL" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ if binary.type._type_affinity is sqltypes.Integer:
+ type_expression = (
+ "ELSE CAST(JSON_EXTRACT(%s, %s) AS SIGNED INTEGER)"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ )
+ elif binary.type._type_affinity is sqltypes.Numeric:
+ if (
+ binary.type.scale is not None
+ and binary.type.precision is not None
+ ):
+ # using DECIMAL here because MySQL does not recognize NUMERIC
+ type_expression = (
+ "ELSE CAST(JSON_EXTRACT(%s, %s) AS DECIMAL(%s, %s))"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ binary.type.precision,
+ binary.type.scale,
+ )
+ )
+ else:
+ # FLOAT / REAL not added in MySQL til 8.0.17
+ type_expression = (
+ "ELSE JSON_EXTRACT(%s, %s)+0.0000000000000000000000"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ )
+ elif binary.type._type_affinity is sqltypes.Boolean:
+ # the NULL handling is particularly weird with boolean, so
+ # explicitly return true/false constants
+ type_expression = "WHEN true THEN true ELSE false"
+ elif binary.type._type_affinity is sqltypes.String:
+ # (gord): this fails with a JSON value that's a four byte unicode
+ # string. SQLite has the same problem at the moment
+ # (zzzeek): I'm not really sure. let's take a look at a test case
+ # that hits each backend and maybe make a requires rule for it?
+ type_expression = "ELSE JSON_UNQUOTE(JSON_EXTRACT(%s, %s))" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ else:
+ # other affinity....this is not expected right now
+ type_expression = "ELSE JSON_EXTRACT(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ return case_expression + " " + type_expression + " END"
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_on_duplicate_key_update(self, on_duplicate, **kw):
+ statement = self.current_executable
+
+ if on_duplicate._parameter_ordering:
+ parameter_ordering = [
+ coercions.expect(roles.DMLColumnRole, key)
+ for key in on_duplicate._parameter_ordering
+ ]
+ ordered_keys = set(parameter_ordering)
+ cols = [
+ statement.table.c[key]
+ for key in parameter_ordering
+ if key in statement.table.c
+ ] + [c for c in statement.table.c if c.key not in ordered_keys]
+ else:
+ cols = statement.table.c
+
+ clauses = []
+ # traverses through all table columns to preserve table column order
+ for column in (col for col in cols if col.key in on_duplicate.update):
+
+ val = on_duplicate.update[column.key]
+
+ if coercions._is_literal(val):
+ val = elements.BindParameter(None, val, type_=column.type)
+ value_text = self.process(val.self_group(), use_schema=False)
+ else:
+
+ def replace(obj):
+ if (
+ isinstance(obj, elements.BindParameter)
+ and obj.type._isnull
+ ):
+ obj = obj._clone()
+ obj.type = column.type
+ return obj
+ elif (
+ isinstance(obj, elements.ColumnClause)
+ and obj.table is on_duplicate.inserted_alias
+ ):
+ obj = literal_column(
+ "VALUES(" + self.preparer.quote(obj.name) + ")"
+ )
+ return obj
+ else:
+ # element is not replaced
+ return None
+
+ val = visitors.replacement_traverse(val, {}, replace)
+ value_text = self.process(val.self_group(), use_schema=False)
+
+ name_text = self.preparer.quote(column.name)
+ clauses.append("%s = %s" % (name_text, value_text))
+
+ non_matching = set(on_duplicate.update) - set(c.key for c in cols)
+ if non_matching:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.statement.table.name,
+ (", ".join("'%s'" % c for c in non_matching)),
+ )
+ )
+
+ return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
+
+ def visit_concat_op_binary(self, binary, operator, **kw):
+ return "concat(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ _match_valid_flag_combinations = frozenset(
+ (
+ # (boolean_mode, natural_language, query_expansion)
+ (False, False, False),
+ (True, False, False),
+ (False, True, False),
+ (False, False, True),
+ (False, True, True),
+ )
+ )
+
+ _match_flag_expressions = (
+ "IN BOOLEAN MODE",
+ "IN NATURAL LANGUAGE MODE",
+ "WITH QUERY EXPANSION",
+ )
+
+ def visit_mysql_match(self, element, **kw):
+ return self.visit_match_op_binary(element, element.operator, **kw)
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ """
+ Note that `mysql_boolean_mode` is enabled by default because of
+ backward compatibility
+ """
+
+ modifiers = binary.modifiers
+
+ boolean_mode = modifiers.get("mysql_boolean_mode", True)
+ natural_language = modifiers.get("mysql_natural_language", False)
+ query_expansion = modifiers.get("mysql_query_expansion", False)
+
+ flag_combination = (boolean_mode, natural_language, query_expansion)
+
+ if flag_combination not in self._match_valid_flag_combinations:
+ flags = (
+ "in_boolean_mode=%s" % boolean_mode,
+ "in_natural_language_mode=%s" % natural_language,
+ "with_query_expansion=%s" % query_expansion,
+ )
+
+ flags = ", ".join(flags)
+
+ raise exc.CompileError("Invalid MySQL match flags: %s" % flags)
+
+ match_clause = binary.left
+ match_clause = self.process(match_clause, **kw)
+ against_clause = self.process(binary.right, **kw)
+
+ if any(flag_combination):
+ flag_expressions = compress(
+ self._match_flag_expressions,
+ flag_combination,
+ )
+
+ against_clause = [against_clause]
+ against_clause.extend(flag_expressions)
+
+ against_clause = " ".join(against_clause)
+
+ return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause)
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def visit_typeclause(self, typeclause, type_=None, **kw):
+ if type_ is None:
+ type_ = typeclause.type.dialect_impl(self.dialect)
+ if isinstance(type_, sqltypes.TypeDecorator):
+ return self.visit_typeclause(typeclause, type_.impl, **kw)
+ elif isinstance(type_, sqltypes.Integer):
+ if getattr(type_, "unsigned", False):
+ return "UNSIGNED INTEGER"
+ else:
+ return "SIGNED INTEGER"
+ elif isinstance(type_, sqltypes.TIMESTAMP):
+ return "DATETIME"
+ elif isinstance(
+ type_,
+ (
+ sqltypes.DECIMAL,
+ sqltypes.DateTime,
+ sqltypes.Date,
+ sqltypes.Time,
+ ),
+ ):
+ return self.dialect.type_compiler.process(type_)
+ elif isinstance(type_, sqltypes.String) and not isinstance(
+ type_, (ENUM, SET)
+ ):
+ adapted = CHAR._adapt_string_for_cast(type_)
+ return self.dialect.type_compiler.process(adapted)
+ elif isinstance(type_, sqltypes._Binary):
+ return "BINARY"
+ elif isinstance(type_, sqltypes.JSON):
+ return "JSON"
+ elif isinstance(type_, sqltypes.NUMERIC):
+ return self.dialect.type_compiler.process(type_).replace(
+ "NUMERIC", "DECIMAL"
+ )
+ elif (
+ isinstance(type_, sqltypes.Float)
+ and self.dialect._support_float_cast
+ ):
+ return self.dialect.type_compiler.process(type_)
+ else:
+ return None
+
+ def visit_cast(self, cast, **kw):
+ type_ = self.process(cast.typeclause)
+ if type_ is None:
+ util.warn(
+ "Datatype %s does not support CAST on MySQL/MariaDb; "
+ "the CAST will be skipped."
+ % self.dialect.type_compiler.process(cast.typeclause.type)
+ )
+ return self.process(cast.clause.self_group(), **kw)
+
+ return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
+
+ def render_literal_value(self, value, type_):
+ value = super(MySQLCompiler, self).render_literal_value(value, type_)
+ if self.dialect._backslash_escapes:
+ value = value.replace("\\", "\\\\")
+ return value
+
+ # override native_boolean=False behavior here, as
+ # MySQL still supports native boolean
+ def visit_true(self, element, **kw):
+ return "true"
+
+ def visit_false(self, element, **kw):
+ return "false"
+
+ def get_select_precolumns(self, select, **kw):
+ """Add special MySQL keywords in place of DISTINCT.
+
+ .. deprecated 1.4:: this usage is deprecated.
+ :meth:`_expression.Select.prefix_with` should be used for special
+ keywords at the start of a SELECT.
+
+ """
+ if isinstance(select._distinct, util.string_types):
+ util.warn_deprecated(
+ "Sending string values for 'distinct' is deprecated in the "
+ "MySQL dialect and will be removed in a future release. "
+ "Please use :meth:`.Select.prefix_with` for special keywords "
+ "at the start of a SELECT statement",
+ version="1.4",
+ )
+ return select._distinct.upper() + " "
+
+ return super(MySQLCompiler, self).get_select_precolumns(select, **kw)
+
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
+ if join.full:
+ join_type = " FULL OUTER JOIN "
+ elif join.isouter:
+ join_type = " LEFT OUTER JOIN "
+ else:
+ join_type = " INNER JOIN "
+
+ return "".join(
+ (
+ self.process(
+ join.left, asfrom=True, from_linter=from_linter, **kwargs
+ ),
+ join_type,
+ self.process(
+ join.right, asfrom=True, from_linter=from_linter, **kwargs
+ ),
+ " ON ",
+ self.process(join.onclause, from_linter=from_linter, **kwargs),
+ )
+ )
+
+ def for_update_clause(self, select, **kw):
+ if select._for_update_arg.read:
+ tmp = " LOCK IN SHARE MODE"
+ else:
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def limit_clause(self, select, **kw):
+ # MySQL supports:
+ # LIMIT <limit>
+ # LIMIT <offset>, <limit>
+ # and in server versions > 3.3:
+ # LIMIT <limit> OFFSET <offset>
+ # The latter is more readable for offsets but we're stuck with the
+ # former until we can refine dialects by server revision.
+
+ limit_clause, offset_clause = (
+ select._limit_clause,
+ select._offset_clause,
+ )
+
+ if limit_clause is None and offset_clause is None:
+ return ""
+ elif offset_clause is not None:
+ # As suggested by the MySQL docs, need to apply an
+ # artificial limit if one wasn't provided
+ # https://dev.mysql.com/doc/refman/5.0/en/select.html
+ if limit_clause is None:
+ # hardwire the upper limit. Currently
+ # needed by OurSQL with Python 3
+ # (https://bugs.launchpad.net/oursql/+bug/686232),
+ # but also is consistent with the usage of the upper
+ # bound as part of MySQL's "syntax" for OFFSET with
+ # no LIMIT
+ return " \n LIMIT %s, %s" % (
+ self.process(offset_clause, **kw),
+ "18446744073709551615",
+ )
+ else:
+ return " \n LIMIT %s, %s" % (
+ self.process(offset_clause, **kw),
+ self.process(limit_clause, **kw),
+ )
+ else:
+ # No offset provided, so just use the limit
+ return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
+
+ def update_limit_clause(self, update_stmt):
+ limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
+ if limit:
+ return "LIMIT %s" % limit
+ else:
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ kw["asfrom"] = True
+ return ", ".join(
+ t._compiler_dispatch(self, **kw)
+ for t in [from_table] + list(extra_froms)
+ )
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return None
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. USING clause specific to MySQL."""
+ kw["asfrom"] = True
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def visit_empty_set_expr(self, element_types):
+ return (
+ "SELECT %(outer)s FROM (SELECT %(inner)s) "
+ "as _empty_set WHERE 1!=1"
+ % {
+ "inner": ", ".join(
+ "1 AS _in_%s" % idx
+ for idx, type_ in enumerate(element_types)
+ ),
+ "outer": ", ".join(
+ "_in_%s" % idx for idx, type_ in enumerate(element_types)
+ ),
+ }
+ )
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "NOT (%s <=> %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "%s <=> %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _mariadb_regexp_flags(self, flags, pattern, **kw):
+ return "CONCAT('(?', %s, ')', %s)" % (
+ self.process(flags, **kw),
+ self.process(pattern, **kw),
+ )
+
+ def _regexp_match(self, op_string, binary, operator, **kw):
+ flags = binary.modifiers["flags"]
+ if flags is None:
+ return self._generate_generic_binary(binary, op_string, **kw)
+ elif self.dialect.is_mariadb:
+ return "%s%s%s" % (
+ self.process(binary.left, **kw),
+ op_string,
+ self._mariadb_regexp_flags(flags, binary.right),
+ )
+ else:
+ text = "REGEXP_LIKE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(flags, **kw),
+ )
+ if op_string == " NOT REGEXP ":
+ return "NOT %s" % text
+ else:
+ return text
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match(" REGEXP ", binary, operator, **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match(" NOT REGEXP ", binary, operator, **kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ flags = binary.modifiers["flags"]
+ replacement = binary.modifiers["replacement"]
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(replacement, **kw),
+ )
+ elif self.dialect.is_mariadb:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self._mariadb_regexp_flags(flags, binary.right),
+ self.process(replacement, **kw),
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(replacement, **kw),
+ self.process(flags, **kw),
+ )
+
+
+class MySQLDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kw):
+ """Builds column DDL."""
+
+ colspec = [
+ self.preparer.format_column(column),
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ ),
+ ]
+
+ if column.computed is not None:
+ colspec.append(self.process(column.computed))
+
+ is_timestamp = isinstance(
+ column.type._unwrapped_dialect_impl(self.dialect),
+ sqltypes.TIMESTAMP,
+ )
+
+ if not column.nullable:
+ colspec.append("NOT NULL")
+
+ # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa
+ elif column.nullable and is_timestamp:
+ colspec.append("NULL")
+
+ comment = column.comment
+ if comment is not None:
+ literal = self.sql_compiler.render_literal_value(
+ comment, sqltypes.String()
+ )
+ colspec.append("COMMENT " + literal)
+
+ if (
+ column.table is not None
+ and column is column.table._autoincrement_column
+ and (
+ column.server_default is None
+ or isinstance(column.server_default, sa_schema.Identity)
+ )
+ and not (
+ self.dialect.supports_sequences
+ and isinstance(column.default, sa_schema.Sequence)
+ and not column.default.optional
+ )
+ ):
+ colspec.append("AUTO_INCREMENT")
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec.append("DEFAULT " + default)
+ return " ".join(colspec)
+
+ def post_create_table(self, table):
+ """Build table-level CREATE options like ENGINE and COLLATE."""
+
+ table_opts = []
+
+ opts = dict(
+ (k[len(self.dialect.name) + 1 :].upper(), v)
+ for k, v in table.kwargs.items()
+ if k.startswith("%s_" % self.dialect.name)
+ )
+
+ if table.comment is not None:
+ opts["COMMENT"] = table.comment
+
+ partition_options = [
+ "PARTITION_BY",
+ "PARTITIONS",
+ "SUBPARTITIONS",
+ "SUBPARTITION_BY",
+ ]
+
+ nonpart_options = set(opts).difference(partition_options)
+ part_options = set(opts).intersection(partition_options)
+
+ for opt in topological.sort(
+ [
+ ("DEFAULT_CHARSET", "COLLATE"),
+ ("DEFAULT_CHARACTER_SET", "COLLATE"),
+ ("CHARSET", "COLLATE"),
+ ("CHARACTER_SET", "COLLATE"),
+ ],
+ nonpart_options,
+ ):
+ arg = opts[opt]
+ if opt in _reflection._options_of_type_string:
+
+ arg = self.sql_compiler.render_literal_value(
+ arg, sqltypes.String()
+ )
+
+ if opt in (
+ "DATA_DIRECTORY",
+ "INDEX_DIRECTORY",
+ "DEFAULT_CHARACTER_SET",
+ "CHARACTER_SET",
+ "DEFAULT_CHARSET",
+ "DEFAULT_COLLATE",
+ ):
+ opt = opt.replace("_", " ")
+
+ joiner = "="
+ if opt in (
+ "TABLESPACE",
+ "DEFAULT CHARACTER SET",
+ "CHARACTER SET",
+ "COLLATE",
+ ):
+ joiner = " "
+
+ table_opts.append(joiner.join((opt, arg)))
+
+ for opt in topological.sort(
+ [
+ ("PARTITION_BY", "PARTITIONS"),
+ ("PARTITION_BY", "SUBPARTITION_BY"),
+ ("PARTITION_BY", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITION_BY"),
+ ("SUBPARTITION_BY", "SUBPARTITIONS"),
+ ],
+ part_options,
+ ):
+ arg = opts[opt]
+ if opt in _reflection._options_of_type_string:
+ arg = self.sql_compiler.render_literal_value(
+ arg, sqltypes.String()
+ )
+
+ opt = opt.replace("_", " ")
+ joiner = " "
+
+ table_opts.append(joiner.join((opt, arg)))
+
+ return " ".join(table_opts)
+
+ def visit_create_index(self, create, **kw):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ table = preparer.format_table(index.table)
+
+ columns = [
+ self.sql_compiler.process(
+ elements.Grouping(expr)
+ if (
+ isinstance(expr, elements.BinaryExpression)
+ or (
+ isinstance(expr, elements.UnaryExpression)
+ and expr.modifier
+ not in (operators.desc_op, operators.asc_op)
+ )
+ or isinstance(expr, functions.FunctionElement)
+ )
+ else expr,
+ include_table=False,
+ literal_binds=True,
+ )
+ for expr in index.expressions
+ ]
+
+ name = self._prepared_index_name(index)
+
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None)
+ if index_prefix:
+ text += index_prefix + " "
+
+ text += "INDEX "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+ text += "%s ON %s " % (name, table)
+
+ length = index.dialect_options[self.dialect.name]["length"]
+ if length is not None:
+
+ if isinstance(length, dict):
+ # length value can be a (column_name --> integer value)
+ # mapping specifying the prefix length for each column of the
+ # index
+ columns = ", ".join(
+ "%s(%d)" % (expr, length[col.name])
+ if col.name in length
+ else (
+ "%s(%d)" % (expr, length[expr])
+ if expr in length
+ else "%s" % expr
+ )
+ for col, expr in zip(index.expressions, columns)
+ )
+ else:
+ # or can be an integer value specifying the same
+ # prefix length for all columns of the index
+ columns = ", ".join(
+ "%s(%d)" % (col, length) for col in columns
+ )
+ else:
+ columns = ", ".join(columns)
+ text += "(%s)" % columns
+
+ parser = index.dialect_options["mysql"]["with_parser"]
+ if parser is not None:
+ text += " WITH PARSER %s" % (parser,)
+
+ using = index.dialect_options["mysql"]["using"]
+ if using is not None:
+ text += " USING %s" % (preparer.quote(using))
+
+ return text
+
+ def visit_primary_key_constraint(self, constraint):
+ text = super(MySQLDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+ using = constraint.dialect_options["mysql"]["using"]
+ if using:
+ text += " USING %s" % (self.preparer.quote(using))
+ return text
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ text = "\nDROP INDEX "
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ return text + "%s ON %s" % (
+ self._prepared_index_name(index, include_schema=False),
+ self.preparer.format_table(index.table),
+ )
+
+ def visit_drop_constraint(self, drop):
+ constraint = drop.element
+ if isinstance(constraint, sa_schema.ForeignKeyConstraint):
+ qual = "FOREIGN KEY "
+ const = self.preparer.format_constraint(constraint)
+ elif isinstance(constraint, sa_schema.PrimaryKeyConstraint):
+ qual = "PRIMARY KEY "
+ const = ""
+ elif isinstance(constraint, sa_schema.UniqueConstraint):
+ qual = "INDEX "
+ const = self.preparer.format_constraint(constraint)
+ elif isinstance(constraint, sa_schema.CheckConstraint):
+ if self.dialect.is_mariadb:
+ qual = "CONSTRAINT "
+ else:
+ qual = "CHECK "
+ const = self.preparer.format_constraint(constraint)
+ else:
+ qual = ""
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
+
+ def define_constraint_match(self, constraint):
+ if constraint.match is not None:
+ raise exc.CompileError(
+ "MySQL ignores the 'MATCH' keyword while at the same time "
+ "causes ON UPDATE/ON DELETE clauses to be ignored."
+ )
+ return ""
+
+ def visit_set_table_comment(self, create):
+ return "ALTER TABLE %s COMMENT %s" % (
+ self.preparer.format_table(create.element),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_table_comment(self, create):
+ return "ALTER TABLE %s COMMENT ''" % (
+ self.preparer.format_table(create.element)
+ )
+
+ def visit_set_column_comment(self, create):
+ return "ALTER TABLE %s CHANGE %s %s" % (
+ self.preparer.format_table(create.element.table),
+ self.preparer.format_column(create.element),
+ self.get_column_specification(create.element),
+ )
+
+
+class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend_numeric(self, type_, spec):
+ "Extend a numeric-type declaration with MySQL specific extensions."
+
+ if not self._mysql_type(type_):
+ return spec
+
+ if type_.unsigned:
+ spec += " UNSIGNED"
+ if type_.zerofill:
+ spec += " ZEROFILL"
+ return spec
+
+ def _extend_string(self, type_, defaults, spec):
+ """Extend a string-type declaration with standard SQL CHARACTER SET /
+ COLLATE annotations and MySQL specific extensions.
+
+ """
+
+ def attr(name):
+ return getattr(type_, name, defaults.get(name))
+
+ if attr("charset"):
+ charset = "CHARACTER SET %s" % attr("charset")
+ elif attr("ascii"):
+ charset = "ASCII"
+ elif attr("unicode"):
+ charset = "UNICODE"
+ else:
+ charset = None
+
+ if attr("collation"):
+ collation = "COLLATE %s" % type_.collation
+ elif attr("binary"):
+ collation = "BINARY"
+ else:
+ collation = None
+
+ if attr("national"):
+ # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
+ return " ".join(
+ [c for c in ("NATIONAL", spec, collation) if c is not None]
+ )
+ return " ".join(
+ [c for c in (spec, charset, collation) if c is not None]
+ )
+
+ def _mysql_type(self, type_):
+ return isinstance(type_, (_StringType, _NumericType))
+
+ def visit_NUMERIC(self, type_, **kw):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "NUMERIC")
+ elif type_.scale is None:
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s)" % {"precision": type_.precision},
+ )
+ else:
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+
+ def visit_DECIMAL(self, type_, **kw):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "DECIMAL")
+ elif type_.scale is None:
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s)" % {"precision": type_.precision},
+ )
+ else:
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+
+ def visit_DOUBLE(self, type_, **kw):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(
+ type_,
+ "DOUBLE(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+ else:
+ return self._extend_numeric(type_, "DOUBLE")
+
+ def visit_REAL(self, type_, **kw):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(
+ type_,
+ "REAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+ else:
+ return self._extend_numeric(type_, "REAL")
+
+ def visit_FLOAT(self, type_, **kw):
+ if (
+ self._mysql_type(type_)
+ and type_.scale is not None
+ and type_.precision is not None
+ ):
+ return self._extend_numeric(
+ type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)
+ )
+ elif type_.precision is not None:
+ return self._extend_numeric(
+ type_, "FLOAT(%s)" % (type_.precision,)
+ )
+ else:
+ return self._extend_numeric(type_, "FLOAT")
+
+ def visit_INTEGER(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "INTEGER(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "INTEGER")
+
+ def visit_BIGINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "BIGINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "BIGINT")
+
+ def visit_MEDIUMINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "MEDIUMINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "MEDIUMINT")
+
+ def visit_TINYINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_, "TINYINT(%s)" % type_.display_width
+ )
+ else:
+ return self._extend_numeric(type_, "TINYINT")
+
+ def visit_SMALLINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "SMALLINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "SMALLINT")
+
+ def visit_BIT(self, type_, **kw):
+ if type_.length is not None:
+ return "BIT(%s)" % type_.length
+ else:
+ return "BIT"
+
+ def visit_DATETIME(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "DATETIME(%d)" % type_.fsp
+ else:
+ return "DATETIME"
+
+ def visit_DATE(self, type_, **kw):
+ return "DATE"
+
+ def visit_TIME(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "TIME(%d)" % type_.fsp
+ else:
+ return "TIME"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "TIMESTAMP(%d)" % type_.fsp
+ else:
+ return "TIMESTAMP"
+
+ def visit_YEAR(self, type_, **kw):
+ if type_.display_width is None:
+ return "YEAR"
+ else:
+ return "YEAR(%s)" % type_.display_width
+
+ def visit_TEXT(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
+ else:
+ return self._extend_string(type_, {}, "TEXT")
+
+ def visit_TINYTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "TINYTEXT")
+
+ def visit_MEDIUMTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "MEDIUMTEXT")
+
+ def visit_LONGTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "LONGTEXT")
+
+ def visit_VARCHAR(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
+ else:
+ raise exc.CompileError(
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+
+ def visit_CHAR(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(
+ type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
+ )
+ else:
+ return self._extend_string(type_, {}, "CHAR")
+
+ def visit_NVARCHAR(self, type_, **kw):
+ # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
+ # of "NVARCHAR".
+ if type_.length:
+ return self._extend_string(
+ type_,
+ {"national": True},
+ "VARCHAR(%(length)s)" % {"length": type_.length},
+ )
+ else:
+ raise exc.CompileError(
+ "NVARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+
+ def visit_NCHAR(self, type_, **kw):
+ # We'll actually generate the equiv.
+ # "NATIONAL CHAR" instead of "NCHAR".
+ if type_.length:
+ return self._extend_string(
+ type_,
+ {"national": True},
+ "CHAR(%(length)s)" % {"length": type_.length},
+ )
+ else:
+ return self._extend_string(type_, {"national": True}, "CHAR")
+
+ def visit_VARBINARY(self, type_, **kw):
+ return "VARBINARY(%d)" % type_.length
+
+ def visit_JSON(self, type_, **kw):
+ return "JSON"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_)
+
+ def visit_enum(self, type_, **kw):
+ if not type_.native_enum:
+ return super(MySQLTypeCompiler, self).visit_enum(type_)
+ else:
+ return self._visit_enumerated_values("ENUM", type_, type_.enums)
+
+ def visit_BLOB(self, type_, **kw):
+ if type_.length:
+ return "BLOB(%d)" % type_.length
+ else:
+ return "BLOB"
+
+ def visit_TINYBLOB(self, type_, **kw):
+ return "TINYBLOB"
+
+ def visit_MEDIUMBLOB(self, type_, **kw):
+ return "MEDIUMBLOB"
+
+ def visit_LONGBLOB(self, type_, **kw):
+ return "LONGBLOB"
+
+ def _visit_enumerated_values(self, name, type_, enumerated_values):
+ quoted_enums = []
+ for e in enumerated_values:
+ quoted_enums.append("'%s'" % e.replace("'", "''"))
+ return self._extend_string(
+ type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
+ )
+
+ def visit_ENUM(self, type_, **kw):
+ return self._visit_enumerated_values("ENUM", type_, type_.enums)
+
+ def visit_SET(self, type_, **kw):
+ return self._visit_enumerated_values("SET", type_, type_.values)
+
+ def visit_BOOLEAN(self, type_, **kw):
+ return "BOOL"
+
+
+class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS_MYSQL
+
+ def __init__(self, dialect, server_ansiquotes=False, **kw):
+ if not server_ansiquotes:
+ quote = "`"
+ else:
+ quote = '"'
+
+ super(MySQLIdentifierPreparer, self).__init__(
+ dialect, initial_quote=quote, escape_quote=quote
+ )
+
+ def _quote_free_identifiers(self, *ids):
+ """Unilaterally identifier-quote any number of strings."""
+
+ return tuple([self.quote_identifier(i) for i in ids if i is not None])
+
+
+class MariaDBIdentifierPreparer(MySQLIdentifierPreparer):
+ reserved_words = RESERVED_WORDS_MARIADB
+
+
+@log.class_logger
+class MySQLDialect(default.DefaultDialect):
+ """Details of the MySQL dialect.
+ Not used directly in application code.
+ """
+
+ name = "mysql"
+ supports_statement_cache = True
+
+ supports_alter = True
+
+ # MySQL has no true "boolean" type; we
+ # allow for the "true" and "false" keywords, however
+ supports_native_boolean = False
+
+ # identifiers are 64, however aliases can be 255...
+ max_identifier_length = 255
+ max_index_name_length = 64
+ max_constraint_name_length = 64
+
+ supports_native_enum = True
+
+ supports_sequences = False # default for MySQL ...
+ # ... may be updated to True for MariaDB 10.3+ in initialize()
+
+ sequences_optional = False
+
+ supports_for_update_of = False # default for MySQL ...
+ # ... may be updated to True for MySQL 8+ in initialize()
+
+ # MySQL doesn't support "DEFAULT VALUES" but *does* support
+ # "VALUES (DEFAULT)"
+ supports_default_values = False
+ supports_default_metavalue = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+ supports_multivalues_insert = True
+
+ supports_comments = True
+ inline_comments = True
+ default_paramstyle = "format"
+ colspecs = colspecs
+
+ cte_follows_insert = True
+
+ statement_compiler = MySQLCompiler
+ ddl_compiler = MySQLDDLCompiler
+ type_compiler = MySQLTypeCompiler
+ ischema_names = ischema_names
+ preparer = MySQLIdentifierPreparer
+
+ is_mariadb = False
+ _mariadb_normalized_version_info = None
+
+ # default SQL compilation settings -
+ # these are modified upon initialize(),
+ # i.e. first connect
+ _backslash_escapes = True
+ _server_ansiquotes = False
+
+ construct_arguments = [
+ (sa_schema.Table, {"*": None}),
+ (sql.Update, {"limit": None}),
+ (sa_schema.PrimaryKeyConstraint, {"using": None}),
+ (
+ sa_schema.Index,
+ {
+ "using": None,
+ "length": None,
+ "prefix": None,
+ "with_parser": None,
+ },
+ ),
+ ]
+
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ is_mariadb=None,
+ **kwargs
+ ):
+ kwargs.pop("use_ansiquotes", None) # legacy
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+ self._set_mariadb(is_mariadb, None)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+
+ # adjust for ConnectionFairy being present
+ # allows attribute set e.g. "connection.autocommit = True"
+ # to work properly
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ self._set_isolation_level(connection, level)
+
+ def _set_isolation_level(self, connection, level):
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ if self._is_mysql and self.server_version_info >= (5, 7, 20):
+ cursor.execute("SELECT @@transaction_isolation")
+ else:
+ cursor.execute("SELECT @@tx_isolation")
+ row = cursor.fetchone()
+ if row is None:
+ util.warn(
+ "Could not retrieve transaction isolation level for MySQL "
+ "connection."
+ )
+ raise NotImplementedError()
+ val = row[0]
+ cursor.close()
+ if util.py3k and isinstance(val, bytes):
+ val = val.decode()
+ return val.upper().replace("-", " ")
+
+ @classmethod
+ def _is_mariadb_from_url(cls, url):
+ dbapi = cls.dbapi()
+ dialect = cls(dbapi=dbapi)
+
+ cargs, cparams = dialect.create_connect_args(url)
+ conn = dialect.connect(*cargs, **cparams)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
+ val = cursor.fetchone()[0]
+ except:
+ raise
+ else:
+ return bool(val)
+ finally:
+ conn.close()
+
+ def _get_server_version_info(self, connection):
+ # get database server version info explicitly over the wire
+ # to avoid proxy servers like MaxScale getting in the
+ # way with their own values, see #4205
+ dbapi_con = connection.connection
+ cursor = dbapi_con.cursor()
+ cursor.execute("SELECT VERSION()")
+ val = cursor.fetchone()[0]
+ cursor.close()
+ if util.py3k and isinstance(val, bytes):
+ val = val.decode()
+
+ return self._parse_server_version(val)
+
+ def _parse_server_version(self, val):
+ version = []
+ is_mariadb = False
+
+ r = re.compile(r"[.\-+]")
+ tokens = r.split(val)
+ for token in tokens:
+ parsed_token = re.match(
+ r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token
+ )
+ if not parsed_token:
+ continue
+ elif parsed_token.group(2):
+ self._mariadb_normalized_version_info = tuple(version[-3:])
+ is_mariadb = True
+ else:
+ digit = int(parsed_token.group(1))
+ version.append(digit)
+
+ server_version_info = tuple(version)
+
+ self._set_mariadb(server_version_info and is_mariadb, val)
+
+ if not is_mariadb:
+ self._mariadb_normalized_version_info = server_version_info
+
+ if server_version_info < (5, 0, 2):
+ raise NotImplementedError(
+ "the MySQL/MariaDB dialect supports server "
+ "version info 5.0.2 and above."
+ )
+
+ # setting it here to help w the test suite
+ self.server_version_info = server_version_info
+ return server_version_info
+
+ def _set_mariadb(self, is_mariadb, server_version_info):
+ if is_mariadb is None:
+ return
+
+ if not is_mariadb and self.is_mariadb:
+ raise exc.InvalidRequestError(
+ "MySQL version %s is not a MariaDB variant."
+ % (server_version_info,)
+ )
+ if is_mariadb:
+ self.preparer = MariaDBIdentifierPreparer
+ # this would have been set by the default dialect already,
+ # so set it again
+ self.identifier_preparer = self.preparer(self)
+ self.is_mariadb = is_mariadb
+
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid"), dict(xid=xid))
+ connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid))
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid"), dict(xid=xid))
+ connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid))
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.exec_driver_sql("XA RECOVER")
+ return [row["data"][0 : row["gtrid_length"]] for row in resultset]
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e,
+ (
+ self.dbapi.OperationalError,
+ self.dbapi.ProgrammingError,
+ self.dbapi.InterfaceError,
+ ),
+ ) and self._extract_error_code(e) in (
+ 1927,
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ 4031,
+ ):
+ return True
+ elif isinstance(
+ e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+ ):
+ # if underlying connection is closed,
+ # this is the error you get
+ return "(0, '')" in str(e)
+ else:
+ return False
+
+ def _compat_fetchall(self, rp, charset=None):
+ """Proxy result rows to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ return [_DecodingRow(row, charset) for row in rp.fetchall()]
+
+ def _compat_fetchone(self, rp, charset=None):
+ """Proxy a result row to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ row = rp.fetchone()
+ if row:
+ return _DecodingRow(row, charset)
+ else:
+ return None
+
+ def _compat_first(self, rp, charset=None):
+ """Proxy a result row to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ row = rp.first()
+ if row:
+ return _DecodingRow(row, charset)
+ else:
+ return None
+
+ def _extract_error_code(self, exception):
+ raise NotImplementedError()
+
+ def _get_default_schema_name(self, connection):
+ return connection.exec_driver_sql("SELECT DATABASE()").scalar()
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ if schema is None:
+ schema = self.default_schema_name
+
+ rs = connection.execute(
+ text(
+ "SELECT COUNT(*) FROM information_schema.tables WHERE "
+ "table_schema = :table_schema AND "
+ "table_name = :table_name"
+ ).bindparams(
+ sql.bindparam("table_schema", type_=Unicode),
+ sql.bindparam("table_name", type_=Unicode),
+ ),
+ {
+ "table_schema": util.text_type(schema),
+ "table_name": util.text_type(table_name),
+ },
+ )
+ return bool(rs.scalar())
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if not self.supports_sequences:
+ self._sequences_not_supported()
+ if not schema:
+ schema = self.default_schema_name
+ # MariaDB implements sequences as a special type of table
+ #
+ cursor = connection.execute(
+ sql.text(
+ "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
+ "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND "
+ "TABLE_SCHEMA=:schema_name"
+ ),
+ dict(
+ name=util.text_type(sequence_name),
+ schema_name=util.text_type(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def _sequences_not_supported(self):
+ raise NotImplementedError(
+ "Sequences are supported only by the "
+ "MariaDB series 10.3 or greater"
+ )
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not self.supports_sequences:
+ self._sequences_not_supported()
+ if not schema:
+ schema = self.default_schema_name
+ # MariaDB implements sequences as a special type of table
+ cursor = connection.execute(
+ sql.text(
+ "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
+ "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name"
+ ),
+ dict(schema_name=schema),
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(
+ cursor, charset=self._connection_charset
+ )
+ ]
+
+ def initialize(self, connection):
+ # this is driver-based, does not need server version info
+ # and is fairly critical for even basic SQL operations
+ self._connection_charset = self._detect_charset(connection)
+
+ # call super().initialize() because we need to have
+ # server_version_info set up. in 1.4 under python 2 only this does the
+ # "check unicode returns" thing, which is the one area that some
+ # SQL gets compiled within initialize() currently
+ default.DefaultDialect.initialize(self, connection)
+
+ self._detect_sql_mode(connection)
+ self._detect_ansiquotes(connection) # depends on sql mode
+ self._detect_casing(connection)
+ if self._server_ansiquotes:
+ # if ansiquotes == True, build a new IdentifierPreparer
+ # with the new setting
+ self.identifier_preparer = self.preparer(
+ self, server_ansiquotes=self._server_ansiquotes
+ )
+
+ self.supports_sequences = (
+ self.is_mariadb and self.server_version_info >= (10, 3)
+ )
+
+ self.supports_for_update_of = (
+ self._is_mysql and self.server_version_info >= (8,)
+ )
+
+ self._needs_correct_for_88718_96365 = (
+ not self.is_mariadb and self.server_version_info >= (8,)
+ )
+
+ self._warn_for_known_db_issues()
+
+ def _warn_for_known_db_issues(self):
+ if self.is_mariadb:
+ mdb_version = self._mariadb_normalized_version_info
+ if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
+ util.warn(
+ "MariaDB %r before 10.2.9 has known issues regarding "
+ "CHECK constraints, which impact handling of NULL values "
+ "with SQLAlchemy's boolean datatype (MDEV-13596). An "
+ "additional issue prevents proper migrations of columns "
+ "with CHECK constraints (MDEV-11114). Please upgrade to "
+ "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
+ "series, to avoid these issues." % (mdb_version,)
+ )
+
+ @property
+ def _support_float_cast(self):
+ if not self.server_version_info:
+ return False
+ elif self.is_mariadb:
+ # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+ return self.server_version_info >= (10, 4, 5)
+ else:
+ # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
+ return self.server_version_info >= (8, 0, 17)
+
+ @property
+ def _is_mariadb(self):
+ return self.is_mariadb
+
+ @property
+ def _is_mysql(self):
+ return not self.is_mariadb
+
+ @property
+ def _is_mariadb_102(self):
+ return self.is_mariadb and self._mariadb_normalized_version_info > (
+ 10,
+ 2,
+ )
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ rp = connection.exec_driver_sql("SHOW schemas")
+ return [r[0] for r in rp]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ """Return a Unicode SHOW TABLES from a given schema."""
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.default_schema_name
+
+ charset = self._connection_charset
+
+ rp = connection.exec_driver_sql(
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
+
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] == "BASE TABLE"
+ ]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+ charset = self._connection_charset
+ rp = connection.exec_driver_sql(
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(schema)
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] in ("VIEW", "SYSTEM VIEW")
+ ]
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return parsed_state.table_options
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return parsed_state.columns
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ for key in parsed_state.keys:
+ if key["type"] == "PRIMARY":
+ # There can be only one.
+ cols = [s[0] for s in key["columns"]]
+ return {"constrained_columns": cols, "name": None}
+ return {"constrained_columns": [], "name": None}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ default_schema = None
+
+ fkeys = []
+
+ for spec in parsed_state.fk_constraints:
+ ref_name = spec["table"][-1]
+ ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema
+
+ if not ref_schema:
+ if default_schema is None:
+ default_schema = connection.dialect.default_schema_name
+ if schema == default_schema:
+ ref_schema = schema
+
+ loc_names = spec["local"]
+ ref_names = spec["foreign"]
+
+ con_kw = {}
+ for opt in ("onupdate", "ondelete"):
+ if spec.get(opt, False) not in ("NO ACTION", None):
+ con_kw[opt] = spec[opt]
+
+ fkey_d = {
+ "name": spec["name"],
+ "constrained_columns": loc_names,
+ "referred_schema": ref_schema,
+ "referred_table": ref_name,
+ "referred_columns": ref_names,
+ "options": con_kw,
+ }
+ fkeys.append(fkey_d)
+
+ if self._needs_correct_for_88718_96365:
+ self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
+
+ return fkeys
+
+ def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
+ # Foreign key is always in lower case (MySQL 8.0)
+ # https://bugs.mysql.com/bug.php?id=88718
+ # issue #4344 for SQLAlchemy
+
+ # table name also for MySQL 8.0
+ # https://bugs.mysql.com/bug.php?id=96365
+ # issue #4751 for SQLAlchemy
+
+ # for lower_case_table_names=2, information_schema.columns
+ # preserves the original table/schema casing, but SHOW CREATE
+ # TABLE does not. this problem is not in lower_case_table_names=1,
+ # but use case-insensitive matching for these two modes in any case.
+
+ if self._casing in (1, 2):
+
+ def lower(s):
+ return s.lower()
+
+ else:
+ # if on case sensitive, there can be two tables referenced
+ # with the same name different casing, so we need to use
+ # case-sensitive matching.
+ def lower(s):
+ return s
+
+ default_schema_name = connection.dialect.default_schema_name
+ col_tuples = [
+ (
+ lower(rec["referred_schema"] or default_schema_name),
+ lower(rec["referred_table"]),
+ col_name,
+ )
+ for rec in fkeys
+ for col_name in rec["referred_columns"]
+ ]
+
+ if col_tuples:
+
+ correct_for_wrong_fk_case = connection.execute(
+ sql.text(
+ """
+ select table_schema, table_name, column_name
+ from information_schema.columns
+ where (table_schema, table_name, lower(column_name)) in
+ :table_data;
+ """
+ ).bindparams(sql.bindparam("table_data", expanding=True)),
+ dict(table_data=col_tuples),
+ )
+
+ # in casing=0, table name and schema name come back in their
+ # exact case.
+ # in casing=1, table name and schema name come back in lower
+ # case.
+ # in casing=2, table name and schema name come back from the
+ # information_schema.columns view in the case
+ # that was used in CREATE DATABASE and CREATE TABLE, but
+ # SHOW CREATE TABLE converts them to *lower case*, therefore
+ # not matching. So for this case, case-insensitive lookup
+ # is necessary
+ d = defaultdict(dict)
+ for schema, tname, cname in correct_for_wrong_fk_case:
+ d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema
+ d[(lower(schema), lower(tname))]["TABLENAME"] = tname
+ d[(lower(schema), lower(tname))][cname.lower()] = cname
+
+ for fkey in fkeys:
+ rec = d[
+ (
+ lower(fkey["referred_schema"] or default_schema_name),
+ lower(fkey["referred_table"]),
+ )
+ ]
+
+ fkey["referred_table"] = rec["TABLENAME"]
+ if fkey["referred_schema"] is not None:
+ fkey["referred_schema"] = rec["SCHEMANAME"]
+
+ fkey["referred_columns"] = [
+ rec[col.lower()] for col in fkey["referred_columns"]
+ ]
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ return [
+ {"name": spec["name"], "sqltext": spec["sqltext"]}
+ for spec in parsed_state.ck_constraints
+ ]
+
+ @reflection.cache
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return {
+ "text": parsed_state.table_options.get(
+ "%s_comment" % self.name, None
+ )
+ }
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ indexes = []
+
+ for spec in parsed_state.keys:
+ dialect_options = {}
+ unique = False
+ flavor = spec["type"]
+ if flavor == "PRIMARY":
+ continue
+ if flavor == "UNIQUE":
+ unique = True
+ elif flavor in ("FULLTEXT", "SPATIAL"):
+ dialect_options["%s_prefix" % self.name] = flavor
+ elif flavor is None:
+ pass
+ else:
+ self.logger.info(
+ "Converting unknown KEY type %s to a plain KEY", flavor
+ )
+ pass
+
+ if spec["parser"]:
+ dialect_options["%s_with_parser" % (self.name)] = spec[
+ "parser"
+ ]
+
+ index_d = {}
+ if dialect_options:
+ index_d["dialect_options"] = dialect_options
+
+ index_d["name"] = spec["name"]
+ index_d["column_names"] = [s[0] for s in spec["columns"]]
+ index_d["unique"] = unique
+ if flavor:
+ index_d["type"] = flavor
+ indexes.append(index_d)
+ return indexes
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ return [
+ {
+ "name": key["name"],
+ "column_names": [col[0] for col in key["columns"]],
+ "duplicates_index": key["name"],
+ }
+ for key in parsed_state.keys
+ if key["type"] == "UNIQUE"
+ ]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+
+ charset = self._connection_charset
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(schema, view_name)
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ return sql
+
+ def _parsed_state_or_create(
+ self, connection, table_name, schema=None, **kw
+ ):
+ return self._setup_parser(
+ connection,
+ table_name,
+ schema,
+ info_cache=kw.get("info_cache", None),
+ )
+
+ @util.memoized_property
+ def _tabledef_parser(self):
+ """return the MySQLTableDefinitionParser, generate if needed.
+
+ The deferred creation ensures that the dialect has
+ retrieved server version information first.
+
+ """
+ preparer = self.identifier_preparer
+ return _reflection.MySQLTableDefinitionParser(self, preparer)
+
+ @reflection.cache
+ def _setup_parser(self, connection, table_name, schema=None, **kw):
+ charset = self._connection_charset
+ parser = self._tabledef_parser
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql):
+ # Adapt views to something table-like.
+ columns = self._describe_table(
+ connection, None, charset, full_name=full_name
+ )
+ sql = parser._describe_to_create(table_name, columns)
+ return parser.parse(sql, charset)
+
+ def _fetch_setting(self, connection, setting_name):
+ charset = self._connection_charset
+
+ if self.server_version_info and self.server_version_info < (5, 6):
+ sql = "SHOW VARIABLES LIKE '%s'" % setting_name
+ fetch_col = 1
+ else:
+ sql = "SELECT @@%s" % setting_name
+ fetch_col = 0
+
+ show_var = connection.exec_driver_sql(sql)
+ row = self._compat_first(show_var, charset=charset)
+ if not row:
+ return None
+ else:
+ return row[fetch_col]
+
+ def _detect_charset(self, connection):
+ raise NotImplementedError()
+
+ def _detect_casing(self, connection):
+ """Sniff out identifier case sensitivity.
+
+ Cached per-connection. This value can not change without a server
+ restart.
+
+ """
+ # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html
+
+ setting = self._fetch_setting(connection, "lower_case_table_names")
+ if setting is None:
+ cs = 0
+ else:
+ # 4.0.15 returns OFF or ON according to [ticket:489]
+ # 3.23 doesn't, 4.0.27 doesn't..
+ if setting == "OFF":
+ cs = 0
+ elif setting == "ON":
+ cs = 1
+ else:
+ cs = int(setting)
+ self._casing = cs
+ return cs
+
+ def _detect_collations(self, connection):
+ """Pull the active COLLATIONS list from the server.
+
+ Cached per-connection.
+ """
+
+ collations = {}
+ charset = self._connection_charset
+ rs = connection.exec_driver_sql("SHOW COLLATION")
+ for row in self._compat_fetchall(rs, charset):
+ collations[row[0]] = row[1]
+ return collations
+
+ def _detect_sql_mode(self, connection):
+ setting = self._fetch_setting(connection, "sql_mode")
+
+ if setting is None:
+ util.warn(
+ "Could not retrieve SQL_MODE; please ensure the "
+ "MySQL user has permissions to SHOW VARIABLES"
+ )
+ self._sql_mode = ""
+ else:
+ self._sql_mode = setting or ""
+
+ def _detect_ansiquotes(self, connection):
+ """Detect and adjust for the ANSI_QUOTES sql mode."""
+
+ mode = self._sql_mode
+ if not mode:
+ mode = ""
+ elif mode.isdigit():
+ mode_no = int(mode)
+ mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or ""
+
+ self._server_ansiquotes = "ANSI_QUOTES" in mode
+
+ # as of MySQL 5.0.1
+ self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
+
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
+ """Run SHOW CREATE TABLE for a ``Table``."""
+
+ if full_name is None:
+ full_name = self.identifier_preparer.format_table(table)
+ st = "SHOW CREATE TABLE %s" % full_name
+
+ rp = None
+ try:
+ rp = connection.execution_options(
+ skip_user_error_events=True
+ ).exec_driver_sql(st)
+ except exc.DBAPIError as e:
+ if self._extract_error_code(e.orig) == 1146:
+ util.raise_(exc.NoSuchTableError(full_name), replace_context=e)
+ else:
+ raise
+ row = self._compat_first(rp, charset=charset)
+ if not row:
+ raise exc.NoSuchTableError(full_name)
+ return row[1].strip()
+
+ def _describe_table(self, connection, table, charset=None, full_name=None):
+ """Run DESCRIBE for a ``Table`` and return processed rows."""
+
+ if full_name is None:
+ full_name = self.identifier_preparer.format_table(table)
+ st = "DESCRIBE %s" % full_name
+
+ rp, rows = None, None
+ try:
+ try:
+ rp = connection.execution_options(
+ skip_user_error_events=True
+ ).exec_driver_sql(st)
+ except exc.DBAPIError as e:
+ code = self._extract_error_code(e.orig)
+ if code == 1146:
+ util.raise_(
+ exc.NoSuchTableError(full_name), replace_context=e
+ )
+ elif code == 1356:
+ util.raise_(
+ exc.UnreflectableTableError(
+ "Table or view named %s could not be "
+ "reflected: %s" % (full_name, e)
+ ),
+ replace_context=e,
+ )
+ else:
+ raise
+ rows = self._compat_fetchall(rp, charset=charset)
+ finally:
+ if rp:
+ rp.close()
+ return rows
+
+
+class _DecodingRow(object):
+ """Return unicode-decoded values based on type inspection.
+
+ Smooth over data type issues (esp. with alpha driver versions) and
+ normalize strings as Unicode regardless of user-configured driver
+ encoding settings.
+
+ """
+
+ # Some MySQL-python versions can return some columns as
+ # sets.Set(['value']) (seriously) but thankfully that doesn't
+ # seem to come up in DDL queries.
+
+ _encoding_compat = {
+ "koi8r": "koi8_r",
+ "koi8u": "koi8_u",
+ "utf16": "utf-16-be", # MySQL's uft16 is always bigendian
+ "utf8mb4": "utf8", # real utf8
+ "utf8mb3": "utf8", # real utf8; saw this happen on CI but I cannot
+ # reproduce, possibly mariadb10.6 related
+ "eucjpms": "ujis",
+ }
+
+ def __init__(self, rowproxy, charset):
+ self.rowproxy = rowproxy
+ self.charset = self._encoding_compat.get(charset, charset)
+
+ def __getitem__(self, index):
+ item = self.rowproxy[index]
+ if isinstance(item, _array):
+ item = item.tostring()
+
+ if self.charset and isinstance(item, util.binary_type):
+ return item.decode(self.charset)
+ else:
+ return item
+
+ def __getattr__(self, attr):
+ item = getattr(self.rowproxy, attr)
+ if isinstance(item, _array):
+ item = item.tostring()
+ if self.charset and isinstance(item, util.binary_type):
+ return item.decode(self.charset)
+ else:
+ return item
diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py
new file mode 100644
index 0000000..a67a194
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/cymysql.py
@@ -0,0 +1,82 @@
+# mysql/cymysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+
+.. dialect:: mysql+cymysql
+ :name: CyMySQL
+ :dbapi: cymysql
+ :connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
+ :url: https://github.com/nakagami/CyMySQL
+
+.. note::
+
+ The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended MySQL
+ dialects are mysqlclient and PyMySQL.
+
+""" # noqa
+
+from .base import BIT
+from .base import MySQLDialect
+from .mysqldb import MySQLDialect_mysqldb
+from ... import util
+
+
+class _cymysqlBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """Convert MySQL's 64 bit, variable length binary string to a long."""
+
+ def process(value):
+ if value is not None:
+ v = 0
+ for i in util.iterbytes(value):
+ v = v << 8 | i
+ return v
+ return value
+
+ return process
+
+
+class MySQLDialect_cymysql(MySQLDialect_mysqldb):
+ driver = "cymysql"
+ supports_statement_cache = True
+
+ description_encoding = None
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = True
+
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("cymysql")
+
+ def _detect_charset(self, connection):
+ return connection.connection.charset
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.OperationalError):
+ return self._extract_error_code(e) in (
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ )
+ elif isinstance(e, self.dbapi.InterfaceError):
+ # if underlying connection is closed,
+ # this is the error you get
+ return True
+ else:
+ return False
+
+
+dialect = MySQLDialect_cymysql
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
new file mode 100644
index 0000000..0c8791a
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -0,0 +1,175 @@
+from ... import exc
+from ... import util
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """MySQL-specific implementation of INSERT.
+
+ Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE.
+
+ The :class:`~.mysql.Insert` object is created using the
+ :func:`sqlalchemy.dialects.mysql.insert` function.
+
+ .. versionadded:: 1.2
+
+ """
+
+ stringify_dialect = "mysql"
+ inherit_cache = False
+
+ @property
+ def inserted(self):
+ """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
+ statement
+
+ MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row
+ that would be inserted, via a special function called ``VALUES()``.
+ This attribute provides all columns in this row to be referenceable
+ such that they will render within a ``VALUES()`` function inside the
+ ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted``
+ so as not to conflict with the existing
+ :meth:`_expression.Insert.values` method.
+
+ .. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance
+ of :class:`_expression.ColumnCollection`, which provides an
+ interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.inserted.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.inserted["column name"]`` or
+ ``stmt.inserted["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ .. seealso::
+
+ :ref:`mysql_insert_on_duplicate_key_update` - example of how
+ to use :attr:`_expression.Insert.inserted`
+
+ """
+ return self.inserted_alias.columns
+
+ @util.memoized_property
+ def inserted_alias(self):
+ return alias(self.table, name="inserted")
+
+ @_generative
+ @_exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already "
+ "has an ON DUPLICATE KEY clause present"
+ },
+ )
+ def on_duplicate_key_update(self, *args, **kw):
+ r"""
+ Specifies the ON DUPLICATE KEY UPDATE clause.
+
+ :param \**kw: Column keys linked to UPDATE values. The
+ values may be any SQL expression or supported literal Python
+ values.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON DUPLICATE KEY UPDATE
+ style of UPDATE, unless values are manually specified here.
+
+ :param \*args: As an alternative to passing key/value parameters,
+ a dictionary or list of 2-tuples can be passed as a single positional
+ argument.
+
+ Passing a single dictionary is equivalent to the keyword argument
+ form::
+
+ insert().on_duplicate_key_update({"name": "some name"})
+
+ Passing a list of 2-tuples indicates that the parameter assignments
+ in the UPDATE clause should be ordered as sent, in a manner similar
+ to that described for the :class:`_expression.Update`
+ construct overall
+ in :ref:`tutorial_parameter_ordered_updates`::
+
+ insert().on_duplicate_key_update(
+ [("name", "some name"), ("value", "some value")])
+
+ .. versionchanged:: 1.3 parameters can be specified as a dictionary
+ or list of 2-tuples; the latter form provides for parameter
+ ordering.
+
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`mysql_insert_on_duplicate_key_update`
+
+ """
+ if args and kw:
+ raise exc.ArgumentError(
+ "Can't pass kwargs and positional arguments simultaneously"
+ )
+
+ if args:
+ if len(args) > 1:
+ raise exc.ArgumentError(
+ "Only a single dictionary or list of tuples "
+ "is accepted positionally."
+ )
+ values = args[0]
+ else:
+ values = kw
+
+ inserted_alias = getattr(self, "inserted_alias", None)
+ self._post_values_clause = OnDuplicateClause(inserted_alias, values)
+
+
+insert = public_factory(
+ Insert, ".dialects.mysql.insert", ".dialects.mysql.Insert"
+)
+
+
+class OnDuplicateClause(ClauseElement):
+ __visit_name__ = "on_duplicate_key_update"
+
+ _parameter_ordering = None
+
+ stringify_dialect = "mysql"
+
+ def __init__(self, inserted_alias, update):
+ self.inserted_alias = inserted_alias
+
+ # auto-detect that parameters should be ordered. This is copied from
+ # Update._proces_colparams(), however we don't look for a special flag
+ # in this case since we are not disambiguating from other use cases as
+ # we are in Update.values().
+ if isinstance(update, list) and (
+ update and isinstance(update[0], tuple)
+ ):
+ self._parameter_ordering = [key for key, value in update]
+ update = dict(update)
+
+ if isinstance(update, dict):
+ if not update:
+ raise ValueError(
+ "update parameter dictionary must not be empty"
+ )
+ elif isinstance(update, ColumnCollection):
+ update = dict(update)
+ else:
+ raise ValueError(
+ "update parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update = update
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py
new file mode 100644
index 0000000..6c9ef28
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/enumerated.py
@@ -0,0 +1,263 @@
+# mysql/enumerated.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .types import _StringType
+from ... import exc
+from ... import sql
+from ... import util
+from ...sql import sqltypes
+from ...sql.base import NO_ARG
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
+ """MySQL ENUM type."""
+
+ __visit_name__ = "ENUM"
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an ENUM.
+
+ E.g.::
+
+ Column('myenum', ENUM("foo", "bar", "baz"))
+
+ :param enums: The range of valid values for this ENUM. Values in
+ enums are not quoted, they will be escaped and surrounded by single
+ quotes when generating the schema. This object may also be a
+ PEP-435-compliant enumerated type.
+
+ .. versionadded: 1.1 added support for PEP-435-compliant enumerated
+ types.
+
+ :param strict: This flag has no effect.
+
+ .. versionchanged:: The MySQL ENUM type as well as the base Enum
+ type now validates all Python data values.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ :param quoting: Not used. A warning will be raised if provided.
+
+ """
+ if kw.pop("quoting", NO_ARG) is not NO_ARG:
+ util.warn_deprecated_20(
+ "The 'quoting' parameter to :class:`.mysql.ENUM` is deprecated"
+ " and will be removed in a future release. "
+ "This parameter now has no effect."
+ )
+ kw.pop("strict", None)
+ self._enum_init(enums, kw)
+ _StringType.__init__(self, length=self.length, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a MySQL native :class:`.mysql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def _object_value_for_elem(self, elem):
+ # mysql sends back a blank string for any value that
+ # was persisted that was not in the enums; that is, it does no
+ # validation on the incoming data, it "truncates" it to be
+ # the blank string. Return it straight.
+ if elem == "":
+ return elem
+ else:
+ return super(ENUM, self)._object_value_for_elem(elem)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
+ )
+
+
+class SET(_StringType):
+ """MySQL SET type."""
+
+ __visit_name__ = "SET"
+
+ def __init__(self, *values, **kw):
+ """Construct a SET.
+
+ E.g.::
+
+ Column('myset', SET("foo", "bar", "baz"))
+
+
+ The list of potential values is required in the case that this
+ set will be used to generate DDL for a table, or if the
+ :paramref:`.SET.retrieve_as_bitwise` flag is set to True.
+
+ :param values: The range of valid values for this SET. The values
+ are not quoted, they will be escaped and surrounded by single
+ quotes when generating the schema.
+
+ :param convert_unicode: Same flag as that of
+ :paramref:`.String.convert_unicode`.
+
+ :param collation: same as that of :paramref:`.String.collation`
+
+ :param charset: same as that of :paramref:`.VARCHAR.charset`.
+
+ :param ascii: same as that of :paramref:`.VARCHAR.ascii`.
+
+ :param unicode: same as that of :paramref:`.VARCHAR.unicode`.
+
+ :param binary: same as that of :paramref:`.VARCHAR.binary`.
+
+ :param retrieve_as_bitwise: if True, the data for the set type will be
+ persisted and selected using an integer value, where a set is coerced
+ into a bitwise mask for persistence. MySQL allows this mode which
+ has the advantage of being able to store values unambiguously,
+ such as the blank string ``''``. The datatype will appear
+ as the expression ``col + 0`` in a SELECT statement, so that the
+ value is coerced into an integer value in result sets.
+ This flag is required if one wishes
+ to persist a set that can store the blank string ``''`` as a value.
+
+ .. warning::
+
+ When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
+ essential that the list of set values is expressed in the
+ **exact same order** as exists on the MySQL database.
+
+ .. versionadded:: 1.0.0
+
+ :param quoting: Not used. A warning will be raised if passed.
+
+ """
+ if kw.pop("quoting", NO_ARG) is not NO_ARG:
+ util.warn_deprecated_20(
+ "The 'quoting' parameter to :class:`.mysql.SET` is deprecated"
+ " and will be removed in a future release. "
+ "This parameter now has no effect."
+ )
+ self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
+ self.values = tuple(values)
+ if not self.retrieve_as_bitwise and "" in values:
+ raise exc.ArgumentError(
+ "Can't use the blank value '' in a SET without "
+ "setting retrieve_as_bitwise=True"
+ )
+ if self.retrieve_as_bitwise:
+ self._bitmap = dict(
+ (value, 2 ** idx) for idx, value in enumerate(self.values)
+ )
+ self._bitmap.update(
+ (2 ** idx, value) for idx, value in enumerate(self.values)
+ )
+ length = max([len(v) for v in values] + [0])
+ kw.setdefault("length", length)
+ super(SET, self).__init__(**kw)
+
+ def column_expression(self, colexpr):
+ if self.retrieve_as_bitwise:
+ return sql.type_coerce(
+ sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
+ )
+ else:
+ return colexpr
+
+ def result_processor(self, dialect, coltype):
+ if self.retrieve_as_bitwise:
+
+ def process(value):
+ if value is not None:
+ value = int(value)
+
+ return set(util.map_bits(self._bitmap.__getitem__, value))
+ else:
+ return None
+
+ else:
+ super_convert = super(SET, self).result_processor(dialect, coltype)
+
+ def process(value):
+ if isinstance(value, util.string_types):
+ # MySQLdb returns a string, let's parse
+ if super_convert:
+ value = super_convert(value)
+ return set(re.findall(r"[^,]+", value))
+ else:
+ # mysql-connector-python does a naive
+ # split(",") which throws in an empty string
+ if value is not None:
+ value.discard("")
+ return value
+
+ return process
+
+ def bind_processor(self, dialect):
+ super_convert = super(SET, self).bind_processor(dialect)
+ if self.retrieve_as_bitwise:
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.int_types + util.string_types):
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
+ else:
+ int_value = 0
+ for v in value:
+ int_value |= self._bitmap[v]
+ return int_value
+
+ else:
+
+ def process(value):
+ # accept strings and int (actually bitflag) values directly
+ if value is not None and not isinstance(
+ value, util.int_types + util.string_types
+ ):
+ value = ",".join(value)
+
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
+
+ return process
+
+ def adapt(self, impltype, **kw):
+ kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
+ return util.constructor_copy(self, impltype, *self.values, **kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self,
+ to_inspect=[SET, _StringType],
+ additional_kw=[
+ ("retrieve_as_bitwise", False),
+ ],
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py
new file mode 100644
index 0000000..7a66e9b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/expression.py
@@ -0,0 +1,130 @@
+from ... import exc
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import operators
+from ...sql import roles
+from ...sql.base import _generative
+from ...sql.base import Generative
+
+
+class match(Generative, elements.BinaryExpression):
+ """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
+
+ E.g.::
+
+ from sqlalchemy import desc
+ from sqlalchemy.dialects.mysql import match
+
+ match_expr = match(
+ users_table.c.firstname,
+ users_table.c.lastname,
+ against="Firstname Lastname",
+ )
+
+ stmt = (
+ select(users_table)
+ .where(match_expr.in_boolean_mode())
+ .order_by(desc(match_expr))
+ )
+
+ Would produce SQL resembling::
+
+ SELECT id, firstname, lastname
+ FROM user
+ WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
+ ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
+
+ The :func:`_mysql.match` function is a standalone version of the
+ :meth:`_sql.ColumnElement.match` method available on all
+ SQL expressions, as when :meth:`_expression.ColumnElement.match` is
+ used, but allows to pass multiple columns
+
+ :param cols: column expressions to match against
+
+ :param against: expression to be compared towards
+
+ :param in_boolean_mode: boolean, set "boolean mode" to true
+
+ :param in_natural_language_mode: boolean , set "natural language" to true
+
+ :param with_query_expansion: boolean, set "query expansion" to true
+
+ .. versionadded:: 1.4.19
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.match`
+
+ """
+
+ __visit_name__ = "mysql_match"
+
+ inherit_cache = True
+
+ def __init__(self, *cols, **kw):
+ if not cols:
+ raise exc.ArgumentError("columns are required")
+
+ against = kw.pop("against", None)
+
+ if against is None:
+ raise exc.ArgumentError("against is required")
+ against = coercions.expect(
+ roles.ExpressionElementRole,
+ against,
+ )
+
+ left = elements.BooleanClauseList._construct_raw(
+ operators.comma_op,
+ clauses=cols,
+ )
+ left.group = False
+
+ flags = util.immutabledict(
+ {
+ "mysql_boolean_mode": kw.pop("in_boolean_mode", False),
+ "mysql_natural_language": kw.pop(
+ "in_natural_language_mode", False
+ ),
+ "mysql_query_expansion": kw.pop("with_query_expansion", False),
+ }
+ )
+
+ if kw:
+ raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
+
+ super(match, self).__init__(
+ left, against, operators.match_op, modifiers=flags
+ )
+
+ @_generative
+ def in_boolean_mode(self):
+ """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
+
+ @_generative
+ def in_natural_language_mode(self):
+ """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
+ expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_natural_language": True})
+
+ @_generative
+ def with_query_expansion(self):
+ """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py
new file mode 100644
index 0000000..857fcce
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/json.py
@@ -0,0 +1,84 @@
+# mysql/json.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+from ... import types as sqltypes
+
+
+class JSON(sqltypes.JSON):
+ """MySQL JSON type.
+
+ MySQL supports JSON as of version 5.7.
+ MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2.
+
+ :class:`_mysql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a MySQL or MariaDB backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`.mysql.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_EXTRACT``
+ function at the database level.
+
+ .. versionadded:: 1.1
+
+ """
+
+ pass
+
+
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py
new file mode 100644
index 0000000..568c3f0
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mariadb.py
@@ -0,0 +1,25 @@
+from .base import MariaDBIdentifierPreparer
+from .base import MySQLDialect
+
+
+class MariaDBDialect(MySQLDialect):
+ is_mariadb = True
+ supports_statement_cache = True
+ name = "mariadb"
+ preparer = MariaDBIdentifierPreparer
+
+
+def loader(driver):
+ driver_mod = __import__(
+ "sqlalchemy.dialects.mysql.%s" % driver
+ ).dialects.mysql
+ driver_cls = getattr(driver_mod, driver).dialect
+
+ return type(
+ "MariaDBDialect_%s" % driver,
+ (
+ MariaDBDialect,
+ driver_cls,
+ ),
+ {"supports_statement_cache": True},
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
new file mode 100644
index 0000000..c8b2ead
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
@@ -0,0 +1,240 @@
+# mysql/mariadbconnector.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+mariadbconnector
+ :name: MariaDB Connector/Python
+ :dbapi: mariadb
+ :connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mariadb/
+
+Driver Status
+-------------
+
+MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
+databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
+It is written in C and uses MariaDB Connector/C client library for client server
+communication.
+
+Note that the default driver for a ``mariadb://`` connection URI continues to
+be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
+
+.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
+
+""" # noqa
+import re
+
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from ... import sql
+from ... import util
+
+mariadb_cpy_minimum_version = (1, 0, 1)
+
+
+class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
+ _lastrowid = None
+
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(buffered=False)
+
+ def create_default_cursor(self):
+ return self._dbapi_connection.cursor(buffered=True)
+
+ def post_exec(self):
+ if self.isinsert and self.compiled.postfetch_lastrowid:
+ self._lastrowid = self.cursor.lastrowid
+
+ def get_lastrowid(self):
+ return self._lastrowid
+
+
+class MySQLCompiler_mariadbconnector(MySQLCompiler):
+ pass
+
+
+class MySQLDialect_mariadbconnector(MySQLDialect):
+ driver = "mariadbconnector"
+ supports_statement_cache = True
+
+ # set this to True at the module level to prevent the driver from running
+ # against a backend that server detects as MySQL. currently this appears to
+ # be unnecessary as MariaDB client libraries have always worked against
+ # MySQL databases. However, if this changes at some point, this can be
+ # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
+ # this change is made at some point to ensure the correct exception
+ # is raised at the correct point when running the driver against
+ # a MySQL backend.
+ # is_mariadb = True
+
+ supports_unicode_statements = True
+ encoding = "utf8mb4"
+ convert_unicode = True
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ supports_native_decimal = True
+ default_paramstyle = "qmark"
+ execution_ctx_cls = MySQLExecutionContext_mariadbconnector
+ statement_compiler = MySQLCompiler_mariadbconnector
+
+ supports_server_side_cursors = True
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ def __init__(self, **kwargs):
+ super(MySQLDialect_mariadbconnector, self).__init__(**kwargs)
+ self.paramstyle = "qmark"
+ if self.dbapi is not None:
+ if self._dbapi_version < mariadb_cpy_minimum_version:
+ raise NotImplementedError(
+ "The minimum required version for MariaDB "
+ "Connector/Python is %s"
+ % ".".join(str(x) for x in mariadb_cpy_minimum_version)
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("mariadb")
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_mariadbconnector, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ elif isinstance(e, self.dbapi.Error):
+ str_e = str(e).lower()
+ return "not connected" in str_e or "isn't valid" in str_e
+ else:
+ return False
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args()
+
+ int_params = [
+ "connect_timeout",
+ "read_timeout",
+ "write_timeout",
+ "client_flag",
+ "port",
+ "pool_size",
+ ]
+ bool_params = [
+ "local_infile",
+ "ssl_verify_cert",
+ "ssl",
+ "pool_reset_connection",
+ ]
+
+ for key in int_params:
+ util.coerce_kw_type(opts, key, int)
+ for key in bool_params:
+ util.coerce_kw_type(opts, key, bool)
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ client_flag = opts.get("client_flag", 0)
+ if self.dbapi is not None:
+ try:
+ CLIENT_FLAGS = __import__(
+ self.dbapi.__name__ + ".constants.CLIENT"
+ ).constants.CLIENT
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ except (AttributeError, ImportError):
+ self.supports_sane_rowcount = False
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _extract_error_code(self, exception):
+ try:
+ rc = exception.errno
+ except:
+ rc = -1
+ return rc
+
+ def _detect_charset(self, connection):
+ return "utf8mb4"
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(MySQLDialect_mariadbconnector, self)._set_isolation_level(
+ connection, level
+ )
+
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(
+ sql.text("XA BEGIN :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(
+ sql.text("XA END :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+ connection.execute(
+ sql.text("XA PREPARE :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ connection.execute(
+ sql.text("XA END :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+ connection.execute(
+ sql.text("XA ROLLBACK :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(
+ sql.text("XA COMMIT :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+
+dialect = MySQLDialect_mariadbconnector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
new file mode 100644
index 0000000..356babe
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -0,0 +1,240 @@
+# mysql/mysqlconnector.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: mysql+mysqlconnector
+ :name: MySQL Connector/Python
+ :dbapi: myconnpy
+ :connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mysql-connector-python/
+
+.. note::
+
+ The MySQL Connector/Python DBAPI has had many issues since its release,
+ some of which may remain unresolved, and the mysqlconnector dialect is
+ **not tested as part of SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+
+""" # noqa
+
+import re
+
+from .base import BIT
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLIdentifierPreparer
+from ... import processors
+from ... import util
+
+
+class MySQLCompiler_mysqlconnector(MySQLCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ if self.dialect._mysqlconnector_double_percents:
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
+
+ def post_process_text(self, text):
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace("%", "%%")
+ else:
+ return text
+
+ def escape_literal_column(self, text):
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace("%", "%%")
+ else:
+ return text
+
+
+class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
+ @property
+ def _double_percents(self):
+ return self.dialect._mysqlconnector_double_percents
+
+ @_double_percents.setter
+ def _double_percents(self, value):
+ pass
+
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ if self.dialect._mysqlconnector_double_percents:
+ return value.replace("%", "%%")
+ else:
+ return value
+
+
+class _myconnpyBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """MySQL-connector already converts mysql bits, so."""
+
+ return None
+
+
+class MySQLDialect_mysqlconnector(MySQLDialect):
+ driver = "mysqlconnector"
+ supports_statement_cache = True
+
+ supports_unicode_binds = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_native_decimal = True
+
+ default_paramstyle = "format"
+ statement_compiler = MySQLCompiler_mysqlconnector
+
+ preparer = MySQLIdentifierPreparer_mysqlconnector
+
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
+
+ def __init__(self, *arg, **kw):
+ super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw)
+
+ # hack description encoding since mysqlconnector randomly
+ # returns bytes or not
+ self._description_decoder = (
+ processors.to_conditional_unicode_processor_factory
+ )(self.description_encoding)
+
+ def _check_unicode_description(self, connection):
+ # hack description encoding since mysqlconnector randomly
+ # returns bytes or not
+ return False
+
+ @property
+ def description_encoding(self):
+ # total guess
+ return "latin-1"
+
+ @util.memoized_property
+ def supports_unicode_statements(self):
+ return util.py3k or self._mysqlconnector_version_info > (2, 0)
+
+ @classmethod
+ def dbapi(cls):
+ from mysql import connector
+
+ return connector
+
+ def do_ping(self, dbapi_connection):
+ try:
+ dbapi_connection.ping(False)
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, None):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "allow_local_infile", bool)
+ util.coerce_kw_type(opts, "autocommit", bool)
+ util.coerce_kw_type(opts, "buffered", bool)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connection_timeout", int)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "consume_results", bool)
+ util.coerce_kw_type(opts, "force_ipv6", bool)
+ util.coerce_kw_type(opts, "get_warnings", bool)
+ util.coerce_kw_type(opts, "pool_reset_session", bool)
+ util.coerce_kw_type(opts, "pool_size", int)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
+ util.coerce_kw_type(opts, "raw", bool)
+ util.coerce_kw_type(opts, "ssl_verify_cert", bool)
+ util.coerce_kw_type(opts, "use_pure", bool)
+ util.coerce_kw_type(opts, "use_unicode", bool)
+
+ # unfortunately, MySQL/connector python refuses to release a
+ # cursor without reading fully, so non-buffered isn't an option
+ opts.setdefault("buffered", True)
+
+ # FOUND_ROWS must be set in ClientFlag to enable
+ # supports_sane_rowcount.
+ if self.dbapi is not None:
+ try:
+ from mysql.connector.constants import ClientFlag
+
+ client_flags = opts.get(
+ "client_flags", ClientFlag.get_default()
+ )
+ client_flags |= ClientFlag.FOUND_ROWS
+ opts["client_flags"] = client_flags
+ except Exception:
+ pass
+ return [[], opts]
+
+ @util.memoized_property
+ def _mysqlconnector_version_info(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+
+ @util.memoized_property
+ def _mysqlconnector_double_percents(self):
+ return not util.py3k and self._mysqlconnector_version_info < (2, 0)
+
+ def _detect_charset(self, connection):
+ return connection.connection.charset
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def is_disconnect(self, e, connection, cursor):
+ errnos = (2006, 2013, 2014, 2045, 2055, 2048)
+ exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
+ if isinstance(e, exceptions):
+ return (
+ e.errno in errnos
+ or "MySQL Connection not available." in str(e)
+ or "Connection to MySQL is not available" in str(e)
+ )
+ else:
+ return False
+
+ def _compat_fetchall(self, rp, charset=None):
+ return rp.fetchall()
+
+ def _compat_fetchone(self, rp, charset=None):
+ return rp.fetchone()
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MySQLDialect_mysqlconnector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
new file mode 100644
index 0000000..7a721e8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -0,0 +1,331 @@
+# mysql/mysqldb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+mysqldb
+ :name: mysqlclient (maintained fork of MySQL-Python)
+ :dbapi: mysqldb
+ :connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mysqlclient/
+
+Driver Status
+-------------
+
+The mysqlclient DBAPI is a maintained fork of the
+`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI
+that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3
+and is very stable.
+
+.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
+
+.. _mysqldb_unicode:
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+.. _mysqldb_ssl:
+
+SSL Connections
+----------------
+
+The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the
+key "ssl", which may be specified using the
+:paramref:`_sa.create_engine.connect_args` dictionary::
+
+ engine = create_engine(
+ "mysql+mysqldb://scott:tiger@192.168.0.134/test",
+ connect_args={
+ "ssl": {
+ "ssl_ca": "/home/gord/client-ssl/ca.pem",
+ "ssl_cert": "/home/gord/client-ssl/client-cert.pem",
+ "ssl_key": "/home/gord/client-ssl/client-key.pem"
+ }
+ }
+ )
+
+For convenience, the following keys may also be specified inline within the URL
+where they will be interpreted into the "ssl" dictionary automatically:
+"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher",
+"ssl_check_hostname". An example is as follows::
+
+ connection_uri = (
+ "mysql+mysqldb://scott:tiger@192.168.0.134/test"
+ "?ssl_ca=/home/gord/client-ssl/ca.pem"
+ "&ssl_cert=/home/gord/client-ssl/client-cert.pem"
+ "&ssl_key=/home/gord/client-ssl/client-key.pem"
+ )
+
+If the server uses an automatically-generated certificate that is self-signed
+or does not match the host name (as seen from the client), it may also be
+necessary to indicate ``ssl_check_hostname=false``::
+
+ connection_uri = (
+ "mysql+pymysql://scott:tiger@192.168.0.134/test"
+ "?ssl_ca=/home/gord/client-ssl/ca.pem"
+ "&ssl_cert=/home/gord/client-ssl/client-cert.pem"
+ "&ssl_key=/home/gord/client-ssl/client-key.pem"
+ "&ssl_check_hostname=false"
+ )
+
+
+.. seealso::
+
+ :ref:`pymysql_ssl` in the PyMySQL dialect
+
+
+Using MySQLdb with Google Cloud SQL
+-----------------------------------
+
+Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
+using a URL like the following::
+
+ mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
+
+Server Side Cursors
+-------------------
+
+The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
+
+"""
+
+import re
+
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from .base import MySQLIdentifierPreparer
+from .base import TEXT
+from ... import sql
+from ... import util
+
+
+class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
+ @property
+ def rowcount(self):
+ if hasattr(self, "_rowcount"):
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
+
+class MySQLCompiler_mysqldb(MySQLCompiler):
+ pass
+
+
+class MySQLDialect_mysqldb(MySQLDialect):
+ driver = "mysqldb"
+ supports_statement_cache = True
+ supports_unicode_statements = True
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_native_decimal = True
+
+ default_paramstyle = "format"
+ execution_ctx_cls = MySQLExecutionContext_mysqldb
+ statement_compiler = MySQLCompiler_mysqldb
+ preparer = MySQLIdentifierPreparer
+
+ def __init__(self, **kwargs):
+ super(MySQLDialect_mysqldb, self).__init__(**kwargs)
+ self._mysql_dbapi_version = (
+ self._parse_dbapi_version(self.dbapi.__version__)
+ if self.dbapi is not None and hasattr(self.dbapi, "__version__")
+ else (0, 0, 0)
+ )
+
+ def _parse_dbapi_version(self, version):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+ else:
+ return (0, 0, 0)
+
+ @util.langhelpers.memoized_property
+ def supports_server_side_cursors(self):
+ try:
+ cursors = __import__("MySQLdb.cursors").cursors
+ self._sscursor = cursors.SSCursor
+ return True
+ except (ImportError, AttributeError):
+ return False
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("MySQLdb")
+
+ def on_connect(self):
+ super_ = super(MySQLDialect_mysqldb, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ charset_name = conn.character_set_name()
+
+ if charset_name is not None:
+ cursor = conn.cursor()
+ cursor.execute("SET NAMES %s" % charset_name)
+ cursor.close()
+
+ return on_connect
+
+ def do_ping(self, dbapi_connection):
+ try:
+ dbapi_connection.ping(False)
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, None):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ rowcount = cursor.executemany(statement, parameters)
+ if context is not None:
+ context._rowcount = rowcount
+
+ def _check_unicode_returns(self, connection):
+ # work around issue fixed in
+ # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
+ # specific issue w/ the utf8mb4_bin collation and unicode returns
+
+ collation = connection.exec_driver_sql(
+ "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
+ % (
+ self.identifier_preparer.quote("Charset"),
+ self.identifier_preparer.quote("Collation"),
+ )
+ ).scalar()
+ has_utf8mb4_bin = self.server_version_info > (5,) and collation
+ if has_utf8mb4_bin:
+ additional_tests = [
+ sql.collate(
+ sql.cast(
+ sql.literal_column("'test collated returns'"),
+ TEXT(charset="utf8mb4"),
+ ),
+ "utf8mb4_bin",
+ )
+ ]
+ else:
+ additional_tests = []
+ return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
+ connection, additional_tests
+ )
+
+ def create_connect_args(self, url, _translate_args=None):
+ if _translate_args is None:
+ _translate_args = dict(
+ database="db", username="user", password="passwd"
+ )
+
+ opts = url.translate_connect_args(**_translate_args)
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "read_timeout", int)
+ util.coerce_kw_type(opts, "write_timeout", int)
+ util.coerce_kw_type(opts, "client_flag", int)
+ util.coerce_kw_type(opts, "local_infile", int)
+ # Note: using either of the below will cause all strings to be
+ # returned as Unicode, both in raw SQL operations and with column
+ # types like String and MSString.
+ util.coerce_kw_type(opts, "use_unicode", bool)
+ util.coerce_kw_type(opts, "charset", str)
+
+ # Rich values 'cursorclass' and 'conv' are not supported via
+ # query string.
+
+ ssl = {}
+ keys = [
+ ("ssl_ca", str),
+ ("ssl_key", str),
+ ("ssl_cert", str),
+ ("ssl_capath", str),
+ ("ssl_cipher", str),
+ ("ssl_check_hostname", bool),
+ ]
+ for key, kw_type in keys:
+ if key in opts:
+ ssl[key[4:]] = opts[key]
+ util.coerce_kw_type(ssl, key[4:], kw_type)
+ del opts[key]
+ if ssl:
+ opts["ssl"] = ssl
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ client_flag = opts.get("client_flag", 0)
+
+ client_flag_found_rows = self._found_rows_client_flag()
+ if client_flag_found_rows is not None:
+ client_flag |= client_flag_found_rows
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _found_rows_client_flag(self):
+ if self.dbapi is not None:
+ try:
+ CLIENT_FLAGS = __import__(
+ self.dbapi.__name__ + ".constants.CLIENT"
+ ).constants.CLIENT
+ except (AttributeError, ImportError):
+ return None
+ else:
+ return CLIENT_FLAGS.FOUND_ROWS
+ else:
+ return None
+
+ def _extract_error_code(self, exception):
+ return exception.args[0]
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ try:
+ # note: the SQL here would be
+ # "SHOW VARIABLES LIKE 'character_set%%'"
+ cset_name = connection.connection.character_set_name
+ except AttributeError:
+ util.warn(
+ "No 'character_set_name' can be detected with "
+ "this MySQL-Python version; "
+ "please upgrade to a recent version of MySQL-Python. "
+ "Assuming latin1."
+ )
+ return "latin1"
+ else:
+ return cset_name()
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit(True)
+ else:
+ connection.autocommit(False)
+ super(MySQLDialect_mysqldb, self)._set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MySQLDialect_mysqldb
diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py
new file mode 100644
index 0000000..f6287dc
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/oursql.py
@@ -0,0 +1,273 @@
+# mysql/oursql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+oursql
+ :name: OurSQL
+ :dbapi: oursql
+ :connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://packages.python.org/oursql/
+
+.. note::
+
+ The OurSQL MySQL dialect is legacy and is no longer supported upstream,
+ and is **not tested as part of SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+
+.. deprecated:: 1.4 The OurSQL DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mysql.
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+
+"""
+
+
+from .base import BIT
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from ... import types as sqltypes
+from ... import util
+
+
+class _oursqlBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """oursql already converts mysql bits, so."""
+
+ return None
+
+
+class MySQLExecutionContext_oursql(MySQLExecutionContext):
+ @property
+ def plain_query(self):
+ return self.execution_options.get("_oursql_plain_query", False)
+
+
+class MySQLDialect_oursql(MySQLDialect):
+ driver = "oursql"
+ supports_statement_cache = True
+
+ if util.py2k:
+ supports_unicode_binds = True
+ supports_unicode_statements = True
+
+ supports_native_decimal = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ execution_ctx_cls = MySQLExecutionContext_oursql
+
+ colspecs = util.update_copy(
+ MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT}
+ )
+
+ @classmethod
+ def dbapi(cls):
+ util.warn_deprecated(
+ "The OurSQL DBAPI is deprecated and will be removed "
+ "in a future version. Please use one of the supported DBAPIs to "
+ "connect to mysql.",
+ version="1.4",
+ )
+ return __import__("oursql")
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of
+ *cursor.execute(statement, parameters)*."""
+
+ if context and context.plain_query:
+ cursor.execute(statement, plain_query=True)
+ else:
+ cursor.execute(statement, parameters)
+
+ def do_begin(self, connection):
+ connection.cursor().execute("BEGIN", plain_query=True)
+
+ def _xa_query(self, connection, query, xid):
+ if util.py2k:
+ arg = connection.connection._escape_string(xid)
+ else:
+ charset = self._connection_charset
+ arg = connection.connection._escape_string(
+ xid.encode(charset)
+ ).decode(charset)
+ arg = "'%s'" % arg
+ connection.execution_options(_oursql_plain_query=True).exec_driver_sql(
+ query % arg
+ )
+
+ # Because mysql is bad, these methods have to be
+ # reimplemented to use _PlainQuery. Basically, some queries
+ # refuse to return any data if they're run through
+ # the parameterized query API, or refuse to be parameterized
+ # in the first place.
+ def do_begin_twophase(self, connection, xid):
+ self._xa_query(connection, "XA BEGIN %s", xid)
+
+ def do_prepare_twophase(self, connection, xid):
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA PREPARE %s", xid)
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA ROLLBACK %s", xid)
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ self._xa_query(connection, "XA COMMIT %s", xid)
+
+ # Q: why didn't we need all these "plain_query" overrides earlier ?
+ # am i on a newer/older version of OurSQL ?
+ def has_table(self, connection, table_name, schema=None):
+ return MySQLDialect.has_table(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema,
+ )
+
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ return MySQLDialect.get_table_options(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema=schema,
+ **kw
+ )
+
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ return MySQLDialect.get_columns(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema=schema,
+ **kw
+ )
+
+ def get_view_names(self, connection, schema=None, **kw):
+ return MySQLDialect.get_view_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ schema=schema,
+ **kw
+ )
+
+ def get_table_names(self, connection, schema=None, **kw):
+ return MySQLDialect.get_table_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ schema,
+ )
+
+ def get_schema_names(self, connection, **kw):
+ return MySQLDialect.get_schema_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ **kw
+ )
+
+ def initialize(self, connection):
+ return MySQLDialect.initialize(
+ self, connection.execution_options(_oursql_plain_query=True)
+ )
+
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
+ return MySQLDialect._show_create_table(
+ self,
+ connection.connect(close_with_result=True).execution_options(
+ _oursql_plain_query=True
+ ),
+ table,
+ charset,
+ full_name,
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return (
+ e.errno is None
+ and "cursor" not in e.args[1]
+ and e.args[1].endswith("closed")
+ )
+ else:
+ return e.errno in (2006, 2013, 2014, 2045, 2055)
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(
+ database="db", username="user", password="passwd"
+ )
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "port", int)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "autoping", bool)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
+
+ util.coerce_kw_type(opts, "default_charset", bool)
+ if opts.pop("default_charset", False):
+ opts["charset"] = None
+ else:
+ util.coerce_kw_type(opts, "charset", str)
+ opts["use_unicode"] = opts.get("use_unicode", True)
+ util.coerce_kw_type(opts, "use_unicode", bool)
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ opts.setdefault("found_rows", True)
+
+ ssl = {}
+ for key in [
+ "ssl_ca",
+ "ssl_key",
+ "ssl_cert",
+ "ssl_capath",
+ "ssl_cipher",
+ ]:
+ if key in opts:
+ ssl[key[4:]] = opts[key]
+ util.coerce_kw_type(ssl, key[4:], str)
+ del opts[key]
+ if ssl:
+ opts["ssl"] = ssl
+
+ return [[], opts]
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ return connection.connection.charset
+
+ def _compat_fetchall(self, rp, charset=None):
+ """oursql isn't super-broken like MySQLdb, yaaay."""
+ return rp.fetchall()
+
+ def _compat_fetchone(self, rp, charset=None):
+ """oursql isn't super-broken like MySQLdb, yaaay."""
+ return rp.fetchone()
+
+ def _compat_first(self, rp, charset=None):
+ return rp.first()
+
+
+dialect = MySQLDialect_oursql
diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py
new file mode 100644
index 0000000..86aaa94
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/provision.py
@@ -0,0 +1,78 @@
+from ... import exc
+from ...testing.provision import configure_follower
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import generate_driver_url
+from ...testing.provision import temp_table_keyword_args
+
+
+@generate_driver_url.for_db("mysql", "mariadb")
+def generate_driver_url(url, driver, query_str):
+ backend = url.get_backend_name()
+
+ # NOTE: at the moment, tests are running mariadbconnector
+ # against both mariadb and mysql backends. if we want this to be
+ # limited, do the decision making here to reject a "mysql+mariadbconnector"
+ # URL. Optionally also re-enable the module level
+ # MySQLDialect_mariadbconnector.is_mysql flag as well, which must include
+ # a unit and/or functional test.
+
+ # all the Jenkins tests have been running mysqlclient Python library
+ # built against mariadb client drivers for years against all MySQL /
+ # MariaDB versions going back to MySQL 5.6, currently they can talk
+ # to MySQL databases without problems.
+
+ if backend == "mysql":
+ dialect_cls = url.get_dialect()
+ if dialect_cls._is_mariadb_from_url(url):
+ backend = "mariadb"
+
+ new_url = url.set(
+ drivername="%s+%s" % (backend, driver)
+ ).update_query_string(query_str)
+
+ try:
+ new_url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return new_url
+
+
+@create_db.for_db("mysql", "mariadb")
+def _mysql_create_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ try:
+ _mysql_drop_db(cfg, conn, ident)
+ except Exception:
+ pass
+
+ with eng.begin() as conn:
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
+ )
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
+ )
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
+ )
+
+
+@configure_follower.for_db("mysql", "mariadb")
+def _mysql_configure_follower(config, ident):
+ config.test_schema = "%s_test_schema" % ident
+ config.test_schema_2 = "%s_test_schema_2" % ident
+
+
+@drop_db.for_db("mysql", "mariadb")
+def _mysql_drop_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
+ conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
+ conn.exec_driver_sql("DROP DATABASE %s" % ident)
+
+
+@temp_table_keyword_args.for_db("mysql", "mariadb")
+def _mysql_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py
new file mode 100644
index 0000000..f620133
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/pymysql.py
@@ -0,0 +1,98 @@
+# mysql/pymysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: mysql+pymysql
+ :name: PyMySQL
+ :dbapi: pymysql
+ :connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
+ :url: https://pymysql.readthedocs.io/
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+.. _pymysql_ssl:
+
+SSL Connections
+------------------
+
+The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb,
+described at :ref:`mysqldb_ssl`. See that section for examples.
+
+
+MySQL-Python Compatibility
+--------------------------
+
+The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
+and targets 100% compatibility. Most behavioral notes for MySQL-python apply
+to the pymysql driver as well.
+
+""" # noqa
+
+from .mysqldb import MySQLDialect_mysqldb
+from ...util import langhelpers
+from ...util import py3k
+
+
+class MySQLDialect_pymysql(MySQLDialect_mysqldb):
+ driver = "pymysql"
+ supports_statement_cache = True
+
+ description_encoding = None
+
+ # generally, these two values should be both True
+ # or both False. PyMySQL unicode tests pass all the way back
+ # to 0.4 either way. See [ticket:3337]
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ @langhelpers.memoized_property
+ def supports_server_side_cursors(self):
+ try:
+ cursors = __import__("pymysql.cursors").cursors
+ self._sscursor = cursors.SSCursor
+ return True
+ except (ImportError, AttributeError):
+ return False
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pymysql")
+
+ def create_connect_args(self, url, _translate_args=None):
+ if _translate_args is None:
+ _translate_args = dict(username="user")
+ return super(MySQLDialect_pymysql, self).create_connect_args(
+ url, _translate_args=_translate_args
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_pymysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ elif isinstance(e, self.dbapi.Error):
+ str_e = str(e).lower()
+ return (
+ "already closed" in str_e or "connection was killed" in str_e
+ )
+ else:
+ return False
+
+ if py3k:
+
+ def _extract_error_code(self, exception):
+ if isinstance(exception.args[0], Exception):
+ exception = exception.args[0]
+ return exception.args[0]
+
+
+dialect = MySQLDialect_pymysql
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py
new file mode 100644
index 0000000..bfa61f6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py
@@ -0,0 +1,136 @@
+# mysql/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+
+.. dialect:: mysql+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
+ :url: https://pypi.org/project/pyodbc/
+
+.. note::
+
+ The PyODBC for MySQL dialect is **not tested as part of
+ SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+ However, if you want to use the mysql+pyodbc dialect and require
+ full support for ``utf8mb4`` characters (including supplementary
+ characters like emoji) be sure to use a current release of
+ MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode")
+ version of the driver in your DSN or connection string.
+
+Pass through exact pyodbc connection string::
+
+ import urllib
+ connection_string = (
+ 'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
+ 'SERVER=localhost;'
+ 'PORT=3307;'
+ 'DATABASE=mydb;'
+ 'UID=root;'
+ 'PWD=(whatever);'
+ 'charset=utf8mb4;'
+ )
+ params = urllib.parse.quote_plus(connection_string)
+ connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
+
+""" # noqa
+
+import re
+
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from .types import TIME
+from ... import exc
+from ... import util
+from ...connectors.pyodbc import PyODBCConnector
+from ...sql.sqltypes import Time
+
+
+class _pyodbcTIME(TIME):
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ # pyodbc returns a datetime.time object; no need to convert
+ return value
+
+ return process
+
+
+class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT LAST_INSERT_ID()")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+
+class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
+ supports_statement_cache = True
+ colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME})
+ supports_unicode_statements = True
+ execution_ctx_cls = MySQLExecutionContext_pyodbc
+
+ pyodbc_driver_name = "MySQL"
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+
+ # set this to None as _fetch_setting attempts to use it (None is OK)
+ self._connection_charset = None
+ try:
+ value = self._fetch_setting(connection, "character_set_client")
+ if value:
+ return value
+ except exc.DBAPIError:
+ pass
+
+ util.warn(
+ "Could not detect the connection character set. "
+ "Assuming latin1."
+ )
+ return "latin1"
+
+ def _get_server_version_info(self, connection):
+ return MySQLDialect._get_server_version_info(self, connection)
+
+ def _extract_error_code(self, exception):
+ m = re.compile(r"\((\d+)\)").search(str(exception.args))
+ c = m.group(1)
+ if c:
+ return int(c)
+ else:
+ return None
+
+ def on_connect(self):
+ super_ = super(MySQLDialect_pyodbc, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ # declare Unicode encoding for pyodbc as per
+ # https://github.com/mkleehammer/pyodbc/wiki/Unicode
+ pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR
+ pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR
+ conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8")
+ conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8")
+ conn.setencoding(encoding="utf-8")
+
+ return on_connect
+
+
+dialect = MySQLDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py
new file mode 100644
index 0000000..27394bb
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/reflection.py
@@ -0,0 +1,558 @@
+# mysql/reflection.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .enumerated import ENUM
+from .enumerated import SET
+from .types import DATETIME
+from .types import TIME
+from .types import TIMESTAMP
+from ... import log
+from ... import types as sqltypes
+from ... import util
+
+
+class ReflectedState(object):
+ """Stores raw information about a SHOW CREATE TABLE statement."""
+
+ def __init__(self):
+ self.columns = []
+ self.table_options = {}
+ self.table_name = None
+ self.keys = []
+ self.fk_constraints = []
+ self.ck_constraints = []
+
+
+@log.class_logger
+class MySQLTableDefinitionParser(object):
+ """Parses the results of a SHOW CREATE TABLE statement."""
+
+ def __init__(self, dialect, preparer):
+ self.dialect = dialect
+ self.preparer = preparer
+ self._prep_regexes()
+
+ def parse(self, show_create, charset):
+ state = ReflectedState()
+ state.charset = charset
+ for line in re.split(r"\r?\n", show_create):
+ if line.startswith(" " + self.preparer.initial_quote):
+ self._parse_column(line, state)
+ # a regular table options line
+ elif line.startswith(") "):
+ self._parse_table_options(line, state)
+ # an ANSI-mode table options line
+ elif line == ")":
+ pass
+ elif line.startswith("CREATE "):
+ self._parse_table_name(line, state)
+ # Not present in real reflection, but may be if
+ # loading from a file.
+ elif not line:
+ pass
+ else:
+ type_, spec = self._parse_constraints(line)
+ if type_ is None:
+ util.warn("Unknown schema content: %r" % line)
+ elif type_ == "key":
+ state.keys.append(spec)
+ elif type_ == "fk_constraint":
+ state.fk_constraints.append(spec)
+ elif type_ == "ck_constraint":
+ state.ck_constraints.append(spec)
+ else:
+ pass
+ return state
+
+ def _parse_constraints(self, line):
+ """Parse a KEY or CONSTRAINT line.
+
+ :param line: A line of SHOW CREATE TABLE output
+ """
+
+ # KEY
+ m = self._re_key.match(line)
+ if m:
+ spec = m.groupdict()
+ # convert columns into name, length pairs
+ # NOTE: we may want to consider SHOW INDEX as the
+ # format of indexes in MySQL becomes more complex
+ spec["columns"] = self._parse_keyexprs(spec["columns"])
+ if spec["version_sql"]:
+ m2 = self._re_key_version_sql.match(spec["version_sql"])
+ if m2 and m2.groupdict()["parser"]:
+ spec["parser"] = m2.groupdict()["parser"]
+ if spec["parser"]:
+ spec["parser"] = self.preparer.unformat_identifiers(
+ spec["parser"]
+ )[0]
+ return "key", spec
+
+ # FOREIGN KEY CONSTRAINT
+ m = self._re_fk_constraint.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["table"] = self.preparer.unformat_identifiers(spec["table"])
+ spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
+ spec["foreign"] = [
+ c[0] for c in self._parse_keyexprs(spec["foreign"])
+ ]
+ return "fk_constraint", spec
+
+ # CHECK constraint
+ m = self._re_ck_constraint.match(line)
+ if m:
+ spec = m.groupdict()
+ return "ck_constraint", spec
+
+ # PARTITION and SUBPARTITION
+ m = self._re_partition.match(line)
+ if m:
+ # Punt!
+ return "partition", line
+
+ # No match.
+ return (None, line)
+
+ def _parse_table_name(self, line, state):
+ """Extract the table name.
+
+ :param line: The first line of SHOW CREATE TABLE
+ """
+
+ regex, cleanup = self._pr_name
+ m = regex.match(line)
+ if m:
+ state.table_name = cleanup(m.group("name"))
+
+ def _parse_table_options(self, line, state):
+ """Build a dictionary of all reflected table-level options.
+
+ :param line: The final line of SHOW CREATE TABLE output.
+ """
+
+ options = {}
+
+ if not line or line == ")":
+ pass
+
+ else:
+ rest_of_line = line[:]
+ for regex, cleanup in self._pr_options:
+ m = regex.search(rest_of_line)
+ if not m:
+ continue
+ directive, value = m.group("directive"), m.group("val")
+ if cleanup:
+ value = cleanup(value)
+ options[directive.lower()] = value
+ rest_of_line = regex.sub("", rest_of_line)
+
+ for nope in ("auto_increment", "data directory", "index directory"):
+ options.pop(nope, None)
+
+ for opt, val in options.items():
+ state.table_options["%s_%s" % (self.dialect.name, opt)] = val
+
+ def _parse_column(self, line, state):
+ """Extract column details.
+
+ Falls back to a 'minimal support' variant if full parse fails.
+
+ :param line: Any column-bearing line from SHOW CREATE TABLE
+ """
+
+ spec = None
+ m = self._re_column.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["full"] = True
+ else:
+ m = self._re_column_loose.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["full"] = False
+ if not spec:
+ util.warn("Unknown column definition %r" % line)
+ return
+ if not spec["full"]:
+ util.warn("Incomplete reflection of column definition %r" % line)
+
+ name, type_, args = spec["name"], spec["coltype"], spec["arg"]
+
+ try:
+ col_type = self.dialect.ischema_names[type_]
+ except KeyError:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
+ col_type = sqltypes.NullType
+
+ # Column type positional arguments eg. varchar(32)
+ if args is None or args == "":
+ type_args = []
+ elif args[0] == "'" and args[-1] == "'":
+ type_args = self._re_csv_str.findall(args)
+ else:
+ type_args = [int(v) for v in self._re_csv_int.findall(args)]
+
+ # Column type keyword options
+ type_kw = {}
+
+ if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
+ if type_args:
+ type_kw["fsp"] = type_args.pop(0)
+
+ for kw in ("unsigned", "zerofill"):
+ if spec.get(kw, False):
+ type_kw[kw] = True
+ for kw in ("charset", "collate"):
+ if spec.get(kw, False):
+ type_kw[kw] = spec[kw]
+ if issubclass(col_type, (ENUM, SET)):
+ type_args = _strip_values(type_args)
+
+ if issubclass(col_type, SET) and "" in type_args:
+ type_kw["retrieve_as_bitwise"] = True
+
+ type_instance = col_type(*type_args, **type_kw)
+
+ col_kw = {}
+
+ # NOT NULL
+ col_kw["nullable"] = True
+ # this can be "NULL" in the case of TIMESTAMP
+ if spec.get("notnull", False) == "NOT NULL":
+ col_kw["nullable"] = False
+
+ # AUTO_INCREMENT
+ if spec.get("autoincr", False):
+ col_kw["autoincrement"] = True
+ elif issubclass(col_type, sqltypes.Integer):
+ col_kw["autoincrement"] = False
+
+ # DEFAULT
+ default = spec.get("default", None)
+
+ if default == "NULL":
+ # eliminates the need to deal with this later.
+ default = None
+
+ comment = spec.get("comment", None)
+
+ if comment is not None:
+ comment = comment.replace("\\\\", "\\").replace("''", "'")
+
+ sqltext = spec.get("generated")
+ if sqltext is not None:
+ computed = dict(sqltext=sqltext)
+ persisted = spec.get("persistence")
+ if persisted is not None:
+ computed["persisted"] = persisted == "STORED"
+ col_kw["computed"] = computed
+
+ col_d = dict(
+ name=name, type=type_instance, default=default, comment=comment
+ )
+ col_d.update(col_kw)
+ state.columns.append(col_d)
+
+ def _describe_to_create(self, table_name, columns):
+ """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
+
+ DESCRIBE is a much simpler reflection and is sufficient for
+ reflecting views for runtime use. This method formats DDL
+ for columns only- keys are omitted.
+
+ :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
+ SHOW FULL COLUMNS FROM rows must be rearranged for use with
+ this function.
+ """
+
+ buffer = []
+ for row in columns:
+ (name, col_type, nullable, default, extra) = [
+ row[i] for i in (0, 1, 2, 4, 5)
+ ]
+
+ line = [" "]
+ line.append(self.preparer.quote_identifier(name))
+ line.append(col_type)
+ if not nullable:
+ line.append("NOT NULL")
+ if default:
+ if "auto_increment" in default:
+ pass
+ elif col_type.startswith("timestamp") and default.startswith(
+ "C"
+ ):
+ line.append("DEFAULT")
+ line.append(default)
+ elif default == "NULL":
+ line.append("DEFAULT")
+ line.append(default)
+ else:
+ line.append("DEFAULT")
+ line.append("'%s'" % default.replace("'", "''"))
+ if extra:
+ line.append(extra)
+
+ buffer.append(" ".join(line))
+
+ return "".join(
+ [
+ (
+ "CREATE TABLE %s (\n"
+ % self.preparer.quote_identifier(table_name)
+ ),
+ ",\n".join(buffer),
+ "\n) ",
+ ]
+ )
+
+ def _parse_keyexprs(self, identifiers):
+ """Unpack '"col"(2),"col" ASC'-ish strings into components."""
+
+ return self._re_keyexprs.findall(identifiers)
+
+ def _prep_regexes(self):
+ """Pre-compile regular expressions."""
+
+ self._re_columns = []
+ self._pr_options = []
+
+ _final = self.preparer.final_quote
+
+ quotes = dict(
+ zip(
+ ("iq", "fq", "esc_fq"),
+ [
+ re.escape(s)
+ for s in (
+ self.preparer.initial_quote,
+ _final,
+ self.preparer._escape_identifier(_final),
+ )
+ ],
+ )
+ )
+
+ self._pr_name = _pr_compile(
+ r"^CREATE (?:\w+ +)?TABLE +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
+ self.preparer._unescape_identifier,
+ )
+
+ # `col`,`col2`(32),`col3`(15) DESC
+ #
+ self._re_keyexprs = _re_compile(
+ r"(?:"
+ r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
+ r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
+ )
+
+ # 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
+ self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
+
+ # 123 or 123,456
+ self._re_csv_int = _re_compile(r"\d+")
+
+ # `colname` <type> [type opts]
+ # (NOT NULL | NULL)
+ # DEFAULT ('value' | CURRENT_TIMESTAMP...)
+ # COMMENT 'comment'
+ # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
+ # STORAGE (DISK|MEMORY)
+ self._re_column = _re_compile(
+ r" "
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"(?P<coltype>\w+)"
+ r"(?:\((?P<arg>(?:\d+|\d+,\d+|"
+ r"(?:'(?:''|[^'])*',?)+))\))?"
+ r"(?: +(?P<unsigned>UNSIGNED))?"
+ r"(?: +(?P<zerofill>ZEROFILL))?"
+ r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?"
+ r"(?: +COLLATE +(?P<collate>[\w_]+))?"
+ r"(?: +(?P<notnull>(?:NOT )?NULL))?"
+ r"(?: +DEFAULT +(?P<default>"
+ r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
+ r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
+ r"))?"
+ r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
+ r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
+ r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
+ r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
+ r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
+ r"(?: +STORAGE +(?P<storage>\w+))?"
+ r"(?: +(?P<extra>.*))?"
+ r",?$" % quotes
+ )
+
+ # Fallback, try to parse as little as possible
+ self._re_column_loose = _re_compile(
+ r" "
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"(?P<coltype>\w+)"
+ r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
+ r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
+ )
+
+ # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
+ # (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
+ # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
+ self._re_key = _re_compile(
+ r" "
+ r"(?:(?P<type>\S+) )?KEY"
+ r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
+ r"(?: +USING +(?P<using_pre>\S+))?"
+ r" +\((?P<columns>.+?)\)"
+ r"(?: +USING +(?P<using_post>\S+))?"
+ r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
+ r"(?: +WITH PARSER +(?P<parser>\S+))?"
+ r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
+ r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
+ r",?$" % quotes
+ )
+
+ # https://forums.mysql.com/read.php?20,567102,567111#msg-567111
+ # It means if the MySQL version >= \d+, execute what's in the comment
+ self._re_key_version_sql = _re_compile(
+ r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
+ )
+
+ # CONSTRAINT `name` FOREIGN KEY (`local_col`)
+ # REFERENCES `remote` (`remote_col`)
+ # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
+ # ON DELETE CASCADE ON UPDATE RESTRICT
+ #
+ # unique constraints come back as KEYs
+ kw = quotes.copy()
+ kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
+ self._re_fk_constraint = _re_compile(
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"FOREIGN KEY +"
+ r"\((?P<local>[^\)]+?)\) REFERENCES +"
+ r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
+ r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
+ r"\((?P<foreign>[^\)]+?)\)"
+ r"(?: +(?P<match>MATCH \w+))?"
+ r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
+ r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
+ )
+
+ # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
+ # testing on MariaDB 10.2 shows that the CHECK constraint
+ # is returned on a line by itself, so to match without worrying
+ # about parenthesis in the expression we go to the end of the line
+ self._re_ck_constraint = _re_compile(
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"CHECK +"
+ r"\((?P<sqltext>.+)\),?" % kw
+ )
+
+ # PARTITION
+ #
+ # punt!
+ self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
+
+ # Table-level options (COLLATE, ENGINE, etc.)
+ # Do the string options first, since they have quoted
+ # strings we need to get rid of.
+ for option in _options_of_type_string:
+ self._add_option_string(option)
+
+ for option in (
+ "ENGINE",
+ "TYPE",
+ "AUTO_INCREMENT",
+ "AVG_ROW_LENGTH",
+ "CHARACTER SET",
+ "DEFAULT CHARSET",
+ "CHECKSUM",
+ "COLLATE",
+ "DELAY_KEY_WRITE",
+ "INSERT_METHOD",
+ "MAX_ROWS",
+ "MIN_ROWS",
+ "PACK_KEYS",
+ "ROW_FORMAT",
+ "KEY_BLOCK_SIZE",
+ ):
+ self._add_option_word(option)
+
+ self._add_option_regex("UNION", r"\([^\)]+\)")
+ self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
+ self._add_option_regex(
+ "RAID_TYPE",
+ r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
+ )
+
+ _optional_equals = r"(?:\s*(?:=\s*)|\s+)"
+
+ def _add_option_string(self, directive):
+ regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
+ self._pr_options.append(
+ _pr_compile(
+ regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
+ )
+ )
+
+ def _add_option_word(self, directive):
+ regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
+ self._pr_options.append(_pr_compile(regex))
+
+ def _add_option_regex(self, directive, regex):
+ regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
+ re.escape(directive),
+ self._optional_equals,
+ regex,
+ )
+ self._pr_options.append(_pr_compile(regex))
+
+
+_options_of_type_string = (
+ "COMMENT",
+ "DATA DIRECTORY",
+ "INDEX DIRECTORY",
+ "PASSWORD",
+ "CONNECTION",
+)
+
+
+def _pr_compile(regex, cleanup=None):
+ """Prepare a 2-tuple of compiled regex and callable."""
+
+ return (_re_compile(regex), cleanup)
+
+
+def _re_compile(regex):
+ """Compile a string to regex, I and UNICODE."""
+
+ return re.compile(regex, re.I | re.UNICODE)
+
+
+def _strip_values(values):
+ "Strip reflected values quotes"
+ strip_values = []
+ for a in values:
+ if a[0:1] == '"' or a[0:1] == "'":
+ # strip enclosing quotes and unquote interior
+ a = a[1:-1].replace(a[0] * 2, a[0])
+ strip_values.append(a)
+ return strip_values
diff --git a/lib/sqlalchemy/dialects/mysql/reserved_words.py b/lib/sqlalchemy/dialects/mysql/reserved_words.py
new file mode 100644
index 0000000..995168b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/reserved_words.py
@@ -0,0 +1,564 @@
+# mysql/reserved_words.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+# generated using:
+# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9
+#
+# https://mariadb.com/kb/en/reserved-words/
+# includes: Reserved Words, Oracle Mode (separate set unioned)
+# excludes: Exceptions, Function Names
+RESERVED_WORDS_MARIADB = {
+ "accessible",
+ "add",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "do_domain_ids",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "enclosed",
+ "escaped",
+ "except",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "general",
+ "grant",
+ "group",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "ignore_domain_ids",
+ "ignore_server_ids",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "intersect",
+ "interval",
+ "into",
+ "is",
+ "iterate",
+ "join",
+ "key",
+ "keys",
+ "kill",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_heartbeat_period",
+ "master_ssl_verify_server_cert",
+ "match",
+ "maxvalue",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "no_write_to_binlog",
+ "not",
+ "null",
+ "numeric",
+ "offset",
+ "on",
+ "optimize",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "over",
+ "page_checksum",
+ "parse_vcol_expr",
+ "partition",
+ "position",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "read",
+ "read_write",
+ "reads",
+ "real",
+ "recursive",
+ "ref_system_id",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "resignal",
+ "restrict",
+ "return",
+ "returning",
+ "revoke",
+ "right",
+ "rlike",
+ "rows",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "signal",
+ "slow",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "ssl",
+ "starting",
+ "stats_auto_recalc",
+ "stats_persistent",
+ "stats_sample_pages",
+ "straight_join",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "write",
+ "xor",
+ "year_month",
+ "zerofill",
+}.union(
+ {
+ "body",
+ "elsif",
+ "goto",
+ "history",
+ "others",
+ "package",
+ "period",
+ "raise",
+ "rowtype",
+ "system",
+ "system_time",
+ "versioning",
+ "without",
+ }
+)
+
+# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
+# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
+# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
+# includes: MySQL x.0 Keywords and Reserved Words
+# excludes: MySQL x.0 New Keywords and Reserved Words,
+# MySQL x.0 Removed Keywords and Reserved Words
+RESERVED_WORDS_MYSQL = {
+ "accessible",
+ "add",
+ "admin",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "array",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "cube",
+ "cume_dist",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "dense_rank",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "empty",
+ "enclosed",
+ "escaped",
+ "except",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "first_value",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "function",
+ "general",
+ "generated",
+ "get",
+ "get_master_public_key",
+ "grant",
+ "group",
+ "grouping",
+ "groups",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "ignore_server_ids",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "interval",
+ "into",
+ "io_after_gtids",
+ "io_before_gtids",
+ "is",
+ "iterate",
+ "join",
+ "json_table",
+ "key",
+ "keys",
+ "kill",
+ "lag",
+ "last_value",
+ "lateral",
+ "lead",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_bind",
+ "master_heartbeat_period",
+ "master_ssl_verify_server_cert",
+ "match",
+ "maxvalue",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "member",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "no_write_to_binlog",
+ "not",
+ "nth_value",
+ "ntile",
+ "null",
+ "numeric",
+ "of",
+ "on",
+ "optimize",
+ "optimizer_costs",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "over",
+ "parse_gcol_expr",
+ "partition",
+ "percent_rank",
+ "persist",
+ "persist_only",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "rank",
+ "read",
+ "read_write",
+ "reads",
+ "real",
+ "recursive",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "resignal",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rlike",
+ "role",
+ "row",
+ "row_number",
+ "rows",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "signal",
+ "slow",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sql_after_gtids",
+ "sql_before_gtids",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "ssl",
+ "starting",
+ "stored",
+ "straight_join",
+ "system",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "virtual",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "write",
+ "xor",
+ "year_month",
+ "zerofill",
+}
diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py
new file mode 100644
index 0000000..b81ee95
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/types.py
@@ -0,0 +1,773 @@
+# mysql/types.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import datetime
+
+from ... import exc
+from ... import types as sqltypes
+from ... import util
+
+
+class _NumericType(object):
+ """Base for MySQL numeric types.
+
+ This is the base both for NUMERIC as well as INTEGER, hence
+ it's a mixin.
+
+ """
+
+ def __init__(self, unsigned=False, zerofill=False, **kw):
+ self.unsigned = unsigned
+ self.zerofill = zerofill
+ super(_NumericType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_NumericType, sqltypes.Numeric]
+ )
+
+
+class _FloatType(_NumericType, sqltypes.Float):
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ if isinstance(self, (REAL, DOUBLE)) and (
+ (precision is None and scale is not None)
+ or (precision is not None and scale is None)
+ ):
+ raise exc.ArgumentError(
+ "You must specify both precision and scale or omit "
+ "both altogether."
+ )
+ super(_FloatType, self).__init__(
+ precision=precision, asdecimal=asdecimal, **kw
+ )
+ self.scale = scale
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
+ )
+
+
+class _IntegerType(_NumericType, sqltypes.Integer):
+ def __init__(self, display_width=None, **kw):
+ self.display_width = display_width
+ super(_IntegerType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
+ )
+
+
+class _StringType(sqltypes.String):
+ """Base for MySQL string types."""
+
+ def __init__(
+ self,
+ charset=None,
+ collation=None,
+ ascii=False, # noqa
+ binary=False,
+ unicode=False,
+ national=False,
+ **kw
+ ):
+ self.charset = charset
+
+ # allow collate= or collation=
+ kw.setdefault("collation", kw.pop("collate", collation))
+
+ self.ascii = ascii
+ self.unicode = unicode
+ self.binary = binary
+ self.national = national
+ super(_StringType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_StringType, sqltypes.String]
+ )
+
+
+class _MatchType(sqltypes.Float, sqltypes.MatchType):
+ def __init__(self, **kw):
+ # TODO: float arguments?
+ sqltypes.Float.__init__(self)
+ sqltypes.MatchType.__init__(self)
+
+
+class NUMERIC(_NumericType, sqltypes.NUMERIC):
+ """MySQL NUMERIC type."""
+
+ __visit_name__ = "NUMERIC"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a NUMERIC.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(NUMERIC, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class DECIMAL(_NumericType, sqltypes.DECIMAL):
+ """MySQL DECIMAL type."""
+
+ __visit_name__ = "DECIMAL"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a DECIMAL.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(DECIMAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class DOUBLE(_FloatType):
+ """MySQL DOUBLE type."""
+
+ __visit_name__ = "DOUBLE"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a DOUBLE.
+
+ .. note::
+
+ The :class:`.DOUBLE` type by default converts from float
+ to Decimal, using a truncation that defaults to 10 digits.
+ Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
+ to change this scale, or ``asdecimal=False`` to return values
+ directly as Python floating points.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(DOUBLE, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class REAL(_FloatType, sqltypes.REAL):
+ """MySQL REAL type."""
+
+ __visit_name__ = "REAL"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a REAL.
+
+ .. note::
+
+ The :class:`.REAL` type by default converts from float
+ to Decimal, using a truncation that defaults to 10 digits.
+ Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
+ to change this scale, or ``asdecimal=False`` to return values
+ directly as Python floating points.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(REAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class FLOAT(_FloatType, sqltypes.FLOAT):
+ """MySQL FLOAT type."""
+
+ __visit_name__ = "FLOAT"
+
+ def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
+ """Construct a FLOAT.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(FLOAT, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+ def bind_processor(self, dialect):
+ return None
+
+
+class INTEGER(_IntegerType, sqltypes.INTEGER):
+ """MySQL INTEGER type."""
+
+ __visit_name__ = "INTEGER"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct an INTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(INTEGER, self).__init__(display_width=display_width, **kw)
+
+
+class BIGINT(_IntegerType, sqltypes.BIGINT):
+ """MySQL BIGINTEGER type."""
+
+ __visit_name__ = "BIGINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a BIGINTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(BIGINT, self).__init__(display_width=display_width, **kw)
+
+
+class MEDIUMINT(_IntegerType):
+ """MySQL MEDIUMINTEGER type."""
+
+ __visit_name__ = "MEDIUMINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a MEDIUMINTEGER
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
+
+
+class TINYINT(_IntegerType):
+ """MySQL TINYINT type."""
+
+ __visit_name__ = "TINYINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a TINYINT.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(TINYINT, self).__init__(display_width=display_width, **kw)
+
+
+class SMALLINT(_IntegerType, sqltypes.SMALLINT):
+ """MySQL SMALLINTEGER type."""
+
+ __visit_name__ = "SMALLINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a SMALLINTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(SMALLINT, self).__init__(display_width=display_width, **kw)
+
+
+class BIT(sqltypes.TypeEngine):
+ """MySQL BIT type.
+
+ This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
+ for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
+ MSTinyInteger() type.
+
+ """
+
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None):
+ """Construct a BIT.
+
+ :param length: Optional, number of bits.
+
+ """
+ self.length = length
+
+ def result_processor(self, dialect, coltype):
+ """Convert a MySQL's 64 bit, variable length binary string to a long.
+
+ TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
+ already do this, so this logic should be moved to those dialects.
+
+ """
+
+ def process(value):
+ if value is not None:
+ v = 0
+ for i in value:
+ if not isinstance(i, int):
+ i = ord(i) # convert byte to int on Python 2
+ v = v << 8 | i
+ return v
+ return value
+
+ return process
+
+
+class TIME(sqltypes.TIME):
+ """MySQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL TIME type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the TIME type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(TIME, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+ def result_processor(self, dialect, coltype):
+ time = datetime.time
+
+ def process(value):
+ # convert from a timedelta value
+ if value is not None:
+ microseconds = value.microseconds
+ seconds = value.seconds
+ minutes = seconds // 60
+ return time(
+ minutes // 60,
+ minutes % 60,
+ seconds - minutes * 60,
+ microsecond=microseconds,
+ )
+ else:
+ return None
+
+ return process
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ """MySQL TIMESTAMP type."""
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL TIMESTAMP type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6.4 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the TIMESTAMP type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+
+class DATETIME(sqltypes.DATETIME):
+ """MySQL DATETIME type."""
+
+ __visit_name__ = "DATETIME"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL DATETIME type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6.4 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the DATETIME type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(DATETIME, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+
+class YEAR(sqltypes.TypeEngine):
+ """MySQL YEAR type, for single byte storage of years 1901-2155."""
+
+ __visit_name__ = "YEAR"
+
+ def __init__(self, display_width=None):
+ self.display_width = display_width
+
+
+class TEXT(_StringType, sqltypes.TEXT):
+ """MySQL TEXT type, for text up to 2^16 characters."""
+
+ __visit_name__ = "TEXT"
+
+ def __init__(self, length=None, **kw):
+ """Construct a TEXT.
+
+ :param length: Optional, if provided the server may optimize storage
+ by substituting the smallest TEXT type sufficient to store
+ ``length`` characters.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(TEXT, self).__init__(length=length, **kw)
+
+
+class TINYTEXT(_StringType):
+ """MySQL TINYTEXT type, for text up to 2^8 characters."""
+
+ __visit_name__ = "TINYTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a TINYTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(TINYTEXT, self).__init__(**kwargs)
+
+
+class MEDIUMTEXT(_StringType):
+ """MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
+
+ __visit_name__ = "MEDIUMTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a MEDIUMTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(MEDIUMTEXT, self).__init__(**kwargs)
+
+
+class LONGTEXT(_StringType):
+ """MySQL LONGTEXT type, for text up to 2^32 characters."""
+
+ __visit_name__ = "LONGTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a LONGTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(LONGTEXT, self).__init__(**kwargs)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+ """MySQL VARCHAR type, for variable-length character data."""
+
+ __visit_name__ = "VARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct a VARCHAR.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(VARCHAR, self).__init__(length=length, **kwargs)
+
+
+class CHAR(_StringType, sqltypes.CHAR):
+ """MySQL CHAR type, for fixed-length character data."""
+
+ __visit_name__ = "CHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct a CHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ super(CHAR, self).__init__(length=length, **kwargs)
+
+ @classmethod
+ def _adapt_string_for_cast(self, type_):
+ # copy the given string type into a CHAR
+ # for the purposes of rendering a CAST expression
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.CHAR):
+ return type_
+ elif isinstance(type_, _StringType):
+ return CHAR(
+ length=type_.length,
+ charset=type_.charset,
+ collation=type_.collation,
+ ascii=type_.ascii,
+ binary=type_.binary,
+ unicode=type_.unicode,
+ national=False, # not supported in CAST
+ )
+ else:
+ return CHAR(length=type_.length)
+
+
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
+ """MySQL NVARCHAR type.
+
+ For variable-length character data in the server's configured national
+ character set.
+ """
+
+ __visit_name__ = "NVARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct an NVARCHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ kwargs["national"] = True
+ super(NVARCHAR, self).__init__(length=length, **kwargs)
+
+
+class NCHAR(_StringType, sqltypes.NCHAR):
+ """MySQL NCHAR type.
+
+ For fixed-length character data in the server's configured national
+ character set.
+ """
+
+ __visit_name__ = "NCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct an NCHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ kwargs["national"] = True
+ super(NCHAR, self).__init__(length=length, **kwargs)
+
+
+class TINYBLOB(sqltypes._Binary):
+ """MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
+
+ __visit_name__ = "TINYBLOB"
+
+
+class MEDIUMBLOB(sqltypes._Binary):
+ """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
+
+ __visit_name__ = "MEDIUMBLOB"
+
+
+class LONGBLOB(sqltypes._Binary):
+ """MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
+
+ __visit_name__ = "LONGBLOB"
diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py
new file mode 100644
index 0000000..c83e057
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/__init__.py
@@ -0,0 +1,58 @@
+# oracle/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import cx_oracle # noqa
+from .base import BFILE
+from .base import BINARY_DOUBLE
+from .base import BINARY_FLOAT
+from .base import BLOB
+from .base import CHAR
+from .base import CLOB
+from .base import DATE
+from .base import DOUBLE_PRECISION
+from .base import FLOAT
+from .base import INTERVAL
+from .base import LONG
+from .base import NCHAR
+from .base import NCLOB
+from .base import NUMBER
+from .base import NVARCHAR
+from .base import NVARCHAR2
+from .base import RAW
+from .base import ROWID
+from .base import TIMESTAMP
+from .base import VARCHAR
+from .base import VARCHAR2
+
+
+base.dialect = dialect = cx_oracle.dialect
+
+__all__ = (
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "NCHAR",
+ "DATE",
+ "NUMBER",
+ "BLOB",
+ "BFILE",
+ "CLOB",
+ "NCLOB",
+ "TIMESTAMP",
+ "RAW",
+ "FLOAT",
+ "DOUBLE_PRECISION",
+ "BINARY_DOUBLE",
+ "BINARY_FLOAT",
+ "LONG",
+ "dialect",
+ "INTERVAL",
+ "VARCHAR2",
+ "NVARCHAR2",
+ "ROWID",
+)
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
new file mode 100644
index 0000000..77f0dbd
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -0,0 +1,2522 @@
+# oracle/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: oracle
+ :name: Oracle
+ :full_support: 11.2, 18c
+ :normal_support: 11+
+ :best_effort: 8+
+
+
+Auto Increment Behavior
+-----------------------
+
+SQLAlchemy Table objects which include integer primary keys are usually
+assumed to have "autoincrementing" behavior, meaning they can generate their
+own primary key values upon INSERT. For use within Oracle, two options are
+available, which are the use of IDENTITY columns (Oracle 12 and above only)
+or the association of a SEQUENCE with the column.
+
+Specifying GENERATED AS IDENTITY (Oracle 12 and above)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Starting from version 12 Oracle can make use of identity columns using
+the :class:`_sql.Identity` to specify the autoincrementing behavior::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Identity(start=3), primary_key=True),
+ Column(...), ...
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE mytable (
+ id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 3),
+ ...,
+ PRIMARY KEY (id)
+ )
+
+The :class:`_schema.Identity` object support many options to control the
+"autoincrementing" behavior of the column, like the starting value, the
+incrementing value, etc.
+In addition to the standard options, Oracle supports setting
+:paramref:`_schema.Identity.always` to ``None`` to use the default
+generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports
+setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL
+in conjunction with a 'BY DEFAULT' identity column.
+
+Using a SEQUENCE (all Oracle versions)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Older version of Oracle had no "autoincrement"
+feature, SQLAlchemy relies upon sequences to produce these values. With the
+older Oracle versions, *a sequence must always be explicitly specified to
+enable autoincrement*. This is divergent with the majority of documentation
+examples which assume the usage of an autoincrement-capable database. To
+specify sequences, use the sqlalchemy.schema.Sequence object which is passed
+to a Column construct::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ Column(...), ...
+ )
+
+This step is also required when using table reflection, i.e. autoload_with=engine::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ autoload_with=engine
+ )
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the option of an autoincrementing
+ column.
+
+.. _oracle_isolation_level:
+
+Transaction Isolation Level / Autocommit
+----------------------------------------
+
+The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of
+isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle
+dialect.
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="AUTOCOMMIT"
+ )
+
+For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the
+level at the session level using ``ALTER SESSION``, which is reverted back
+to its default setting when the connection is returned to the connection
+pool.
+
+Valid values for ``isolation_level`` include:
+
+* ``READ COMMITTED``
+* ``AUTOCOMMIT``
+* ``SERIALIZABLE``
+
+.. note:: The implementation for the
+ :meth:`_engine.Connection.get_isolation_level` method as implemented by the
+ Oracle dialect necessarily forces the start of a transaction using the
+ Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally
+ readable.
+
+ Additionally, the :meth:`_engine.Connection.get_isolation_level` method will
+ raise an exception if the ``v$transaction`` view is not available due to
+ permissions or other reasons, which is a common occurrence in Oracle
+ installations.
+
+ The cx_Oracle dialect attempts to call the
+ :meth:`_engine.Connection.get_isolation_level` method when the dialect makes
+ its first connection to the database in order to acquire the
+ "default"isolation level. This default level is necessary so that the level
+ can be reset on a connection after it has been temporarily modified using
+ :meth:`_engine.Connection.execution_options` method. In the common event
+ that the :meth:`_engine.Connection.get_isolation_level` method raises an
+ exception due to ``v$transaction`` not being readable as well as any other
+ database-related failure, the level is assumed to be "READ COMMITTED". No
+ warning is emitted for this initial first-connect condition as it is
+ expected to be a common restriction on Oracle databases.
+
+.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect
+ as well as the notion of a default isolation level
+
+.. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live
+ reading of the isolation level.
+
+.. versionchanged:: 1.3.22 In the event that the default isolation
+ level cannot be read due to permissions on the v$transaction view as
+ is common in Oracle installations, the default isolation level is hardcoded
+ to "READ COMMITTED" which was the behavior prior to 1.3.21.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+Identifier Casing
+-----------------
+
+In Oracle, the data dictionary represents all case insensitive identifier
+names using UPPERCASE text. SQLAlchemy on the other hand considers an
+all-lower case identifier name to be case insensitive. The Oracle dialect
+converts all case insensitive identifiers to and from those two formats during
+schema level communication, such as reflection of tables and indexes. Using
+an UPPERCASE name on the SQLAlchemy side indicates a case sensitive
+identifier, and SQLAlchemy will quote the name - this will cause mismatches
+against data dictionary data received from Oracle, so unless identifier names
+have been truly created as case sensitive (i.e. using quoted names), all
+lowercase names should be used on the SQLAlchemy side.
+
+.. _oracle_max_identifier_lengths:
+
+Max Identifier Lengths
+----------------------
+
+Oracle has changed the default max identifier length as of Oracle Server
+version 12.2. Prior to this version, the length was 30, and for 12.2 and
+greater it is now 128. This change impacts SQLAlchemy in the area of
+generated SQL label names as well as the generation of constraint names,
+particularly in the case where the constraint naming convention feature
+described at :ref:`constraint_naming_conventions` is being used.
+
+To assist with this change and others, Oracle includes the concept of a
+"compatibility" version, which is a version number that is independent of the
+actual server version in order to assist with migration of Oracle databases,
+and may be configured within the Oracle server itself. This compatibility
+version is retrieved using the query ``SELECT value FROM v$parameter WHERE
+name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with
+determining the default max identifier length, will attempt to use this query
+upon first connect in order to determine the effective compatibility version of
+the server, which determines what the maximum allowed identifier length is for
+the server. If the table is not available, the server version information is
+used instead.
+
+As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect
+is 128 characters. Upon first connect, the compatibility version is detected
+and if it is less than Oracle version 12.2, the max identifier length is
+changed to be 30 characters. In all cases, setting the
+:paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this
+change and the value given will be used as is::
+
+ engine = create_engine(
+ "oracle+cx_oracle://scott:tiger@oracle122",
+ max_identifier_length=30)
+
+The maximum identifier length comes into play both when generating anonymized
+SQL labels in SELECT statements, but more crucially when generating constraint
+names from a naming convention. It is this area that has created the need for
+SQLAlchemy to change this default conservatively. For example, the following
+naming convention produces two very different constraint names based on the
+identifier length::
+
+ from sqlalchemy import Column
+ from sqlalchemy import Index
+ from sqlalchemy import Integer
+ from sqlalchemy import MetaData
+ from sqlalchemy import Table
+ from sqlalchemy.dialects import oracle
+ from sqlalchemy.schema import CreateIndex
+
+ m = MetaData(naming_convention={"ix": "ix_%(column_0N_name)s"})
+
+ t = Table(
+ "t",
+ m,
+ Column("some_column_name_1", Integer),
+ Column("some_column_name_2", Integer),
+ Column("some_column_name_3", Integer),
+ )
+
+ ix = Index(
+ None,
+ t.c.some_column_name_1,
+ t.c.some_column_name_2,
+ t.c.some_column_name_3,
+ )
+
+ oracle_dialect = oracle.dialect(max_identifier_length=30)
+ print(CreateIndex(ix).compile(dialect=oracle_dialect))
+
+With an identifier length of 30, the above CREATE INDEX looks like::
+
+ CREATE INDEX ix_some_column_name_1s_70cd ON t
+ (some_column_name_1, some_column_name_2, some_column_name_3)
+
+However with length=128, it becomes::
+
+ CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t
+ (some_column_name_1, some_column_name_2, some_column_name_3)
+
+Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle
+server version 12.2 or greater are therefore subject to the scenario of a
+database migration that wishes to "DROP CONSTRAINT" on a name that was
+previously generated with the shorter length. This migration will fail when
+the identifier length is changed without the name of the index or constraint
+first being adjusted. Such applications are strongly advised to make use of
+:paramref:`_sa.create_engine.max_identifier_length`
+in order to maintain control
+of the generation of truncated names, and to fully review and test all database
+migrations in a staging environment when changing this value to ensure that the
+impact of this change has been mitigated.
+
+.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128
+ characters, which is adjusted down to 30 upon first connect if an older
+ version of Oracle server (compatibility version < 12.2) is detected.
+
+
+LIMIT/OFFSET/FETCH Support
+--------------------------
+
+Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` currently
+use an emulated approach for LIMIT / OFFSET based on window functions, which
+involves creation of a subquery using ``ROW_NUMBER`` that is prone to
+performance issues as well as SQL construction issues for complex statements.
+However, this approach is supported by all Oracle versions. See notes below.
+
+When using Oracle 12c and above, use the :meth:`_sql.Select.fetch` method
+instead; this will render the more modern
+``FETCH FIRST N ROW / OFFSET N ROWS`` syntax.
+
+Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`,
+or with the ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods,
+and the :meth:`_sql.Select.fetch` method **cannot** be used instead, the following
+notes apply:
+
+* SQLAlchemy currently makes use of ROWNUM to achieve
+ LIMIT/OFFSET; the exact methodology is taken from
+ https://blogs.oracle.com/oraclemagazine/on-rownum-and-limiting-results .
+
+* the "FIRST_ROWS()" optimization keyword is not used by default. To enable
+ the usage of this optimization directive, specify ``optimize_limits=True``
+ to :func:`_sa.create_engine`.
+
+ .. versionchanged:: 1.4
+ The Oracle dialect renders limit/offset integer values using a "post
+ compile" scheme which renders the integer directly before passing the
+ statement to the cursor for execution. The ``use_binds_for_limits`` flag
+ no longer has an effect.
+
+ .. seealso::
+
+ :ref:`change_4808`.
+
+* A future release may use ``FETCH FIRST N ROW / OFFSET N ROWS`` automatically
+ when :meth:`_sql.Select.limit`, :meth:`_sql.Select.offset`, :meth:`_orm.Query.limit`,
+ :meth:`_orm.Query.offset` are used.
+
+.. _oracle_returning:
+
+RETURNING Support
+-----------------
+
+The Oracle database supports a limited form of RETURNING, in order to retrieve
+result sets of matched rows from INSERT, UPDATE and DELETE statements.
+Oracle's RETURNING..INTO syntax only supports one row being returned, as it
+relies upon OUT parameters in order to function. In addition, supported
+DBAPIs have further limitations (see :ref:`cx_oracle_returning`).
+
+SQLAlchemy's "implicit returning" feature, which employs RETURNING within an
+INSERT and sometimes an UPDATE statement in order to fetch newly generated
+primary key values and other SQL defaults and expressions, is normally enabled
+on the Oracle backend. By default, "implicit returning" typically only
+fetches the value of a single ``nextval(some_seq)`` expression embedded into
+an INSERT in order to increment a sequence within an INSERT statement and get
+the value back at the same time. To disable this feature across the board,
+specify ``implicit_returning=False`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("oracle://scott:tiger@dsn",
+ implicit_returning=False)
+
+Implicit returning can also be disabled on a table-by-table basis as a table
+option::
+
+ # Core Table
+ my_table = Table("my_table", metadata, ..., implicit_returning=False)
+
+
+ # declarative
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+ __table_args__ = {"implicit_returning": False}
+
+.. seealso::
+
+ :ref:`cx_oracle_returning` - additional cx_oracle-specific restrictions on
+ implicit returning.
+
+ON UPDATE CASCADE
+-----------------
+
+Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based
+solution is available at
+https://asktom.oracle.com/tkyte/update_cascade/index.html .
+
+When using the SQLAlchemy ORM, the ORM has limited ability to manually issue
+cascading updates - specify ForeignKey objects using the
+"deferrable=True, initially='deferred'" keyword arguments,
+and specify "passive_updates=False" on each relationship().
+
+Oracle 8 Compatibility
+----------------------
+
+When Oracle 8 is detected, the dialect internally configures itself to the
+following behaviors:
+
+* the use_ansi flag is set to False. This has the effect of converting all
+ JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
+ makes use of Oracle's (+) operator.
+
+* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when
+ the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are
+ issued instead. This because these types don't seem to work correctly on
+ Oracle 8 even though they are available. The
+ :class:`~sqlalchemy.types.NVARCHAR` and
+ :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate
+ NVARCHAR2 and NCLOB.
+
+* the "native unicode" mode is disabled when using cx_oracle, i.e. SQLAlchemy
+ encodes all Python unicode objects to "string" before passing in as bind
+ parameters.
+
+Synonym/DBLINK Reflection
+-------------------------
+
+When using reflection with Table objects, the dialect can optionally search
+for tables indicated by synonyms, either in local or remote schemas or
+accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as
+a keyword argument to the :class:`_schema.Table` construct::
+
+ some_table = Table('some_table', autoload_with=some_engine,
+ oracle_resolve_synonyms=True)
+
+When this flag is set, the given name (such as ``some_table`` above) will
+be searched not just in the ``ALL_TABLES`` view, but also within the
+``ALL_SYNONYMS`` view to see if this name is actually a synonym to another
+name. If the synonym is located and refers to a DBLINK, the oracle dialect
+knows how to locate the table's information using DBLINK syntax(e.g.
+``@dblink``).
+
+``oracle_resolve_synonyms`` is accepted wherever reflection arguments are
+accepted, including methods such as :meth:`_schema.MetaData.reflect` and
+:meth:`_reflection.Inspector.get_columns`.
+
+If synonyms are not in use, this flag should be left disabled.
+
+.. _oracle_constraint_reflection:
+
+Constraint Reflection
+---------------------
+
+The Oracle dialect can return information about foreign key, unique, and
+CHECK constraints, as well as indexes on tables.
+
+Raw information regarding these constraints can be acquired using
+:meth:`_reflection.Inspector.get_foreign_keys`,
+:meth:`_reflection.Inspector.get_unique_constraints`,
+:meth:`_reflection.Inspector.get_check_constraints`, and
+:meth:`_reflection.Inspector.get_indexes`.
+
+.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and
+ CHECK constraints.
+
+When using reflection at the :class:`_schema.Table` level, the
+:class:`_schema.Table`
+will also include these constraints.
+
+Note the following caveats:
+
+* When using the :meth:`_reflection.Inspector.get_check_constraints` method,
+ Oracle
+ builds a special "IS NOT NULL" constraint for columns that specify
+ "NOT NULL". This constraint is **not** returned by default; to include
+ the "IS NOT NULL" constraints, pass the flag ``include_all=True``::
+
+ from sqlalchemy import create_engine, inspect
+
+ engine = create_engine("oracle+cx_oracle://s:t@dsn")
+ inspector = inspect(engine)
+ all_check_constraints = inspector.get_check_constraints(
+ "some_table", include_all=True)
+
+* in most cases, when reflecting a :class:`_schema.Table`,
+ a UNIQUE constraint will
+ **not** be available as a :class:`.UniqueConstraint` object, as Oracle
+ mirrors unique constraints with a UNIQUE index in most cases (the exception
+ seems to be when two or more unique constraints represent the same columns);
+ the :class:`_schema.Table` will instead represent these using
+ :class:`.Index`
+ with the ``unique=True`` flag set.
+
+* Oracle creates an implicit index for the primary key of a table; this index
+ is **excluded** from all index results.
+
+* the list of columns reflected for an index will not include column names
+ that start with SYS_NC.
+
+Table names with SYSTEM/SYSAUX tablespaces
+-------------------------------------------
+
+The :meth:`_reflection.Inspector.get_table_names` and
+:meth:`_reflection.Inspector.get_temp_table_names`
+methods each return a list of table names for the current engine. These methods
+are also part of the reflection which occurs within an operation such as
+:meth:`_schema.MetaData.reflect`. By default,
+these operations exclude the ``SYSTEM``
+and ``SYSAUX`` tablespaces from the operation. In order to change this, the
+default list of tablespaces excluded can be changed at the engine level using
+the ``exclude_tablespaces`` parameter::
+
+ # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM
+ e = create_engine(
+ "oracle://scott:tiger@xe",
+ exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"])
+
+.. versionadded:: 1.1
+
+DateTime Compatibility
+----------------------
+
+Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``,
+which can actually store a date and time value. For this reason, the Oracle
+dialect provides a type :class:`_oracle.DATE` which is a subclass of
+:class:`.DateTime`. This type has no special behavior, and is only
+present as a "marker" for this type; additionally, when a database column
+is reflected and the type is reported as ``DATE``, the time-supporting
+:class:`_oracle.DATE` type is used.
+
+.. versionchanged:: 0.9.4 Added :class:`_oracle.DATE` to subclass
+ :class:`.DateTime`. This is a change as previous versions
+ would reflect a ``DATE`` column as :class:`_types.DATE`, which subclasses
+ :class:`.Date`. The only significance here is for schemes that are
+ examining the type of column for use in special Python translations or
+ for migrating schemas to other database backends.
+
+.. _oracle_table_options:
+
+Oracle Table Options
+-------------------------
+
+The CREATE TABLE phrase supports the following options with Oracle
+in conjunction with the :class:`_schema.Table` construct:
+
+
+* ``ON COMMIT``::
+
+ Table(
+ "some_table", metadata, ...,
+ prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS')
+
+.. versionadded:: 1.0.0
+
+* ``COMPRESS``::
+
+ Table('mytable', metadata, Column('data', String(32)),
+ oracle_compress=True)
+
+ Table('mytable', metadata, Column('data', String(32)),
+ oracle_compress=6)
+
+ The ``oracle_compress`` parameter accepts either an integer compression
+ level, or ``True`` to use the default compression level.
+
+.. versionadded:: 1.0.0
+
+.. _oracle_index_options:
+
+Oracle Specific Index Options
+-----------------------------
+
+Bitmap Indexes
+~~~~~~~~~~~~~~
+
+You can specify the ``oracle_bitmap`` parameter to create a bitmap index
+instead of a B-tree index::
+
+ Index('my_index', my_table.c.data, oracle_bitmap=True)
+
+Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not
+check for such limitations, only the database will.
+
+.. versionadded:: 1.0.0
+
+Index compression
+~~~~~~~~~~~~~~~~~
+
+Oracle has a more efficient storage mode for indexes containing lots of
+repeated values. Use the ``oracle_compress`` parameter to turn on key
+compression::
+
+ Index('my_index', my_table.c.data, oracle_compress=True)
+
+ Index('my_index', my_table.c.data1, my_table.c.data2, unique=True,
+ oracle_compress=1)
+
+The ``oracle_compress`` parameter accepts either an integer specifying the
+number of prefix columns to compress, or ``True`` to use the default (all
+columns for non-unique indexes, all but the last column for unique indexes).
+
+.. versionadded:: 1.0.0
+
+""" # noqa
+
+from itertools import groupby
+import re
+
+from ... import Computed
+from ... import exc
+from ... import schema as sa_schema
+from ... import sql
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import compiler
+from ...sql import expression
+from ...sql import sqltypes
+from ...sql import util as sql_util
+from ...sql import visitors
+from ...types import BLOB
+from ...types import CHAR
+from ...types import CLOB
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NCHAR
+from ...types import NVARCHAR
+from ...types import TIMESTAMP
+from ...types import VARCHAR
+from ...util import compat
+
+RESERVED_WORDS = set(
+ "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN "
+ "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED "
+ "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE "
+ "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE "
+ "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES "
+ "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS "
+ "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER "
+ "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR "
+ "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split()
+)
+
+NO_ARG_FNS = set(
+ "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split()
+)
+
+
+class RAW(sqltypes._Binary):
+ __visit_name__ = "RAW"
+
+
+OracleRaw = RAW
+
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = "NCLOB"
+
+
+class VARCHAR2(VARCHAR):
+ __visit_name__ = "VARCHAR2"
+
+
+NVARCHAR2 = NVARCHAR
+
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
+ __visit_name__ = "NUMBER"
+
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
+
+ def adapt(self, impltype):
+ ret = super(NUMBER, self).adapt(impltype)
+ # leave a hint for the DBAPI handler
+ ret._is_oracle_number = True
+ return ret
+
+ @property
+ def _type_affinity(self):
+ if bool(self.scale and self.scale > 0):
+ return sqltypes.Numeric
+ else:
+ return sqltypes.Integer
+
+
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = "DOUBLE_PRECISION"
+
+
+class BINARY_DOUBLE(sqltypes.Float):
+ __visit_name__ = "BINARY_DOUBLE"
+
+
+class BINARY_FLOAT(sqltypes.Float):
+ __visit_name__ = "BINARY_FLOAT"
+
+
+class BFILE(sqltypes.LargeBinary):
+ __visit_name__ = "BFILE"
+
+
+class LONG(sqltypes.Text):
+ __visit_name__ = "LONG"
+
+
+class DATE(sqltypes.DateTime):
+ """Provide the oracle DATE type.
+
+ This type has no special Python behavior, except that it subclasses
+ :class:`_types.DateTime`; this is to suit the fact that the Oracle
+ ``DATE`` type supports a time value.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ __visit_name__ = "DATE"
+
+ def _compare_type_affinity(self, other):
+ return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+ __visit_name__ = "INTERVAL"
+
+ def __init__(self, day_precision=None, second_precision=None):
+ """Construct an INTERVAL.
+
+ Note that only DAY TO SECOND intervals are currently supported.
+ This is due to a lack of support for YEAR TO MONTH intervals
+ within available DBAPIs.
+
+ :param day_precision: the day precision value. this is the number of
+ digits to store for the day field. Defaults to "2"
+ :param second_precision: the second precision value. this is the
+ number of digits to store for the fractional seconds field.
+ Defaults to "6".
+
+ """
+ self.day_precision = day_precision
+ self.second_precision = second_precision
+
+ @classmethod
+ def _adapt_from_generic_interval(cls, interval):
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(
+ native=True,
+ second_precision=self.second_precision,
+ day_precision=self.day_precision,
+ )
+
+ def coerce_compared_value(self, op, value):
+ return self
+
+
+class ROWID(sqltypes.TypeEngine):
+ """Oracle ROWID type.
+
+ When used in a cast() or similar, generates ROWID.
+
+ """
+
+ __visit_name__ = "ROWID"
+
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+
+colspecs = {
+ sqltypes.Boolean: _OracleBoolean,
+ sqltypes.Interval: INTERVAL,
+ sqltypes.DateTime: DATE,
+}
+
+ischema_names = {
+ "VARCHAR2": VARCHAR,
+ "NVARCHAR2": NVARCHAR,
+ "CHAR": CHAR,
+ "NCHAR": NCHAR,
+ "DATE": DATE,
+ "NUMBER": NUMBER,
+ "BLOB": BLOB,
+ "BFILE": BFILE,
+ "CLOB": CLOB,
+ "NCLOB": NCLOB,
+ "TIMESTAMP": TIMESTAMP,
+ "TIMESTAMP WITH TIME ZONE": TIMESTAMP,
+ "INTERVAL DAY TO SECOND": INTERVAL,
+ "RAW": RAW,
+ "FLOAT": FLOAT,
+ "DOUBLE PRECISION": DOUBLE_PRECISION,
+ "LONG": LONG,
+ "BINARY_DOUBLE": BINARY_DOUBLE,
+ "BINARY_FLOAT": BINARY_FLOAT,
+}
+
+
+class OracleTypeCompiler(compiler.GenericTypeCompiler):
+ # Note:
+ # Oracle DATE == DATETIME
+ # Oracle does not allow milliseconds in DATE
+ # Oracle does not support TIME columns
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
+
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
+
+ def visit_unicode(self, type_, **kw):
+ if self.dialect._use_nchar_for_unicode:
+ return self.visit_NVARCHAR2(type_, **kw)
+ else:
+ return self.visit_VARCHAR2(type_, **kw)
+
+ def visit_INTERVAL(self, type_, **kw):
+ return "INTERVAL DAY%s TO SECOND%s" % (
+ type_.day_precision is not None
+ and "(%d)" % type_.day_precision
+ or "",
+ type_.second_precision is not None
+ and "(%d)" % type_.second_precision
+ or "",
+ )
+
+ def visit_LONG(self, type_, **kw):
+ return "LONG"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ if type_.timezone:
+ return "TIMESTAMP WITH TIME ZONE"
+ else:
+ return "TIMESTAMP"
+
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
+ return self._generate_numeric(type_, "DOUBLE PRECISION", **kw)
+
+ def visit_BINARY_DOUBLE(self, type_, **kw):
+ return self._generate_numeric(type_, "BINARY_DOUBLE", **kw)
+
+ def visit_BINARY_FLOAT(self, type_, **kw):
+ return self._generate_numeric(type_, "BINARY_FLOAT", **kw)
+
+ def visit_FLOAT(self, type_, **kw):
+ # don't support conversion between decimal/binary
+ # precision yet
+ kw["no_precision"] = True
+ return self._generate_numeric(type_, "FLOAT", **kw)
+
+ def visit_NUMBER(self, type_, **kw):
+ return self._generate_numeric(type_, "NUMBER", **kw)
+
+ def _generate_numeric(
+ self, type_, name, precision=None, scale=None, no_precision=False, **kw
+ ):
+ if precision is None:
+ precision = type_.precision
+
+ if scale is None:
+ scale = getattr(type_, "scale", None)
+
+ if no_precision or precision is None:
+ return name
+ elif scale is None:
+ n = "%(name)s(%(precision)s)"
+ return n % {"name": name, "precision": precision}
+ else:
+ n = "%(name)s(%(precision)s, %(scale)s)"
+ return n % {"name": name, "precision": precision, "scale": scale}
+
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR2(type_, **kw)
+
+ def visit_VARCHAR2(self, type_, **kw):
+ return self._visit_varchar(type_, "", "2")
+
+ def visit_NVARCHAR2(self, type_, **kw):
+ return self._visit_varchar(type_, "N", "2")
+
+ visit_NVARCHAR = visit_NVARCHAR2
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._visit_varchar(type_, "", "")
+
+ def _visit_varchar(self, type_, n, num):
+ if not type_.length:
+ return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n}
+ elif not n and self.dialect._supports_char_length:
+ varchar = "VARCHAR%(two)s(%(length)s CHAR)"
+ return varchar % {"length": type_.length, "two": num}
+ else:
+ varchar = "%(n)sVARCHAR%(two)s(%(length)s)"
+ return varchar % {"length": type_.length, "two": num, "n": n}
+
+ def visit_text(self, type_, **kw):
+ return self.visit_CLOB(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ if self.dialect._use_nchar_for_unicode:
+ return self.visit_NCLOB(type_, **kw)
+ else:
+ return self.visit_CLOB(type_, **kw)
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
+
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_NUMBER(type_, precision=19, **kw)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_RAW(self, type_, **kw):
+ if type_.length:
+ return "RAW(%(length)s)" % {"length": type_.length}
+ else:
+ return "RAW"
+
+ def visit_ROWID(self, type_, **kw):
+ return "ROWID"
+
+
+class OracleCompiler(compiler.SQLCompiler):
+ """Oracle compiler modifies the lexical structure of Select
+ statements to work under non-ANSI configured Oracle databases, if
+ the use_ansi flag is False.
+ """
+
+ compound_keywords = util.update_copy(
+ compiler.SQLCompiler.compound_keywords,
+ {expression.CompoundSelect.EXCEPT: "MINUS"},
+ )
+
+ def __init__(self, *args, **kwargs):
+ self.__wheres = {}
+ super(OracleCompiler, self).__init__(*args, **kwargs)
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return "mod(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LENGTH" + self.function_argspec(fn, **kw)
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ return "CONTAINS (%s, %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def get_cte_preamble(self, recursive):
+ return "WITH"
+
+ def get_select_hint_text(self, byfroms):
+ return " ".join("/*+ %s */" % text for table, text in byfroms.items())
+
+ def function_argspec(self, fn, **kw):
+ if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+ else:
+ return ""
+
+ def visit_function(self, func, **kw):
+ text = super(OracleCompiler, self).visit_function(func, **kw)
+ if kw.get("asfrom", False):
+ text = "TABLE (%s)" % func
+ return text
+
+ def visit_table_valued_column(self, element, **kw):
+ text = super(OracleCompiler, self).visit_table_valued_column(
+ element, **kw
+ )
+ text = "COLUMN_VALUE " + text
+ return text
+
+ def default_from(self):
+ """Called when a ``SELECT`` statement has no froms,
+ and no ``FROM`` clause is to be appended.
+
+ The Oracle compiler tacks a "FROM DUAL" to the statement.
+ """
+
+ return " FROM DUAL"
+
+ def visit_join(self, join, from_linter=None, **kwargs):
+ if self.dialect.use_ansi:
+ return compiler.SQLCompiler.visit_join(
+ self, join, from_linter=from_linter, **kwargs
+ )
+ else:
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
+ kwargs["asfrom"] = True
+ if isinstance(join.right, expression.FromGrouping):
+ right = join.right.element
+ else:
+ right = join.right
+ return (
+ self.process(join.left, from_linter=from_linter, **kwargs)
+ + ", "
+ + self.process(right, from_linter=from_linter, **kwargs)
+ )
+
+ def _get_nonansi_join_whereclause(self, froms):
+ clauses = []
+
+ def visit_join(join):
+ if join.isouter:
+ # https://docs.oracle.com/database/121/SQLRF/queries006.htm#SQLRF52354
+ # "apply the outer join operator (+) to all columns of B in
+ # the join condition in the WHERE clause" - that is,
+ # unconditionally regardless of operator or the other side
+ def visit_binary(binary):
+ if isinstance(
+ binary.left, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.left.table):
+ binary.left = _OuterJoinColumn(binary.left)
+ elif isinstance(
+ binary.right, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.right.table):
+ binary.right = _OuterJoinColumn(binary.right)
+
+ clauses.append(
+ visitors.cloned_traverse(
+ join.onclause, {}, {"binary": visit_binary}
+ )
+ )
+ else:
+ clauses.append(join.onclause)
+
+ for j in join.left, join.right:
+ if isinstance(j, expression.Join):
+ visit_join(j)
+ elif isinstance(j, expression.FromGrouping):
+ visit_join(j.element)
+
+ for f in froms:
+ if isinstance(f, expression.Join):
+ visit_join(f)
+
+ if not clauses:
+ return None
+ else:
+ return sql.and_(*clauses)
+
+ def visit_outer_join_column(self, vc, **kw):
+ return self.process(vc.column, **kw) + "(+)"
+
+ def visit_sequence(self, seq, **kw):
+ return self.preparer.format_sequence(seq) + ".nextval"
+
+ def get_render_as_alias_suffix(self, alias_name_text):
+ """Oracle doesn't like ``FROM table AS alias``"""
+
+ return " " + alias_name_text
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = []
+ binds = []
+
+ for i, column in enumerate(
+ expression._select_iterables(returning_cols)
+ ):
+ if (
+ self.isupdate
+ and isinstance(column, sa_schema.Column)
+ and isinstance(column.server_default, Computed)
+ and not self.dialect._supports_update_returning_computed_cols
+ ):
+ util.warn(
+ "Computed columns don't work with Oracle UPDATE "
+ "statements that use RETURNING; the value of the column "
+ "*before* the UPDATE takes place is returned. It is "
+ "advised to not use RETURNING with an Oracle computed "
+ "column. Consider setting implicit_returning to False on "
+ "the Table object in order to avoid implicit RETURNING "
+ "clauses from being generated for this Table."
+ )
+ if column.type._has_column_expression:
+ col_expr = column.type.column_expression(column)
+ else:
+ col_expr = column
+
+ outparam = sql.outparam("ret_%d" % i, type_=column.type)
+ self.binds[outparam.key] = outparam
+ binds.append(
+ self.bindparam_string(self._truncate_bindparam(outparam))
+ )
+
+ # ensure the ExecutionContext.get_out_parameters() method is
+ # *not* called; the cx_Oracle dialect wants to handle these
+ # parameters separately
+ self.has_out_parameters = False
+
+ columns.append(self.process(col_expr, within_columns_clause=False))
+
+ self._add_to_result_map(
+ getattr(col_expr, "name", col_expr._anon_name_label),
+ getattr(col_expr, "name", col_expr._anon_name_label),
+ (
+ column,
+ getattr(column, "name", None),
+ getattr(column, "key", None),
+ ),
+ column.type,
+ )
+
+ return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
+
+ def translate_select_structure(self, select_stmt, **kwargs):
+ select = select_stmt
+
+ if not getattr(select, "_oracle_visit", None):
+ if not self.dialect.use_ansi:
+ froms = self._display_froms_for_select(
+ select, kwargs.get("asfrom", False)
+ )
+ whereclause = self._get_nonansi_join_whereclause(froms)
+ if whereclause is not None:
+ select = select.where(whereclause)
+ select._oracle_visit = True
+
+ # if fetch is used this is not needed
+ if (
+ select._has_row_limiting_clause
+ and select._fetch_clause is None
+ ):
+ limit_clause = select._limit_clause
+ offset_clause = select._offset_clause
+
+ if select._simple_int_clause(limit_clause):
+ limit_clause = limit_clause.render_literal_execute()
+
+ if select._simple_int_clause(offset_clause):
+ offset_clause = offset_clause.render_literal_execute()
+
+ # currently using form at:
+ # https://blogs.oracle.com/oraclemagazine/\
+ # on-rownum-and-limiting-results
+
+ orig_select = select
+ select = select._generate()
+ select._oracle_visit = True
+
+ # add expressions to accommodate FOR UPDATE OF
+ for_update = select._for_update_arg
+ if for_update is not None and for_update.of:
+ for_update = for_update._clone()
+ for_update._copy_internals()
+
+ for elem in for_update.of:
+ if not select.selected_columns.contains_column(elem):
+ select = select.add_columns(elem)
+
+ # Wrap the middle select and add the hint
+ inner_subquery = select.alias()
+ limitselect = sql.select(
+ *[
+ c
+ for c in inner_subquery.c
+ if orig_select.selected_columns.corresponding_column(c)
+ is not None
+ ]
+ )
+
+ if (
+ limit_clause is not None
+ and self.dialect.optimize_limits
+ and select._simple_int_clause(limit_clause)
+ ):
+ limitselect = limitselect.prefix_with(
+ expression.text(
+ "/*+ FIRST_ROWS(%s) */"
+ % self.process(limit_clause, **kwargs)
+ )
+ )
+
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ # add expressions to accommodate FOR UPDATE OF
+ if for_update is not None and for_update.of:
+
+ adapter = sql_util.ClauseAdapter(inner_subquery)
+ for_update.of = [
+ adapter.traverse(elem) for elem in for_update.of
+ ]
+
+ # If needed, add the limiting clause
+ if limit_clause is not None:
+ if select._simple_int_clause(limit_clause) and (
+ offset_clause is None
+ or select._simple_int_clause(offset_clause)
+ ):
+ max_row = limit_clause
+
+ if offset_clause is not None:
+ max_row = max_row + offset_clause
+
+ else:
+ max_row = limit_clause
+
+ if offset_clause is not None:
+ max_row = max_row + offset_clause
+ limitselect = limitselect.where(
+ sql.literal_column("ROWNUM") <= max_row
+ )
+
+ # If needed, add the ora_rn, and wrap again with offset.
+ if offset_clause is None:
+ limitselect._for_update_arg = for_update
+ select = limitselect
+ else:
+ limitselect = limitselect.add_columns(
+ sql.literal_column("ROWNUM").label("ora_rn")
+ )
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ if for_update is not None and for_update.of:
+ limitselect_cols = limitselect.selected_columns
+ for elem in for_update.of:
+ if (
+ limitselect_cols.corresponding_column(elem)
+ is None
+ ):
+ limitselect = limitselect.add_columns(elem)
+
+ limit_subquery = limitselect.alias()
+ origselect_cols = orig_select.selected_columns
+ offsetselect = sql.select(
+ *[
+ c
+ for c in limit_subquery.c
+ if origselect_cols.corresponding_column(c)
+ is not None
+ ]
+ )
+
+ offsetselect._oracle_visit = True
+ offsetselect._is_wrapper = True
+
+ if for_update is not None and for_update.of:
+ adapter = sql_util.ClauseAdapter(limit_subquery)
+ for_update.of = [
+ adapter.traverse(elem) for elem in for_update.of
+ ]
+
+ offsetselect = offsetselect.where(
+ sql.literal_column("ora_rn") > offset_clause
+ )
+
+ offsetselect._for_update_arg = for_update
+ select = offsetselect
+
+ return select
+
+ def limit_clause(self, select, **kw):
+ return ""
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 FROM DUAL WHERE 1!=1"
+
+ def for_update_clause(self, select, **kw):
+ if self.is_subquery():
+ return ""
+
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of:
+ tmp += " OF " + ", ".join(
+ self.process(elem, **kw) for elem in select._for_update_arg.of
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "DECODE(%s, %s, 0, 1) = 1" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "DECODE(%s, %s, 0, 1) = 0" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _get_regexp_args(self, binary, kw):
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ flags = binary.modifiers["flags"]
+ if flags is not None:
+ flags = self.process(flags, **kw)
+ return string, pattern, flags
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ string, pattern, flags = self._get_regexp_args(binary, kw)
+ if flags is None:
+ return "REGEXP_LIKE(%s, %s)" % (string, pattern)
+ else:
+ return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return "NOT %s" % self.visit_regexp_match_op_binary(
+ binary, operator, **kw
+ )
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ string, pattern, flags = self._get_regexp_args(binary, kw)
+ replacement = self.process(binary.modifiers["replacement"], **kw)
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ flags,
+ )
+
+
+class OracleDDLCompiler(compiler.DDLCompiler):
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % constraint.ondelete
+
+ # oracle has no ON UPDATE CASCADE -
+ # its only available via triggers
+ # https://asktom.oracle.com/tkyte/update_cascade/index.html
+ if constraint.onupdate is not None:
+ util.warn(
+ "Oracle does not contain native UPDATE CASCADE "
+ "functionality - onupdates will not be rendered for foreign "
+ "keys. Consider using deferrable=True, initially='deferred' "
+ "or triggers."
+ )
+
+ return text
+
+ def visit_drop_table_comment(self, drop):
+ return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_create_index(self, create):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ if index.dialect_options["oracle"]["bitmap"]:
+ text += "BITMAP "
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=True),
+ preparer.format_table(index.table, use_schema=True),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+ if index.dialect_options["oracle"]["compress"] is not False:
+ if index.dialect_options["oracle"]["compress"] is True:
+ text += " COMPRESS"
+ else:
+ text += " COMPRESS %d" % (
+ index.dialect_options["oracle"]["compress"]
+ )
+ return text
+
+ def post_create_table(self, table):
+ table_opts = []
+ opts = table.dialect_options["oracle"]
+
+ if opts["on_commit"]:
+ on_commit_options = opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
+
+ if opts["compress"]:
+ if opts["compress"] is True:
+ table_opts.append("\n COMPRESS")
+ else:
+ table_opts.append("\n COMPRESS FOR %s" % (opts["compress"]))
+
+ return "".join(table_opts)
+
+ def get_identity_options(self, identity_options):
+ text = super(OracleDDLCompiler, self).get_identity_options(
+ identity_options
+ )
+ text = text.replace("NO MINVALUE", "NOMINVALUE")
+ text = text.replace("NO MAXVALUE", "NOMAXVALUE")
+ text = text.replace("NO CYCLE", "NOCYCLE")
+ text = text.replace("NO ORDER", "NOORDER")
+ return text
+
+ def visit_computed_column(self, generated):
+ text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ if generated.persisted is True:
+ raise exc.CompileError(
+ "Oracle computed columns do not support 'stored' persistence; "
+ "set the 'persisted' flag to None or False for Oracle support."
+ )
+ elif generated.persisted is False:
+ text += " VIRTUAL"
+ return text
+
+ def visit_identity_column(self, identity, **kw):
+ if identity.always is None:
+ kind = ""
+ else:
+ kind = "ALWAYS" if identity.always else "BY DEFAULT"
+ text = "GENERATED %s" % kind
+ if identity.on_null:
+ text += " ON NULL"
+ text += " AS IDENTITY"
+ options = self.get_identity_options(identity)
+ if options:
+ text += " (%s)" % options
+ return text
+
+
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
+
+ reserved_words = {x.lower() for x in RESERVED_WORDS}
+ illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union(
+ ["_", "$"]
+ )
+
+ def _bindparam_requires_quotes(self, value):
+ """Return True if the given identifier requires quoting."""
+ lc_value = value.lower()
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ )
+
+ def format_savepoint(self, savepoint):
+ name = savepoint.ident.lstrip("_")
+ return super(OracleIdentifierPreparer, self).format_savepoint(
+ savepoint, name
+ )
+
+
+class OracleExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ "SELECT "
+ + self.identifier_preparer.format_sequence(seq)
+ + ".nextval FROM DUAL",
+ type_,
+ )
+
+
+class OracleDialect(default.DefaultDialect):
+ name = "oracle"
+ supports_statement_cache = True
+ supports_alter = True
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+ max_identifier_length = 128
+
+ supports_simple_order_by_label = False
+ cte_follows_insert = True
+
+ supports_sequences = True
+ sequences_optional = False
+ postfetch_lastrowid = False
+
+ default_paramstyle = "named"
+ colspecs = colspecs
+ ischema_names = ischema_names
+ requires_name_normalize = True
+
+ supports_comments = True
+
+ supports_default_values = False
+ supports_default_metavalue = True
+ supports_empty_insert = False
+ supports_identity_columns = True
+
+ statement_compiler = OracleCompiler
+ ddl_compiler = OracleDDLCompiler
+ type_compiler = OracleTypeCompiler
+ preparer = OracleIdentifierPreparer
+ execution_ctx_cls = OracleExecutionContext
+
+ reflection_options = ("oracle_resolve_synonyms",)
+
+ _use_nchar_for_unicode = False
+
+ construct_arguments = [
+ (
+ sa_schema.Table,
+ {"resolve_synonyms": False, "on_commit": None, "compress": False},
+ ),
+ (sa_schema.Index, {"bitmap": False, "compress": False}),
+ ]
+
+ @util.deprecated_params(
+ use_binds_for_limits=(
+ "1.4",
+ "The ``use_binds_for_limits`` Oracle dialect parameter is "
+ "deprecated. The dialect now renders LIMIT /OFFSET integers "
+ "inline in all cases using a post-compilation hook, so that the "
+ "value is still represented by a 'bound parameter' on the Core "
+ "Expression side.",
+ )
+ )
+ def __init__(
+ self,
+ use_ansi=True,
+ optimize_limits=False,
+ use_binds_for_limits=None,
+ use_nchar_for_unicode=False,
+ exclude_tablespaces=("SYSTEM", "SYSAUX"),
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self._use_nchar_for_unicode = use_nchar_for_unicode
+ self.use_ansi = use_ansi
+ self.optimize_limits = optimize_limits
+ self.exclude_tablespaces = exclude_tablespaces
+
+ def initialize(self, connection):
+ super(OracleDialect, self).initialize(connection)
+
+ self.implicit_returning = self.__dict__.get(
+ "implicit_returning", self.server_version_info > (10,)
+ )
+
+ if self._is_oracle_8:
+ self.colspecs = self.colspecs.copy()
+ self.colspecs.pop(sqltypes.Interval)
+ self.use_ansi = False
+
+ self.supports_identity_columns = self.server_version_info >= (12,)
+
+ def _get_effective_compat_server_version_info(self, connection):
+ # dialect does not need compat levels below 12.2, so don't query
+ # in those cases
+
+ if self.server_version_info < (12, 2):
+ return self.server_version_info
+ try:
+ compat = connection.exec_driver_sql(
+ "SELECT value FROM v$parameter WHERE name = 'compatible'"
+ ).scalar()
+ except exc.DBAPIError:
+ compat = None
+
+ if compat:
+ try:
+ return tuple(int(x) for x in compat.split("."))
+ except:
+ return self.server_version_info
+ else:
+ return self.server_version_info
+
+ @property
+ def _is_oracle_8(self):
+ return self.server_version_info and self.server_version_info < (9,)
+
+ @property
+ def _supports_table_compression(self):
+ return self.server_version_info and self.server_version_info >= (10, 1)
+
+ @property
+ def _supports_table_compress_for(self):
+ return self.server_version_info and self.server_version_info >= (11,)
+
+ @property
+ def _supports_char_length(self):
+ return not self._is_oracle_8
+
+ @property
+ def _supports_update_returning_computed_cols(self):
+ # on version 18 this error is no longet present while it happens on 11
+ # it may work also on versions before the 18
+ return self.server_version_info and self.server_version_info >= (18,)
+
+ def do_release_savepoint(self, connection, name):
+ # Oracle does not support RELEASE SAVEPOINT
+ pass
+
+ def _check_max_identifier_length(self, connection):
+ if self._get_effective_compat_server_version_info(connection) < (
+ 12,
+ 2,
+ ):
+ return 30
+ else:
+ # use the default
+ return None
+
+ def _check_unicode_returns(self, connection):
+ additional_tests = [
+ expression.cast(
+ expression.literal_column("'test nvarchar2 returns'"),
+ sqltypes.NVARCHAR(60),
+ )
+ ]
+ return super(OracleDialect, self)._check_unicode_returns(
+ connection, additional_tests
+ )
+
+ _isolation_lookup = ["READ COMMITTED", "SERIALIZABLE"]
+
+ def get_isolation_level(self, connection):
+ raise NotImplementedError("implemented by cx_Oracle dialect")
+
+ def get_default_isolation_level(self, dbapi_conn):
+ try:
+ return self.get_isolation_level(dbapi_conn)
+ except NotImplementedError:
+ raise
+ except:
+ return "READ COMMITTED"
+
+ def set_isolation_level(self, connection, level):
+ raise NotImplementedError("implemented by cx_Oracle dialect")
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ if not schema:
+ schema = self.default_schema_name
+
+ cursor = connection.execute(
+ sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE table_name = CAST(:name AS VARCHAR2(128)) "
+ "AND owner = CAST(:schema_name AS VARCHAR2(128))"
+ ),
+ dict(
+ name=self.denormalize_name(table_name),
+ schema_name=self.denormalize_name(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_name = :name AND "
+ "sequence_owner = :schema_name"
+ ),
+ dict(
+ name=self.denormalize_name(sequence_name),
+ schema_name=self.denormalize_name(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def _get_default_schema_name(self, connection):
+ return self.normalize_name(
+ connection.exec_driver_sql(
+ "select sys_context( 'userenv', 'current_schema' ) from dual"
+ ).scalar()
+ )
+
+ def _resolve_synonym(
+ self,
+ connection,
+ desired_owner=None,
+ desired_synonym=None,
+ desired_table=None,
+ ):
+ """search for a local synonym matching the given desired owner/name.
+
+ if desired_owner is None, attempts to locate a distinct owner.
+
+ returns the actual name, owner, dblink name, and synonym name if
+ found.
+ """
+
+ q = (
+ "SELECT owner, table_owner, table_name, db_link, "
+ "synonym_name FROM all_synonyms WHERE "
+ )
+ clauses = []
+ params = {}
+ if desired_synonym:
+ clauses.append(
+ "synonym_name = CAST(:synonym_name AS VARCHAR2(128))"
+ )
+ params["synonym_name"] = desired_synonym
+ if desired_owner:
+ clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))")
+ params["desired_owner"] = desired_owner
+ if desired_table:
+ clauses.append("table_name = CAST(:tname AS VARCHAR2(128))")
+ params["tname"] = desired_table
+
+ q += " AND ".join(clauses)
+
+ result = connection.execution_options(future_result=True).execute(
+ sql.text(q), params
+ )
+ if desired_owner:
+ row = result.mappings().first()
+ if row:
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
+ else:
+ return None, None, None, None
+ else:
+ rows = result.mappings().all()
+ if len(rows) > 1:
+ raise AssertionError(
+ "There are multiple tables visible to the schema, you "
+ "must specify owner"
+ )
+ elif len(rows) == 1:
+ row = rows[0]
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
+ else:
+ return None, None, None, None
+
+ @reflection.cache
+ def _prepare_reflection_args(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ if resolve_synonyms:
+ actual_name, owner, dblink, synonym = self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(schema),
+ desired_synonym=self.denormalize_name(table_name),
+ )
+ else:
+ actual_name, owner, dblink, synonym = None, None, None, None
+ if not actual_name:
+ actual_name = self.denormalize_name(table_name)
+
+ if dblink:
+ # using user_db_links here since all_db_links appears
+ # to have more restricted permissions.
+ # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
+ # will need to hear from more users if we are doing
+ # the right thing here. See [ticket:2619]
+ owner = connection.scalar(
+ sql.text(
+ "SELECT username FROM user_db_links " "WHERE db_link=:link"
+ ),
+ dict(link=dblink),
+ )
+ dblink = "@" + dblink
+ elif not owner:
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ return (actual_name, owner, dblink or "", synonym)
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = "SELECT username FROM all_users ORDER BY username"
+ cursor = connection.exec_driver_sql(s)
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.default_schema_name)
+
+ # note that table_names() isn't loading DBLINKed or synonym'ed tables
+ if schema is None:
+ schema = self.default_schema_name
+
+ sql_str = "SELECT table_name FROM all_tables WHERE "
+ if self.exclude_tablespaces:
+ sql_str += (
+ "nvl(tablespace_name, 'no tablespace') "
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ )
+ sql_str += (
+ "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ )
+
+ cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ schema = self.denormalize_name(self.default_schema_name)
+
+ sql_str = "SELECT table_name FROM all_tables WHERE "
+ if self.exclude_tablespaces:
+ sql_str += (
+ "nvl(tablespace_name, 'no tablespace') "
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ )
+ sql_str += (
+ "OWNER = :owner "
+ "AND IOT_NAME IS NULL "
+ "AND DURATION IS NOT NULL"
+ )
+
+ cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.default_schema_name)
+ s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
+ cursor = connection.execute(
+ s, dict(owner=self.denormalize_name(schema))
+ )
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_owner = :schema_name"
+ ),
+ dict(schema_name=self.denormalize_name(schema)),
+ )
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ options = {}
+
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ params = {"table_name": table_name}
+
+ columns = ["table_name"]
+ if self._supports_table_compression:
+ columns.append("compression")
+ if self._supports_table_compress_for:
+ columns.append("compress_for")
+
+ text = (
+ "SELECT %(columns)s "
+ "FROM ALL_TABLES%(dblink)s "
+ "WHERE table_name = CAST(:table_name AS VARCHAR(128))"
+ )
+
+ if schema is not None:
+ params["owner"] = schema
+ text += " AND owner = CAST(:owner AS VARCHAR(128)) "
+ text = text % {"dblink": dblink, "columns": ", ".join(columns)}
+
+ result = connection.execute(sql.text(text), params)
+
+ enabled = dict(DISABLED=False, ENABLED=True)
+
+ row = result.first()
+ if row:
+ if "compression" in row._fields and enabled.get(
+ row.compression, False
+ ):
+ if "compress_for" in row._fields:
+ options["oracle_compress"] = row.compress_for
+ else:
+ options["oracle_compress"] = True
+
+ return options
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ columns = []
+ if self._supports_char_length:
+ char_length_col = "char_length"
+ else:
+ char_length_col = "data_length"
+
+ if self.server_version_info >= (12,):
+ identity_cols = """\
+ col.default_on_null,
+ (
+ SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
+ FROM ALL_TAB_IDENTITY_COLS%(dblink)s id
+ WHERE col.table_name = id.table_name
+ AND col.column_name = id.column_name
+ AND col.owner = id.owner
+ ) AS identity_options""" % {
+ "dblink": dblink
+ }
+ else:
+ identity_cols = "NULL as default_on_null, NULL as identity_options"
+
+ params = {"table_name": table_name}
+
+ text = """
+ SELECT
+ col.column_name,
+ col.data_type,
+ col.%(char_length_col)s,
+ col.data_precision,
+ col.data_scale,
+ col.nullable,
+ col.data_default,
+ com.comments,
+ col.virtual_column,
+ %(identity_cols)s
+ FROM all_tab_cols%(dblink)s col
+ LEFT JOIN all_col_comments%(dblink)s com
+ ON col.table_name = com.table_name
+ AND col.column_name = com.column_name
+ AND col.owner = com.owner
+ WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
+ AND col.hidden_column = 'NO'
+ """
+ if schema is not None:
+ params["owner"] = schema
+ text += " AND col.owner = :owner "
+ text += " ORDER BY col.column_id"
+ text = text % {
+ "dblink": dblink,
+ "char_length_col": char_length_col,
+ "identity_cols": identity_cols,
+ }
+
+ c = connection.execute(sql.text(text), params)
+
+ for row in c:
+ colname = self.normalize_name(row[0])
+ orig_colname = row[0]
+ coltype = row[1]
+ length = row[2]
+ precision = row[3]
+ scale = row[4]
+ nullable = row[5] == "Y"
+ default = row[6]
+ comment = row[7]
+ generated = row[8]
+ default_on_nul = row[9]
+ identity_options = row[10]
+
+ if coltype == "NUMBER":
+ if precision is None and scale == 0:
+ coltype = INTEGER()
+ else:
+ coltype = NUMBER(precision, scale)
+ elif coltype == "FLOAT":
+ # TODO: support "precision" here as "binary_precision"
+ coltype = FLOAT()
+ elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
+ coltype = self.ischema_names.get(coltype)(length)
+ elif "WITH TIME ZONE" in coltype:
+ coltype = TIMESTAMP(timezone=True)
+ else:
+ coltype = re.sub(r"\(\d+\)", "", coltype)
+ try:
+ coltype = self.ischema_names[coltype]
+ except KeyError:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (coltype, colname)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ if generated == "YES":
+ computed = dict(sqltext=default)
+ default = None
+ else:
+ computed = None
+
+ if identity_options is not None:
+ identity = self._parse_identity_options(
+ identity_options, default_on_nul
+ )
+ default = None
+ else:
+ identity = None
+
+ cdict = {
+ "name": colname,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "comment": comment,
+ }
+ if orig_colname.lower() == orig_colname:
+ cdict["quote"] = True
+ if computed is not None:
+ cdict["computed"] = computed
+ if identity is not None:
+ cdict["identity"] = identity
+
+ columns.append(cdict)
+ return columns
+
+ def _parse_identity_options(self, identity_options, default_on_nul):
+ # identity_options is a string that starts with 'ALWAYS,' or
+ # 'BY DEFAULT,' and continues with
+ # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1,
+ # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N,
+ # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N
+ parts = [p.strip() for p in identity_options.split(",")]
+ identity = {
+ "always": parts[0] == "ALWAYS",
+ "on_null": default_on_nul == "YES",
+ }
+
+ for part in parts[1:]:
+ option, value = part.split(":")
+ value = value.strip()
+
+ if "START WITH" in option:
+ identity["start"] = compat.long_type(value)
+ elif "INCREMENT BY" in option:
+ identity["increment"] = compat.long_type(value)
+ elif "MAX_VALUE" in option:
+ identity["maxvalue"] = compat.long_type(value)
+ elif "MIN_VALUE" in option:
+ identity["minvalue"] = compat.long_type(value)
+ elif "CYCLE_FLAG" in option:
+ identity["cycle"] = value == "Y"
+ elif "CACHE_SIZE" in option:
+ identity["cache"] = compat.long_type(value)
+ elif "ORDER_FLAG" in option:
+ identity["order"] = value == "Y"
+ return identity
+
+ @reflection.cache
+ def get_table_comment(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ if not schema:
+ schema = self.default_schema_name
+
+ COMMENT_SQL = """
+ SELECT comments
+ FROM all_tab_comments
+ WHERE table_name = CAST(:table_name AS VARCHAR(128))
+ AND owner = CAST(:schema_name AS VARCHAR(128))
+ """
+
+ c = connection.execute(
+ sql.text(COMMENT_SQL),
+ dict(table_name=table_name, schema_name=schema),
+ )
+ return {"text": c.scalar()}
+
+ @reflection.cache
+ def get_indexes(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ indexes = []
+
+ params = {"table_name": table_name}
+ text = (
+ "SELECT a.index_name, a.column_name, "
+ "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
+ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
+ "\nALL_INDEXES%(dblink)s b "
+ "\nWHERE "
+ "\na.index_name = b.index_name "
+ "\nAND a.table_owner = b.table_owner "
+ "\nAND a.table_name = b.table_name "
+ "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))"
+ )
+
+ if schema is not None:
+ params["schema"] = schema
+ text += "AND a.table_owner = :schema "
+
+ text += "ORDER BY a.index_name, a.column_position"
+
+ text = text % {"dblink": dblink}
+
+ q = sql.text(text)
+ rp = connection.execute(q, params)
+ indexes = []
+ last_index_name = None
+ pk_constraint = self.get_pk_constraint(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms=resolve_synonyms,
+ dblink=dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
+ enabled = dict(DISABLED=False, ENABLED=True)
+
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
+
+ index = None
+ for rset in rp:
+ index_name_normalized = self.normalize_name(rset.index_name)
+
+ # skip primary key index. This is refined as of
+ # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y"
+ # if the name of this index was generated by Oracle, however
+ # if a named primary key constraint was created then this flag
+ # is false.
+ if (
+ pk_constraint
+ and index_name_normalized == pk_constraint["name"]
+ ):
+ continue
+
+ if rset.index_name != last_index_name:
+ index = dict(
+ name=index_name_normalized,
+ column_names=[],
+ dialect_options={},
+ )
+ indexes.append(index)
+ index["unique"] = uniqueness.get(rset.uniqueness, False)
+
+ if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
+ index["dialect_options"]["oracle_bitmap"] = True
+ if enabled.get(rset.compression, False):
+ index["dialect_options"][
+ "oracle_compress"
+ ] = rset.prefix_length
+
+ # filter out Oracle SYS_NC names. could also do an outer join
+ # to the all_tab_columns table and check for real col names there.
+ if not oracle_sys_col.match(rset.column_name):
+ index["column_names"].append(
+ self.normalize_name(rset.column_name)
+ )
+ last_index_name = rset.index_name
+
+ return indexes
+
+ @reflection.cache
+ def _get_constraint_data(
+ self, connection, table_name, schema=None, dblink="", **kw
+ ):
+
+ params = {"table_name": table_name}
+
+ text = (
+ "SELECT"
+ "\nac.constraint_name," # 0
+ "\nac.constraint_type," # 1
+ "\nloc.column_name AS local_column," # 2
+ "\nrem.table_name AS remote_table," # 3
+ "\nrem.column_name AS remote_column," # 4
+ "\nrem.owner AS remote_owner," # 5
+ "\nloc.position as loc_pos," # 6
+ "\nrem.position as rem_pos," # 7
+ "\nac.search_condition," # 8
+ "\nac.delete_rule" # 9
+ "\nFROM all_constraints%(dblink)s ac,"
+ "\nall_cons_columns%(dblink)s loc,"
+ "\nall_cons_columns%(dblink)s rem"
+ "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))"
+ "\nAND ac.constraint_type IN ('R','P', 'U', 'C')"
+ )
+
+ if schema is not None:
+ params["owner"] = schema
+ text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))"
+
+ text += (
+ "\nAND ac.owner = loc.owner"
+ "\nAND ac.constraint_name = loc.constraint_name"
+ "\nAND ac.r_owner = rem.owner(+)"
+ "\nAND ac.r_constraint_name = rem.constraint_name(+)"
+ "\nAND (rem.position IS NULL or loc.position=rem.position)"
+ "\nORDER BY ac.constraint_name, loc.position"
+ )
+
+ text = text % {"dblink": dblink}
+ rp = connection.execute(sql.text(text), params)
+ constraint_data = rp.fetchall()
+ return constraint_data
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ pkeys = []
+ constraint_name = None
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ for row in constraint_data:
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+ if cons_type == "P":
+ if constraint_name is None:
+ constraint_name = self.normalize_name(cons_name)
+ pkeys.append(local_column)
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+ requested_schema = schema # to check later on
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ def fkey_rec():
+ return {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ "options": {},
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for row in constraint_data:
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+
+ cons_name = self.normalize_name(cons_name)
+
+ if cons_type == "R":
+ if remote_table is None:
+ # ticket 363
+ util.warn(
+ (
+ "Got 'None' querying 'table_name' from "
+ "all_cons_columns%(dblink)s - does the user have "
+ "proper rights to the table?"
+ )
+ % {"dblink": dblink}
+ )
+ continue
+
+ rec = fkeys[cons_name]
+ rec["name"] = cons_name
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
+
+ if not rec["referred_table"]:
+ if resolve_synonyms:
+ (
+ ref_remote_name,
+ ref_remote_owner,
+ ref_dblink,
+ ref_synonym,
+ ) = self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(remote_owner),
+ desired_table=self.denormalize_name(remote_table),
+ )
+ if ref_synonym:
+ remote_table = self.normalize_name(ref_synonym)
+ remote_owner = self.normalize_name(
+ ref_remote_owner
+ )
+
+ rec["referred_table"] = remote_table
+
+ if (
+ requested_schema is not None
+ or self.denormalize_name(remote_owner) != schema
+ ):
+ rec["referred_schema"] = remote_owner
+
+ if row[9] != "NO ACTION":
+ rec["options"]["ondelete"] = row[9]
+
+ local_cols.append(local_column)
+ remote_cols.append(remote_column)
+
+ return list(fkeys.values())
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ unique_keys = filter(lambda x: x[1] == "U", constraint_data)
+ uniques_group = groupby(unique_keys, lambda x: x[0])
+
+ index_names = {
+ ix["name"]
+ for ix in self.get_indexes(connection, table_name, schema=schema)
+ }
+ return [
+ {
+ "name": name,
+ "column_names": cols,
+ "duplicates_index": name if name in index_names else None,
+ }
+ for name, cols in [
+ [
+ self.normalize_name(i[0]),
+ [self.normalize_name(x[2]) for x in i[1]],
+ ]
+ for i in uniques_group
+ ]
+ ]
+
+ @reflection.cache
+ def get_view_definition(
+ self,
+ connection,
+ view_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+ info_cache = kw.get("info_cache")
+ (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ view_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ params = {"view_name": view_name}
+ text = "SELECT text FROM all_views WHERE view_name=:view_name"
+
+ if schema is not None:
+ text += " AND owner = :schema"
+ params["schema"] = schema
+
+ rp = connection.execute(sql.text(text), params).scalar()
+ if rp:
+ if util.py2k:
+ rp = rp.decode(self.encoding)
+ return rp
+ else:
+ return None
+
+ @reflection.cache
+ def get_check_constraints(
+ self, connection, table_name, schema=None, include_all=False, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ check_constraints = filter(lambda x: x[1] == "C", constraint_data)
+
+ return [
+ {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
+ for cons in check_constraints
+ if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
+ ]
+
+
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = "outer_join_column"
+
+ def __init__(self, column):
+ self.column = column
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
new file mode 100644
index 0000000..64029a4
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -0,0 +1,1424 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: oracle+cx_oracle
+ :name: cx-Oracle
+ :dbapi: cx_oracle
+ :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
+ :url: https://oracle.github.io/python-cx_Oracle/
+
+DSN vs. Hostname connections
+-----------------------------
+
+cx_Oracle provides several methods of indicating the target database. The
+dialect translates from a series of different URL forms.
+
+Hostname Connections with Easy Connect Syntax
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Given a hostname, port and service name of the target Oracle Database, for
+example from Oracle's `Easy Connect syntax
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#easy-connect-syntax-for-connection-strings>`_,
+then connect in SQLAlchemy using the ``service_name`` query string parameter::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8")
+
+The `full Easy Connect syntax
+<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_
+is not supported. Instead, use a ``tnsnames.ora`` file and connect using a
+DSN.
+
+Connections with tnsnames.ora or Oracle Cloud
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Alternatively, if no port, database name, or ``service_name`` is provided, the
+dialect will use an Oracle DSN "connection string". This takes the "hostname"
+portion of the URL as the data source name. For example, if the
+``tnsnames.ora`` file contains a `Net Service Name
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#net-service-names-for-connection-strings>`_
+of ``myalias`` as below::
+
+ myalias =
+ (DESCRIPTION =
+ (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521))
+ (CONNECT_DATA =
+ (SERVER = DEDICATED)
+ (SERVICE_NAME = orclpdb1)
+ )
+ )
+
+The cx_Oracle dialect connects to this database service when ``myalias`` is the
+hostname portion of the URL, without specifying a port, database name or
+``service_name``::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8")
+
+Users of Oracle Cloud should use this syntax and also configure the cloud
+wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-autononmous-databases>`_.
+
+SID Connections
+^^^^^^^^^^^^^^^
+
+To use Oracle's obsolete SID connection syntax, the SID can be passed in a
+"database name" portion of the URL as below::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8")
+
+Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as
+follows::
+
+ >>> import cx_Oracle
+ >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname")
+ '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))'
+
+Passing cx_Oracle connect arguments
+-----------------------------------
+
+Additional connection arguments can usually be passed via the URL
+query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted
+and converted to the correct symbol::
+
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true")
+
+.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names
+ within the URL string itself, to be passed to the cx_Oracle DBAPI. As
+ was the case earlier but not correctly documented, the
+ :paramref:`_sa.create_engine.connect_args` parameter also accepts all
+ cx_Oracle DBAPI connect arguments.
+
+To pass arguments directly to ``.connect()`` without using the query
+string, use the :paramref:`_sa.create_engine.connect_args` dictionary.
+Any cx_Oracle parameter value and/or constant may be passed, such as::
+
+ import cx_Oracle
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn",
+ connect_args={
+ "encoding": "UTF-8",
+ "nencoding": "UTF-8",
+ "mode": cx_Oracle.SYSDBA,
+ "events": True
+ }
+ )
+
+Note that the default value for ``encoding`` and ``nencoding`` was changed to
+"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that
+version, or later.
+
+Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver
+--------------------------------------------------------------------------
+
+There are also options that are consumed by the SQLAlchemy cx_oracle dialect
+itself. These options are always passed directly to :func:`_sa.create_engine`
+, such as::
+
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn", coerce_to_unicode=False)
+
+The parameters accepted by the cx_oracle dialect are as follows:
+
+* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
+ to 50. This setting is significant with cx_Oracle as the contents of LOB
+ objects are only readable within a "live" row (e.g. within a batch of
+ 50 rows).
+
+* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
+
+* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail.
+
+* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail.
+
+* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
+
+.. _cx_oracle_sessionpool:
+
+Using cx_Oracle SessionPool
+---------------------------
+
+The cx_Oracle library provides its own connection pool implementation that may
+be used in place of SQLAlchemy's pooling functionality. This can be achieved
+by using the :paramref:`_sa.create_engine.creator` parameter to provide a
+function that returns a new connection, along with setting
+:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
+SQLAlchemy's pooling::
+
+ import cx_Oracle
+ from sqlalchemy import create_engine
+ from sqlalchemy.pool import NullPool
+
+ pool = cx_Oracle.SessionPool(
+ user="scott", password="tiger", dsn="orclpdb",
+ min=2, max=5, increment=1, threaded=True,
+ encoding="UTF-8", nencoding="UTF-8"
+ )
+
+ engine = create_engine("oracle://", creator=pool.acquire, poolclass=NullPool)
+
+The above engine may then be used normally where cx_Oracle's pool handles
+connection pooling::
+
+ with engine.connect() as conn:
+ print(conn.scalar("select 1 FROM dual"))
+
+
+As well as providing a scalable solution for multi-user applications, the
+cx_Oracle session pool supports some Oracle features such as DRCP and
+`Application Continuity
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/ha.html#application-continuity-ac>`_.
+
+Using Oracle Database Resident Connection Pooling (DRCP)
+--------------------------------------------------------
+
+When using Oracle's `DRCP
+<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-015CA8C1-2386-4626-855D-CC546DDC1086>`_,
+the best practice is to pass a connection class and "purity" when acquiring a
+connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
+
+This can be achieved by wrapping ``pool.acquire()``::
+
+ import cx_Oracle
+ from sqlalchemy import create_engine
+ from sqlalchemy.pool import NullPool
+
+ pool = cx_Oracle.SessionPool(
+ user="scott", password="tiger", dsn="orclpdb",
+ min=2, max=5, increment=1, threaded=True,
+ encoding="UTF-8", nencoding="UTF-8"
+ )
+
+ def creator():
+ return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF)
+
+ engine = create_engine("oracle://", creator=creator, poolclass=NullPool)
+
+The above engine may then be used normally where cx_Oracle handles session
+pooling and Oracle Database additionally uses DRCP::
+
+ with engine.connect() as conn:
+ print(conn.scalar("select 1 FROM dual"))
+
+.. _cx_oracle_unicode:
+
+Unicode
+-------
+
+As is the case for all DBAPIs under Python 3, all strings are inherently
+Unicode strings. Under Python 2, cx_Oracle also supports Python Unicode
+objects directly. In all cases however, the driver requires an explicit
+encoding configuration.
+
+Ensuring the Correct Client Encoding
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The long accepted standard for establishing client encoding for nearly all
+Oracle related software is via the `NLS_LANG <https://www.oracle.com/database/technologies/faq-nls-lang.html>`_
+environment variable. cx_Oracle like most other Oracle drivers will use
+this environment variable as the source of its encoding configuration. The
+format of this variable is idiosyncratic; a typical value would be
+``AMERICAN_AMERICA.AL32UTF8``.
+
+The cx_Oracle driver also supports a programmatic alternative which is to
+pass the ``encoding`` and ``nencoding`` parameters directly to its
+``.connect()`` function. These can be present in the URL as follows::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8")
+
+For the meaning of the ``encoding`` and ``nencoding`` parameters, please
+consult
+`Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_.
+
+.. seealso::
+
+ `Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_
+ - in the cx_Oracle documentation.
+
+
+Unicode-specific Column datatypes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The Core expression language handles unicode data by use of the :class:`.Unicode`
+and :class:`.UnicodeText`
+datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by
+default. When using these datatypes with Unicode data, it is expected that
+the Oracle database is configured with a Unicode-aware character set, as well
+as that the ``NLS_LANG`` environment variable is set appropriately, so that
+the VARCHAR2 and CLOB datatypes can accommodate the data.
+
+In the case that the Oracle database is not configured with a Unicode character
+set, the two options are to use the :class:`_types.NCHAR` and
+:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
+``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`,
+which will cause the
+SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
+:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
+
+.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
+ datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes
+ unless the ``use_nchar_for_unicode=True`` is passed to the dialect
+ when :func:`_sa.create_engine` is called.
+
+Unicode Coercion of result rows under Python 2
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When result sets are fetched that include strings, under Python 3 the cx_Oracle
+DBAPI returns all strings as Python Unicode objects, since Python 3 only has a
+Unicode string type. This occurs for data fetched from datatypes such as
+VARCHAR2, CHAR, CLOB, NCHAR, NCLOB, etc. In order to provide cross-
+compatibility under Python 2, the SQLAlchemy cx_Oracle dialect will add
+Unicode-conversion to string data under Python 2 as well. Historically, this
+made use of converters that were supplied by cx_Oracle but were found to be
+non-performant; SQLAlchemy's own converters are used for the string to Unicode
+conversion under Python 2. To disable the Python 2 Unicode conversion for
+VARCHAR2, CHAR, and CLOB, the flag ``coerce_to_unicode=False`` can be passed to
+:func:`_sa.create_engine`.
+
+.. versionchanged:: 1.3 Unicode conversion is applied to all string values
+ by default under python 2. The ``coerce_to_unicode`` now defaults to True
+ and can be set to False to disable the Unicode coercion of strings that are
+ delivered as VARCHAR2/CHAR/CLOB data.
+
+.. _cx_oracle_unicode_encoding_errors:
+
+Encoding Errors
+^^^^^^^^^^^^^^^
+
+For the unusual case that data in the Oracle database is present with a broken
+encoding, the dialect accepts a parameter ``encoding_errors`` which will be
+passed to Unicode decoding functions in order to affect how decoding errors are
+handled. The value is ultimately consumed by the Python `decode
+<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`_ function, and
+is passed both via cx_Oracle's ``encodingErrors`` parameter consumed by
+``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the
+cx_Oracle dialect makes use of both under different circumstances.
+
+.. versionadded:: 1.3.11
+
+
+.. _cx_oracle_setinputsizes:
+
+Fine grained control over cx_Oracle data binding performance with setinputsizes
+-------------------------------------------------------------------------------
+
+The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the
+DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
+datatypes that are bound to a SQL statement for Python values being passed as
+parameters. While virtually no other DBAPI assigns any use to the
+``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its
+interactions with the Oracle client interface, and in some scenarios it is not
+possible for SQLAlchemy to know exactly how data should be bound, as some
+settings can cause profoundly different performance characteristics, while
+altering the type coercion behavior at the same time.
+
+Users of the cx_Oracle dialect are **strongly encouraged** to read through
+cx_Oracle's list of built-in datatype symbols at
+https://cx-oracle.readthedocs.io/en/latest/api_manual/module.html#database-types.
+Note that in some cases, significant performance degradation can occur when
+using these types vs. not, in particular when specifying ``cx_Oracle.CLOB``.
+
+On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can
+be used both for runtime visibility (e.g. logging) of the setinputsizes step as
+well as to fully control how ``setinputsizes()`` is used on a per-statement
+basis.
+
+.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes`
+
+
+Example 1 - logging all setinputsizes calls
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The following example illustrates how to log the intermediary values from a
+SQLAlchemy perspective before they are converted to the raw ``setinputsizes()``
+parameter dictionary. The keys of the dictionary are :class:`.BindParameter`
+objects which have a ``.key`` and a ``.type`` attribute::
+
+ from sqlalchemy import create_engine, event
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
+
+ @event.listens_for(engine, "do_setinputsizes")
+ def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
+ for bindparam, dbapitype in inputsizes.items():
+ log.info(
+ "Bound parameter name: %s SQLAlchemy type: %r "
+ "DBAPI object: %s",
+ bindparam.key, bindparam.type, dbapitype)
+
+Example 2 - remove all bindings to CLOB
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The ``CLOB`` datatype in cx_Oracle incurs a significant performance overhead,
+however is set by default for the ``Text`` type within the SQLAlchemy 1.2
+series. This setting can be modified as follows::
+
+ from sqlalchemy import create_engine, event
+ from cx_Oracle import CLOB
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
+
+ @event.listens_for(engine, "do_setinputsizes")
+ def _remove_clob(inputsizes, cursor, statement, parameters, context):
+ for bindparam, dbapitype in list(inputsizes.items()):
+ if dbapitype is CLOB:
+ del inputsizes[bindparam]
+
+.. _cx_oracle_returning:
+
+RETURNING Support
+-----------------
+
+The cx_Oracle dialect implements RETURNING using OUT parameters.
+The dialect supports RETURNING fully, however cx_Oracle 6 is recommended
+for complete support.
+
+.. _cx_oracle_lob:
+
+LOB Objects
+-----------
+
+cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy
+converts these to strings so that the interface of the Binary type is
+consistent with that of other backends, which takes place within a cx_Oracle
+outputtypehandler.
+
+cx_Oracle prior to version 6 would require that LOB objects be read before
+a new batch of rows would be read, as determined by the ``cursor.arraysize``.
+As of the 6 series, this limitation has been lifted. Nevertheless, because
+SQLAlchemy pre-reads these LOBs up front, this issue is avoided in any case.
+
+To disable the auto "read()" feature of the dialect, the flag
+``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`. Under
+the cx_Oracle 5 series, having this flag turned off means there is the chance
+of reading from a stale LOB object if not read as it is fetched. With
+cx_Oracle 6, this issue is resolved.
+
+.. versionchanged:: 1.2 the LOB handling system has been greatly simplified
+ internally to make use of outputtypehandlers, and no longer makes use
+ of alternate "buffered" result set objects.
+
+Two Phase Transactions Not Supported
+-------------------------------------
+
+Two phase transactions are **not supported** under cx_Oracle due to poor
+driver support. As of cx_Oracle 6.0b1, the interface for
+two phase transactions has been changed to be more of a direct pass-through
+to the underlying OCI layer with less automation. The additional logic
+to support this system is not implemented in SQLAlchemy.
+
+.. _cx_oracle_numeric:
+
+Precision Numerics
+------------------
+
+SQLAlchemy's numeric types can handle receiving and returning values as Python
+``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
+subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
+use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
+coerced to ``Decimal`` upon return, or returned as float objects. To make
+matters more complicated under Oracle, Oracle's ``NUMBER`` type can also
+represent integer values if the "scale" is zero, so the Oracle-specific
+:class:`_oracle.NUMBER` type takes this into account as well.
+
+The cx_Oracle dialect makes extensive use of connection- and cursor-level
+"outputtypehandler" callables in order to coerce numeric values as requested.
+These callables are specific to the specific flavor of :class:`.Numeric` in
+use, as well as if no SQLAlchemy typing objects are present. There are
+observed scenarios where Oracle may sends incomplete or ambiguous information
+about the numeric types being returned, such as a query where the numeric types
+are buried under multiple levels of subquery. The type handlers do their best
+to make the right decision in all cases, deferring to the underlying cx_Oracle
+DBAPI for all those cases where the driver can make the best decision.
+
+When no typing objects are present, as when executing plain SQL strings, a
+default "outputtypehandler" is present which will generally return numeric
+values which specify precision and scale as Python ``Decimal`` objects. To
+disable this coercion to decimal for performance reasons, pass the flag
+``coerce_to_decimal=False`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False)
+
+The ``coerce_to_decimal`` flag only impacts the results of plain string
+SQL statements that are not otherwise associated with a :class:`.Numeric`
+SQLAlchemy type (or a subclass of such).
+
+.. versionchanged:: 1.2 The numeric handling system for cx_Oracle has been
+ reworked to take advantage of newer cx_Oracle features as well
+ as better integration of outputtypehandlers.
+
+""" # noqa
+
+from __future__ import absolute_import
+
+import decimal
+import random
+import re
+
+from . import base as oracle
+from .base import OracleCompiler
+from .base import OracleDialect
+from .base import OracleExecutionContext
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...util import compat
+
+
+class _OracleInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ # see https://github.com/oracle/python-cx_Oracle/issues/
+ # 208#issuecomment-409715955
+ return int
+
+ def _cx_oracle_var(self, dialect, cursor):
+ cx_Oracle = dialect.dbapi
+ return cursor.var(
+ cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int
+ )
+
+ def _cx_oracle_outputtypehandler(self, dialect):
+ def handler(cursor, name, default_type, size, precision, scale):
+ return self._cx_oracle_var(dialect, cursor)
+
+ return handler
+
+
+class _OracleNumeric(sqltypes.Numeric):
+ is_number = False
+
+ def bind_processor(self, dialect):
+ if self.scale == 0:
+ return None
+ elif self.asdecimal:
+ processor = processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+
+ def process(value):
+ if isinstance(value, (int, float)):
+ return processor(value)
+ elif value is not None and value.is_infinite():
+ return float(value)
+ else:
+ return value
+
+ return process
+ else:
+ return processors.to_float
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def _cx_oracle_outputtypehandler(self, dialect):
+ cx_Oracle = dialect.dbapi
+
+ is_cx_oracle_6 = dialect._is_cx_oracle_6
+
+ def handler(cursor, name, default_type, size, precision, scale):
+ outconverter = None
+
+ if precision:
+ if self.asdecimal:
+ if default_type == cx_Oracle.NATIVE_FLOAT:
+ # receiving float and doing Decimal after the fact
+ # allows for float("inf") to be handled
+ type_ = default_type
+ outconverter = decimal.Decimal
+ elif is_cx_oracle_6:
+ type_ = decimal.Decimal
+ else:
+ type_ = cx_Oracle.STRING
+ outconverter = dialect._to_decimal
+ else:
+ if self.is_number and scale == 0:
+ # integer. cx_Oracle is observed to handle the widest
+ # variety of ints when no directives are passed,
+ # from 5.2 to 7.0. See [ticket:4457]
+ return None
+ else:
+ type_ = cx_Oracle.NATIVE_FLOAT
+
+ else:
+ if self.asdecimal:
+ if default_type == cx_Oracle.NATIVE_FLOAT:
+ type_ = default_type
+ outconverter = decimal.Decimal
+ elif is_cx_oracle_6:
+ type_ = decimal.Decimal
+ else:
+ type_ = cx_Oracle.STRING
+ outconverter = dialect._to_decimal
+ else:
+ if self.is_number and scale == 0:
+ # integer. cx_Oracle is observed to handle the widest
+ # variety of ints when no directives are passed,
+ # from 5.2 to 7.0. See [ticket:4457]
+ return None
+ else:
+ type_ = cx_Oracle.NATIVE_FLOAT
+
+ return cursor.var(
+ type_,
+ 255,
+ arraysize=cursor.arraysize,
+ outconverter=outconverter,
+ )
+
+ return handler
+
+
+class _OracleBinaryFloat(_OracleNumeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NATIVE_FLOAT
+
+
+class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT):
+ pass
+
+
+class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE):
+ pass
+
+
+class _OracleNUMBER(_OracleNumeric):
+ is_number = True
+
+
+class _OracleDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ return value.date()
+ else:
+ return value
+
+ return process
+
+
+# TODO: the names used across CHAR / VARCHAR / NCHAR / NVARCHAR
+# here are inconsistent and not very good
+class _OracleChar(sqltypes.CHAR):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FIXED_CHAR
+
+
+class _OracleNChar(sqltypes.NCHAR):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FIXED_NCHAR
+
+
+class _OracleUnicodeStringNCHAR(oracle.NVARCHAR2):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NCHAR
+
+
+class _OracleUnicodeStringCHAR(sqltypes.Unicode):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.LONG_STRING
+
+
+class _OracleUnicodeTextNCLOB(oracle.NCLOB):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NCLOB
+
+
+class _OracleUnicodeTextCLOB(sqltypes.UnicodeText):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.CLOB
+
+
+class _OracleText(sqltypes.Text):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.CLOB
+
+
+class _OracleLong(oracle.LONG):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.LONG_STRING
+
+
+class _OracleString(sqltypes.String):
+ pass
+
+
+class _OracleEnum(sqltypes.Enum):
+ def bind_processor(self, dialect):
+ enum_proc = sqltypes.Enum.bind_processor(self, dialect)
+
+ def process(value):
+ raw_str = enum_proc(value)
+ return raw_str
+
+ return process
+
+
+class _OracleBinary(sqltypes.LargeBinary):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BLOB
+
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.auto_convert_lobs:
+ return None
+ else:
+ return super(_OracleBinary, self).result_processor(
+ dialect, coltype
+ )
+
+
+class _OracleInterval(oracle.INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+
+class _OracleRaw(oracle.RAW):
+ pass
+
+
+class _OracleRowid(oracle.ROWID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.ROWID
+
+
+class OracleCompiler_cx_oracle(OracleCompiler):
+ _oracle_cx_sql_compiler = True
+
+ def bindparam_string(self, name, **kw):
+ quote = getattr(name, "quote", None)
+ if (
+ quote is True
+ or quote is not False
+ and self.preparer._bindparam_requires_quotes(name)
+ and not kw.get("post_compile", False)
+ ):
+ # interesting to note about expanding parameters - since the
+ # new parameters take the form <paramname>_<int>, at least if
+ # they are originally formed from reserved words, they no longer
+ # need quoting :). names that include illegal characters
+ # won't work however.
+ quoted_name = '"%s"' % name
+ kw["escaped_from"] = name
+ name = quoted_name
+
+ return OracleCompiler.bindparam_string(self, name, **kw)
+
+
+class OracleExecutionContext_cx_oracle(OracleExecutionContext):
+ out_parameters = None
+
+ def _generate_out_parameter_vars(self):
+ # check for has_out_parameters or RETURNING, create cx_Oracle.var
+ # objects if so
+ if self.compiled.returning or self.compiled.has_out_parameters:
+ quoted_bind_names = self.compiled.escaped_bind_names
+ for bindparam in self.compiled.binds.values():
+ if bindparam.isoutparam:
+ name = self.compiled.bind_names[bindparam]
+ type_impl = bindparam.type.dialect_impl(self.dialect)
+
+ if hasattr(type_impl, "_cx_oracle_var"):
+ self.out_parameters[name] = type_impl._cx_oracle_var(
+ self.dialect, self.cursor
+ )
+ else:
+ dbtype = type_impl.get_dbapi_type(self.dialect.dbapi)
+
+ cx_Oracle = self.dialect.dbapi
+
+ if dbtype is None:
+ raise exc.InvalidRequestError(
+ "Cannot create out parameter for "
+ "parameter "
+ "%r - its type %r is not supported by"
+ " cx_oracle" % (bindparam.key, bindparam.type)
+ )
+
+ if compat.py2k and dbtype in (
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ outconverter = (
+ processors.to_unicode_processor_factory(
+ self.dialect.encoding,
+ errors=self.dialect.encoding_errors,
+ )
+ )
+ self.out_parameters[name] = self.cursor.var(
+ dbtype,
+ outconverter=lambda value: outconverter(
+ value.read()
+ ),
+ )
+
+ elif dbtype in (
+ cx_Oracle.BLOB,
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ self.out_parameters[name] = self.cursor.var(
+ dbtype, outconverter=lambda value: value.read()
+ )
+ elif compat.py2k and isinstance(
+ type_impl, sqltypes.Unicode
+ ):
+ outconverter = (
+ processors.to_unicode_processor_factory(
+ self.dialect.encoding,
+ errors=self.dialect.encoding_errors,
+ )
+ )
+ self.out_parameters[name] = self.cursor.var(
+ dbtype, outconverter=outconverter
+ )
+ else:
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[0][
+ quoted_bind_names.get(name, name)
+ ] = self.out_parameters[name]
+
+ def _generate_cursor_outputtype_handler(self):
+ output_handlers = {}
+
+ for (keyname, name, objects, type_) in self.compiled._result_columns:
+ handler = type_._cached_custom_processor(
+ self.dialect,
+ "cx_oracle_outputtypehandler",
+ self._get_cx_oracle_type_handler,
+ )
+
+ if handler:
+ denormalized_name = self.dialect.denormalize_name(keyname)
+ output_handlers[denormalized_name] = handler
+
+ if output_handlers:
+ default_handler = self._dbapi_connection.outputtypehandler
+
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
+ if name in output_handlers:
+ return output_handlers[name](
+ cursor, name, default_type, size, precision, scale
+ )
+ else:
+ return default_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+
+ self.cursor.outputtypehandler = output_type_handler
+
+ def _get_cx_oracle_type_handler(self, impl):
+ if hasattr(impl, "_cx_oracle_outputtypehandler"):
+ return impl._cx_oracle_outputtypehandler(self.dialect)
+ else:
+ return None
+
+ def pre_exec(self):
+ if not getattr(self.compiled, "_oracle_cx_sql_compiler", False):
+ return
+
+ self.out_parameters = {}
+
+ self._generate_out_parameter_vars()
+
+ self._generate_cursor_outputtype_handler()
+
+ self.include_set_input_sizes = self.dialect._include_setinputsizes
+
+ def post_exec(self):
+ if self.compiled and self.out_parameters and self.compiled.returning:
+ # create a fake cursor result from the out parameters. unlike
+ # get_out_parameter_values(), the result-row handlers here will be
+ # applied at the Result level
+ returning_params = [
+ self.dialect._returningval(self.out_parameters["ret_%d" % i])
+ for i in range(len(self.out_parameters))
+ ]
+
+ fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy(
+ self.cursor,
+ [
+ (getattr(col, "name", col._anon_name_label), None)
+ for col in self.compiled.returning
+ ],
+ initial_buffer=[tuple(returning_params)],
+ )
+
+ self.cursor_fetch_strategy = fetch_strategy
+
+ def create_cursor(self):
+ c = self._dbapi_connection.cursor()
+ if self.dialect.arraysize:
+ c.arraysize = self.dialect.arraysize
+
+ return c
+
+ def get_out_parameter_values(self, out_param_names):
+ # this method should not be called when the compiler has
+ # RETURNING as we've turned the has_out_parameters flag set to
+ # False.
+ assert not self.compiled.returning
+
+ return [
+ self.dialect._paramval(self.out_parameters[name])
+ for name in out_param_names
+ ]
+
+
+class OracleDialect_cx_oracle(OracleDialect):
+ supports_statement_cache = True
+ execution_ctx_cls = OracleExecutionContext_cx_oracle
+ statement_compiler = OracleCompiler_cx_oracle
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ use_setinputsizes = True
+
+ driver = "cx_oracle"
+
+ colspecs = {
+ sqltypes.Numeric: _OracleNumeric,
+ sqltypes.Float: _OracleNumeric,
+ oracle.BINARY_FLOAT: _OracleBINARY_FLOAT,
+ oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE,
+ sqltypes.Integer: _OracleInteger,
+ oracle.NUMBER: _OracleNUMBER,
+ sqltypes.Date: _OracleDate,
+ sqltypes.LargeBinary: _OracleBinary,
+ sqltypes.Boolean: oracle._OracleBoolean,
+ sqltypes.Interval: _OracleInterval,
+ oracle.INTERVAL: _OracleInterval,
+ sqltypes.Text: _OracleText,
+ sqltypes.String: _OracleString,
+ sqltypes.UnicodeText: _OracleUnicodeTextCLOB,
+ sqltypes.CHAR: _OracleChar,
+ sqltypes.NCHAR: _OracleNChar,
+ sqltypes.Enum: _OracleEnum,
+ oracle.LONG: _OracleLong,
+ oracle.RAW: _OracleRaw,
+ sqltypes.Unicode: _OracleUnicodeStringCHAR,
+ sqltypes.NVARCHAR: _OracleUnicodeStringNCHAR,
+ oracle.NCLOB: _OracleUnicodeTextNCLOB,
+ oracle.ROWID: _OracleRowid,
+ }
+
+ execute_sequence_format = list
+
+ _cx_oracle_threaded = None
+
+ @util.deprecated_params(
+ threaded=(
+ "1.3",
+ "The 'threaded' parameter to the cx_oracle dialect "
+ "is deprecated as a dialect-level argument, and will be removed "
+ "in a future release. As of version 1.3, it defaults to False "
+ "rather than True. The 'threaded' option can be passed to "
+ "cx_Oracle directly in the URL query string passed to "
+ ":func:`_sa.create_engine`.",
+ )
+ )
+ def __init__(
+ self,
+ auto_convert_lobs=True,
+ coerce_to_unicode=True,
+ coerce_to_decimal=True,
+ arraysize=50,
+ encoding_errors=None,
+ threaded=None,
+ **kwargs
+ ):
+
+ OracleDialect.__init__(self, **kwargs)
+ self.arraysize = arraysize
+ self.encoding_errors = encoding_errors
+ if threaded is not None:
+ self._cx_oracle_threaded = threaded
+ self.auto_convert_lobs = auto_convert_lobs
+ self.coerce_to_unicode = coerce_to_unicode
+ self.coerce_to_decimal = coerce_to_decimal
+ if self._use_nchar_for_unicode:
+ self.colspecs = self.colspecs.copy()
+ self.colspecs[sqltypes.Unicode] = _OracleUnicodeStringNCHAR
+ self.colspecs[sqltypes.UnicodeText] = _OracleUnicodeTextNCLOB
+
+ cx_Oracle = self.dbapi
+
+ if cx_Oracle is None:
+ self._include_setinputsizes = {}
+ self.cx_oracle_ver = (0, 0, 0)
+ else:
+ self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
+ if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0):
+ raise exc.InvalidRequestError(
+ "cx_Oracle version 5.2 and above are supported"
+ )
+
+ self._include_setinputsizes = {
+ cx_Oracle.DATETIME,
+ cx_Oracle.NCLOB,
+ cx_Oracle.CLOB,
+ cx_Oracle.LOB,
+ cx_Oracle.NCHAR,
+ cx_Oracle.FIXED_NCHAR,
+ cx_Oracle.BLOB,
+ cx_Oracle.FIXED_CHAR,
+ cx_Oracle.TIMESTAMP,
+ _OracleInteger,
+ _OracleBINARY_FLOAT,
+ _OracleBINARY_DOUBLE,
+ }
+
+ self._paramval = lambda value: value.getvalue()
+
+ # https://github.com/oracle/python-cx_Oracle/issues/176#issuecomment-386821291
+ # https://github.com/oracle/python-cx_Oracle/issues/224
+ self._values_are_lists = self.cx_oracle_ver >= (6, 3)
+ if self._values_are_lists:
+ cx_Oracle.__future__.dml_ret_array_val = True
+
+ def _returningval(value):
+ try:
+ return value.values[0][0]
+ except IndexError:
+ return None
+
+ self._returningval = _returningval
+ else:
+ self._returningval = self._paramval
+
+ self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,)
+
+ @property
+ def _cursor_var_unicode_kwargs(self):
+ if self.encoding_errors:
+ if self.cx_oracle_ver >= (6, 4):
+ return {"encodingErrors": self.encoding_errors}
+ else:
+ util.warn(
+ "cx_oracle version %r does not support encodingErrors"
+ % (self.cx_oracle_ver,)
+ )
+
+ return {}
+
+ def _parse_cx_oracle_ver(self, version):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+ else:
+ return (0, 0, 0)
+
+ @classmethod
+ def dbapi(cls):
+ import cx_Oracle
+
+ return cx_Oracle
+
+ def initialize(self, connection):
+ super(OracleDialect_cx_oracle, self).initialize(connection)
+ if self._is_oracle_8:
+ self.supports_unicode_binds = False
+
+ self._detect_decimal_char(connection)
+
+ def get_isolation_level(self, connection):
+ # sources:
+
+ # general idea of transaction id, have to start one, etc.
+ # https://stackoverflow.com/questions/10711204/how-to-check-isoloation-level
+
+ # how to decode xid cols from v$transaction to match
+ # https://asktom.oracle.com/pls/apex/f?p=100:11:0::::P11_QUESTION_ID:9532779900346079444
+
+ # Oracle tuple comparison without using IN:
+ # https://www.sql-workbench.eu/comparison/tuple_comparison.html
+
+ with connection.cursor() as cursor:
+ # this is the only way to ensure a transaction is started without
+ # actually running DML. There's no way to see the configured
+ # isolation level without getting it from v$transaction which
+ # means transaction has to be started.
+ outval = cursor.var(str)
+ cursor.execute(
+ """
+ begin
+ :trans_id := dbms_transaction.local_transaction_id( TRUE );
+ end;
+ """,
+ {"trans_id": outval},
+ )
+ trans_id = outval.getvalue()
+ xidusn, xidslot, xidsqn = trans_id.split(".", 2)
+
+ cursor.execute(
+ "SELECT CASE BITAND(t.flag, POWER(2, 28)) "
+ "WHEN 0 THEN 'READ COMMITTED' "
+ "ELSE 'SERIALIZABLE' END AS isolation_level "
+ "FROM v$transaction t WHERE "
+ "(t.xidusn, t.xidslot, t.xidsqn) = "
+ "((:xidusn, :xidslot, :xidsqn))",
+ {"xidusn": xidusn, "xidslot": xidslot, "xidsqn": xidsqn},
+ )
+ row = cursor.fetchone()
+ if row is None:
+ raise exc.InvalidRequestError(
+ "could not retrieve isolation level"
+ )
+ result = row[0]
+
+ return result
+
+ def set_isolation_level(self, connection, level):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+ if level == "AUTOCOMMIT":
+ dbapi_connection.autocommit = True
+ else:
+ dbapi_connection.autocommit = False
+ connection.rollback()
+ with connection.cursor() as cursor:
+ cursor.execute("ALTER SESSION SET ISOLATION_LEVEL=%s" % level)
+
+ def _detect_decimal_char(self, connection):
+ # we have the option to change this setting upon connect,
+ # or just look at what it is upon connect and convert.
+ # to minimize the chance of interference with changes to
+ # NLS_TERRITORY or formatting behavior of the DB, we opt
+ # to just look at it
+
+ self._decimal_char = connection.exec_driver_sql(
+ "select value from nls_session_parameters "
+ "where parameter = 'NLS_NUMERIC_CHARACTERS'"
+ ).scalar()[0]
+ if self._decimal_char != ".":
+ _detect_decimal = self._detect_decimal
+ _to_decimal = self._to_decimal
+
+ self._detect_decimal = lambda value: _detect_decimal(
+ value.replace(self._decimal_char, ".")
+ )
+ self._to_decimal = lambda value: _to_decimal(
+ value.replace(self._decimal_char, ".")
+ )
+
+ def _detect_decimal(self, value):
+ if "." in value:
+ return self._to_decimal(value)
+ else:
+ return int(value)
+
+ _to_decimal = decimal.Decimal
+
+ def _generate_connection_outputtype_handler(self):
+ """establish the default outputtypehandler established at the
+ connection level.
+
+ """
+
+ dialect = self
+ cx_Oracle = dialect.dbapi
+
+ number_handler = _OracleNUMBER(
+ asdecimal=True
+ )._cx_oracle_outputtypehandler(dialect)
+ float_handler = _OracleNUMBER(
+ asdecimal=False
+ )._cx_oracle_outputtypehandler(dialect)
+
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
+
+ if (
+ default_type == cx_Oracle.NUMBER
+ and default_type is not cx_Oracle.NATIVE_FLOAT
+ ):
+ if not dialect.coerce_to_decimal:
+ return None
+ elif precision == 0 and scale in (0, -127):
+ # ambiguous type, this occurs when selecting
+ # numbers from deep subqueries
+ return cursor.var(
+ cx_Oracle.STRING,
+ 255,
+ outconverter=dialect._detect_decimal,
+ arraysize=cursor.arraysize,
+ )
+ elif precision and scale > 0:
+ return number_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+ else:
+ return float_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+
+ # allow all strings to come back natively as Unicode
+ elif (
+ dialect.coerce_to_unicode
+ and default_type
+ in (
+ cx_Oracle.STRING,
+ cx_Oracle.FIXED_CHAR,
+ )
+ and default_type is not cx_Oracle.CLOB
+ and default_type is not cx_Oracle.NCLOB
+ ):
+ if compat.py2k:
+ outconverter = processors.to_unicode_processor_factory(
+ dialect.encoding, errors=dialect.encoding_errors
+ )
+ return cursor.var(
+ cx_Oracle.STRING,
+ size,
+ cursor.arraysize,
+ outconverter=outconverter,
+ )
+ else:
+ return cursor.var(
+ util.text_type,
+ size,
+ cursor.arraysize,
+ **dialect._cursor_var_unicode_kwargs
+ )
+
+ elif dialect.auto_convert_lobs and default_type in (
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ if compat.py2k:
+ outconverter = processors.to_unicode_processor_factory(
+ dialect.encoding, errors=dialect.encoding_errors
+ )
+ return cursor.var(
+ cx_Oracle.LONG_STRING,
+ size,
+ cursor.arraysize,
+ outconverter=outconverter,
+ )
+ else:
+ return cursor.var(
+ cx_Oracle.LONG_STRING,
+ size,
+ cursor.arraysize,
+ **dialect._cursor_var_unicode_kwargs
+ )
+
+ elif dialect.auto_convert_lobs and default_type in (
+ cx_Oracle.BLOB,
+ ):
+ return cursor.var(
+ cx_Oracle.LONG_BINARY,
+ size,
+ cursor.arraysize,
+ )
+
+ return output_type_handler
+
+ def on_connect(self):
+
+ output_type_handler = self._generate_connection_outputtype_handler()
+
+ def on_connect(conn):
+ conn.outputtypehandler = output_type_handler
+
+ return on_connect
+
+ def create_connect_args(self, url):
+ opts = dict(url.query)
+
+ for opt in ("use_ansi", "auto_convert_lobs"):
+ if opt in opts:
+ util.warn_deprecated(
+ "cx_oracle dialect option %r should only be passed to "
+ "create_engine directly, not within the URL string" % opt,
+ version="1.3",
+ )
+ util.coerce_kw_type(opts, opt, bool)
+ setattr(self, opt, opts.pop(opt))
+
+ database = url.database
+ service_name = opts.pop("service_name", None)
+ if database or service_name:
+ # if we have a database, then we have a remote host
+ port = url.port
+ if port:
+ port = int(port)
+ else:
+ port = 1521
+
+ if database and service_name:
+ raise exc.InvalidRequestError(
+ '"service_name" option shouldn\'t '
+ 'be used with a "database" part of the url'
+ )
+ if database:
+ makedsn_kwargs = {"sid": database}
+ if service_name:
+ makedsn_kwargs = {"service_name": service_name}
+
+ dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs)
+ else:
+ # we have a local tnsname
+ dsn = url.host
+
+ if dsn is not None:
+ opts["dsn"] = dsn
+ if url.password is not None:
+ opts["password"] = url.password
+ if url.username is not None:
+ opts["user"] = url.username
+
+ if self._cx_oracle_threaded is not None:
+ opts.setdefault("threaded", self._cx_oracle_threaded)
+
+ def convert_cx_oracle_constant(value):
+ if isinstance(value, util.string_types):
+ try:
+ int_val = int(value)
+ except ValueError:
+ value = value.upper()
+ return getattr(self.dbapi, value)
+ else:
+ return int_val
+ else:
+ return value
+
+ util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant)
+ util.coerce_kw_type(opts, "threaded", bool)
+ util.coerce_kw_type(opts, "events", bool)
+ util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant)
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ return tuple(int(x) for x in connection.connection.version.split("."))
+
+ def is_disconnect(self, e, connection, cursor):
+ (error,) = e.args
+ if isinstance(
+ e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError)
+ ) and "not connected" in str(e):
+ return True
+
+ if hasattr(error, "code") and error.code in {
+ 28,
+ 3114,
+ 3113,
+ 3135,
+ 1033,
+ 2396,
+ }:
+ # ORA-00028: your session has been killed
+ # ORA-03114: not connected to ORACLE
+ # ORA-03113: end-of-file on communication channel
+ # ORA-03135: connection lost contact
+ # ORA-01033: ORACLE initialization or shutdown in progress
+ # ORA-02396: exceeded maximum idle time, please connect again
+ # TODO: Others ?
+ return True
+
+ if re.match(r"^(?:DPI-1010|DPI-1080|DPY-1001|DPY-4011)", str(e)):
+ # DPI-1010: not connected
+ # DPI-1080: connection was closed by ORA-3113
+ # python-oracledb's DPY-1001: not connected to database
+ # python-oracledb's DPY-4011: the database or network closed the
+ # connection
+ # TODO: others?
+ return True
+
+ return False
+
+ def create_xid(self):
+ """create a two-phase transaction ID.
+
+ this id will be passed to do_begin_twophase(), do_rollback_twophase(),
+ do_commit_twophase(). its format is unspecified.
+
+ """
+
+ id_ = random.randint(0, 2 ** 128)
+ return (0x1234, "%032x" % id_, "%032x" % 9)
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if isinstance(parameters, tuple):
+ parameters = list(parameters)
+ cursor.executemany(statement, parameters)
+
+ def do_begin_twophase(self, connection, xid):
+ connection.connection.begin(*xid)
+ connection.connection.info["cx_oracle_xid"] = xid
+
+ def do_prepare_twophase(self, connection, xid):
+ result = connection.connection.prepare()
+ connection.info["cx_oracle_prepared"] = result
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ self.do_rollback(connection.connection)
+ # TODO: need to end XA state here
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+
+ if not is_prepared:
+ self.do_commit(connection.connection)
+ else:
+ if recover:
+ raise NotImplementedError(
+ "2pc recovery not implemented for cx_Oracle"
+ )
+ oci_prepared = connection.info["cx_oracle_prepared"]
+ if oci_prepared:
+ self.do_commit(connection.connection)
+ # TODO: need to end XA state here
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ # not usually used, here to support if someone is modifying
+ # the dialect to use positional style
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ collection = (
+ (key, dbtype)
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ )
+
+ if not self.supports_unicode_binds:
+ # oracle 8 only
+ collection = (
+ (self.dialect._encoder(key)[0], dbtype)
+ for key, dbtype in collection
+ )
+
+ cursor.setinputsizes(**{key: dbtype for key, dbtype in collection})
+
+ def do_recover_twophase(self, connection):
+ raise NotImplementedError(
+ "recover two phase query for cx_Oracle not implemented"
+ )
+
+
+dialect = OracleDialect_cx_oracle
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
new file mode 100644
index 0000000..74ad1f2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -0,0 +1,160 @@
+from ... import create_engine
+from ... import exc
+from ...engine import url as sa_url
+from ...testing.provision import configure_follower
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import follower_url_from_main
+from ...testing.provision import log
+from ...testing.provision import post_configure_engine
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import set_default_schema_on_connection
+from ...testing.provision import stop_test_class_outside_fixtures
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("oracle")
+def _oracle_create_db(cfg, eng, ident):
+ # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
+ # similar, so that the default tablespace is not "system"; reflection will
+ # fail otherwise
+ with eng.begin() as conn:
+ conn.exec_driver_sql("create user %s identified by xe" % ident)
+ conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
+ conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
+ conn.exec_driver_sql("grant dba to %s" % (ident,))
+ conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
+ conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
+ conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
+
+
+@configure_follower.for_db("oracle")
+def _oracle_configure_follower(config, ident):
+ config.test_schema = "%s_ts1" % ident
+ config.test_schema_2 = "%s_ts2" % ident
+
+
+def _ora_drop_ignore(conn, dbname):
+ try:
+ conn.exec_driver_sql("drop user %s cascade" % dbname)
+ log.info("Reaped db: %s", dbname)
+ return True
+ except exc.DatabaseError as err:
+ log.warning("couldn't drop db: %s", err)
+ return False
+
+
+@drop_db.for_db("oracle")
+def _oracle_drop_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ # cx_Oracle seems to occasionally leak open connections when a large
+ # suite it run, even if we confirm we have zero references to
+ # connection objects.
+ # while there is a "kill session" command in Oracle,
+ # it unfortunately does not release the connection sufficiently.
+ _ora_drop_ignore(conn, ident)
+ _ora_drop_ignore(conn, "%s_ts1" % ident)
+ _ora_drop_ignore(conn, "%s_ts2" % ident)
+
+
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
+
+ try:
+ with db.begin() as conn:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
+ conn.exec_driver_sql("purge recyclebin")
+ except exc.DatabaseError as err:
+ log.warning("purge recyclebin command failed: %s", err)
+
+ # clear statement cache on all connections that were used
+ # https://github.com/oracle/python-cx_Oracle/issues/519
+
+ for cx_oracle_conn in _all_conns:
+ try:
+ sc = cx_oracle_conn.stmtcachesize
+ except db.dialect.dbapi.InterfaceError:
+ # connection closed
+ pass
+ else:
+ cx_oracle_conn.stmtcachesize = 0
+ cx_oracle_conn.stmtcachesize = sc
+ _all_conns.clear()
+
+
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "checkout")
+ def checkout(dbapi_con, con_record, con_proxy):
+ _all_conns.add(dbapi_con)
+
+ @event.listens_for(engine, "checkin")
+ def checkin(dbapi_connection, connection_record):
+ # work around cx_Oracle issue:
+ # https://github.com/oracle/python-cx_Oracle/issues/530
+ # invalidate oracle connections that had 2pc set up
+ if "cx_oracle_xid" in connection_record.info:
+ connection_record.invalidate()
+
+
+@run_reap_dbs.for_db("oracle")
+def _reap_oracle_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+ eng = create_engine(url)
+ with eng.begin() as conn:
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+
+ to_reap = conn.exec_driver_sql(
+ "select u.username from all_users u where username "
+ "like 'TEST_%' and not exists (select username "
+ "from v$session where username=u.username)"
+ )
+ all_names = {username.lower() for (username,) in to_reap}
+ to_drop = set()
+ for name in all_names:
+ if name.endswith("_ts1") or name.endswith("_ts2"):
+ continue
+ elif name in idents:
+ to_drop.add(name)
+ if "%s_ts1" % name in all_names:
+ to_drop.add("%s_ts1" % name)
+ if "%s_ts2" % name in all_names:
+ to_drop.add("%s_ts2" % name)
+
+ dropped = total = 0
+ for total, username in enumerate(to_drop, 1):
+ if _ora_drop_ignore(conn, username):
+ dropped += 1
+ log.info(
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
+
+
+@follower_url_from_main.for_db("oracle")
+def _oracle_follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+ return url.set(username=ident, password="xe")
+
+
+@temp_table_keyword_args.for_db("oracle")
+def _oracle_temp_table_keyword_args(cfg, eng):
+ return {
+ "prefixes": ["GLOBAL TEMPORARY"],
+ "oracle_on_commit": "PRESERVE ROWS",
+ }
+
+
+@set_default_schema_on_connection.for_db("oracle")
+def _oracle_set_default_schema_on_connection(
+ cfg, dbapi_connection, schema_name
+):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name)
+ cursor.close()
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
new file mode 100644
index 0000000..12d9e94
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -0,0 +1,117 @@
+# postgresql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from . import base
+from . import pg8000 # noqa
+from . import psycopg2 # noqa
+from . import psycopg2cffi # noqa
+from . import pygresql # noqa
+from . import pypostgresql # noqa
+from .array import All
+from .array import Any
+from .array import ARRAY
+from .array import array
+from .base import BIGINT
+from .base import BIT
+from .base import BOOLEAN
+from .base import BYTEA
+from .base import CHAR
+from .base import CIDR
+from .base import CreateEnumType
+from .base import DATE
+from .base import DOUBLE_PRECISION
+from .base import DropEnumType
+from .base import ENUM
+from .base import FLOAT
+from .base import INET
+from .base import INTEGER
+from .base import INTERVAL
+from .base import MACADDR
+from .base import MONEY
+from .base import NUMERIC
+from .base import OID
+from .base import REAL
+from .base import REGCLASS
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TSVECTOR
+from .base import UUID
+from .base import VARCHAR
+from .dml import Insert
+from .dml import insert
+from .ext import aggregate_order_by
+from .ext import array_agg
+from .ext import ExcludeConstraint
+from .hstore import HSTORE
+from .hstore import hstore
+from .json import JSON
+from .json import JSONB
+from .ranges import DATERANGE
+from .ranges import INT4RANGE
+from .ranges import INT8RANGE
+from .ranges import NUMRANGE
+from .ranges import TSRANGE
+from .ranges import TSTZRANGE
+from ...util import compat
+
+if compat.py3k:
+ from . import asyncpg # noqa
+
+base.dialect = dialect = psycopg2.dialect
+
+
+__all__ = (
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "VARCHAR",
+ "CHAR",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "REAL",
+ "INET",
+ "CIDR",
+ "UUID",
+ "BIT",
+ "MACADDR",
+ "MONEY",
+ "OID",
+ "REGCLASS",
+ "DOUBLE_PRECISION",
+ "TIMESTAMP",
+ "TIME",
+ "DATE",
+ "BYTEA",
+ "BOOLEAN",
+ "INTERVAL",
+ "ARRAY",
+ "ENUM",
+ "dialect",
+ "array",
+ "HSTORE",
+ "hstore",
+ "INT4RANGE",
+ "INT8RANGE",
+ "NUMRANGE",
+ "DATERANGE",
+ "TSVECTOR",
+ "TSRANGE",
+ "TSTZRANGE",
+ "JSON",
+ "JSONB",
+ "Any",
+ "All",
+ "DropEnumType",
+ "CreateEnumType",
+ "ExcludeConstraint",
+ "aggregate_order_by",
+ "array_agg",
+ "insert",
+ "Insert",
+)
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
new file mode 100644
index 0000000..daf7c5d
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -0,0 +1,413 @@
+# postgresql/array.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from ... import types as sqltypes
+from ... import util
+from ...sql import coercions
+from ...sql import expression
+from ...sql import operators
+from ...sql import roles
+
+
+def Any(other, arrexpr, operator=operators.eq):
+ """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
+ See that method for details.
+
+ """
+
+ return arrexpr.any(other, operator)
+
+
+def All(other, arrexpr, operator=operators.eq):
+ """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
+ See that method for details.
+
+ """
+
+ return arrexpr.all(other, operator)
+
+
+class array(expression.ClauseList, expression.ColumnElement):
+
+ """A PostgreSQL ARRAY literal.
+
+ This is used to produce ARRAY literals in SQL expressions, e.g.::
+
+ from sqlalchemy.dialects.postgresql import array
+ from sqlalchemy.dialects import postgresql
+ from sqlalchemy import select, func
+
+ stmt = select(array([1,2]) + array([3,4,5]))
+
+ print(stmt.compile(dialect=postgresql.dialect()))
+
+ Produces the SQL::
+
+ SELECT ARRAY[%(param_1)s, %(param_2)s] ||
+ ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
+
+ An instance of :class:`.array` will always have the datatype
+ :class:`_types.ARRAY`. The "inner" type of the array is inferred from
+ the values present, unless the ``type_`` keyword argument is passed::
+
+ array(['foo', 'bar'], type_=CHAR)
+
+ Multidimensional arrays are produced by nesting :class:`.array` constructs.
+ The dimensionality of the final :class:`_types.ARRAY`
+ type is calculated by
+ recursively adding the dimensions of the inner :class:`_types.ARRAY`
+ type::
+
+ stmt = select(
+ array([
+ array([1, 2]), array([3, 4]), array([column('q'), column('x')])
+ ])
+ )
+ print(stmt.compile(dialect=postgresql.dialect()))
+
+ Produces::
+
+ SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
+ ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
+
+ .. versionadded:: 1.3.6 added support for multidimensional array literals
+
+ .. seealso::
+
+ :class:`_postgresql.ARRAY`
+
+ """
+
+ __visit_name__ = "array"
+
+ stringify_dialect = "postgresql"
+ inherit_cache = True
+
+ def __init__(self, clauses, **kw):
+ clauses = [
+ coercions.expect(roles.ExpressionElementRole, c) for c in clauses
+ ]
+
+ super(array, self).__init__(*clauses, **kw)
+
+ self._type_tuple = [arg.type for arg in clauses]
+ main_type = kw.pop(
+ "type_",
+ self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE,
+ )
+
+ if isinstance(main_type, ARRAY):
+ self.type = ARRAY(
+ main_type.item_type,
+ dimensions=main_type.dimensions + 1
+ if main_type.dimensions is not None
+ else 2,
+ )
+ else:
+ self.type = ARRAY(main_type)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
+ if _assume_scalar or operator is operators.getitem:
+ return expression.BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ )
+
+ else:
+ return array(
+ [
+ self._bind_param(
+ operator, o, _assume_scalar=True, type_=type_
+ )
+ for o in obj
+ ]
+ )
+
+ def self_group(self, against=None):
+ if against in (operators.any_op, operators.all_op, operators.getitem):
+ return expression.Grouping(self)
+ else:
+ return self
+
+
+CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True)
+
+CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True)
+
+OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True)
+
+
+class ARRAY(sqltypes.ARRAY):
+
+ """PostgreSQL ARRAY type.
+
+ .. versionchanged:: 1.1 The :class:`_postgresql.ARRAY` type is now
+ a subclass of the core :class:`_types.ARRAY` type.
+
+ The :class:`_postgresql.ARRAY` type is constructed in the same way
+ as the core :class:`_types.ARRAY` type; a member type is required, and a
+ number of dimensions is recommended if the type is to be used for more
+ than one dimension::
+
+ from sqlalchemy.dialects import postgresql
+
+ mytable = Table("mytable", metadata,
+ Column("data", postgresql.ARRAY(Integer, dimensions=2))
+ )
+
+ The :class:`_postgresql.ARRAY` type provides all operations defined on the
+ core :class:`_types.ARRAY` type, including support for "dimensions",
+ indexed access, and simple matching such as
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY`
+ class also
+ provides PostgreSQL-specific methods for containment operations, including
+ :meth:`.postgresql.ARRAY.Comparator.contains`
+ :meth:`.postgresql.ARRAY.Comparator.contained_by`, and
+ :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
+
+ mytable.c.data.contains([1, 2])
+
+ The :class:`_postgresql.ARRAY` type may not be supported on all
+ PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
+
+ Additionally, the :class:`_postgresql.ARRAY`
+ type does not work directly in
+ conjunction with the :class:`.ENUM` type. For a workaround, see the
+ special type at :ref:`postgresql_array_of_enum`.
+
+ .. seealso::
+
+ :class:`_types.ARRAY` - base array type
+
+ :class:`_postgresql.array` - produces a literal array value.
+
+ """
+
+ class Comparator(sqltypes.ARRAY.Comparator):
+
+ """Define comparison operations for :class:`_types.ARRAY`.
+
+ Note that these operations are in addition to those provided
+ by the base :class:`.types.ARRAY.Comparator` class, including
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`.
+
+ """
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if elements are a superset of the
+ elements of the argument array expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if elements are a proper subset of the
+ elements of the argument array expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ def overlap(self, other):
+ """Boolean expression. Test if array has elements in common with
+ an argument array expression.
+ """
+ return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
+
+ comparator_factory = Comparator
+
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
+ """Construct an ARRAY.
+
+ E.g.::
+
+ Column('myarray', ARRAY(Integer))
+
+ Arguments are:
+
+ :param item_type: The data type of items of this array. Note that
+ dimensionality is irrelevant here, so multi-dimensional arrays like
+ ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
+ ``ARRAY(ARRAY(Integer))`` or such.
+
+ :param as_tuple=False: Specify whether return results
+ should be converted to tuples from lists. DBAPIs such
+ as psycopg2 return lists by default. When tuples are
+ returned, the results are hashable.
+
+ :param dimensions: if non-None, the ARRAY will assume a fixed
+ number of dimensions. This will cause the DDL emitted for this
+ ARRAY to include the exact number of bracket clauses ``[]``,
+ and will also optimize the performance of the type overall.
+ Note that PG arrays are always implicitly "non-dimensioned",
+ meaning they can store any number of dimensions no matter how
+ they were declared.
+
+ :param zero_indexes=False: when True, index values will be converted
+ between Python zero-based and PostgreSQL one-based indexes, e.g.
+ a value of one will be added to all index values before passing
+ to the database.
+
+ .. versionadded:: 0.9.5
+
+
+ """
+ if isinstance(item_type, ARRAY):
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ self.as_tuple = as_tuple
+ self.dimensions = dimensions
+ self.zero_indexes = zero_indexes
+
+ @property
+ def hashable(self):
+ return self.as_tuple
+
+ @property
+ def python_type(self):
+ return list
+
+ def compare_values(self, x, y):
+ return x == y
+
+ def _proc_array(self, arr, itemproc, dim, collection):
+ if dim is None:
+ arr = list(arr)
+ if (
+ dim == 1
+ or dim is None
+ and (
+ # this has to be (list, tuple), or at least
+ # not hasattr('__iter__'), since Py3K strings
+ # etc. have __iter__
+ not arr
+ or not isinstance(arr[0], (list, tuple))
+ )
+ ):
+ if itemproc:
+ return collection(itemproc(x) for x in arr)
+ else:
+ return collection(arr)
+ else:
+ return collection(
+ self._proc_array(
+ x,
+ itemproc,
+ dim - 1 if dim is not None else None,
+ collection,
+ )
+ for x in arr
+ )
+
+ @util.memoized_property
+ def _against_native_enum(self):
+ return (
+ isinstance(self.item_type, sqltypes.Enum)
+ and self.item_type.native_enum
+ )
+
+ def bind_expression(self, bindvalue):
+ return bindvalue
+
+ def bind_processor(self, dialect):
+ item_proc = self.item_type.dialect_impl(dialect).bind_processor(
+ dialect
+ )
+
+ def process(value):
+ if value is None:
+ return value
+ else:
+ return self._proc_array(
+ value, item_proc, self.dimensions, list
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ item_proc = self.item_type.dialect_impl(dialect).result_processor(
+ dialect, coltype
+ )
+
+ def process(value):
+ if value is None:
+ return value
+ else:
+ return self._proc_array(
+ value,
+ item_proc,
+ self.dimensions,
+ tuple if self.as_tuple else list,
+ )
+
+ if self._against_native_enum:
+ super_rp = process
+ pattern = re.compile(r"^{(.*)}$")
+
+ def handle_raw_string(value):
+ inner = pattern.match(value).group(1)
+ return _split_enum_values(inner)
+
+ def process(value):
+ if value is None:
+ return value
+ # isinstance(value, util.string_types) is required to handle
+ # the case where a TypeDecorator for and Array of Enum is
+ # used like was required in sa < 1.3.17
+ return super_rp(
+ handle_raw_string(value)
+ if isinstance(value, util.string_types)
+ else value
+ )
+
+ return process
+
+
+def _split_enum_values(array_string):
+
+ if '"' not in array_string:
+ # no escape char is present so it can just split on the comma
+ return array_string.split(",") if array_string else []
+
+ # handles quoted strings from:
+ # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
+ # returns
+ # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
+ text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
+ text = text.replace(r"\\", "\\")
+ result = []
+ on_quotes = re.split(r'(")', text)
+ in_quotes = False
+ for tok in on_quotes:
+ if tok == '"':
+ in_quotes = not in_quotes
+ elif in_quotes:
+ result.append(tok.replace("_$ESC_QUOTE$_", '"'))
+ else:
+ result.extend(re.findall(r"([^\s,]+),?", tok))
+ return result
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
new file mode 100644
index 0000000..305ad46
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -0,0 +1,1112 @@
+# postgresql/asyncpg.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+asyncpg
+ :name: asyncpg
+ :dbapi: asyncpg
+ :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://magicstack.github.io/asyncpg/
+
+The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
+
+Using a special asyncio mediation layer, the asyncpg dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
+
+The dialect can also be run as a "synchronous" dialect within the
+:func:`_sa.create_engine` function, which will pass "await" calls into
+an ad-hoc event loop. This mode of operation is of **limited use**
+and is for special testing scenarios only. The mode can be enabled by
+adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
+in conjunction with :func:`_sa.create_engine`::
+
+ # for testing purposes only; do not use in production!
+ engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
+
+
+.. versionadded:: 1.4
+
+.. note::
+
+ By default asyncpg does not decode the ``json`` and ``jsonb`` types and
+ returns them as strings. SQLAlchemy sets default type decoder for ``json``
+ and ``jsonb`` types using the python builtin ``json.loads`` function.
+ The json implementation used can be changed by setting the attribute
+ ``json_deserializer`` when creating the engine with
+ :func:`create_engine` or :func:`create_async_engine`.
+
+
+.. _asyncpg_prepared_statement_cache:
+
+Prepared Statement Cache
+--------------------------
+
+The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
+for all statements. The prepared statement objects are cached after
+construction which appears to grant a 10% or more performance improvement for
+statement invocation. The cache is on a per-DBAPI connection basis, which
+means that the primary storage for prepared statements is within DBAPI
+connections pooled within the connection pool. The size of this cache
+defaults to 100 statements per DBAPI connection and may be adjusted using the
+``prepared_statement_cache_size`` DBAPI argument (note that while this argument
+is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
+asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
+argument)::
+
+
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
+
+To disable the prepared statement cache, use a value of zero::
+
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
+
+.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
+
+
+.. warning:: The ``asyncpg`` database driver necessarily uses caches for
+ PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
+ such as ``ENUM`` objects are changed via DDL operations. Additionally,
+ prepared statements themselves which are optionally cached by SQLAlchemy's
+ driver as described above may also become "stale" when DDL has been emitted
+ to the PostgreSQL database which modifies the tables or other objects
+ involved in a particular prepared statement.
+
+ The SQLAlchemy asyncpg dialect will invalidate these caches within its local
+ process when statements that represent DDL are emitted on a local
+ connection, but this is only controllable within a single Python process /
+ database engine. If DDL changes are made from other database engines
+ and/or processes, a running application may encounter asyncpg exceptions
+ ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
+ failed for type <oid>")`` if it refers to pooled database connections which
+ operated upon the previous structures. The SQLAlchemy asyncpg dialect will
+ recover from these error cases when the driver raises these exceptions by
+ clearing its internal caches as well as those of the asyncpg driver in
+ response to them, but cannot prevent them from being raised in the first
+ place if the cached prepared statement or asyncpg type caches have gone
+ stale, nor can it retry the statement as the PostgreSQL transaction is
+ invalidated when these errors occur.
+
+Disabling the PostgreSQL JIT to improve ENUM datatype handling
+---------------------------------------------------------------
+
+Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
+using PostgreSQL ENUM datatypes, where upon the creation of new database
+connections, an expensive query may be emitted in order to retrieve metadata
+regarding custom types which has been shown to negatively affect performance.
+To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
+client using this setting passed to :func:`_asyncio.create_async_engine`::
+
+ engine = create_async_engine(
+ "postgresql+asyncpg://user:password@localhost/tmp",
+ connect_args={"server_settings": {"jit": "off"}},
+ )
+
+.. seealso::
+
+ https://github.com/MagicStack/asyncpg/issues/727
+
+""" # noqa
+
+import collections
+import decimal
+import json as _py_json
+import re
+import time
+
+from . import json
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
+from .base import OID
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import REGCLASS
+from .base import UUID
+from ... import exc
+from ... import pool
+from ... import processors
+from ... import util
+from ...engine import AdaptedConnection
+from ...sql import sqltypes
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+try:
+ from uuid import UUID as _python_UUID # noqa
+except ImportError:
+ _python_UUID = None
+
+
+class AsyncpgTime(sqltypes.Time):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ return dbapi.TIME_W_TZ
+ else:
+ return dbapi.TIME
+
+
+class AsyncpgDate(sqltypes.Date):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATE
+
+
+class AsyncpgDateTime(sqltypes.DateTime):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ return dbapi.TIMESTAMP_W_TZ
+ else:
+ return dbapi.TIMESTAMP
+
+
+class AsyncpgBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BOOLEAN
+
+
+class AsyncPgInterval(INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+
+ return AsyncPgInterval(precision=interval.second_precision)
+
+
+class AsyncPgEnum(ENUM):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.ENUM
+
+
+class AsyncpgInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class AsyncpgBigInteger(sqltypes.BigInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BIGINTEGER
+
+
+class AsyncpgJSON(json.JSON):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSON
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class AsyncpgJSONB(json.JSONB):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSONB
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
+ def get_dbapi_type(self, dbapi):
+ raise NotImplementedError("should not be here")
+
+
+class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class AsyncpgJSONPathType(json.JSONPathType):
+ def bind_processor(self, dialect):
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ return tokens
+
+ return process
+
+
+class AsyncpgUUID(UUID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.UUID
+
+ def bind_processor(self, dialect):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+class AsyncpgNumeric(sqltypes.Numeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class AsyncpgFloat(AsyncpgNumeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FLOAT
+
+
+class AsyncpgREGCLASS(REGCLASS):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class AsyncpgOID(OID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class PGExecutionContext_asyncpg(PGExecutionContext):
+ def handle_dbapi_exception(self, e):
+ if isinstance(
+ e,
+ (
+ self.dialect.dbapi.InvalidCachedStatementError,
+ self.dialect.dbapi.InternalServerError,
+ ),
+ ):
+ self.dialect._invalidate_schema_cache()
+
+ def pre_exec(self):
+ if self.isddl:
+ self.dialect._invalidate_schema_cache()
+
+ self.cursor._invalidate_schema_cache_asof = (
+ self.dialect._invalidate_schema_cache_asof
+ )
+
+ if not self.compiled:
+ return
+
+ # we have to exclude ENUM because "enum" not really a "type"
+ # we can cast to, it has to be the name of the type itself.
+ # for now we just omit it from casting
+ self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
+
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(server_side=True)
+
+
+class PGCompiler_asyncpg(PGCompiler):
+ pass
+
+
+class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
+ pass
+
+
+class AsyncAdapt_asyncpg_cursor:
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "_rows",
+ "description",
+ "arraysize",
+ "rowcount",
+ "_inputsizes",
+ "_cursor",
+ "_invalidate_schema_cache_asof",
+ )
+
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self._rows = []
+ self._cursor = None
+ self.description = None
+ self.arraysize = 1
+ self.rowcount = -1
+ self._inputsizes = None
+ self._invalidate_schema_cache_asof = 0
+
+ def close(self):
+ self._rows[:] = []
+
+ def _handle_exception(self, error):
+ self._adapt_connection._handle_exception(error)
+
+ def _parameter_placeholders(self, params):
+ if not self._inputsizes:
+ return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
+ else:
+ return tuple(
+ "$%d::%s" % (idx, typ) if typ else "$%d" % idx
+ for idx, typ in enumerate(
+ (_pg_types.get(typ) for typ in self._inputsizes), 1
+ )
+ )
+
+ async def _prepare_and_execute(self, operation, parameters):
+ adapt_connection = self._adapt_connection
+
+ async with adapt_connection._execute_mutex:
+
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
+
+ if parameters is not None:
+ operation = operation % self._parameter_placeholders(
+ parameters
+ )
+ else:
+ parameters = ()
+
+ try:
+ prepared_stmt, attributes = await adapt_connection._prepare(
+ operation, self._invalidate_schema_cache_asof
+ )
+
+ if attributes:
+ self.description = [
+ (
+ attr.name,
+ attr.type.oid,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ for attr in attributes
+ ]
+ else:
+ self.description = None
+
+ if self.server_side:
+ self._cursor = await prepared_stmt.cursor(*parameters)
+ self.rowcount = -1
+ else:
+ self._rows = await prepared_stmt.fetch(*parameters)
+ status = prepared_stmt.get_statusmsg()
+
+ reg = re.match(
+ r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status
+ )
+ if reg:
+ self.rowcount = int(reg.group(1))
+ else:
+ self.rowcount = -1
+
+ except Exception as error:
+ self._handle_exception(error)
+
+ async def _executemany(self, operation, seq_of_parameters):
+ adapt_connection = self._adapt_connection
+
+ async with adapt_connection._execute_mutex:
+ await adapt_connection._check_type_cache_invalidation(
+ self._invalidate_schema_cache_asof
+ )
+
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
+
+ operation = operation % self._parameter_placeholders(
+ seq_of_parameters[0]
+ )
+
+ try:
+ return await self._connection.executemany(
+ operation, seq_of_parameters
+ )
+ except Exception as error:
+ self._handle_exception(error)
+
+ def execute(self, operation, parameters=None):
+ self._adapt_connection.await_(
+ self._prepare_and_execute(operation, parameters)
+ )
+
+ def executemany(self, operation, seq_of_parameters):
+ return self._adapt_connection.await_(
+ self._executemany(operation, seq_of_parameters)
+ )
+
+ def setinputsizes(self, *inputsizes):
+ self._inputsizes = inputsizes
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
+
+ server_side = True
+ __slots__ = ("_rowbuffer",)
+
+ def __init__(self, adapt_connection):
+ super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection)
+ self._rowbuffer = None
+
+ def close(self):
+ self._cursor = None
+ self._rowbuffer = None
+
+ def _buffer_rows(self):
+ new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
+ self._rowbuffer = collections.deque(new_rows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if not self._rowbuffer:
+ self._buffer_rows()
+
+ while True:
+ while self._rowbuffer:
+ yield self._rowbuffer.popleft()
+
+ self._buffer_rows()
+ if not self._rowbuffer:
+ break
+
+ def fetchone(self):
+ if not self._rowbuffer:
+ self._buffer_rows()
+ if not self._rowbuffer:
+ return None
+ return self._rowbuffer.popleft()
+
+ def fetchmany(self, size=None):
+ if size is None:
+ return self.fetchall()
+
+ if not self._rowbuffer:
+ self._buffer_rows()
+
+ buf = list(self._rowbuffer)
+ lb = len(buf)
+ if size > lb:
+ buf.extend(
+ self._adapt_connection.await_(self._cursor.fetch(size - lb))
+ )
+
+ result = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ return result
+
+ def fetchall(self):
+ ret = list(self._rowbuffer) + list(
+ self._adapt_connection.await_(self._all())
+ )
+ self._rowbuffer.clear()
+ return ret
+
+ async def _all(self):
+ rows = []
+
+ # TODO: looks like we have to hand-roll some kind of batching here.
+ # hardcoding for the moment but this should be improved.
+ while True:
+ batch = await self._cursor.fetch(1000)
+ if batch:
+ rows.extend(batch)
+ continue
+ else:
+ break
+ return rows
+
+ def executemany(self, operation, seq_of_parameters):
+ raise NotImplementedError(
+ "server side cursor doesn't support executemany yet"
+ )
+
+
+class AsyncAdapt_asyncpg_connection(AdaptedConnection):
+ __slots__ = (
+ "dbapi",
+ "_connection",
+ "isolation_level",
+ "_isolation_setting",
+ "readonly",
+ "deferrable",
+ "_transaction",
+ "_started",
+ "_prepared_statement_cache",
+ "_invalidate_schema_cache_asof",
+ "_execute_mutex",
+ )
+
+ await_ = staticmethod(await_only)
+
+ def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
+ self.dbapi = dbapi
+ self._connection = connection
+ self.isolation_level = self._isolation_setting = "read_committed"
+ self.readonly = False
+ self.deferrable = False
+ self._transaction = None
+ self._started = False
+ self._invalidate_schema_cache_asof = time.time()
+ self._execute_mutex = asyncio.Lock()
+
+ if prepared_statement_cache_size:
+ self._prepared_statement_cache = util.LRUCache(
+ prepared_statement_cache_size
+ )
+ else:
+ self._prepared_statement_cache = None
+
+ async def _check_type_cache_invalidation(self, invalidate_timestamp):
+ if invalidate_timestamp > self._invalidate_schema_cache_asof:
+ await self._connection.reload_schema_state()
+ self._invalidate_schema_cache_asof = invalidate_timestamp
+
+ async def _prepare(self, operation, invalidate_timestamp):
+ await self._check_type_cache_invalidation(invalidate_timestamp)
+
+ cache = self._prepared_statement_cache
+ if cache is None:
+ prepared_stmt = await self._connection.prepare(operation)
+ attributes = prepared_stmt.get_attributes()
+ return prepared_stmt, attributes
+
+ # asyncpg uses a type cache for the "attributes" which seems to go
+ # stale independently of the PreparedStatement itself, so place that
+ # collection in the cache as well.
+ if operation in cache:
+ prepared_stmt, attributes, cached_timestamp = cache[operation]
+
+ # preparedstatements themselves also go stale for certain DDL
+ # changes such as size of a VARCHAR changing, so there is also
+ # a cross-connection invalidation timestamp
+ if cached_timestamp > invalidate_timestamp:
+ return prepared_stmt, attributes
+
+ prepared_stmt = await self._connection.prepare(operation)
+ attributes = prepared_stmt.get_attributes()
+ cache[operation] = (prepared_stmt, attributes, time.time())
+
+ return prepared_stmt, attributes
+
+ def _handle_exception(self, error):
+ if self._connection.is_closed():
+ self._transaction = None
+ self._started = False
+
+ if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
+ exception_mapping = self.dbapi._asyncpg_error_translate
+
+ for super_ in type(error).__mro__:
+ if super_ in exception_mapping:
+ translated_error = exception_mapping[super_](
+ "%s: %s" % (type(error), error)
+ )
+ translated_error.pgcode = (
+ translated_error.sqlstate
+ ) = getattr(error, "sqlstate", None)
+ raise translated_error from error
+ else:
+ raise error
+ else:
+ raise error
+
+ @property
+ def autocommit(self):
+ return self.isolation_level == "autocommit"
+
+ @autocommit.setter
+ def autocommit(self, value):
+ if value:
+ self.isolation_level = "autocommit"
+ else:
+ self.isolation_level = self._isolation_setting
+
+ def set_isolation_level(self, level):
+ if self._started:
+ self.rollback()
+ self.isolation_level = self._isolation_setting = level
+
+ async def _start_transaction(self):
+ if self.isolation_level == "autocommit":
+ return
+
+ try:
+ self._transaction = self._connection.transaction(
+ isolation=self.isolation_level,
+ readonly=self.readonly,
+ deferrable=self.deferrable,
+ )
+ await self._transaction.start()
+ except Exception as error:
+ self._handle_exception(error)
+ else:
+ self._started = True
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_asyncpg_ss_cursor(self)
+ else:
+ return AsyncAdapt_asyncpg_cursor(self)
+
+ def rollback(self):
+ if self._started:
+ try:
+ self.await_(self._transaction.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
+
+ def commit(self):
+ if self._started:
+ try:
+ self.await_(self._transaction.commit())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
+
+ def close(self):
+ self.rollback()
+
+ self.await_(self._connection.close())
+
+
+class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_asyncpg_dbapi:
+ def __init__(self, asyncpg):
+ self.asyncpg = asyncpg
+ self.paramstyle = "format"
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+ prepared_statement_cache_size = kw.pop(
+ "prepared_statement_cache_size", 100
+ )
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_asyncpg_connection(
+ self,
+ await_fallback(self.asyncpg.connect(*arg, **kw)),
+ prepared_statement_cache_size=prepared_statement_cache_size,
+ )
+ else:
+ return AsyncAdapt_asyncpg_connection(
+ self,
+ await_only(self.asyncpg.connect(*arg, **kw)),
+ prepared_statement_cache_size=prepared_statement_cache_size,
+ )
+
+ class Error(Exception):
+ pass
+
+ class Warning(Exception): # noqa
+ pass
+
+ class InterfaceError(Error):
+ pass
+
+ class DatabaseError(Error):
+ pass
+
+ class InternalError(DatabaseError):
+ pass
+
+ class OperationalError(DatabaseError):
+ pass
+
+ class ProgrammingError(DatabaseError):
+ pass
+
+ class IntegrityError(DatabaseError):
+ pass
+
+ class DataError(DatabaseError):
+ pass
+
+ class NotSupportedError(DatabaseError):
+ pass
+
+ class InternalServerError(InternalError):
+ pass
+
+ class InvalidCachedStatementError(NotSupportedError):
+ def __init__(self, message):
+ super(
+ AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self
+ ).__init__(
+ message + " (SQLAlchemy asyncpg dialect will now invalidate "
+ "all prepared caches in response to this exception)",
+ )
+
+ @util.memoized_property
+ def _asyncpg_error_translate(self):
+ import asyncpg
+
+ return {
+ asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
+ asyncpg.exceptions.PostgresError: self.Error,
+ asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
+ asyncpg.exceptions.InterfaceError: self.InterfaceError,
+ asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
+ asyncpg.exceptions.InternalServerError: self.InternalServerError,
+ }
+
+ def Binary(self, value):
+ return value
+
+ STRING = util.symbol("STRING")
+ TIMESTAMP = util.symbol("TIMESTAMP")
+ TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
+ TIME = util.symbol("TIME")
+ TIME_W_TZ = util.symbol("TIME_W_TZ")
+ DATE = util.symbol("DATE")
+ INTERVAL = util.symbol("INTERVAL")
+ NUMBER = util.symbol("NUMBER")
+ FLOAT = util.symbol("FLOAT")
+ BOOLEAN = util.symbol("BOOLEAN")
+ INTEGER = util.symbol("INTEGER")
+ BIGINTEGER = util.symbol("BIGINTEGER")
+ BYTES = util.symbol("BYTES")
+ DECIMAL = util.symbol("DECIMAL")
+ JSON = util.symbol("JSON")
+ JSONB = util.symbol("JSONB")
+ ENUM = util.symbol("ENUM")
+ UUID = util.symbol("UUID")
+ BYTEA = util.symbol("BYTEA")
+
+ DATETIME = TIMESTAMP
+ BINARY = BYTEA
+
+
+_pg_types = {
+ AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
+ AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
+ AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
+ AsyncAdapt_asyncpg_dbapi.DATE: "date",
+ AsyncAdapt_asyncpg_dbapi.TIME: "time",
+ AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
+ AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
+ AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
+ AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
+ AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
+ AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
+ AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
+ AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
+ AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
+ AsyncAdapt_asyncpg_dbapi.JSON: "json",
+ AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
+ AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
+ AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
+ AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
+}
+
+
+class PGDialect_asyncpg(PGDialect):
+ driver = "asyncpg"
+ supports_statement_cache = True
+
+ supports_unicode_statements = True
+ supports_server_side_cursors = True
+
+ supports_unicode_binds = True
+
+ default_paramstyle = "format"
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PGExecutionContext_asyncpg
+ statement_compiler = PGCompiler_asyncpg
+ preparer = PGIdentifierPreparer_asyncpg
+
+ use_setinputsizes = True
+
+ use_native_uuid = True
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Time: AsyncpgTime,
+ sqltypes.Date: AsyncpgDate,
+ sqltypes.DateTime: AsyncpgDateTime,
+ sqltypes.Interval: AsyncPgInterval,
+ INTERVAL: AsyncPgInterval,
+ UUID: AsyncpgUUID,
+ sqltypes.Boolean: AsyncpgBoolean,
+ sqltypes.Integer: AsyncpgInteger,
+ sqltypes.BigInteger: AsyncpgBigInteger,
+ sqltypes.Numeric: AsyncpgNumeric,
+ sqltypes.Float: AsyncpgFloat,
+ sqltypes.JSON: AsyncpgJSON,
+ json.JSONB: AsyncpgJSONB,
+ sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
+ sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
+ sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
+ sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
+ sqltypes.Enum: AsyncPgEnum,
+ OID: AsyncpgOID,
+ REGCLASS: AsyncpgREGCLASS,
+ },
+ )
+ is_async = True
+ _invalidate_schema_cache_asof = 0
+
+ def _invalidate_schema_cache(self):
+ self._invalidate_schema_cache_asof = time.time()
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
+
+ @util.memoized_property
+ def _isolation_lookup(self):
+ return {
+ "AUTOCOMMIT": "autocommit",
+ "READ COMMITTED": "read_committed",
+ "REPEATABLE READ": "repeatable_read",
+ "SERIALIZABLE": "serializable",
+ }
+
+ def set_isolation_level(self, connection, level):
+ try:
+ level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ ),
+ replace_context=err,
+ )
+
+ connection.set_isolation_level(level)
+
+ def set_readonly(self, connection, value):
+ connection.readonly = value
+
+ def get_readonly(self, connection):
+ return connection.readonly
+
+ def set_deferrable(self, connection, value):
+ connection.deferrable = value
+
+ def get_deferrable(self, connection):
+ return connection.deferrable
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ opts.update(url.query)
+ util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
+ util.coerce_kw_type(opts, "port", int)
+ return ([], opts)
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def is_disconnect(self, e, connection, cursor):
+ if connection:
+ return connection._connection.is_closed()
+ else:
+ return isinstance(
+ e, self.dbapi.InterfaceError
+ ) and "connection is closed" in str(e)
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
+ async def setup_asyncpg_json_codec(self, conn):
+ """set up JSON codec for asyncpg.
+
+ This occurs for all new connections and
+ can be overridden by third party dialects.
+
+ .. versionadded:: 1.4.27
+
+ """
+
+ asyncpg_connection = conn._connection
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _json_decoder(bin_value):
+ return deserializer(bin_value.decode())
+
+ await asyncpg_connection.set_type_codec(
+ "json",
+ encoder=str.encode,
+ decoder=_json_decoder,
+ schema="pg_catalog",
+ format="binary",
+ )
+
+ async def setup_asyncpg_jsonb_codec(self, conn):
+ """set up JSONB codec for asyncpg.
+
+ This occurs for all new connections and
+ can be overridden by third party dialects.
+
+ .. versionadded:: 1.4.27
+
+ """
+
+ asyncpg_connection = conn._connection
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _jsonb_encoder(str_value):
+ # \x01 is the prefix for jsonb used by PostgreSQL.
+ # asyncpg requires it when format='binary'
+ return b"\x01" + str_value.encode()
+
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _jsonb_decoder(bin_value):
+ # the byte is the \x01 prefix for jsonb used by PostgreSQL.
+ # asyncpg returns it when format='binary'
+ return deserializer(bin_value[1:].decode())
+
+ await asyncpg_connection.set_type_codec(
+ "jsonb",
+ encoder=_jsonb_encoder,
+ decoder=_jsonb_decoder,
+ schema="pg_catalog",
+ format="binary",
+ )
+
+ def on_connect(self):
+ """on_connect for asyncpg
+
+ A major component of this for asyncpg is to set up type decoders at the
+ asyncpg level.
+
+ See https://github.com/MagicStack/asyncpg/issues/623 for
+ notes on JSON/JSONB implementation.
+
+ """
+
+ super_connect = super(PGDialect_asyncpg, self).on_connect()
+
+ def connect(conn):
+ conn.await_(self.setup_asyncpg_json_codec(conn))
+ conn.await_(self.setup_asyncpg_jsonb_codec(conn))
+ if super_connect is not None:
+ super_connect(conn)
+
+ return connect
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = PGDialect_asyncpg
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
new file mode 100644
index 0000000..eb84170
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -0,0 +1,4651 @@
+# postgresql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: postgresql
+ :name: PostgreSQL
+ :full_support: 9.6, 10, 11, 12, 13, 14
+ :normal_support: 9.6+
+ :best_effort: 8+
+
+.. _postgresql_sequences:
+
+Sequences/SERIAL/IDENTITY
+-------------------------
+
+PostgreSQL supports sequences, and SQLAlchemy uses these as the default means
+of creating new primary key values for integer-based primary key columns. When
+creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for
+integer-based primary key columns, which generates a sequence and server side
+default corresponding to the column.
+
+To specify a specific named sequence to be used for primary key generation,
+use the :func:`~sqlalchemy.schema.Sequence` construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
+ )
+
+When SQLAlchemy issues a single INSERT statement, to fulfill the contract of
+having the "last insert identifier" available, a RETURNING clause is added to
+the INSERT statement which specifies the primary key columns should be
+returned after the statement completes. The RETURNING functionality only takes
+place if PostgreSQL 8.2 or later is in use. As a fallback approach, the
+sequence, whether specified explicitly or implicitly via ``SERIAL``, is
+executed independently beforehand, the returned value to be used in the
+subsequent insert. Note that when an
+:func:`~sqlalchemy.sql.expression.insert()` construct is executed using
+"executemany" semantics, the "last inserted identifier" functionality does not
+apply; no RETURNING clause is emitted nor is the sequence pre-executed in this
+case.
+
+To force the usage of RETURNING by default off, specify the flag
+``implicit_returning=False`` to :func:`_sa.create_engine`.
+
+PostgreSQL 10 and above IDENTITY columns
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL 10 and above have a new IDENTITY feature that supersedes the use
+of SERIAL. The :class:`_schema.Identity` construct in a
+:class:`_schema.Column` can be used to control its behavior::
+
+ from sqlalchemy import Table, Column, MetaData, Integer, Computed
+
+ metadata = MetaData()
+
+ data = Table(
+ "data",
+ metadata,
+ Column(
+ 'id', Integer, Identity(start=42, cycle=True), primary_key=True
+ ),
+ Column('data', String)
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE data (
+ id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 42 CYCLE),
+ data VARCHAR,
+ PRIMARY KEY (id)
+ )
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the option of an autoincrementing
+ column.
+
+.. note::
+
+ Previous versions of SQLAlchemy did not have built-in support for rendering
+ of IDENTITY, and could use the following compilation hook to replace
+ occurrences of SERIAL with IDENTITY::
+
+ from sqlalchemy.schema import CreateColumn
+ from sqlalchemy.ext.compiler import compiles
+
+
+ @compiles(CreateColumn, 'postgresql')
+ def use_identity(element, compiler, **kw):
+ text = compiler.visit_create_column(element, **kw)
+ text = text.replace(
+ "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY"
+ )
+ return text
+
+ Using the above, a table such as::
+
+ t = Table(
+ 't', m,
+ Column('id', Integer, primary_key=True),
+ Column('data', String)
+ )
+
+ Will generate on the backing database as::
+
+ CREATE TABLE t (
+ id INT GENERATED BY DEFAULT AS IDENTITY,
+ data VARCHAR,
+ PRIMARY KEY (id)
+ )
+
+.. _postgresql_ss_cursors:
+
+Server Side Cursors
+-------------------
+
+Server-side cursor support is available for the psycopg2, asyncpg
+dialects and may also be available in others.
+
+Server side cursors are enabled on a per-statement basis by using the
+:paramref:`.Connection.execution_options.stream_results` connection execution
+option::
+
+ with engine.connect() as conn:
+ result = conn.execution_options(stream_results=True).execute(text("select * from table"))
+
+Note that some kinds of SQL statements may not be supported with
+server side cursors; generally, only SQL statements that return rows should be
+used with this option.
+
+.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated
+ and will be removed in a future release. Please use the
+ :paramref:`_engine.Connection.stream_results` execution option for
+ unbuffered cursor support.
+
+.. seealso::
+
+ :ref:`engine_stream_results`
+
+.. _postgresql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+Most SQLAlchemy dialects support setting of transaction isolation level
+using the :paramref:`_sa.create_engine.isolation_level` parameter
+at the :func:`_sa.create_engine` level, and at the :class:`_engine.Connection`
+level via the :paramref:`.Connection.execution_options.isolation_level`
+parameter.
+
+For PostgreSQL dialects, this feature works either by making use of the
+DBAPI-specific features, such as psycopg2's isolation level flags which will
+embed the isolation level setting inline with the ``"BEGIN"`` statement, or for
+DBAPIs with no direct support by emitting ``SET SESSION CHARACTERISTICS AS
+TRANSACTION ISOLATION LEVEL <level>`` ahead of the ``"BEGIN"`` statement
+emitted by the DBAPI. For the special AUTOCOMMIT isolation level,
+DBAPI-specific techniques are used which is typically an ``.autocommit``
+flag on the DBAPI connection object.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+pg8000://scott:tiger@localhost/test",
+ isolation_level = "REPEATABLE READ"
+ )
+
+To set using per-connection execution options::
+
+ with engine.connect() as conn:
+ conn = conn.execution_options(
+ isolation_level="REPEATABLE READ"
+ )
+ with conn.begin():
+ # ... work with transaction
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+Valid values for ``isolation_level`` on most PostgreSQL dialects include:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+ :ref:`postgresql_readonly_deferrable`
+
+ :ref:`psycopg2_isolation_level`
+
+ :ref:`pg8000_isolation_level`
+
+.. _postgresql_readonly_deferrable:
+
+Setting READ ONLY / DEFERRABLE
+------------------------------
+
+Most PostgreSQL dialects support setting the "READ ONLY" and "DEFERRABLE"
+characteristics of the transaction, which is in addition to the isolation level
+setting. These two attributes can be established either in conjunction with or
+independently of the isolation level by passing the ``postgresql_readonly`` and
+``postgresql_deferrable`` flags with
+:meth:`_engine.Connection.execution_options`. The example below illustrates
+passing the ``"SERIALIZABLE"`` isolation level at the same time as setting
+"READ ONLY" and "DEFERRABLE"::
+
+ with engine.connect() as conn:
+ conn = conn.execution_options(
+ isolation_level="SERIALIZABLE",
+ postgresql_readonly=True,
+ postgresql_deferrable=True
+ )
+ with conn.begin():
+ # ... work with transaction
+
+Note that some DBAPIs such as asyncpg only support "readonly" with
+SERIALIZABLE isolation.
+
+.. versionadded:: 1.4 added support for the ``postgresql_readonly``
+ and ``postgresql_deferrable`` execution options.
+
+.. _postgresql_alternate_search_path:
+
+Setting Alternate Search Paths on Connect
+------------------------------------------
+
+The PostgreSQL ``search_path`` variable refers to the list of schema names
+that will be implicitly referred towards when a particular table or other
+object is referenced in a SQL statement. As detailed in the next section
+:ref:`postgresql_schema_reflection`, SQLAlchemy is generally organized around
+the concept of keeping this variable at its default value of ``public``,
+however, in order to have it set to any arbitrary name or names when connections
+are used automatically, the "SET SESSION search_path" command may be invoked
+for all connections in a pool using the following event handler, as discussed
+at :ref:`schema_set_default_connections`::
+
+ from sqlalchemy import event
+ from sqlalchemy import create_engine
+
+ engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname")
+
+ @event.listens_for(engine, "connect", insert=True)
+ def set_search_path(dbapi_connection, connection_record):
+ existing_autocommit = dbapi_connection.autocommit
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET SESSION search_path='%s'" % schema_name)
+ cursor.close()
+ dbapi_connection.autocommit = existing_autocommit
+
+The reason the recipe is complicated by use of the ``.autocommit`` DBAPI
+attribute is so that when the ``SET SESSION search_path`` directive is invoked,
+it is invoked outside of the scope of any transaction and therefore will not
+be reverted when the DBAPI connection has a rollback.
+
+.. seealso::
+
+ :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation
+
+
+
+
+.. _postgresql_schema_reflection:
+
+Remote-Schema Table Introspection and PostgreSQL search_path
+------------------------------------------------------------
+
+.. admonition:: Section Best Practices Summarized
+
+ keep the ``search_path`` variable set to its default of ``public``, without
+ any other schema names. For other schema names, name these explicitly
+ within :class:`_schema.Table` definitions. Alternatively, the
+ ``postgresql_ignore_search_path`` option will cause all reflected
+ :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema`
+ attribute set up.
+
+The PostgreSQL dialect can reflect tables from any schema, as outlined in
+:ref:`metadata_reflection_schemas`.
+
+With regards to tables which these :class:`_schema.Table`
+objects refer to via foreign key constraint, a decision must be made as to how
+the ``.schema`` is represented in those remote tables, in the case where that
+remote schema name is also a member of the current
+`PostgreSQL search path
+<https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_.
+
+By default, the PostgreSQL dialect mimics the behavior encouraged by
+PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function
+returns a sample definition for a particular foreign key constraint,
+omitting the referenced schema name from that definition when the name is
+also in the PostgreSQL schema search path. The interaction below
+illustrates this behavior::
+
+ test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY);
+ CREATE TABLE
+ test=> CREATE TABLE referring(
+ test(> id INTEGER PRIMARY KEY,
+ test(> referred_id INTEGER REFERENCES test_schema.referred(id));
+ CREATE TABLE
+ test=> SET search_path TO public, test_schema;
+ test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM
+ test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n
+ test-> ON n.oid = c.relnamespace
+ test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid
+ test-> WHERE c.relname='referring' AND r.contype = 'f'
+ test-> ;
+ pg_get_constraintdef
+ ---------------------------------------------------
+ FOREIGN KEY (referred_id) REFERENCES referred(id)
+ (1 row)
+
+Above, we created a table ``referred`` as a member of the remote schema
+``test_schema``, however when we added ``test_schema`` to the
+PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the
+``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of
+the function.
+
+On the other hand, if we set the search path back to the typical default
+of ``public``::
+
+ test=> SET search_path TO public;
+ SET
+
+The same query against ``pg_get_constraintdef()`` now returns the fully
+schema-qualified name for us::
+
+ test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM
+ test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n
+ test-> ON n.oid = c.relnamespace
+ test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid
+ test-> WHERE c.relname='referring' AND r.contype = 'f';
+ pg_get_constraintdef
+ ---------------------------------------------------------------
+ FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id)
+ (1 row)
+
+SQLAlchemy will by default use the return value of ``pg_get_constraintdef()``
+in order to determine the remote schema name. That is, if our ``search_path``
+were set to include ``test_schema``, and we invoked a table
+reflection process as follows::
+
+ >>> from sqlalchemy import Table, MetaData, create_engine, text
+ >>> engine = create_engine("postgresql://scott:tiger@localhost/test")
+ >>> with engine.connect() as conn:
+ ... conn.execute(text("SET search_path TO test_schema, public"))
+ ... metadata_obj = MetaData()
+ ... referring = Table('referring', metadata_obj,
+ ... autoload_with=conn)
+ ...
+ <sqlalchemy.engine.result.CursorResult object at 0x101612ed0>
+
+The above process would deliver to the :attr:`_schema.MetaData.tables`
+collection
+``referred`` table named **without** the schema::
+
+ >>> metadata_obj.tables['referred'].schema is None
+ True
+
+To alter the behavior of reflection such that the referred schema is
+maintained regardless of the ``search_path`` setting, use the
+``postgresql_ignore_search_path`` option, which can be specified as a
+dialect-specific argument to both :class:`_schema.Table` as well as
+:meth:`_schema.MetaData.reflect`::
+
+ >>> with engine.connect() as conn:
+ ... conn.execute(text("SET search_path TO test_schema, public"))
+ ... metadata_obj = MetaData()
+ ... referring = Table('referring', metadata_obj,
+ ... autoload_with=conn,
+ ... postgresql_ignore_search_path=True)
+ ...
+ <sqlalchemy.engine.result.CursorResult object at 0x1016126d0>
+
+We will now have ``test_schema.referred`` stored as schema-qualified::
+
+ >>> metadata_obj.tables['test_schema.referred'].schema
+ 'test_schema'
+
+.. sidebar:: Best Practices for PostgreSQL Schema reflection
+
+ The description of PostgreSQL schema reflection behavior is complex, and
+ is the product of many years of dealing with widely varied use cases and
+ user preferences. But in fact, there's no need to understand any of it if
+ you just stick to the simplest use pattern: leave the ``search_path`` set
+ to its default of ``public`` only, never refer to the name ``public`` as
+ an explicit schema name otherwise, and refer to all other schema names
+ explicitly when building up a :class:`_schema.Table` object. The options
+ described here are only for those users who can't, or prefer not to, stay
+ within these guidelines.
+
+Note that **in all cases**, the "default" schema is always reflected as
+``None``. The "default" schema on PostgreSQL is that which is returned by the
+PostgreSQL ``current_schema()`` function. On a typical PostgreSQL
+installation, this is the name ``public``. So a table that refers to another
+which is in the ``public`` (i.e. default) schema will always have the
+``.schema`` attribute set to ``None``.
+
+.. seealso::
+
+ :ref:`reflection_schema_qualified_interaction` - discussion of the issue
+ from a backend-agnostic perspective
+
+ `The Schema Search Path
+ <https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_
+ - on the PostgreSQL website.
+
+INSERT/UPDATE...RETURNING
+-------------------------
+
+The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and
+``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default
+for single-row INSERT statements in order to fetch newly generated
+primary key identifiers. To specify an explicit ``RETURNING`` clause,
+use the :meth:`._UpdateBase.returning` method on a per-statement basis::
+
+ # INSERT..RETURNING
+ result = table.insert().returning(table.c.col1, table.c.col2).\
+ values(name='foo')
+ print(result.fetchall())
+
+ # UPDATE..RETURNING
+ result = table.update().returning(table.c.col1, table.c.col2).\
+ where(table.c.name=='foo').values(name='bar')
+ print(result.fetchall())
+
+ # DELETE..RETURNING
+ result = table.delete().returning(table.c.col1, table.c.col2).\
+ where(table.c.name=='foo')
+ print(result.fetchall())
+
+.. _postgresql_insert_on_conflict:
+
+INSERT...ON CONFLICT (Upsert)
+------------------------------
+
+Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) of
+rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. A
+candidate row will only be inserted if that row does not violate any unique
+constraints. In the case of a unique constraint violation, a secondary action
+can occur which can be either "DO UPDATE", indicating that the data in the
+target row should be updated, or "DO NOTHING", which indicates to silently skip
+this row.
+
+Conflicts are determined using existing unique constraints and indexes. These
+constraints may be identified either using their name as stated in DDL,
+or they may be inferred by stating the columns and conditions that comprise
+the indexes.
+
+SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific
+:func:`_postgresql.insert()` function, which provides
+the generative methods :meth:`_postgresql.Insert.on_conflict_do_update`
+and :meth:`~.postgresql.Insert.on_conflict_do_nothing`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.postgresql import insert
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+ >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(
+ ... index_elements=['id']
+ ... )
+ >>> print(do_nothing_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO NOTHING
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='pk_my_table',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s
+
+.. versionadded:: 1.1
+
+.. seealso::
+
+ `INSERT .. ON CONFLICT
+ <https://www.postgresql.org/docs/current/static/sql-insert.html#SQL-ON-CONFLICT>`_
+ - in the PostgreSQL documentation.
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
+Both methods supply the "target" of the conflict using either the
+named constraint or by column inference:
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument
+ specifies a sequence containing string column names, :class:`_schema.Column`
+ objects, and/or SQL expression elements, which would identify a unique
+ index:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.id],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to
+ infer an index, a partial index can be inferred by also specifying the
+ use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data')
+ >>> stmt = stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.user_email],
+ ... index_where=my_table.c.user_email.like('%@gmail.com'),
+ ... set_=dict(data=stmt.excluded.data)
+ ... )
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (data, user_email)
+ VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email)
+ WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is
+ used to specify an index directly rather than inferring it. This can be
+ the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='my_table_idx_1',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT my_table_idx_1 DO UPDATE SET data = %(param_1)s
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='my_table_pk',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s
+ {stop}
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may
+ also refer to a SQLAlchemy construct representing a constraint,
+ e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`,
+ :class:`.Index`, or :class:`.ExcludeConstraint`. In this use,
+ if the constraint has a name, it is used directly. Otherwise, if the
+ constraint is unnamed, then inference will be used, where the expressions
+ and optional WHERE clause of the constraint will be spelled out in the
+ construct. This use is especially convenient
+ to refer to the named or unnamed primary key of a :class:`_schema.Table`
+ using the
+ :attr:`_schema.Table.primary_key` attribute:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint=my_table.primary_key,
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+The SET Clause
+^^^^^^^^^^^^^^^
+
+``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are specified using the
+:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter. This
+parameter accepts a dictionary which consists of direct values
+for UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+.. warning::
+
+ The :meth:`_expression.Insert.on_conflict_do_update`
+ method does **not** take into
+ account Python-side default UPDATE values or generation functions, e.g.
+ those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of UPDATE,
+ unless they are manually specified in the
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`~.postgresql.Insert.excluded` is available as an attribute on
+the :class:`_postgresql.Insert` object; this object is a
+:class:`_expression.ColumnCollection`
+which alias contains all columns of the target
+table:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author)
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author)
+ VALUES (%(id)s, %(data)s, %(author)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
+
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :meth:`_expression.Insert.on_conflict_do_update` method also accepts
+a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where`
+parameter, which will limit those rows which receive an UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+ >>> on_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author),
+ ... where=(my_table.c.status == 2)
+ ... )
+ >>> print(on_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author)
+ VALUES (%(id)s, %(data)s, %(author)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
+ WHERE my_table.status = %(status_1)s
+
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
+if any conflict with a unique or exclusion constraint occurs; below
+this is illustrated using the
+:meth:`~.postgresql.Insert.on_conflict_do_nothing` method:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id'])
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO NOTHING
+
+If ``DO NOTHING`` is used without specifying any columns or constraint,
+it has the effect of skipping the INSERT for any unique or exclusion
+constraint violation which occurs:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing()
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT DO NOTHING
+
+.. _postgresql_match:
+
+Full Text Search
+----------------
+
+SQLAlchemy makes available the PostgreSQL ``@@`` operator via the
+:meth:`_expression.ColumnElement.match` method on any textual column expression.
+
+On the PostgreSQL dialect, an expression like the following::
+
+ select(sometable.c.text.match("search string"))
+
+will emit to the database::
+
+ SELECT text @@ to_tsquery('search string') FROM table
+
+Various other PostgreSQL text search functions such as ``to_tsquery()``,
+``to_tsvector()``, and ``plainto_tsquery()`` are available by explicitly using
+the standard SQLAlchemy :data:`.func` construct.
+
+For example::
+
+ select(func.to_tsvector('fat cats ate rats').match('cat & rat'))
+
+Emits the equivalent of::
+
+ SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat')
+
+The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST::
+
+ from sqlalchemy.dialects.postgresql import TSVECTOR
+ from sqlalchemy import select, cast
+ select(cast("some text", TSVECTOR))
+
+produces a statement equivalent to::
+
+ SELECT CAST('some text' AS TSVECTOR) AS anon_1
+
+.. tip::
+
+ It's important to remember that text searching in PostgreSQL is powerful but complicated,
+ and SQLAlchemy users are advised to reference the PostgreSQL documentation
+ regarding
+ `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_.
+
+ There are important differences between ``to_tsquery`` and
+ ``plainto_tsquery``, the most significant of which is that ``to_tsquery``
+ expects specially formatted "querytext" that is written to PostgreSQL's own
+ specification, while ``plainto_tsquery`` expects unformatted text that is
+ transformed into ``to_tsquery`` compatible querytext. This means the input to
+ ``.match()`` under PostgreSQL may be incompatible with the input to
+ ``.match()`` under another database backend. SQLAlchemy users who support
+ multiple backends are advised to carefully implement their usage of
+ ``.match()`` to work around these constraints.
+
+Full Text Searches in PostgreSQL are influenced by a combination of: the
+PostgreSQL setting of ``default_text_search_config``, the ``regconfig`` used
+to build the GIN/GiST indexes, and the ``regconfig`` optionally passed in
+during a query.
+
+When performing a Full Text Search against a column that has a GIN or
+GiST index that is already pre-computed (which is common on full text
+searches) one may need to explicitly pass in a particular PostgreSQL
+``regconfig`` value to ensure the query-planner utilizes the index and does
+not re-compute the column on demand.
+
+In order to provide for this explicit query planning, or to use different
+search strategies, the ``match`` method accepts a ``postgresql_regconfig``
+keyword argument::
+
+ select(mytable.c.id).where(
+ mytable.c.title.match('somestring', postgresql_regconfig='english')
+ )
+
+Emits the equivalent of::
+
+ SELECT mytable.id FROM mytable
+ WHERE mytable.title @@ to_tsquery('english', 'somestring')
+
+One can also specifically pass in a `'regconfig'` value to the
+``to_tsvector()`` command as the initial argument::
+
+ select(mytable.c.id).where(
+ func.to_tsvector('english', mytable.c.title )\
+ .match('somestring', postgresql_regconfig='english')
+ )
+
+produces a statement equivalent to::
+
+ SELECT mytable.id FROM mytable
+ WHERE to_tsvector('english', mytable.title) @@
+ to_tsquery('english', 'somestring')
+
+It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from
+PostgreSQL to ensure that you are generating queries with SQLAlchemy that
+take full advantage of any indexes you may have created for full text search.
+
+.. seealso::
+
+ `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_ - in the PostgreSQL documentation
+
+
+FROM ONLY ...
+-------------
+
+The dialect supports PostgreSQL's ONLY keyword for targeting only a particular
+table in an inheritance hierarchy. This can be used to produce the
+``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...``
+syntaxes. It uses SQLAlchemy's hints mechanism::
+
+ # SELECT ... FROM ONLY ...
+ result = table.select().with_hint(table, 'ONLY', 'postgresql')
+ print(result.fetchall())
+
+ # UPDATE ONLY ...
+ table.update(values=dict(foo='bar')).with_hint('ONLY',
+ dialect_name='postgresql')
+
+ # DELETE FROM ONLY ...
+ table.delete().with_hint('ONLY', dialect_name='postgresql')
+
+
+.. _postgresql_indexes:
+
+PostgreSQL-Specific Index Options
+---------------------------------
+
+Several extensions to the :class:`.Index` construct are available, specific
+to the PostgreSQL dialect.
+
+Covering Indexes
+^^^^^^^^^^^^^^^^
+
+The ``postgresql_include`` option renders INCLUDE(colname) for the given
+string names::
+
+ Index("my_index", table.c.x, postgresql_include=['y'])
+
+would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
+
+Note that this feature requires PostgreSQL 11 or later.
+
+.. versionadded:: 1.4
+
+.. _postgresql_partial_indexes:
+
+Partial Indexes
+^^^^^^^^^^^^^^^
+
+Partial indexes add criterion to the index definition so that the index is
+applied to a subset of rows. These can be specified on :class:`.Index`
+using the ``postgresql_where`` keyword argument::
+
+ Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10)
+
+.. _postgresql_operator_classes:
+
+Operator Classes
+^^^^^^^^^^^^^^^^
+
+PostgreSQL allows the specification of an *operator class* for each column of
+an index (see
+https://www.postgresql.org/docs/current/interactive/indexes-opclass.html).
+The :class:`.Index` construct allows these to be specified via the
+``postgresql_ops`` keyword argument::
+
+ Index(
+ 'my_index', my_table.c.id, my_table.c.data,
+ postgresql_ops={
+ 'data': 'text_pattern_ops',
+ 'id': 'int4_ops'
+ })
+
+Note that the keys in the ``postgresql_ops`` dictionaries are the
+"key" name of the :class:`_schema.Column`, i.e. the name used to access it from
+the ``.c`` collection of :class:`_schema.Table`, which can be configured to be
+different than the actual name of the column as expressed in the database.
+
+If ``postgresql_ops`` is to be used against a complex SQL expression such
+as a function call, then to apply to the column it must be given a label
+that is identified in the dictionary by name, e.g.::
+
+ Index(
+ 'my_index', my_table.c.id,
+ func.lower(my_table.c.data).label('data_lower'),
+ postgresql_ops={
+ 'data_lower': 'text_pattern_ops',
+ 'id': 'int4_ops'
+ })
+
+Operator classes are also supported by the
+:class:`_postgresql.ExcludeConstraint` construct using the
+:paramref:`_postgresql.ExcludeConstraint.ops` parameter. See that parameter for
+details.
+
+.. versionadded:: 1.3.21 added support for operator classes with
+ :class:`_postgresql.ExcludeConstraint`.
+
+
+Index Types
+^^^^^^^^^^^
+
+PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well
+as the ability for users to create their own (see
+https://www.postgresql.org/docs/current/static/indexes-types.html). These can be
+specified on :class:`.Index` using the ``postgresql_using`` keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_using='gin')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX command, so it *must* be a valid index type for your
+version of PostgreSQL.
+
+.. _postgresql_index_storage:
+
+Index Storage Parameters
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL allows storage parameters to be set on indexes. The storage
+parameters available depend on the index method used by the index. Storage
+parameters can be specified on :class:`.Index` using the ``postgresql_with``
+keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50})
+
+.. versionadded:: 1.0.6
+
+PostgreSQL allows to define the tablespace in which to create the index.
+The tablespace can be specified on :class:`.Index` using the
+``postgresql_tablespace`` keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace')
+
+.. versionadded:: 1.1
+
+Note that the same option is available on :class:`_schema.Table` as well.
+
+.. _postgresql_index_concurrently:
+
+Indexes with CONCURRENTLY
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The PostgreSQL index option CONCURRENTLY is supported by passing the
+flag ``postgresql_concurrently`` to the :class:`.Index` construct::
+
+ tbl = Table('testtbl', m, Column('data', Integer))
+
+ idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True)
+
+The above index construct will render DDL for CREATE INDEX, assuming
+PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as::
+
+ CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data)
+
+For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for
+a connection-less dialect, it will emit::
+
+ DROP INDEX CONCURRENTLY test_idx1
+
+.. versionadded:: 1.1 support for CONCURRENTLY on DROP INDEX. The
+ CONCURRENTLY keyword is now only emitted if a high enough version
+ of PostgreSQL is detected on the connection (or for a connection-less
+ dialect).
+
+When using CONCURRENTLY, the PostgreSQL database requires that the statement
+be invoked outside of a transaction block. The Python DBAPI enforces that
+even for a single statement, a transaction is present, so to use this
+construct, the DBAPI's "autocommit" mode must be used::
+
+ metadata = MetaData()
+ table = Table(
+ "foo", metadata,
+ Column("id", String))
+ index = Index(
+ "foo_idx", table.c.id, postgresql_concurrently=True)
+
+ with engine.connect() as conn:
+ with conn.execution_options(isolation_level='AUTOCOMMIT'):
+ table.create(conn)
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+.. _postgresql_index_reflection:
+
+PostgreSQL Index Reflection
+---------------------------
+
+The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the
+UNIQUE CONSTRAINT construct is used. When inspecting a table using
+:class:`_reflection.Inspector`, the :meth:`_reflection.Inspector.get_indexes`
+and the :meth:`_reflection.Inspector.get_unique_constraints`
+will report on these
+two constructs distinctly; in the case of the index, the key
+``duplicates_constraint`` will be present in the index entry if it is
+detected as mirroring a constraint. When performing reflection using
+``Table(..., autoload_with=engine)``, the UNIQUE INDEX is **not** returned
+in :attr:`_schema.Table.indexes` when it is detected as mirroring a
+:class:`.UniqueConstraint` in the :attr:`_schema.Table.constraints` collection
+.
+
+.. versionchanged:: 1.0.0 - :class:`_schema.Table` reflection now includes
+ :class:`.UniqueConstraint` objects present in the
+ :attr:`_schema.Table.constraints`
+ collection; the PostgreSQL backend will no longer include a "mirrored"
+ :class:`.Index` construct in :attr:`_schema.Table.indexes`
+ if it is detected
+ as corresponding to a unique constraint.
+
+Special Reflection Options
+--------------------------
+
+The :class:`_reflection.Inspector`
+used for the PostgreSQL backend is an instance
+of :class:`.PGInspector`, which offers additional methods::
+
+ from sqlalchemy import create_engine, inspect
+
+ engine = create_engine("postgresql+psycopg2://localhost/test")
+ insp = inspect(engine) # will be a PGInspector
+
+ print(insp.get_enums())
+
+.. autoclass:: PGInspector
+ :members:
+
+.. _postgresql_table_options:
+
+PostgreSQL Table Options
+------------------------
+
+Several options for CREATE TABLE are supported directly by the PostgreSQL
+dialect in conjunction with the :class:`_schema.Table` construct:
+
+* ``TABLESPACE``::
+
+ Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace')
+
+ The above option is also available on the :class:`.Index` construct.
+
+* ``ON COMMIT``::
+
+ Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS')
+
+* ``WITH OIDS``::
+
+ Table("some_table", metadata, ..., postgresql_with_oids=True)
+
+* ``WITHOUT OIDS``::
+
+ Table("some_table", metadata, ..., postgresql_with_oids=False)
+
+* ``INHERITS``::
+
+ Table("some_table", metadata, ..., postgresql_inherits="some_supertable")
+
+ Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...))
+
+ .. versionadded:: 1.0.0
+
+* ``PARTITION BY``::
+
+ Table("some_table", metadata, ...,
+ postgresql_partition_by='LIST (part_column)')
+
+ .. versionadded:: 1.2.6
+
+.. seealso::
+
+ `PostgreSQL CREATE TABLE options
+ <https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ -
+ in the PostgreSQL documentation.
+
+.. _postgresql_constraint_options:
+
+PostgreSQL Constraint Options
+-----------------------------
+
+The following option(s) are supported by the PostgreSQL dialect in conjunction
+with selected constraint constructs:
+
+* ``NOT VALID``: This option applies towards CHECK and FOREIGN KEY constraints
+ when the constraint is being added to an existing table via ALTER TABLE,
+ and has the effect that existing rows are not scanned during the ALTER
+ operation against the constraint being added.
+
+ When using a SQL migration tool such as `Alembic <https://alembic.sqlalchemy.org>`_
+ that renders ALTER TABLE constructs, the ``postgresql_not_valid`` argument
+ may be specified as an additional keyword argument within the operation
+ that creates the constraint, as in the following Alembic example::
+
+ def update():
+ op.create_foreign_key(
+ "fk_user_address",
+ "address",
+ "user",
+ ["user_id"],
+ ["id"],
+ postgresql_not_valid=True
+ )
+
+ The keyword is ultimately accepted directly by the
+ :class:`_schema.CheckConstraint`, :class:`_schema.ForeignKeyConstraint`
+ and :class:`_schema.ForeignKey` constructs; when using a tool like
+ Alembic, dialect-specific keyword arguments are passed through to
+ these constructs from the migration operation directives::
+
+ CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True)
+
+ ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True)
+
+ .. versionadded:: 1.4.32
+
+ .. seealso::
+
+ `PostgreSQL ALTER TABLE options
+ <https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ -
+ in the PostgreSQL documentation.
+
+.. _postgresql_table_valued_overview:
+
+Table values, Table and Column valued functions, Row and Tuple objects
+-----------------------------------------------------------------------
+
+PostgreSQL makes great use of modern SQL forms such as table-valued functions,
+tables and rows as values. These constructs are commonly used as part
+of PostgreSQL's support for complex datatypes such as JSON, ARRAY, and other
+datatypes. SQLAlchemy's SQL expression language has native support for
+most table-valued and row-valued forms.
+
+.. _postgresql_table_valued:
+
+Table-Valued Functions
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Many PostgreSQL built-in functions are intended to be used in the FROM clause
+of a SELECT statement, and are capable of returning table rows or sets of table
+rows. A large portion of PostgreSQL's JSON functions for example such as
+``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``,
+``json_each()``, ``json_to_record()``, ``json_populate_recordset()`` use such
+forms. These classes of SQL function calling forms in SQLAlchemy are available
+using the :meth:`_functions.FunctionElement.table_valued` method in conjunction
+with :class:`_functions.Function` objects generated from the :data:`_sql.func`
+namespace.
+
+Examples from PostgreSQL's reference documentation follow below:
+
+* ``json_each()``::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value"))
+ >>> print(stmt)
+ SELECT anon_1.key, anon_1.value
+ FROM json_each(:json_each_1) AS anon_1
+
+* ``json_populate_record()``::
+
+ >>> from sqlalchemy import select, func, literal_column
+ >>> stmt = select(
+ ... func.json_populate_record(
+ ... literal_column("null::myrowtype"),
+ ... '{"a":1,"b":2}'
+ ... ).table_valued("a", "b", name="x")
+ ... )
+ >>> print(stmt)
+ SELECT x.a, x.b
+ FROM json_populate_record(null::myrowtype, :json_populate_record_1) AS x
+
+* ``json_to_record()`` - this form uses a PostgreSQL specific form of derived
+ columns in the alias, where we may make use of :func:`_sql.column` elements with
+ types to produce them. The :meth:`_functions.FunctionElement.table_valued`
+ method produces a :class:`_sql.TableValuedAlias` construct, and the method
+ :meth:`_sql.TableValuedAlias.render_derived` method sets up the derived
+ columns specification::
+
+ >>> from sqlalchemy import select, func, column, Integer, Text
+ >>> stmt = select(
+ ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued(
+ ... column("a", Integer), column("b", Text), column("d", Text),
+ ... ).render_derived(name="x", with_types=True)
+ ... )
+ >>> print(stmt)
+ SELECT x.a, x.b, x.d
+ FROM json_to_record(:json_to_record_1) AS x(a INTEGER, b TEXT, d TEXT)
+
+* ``WITH ORDINALITY`` - part of the SQL standard, ``WITH ORDINALITY`` adds an
+ ordinal counter to the output of a function and is accepted by a limited set
+ of PostgreSQL functions including ``unnest()`` and ``generate_series()``. The
+ :meth:`_functions.FunctionElement.table_valued` method accepts a keyword
+ parameter ``with_ordinality`` for this purpose, which accepts the string name
+ that will be applied to the "ordinality" column::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(
+ ... func.generate_series(4, 1, -1).
+ ... table_valued("value", with_ordinality="ordinality").
+ ... render_derived()
+ ... )
+ >>> print(stmt)
+ SELECT anon_1.value, anon_1.ordinality
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3)
+ WITH ORDINALITY AS anon_1(value, ordinality)
+
+.. versionadded:: 1.4.0b2
+
+.. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+.. _postgresql_column_valued:
+
+Column Valued Functions
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Similar to the table valued function, a column valued function is present
+in the FROM clause, but delivers itself to the columns clause as a single
+scalar value. PostgreSQL functions such as ``json_array_elements()``,
+``unnest()`` and ``generate_series()`` may use this form. Column valued functions are available using the
+:meth:`_functions.FunctionElement.column_valued` method of :class:`_functions.FunctionElement`:
+
+* ``json_array_elements()``::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x"))
+ >>> print(stmt)
+ SELECT x
+ FROM json_array_elements(:json_array_elements_1) AS x
+
+* ``unnest()`` - in order to generate a PostgreSQL ARRAY literal, the
+ :func:`_postgresql.array` construct may be used::
+
+
+ >>> from sqlalchemy.dialects.postgresql import array
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.unnest(array([1, 2])).column_valued())
+ >>> print(stmt)
+ SELECT anon_1
+ FROM unnest(ARRAY[%(param_1)s, %(param_2)s]) AS anon_1
+
+ The function can of course be used against an existing table-bound column
+ that's of type :class:`_types.ARRAY`::
+
+ >>> from sqlalchemy import table, column, ARRAY, Integer
+ >>> from sqlalchemy import select, func
+ >>> t = table("t", column('value', ARRAY(Integer)))
+ >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value"))
+ >>> print(stmt)
+ SELECT unnested_value
+ FROM unnest(t.value) AS unnested_value
+
+.. seealso::
+
+ :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial`
+
+
+Row Types
+^^^^^^^^^
+
+Built-in support for rendering a ``ROW`` may be approximated using
+``func.ROW`` with the :attr:`_sa.func` namespace, or by using the
+:func:`_sql.tuple_` construct::
+
+ >>> from sqlalchemy import table, column, func, tuple_
+ >>> t = table("t", column("id"), column("fk"))
+ >>> stmt = t.select().where(
+ ... tuple_(t.c.id, t.c.fk) > (1,2)
+ ... ).where(
+ ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7)
+ ... )
+ >>> print(stmt)
+ SELECT t.id, t.fk
+ FROM t
+ WHERE (t.id, t.fk) > (:param_1, :param_2) AND ROW(t.id, t.fk) < ROW(:ROW_1, :ROW_2)
+
+.. seealso::
+
+ `PostgreSQL Row Constructors
+ <https://www.postgresql.org/docs/current/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS>`_
+
+ `PostgreSQL Row Constructor Comparison
+ <https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON>`_
+
+Table Types passed to Functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL supports passing a table as an argument to a function, which it
+refers towards as a "record" type. SQLAlchemy :class:`_sql.FromClause` objects
+such as :class:`_schema.Table` support this special form using the
+:meth:`_sql.FromClause.table_valued` method, which is comparable to the
+:meth:`_functions.FunctionElement.table_valued` method except that the collection
+of columns is already established by that of the :class:`_sql.FromClause`
+itself::
+
+
+ >>> from sqlalchemy import table, column, func, select
+ >>> a = table( "a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+.. versionadded:: 1.4.0b2
+
+
+ARRAY Types
+-----------
+
+The PostgreSQL dialect supports arrays, both as multidimensional column types
+as well as array literals:
+
+* :class:`_postgresql.ARRAY` - ARRAY datatype
+
+* :class:`_postgresql.array` - array literal
+
+* :func:`_postgresql.array_agg` - ARRAY_AGG SQL function
+
+* :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate
+ function syntax.
+
+JSON Types
+----------
+
+The PostgreSQL dialect supports both JSON and JSONB datatypes, including
+psycopg2's native support and support for all of PostgreSQL's special
+operators:
+
+* :class:`_postgresql.JSON`
+
+* :class:`_postgresql.JSONB`
+
+HSTORE Type
+-----------
+
+The PostgreSQL HSTORE type as well as hstore literals are supported:
+
+* :class:`_postgresql.HSTORE` - HSTORE datatype
+
+* :class:`_postgresql.hstore` - hstore literal
+
+ENUM Types
+----------
+
+PostgreSQL has an independently creatable TYPE structure which is used
+to implement an enumerated type. This approach introduces significant
+complexity on the SQLAlchemy side in terms of when this type should be
+CREATED and DROPPED. The type object is also an independently reflectable
+entity. The following sections should be consulted:
+
+* :class:`_postgresql.ENUM` - DDL and typing support for ENUM.
+
+* :meth:`.PGInspector.get_enums` - retrieve a listing of current ENUM types
+
+* :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual
+ CREATE and DROP commands for ENUM.
+
+.. _postgresql_array_of_enum:
+
+Using ENUM with ARRAY
+^^^^^^^^^^^^^^^^^^^^^
+
+The combination of ENUM and ARRAY is not directly supported by backend
+DBAPIs at this time. Prior to SQLAlchemy 1.3.17, a special workaround
+was needed in order to allow this combination to work, described below.
+
+.. versionchanged:: 1.3.17 The combination of ENUM and ARRAY is now directly
+ handled by SQLAlchemy's implementation without any workarounds needed.
+
+.. sourcecode:: python
+
+ from sqlalchemy import TypeDecorator
+ from sqlalchemy.dialects.postgresql import ARRAY
+
+ class ArrayOfEnum(TypeDecorator):
+ impl = ARRAY
+
+ def bind_expression(self, bindvalue):
+ return sa.cast(bindvalue, self)
+
+ def result_processor(self, dialect, coltype):
+ super_rp = super(ArrayOfEnum, self).result_processor(
+ dialect, coltype)
+
+ def handle_raw_string(value):
+ inner = re.match(r"^{(.*)}$", value).group(1)
+ return inner.split(",") if inner else []
+
+ def process(value):
+ if value is None:
+ return None
+ return super_rp(handle_raw_string(value))
+ return process
+
+E.g.::
+
+ Table(
+ 'mydata', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', ArrayOfEnum(ENUM('a', 'b, 'c', name='myenum')))
+
+ )
+
+This type is not included as a built-in type as it would be incompatible
+with a DBAPI that suddenly decides to support ARRAY of ENUM directly in
+a new version.
+
+.. _postgresql_array_of_json:
+
+Using JSON/JSONB with ARRAY
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Similar to using ENUM, prior to SQLAlchemy 1.3.17, for an ARRAY of JSON/JSONB
+we need to render the appropriate CAST. Current psycopg2 drivers accommodate
+the result set correctly without any special steps.
+
+.. versionchanged:: 1.3.17 The combination of JSON/JSONB and ARRAY is now
+ directly handled by SQLAlchemy's implementation without any workarounds
+ needed.
+
+.. sourcecode:: python
+
+ class CastingArray(ARRAY):
+ def bind_expression(self, bindvalue):
+ return sa.cast(bindvalue, self)
+
+E.g.::
+
+ Table(
+ 'mydata', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', CastingArray(JSONB))
+ )
+
+
+""" # noqa: E501
+
+from collections import defaultdict
+import datetime as dt
+import re
+from uuid import UUID as _python_UUID
+
+from . import array as _array
+from . import dml
+from . import hstore as _hstore
+from . import json as _json
+from . import ranges as _ranges
+from ... import exc
+from ... import schema
+from ... import sql
+from ... import util
+from ...engine import characteristics
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import expression
+from ...sql import roles
+from ...sql import sqltypes
+from ...sql import util as sql_util
+from ...sql.ddl import DDLBase
+from ...types import BIGINT
+from ...types import BOOLEAN
+from ...types import CHAR
+from ...types import DATE
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NUMERIC
+from ...types import REAL
+from ...types import SMALLINT
+from ...types import TEXT
+from ...types import VARCHAR
+
+IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
+
+AUTOCOMMIT_REGEXP = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|"
+ "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)",
+ re.I | re.UNICODE,
+)
+
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "current_catalog",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "fetch",
+ "for",
+ "foreign",
+ "from",
+ "grant",
+ "group",
+ "having",
+ "in",
+ "initially",
+ "intersect",
+ "into",
+ "leading",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "new",
+ "not",
+ "null",
+ "of",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "placing",
+ "primary",
+ "references",
+ "returning",
+ "select",
+ "session_user",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "variadic",
+ "when",
+ "where",
+ "window",
+ "with",
+ "authorization",
+ "between",
+ "binary",
+ "cross",
+ "current_schema",
+ "freeze",
+ "full",
+ "ilike",
+ "inner",
+ "is",
+ "isnull",
+ "join",
+ "left",
+ "like",
+ "natural",
+ "notnull",
+ "outer",
+ "over",
+ "overlaps",
+ "right",
+ "similar",
+ "verbose",
+ ]
+)
+
+_DECIMAL_TYPES = (1231, 1700)
+_FLOAT_TYPES = (700, 701, 1021, 1022)
+_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
+
+class BYTEA(sqltypes.LargeBinary):
+ __visit_name__ = "BYTEA"
+
+
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = "DOUBLE_PRECISION"
+
+
+class INET(sqltypes.TypeEngine):
+ __visit_name__ = "INET"
+
+
+PGInet = INET
+
+
+class CIDR(sqltypes.TypeEngine):
+ __visit_name__ = "CIDR"
+
+
+PGCidr = CIDR
+
+
+class MACADDR(sqltypes.TypeEngine):
+ __visit_name__ = "MACADDR"
+
+
+PGMacAddr = MACADDR
+
+
+class MONEY(sqltypes.TypeEngine):
+
+ r"""Provide the PostgreSQL MONEY type.
+
+ Depending on driver, result rows using this type may return a
+ string value which includes currency symbols.
+
+ For this reason, it may be preferable to provide conversion to a
+ numerically-based currency datatype using :class:`_types.TypeDecorator`::
+
+ import re
+ import decimal
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def process_result_value(self, value: Any, dialect: Any) -> None:
+ if value is not None:
+ # adjust this for the currency and numeric
+ m = re.match(r"\$([\d.]+)", value)
+ if m:
+ value = decimal.Decimal(m.group(1))
+ return value
+
+ Alternatively, the conversion may be applied as a CAST using
+ the :meth:`_types.TypeDecorator.column_expression` method as follows::
+
+ import decimal
+ from sqlalchemy import cast
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def column_expression(self, column: Any):
+ return cast(column, Numeric())
+
+ .. versionadded:: 1.2
+
+ """
+
+ __visit_name__ = "MONEY"
+
+
+class OID(sqltypes.TypeEngine):
+
+ """Provide the PostgreSQL OID type.
+
+ .. versionadded:: 0.9.5
+
+ """
+
+ __visit_name__ = "OID"
+
+
+class REGCLASS(sqltypes.TypeEngine):
+
+ """Provide the PostgreSQL REGCLASS type.
+
+ .. versionadded:: 1.2.7
+
+ """
+
+ __visit_name__ = "REGCLASS"
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+
+ """Provide the PostgreSQL TIMESTAMP type."""
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False, precision=None):
+ """Construct a TIMESTAMP.
+
+ :param timezone: boolean value if timezone present, default False
+ :param precision: optional integer precision value
+
+ .. versionadded:: 1.4
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class TIME(sqltypes.TIME):
+
+ """PostgreSQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+ def __init__(self, timezone=False, precision=None):
+ """Construct a TIME.
+
+ :param timezone: boolean value if timezone present, default False
+ :param precision: optional integer precision value
+
+ .. versionadded:: 1.4
+
+ """
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+
+ """PostgreSQL INTERVAL type."""
+
+ __visit_name__ = "INTERVAL"
+ native = True
+
+ def __init__(self, precision=None, fields=None):
+ """Construct an INTERVAL.
+
+ :param precision: optional integer precision value
+ :param fields: string fields specifier. allows storage of fields
+ to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
+ etc.
+
+ .. versionadded:: 1.2
+
+ """
+ self.precision = precision
+ self.fields = fields
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return INTERVAL(precision=interval.second_precision)
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(native=True, second_precision=self.precision)
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+ def coerce_compared_value(self, op, value):
+ return self
+
+
+PGInterval = INTERVAL
+
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None, varying=False):
+ if not varying:
+ # BIT without VARYING defaults to length 1
+ self.length = length or 1
+ else:
+ # but BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ self.varying = varying
+
+
+PGBit = BIT
+
+
+class UUID(sqltypes.TypeEngine):
+
+ """PostgreSQL UUID type.
+
+ Represents the UUID column type, interpreting
+ data either as natively returned by the DBAPI
+ or as Python uuid objects.
+
+ The UUID type is currently known to work within the prominent DBAPI
+ drivers supported by SQLAlchemy including psycopg2, pg8000 and
+ asyncpg. Support for other DBAPI drivers may be incomplete or non-present.
+
+ """
+
+ __visit_name__ = "UUID"
+
+ def __init__(self, as_uuid=False):
+ """Construct a UUID type.
+
+
+ :param as_uuid=False: if True, values will be interpreted
+ as Python uuid objects, converting to/from string via the
+ DBAPI.
+
+ """
+ self.as_uuid = as_uuid
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+ if isinstance(value, util.string_types):
+ return self
+ else:
+ return super(UUID, self).coerce_compared_value(op, value)
+
+ def bind_processor(self, dialect):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = util.text_type(value)
+ return value
+
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+ else:
+ return None
+
+ def literal_processor(self, dialect):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = "'%s'::UUID" % value
+ return value
+
+ return process
+ else:
+
+ def process(value):
+ if value is not None:
+ value = "'%s'" % value
+ return value
+
+ return process
+
+ @property
+ def python_type(self):
+ return _python_UUID if self.as_uuid else str
+
+
+PGUuid = UUID
+
+
+class TSVECTOR(sqltypes.TypeEngine):
+
+ """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
+ text search type TSVECTOR.
+
+ It can be used to do full text queries on natural language
+ documents.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`postgresql_match`
+
+ """
+
+ __visit_name__ = "TSVECTOR"
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+
+ """PostgreSQL ENUM type.
+
+ This is a subclass of :class:`_types.Enum` which includes
+ support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+ When the builtin type :class:`_types.Enum` is used and the
+ :paramref:`.Enum.native_enum` flag is left at its default of
+ True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+ type as the implementation, so the special create/drop rules
+ will be used.
+
+ The create/drop behavior of ENUM is necessarily intricate, due to the
+ awkward relationship the ENUM type has in relationship to the
+ parent table, in that it may be "owned" by just a single table, or
+ may be shared among many tables.
+
+ When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+ in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+ corresponding to when the :meth:`_schema.Table.create` and
+ :meth:`_schema.Table.drop`
+ methods are called::
+
+ table = Table('sometable', metadata,
+ Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+ )
+
+ table.create(engine) # will emit CREATE ENUM and CREATE TABLE
+ table.drop(engine) # will emit DROP TABLE and DROP ENUM
+
+ To use a common enumerated type between multiple tables, the best
+ practice is to declare the :class:`_types.Enum` or
+ :class:`_postgresql.ENUM` independently, and associate it with the
+ :class:`_schema.MetaData` object itself::
+
+ my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+ t1 = Table('sometable_one', metadata,
+ Column('some_enum', myenum)
+ )
+
+ t2 = Table('sometable_two', metadata,
+ Column('some_enum', myenum)
+ )
+
+ When this pattern is used, care must still be taken at the level
+ of individual table creates. Emitting CREATE TABLE without also
+ specifying ``checkfirst=True`` will still cause issues::
+
+ t1.create(engine) # will fail: no such type 'myenum'
+
+ If we specify ``checkfirst=True``, the individual table-level create
+ operation will check for the ``ENUM`` and create if not exists::
+
+ # will check if enum exists, and emit CREATE TYPE if not
+ t1.create(engine, checkfirst=True)
+
+ When using a metadata-level ENUM type, the type will always be created
+ and dropped if either the metadata-wide create/drop is called::
+
+ metadata.create_all(engine) # will emit CREATE TYPE
+ metadata.drop_all(engine) # will emit DROP TYPE
+
+ The type can also be created and dropped directly::
+
+ my_enum.create(engine)
+ my_enum.drop(engine)
+
+ .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+ now behaves more strictly with regards to CREATE/DROP. A metadata-level
+ ENUM type will only be created and dropped at the metadata level,
+ not the table level, with the exception of
+ ``table.create(checkfirst=True)``.
+ The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+ enumerated type.
+
+ """
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an :class:`_postgresql.ENUM`.
+
+ Arguments are the same as that of
+ :class:`_types.Enum`, but also including
+ the following parameters.
+
+ :param create_type: Defaults to True.
+ Indicates that ``CREATE TYPE`` should be
+ emitted, after optionally checking for the
+ presence of the type, when the parent
+ table is being created; and additionally
+ that ``DROP TYPE`` is called when the table
+ is dropped. When ``False``, no check
+ will be performed and no ``CREATE TYPE``
+ or ``DROP TYPE`` is emitted, unless
+ :meth:`~.postgresql.ENUM.create`
+ or :meth:`~.postgresql.ENUM.drop`
+ are called directly.
+ Setting to ``False`` is helpful
+ when invoking a creation scheme to a SQL file
+ without access to the actual database -
+ the :meth:`~.postgresql.ENUM.create` and
+ :meth:`~.postgresql.ENUM.drop` methods can
+ be used to emit SQL to a target bind.
+
+ """
+ native_enum = kw.pop("native_enum", None)
+ if native_enum is False:
+ util.warn(
+ "the native_enum flag does not apply to the "
+ "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+ "always refers to ENUM. Use sqlalchemy.types.Enum for "
+ "non-native enum."
+ )
+ self.create_type = kw.pop("create_type", True)
+ super(ENUM, self).__init__(*enums, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def create(self, bind=None, checkfirst=True):
+ """Emit ``CREATE TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL CREATE TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type does not exist already before
+ creating.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Emit ``DROP TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL DROP TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type actually exists before dropping.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
+
+ def _check_for_name_in_memos(self, checkfirst, kw):
+ """Look in the 'ddl runner' for 'memos', then
+ note our name in that collection.
+
+ This to ensure a particular named enum is operated
+ upon only once within any kind of create/drop
+ sequence without relying upon "checkfirst".
+
+ """
+ if not self.create_type:
+ return True
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
+ else:
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
+ present = (self.schema, self.name) in pg_enums
+ pg_enums.add((self.schema, self.name))
+ return present
+ else:
+ return False
+
+ def _on_table_create(self, target, bind, checkfirst=False, **kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class _ColonCast(elements.Cast):
+ __visit_name__ = "colon_cast"
+
+ def __init__(self, expression, type_):
+ self.type = type_
+ self.clause = expression
+ self.typeclause = elements.TypeClause(type_)
+
+
+colspecs = {
+ sqltypes.ARRAY: _array.ARRAY,
+ sqltypes.Interval: INTERVAL,
+ sqltypes.Enum: ENUM,
+ sqltypes.JSON.JSONPathType: _json.JSONPathType,
+ sqltypes.JSON: _json.JSON,
+}
+
+ischema_names = {
+ "_array": _array.ARRAY,
+ "hstore": _hstore.HSTORE,
+ "json": _json.JSON,
+ "jsonb": _json.JSONB,
+ "int4range": _ranges.INT4RANGE,
+ "int8range": _ranges.INT8RANGE,
+ "numrange": _ranges.NUMRANGE,
+ "daterange": _ranges.DATERANGE,
+ "tsrange": _ranges.TSRANGE,
+ "tstzrange": _ranges.TSTZRANGE,
+ "integer": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "character varying": VARCHAR,
+ "character": CHAR,
+ '"char"': sqltypes.String,
+ "name": sqltypes.String,
+ "text": TEXT,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "real": REAL,
+ "inet": INET,
+ "cidr": CIDR,
+ "uuid": UUID,
+ "bit": BIT,
+ "bit varying": BIT,
+ "macaddr": MACADDR,
+ "money": MONEY,
+ "oid": OID,
+ "regclass": REGCLASS,
+ "double precision": DOUBLE_PRECISION,
+ "timestamp": TIMESTAMP,
+ "timestamp with time zone": TIMESTAMP,
+ "timestamp without time zone": TIMESTAMP,
+ "time with time zone": TIME,
+ "time without time zone": TIME,
+ "date": DATE,
+ "time": TIME,
+ "bytea": BYTEA,
+ "boolean": BOOLEAN,
+ "interval": INTERVAL,
+ "tsvector": TSVECTOR,
+}
+
+
+class PGCompiler(compiler.SQLCompiler):
+ def visit_colon_cast(self, element, **kw):
+ return "%s::%s" % (
+ element.clause._compiler_dispatch(self, **kw),
+ element.typeclause._compiler_dispatch(self, **kw),
+ )
+
+ def visit_array(self, element, **kw):
+ return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
+
+ def visit_slice(self, element, **kw):
+ return "%s:%s" % (
+ self.process(element.start, **kw),
+ self.process(element.stop, **kw),
+ )
+
+ def visit_json_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
+ kw["eager_grouping"] = True
+
+ return self._generate_generic_binary(
+ binary, " -> " if not _cast_applied else " ->> ", **kw
+ )
+
+ def visit_json_path_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
+ kw["eager_grouping"] = True
+ return self._generate_generic_binary(
+ binary, " #> " if not _cast_applied else " #>> ", **kw
+ )
+
+ def visit_getitem_binary(self, binary, operator, **kw):
+ return "%s[%s]" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_aggregate_order_by(self, element, **kw):
+ return "%s ORDER BY %s" % (
+ self.process(element.target, **kw),
+ self.process(element.order_by, **kw),
+ )
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ if "postgresql_regconfig" in binary.modifiers:
+ regconfig = self.render_literal_value(
+ binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE
+ )
+ if regconfig:
+ return "%s @@ to_tsquery(%s, %s)" % (
+ self.process(binary.left, **kw),
+ regconfig,
+ self.process(binary.right, **kw),
+ )
+ return "%s @@ to_tsquery(%s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+
+ return "%s ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "%s NOT ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def _regexp_match(self, base_op, binary, operator, kw):
+ flags = binary.modifiers["flags"]
+ if flags is None:
+ return self._generate_generic_binary(
+ binary, " %s " % base_op, **kw
+ )
+ if isinstance(flags, elements.BindParameter) and flags.value == "i":
+ return self._generate_generic_binary(
+ binary, " %s* " % base_op, **kw
+ )
+ flags = self.process(flags, **kw)
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ return "%s %s CONCAT('(?', %s, ')', %s)" % (
+ string,
+ base_op,
+ flags,
+ pattern,
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match("~", binary, operator, kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match("!~", binary, operator, kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ flags = binary.modifiers["flags"]
+ if flags is not None:
+ flags = self.process(flags, **kw)
+ replacement = self.process(binary.modifiers["replacement"], **kw)
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ flags,
+ )
+
+ def visit_empty_set_expr(self, element_types):
+ # cast the empty set to the type we are comparing against. if
+ # we are comparing against the null type, pick an arbitrary
+ # datatype for the empty set
+ return "SELECT %s WHERE 1!=1" % (
+ ", ".join(
+ "CAST(NULL AS %s)"
+ % self.dialect.type_compiler.process(
+ INTEGER() if type_._isnull else type_
+ )
+ for type_ in element_types or [INTEGER()]
+ ),
+ )
+
+ def render_literal_value(self, value, type_):
+ value = super(PGCompiler, self).render_literal_value(value, type_)
+
+ if self.dialect._backslash_escapes:
+ value = value.replace("\\", "\\\\")
+ return value
+
+ def visit_sequence(self, seq, **kw):
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += " \n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT ALL"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def format_from_hint_text(self, sqltext, table, hint, iscrud):
+ if hint.upper() != "ONLY":
+ raise exc.CompileError("Unrecognized hint: %r" % hint)
+ return "ONLY " + sqltext
+
+ def get_select_precolumns(self, select, **kw):
+ # Do not call super().get_select_precolumns because
+ # it will warn/raise when distinct on is present
+ if select._distinct or select._distinct_on:
+ if select._distinct_on:
+ return (
+ "DISTINCT ON ("
+ + ", ".join(
+ [
+ self.process(col, **kw)
+ for col in select._distinct_on
+ ]
+ )
+ + ") "
+ )
+ else:
+ return "DISTINCT "
+ else:
+ return ""
+
+ def for_update_clause(self, select, **kw):
+
+ if select._for_update_arg.read:
+ if select._for_update_arg.key_share:
+ tmp = " FOR KEY SHARE"
+ else:
+ tmp = " FOR SHARE"
+ elif select._for_update_arg.key_share:
+ tmp = " FOR NO KEY UPDATE"
+ else:
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self._label_returning_column(stmt, c)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0], **kw)
+ start = self.process(func.clauses.clauses[1], **kw)
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2], **kw)
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+
+ def _on_conflict_target(self, clause, **kw):
+
+ if clause.constraint_target is not None:
+ # target may be a name of an Index, UniqueConstraint or
+ # ExcludeConstraint. While there is a separate
+ # "max_identifier_length" for indexes, PostgreSQL uses the same
+ # length for all objects so we can use
+ # truncate_and_render_constraint_name
+ target_text = (
+ "ON CONSTRAINT %s"
+ % self.preparer.truncate_and_render_constraint_name(
+ clause.constraint_target
+ )
+ )
+ elif clause.inferred_target_elements is not None:
+ target_text = "(%s)" % ", ".join(
+ (
+ self.preparer.quote(c)
+ if isinstance(c, util.string_types)
+ else self.process(c, include_table=False, use_schema=False)
+ )
+ for c in clause.inferred_target_elements
+ )
+ if clause.inferred_target_whereclause is not None:
+ target_text += " WHERE %s" % self.process(
+ clause.inferred_target_whereclause,
+ include_table=False,
+ use_schema=False,
+ )
+ else:
+ target_text = ""
+
+ return target_text
+
+ @util.memoized_property
+ def _is_safe_for_fast_insert_values_helper(self):
+ # don't allow fast executemany if _post_values_clause is
+ # present and is not an OnConflictDoNothing. what this means
+ # concretely is that the
+ # "fast insert executemany helper" won't be used, in other
+ # words we won't convert "executemany()" of many parameter
+ # sets into a single INSERT with many elements in VALUES.
+ # We can't apply that optimization safely if for example the
+ # statement includes a clause like "ON CONFLICT DO UPDATE"
+
+ return self.insert_single_values_expr is not None and (
+ self.statement._post_values_clause is None
+ or isinstance(
+ self.statement._post_values_clause, dml.OnConflictDoNothing
+ )
+ )
+
+ def visit_on_conflict_do_nothing(self, on_conflict, **kw):
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ if target_text:
+ return "ON CONFLICT %s DO NOTHING" % target_text
+ else:
+ return "ON CONFLICT DO NOTHING"
+
+ def visit_on_conflict_do_update(self, on_conflict, **kw):
+
+ clause = on_conflict
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ action_set_ops = []
+
+ set_parameters = dict(clause.update_values_to_set)
+ # create a list of column assignment clauses as tuples
+
+ insert_statement = self.stack[-1]["selectable"]
+ cols = insert_statement.table.c
+ for c in cols:
+ col_key = c.key
+
+ if col_key in set_parameters:
+ value = set_parameters.pop(col_key)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
+
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(c.name)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ # check for names that don't match columns
+ if set_parameters:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.current_executable.table.name,
+ (", ".join("'%s'" % c for c in set_parameters)),
+ )
+ )
+ for k, v in set_parameters.items():
+ key_text = (
+ self.preparer.quote(k)
+ if isinstance(k, util.string_types)
+ else self.process(k, use_schema=False)
+ )
+ value_text = self.process(
+ coercions.expect(roles.ExpressionElementRole, v),
+ use_schema=False,
+ )
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ action_text = ", ".join(action_set_ops)
+ if clause.update_whereclause is not None:
+ action_text += " WHERE %s" % self.process(
+ clause.update_whereclause, include_table=True, use_schema=False
+ )
+
+ return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. USING clause specific to PostgreSQL."""
+ kw["asfrom"] = True
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def fetch_clause(self, select, **kw):
+ # pg requires parens for non literal clauses. It's also required for
+ # bind parameters if a ::type casts is used by the driver (asyncpg),
+ # so it's easiest to just always add it
+ text = ""
+ if select._offset_clause is not None:
+ text += "\n OFFSET (%s) ROWS" % self.process(
+ select._offset_clause, **kw
+ )
+ if select._fetch_clause is not None:
+ text += "\n FETCH FIRST (%s)%s ROWS %s" % (
+ self.process(select._fetch_clause, **kw),
+ " PERCENT" if select._fetch_clause_options["percent"] else "",
+ "WITH TIES"
+ if select._fetch_clause_options["with_ties"]
+ else "ONLY",
+ )
+ return text
+
+
+class PGDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ colspec = self.preparer.format_column(column)
+ impl_type = column.type.dialect_impl(self.dialect)
+ if isinstance(impl_type, sqltypes.TypeDecorator):
+ impl_type = impl_type.impl
+
+ has_identity = (
+ column.identity is not None
+ and self.dialect.supports_identity_columns
+ )
+
+ if (
+ column.primary_key
+ and column is column.table._autoincrement_column
+ and (
+ self.dialect.supports_smallserial
+ or not isinstance(impl_type, sqltypes.SmallInteger)
+ )
+ and not has_identity
+ and (
+ column.default is None
+ or (
+ isinstance(column.default, schema.Sequence)
+ and column.default.optional
+ )
+ )
+ ):
+ if isinstance(impl_type, sqltypes.BigInteger):
+ colspec += " BIGSERIAL"
+ elif isinstance(impl_type, sqltypes.SmallInteger):
+ colspec += " SMALLSERIAL"
+ else:
+ colspec += " SERIAL"
+ else:
+ colspec += " " + self.dialect.type_compiler.process(
+ column.type,
+ type_expression=column,
+ identifier_preparer=self.preparer,
+ )
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+ if has_identity:
+ colspec += " " + self.process(column.identity)
+
+ if not column.nullable and not has_identity:
+ colspec += " NOT NULL"
+ elif column.nullable and has_identity:
+ colspec += " NULL"
+ return colspec
+
+ def _define_constraint_validity(self, constraint):
+ not_valid = constraint.dialect_options["postgresql"]["not_valid"]
+ return " NOT VALID" if not_valid else ""
+
+ def visit_check_constraint(self, constraint):
+ if constraint._type_bound:
+ typ = list(constraint.columns)[0].type
+ if (
+ isinstance(typ, sqltypes.ARRAY)
+ and isinstance(typ.item_type, sqltypes.Enum)
+ and not typ.item_type.native_enum
+ ):
+ raise exc.CompileError(
+ "PostgreSQL dialect cannot produce the CHECK constraint "
+ "for ARRAY of non-native ENUM; please specify "
+ "create_constraint=False on this Enum datatype."
+ )
+
+ text = super(PGDDLCompiler, self).visit_check_constraint(constraint)
+ text += self._define_constraint_validity(constraint)
+ return text
+
+ def visit_foreign_key_constraint(self, constraint):
+ text = super(PGDDLCompiler, self).visit_foreign_key_constraint(
+ constraint
+ )
+ text += self._define_constraint_validity(constraint)
+ return text
+
+ def visit_drop_table_comment(self, drop):
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_create_enum_type(self, create):
+ type_ = create.element
+
+ return "CREATE TYPE %s AS ENUM (%s)" % (
+ self.preparer.format_type(type_),
+ ", ".join(
+ self.sql_compiler.process(sql.literal(e), literal_binds=True)
+ for e in type_.enums
+ ),
+ )
+
+ def visit_drop_enum_type(self, drop):
+ type_ = drop.element
+
+ return "DROP TYPE %s" % (self.preparer.format_type(type_))
+
+ def visit_create_index(self, create):
+ preparer = self.preparer
+ index = create.element
+ self._verify_index_table(index)
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX "
+
+ if self.dialect._supports_create_index_concurrently:
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
+ if concurrently:
+ text += "CONCURRENTLY "
+
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s " % (
+ self._prepared_index_name(index, include_schema=False),
+ preparer.format_table(index.table),
+ )
+
+ using = index.dialect_options["postgresql"]["using"]
+ if using:
+ text += (
+ "USING %s "
+ % self.preparer.validate_sql_phrase(using, IDX_USING).lower()
+ )
+
+ ops = index.dialect_options["postgresql"]["ops"]
+ text += "(%s)" % (
+ ", ".join(
+ [
+ self.sql_compiler.process(
+ expr.self_group()
+ if not isinstance(expr, expression.ColumnClause)
+ else expr,
+ include_table=False,
+ literal_binds=True,
+ )
+ + (
+ (" " + ops[expr.key])
+ if hasattr(expr, "key") and expr.key in ops
+ else ""
+ )
+ for expr in index.expressions
+ ]
+ )
+ )
+
+ includeclause = index.dialect_options["postgresql"]["include"]
+ if includeclause:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in includeclause
+ ]
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
+
+ withclause = index.dialect_options["postgresql"]["with"]
+ if withclause:
+ text += " WITH (%s)" % (
+ ", ".join(
+ [
+ "%s = %s" % storage_parameter
+ for storage_parameter in withclause.items()
+ ]
+ )
+ )
+
+ tablespace_name = index.dialect_options["postgresql"]["tablespace"]
+ if tablespace_name:
+ text += " TABLESPACE %s" % preparer.quote(tablespace_name)
+
+ whereclause = index.dialect_options["postgresql"]["where"]
+ if whereclause is not None:
+ whereclause = coercions.expect(
+ roles.DDLExpressionRole, whereclause
+ )
+
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+
+ text = "\nDROP INDEX "
+
+ if self.dialect._supports_drop_index_concurrently:
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
+ if concurrently:
+ text += "CONCURRENTLY "
+
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ text += self._prepared_index_name(index, include_schema=True)
+ return text
+
+ def visit_exclude_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
+ elements = []
+ for expr, name, op in constraint._render_exprs:
+ kw["include_table"] = False
+ exclude_element = self.sql_compiler.process(expr, **kw) + (
+ (" " + constraint.ops[expr.key])
+ if hasattr(expr, "key") and expr.key in constraint.ops
+ else ""
+ )
+
+ elements.append("%s WITH %s" % (exclude_element, op))
+ text += "EXCLUDE USING %s (%s)" % (
+ self.preparer.validate_sql_phrase(
+ constraint.using, IDX_USING
+ ).lower(),
+ ", ".join(elements),
+ )
+ if constraint.where is not None:
+ text += " WHERE (%s)" % self.sql_compiler.process(
+ constraint.where, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def post_create_table(self, table):
+ table_opts = []
+ pg_opts = table.dialect_options["postgresql"]
+
+ inherits = pg_opts.get("inherits")
+ if inherits is not None:
+ if not isinstance(inherits, (list, tuple)):
+ inherits = (inherits,)
+ table_opts.append(
+ "\n INHERITS ( "
+ + ", ".join(self.preparer.quote(name) for name in inherits)
+ + " )"
+ )
+
+ if pg_opts["partition_by"]:
+ table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"])
+
+ if pg_opts["with_oids"] is True:
+ table_opts.append("\n WITH OIDS")
+ elif pg_opts["with_oids"] is False:
+ table_opts.append("\n WITHOUT OIDS")
+
+ if pg_opts["on_commit"]:
+ on_commit_options = pg_opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
+
+ if pg_opts["tablespace"]:
+ tablespace_name = pg_opts["tablespace"]
+ table_opts.append(
+ "\n TABLESPACE %s" % self.preparer.quote(tablespace_name)
+ )
+
+ return "".join(table_opts)
+
+ def visit_computed_column(self, generated):
+ if generated.persisted is False:
+ raise exc.CompileError(
+ "PostrgreSQL computed columns do not support 'virtual' "
+ "persistence; set the 'persisted' flag to None or True for "
+ "PostgreSQL support."
+ )
+
+ return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+
+ def visit_create_sequence(self, create, **kw):
+ prefix = None
+ if create.element.data_type is not None:
+ prefix = " AS %s" % self.type_compiler.process(
+ create.element.data_type
+ )
+
+ return super(PGDDLCompiler, self).visit_create_sequence(
+ create, prefix=prefix, **kw
+ )
+
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_TSVECTOR(self, type_, **kw):
+ return "TSVECTOR"
+
+ def visit_INET(self, type_, **kw):
+ return "INET"
+
+ def visit_CIDR(self, type_, **kw):
+ return "CIDR"
+
+ def visit_MACADDR(self, type_, **kw):
+ return "MACADDR"
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_OID(self, type_, **kw):
+ return "OID"
+
+ def visit_REGCLASS(self, type_, **kw):
+ return "REGCLASS"
+
+ def visit_FLOAT(self, type_, **kw):
+ if not type_.precision:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {"precision": type_.precision}
+
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
+ return "DOUBLE PRECISION"
+
+ def visit_BIGINT(self, type_, **kw):
+ return "BIGINT"
+
+ def visit_HSTORE(self, type_, **kw):
+ return "HSTORE"
+
+ def visit_JSON(self, type_, **kw):
+ return "JSON"
+
+ def visit_JSONB(self, type_, **kw):
+ return "JSONB"
+
+ def visit_INT4RANGE(self, type_, **kw):
+ return "INT4RANGE"
+
+ def visit_INT8RANGE(self, type_, **kw):
+ return "INT8RANGE"
+
+ def visit_NUMRANGE(self, type_, **kw):
+ return "NUMRANGE"
+
+ def visit_DATERANGE(self, type_, **kw):
+ return "DATERANGE"
+
+ def visit_TSRANGE(self, type_, **kw):
+ return "TSRANGE"
+
+ def visit_TSTZRANGE(self, type_, **kw):
+ return "TSTZRANGE"
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
+
+ def visit_enum(self, type_, **kw):
+ if not type_.native_enum or not self.dialect.supports_native_enum:
+ return super(PGTypeCompiler, self).visit_enum(type_, **kw)
+ else:
+ return self.visit_ENUM(type_, **kw)
+
+ def visit_ENUM(self, type_, identifier_preparer=None, **kw):
+ if identifier_preparer is None:
+ identifier_preparer = self.dialect.identifier_preparer
+ return identifier_preparer.format_type(type_)
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP%s %s" % (
+ "(%d)" % type_.precision
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
+ )
+
+ def visit_TIME(self, type_, **kw):
+ return "TIME%s %s" % (
+ "(%d)" % type_.precision
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
+ )
+
+ def visit_INTERVAL(self, type_, **kw):
+ text = "INTERVAL"
+ if type_.fields is not None:
+ text += " " + type_.fields
+ if type_.precision is not None:
+ text += " (%d)" % type_.precision
+ return text
+
+ def visit_BIT(self, type_, **kw):
+ if type_.varying:
+ compiled = "BIT VARYING"
+ if type_.length is not None:
+ compiled += "(%d)" % type_.length
+ else:
+ compiled = "BIT(%d)" % type_.length
+ return compiled
+
+ def visit_UUID(self, type_, **kw):
+ return "UUID"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BYTEA(type_, **kw)
+
+ def visit_BYTEA(self, type_, **kw):
+ return "BYTEA"
+
+ def visit_ARRAY(self, type_, **kw):
+
+ inner = self.process(type_.item_type, **kw)
+ return re.sub(
+ r"((?: COLLATE.*)?)$",
+ (
+ r"%s\1"
+ % (
+ "[]"
+ * (type_.dimensions if type_.dimensions is not None else 1)
+ )
+ ),
+ inner,
+ count=1,
+ )
+
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+
+ reserved_words = RESERVED_WORDS
+
+ def _unquote_identifier(self, value):
+ if value[0] == self.initial_quote:
+ value = value[1:-1].replace(
+ self.escape_to_quote, self.escape_quote
+ )
+ return value
+
+ def format_type(self, type_, use_schema=True):
+ if not type_.name:
+ raise exc.CompileError("PostgreSQL ENUM type requires a name.")
+
+ name = self.quote(type_.name)
+ effective_schema = self.schema_for_object(type_)
+
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
+ name = self.quote_schema(effective_schema) + "." + name
+ return name
+
+
+class PGInspector(reflection.Inspector):
+ def get_table_oid(self, table_name, schema=None):
+ """Return the OID for the given table name."""
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_oid(
+ conn, table_name, schema, info_cache=self.info_cache
+ )
+
+ def get_enums(self, schema=None):
+ """Return a list of ENUM objects.
+
+ Each member is a dictionary containing these fields:
+
+ * name - name of the enum
+ * schema - the schema name for the enum.
+ * visible - boolean, whether or not this enum is visible
+ in the default search path.
+ * labels - a list of string labels that apply to the enum.
+
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ indicate load enums for all schemas.
+
+ .. versionadded:: 1.0.0
+
+ """
+ schema = schema or self.default_schema_name
+ with self._operation_context() as conn:
+ return self.dialect._load_enums(conn, schema)
+
+ def get_foreign_table_names(self, schema=None):
+ """Return a list of FOREIGN TABLE names.
+
+ Behavior is similar to that of
+ :meth:`_reflection.Inspector.get_table_names`,
+ except that the list is limited to those tables that report a
+ ``relkind`` value of ``f``.
+
+ .. versionadded:: 1.0.0
+
+ """
+ schema = schema or self.default_schema_name
+ with self._operation_context() as conn:
+ return self.dialect._get_foreign_table_names(conn, schema)
+
+ def get_view_names(self, schema=None, include=("plain", "materialized")):
+ """Return all view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ :param include: specify which types of views to return. Passed
+ as a string value (for a single type) or a tuple (for any number
+ of types). Defaults to ``('plain', 'materialized')``.
+
+ .. versionadded:: 1.1
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_names(
+ conn, schema, info_cache=self.info_cache, include=include
+ )
+
+
+class CreateEnumType(schema._CreateDropBase):
+ __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+ __visit_name__ = "drop_enum_type"
+
+
+class PGExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "select nextval('%s')"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+ def get_insert_default(self, column):
+ if column.primary_key and column is column.table._autoincrement_column:
+ if column.server_default and column.server_default.has_argument:
+
+ # pre-execute passive defaults on primary key columns
+ return self._execute_scalar(
+ "select %s" % column.server_default.arg, column.type
+ )
+
+ elif column.default is None or (
+ column.default.is_sequence and column.default.optional
+ ):
+ # execute the sequence associated with a SERIAL primary
+ # key column. for non-primary-key SERIAL, the ID just
+ # generates server side.
+
+ try:
+ seq_name = column._postgresql_seq_name
+ except AttributeError:
+ tab = column.table.name
+ col = column.name
+ tab = tab[0 : 29 + max(0, (29 - len(col)))]
+ col = col[0 : 29 + max(0, (29 - len(tab)))]
+ name = "%s_%s_seq" % (tab, col)
+ column._postgresql_seq_name = seq_name = name
+
+ if column.table is not None:
+ effective_schema = self.connection.schema_for_object(
+ column.table
+ )
+ else:
+ effective_schema = None
+
+ if effective_schema is not None:
+ exc = 'select nextval(\'"%s"."%s"\')' % (
+ effective_schema,
+ seq_name,
+ )
+ else:
+ exc = "select nextval('\"%s\"')" % (seq_name,)
+
+ return self._execute_scalar(exc, column.type)
+
+ return super(PGExecutionContext, self).get_insert_default(column)
+
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_REGEXP.match(statement)
+
+
+class PGReadOnlyConnectionCharacteristic(
+ characteristics.ConnectionCharacteristic
+):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.set_readonly(dbapi_conn, False)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_readonly(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_readonly(dbapi_conn)
+
+
+class PGDeferrableConnectionCharacteristic(
+ characteristics.ConnectionCharacteristic
+):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.set_deferrable(dbapi_conn, False)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_deferrable(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_deferrable(dbapi_conn)
+
+
+class PGDialect(default.DefaultDialect):
+ name = "postgresql"
+ supports_statement_cache = True
+ supports_alter = True
+ max_identifier_length = 63
+ supports_sane_rowcount = True
+
+ supports_native_enum = True
+ supports_native_boolean = True
+ supports_smallserial = True
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = True
+ postfetch_lastrowid = False
+
+ supports_comments = True
+ supports_default_values = True
+
+ supports_default_metavalue = True
+
+ supports_empty_insert = False
+ supports_multivalues_insert = True
+ supports_identity_columns = True
+
+ default_paramstyle = "pyformat"
+ ischema_names = ischema_names
+ colspecs = colspecs
+
+ statement_compiler = PGCompiler
+ ddl_compiler = PGDDLCompiler
+ type_compiler = PGTypeCompiler
+ preparer = PGIdentifierPreparer
+ execution_ctx_cls = PGExecutionContext
+ inspector = PGInspector
+ isolation_level = None
+
+ implicit_returning = True
+ full_returning = True
+
+ connection_characteristics = (
+ default.DefaultDialect.connection_characteristics
+ )
+ connection_characteristics = connection_characteristics.union(
+ {
+ "postgresql_readonly": PGReadOnlyConnectionCharacteristic(),
+ "postgresql_deferrable": PGDeferrableConnectionCharacteristic(),
+ }
+ )
+
+ construct_arguments = [
+ (
+ schema.Index,
+ {
+ "using": False,
+ "include": None,
+ "where": None,
+ "ops": {},
+ "concurrently": False,
+ "with": {},
+ "tablespace": None,
+ },
+ ),
+ (
+ schema.Table,
+ {
+ "ignore_search_path": False,
+ "tablespace": None,
+ "partition_by": None,
+ "with_oids": None,
+ "on_commit": None,
+ "inherits": None,
+ },
+ ),
+ (
+ schema.CheckConstraint,
+ {
+ "not_valid": False,
+ },
+ ),
+ (
+ schema.ForeignKeyConstraint,
+ {
+ "not_valid": False,
+ },
+ ),
+ ]
+
+ reflection_options = ("postgresql_ignore_search_path",)
+
+ _backslash_escapes = True
+ _supports_create_index_concurrently = True
+ _supports_drop_index_concurrently = True
+
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+
+ # the isolation_level parameter to the PGDialect itself is legacy.
+ # still works however the execution_options method is the one that
+ # is documented.
+ self.isolation_level = isolation_level
+ self._json_deserializer = json_deserializer
+ self._json_serializer = json_serializer
+
+ def initialize(self, connection):
+ super(PGDialect, self).initialize(connection)
+
+ if self.server_version_info <= (8, 2):
+ self.full_returning = self.implicit_returning = False
+
+ self.supports_native_enum = self.server_version_info >= (8, 3)
+ if not self.supports_native_enum:
+ self.colspecs = self.colspecs.copy()
+ # pop base Enum type
+ self.colspecs.pop(sqltypes.Enum, None)
+ # psycopg2, others may have placed ENUM here as well
+ self.colspecs.pop(ENUM, None)
+
+ # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
+ self.supports_smallserial = self.server_version_info >= (9, 2)
+
+ if self.server_version_info < (8, 2):
+ self._backslash_escapes = False
+ else:
+ # ensure this query is not emitted on server version < 8.2
+ # as it will fail
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
+ self._supports_create_index_concurrently = (
+ self.server_version_info >= (8, 2)
+ )
+ self._supports_drop_index_concurrently = self.server_version_info >= (
+ 9,
+ 2,
+ )
+ self.supports_identity_columns = self.server_version_info >= (10,)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION "
+ "ISOLATION LEVEL %s" % level
+ )
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ cursor.execute("show transaction isolation level")
+ val = cursor.fetchone()[0]
+ cursor.close()
+ return val.upper()
+
+ def set_readonly(self, connection, value):
+ raise NotImplementedError()
+
+ def get_readonly(self, connection):
+ raise NotImplementedError()
+
+ def set_deferrable(self, connection, value):
+ raise NotImplementedError()
+
+ def get_deferrable(self, connection):
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid)
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ if recover:
+ # FIXME: ugly hack to get out of transaction
+ # context when committing recoverable transactions
+ # Must find out a way how to make the dbapi not
+ # open a transaction.
+ connection.exec_driver_sql("ROLLBACK")
+ connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+ connection.exec_driver_sql("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ if recover:
+ connection.exec_driver_sql("ROLLBACK")
+ connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid)
+ connection.exec_driver_sql("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(
+ sql.text("SELECT gid FROM pg_prepared_xacts")
+ )
+ return [row[0] for row in resultset]
+
+ def _get_default_schema_name(self, connection):
+ return connection.exec_driver_sql("select current_schema()").scalar()
+
+ def has_schema(self, connection, schema):
+ query = (
+ "select nspname from pg_namespace " "where lower(nspname)=:schema"
+ )
+ cursor = connection.execute(
+ sql.text(query).bindparams(
+ sql.bindparam(
+ "schema",
+ util.text_type(schema.lower()),
+ type_=sqltypes.Unicode,
+ )
+ )
+ )
+
+ return bool(cursor.first())
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+ # seems like case gets folded in pg_class...
+ if schema is None:
+ cursor = connection.execute(
+ sql.text(
+ "select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where "
+ "pg_catalog.pg_table_is_visible(c.oid) "
+ "and relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ )
+ )
+ )
+ else:
+ cursor = connection.execute(
+ sql.text(
+ "select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=:schema and "
+ "relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+ return bool(cursor.first())
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if schema is None:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT relname FROM pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where relkind='S' and "
+ "n.nspname=:schema and relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(sequence_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+
+ return bool(cursor.first())
+
+ def has_type(self, connection, type_name, schema=None):
+ if schema is not None:
+ query = """
+ SELECT EXISTS (
+ SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
+ WHERE t.typnamespace = n.oid
+ AND t.typname = :typname
+ AND n.nspname = :nspname
+ )
+ """
+ query = sql.text(query)
+ else:
+ query = """
+ SELECT EXISTS (
+ SELECT * FROM pg_catalog.pg_type t
+ WHERE t.typname = :typname
+ AND pg_type_is_visible(t.oid)
+ )
+ """
+ query = sql.text(query)
+ query = query.bindparams(
+ sql.bindparam(
+ "typname", util.text_type(type_name), type_=sqltypes.Unicode
+ )
+ )
+ if schema is not None:
+ query = query.bindparams(
+ sql.bindparam(
+ "nspname", util.text_type(schema), type_=sqltypes.Unicode
+ )
+ )
+ cursor = connection.execute(query)
+ return bool(cursor.scalar())
+
+ def _get_server_version_info(self, connection):
+ v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
+ m = re.match(
+ r".*(?:PostgreSQL|EnterpriseDB) "
+ r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?",
+ v,
+ )
+ if not m:
+ raise AssertionError(
+ "Could not determine version from string '%s'" % v
+ )
+ return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
+
+ @reflection.cache
+ def get_table_oid(self, connection, table_name, schema=None, **kw):
+ """Fetch the oid for schema.table_name.
+
+ Several reflection methods require the table oid. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+ table_oid = None
+ if schema is not None:
+ schema_where_clause = "n.nspname = :schema"
+ else:
+ schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
+ query = (
+ """
+ SELECT c.oid
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE (%s)
+ AND c.relname = :table_name AND c.relkind in
+ ('r', 'v', 'm', 'f', 'p')
+ """
+ % schema_where_clause
+ )
+ # Since we're binding to unicode, table_name and schema_name must be
+ # unicode.
+ table_name = util.text_type(table_name)
+ if schema is not None:
+ schema = util.text_type(schema)
+ s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
+ s = s.columns(oid=sqltypes.Integer)
+ if schema:
+ s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
+ c = connection.execute(s, dict(table_name=table_name, schema=schema))
+ table_oid = c.scalar()
+ if table_oid is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_oid
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT nspname FROM pg_namespace "
+ "WHERE nspname NOT LIKE 'pg_%' "
+ "ORDER BY nspname"
+ ).columns(nspname=sqltypes.Unicode)
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def _get_foreign_table_names(self, connection, schema=None, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind = 'f'"
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_view_names(
+ self, connection, schema=None, include=("plain", "materialized"), **kw
+ ):
+
+ include_kind = {"plain": "v", "materialized": "m"}
+ try:
+ kinds = [include_kind[i] for i in util.to_list(include)]
+ except KeyError:
+ raise ValueError(
+ "include %r unknown, needs to be a sequence containing "
+ "one or both of 'plain' and 'materialized'" % (include,)
+ )
+ if not kinds:
+ raise ValueError(
+ "empty include, needs to be a sequence containing "
+ "one or both of 'plain' and 'materialized'"
+ )
+
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind IN (%s)"
+ % (", ".join("'%s'" % elem for elem in kinds))
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT relname FROM pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where relkind='S' and "
+ "n.nspname=:schema"
+ ).bindparams(
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+ return [row[0] for row in cursor]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ view_def = connection.scalar(
+ sql.text(
+ "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relname = :view_name "
+ "AND c.relkind IN ('v', 'm')"
+ ).columns(view_def=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name,
+ view_name=view_name,
+ ),
+ )
+ return view_def
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ generated = (
+ "a.attgenerated as generated"
+ if self.server_version_info >= (12,)
+ else "NULL as generated"
+ )
+ if self.server_version_info >= (10,):
+ # a.attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ identity = """\
+ (SELECT json_build_object(
+ 'always', a.attidentity = 'a',
+ 'start', s.seqstart,
+ 'increment', s.seqincrement,
+ 'minvalue', s.seqmin,
+ 'maxvalue', s.seqmax,
+ 'cache', s.seqcache,
+ 'cycle', s.seqcycle)
+ FROM pg_catalog.pg_sequence s
+ JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
+ WHERE c.relkind = 'S'
+ AND a.attidentity != ''
+ AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
+ a.attrelid::regclass::text, a.attname
+ )::regclass::oid
+ ) as identity_options\
+ """
+ else:
+ identity = "NULL as identity_options"
+
+ SQL_COLS = """
+ SELECT a.attname,
+ pg_catalog.format_type(a.atttypid, a.atttypmod),
+ (
+ SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
+ FROM pg_catalog.pg_attrdef d
+ WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
+ AND a.atthasdef
+ ) AS DEFAULT,
+ a.attnotnull,
+ a.attrelid as table_oid,
+ pgd.description as comment,
+ %s,
+ %s
+ FROM pg_catalog.pg_attribute a
+ LEFT JOIN pg_catalog.pg_description pgd ON (
+ pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
+ WHERE a.attrelid = :table_oid
+ AND a.attnum > 0 AND NOT a.attisdropped
+ ORDER BY a.attnum
+ """ % (
+ generated,
+ identity,
+ )
+ s = (
+ sql.text(SQL_COLS)
+ .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
+ .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ )
+ c = connection.execute(s, dict(table_oid=table_oid))
+ rows = c.fetchall()
+
+ # dictionary with (name, ) if default search path or (schema, name)
+ # as keys
+ domains = self._load_domains(connection)
+
+ # dictionary with (name, ) if default search path or (schema, name)
+ # as keys
+ enums = dict(
+ ((rec["name"],), rec)
+ if rec["visible"]
+ else ((rec["schema"], rec["name"]), rec)
+ for rec in self._load_enums(connection, schema="*")
+ )
+
+ # format columns
+ columns = []
+
+ for (
+ name,
+ format_type,
+ default_,
+ notnull,
+ table_oid,
+ comment,
+ generated,
+ identity,
+ ) in rows:
+ column_info = self._get_column_info(
+ name,
+ format_type,
+ default_,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ generated,
+ identity,
+ )
+ columns.append(column_info)
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ format_type,
+ default,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ generated,
+ identity,
+ ):
+ def _handle_array_type(attype):
+ return (
+ # strip '[]' from integer[], etc.
+ re.sub(r"\[\]$", "", attype),
+ attype.endswith("[]"),
+ )
+
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = re.sub(r"\(.*\)", "", format_type)
+
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
+
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not notnull
+
+ charlen = re.search(r"\(([\d,]+)\)", format_type)
+ if charlen:
+ charlen = charlen.group(1)
+ args = re.search(r"\((.*)\)", format_type)
+ if args and args.group(1):
+ args = tuple(re.split(r"\s*,\s*", args.group(1)))
+ else:
+ args = ()
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
+ args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
+ else:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also absent
+ # for older PG versions), then not a generated column. Otherwise, s =
+ # stored. (Other values might be added in the future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
+ else:
+ computed = None
+
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ sch = schema
+ if "." not in match.group(2) and sch is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % sch)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement or identity is not None,
+ comment=comment,
+ )
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+ return column_info
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ if self.server_version_info < (8, 4):
+ PK_SQL = """
+ SELECT a.attname
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_attribute a
+ on t.oid=a.attrelid AND %s
+ WHERE
+ t.oid = :table_oid and ix.indisprimary = 't'
+ ORDER BY a.attnum
+ """ % self._pg_index_any(
+ "a.attnum", "ix.indkey"
+ )
+
+ else:
+ # unnest() and generate_subscripts() both introduced in
+ # version 8.4
+ PK_SQL = """
+ SELECT a.attname
+ FROM pg_attribute a JOIN (
+ SELECT unnest(ix.indkey) attnum,
+ generate_subscripts(ix.indkey, 1) ord
+ FROM pg_index ix
+ WHERE ix.indrelid = :table_oid AND ix.indisprimary
+ ) k ON a.attnum=k.attnum
+ WHERE a.attrelid = :table_oid
+ ORDER BY k.ord
+ """
+ t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+ cols = [r[0] for r in c.fetchall()]
+
+ PK_CONS_SQL = """
+ SELECT conname
+ FROM pg_catalog.pg_constraint r
+ WHERE r.conrelid = :table_oid AND r.contype = 'p'
+ ORDER BY 1
+ """
+ t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+ name = c.scalar()
+
+ return {"constrained_columns": cols, "name": name}
+
+ @reflection.cache
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ postgresql_ignore_search_path=False,
+ **kw
+ ):
+ preparer = self.identifier_preparer
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ FK_SQL = """
+ SELECT r.conname,
+ pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
+ n.nspname as conschema
+ FROM pg_catalog.pg_constraint r,
+ pg_namespace n,
+ pg_class c
+
+ WHERE r.conrelid = :table AND
+ r.contype = 'f' AND
+ c.oid = confrelid AND
+ n.oid = c.relnamespace
+ ORDER BY 1
+ """
+ # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
+ FK_REGEX = re.compile(
+ r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
+ r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
+ r"[\s]?(ON UPDATE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(ON DELETE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
+ r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
+ )
+
+ t = sql.text(FK_SQL).columns(
+ conname=sqltypes.Unicode, condef=sqltypes.Unicode
+ )
+ c = connection.execute(t, dict(table=table_oid))
+ fkeys = []
+ for conname, condef, conschema in c.fetchall():
+ m = re.search(FK_REGEX, condef).groups()
+
+ (
+ constrained_columns,
+ referred_schema,
+ referred_table,
+ referred_columns,
+ _,
+ match,
+ _,
+ onupdate,
+ _,
+ ondelete,
+ deferrable,
+ _,
+ initially,
+ ) = m
+
+ if deferrable is not None:
+ deferrable = True if deferrable == "DEFERRABLE" else False
+ constrained_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s*", constrained_columns)
+ ]
+
+ if postgresql_ignore_search_path:
+ # when ignoring search path, we use the actual schema
+ # provided it isn't the "default" schema
+ if conschema != self.default_schema_name:
+ referred_schema = conschema
+ else:
+ referred_schema = schema
+ elif referred_schema:
+ # referred_schema is the schema that we regexp'ed from
+ # pg_get_constraintdef(). If the schema is in the search
+ # path, pg_get_constraintdef() will give us None.
+ referred_schema = preparer._unquote_identifier(referred_schema)
+ elif schema is not None and schema == conschema:
+ # If the actual schema matches the schema of the table
+ # we're reflecting, then we will use that.
+ referred_schema = schema
+
+ referred_table = preparer._unquote_identifier(referred_table)
+ referred_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s", referred_columns)
+ ]
+ options = {
+ k: v
+ for k, v in [
+ ("onupdate", onupdate),
+ ("ondelete", ondelete),
+ ("initially", initially),
+ ("deferrable", deferrable),
+ ("match", match),
+ ]
+ if v is not None and v != "NO ACTION"
+ }
+ fkey_d = {
+ "name": conname,
+ "constrained_columns": constrained_columns,
+ "referred_schema": referred_schema,
+ "referred_table": referred_table,
+ "referred_columns": referred_columns,
+ "options": options,
+ }
+ fkeys.append(fkey_d)
+ return fkeys
+
+ def _pg_index_any(self, col, compare_to):
+ if self.server_version_info < (8, 1):
+ # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
+ # "In CVS tip you could replace this with "attnum = ANY (indkey)".
+ # Unfortunately, most array support doesn't work on int2vector in
+ # pre-8.1 releases, so I think you're kinda stuck with the above
+ # for now.
+ # regards, tom lane"
+ return "(%s)" % " OR ".join(
+ "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
+ )
+ else:
+ return "%s = ANY(%s)" % (col, compare_to)
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ # cast indkey as varchar since it's an int2vector,
+ # returned as a list by some drivers such as pypostgresql
+
+ if self.server_version_info < (8, 5):
+ IDX_SQL = """
+ SELECT
+ i.relname as relname,
+ ix.indisunique, ix.indexprs, ix.indpred,
+ a.attname, a.attnum, NULL, ix.indkey%s,
+ %s, %s, am.amname,
+ NULL as indnkeyatts
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_class i on i.oid = ix.indexrelid
+ left outer join
+ pg_attribute a
+ on t.oid = a.attrelid and %s
+ left outer join
+ pg_am am
+ on i.relam = am.oid
+ WHERE
+ t.relkind IN ('r', 'v', 'f', 'm')
+ and t.oid = :table_oid
+ and ix.indisprimary = 'f'
+ ORDER BY
+ t.relname,
+ i.relname
+ """ % (
+ # version 8.3 here was based on observing the
+ # cast does not work in PG 8.2.4, does work in 8.3.0.
+ # nothing in PG changelogs regarding this.
+ "::varchar" if self.server_version_info >= (8, 3) else "",
+ "ix.indoption::varchar"
+ if self.server_version_info >= (8, 3)
+ else "NULL",
+ "i.reloptions"
+ if self.server_version_info >= (8, 2)
+ else "NULL",
+ self._pg_index_any("a.attnum", "ix.indkey"),
+ )
+ else:
+ IDX_SQL = """
+ SELECT
+ i.relname as relname,
+ ix.indisunique, ix.indexprs,
+ a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
+ ix.indoption::varchar, i.reloptions, am.amname,
+ pg_get_expr(ix.indpred, ix.indrelid),
+ %s as indnkeyatts
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_class i on i.oid = ix.indexrelid
+ left outer join
+ pg_attribute a
+ on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
+ left outer join
+ pg_constraint c
+ on (ix.indrelid = c.conrelid and
+ ix.indexrelid = c.conindid and
+ c.contype in ('p', 'u', 'x'))
+ left outer join
+ pg_am am
+ on i.relam = am.oid
+ WHERE
+ t.relkind IN ('r', 'v', 'f', 'm', 'p')
+ and t.oid = :table_oid
+ and ix.indisprimary = 'f'
+ ORDER BY
+ t.relname,
+ i.relname
+ """ % (
+ "ix.indnkeyatts"
+ if self.server_version_info >= (11, 0)
+ else "NULL",
+ )
+
+ t = sql.text(IDX_SQL).columns(
+ relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ )
+ c = connection.execute(t, dict(table_oid=table_oid))
+
+ indexes = defaultdict(lambda: defaultdict(dict))
+
+ sv_idx_name = None
+ for row in c.fetchall():
+ (
+ idx_name,
+ unique,
+ expr,
+ col,
+ col_num,
+ conrelid,
+ idx_key,
+ idx_option,
+ options,
+ amname,
+ filter_definition,
+ indnkeyatts,
+ ) = row
+
+ if expr:
+ if idx_name != sv_idx_name:
+ util.warn(
+ "Skipped unsupported reflection of "
+ "expression-based index %s" % idx_name
+ )
+ sv_idx_name = idx_name
+ continue
+
+ has_idx = idx_name in indexes
+ index = indexes[idx_name]
+ if col is not None:
+ index["cols"][col_num] = col
+ if not has_idx:
+ idx_keys = idx_key.split()
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and idx_keys[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_keys = idx_keys[indnkeyatts:]
+ idx_keys = idx_keys[:indnkeyatts]
+ else:
+ inc_keys = []
+
+ index["key"] = [int(k.strip()) for k in idx_keys]
+ index["inc"] = [int(k.strip()) for k in inc_keys]
+
+ # (new in pg 8.3)
+ # "pg_index.indoption" is list of ints, one per column/expr.
+ # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
+ sorting = {}
+ for col_idx, col_flags in enumerate(
+ (idx_option or "").split()
+ ):
+ col_flags = int(col_flags.strip())
+ col_sorting = ()
+ # try to set flags only if they differ from PG defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[col_idx] = col_sorting
+ if sorting:
+ index["sorting"] = sorting
+
+ index["unique"] = unique
+ if conrelid is not None:
+ index["duplicates_constraint"] = idx_name
+ if options:
+ index["options"] = dict(
+ [option.split("=") for option in options]
+ )
+
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ if amname and amname != "btree":
+ index["amname"] = amname
+
+ if filter_definition:
+ index["postgresql_where"] = filter_definition
+
+ result = []
+ for name, idx in indexes.items():
+ entry = {
+ "name": name,
+ "unique": idx["unique"],
+ "column_names": [idx["cols"][i] for i in idx["key"]],
+ }
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of dialect_options now
+ # as of #7382
+ entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
+ if "duplicates_constraint" in idx:
+ entry["duplicates_constraint"] = idx["duplicates_constraint"]
+ if "sorting" in idx:
+ entry["column_sorting"] = dict(
+ (idx["cols"][idx["key"][i]], value)
+ for i, value in idx["sorting"].items()
+ )
+ if "include_columns" in entry:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_include"
+ ] = entry["include_columns"]
+ if "options" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_with"
+ ] = idx["options"]
+ if "amname" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_using"
+ ] = idx["amname"]
+ if "postgresql_where" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_where"
+ ] = idx["postgresql_where"]
+ result.append(entry)
+ return result
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ UNIQUE_SQL = """
+ SELECT
+ cons.conname as name,
+ cons.conkey as key,
+ a.attnum as col_num,
+ a.attname as col_name
+ FROM
+ pg_catalog.pg_constraint cons
+ join pg_attribute a
+ on cons.conrelid = a.attrelid AND
+ a.attnum = ANY(cons.conkey)
+ WHERE
+ cons.conrelid = :table_oid AND
+ cons.contype = 'u'
+ """
+
+ t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+
+ uniques = defaultdict(lambda: defaultdict(dict))
+ for row in c.fetchall():
+ uc = uniques[row.name]
+ uc["key"] = row.key
+ uc["cols"][row.col_num] = row.col_name
+
+ return [
+ {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
+ for name, uc in uniques.items()
+ ]
+
+ @reflection.cache
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ COMMENT_SQL = """
+ SELECT
+ pgd.description as table_comment
+ FROM
+ pg_catalog.pg_description pgd
+ WHERE
+ pgd.objsubid = 0 AND
+ pgd.objoid = :table_oid
+ """
+
+ c = connection.execute(
+ sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ )
+ return {"text": c.scalar()}
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ CHECK_SQL = """
+ SELECT
+ cons.conname as name,
+ pg_get_constraintdef(cons.oid) as src
+ FROM
+ pg_catalog.pg_constraint cons
+ WHERE
+ cons.conrelid = :table_oid AND
+ cons.contype = 'c'
+ """
+
+ c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+
+ ret = []
+ for name, src in c:
+ # samples:
+ # "CHECK (((a > 1) AND (a < 5)))"
+ # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
+ # "CHECK (((a > 1) AND (a < 5))) NOT VALID"
+ # "CHECK (some_boolean_function(a))"
+ # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)"
+
+ m = re.match(
+ r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL
+ )
+ if not m:
+ util.warn("Could not parse CHECK constraint text: %r" % src)
+ sqltext = ""
+ else:
+ sqltext = re.compile(
+ r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
+ ).sub(r"\1", m.group(1))
+ entry = {"name": name, "sqltext": sqltext}
+ if m and m.group(2):
+ entry["dialect_options"] = {"not_valid": True}
+
+ ret.append(entry)
+ return ret
+
+ def _load_enums(self, connection, schema=None):
+ schema = schema or self.default_schema_name
+ if not self.supports_native_enum:
+ return {}
+
+ # Load data types for enums:
+ SQL_ENUMS = """
+ SELECT t.typname as "name",
+ -- no enum defaults in 8.4 at least
+ -- t.typdefault as "default",
+ pg_catalog.pg_type_is_visible(t.oid) as "visible",
+ n.nspname as "schema",
+ e.enumlabel as "label"
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
+ WHERE t.typtype = 'e'
+ """
+
+ if schema != "*":
+ SQL_ENUMS += "AND n.nspname = :schema "
+
+ # e.oid gives us label order within an enum
+ SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+
+ s = sql.text(SQL_ENUMS).columns(
+ attname=sqltypes.Unicode, label=sqltypes.Unicode
+ )
+
+ if schema != "*":
+ s = s.bindparams(schema=schema)
+
+ c = connection.execute(s)
+
+ enums = []
+ enum_by_name = {}
+ for enum in c.fetchall():
+ key = (enum.schema, enum.name)
+ if key in enum_by_name:
+ enum_by_name[key]["labels"].append(enum.label)
+ else:
+ enum_by_name[key] = enum_rec = {
+ "name": enum.name,
+ "schema": enum.schema,
+ "visible": enum.visible,
+ "labels": [],
+ }
+ if enum.label is not None:
+ enum_rec["labels"].append(enum.label)
+ enums.append(enum_rec)
+ return enums
+
+ def _load_domains(self, connection):
+ # Load data types for domains:
+ SQL_DOMAINS = """
+ SELECT t.typname as "name",
+ pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
+ not t.typnotnull as "nullable",
+ t.typdefault as "default",
+ pg_catalog.pg_type_is_visible(t.oid) as "visible",
+ n.nspname as "schema"
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ WHERE t.typtype = 'd'
+ """
+
+ s = sql.text(SQL_DOMAINS)
+ c = connection.execution_options(future_result=True).execute(s)
+
+ domains = {}
+ for domain in c.mappings():
+ domain = domain
+ # strip (30) from character varying(30)
+ attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
+ # 'visible' just means whether or not the domain is in a
+ # schema that's on the search path -- or not overridden by
+ # a schema with higher precedence. If it's not visible,
+ # it will be prefixed with the schema-name when it's used.
+ if domain["visible"]:
+ key = (domain["name"],)
+ else:
+ key = (domain["schema"], domain["name"])
+
+ domains[key] = {
+ "attype": attype,
+ "nullable": domain["nullable"],
+ "default": domain["default"],
+ }
+
+ return domains
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
new file mode 100644
index 0000000..b483774
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -0,0 +1,274 @@
+# postgresql/on_conflict.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import ext
+from ... import util
+from ...sql import coercions
+from ...sql import roles
+from ...sql import schema
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """PostgreSQL-specific implementation of INSERT.
+
+ Adds methods for PG-specific syntaxes such as ON CONFLICT.
+
+ The :class:`_postgresql.Insert` object is created using the
+ :func:`sqlalchemy.dialects.postgresql.insert` function.
+
+ .. versionadded:: 1.1
+
+ """
+
+ stringify_dialect = "postgresql"
+ inherit_cache = False
+
+ @util.memoized_property
+ def excluded(self):
+ """Provide the ``excluded`` namespace for an ON CONFLICT statement
+
+ PG's ON CONFLICT clause allows reference to the row that would
+ be inserted, known as ``excluded``. This attribute provides
+ all columns in this row to be referenceable.
+
+ .. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an
+ instance of :class:`_expression.ColumnCollection`, which provides
+ an interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.excluded.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.excluded["column name"]`` or
+ ``stmt.excluded["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict` - example of how
+ to use :attr:`_expression.Insert.excluded`
+
+ """
+ return alias(self.table, name="excluded").columns
+
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_update(
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ r"""
+ Specifies a DO UPDATE SET action for ON CONFLICT clause.
+
+ Either the ``constraint`` or ``index_elements`` argument is
+ required, but only one of these can be specified.
+
+ :param constraint:
+ The name of a unique or exclusion constraint on the table,
+ or the constraint object itself if it has a .name attribute.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ :param set\_:
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of
+ UPDATE, unless they are manually specified in the
+ :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+
+ :param where:
+ Optional argument. If present, can be a literal SQL
+ string or an acceptable expression for a ``WHERE`` clause
+ that restricts the rows affected by ``DO UPDATE SET``. Rows
+ not meeting the ``WHERE`` condition will not be updated
+ (effectively a ``DO NOTHING`` for those rows).
+
+ .. versionadded:: 1.1
+
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict`
+
+ """
+ self._post_values_clause = OnConflictDoUpdate(
+ constraint, index_elements, index_where, set_, where
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_nothing(
+ self, constraint=None, index_elements=None, index_where=None
+ ):
+ """
+ Specifies a DO NOTHING action for ON CONFLICT clause.
+
+ The ``constraint`` and ``index_elements`` arguments
+ are optional, but only one of these can be specified.
+
+ :param constraint:
+ The name of a unique or exclusion constraint on the table,
+ or the constraint object itself if it has a .name attribute.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict`
+
+ """
+ self._post_values_clause = OnConflictDoNothing(
+ constraint, index_elements, index_where
+ )
+
+
+insert = public_factory(
+ Insert, ".dialects.postgresql.insert", ".dialects.postgresql.Insert"
+)
+
+
+class OnConflictClause(ClauseElement):
+ stringify_dialect = "postgresql"
+
+ def __init__(self, constraint=None, index_elements=None, index_where=None):
+
+ if constraint is not None:
+ if not isinstance(constraint, util.string_types) and isinstance(
+ constraint,
+ (schema.Index, schema.Constraint, ext.ExcludeConstraint),
+ ):
+ constraint = getattr(constraint, "name") or constraint
+
+ if constraint is not None:
+ if index_elements is not None:
+ raise ValueError(
+ "'constraint' and 'index_elements' are mutually exclusive"
+ )
+
+ if isinstance(constraint, util.string_types):
+ self.constraint_target = constraint
+ self.inferred_target_elements = None
+ self.inferred_target_whereclause = None
+ elif isinstance(constraint, schema.Index):
+ index_elements = constraint.expressions
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
+ elif isinstance(constraint, ext.ExcludeConstraint):
+ index_elements = constraint.columns
+ index_where = constraint.where
+ else:
+ index_elements = constraint.columns
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
+
+ if index_elements is not None:
+ self.constraint_target = None
+ self.inferred_target_elements = index_elements
+ self.inferred_target_whereclause = index_where
+ elif constraint is None:
+ self.constraint_target = (
+ self.inferred_target_elements
+ ) = self.inferred_target_whereclause = None
+
+
+class OnConflictDoNothing(OnConflictClause):
+ __visit_name__ = "on_conflict_do_nothing"
+
+
+class OnConflictDoUpdate(OnConflictClause):
+ __visit_name__ = "on_conflict_do_update"
+
+ def __init__(
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ super(OnConflictDoUpdate, self).__init__(
+ constraint=constraint,
+ index_elements=index_elements,
+ index_where=index_where,
+ )
+
+ if (
+ self.inferred_target_elements is None
+ and self.constraint_target is None
+ ):
+ raise ValueError(
+ "Either constraint or index_elements, "
+ "but not both, must be specified unless DO NOTHING"
+ )
+
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update_values_to_set = [
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
+ ]
+ self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
new file mode 100644
index 0000000..9e52ee1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -0,0 +1,277 @@
+# postgresql/ext.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .array import ARRAY
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import expression
+from ...sql import functions
+from ...sql import roles
+from ...sql import schema
+from ...sql.schema import ColumnCollectionConstraint
+
+
+class aggregate_order_by(expression.ColumnElement):
+ """Represent a PostgreSQL aggregate order by expression.
+
+ E.g.::
+
+ from sqlalchemy.dialects.postgresql import aggregate_order_by
+ expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
+ stmt = select(expr)
+
+ would represent the expression::
+
+ SELECT array_agg(a ORDER BY b DESC) FROM table;
+
+ Similarly::
+
+ expr = func.string_agg(
+ table.c.a,
+ aggregate_order_by(literal_column("','"), table.c.a)
+ )
+ stmt = select(expr)
+
+ Would represent::
+
+ SELECT string_agg(a, ',' ORDER BY a) FROM table;
+
+ .. versionadded:: 1.1
+
+ .. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms
+
+ .. seealso::
+
+ :class:`_functions.array_agg`
+
+ """
+
+ __visit_name__ = "aggregate_order_by"
+
+ stringify_dialect = "postgresql"
+ inherit_cache = False
+
+ def __init__(self, target, *order_by):
+ self.target = coercions.expect(roles.ExpressionElementRole, target)
+ self.type = self.target.type
+
+ _lob = len(order_by)
+ if _lob == 0:
+ raise TypeError("at least one ORDER BY element is required")
+ elif _lob == 1:
+ self.order_by = coercions.expect(
+ roles.ExpressionElementRole, order_by[0]
+ )
+ else:
+ self.order_by = elements.ClauseList(
+ *order_by, _literal_as_text_role=roles.ExpressionElementRole
+ )
+
+ def self_group(self, against=None):
+ return self
+
+ def get_children(self, **kwargs):
+ return self.target, self.order_by
+
+ def _copy_internals(self, clone=elements._clone, **kw):
+ self.target = clone(self.target, **kw)
+ self.order_by = clone(self.order_by, **kw)
+
+ @property
+ def _from_objects(self):
+ return self.target._from_objects + self.order_by._from_objects
+
+
+class ExcludeConstraint(ColumnCollectionConstraint):
+ """A table-level EXCLUDE constraint.
+
+ Defines an EXCLUDE constraint as described in the `PostgreSQL
+ documentation`__.
+
+ __ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
+
+ """ # noqa
+
+ __visit_name__ = "exclude_constraint"
+
+ where = None
+ inherit_cache = False
+
+ create_drop_stringify_dialect = "postgresql"
+
+ @elements._document_text_coercion(
+ "where",
+ ":class:`.ExcludeConstraint`",
+ ":paramref:`.ExcludeConstraint.where`",
+ )
+ def __init__(self, *elements, **kw):
+ r"""
+ Create an :class:`.ExcludeConstraint` object.
+
+ E.g.::
+
+ const = ExcludeConstraint(
+ (Column('period'), '&&'),
+ (Column('group'), '='),
+ where=(Column('group') != 'some group'),
+ ops={'group': 'my_operator_class'}
+ )
+
+ The constraint is normally embedded into the :class:`_schema.Table`
+ construct
+ directly, or added later using :meth:`.append_constraint`::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('period', TSRANGE()),
+ Column('group', String)
+ )
+
+ some_table.append_constraint(
+ ExcludeConstraint(
+ (some_table.c.period, '&&'),
+ (some_table.c.group, '='),
+ where=some_table.c.group != 'some group',
+ name='some_table_excl_const',
+ ops={'group': 'my_operator_class'}
+ )
+ )
+
+ :param \*elements:
+
+ A sequence of two tuples of the form ``(column, operator)`` where
+ "column" is a SQL expression element or a raw SQL string, most
+ typically a :class:`_schema.Column` object,
+ and "operator" is a string
+ containing the operator to use. In order to specify a column name
+ when a :class:`_schema.Column` object is not available,
+ while ensuring
+ that any necessary quoting rules take effect, an ad-hoc
+ :class:`_schema.Column` or :func:`_expression.column`
+ object should be
+ used.
+
+ :param name:
+ Optional, the in-database name of this constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param using:
+ Optional string. If set, emit USING <index_method> when issuing DDL
+ for this constraint. Defaults to 'gist'.
+
+ :param where:
+ Optional SQL expression construct or literal SQL string.
+ If set, emit WHERE <predicate> when issuing DDL
+ for this constraint.
+
+ :param ops:
+ Optional dictionary. Used to define operator classes for the
+ elements; works the same way as that of the
+ :ref:`postgresql_ops <postgresql_operator_classes>`
+ parameter specified to the :class:`_schema.Index` construct.
+
+ .. versionadded:: 1.3.21
+
+ .. seealso::
+
+ :ref:`postgresql_operator_classes` - general description of how
+ PostgreSQL operator classes are specified.
+
+ """
+ columns = []
+ render_exprs = []
+ self.operators = {}
+
+ expressions, operators = zip(*elements)
+
+ for (expr, column, strname, add_element), operator in zip(
+ coercions.expect_col_expression_collection(
+ roles.DDLConstraintColumnRole, expressions
+ ),
+ operators,
+ ):
+ if add_element is not None:
+ columns.append(add_element)
+
+ name = column.name if column is not None else strname
+
+ if name is not None:
+ # backwards compat
+ self.operators[name] = operator
+
+ render_exprs.append((expr, name, operator))
+
+ self._render_exprs = render_exprs
+
+ ColumnCollectionConstraint.__init__(
+ self,
+ *columns,
+ name=kw.get("name"),
+ deferrable=kw.get("deferrable"),
+ initially=kw.get("initially")
+ )
+ self.using = kw.get("using", "gist")
+ where = kw.get("where")
+ if where is not None:
+ self.where = coercions.expect(roles.StatementOptionRole, where)
+
+ self.ops = kw.get("ops", {})
+
+ def _set_parent(self, table, **kw):
+ super(ExcludeConstraint, self)._set_parent(table)
+
+ self._render_exprs = [
+ (
+ expr if isinstance(expr, elements.ClauseElement) else colexpr,
+ name,
+ operator,
+ )
+ for (expr, name, operator), colexpr in util.zip_longest(
+ self._render_exprs, self.columns
+ )
+ ]
+
+ def _copy(self, target_table=None, **kw):
+ elements = [
+ (
+ schema._copy_expression(expr, self.parent, target_table),
+ self.operators[expr.name],
+ )
+ for expr in self.columns
+ ]
+ c = self.__class__(
+ *elements,
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ where=self.where,
+ using=self.using
+ )
+ c.dispatch._update(self.dispatch)
+ return c
+
+
+def array_agg(*arg, **kw):
+ """PostgreSQL-specific form of :class:`_functions.array_agg`, ensures
+ return type is :class:`_postgresql.ARRAY` and not
+ the plain :class:`_types.ARRAY`, unless an explicit ``type_``
+ is passed.
+
+ .. versionadded:: 1.1
+
+ """
+ kw["_default_array_type"] = ARRAY
+ return functions.func.array_agg(*arg, **kw)
diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py
new file mode 100644
index 0000000..29800d2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/hstore.py
@@ -0,0 +1,455 @@
+# postgresql/hstore.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .array import ARRAY
+from ... import types as sqltypes
+from ... import util
+from ...sql import functions as sqlfunc
+from ...sql import operators
+
+
+__all__ = ("HSTORE", "hstore")
+
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
+
+GETITEM = operators.custom_op(
+ "->",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_KEY = operators.custom_op(
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ALL = operators.custom_op(
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ANY = operators.custom_op(
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINS = operators.custom_op(
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINED_BY = operators.custom_op(
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
+ """Represent the PostgreSQL HSTORE type.
+
+ The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', HSTORE)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ data = {"key1": "value1", "key2": "value2"}
+ )
+
+ :class:`.HSTORE` provides for a wide range of operations, including:
+
+ * Index operations::
+
+ data_table.c.data['some key'] == 'some value'
+
+ * Containment operations::
+
+ data_table.c.data.has_key('some key')
+
+ data_table.c.data.has_all(['one', 'two', 'three'])
+
+ * Concatenation::
+
+ data_table.c.data + {"k1": "v1"}
+
+ For a full list of special methods see
+ :class:`.HSTORE.comparator_factory`.
+
+ For usage with the SQLAlchemy ORM, it may be desirable to combine
+ the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary
+ now part of the :mod:`sqlalchemy.ext.mutable`
+ extension. This extension will allow "in-place" changes to the
+ dictionary, e.g. addition of new keys or replacement/removal of existing
+ keys to/from the current dictionary, to produce events which will be
+ detected by the unit of work::
+
+ from sqlalchemy.ext.mutable import MutableDict
+
+ class MyClass(Base):
+ __tablename__ = 'data_table'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(HSTORE))
+
+ my_object = session.query(MyClass).one()
+
+ # in-place mutation, requires Mutable extension
+ # in order for the ORM to detect
+ my_object.data['some_key'] = 'some value'
+
+ session.commit()
+
+ When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
+ will not be alerted to any changes to the contents of an existing
+ dictionary, unless that dictionary value is re-assigned to the
+ HSTORE-attribute itself, thus generating a change event.
+
+ .. seealso::
+
+ :class:`.hstore` - render the PostgreSQL ``hstore()`` function.
+
+
+ """
+
+ __visit_name__ = "HSTORE"
+ hashable = False
+ text_type = sqltypes.Text()
+
+ def __init__(self, text_type=None):
+ """Construct a new :class:`.HSTORE`.
+
+ :param text_type: the type that should be used for indexed values.
+ Defaults to :class:`_types.Text`.
+
+ .. versionadded:: 1.1.0
+
+ """
+ if text_type is not None:
+ self.text_type = text_type
+
+ class Comparator(
+ sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
+ ):
+ """Define comparison operations for :class:`.HSTORE`."""
+
+ def has_key(self, other):
+ """Boolean expression. Test for presence of a key. Note that the
+ key may be a SQLA expression.
+ """
+ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
+
+ def has_all(self, other):
+ """Boolean expression. Test for presence of all keys in jsonb"""
+ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
+
+ def has_any(self, other):
+ """Boolean expression. Test for presence of any key in jsonb"""
+ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if keys (or array) are a superset
+ of/contained the keys of the argument jsonb expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if keys are a proper subset of the
+ keys of the argument jsonb expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ def _setup_getitem(self, index):
+ return GETITEM, index, self.type.text_type
+
+ def defined(self, key):
+ """Boolean expression. Test for presence of a non-NULL value for
+ the key. Note that the key may be a SQLA expression.
+ """
+ return _HStoreDefinedFunction(self.expr, key)
+
+ def delete(self, key):
+ """HStore expression. Returns the contents of this hstore with the
+ given key deleted. Note that the key may be a SQLA expression.
+ """
+ if isinstance(key, dict):
+ key = _serialize_hstore(key)
+ return _HStoreDeleteFunction(self.expr, key)
+
+ def slice(self, array):
+ """HStore expression. Returns a subset of an hstore defined by
+ array of keys.
+ """
+ return _HStoreSliceFunction(self.expr, array)
+
+ def keys(self):
+ """Text array expression. Returns array of keys."""
+ return _HStoreKeysFunction(self.expr)
+
+ def vals(self):
+ """Text array expression. Returns array of values."""
+ return _HStoreValsFunction(self.expr)
+
+ def array(self):
+ """Text array expression. Returns array of alternating keys and
+ values.
+ """
+ return _HStoreArrayFunction(self.expr)
+
+ def matrix(self):
+ """Text array expression. Returns array of [key, value] pairs."""
+ return _HStoreMatrixFunction(self.expr)
+
+ comparator_factory = Comparator
+
+ def bind_processor(self, dialect):
+ if util.py2k:
+ encoding = dialect.encoding
+
+ def process(value):
+ if isinstance(value, dict):
+ return _serialize_hstore(value).encode(encoding)
+ else:
+ return value
+
+ else:
+
+ def process(value):
+ if isinstance(value, dict):
+ return _serialize_hstore(value)
+ else:
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if util.py2k:
+ encoding = dialect.encoding
+
+ def process(value):
+ if value is not None:
+ return _parse_hstore(value.decode(encoding))
+ else:
+ return value
+
+ else:
+
+ def process(value):
+ if value is not None:
+ return _parse_hstore(value)
+ else:
+ return value
+
+ return process
+
+
+class hstore(sqlfunc.GenericFunction):
+ """Construct an hstore value within a SQL expression using the
+ PostgreSQL ``hstore()`` function.
+
+ The :class:`.hstore` function accepts one or two arguments as described
+ in the PostgreSQL documentation.
+
+ E.g.::
+
+ from sqlalchemy.dialects.postgresql import array, hstore
+
+ select(hstore('key1', 'value1'))
+
+ select(
+ hstore(
+ array(['key1', 'key2', 'key3']),
+ array(['value1', 'value2', 'value3'])
+ )
+ )
+
+ .. seealso::
+
+ :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
+
+ """
+
+ type = HSTORE
+ name = "hstore"
+ inherit_cache = True
+
+
+class _HStoreDefinedFunction(sqlfunc.GenericFunction):
+ type = sqltypes.Boolean
+ name = "defined"
+ inherit_cache = True
+
+
+class _HStoreDeleteFunction(sqlfunc.GenericFunction):
+ type = HSTORE
+ name = "delete"
+ inherit_cache = True
+
+
+class _HStoreSliceFunction(sqlfunc.GenericFunction):
+ type = HSTORE
+ name = "slice"
+ inherit_cache = True
+
+
+class _HStoreKeysFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "akeys"
+ inherit_cache = True
+
+
+class _HStoreValsFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "avals"
+ inherit_cache = True
+
+
+class _HStoreArrayFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "hstore_to_array"
+ inherit_cache = True
+
+
+class _HStoreMatrixFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "hstore_to_matrix"
+ inherit_cache = True
+
+
+#
+# parsing. note that none of this is used with the psycopg2 backend,
+# which provides its own native extensions.
+#
+
+# My best guess at the parsing rules of hstore literals, since no formal
+# grammar is given. This is mostly reverse engineered from PG's input parser
+# behavior.
+HSTORE_PAIR_RE = re.compile(
+ r"""
+(
+ "(?P<key> (\\ . | [^"])* )" # Quoted key
+)
+[ ]* => [ ]* # Pair operator, optional adjoining whitespace
+(
+ (?P<value_null> NULL ) # NULL value
+ | "(?P<value> (\\ . | [^"])* )" # Quoted value
+)
+""",
+ re.VERBOSE,
+)
+
+HSTORE_DELIMITER_RE = re.compile(
+ r"""
+[ ]* , [ ]*
+""",
+ re.VERBOSE,
+)
+
+
+def _parse_error(hstore_str, pos):
+ """format an unmarshalling error."""
+
+ ctx = 20
+ hslen = len(hstore_str)
+
+ parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
+ residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
+
+ if len(parsed_tail) > ctx:
+ parsed_tail = "[...]" + parsed_tail[1:]
+ if len(residual) > ctx:
+ residual = residual[:-1] + "[...]"
+
+ return "After %r, could not parse residual at position %d: %r" % (
+ parsed_tail,
+ pos,
+ residual,
+ )
+
+
+def _parse_hstore(hstore_str):
+ """Parse an hstore from its literal string representation.
+
+ Attempts to approximate PG's hstore input parsing rules as closely as
+ possible. Although currently this is not strictly necessary, since the
+ current implementation of hstore's output syntax is stricter than what it
+ accepts as input, the documentation makes no guarantees that will always
+ be the case.
+
+
+
+ """
+ result = {}
+ pos = 0
+ pair_match = HSTORE_PAIR_RE.match(hstore_str)
+
+ while pair_match is not None:
+ key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
+ if pair_match.group("value_null"):
+ value = None
+ else:
+ value = (
+ pair_match.group("value")
+ .replace(r"\"", '"')
+ .replace("\\\\", "\\")
+ )
+ result[key] = value
+
+ pos += pair_match.end()
+
+ delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
+ if delim_match is not None:
+ pos += delim_match.end()
+
+ pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
+
+ if pos != len(hstore_str):
+ raise ValueError(_parse_error(hstore_str, pos))
+
+ return result
+
+
+def _serialize_hstore(val):
+ """Serialize a dictionary into an hstore literal. Keys and values must
+ both be strings (except None for values).
+
+ """
+
+ def esc(s, position):
+ if position == "value" and s is None:
+ return "NULL"
+ elif isinstance(s, util.string_types):
+ return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
+ else:
+ raise ValueError(
+ "%r in %s position is not a string." % (s, position)
+ )
+
+ return ", ".join(
+ "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
new file mode 100644
index 0000000..daaaeac
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -0,0 +1,327 @@
+# postgresql/json.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import absolute_import
+
+from ... import types as sqltypes
+from ... import util
+from ...sql import operators
+
+
+__all__ = ("JSON", "JSONB")
+
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
+
+ASTEXT = operators.custom_op(
+ "->>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+JSONPATH_ASTEXT = operators.custom_op(
+ "#>>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+HAS_KEY = operators.custom_op(
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ALL = operators.custom_op(
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ANY = operators.custom_op(
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINS = operators.custom_op(
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINED_BY = operators.custom_op(
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+class JSONPathType(sqltypes.JSON.JSONPathType):
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSON(sqltypes.JSON):
+ """Represent the PostgreSQL JSON type.
+
+ :class:`_postgresql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a PostgreSQL backend,
+ however base :class:`_types.JSON` datatype does not provide Python
+ accessors for PostgreSQL-specific comparison methods such as
+ :meth:`_postgresql.JSON.Comparator.astext`; additionally, to use
+ PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should
+ be used explicitly.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The operators provided by the PostgreSQL version of :class:`_types.JSON`
+ include:
+
+ * Index operations (the ``->`` operator)::
+
+ data_table.c.data['some key']
+
+ data_table.c.data[5]
+
+
+ * Index operations returning text (the ``->>`` operator)::
+
+ data_table.c.data['some key'].astext == 'some value'
+
+ Note that equivalent functionality is available via the
+ :attr:`.JSON.Comparator.as_string` accessor.
+
+ * Index operations with CAST
+ (equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
+
+ data_table.c.data['some key'].astext.cast(Integer) == 5
+
+ Note that equivalent functionality is available via the
+ :attr:`.JSON.Comparator.as_integer` and similar accessors.
+
+ * Path index operations (the ``#>`` operator)::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
+
+ * Path index operations returning text (the ``#>>`` operator)::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
+
+ .. versionchanged:: 1.1 The :meth:`_expression.ColumnElement.cast`
+ operator on
+ JSON objects now requires that the :attr:`.JSON.Comparator.astext`
+ modifier be called explicitly, if the cast works only from a textual
+ string.
+
+ Index operations return an expression object whose type defaults to
+ :class:`_types.JSON` by default,
+ so that further JSON-oriented instructions
+ may be called upon the result type.
+
+ Custom serializers and deserializers are specified at the dialect level,
+ that is using :func:`_sa.create_engine`. The reason for this is that when
+ using psycopg2, the DBAPI only allows serializers at the per-cursor
+ or per-connection level. E.g.::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test",
+ json_serializer=my_serialize_fn,
+ json_deserializer=my_deserialize_fn
+ )
+
+ When using the psycopg2 dialect, the json_deserializer is registered
+ against the database using ``psycopg2.extras.register_default_json``.
+
+ .. seealso::
+
+ :class:`_types.JSON` - Core level JSON type
+
+ :class:`_postgresql.JSONB`
+
+ .. versionchanged:: 1.1 :class:`_postgresql.JSON` is now a PostgreSQL-
+ specific specialization of the new :class:`_types.JSON` type.
+
+ """ # noqa
+
+ astext_type = sqltypes.Text()
+
+ def __init__(self, none_as_null=False, astext_type=None):
+ """Construct a :class:`_types.JSON` type.
+
+ :param none_as_null: if True, persist the value ``None`` as a
+ SQL NULL value, not the JSON encoding of ``null``. Note that
+ when this flag is False, the :func:`.null` construct can still
+ be used to persist a NULL value::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), data=null())
+
+ .. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null`
+ is now supported in order to persist a NULL value.
+
+ .. seealso::
+
+ :attr:`_types.JSON.NULL`
+
+ :param astext_type: the type to use for the
+ :attr:`.JSON.Comparator.astext`
+ accessor on indexed attributes. Defaults to :class:`_types.Text`.
+
+ .. versionadded:: 1.1
+
+ """
+ super(JSON, self).__init__(none_as_null=none_as_null)
+ if astext_type is not None:
+ self.astext_type = astext_type
+
+ class Comparator(sqltypes.JSON.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ @property
+ def astext(self):
+ """On an indexed expression, use the "astext" (e.g. "->>")
+ conversion when rendered in SQL.
+
+ E.g.::
+
+ select(data_table.c.data['some key'].astext)
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.cast`
+
+ """
+ if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
+ return self.expr.left.operate(
+ JSONPATH_ASTEXT,
+ self.expr.right,
+ result_type=self.type.astext_type,
+ )
+ else:
+ return self.expr.left.operate(
+ ASTEXT, self.expr.right, result_type=self.type.astext_type
+ )
+
+ comparator_factory = Comparator
+
+
+class JSONB(JSON):
+ """Represent the PostgreSQL JSONB type.
+
+ The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
+ e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', JSONB)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ data = {"key1": "value1", "key2": "value2"}
+ )
+
+ The :class:`_postgresql.JSONB` type includes all operations provided by
+ :class:`_types.JSON`, including the same behaviors for indexing
+ operations.
+ It also adds additional operators specific to JSONB, including
+ :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
+ :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
+ and :meth:`.JSONB.Comparator.contained_by`.
+
+ Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB`
+ type does not detect
+ in-place changes when used with the ORM, unless the
+ :mod:`sqlalchemy.ext.mutable` extension is used.
+
+ Custom serializers and deserializers
+ are shared with the :class:`_types.JSON` class,
+ using the ``json_serializer``
+ and ``json_deserializer`` keyword arguments. These must be specified
+ at the dialect level using :func:`_sa.create_engine`. When using
+ psycopg2, the serializers are associated with the jsonb type using
+ ``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
+ in the same way that ``psycopg2.extras.register_default_json`` is used
+ to register these handlers with the json type.
+
+ .. versionadded:: 0.9.7
+
+ .. seealso::
+
+ :class:`_types.JSON`
+
+ """
+
+ __visit_name__ = "JSONB"
+
+ class Comparator(JSON.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ def has_key(self, other):
+ """Boolean expression. Test for presence of a key. Note that the
+ key may be a SQLA expression.
+ """
+ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
+
+ def has_all(self, other):
+ """Boolean expression. Test for presence of all keys in jsonb"""
+ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
+
+ def has_any(self, other):
+ """Boolean expression. Test for presence of any key in jsonb"""
+ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if keys (or array) are a superset
+ of/contained the keys of the argument jsonb expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if keys are a proper subset of the
+ keys of the argument jsonb expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ comparator_factory = Comparator
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
new file mode 100644
index 0000000..98561a9
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -0,0 +1,594 @@
+# postgresql/pg8000.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+pg8000
+ :name: pg8000
+ :dbapi: pg8000
+ :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/pg8000/
+
+.. versionchanged:: 1.4 The pg8000 dialect has been updated for version
+ 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
+ with full feature support.
+
+.. _pg8000_unicode:
+
+Unicode
+-------
+
+pg8000 will encode / decode string values between it and the server using the
+PostgreSQL ``client_encoding`` parameter; by default this is the value in
+the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
+Typically, this can be changed to ``utf-8``, as a more useful default::
+
+ #client_encoding = sql_ascii # actually, defaults to database
+ # encoding
+ client_encoding = utf8
+
+The ``client_encoding`` can be overridden for a session by executing the SQL:
+
+SET CLIENT_ENCODING TO 'utf8';
+
+SQLAlchemy will execute this SQL on all new connections based on the value
+passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
+
+ engine = create_engine(
+ "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
+
+.. _pg8000_ssl:
+
+SSL Connections
+---------------
+
+pg8000 accepts a Python ``SSLContext`` object which may be specified using the
+:paramref:`_sa.create_engine.connect_args` dictionary::
+
+ import ssl
+ ssl_context = ssl.create_default_context()
+ engine = sa.create_engine(
+ "postgresql+pg8000://scott:tiger@192.168.0.199/test",
+ connect_args={"ssl_context": ssl_context},
+ )
+
+If the server uses an automatically-generated certificate that is self-signed
+or does not match the host name (as seen from the client), it may also be
+necessary to disable hostname checking::
+
+ import ssl
+ ssl_context = ssl.create_default_context()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ engine = sa.create_engine(
+ "postgresql+pg8000://scott:tiger@192.168.0.199/test",
+ connect_args={"ssl_context": ssl_context},
+ )
+
+.. _pg8000_isolation_level:
+
+pg8000 Transaction Isolation Level
+-------------------------------------
+
+The pg8000 dialect offers the same isolation level settings as that
+of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+ :ref:`psycopg2_isolation_level`
+
+
+""" # noqa
+import decimal
+import re
+from uuid import UUID as _python_UUID
+
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .json import JSON
+from .json import JSONB
+from .json import JSONPathType
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...sql.elements import quoted_name
+
+
+class _PGNumeric(sqltypes.Numeric):
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGNumericNoBind(_PGNumeric):
+ def bind_processor(self, dialect):
+ return None
+
+
+class _PGJSON(JSON):
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSON
+
+
+class _PGJSONB(JSONB):
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSONB
+
+
+class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
+ def get_dbapi_type(self, dbapi):
+ raise NotImplementedError("should not be here")
+
+
+class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class _PGJSONPathType(JSONPathType):
+ def get_dbapi_type(self, dbapi):
+ return 1009
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+class _PGEnum(ENUM):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.UNKNOWN
+
+
+class _PGInterval(INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return _PGInterval(precision=interval.second_precision)
+
+
+class _PGTimeStamp(sqltypes.DateTime):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ # TIMESTAMPTZOID
+ return 1184
+ else:
+ # TIMESTAMPOID
+ return 1114
+
+
+class _PGTime(sqltypes.Time):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.TIME
+
+
+class _PGInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGSmallInteger(sqltypes.SmallInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGNullType(sqltypes.NullType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NULLTYPE
+
+
+class _PGBigInteger(sqltypes.BigInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BIGINTEGER
+
+
+class _PGBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BOOLEAN
+
+
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
+_server_side_id = util.counter()
+
+
+class PGExecutionContext_pg8000(PGExecutionContext):
+ def create_server_side_cursor(self):
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+ return ServerSideCursor(self._dbapi_connection.cursor(), ident)
+
+ def pre_exec(self):
+ if not self.compiled:
+ return
+
+
+class ServerSideCursor:
+ server_side = True
+
+ def __init__(self, cursor, ident):
+ self.ident = ident
+ self.cursor = cursor
+
+ @property
+ def connection(self):
+ return self.cursor.connection
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
+
+ @property
+ def description(self):
+ return self.cursor.description
+
+ def execute(self, operation, args=(), stream=None):
+ op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
+ self.cursor.execute(op, args, stream=stream)
+ return self
+
+ def executemany(self, operation, param_sets):
+ self.cursor.executemany(operation, param_sets)
+ return self
+
+ def fetchone(self):
+ self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
+ return self.cursor.fetchone()
+
+ def fetchmany(self, num=None):
+ if num is None:
+ return self.fetchall()
+ else:
+ self.cursor.execute(
+ "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
+ )
+ return self.cursor.fetchall()
+
+ def fetchall(self):
+ self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
+ return self.cursor.fetchall()
+
+ def close(self):
+ self.cursor.execute("CLOSE " + self.ident)
+ self.cursor.close()
+
+ def setinputsizes(self, *sizes):
+ self.cursor.setinputsizes(*sizes)
+
+ def setoutputsize(self, size, column=None):
+ pass
+
+
+class PGCompiler_pg8000(PGCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+
+
+class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
+ def __init__(self, *args, **kwargs):
+ PGIdentifierPreparer.__init__(self, *args, **kwargs)
+ self._double_percents = False
+
+
+class PGDialect_pg8000(PGDialect):
+ driver = "pg8000"
+ supports_statement_cache = True
+
+ supports_unicode_statements = True
+
+ supports_unicode_binds = True
+
+ default_paramstyle = "format"
+ supports_sane_multi_rowcount = True
+ execution_ctx_cls = PGExecutionContext_pg8000
+ statement_compiler = PGCompiler_pg8000
+ preparer = PGIdentifierPreparer_pg8000
+ supports_server_side_cursors = True
+
+ use_setinputsizes = True
+
+ # reversed as of pg8000 1.16.6. 1.16.5 and lower
+ # are no longer compatible
+ description_encoding = None
+ # description_encoding = "use_encoding"
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: _PGNumericNoBind,
+ sqltypes.Float: _PGNumeric,
+ sqltypes.JSON: _PGJSON,
+ sqltypes.Boolean: _PGBoolean,
+ sqltypes.NullType: _PGNullType,
+ JSONB: _PGJSONB,
+ sqltypes.JSON.JSONPathType: _PGJSONPathType,
+ sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
+ sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
+ sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
+ UUID: _PGUUID,
+ sqltypes.Interval: _PGInterval,
+ INTERVAL: _PGInterval,
+ sqltypes.DateTime: _PGTimeStamp,
+ sqltypes.Time: _PGTime,
+ sqltypes.Integer: _PGInteger,
+ sqltypes.SmallInteger: _PGSmallInteger,
+ sqltypes.BigInteger: _PGBigInteger,
+ sqltypes.Enum: _PGEnum,
+ sqltypes.ARRAY: _PGARRAY,
+ },
+ )
+
+ def __init__(self, client_encoding=None, **kwargs):
+ PGDialect.__init__(self, **kwargs)
+ self.client_encoding = client_encoding
+
+ if self._dbapi_version < (1, 16, 6):
+ raise NotImplementedError("pg8000 1.16.6 or greater is required")
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pg8000")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
+ e
+ ):
+ # new as of pg8000 1.19.0 for broken connections
+ return True
+
+ # connection was closed normally
+ return "connection is closed" in str(e)
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+
+ # adjust for ConnectionFairy possibly being present
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ elif level in self._isolation_lookup:
+ connection.autocommit = False
+ cursor = connection.cursor()
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION "
+ "ISOLATION LEVEL %s" % level
+ )
+ cursor.execute("COMMIT")
+ cursor.close()
+ else:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s or AUTOCOMMIT"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+
+ def set_readonly(self, connection, value):
+ cursor = connection.cursor()
+ try:
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
+ % ("READ ONLY" if value else "READ WRITE")
+ )
+ cursor.execute("COMMIT")
+ finally:
+ cursor.close()
+
+ def get_readonly(self, connection):
+ cursor = connection.cursor()
+ try:
+ cursor.execute("show transaction_read_only")
+ val = cursor.fetchone()[0]
+ finally:
+ cursor.close()
+
+ return val == "on"
+
+ def set_deferrable(self, connection, value):
+ cursor = connection.cursor()
+ try:
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
+ % ("DEFERRABLE" if value else "NOT DEFERRABLE")
+ )
+ cursor.execute("COMMIT")
+ finally:
+ cursor.close()
+
+ def get_deferrable(self, connection):
+ cursor = connection.cursor()
+ try:
+ cursor.execute("show transaction_deferrable")
+ val = cursor.fetchone()[0]
+ finally:
+ cursor.close()
+
+ return val == "on"
+
+ def set_client_encoding(self, connection, client_encoding):
+ # adjust for ConnectionFairy possibly being present
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ cursor = connection.cursor()
+ cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
+ def do_begin_twophase(self, connection, xid):
+ connection.connection.tpc_begin((0, xid, ""))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.connection.tpc_prepare()
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_rollback((0, xid, ""))
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_commit((0, xid, ""))
+
+ def do_recover_twophase(self, connection):
+ return [row[1] for row in connection.connection.tpc_recover()]
+
+ def on_connect(self):
+ fns = []
+
+ def on_connect(conn):
+ conn.py_types[quoted_name] = conn.py_types[util.text_type]
+
+ fns.append(on_connect)
+
+ if self.client_encoding is not None:
+
+ def on_connect(conn):
+ self.set_client_encoding(conn, self.client_encoding)
+
+ fns.append(on_connect)
+
+ if self.isolation_level is not None:
+
+ def on_connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(on_connect)
+
+ if self._json_deserializer:
+
+ def on_connect(conn):
+ # json
+ conn.register_in_adapter(114, self._json_deserializer)
+
+ # jsonb
+ conn.register_in_adapter(3802, self._json_deserializer)
+
+ fns.append(on_connect)
+
+ if len(fns) > 0:
+
+ def on_connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return on_connect
+ else:
+ return None
+
+
+dialect = PGDialect_pg8000
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
new file mode 100644
index 0000000..98470f3
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -0,0 +1,124 @@
+import time
+
+from ... import exc
+from ... import inspect
+from ... import text
+from ...testing import warn_test_suite
+from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_post_tables
+from ...testing.provision import drop_all_schema_objects_pre_tables
+from ...testing.provision import drop_db
+from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
+from ...testing.provision import set_default_schema_on_connection
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("postgresql")
+def _pg_create_db(cfg, eng, ident):
+ template_db = cfg.options.postgresql_templatedb
+
+ with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
+
+ if not template_db:
+ template_db = conn.exec_driver_sql(
+ "select current_database()"
+ ).scalar()
+
+ attempt = 0
+ while True:
+ try:
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
+ )
+ except exc.OperationalError as err:
+ attempt += 1
+ if attempt >= 3:
+ raise
+ if "accessed by other users" in str(err):
+ log.info(
+ "Waiting to create %s, URI %r, "
+ "template DB %s is in use sleeping for .5",
+ ident,
+ eng.url,
+ template_db,
+ )
+ time.sleep(0.5)
+ except:
+ raise
+ else:
+ break
+
+
+@drop_db.for_db("postgresql")
+def _pg_drop_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ with conn.begin():
+ conn.execute(
+ text(
+ "select pg_terminate_backend(pid) from pg_stat_activity "
+ "where usename=current_user and pid != pg_backend_pid() "
+ "and datname=:dname"
+ ),
+ dict(dname=ident),
+ )
+ conn.exec_driver_sql("DROP DATABASE %s" % ident)
+
+
+@temp_table_keyword_args.for_db("postgresql")
+def _postgresql_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
+
+
+@set_default_schema_on_connection.for_db("postgresql")
+def _postgresql_set_default_schema_on_connection(
+ cfg, dbapi_connection, schema_name
+):
+ existing_autocommit = dbapi_connection.autocommit
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET SESSION search_path='%s'" % schema_name)
+ cursor.close()
+ dbapi_connection.autocommit = existing_autocommit
+
+
+@drop_all_schema_objects_pre_tables.for_db("postgresql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ for xid in conn.execute("select gid from pg_prepared_xacts").scalars():
+ conn.execute("ROLLBACK PREPARED '%s'" % xid)
+
+
+@drop_all_schema_objects_post_tables.for_db("postgresql")
+def drop_all_schema_objects_post_tables(cfg, eng):
+ from sqlalchemy.dialects import postgresql
+
+ inspector = inspect(eng)
+ with eng.begin() as conn:
+ for enum in inspector.get_enums("*"):
+ conn.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(name=enum["name"], schema=enum["schema"])
+ )
+ )
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+ """Ensure there are no locks on the current username/database."""
+
+ result = connection.exec_driver_sql(
+ "select pid, state, wait_event_type, query "
+ # "select pg_terminate_backend(pid), state, wait_event_type "
+ "from pg_stat_activity where "
+ "usename=current_user "
+ "and datname=current_database() and state='idle in transaction' "
+ "and pid != pg_backend_pid()"
+ )
+ rows = result.all() # noqa
+ if rows:
+ warn_test_suite(
+ "PostgreSQL may not be able to DROP tables due to "
+ "idle in transaction: %s"
+ % ("; ".join(row._mapping["query"] for row in rows))
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
new file mode 100644
index 0000000..6747427
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -0,0 +1,1088 @@
+# postgresql/psycopg2.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg2
+ :name: psycopg2
+ :dbapi: psycopg2
+ :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/psycopg2/
+
+psycopg2 Connect Arguments
+--------------------------
+
+Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect
+may be passed to :func:`_sa.create_engine()`, and include the following:
+
+
+* ``isolation_level``: This option, available for all PostgreSQL dialects,
+ includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
+ dialect. This option sets the **default** isolation level for the
+ connection that is set immediately upon connection to the database before
+ the connection is pooled. This option is generally superseded by the more
+ modern :paramref:`_engine.Connection.execution_options.isolation_level`
+ execution option, detailed at :ref:`dbapi_autocommit`.
+
+ .. seealso::
+
+ :ref:`psycopg2_isolation_level`
+
+ :ref:`dbapi_autocommit`
+
+
+* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
+ using psycopg2's ``set_client_encoding()`` method.
+
+ .. seealso::
+
+ :ref:`psycopg2_unicode`
+
+* ``use_native_unicode``: Under Python 2 only, this can be set to False to
+ disable the use of psycopg2's native Unicode support.
+
+ .. seealso::
+
+ :ref:`psycopg2_disable_native_unicode`
+
+
+* ``executemany_mode``, ``executemany_batch_page_size``,
+ ``executemany_values_page_size``: Allows use of psycopg2
+ extensions for optimizing "executemany"-style queries. See the referenced
+ section below for details.
+
+ .. seealso::
+
+ :ref:`psycopg2_executemany_mode`
+
+.. tip::
+
+ The above keyword arguments are **dialect** keyword arguments, meaning
+ that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@localhost/test",
+ isolation_level="SERIALIZABLE",
+ )
+
+ These should not be confused with **DBAPI** connect arguments, which
+ are passed as part of the :paramref:`_sa.create_engine.connect_args`
+ dictionary and/or are passed in the URL query string, as detailed in
+ the section :ref:`custom_dbapi_args`.
+
+.. _psycopg2_ssl:
+
+SSL Connections
+---------------
+
+The psycopg2 module has a connection argument named ``sslmode`` for
+controlling its behavior regarding secure (SSL) connections. The default is
+``sslmode=prefer``; it will attempt an SSL connection and if that fails it
+will fall back to an unencrypted connection. ``sslmode=require`` may be used
+to ensure that only secure connections are established. Consult the
+psycopg2 / libpq documentation for further options that are available.
+
+Note that ``sslmode`` is specific to psycopg2 so it is included in the
+connection URI::
+
+ engine = sa.create_engine(
+ "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
+ )
+
+Unix Domain Connections
+------------------------
+
+psycopg2 supports connecting via Unix domain connections. When the ``host``
+portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
+which specifies Unix-domain communication rather than TCP/IP communication::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname")
+
+By default, the socket file used is to connect to a Unix-domain socket
+in ``/tmp``, or whatever socket directory was specified when PostgreSQL
+was built. This value can be overridden by passing a pathname to psycopg2,
+using ``host`` as an additional keyword argument::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
+
+.. seealso::
+
+ `PQconnectdbParams \
+ <https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
+
+.. _psycopg2_multi_host:
+
+Specifying multiple fallback hosts
+-----------------------------------
+
+psycopg2 supports multiple connection points in the connection string.
+When the ``host`` parameter is used multiple times in the query section of
+the URL, SQLAlchemy will create a single string of the host and port
+information provided to make the connections. Tokens may consist of
+``host::port`` or just ``host``; in the latter case, the default port
+is selected by libpq. In the example below, three host connections
+are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port,
+and ``HostC::PortC``::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC"
+ )
+
+As an alternative, libpq query string format also may be used; this specifies
+``host`` and ``port`` as single query string arguments with comma-separated
+lists - the default port can be chosen by indicating an empty value
+in the comma separated list::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC"
+ )
+
+With either URL style, connections to each host is attempted based on a
+configurable strategy, which may be configured using the libpq
+``target_session_attrs`` parameter. Per libpq this defaults to ``any``
+which indicates a connection to each host is then attempted until a connection is successful.
+Other strategies include ``primary``, ``prefer-standby``, etc. The complete
+list is documented by PostgreSQL at
+`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_.
+
+For example, to indicate two hosts using the ``primary`` strategy::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary"
+ )
+
+.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format
+ is repaired, previously ports were not correctly interpreted in this context.
+ libpq comma-separated format is also now supported.
+
+.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection
+ string.
+
+.. seealso::
+
+ `libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_ - please refer
+ to this section in the libpq documentation for complete background on multiple host support.
+
+
+Empty DSN Connections / Environment Variable Connections
+---------------------------------------------------------
+
+The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the
+libpq client library, which by default indicates to connect to a localhost
+PostgreSQL database that is open for "trust" connections. This behavior can be
+further tailored using a particular set of environment variables which are
+prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of
+any or all elements of the connection string.
+
+For this form, the URL can be passed without any elements other than the
+initial scheme::
+
+ engine = create_engine('postgresql+psycopg2://')
+
+In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
+function which in turn represents an empty DSN passed to libpq.
+
+.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2.
+
+.. seealso::
+
+ `Environment Variables\
+ <https://www.postgresql.org/docs/current/libpq-envars.html>`_ -
+ PostgreSQL documentation on how to use ``PG_...``
+ environment variables for connections.
+
+.. _psycopg2_execution_options:
+
+Per-Statement/Connection Execution Options
+-------------------------------------------
+
+The following DBAPI-specific options are respected when used with
+:meth:`_engine.Connection.execution_options`,
+:meth:`.Executable.execution_options`,
+:meth:`_query.Query.execution_options`,
+in addition to those not specific to DBAPIs:
+
+* ``isolation_level`` - Set the transaction isolation level for the lifespan
+ of a :class:`_engine.Connection` (can only be set on a connection,
+ not a statement
+ or query). See :ref:`psycopg2_isolation_level`.
+
+* ``stream_results`` - Enable or disable usage of psycopg2 server side
+ cursors - this feature makes use of "named" cursors in combination with
+ special result handling methods so that result rows are not fully buffered.
+ Defaults to False, meaning cursors are buffered by default.
+
+* ``max_row_buffer`` - when using ``stream_results``, an integer value that
+ specifies the maximum number of rows to buffer at a time. This is
+ interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the
+ buffer will grow to ultimately store 1000 rows at a time.
+
+ .. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than
+ 1000, and the buffer will grow to that size.
+
+.. _psycopg2_batch_mode:
+
+.. _psycopg2_executemany_mode:
+
+Psycopg2 Fast Execution Helpers
+-------------------------------
+
+Modern versions of psycopg2 include a feature known as
+`Fast Execution Helpers \
+<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
+have been shown in benchmarking to improve psycopg2's executemany()
+performance, primarily with INSERT statements, by multiple orders of magnitude.
+SQLAlchemy internally makes use of these extensions for ``executemany()`` style
+calls, which correspond to lists of parameters being passed to
+:meth:`_engine.Connection.execute` as detailed in :ref:`multiple parameter
+sets <tutorial_multiple_parameters>`. The ORM also uses this mode internally whenever
+possible.
+
+The two available extensions on the psycopg2 side are the ``execute_values()``
+and ``execute_batch()`` functions. The psycopg2 dialect defaults to using the
+``execute_values()`` extension for all qualifying INSERT statements.
+
+.. versionchanged:: 1.4 The psycopg2 dialect now defaults to a new mode
+ ``"values_only"`` for ``executemany_mode``, which allows an order of
+ magnitude performance improvement for INSERT statements, but does not
+ include "batch" mode for UPDATE and DELETE statements which removes the
+ ability of ``cursor.rowcount`` to function correctly.
+
+The use of these extensions is controlled by the ``executemany_mode`` flag
+which may be passed to :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@host/dbname",
+ executemany_mode='values_plus_batch')
+
+
+Possible options for ``executemany_mode`` include:
+
+* ``values_only`` - this is the default value. the psycopg2 execute_values()
+ extension is used for qualifying INSERT statements, which rewrites the INSERT
+ to include multiple VALUES clauses so that many parameter sets can be
+ inserted with one statement.
+
+ .. versionadded:: 1.4 Added ``"values_only"`` setting for ``executemany_mode``
+ which is also now the default.
+
+* ``None`` - No psycopg2 extensions are not used, and the usual
+ ``cursor.executemany()`` method is used when invoking statements with
+ multiple parameter sets.
+
+* ``'batch'`` - Uses ``psycopg2.extras.execute_batch`` for all qualifying
+ INSERT, UPDATE and DELETE statements, so that multiple copies
+ of a SQL query, each one corresponding to a parameter set passed to
+ ``executemany()``, are joined into a single SQL string separated by a
+ semicolon. When using this mode, the :attr:`_engine.CursorResult.rowcount`
+ attribute will not contain a value for executemany-style executions.
+
+* ``'values_plus_batch'``- ``execute_values`` is used for qualifying INSERT
+ statements, ``execute_batch`` is used for UPDATE and DELETE.
+ When using this mode, the :attr:`_engine.CursorResult.rowcount`
+ attribute will not contain a value for executemany-style executions against
+ UPDATE and DELETE statements.
+
+By "qualifying statements", we mean that the statement being executed
+must be a Core :func:`_expression.insert`, :func:`_expression.update`
+or :func:`_expression.delete` construct, and not a plain textual SQL
+string or one constructed using :func:`_expression.text`. When using the
+ORM, all insert/update/delete statements used by the ORM flush process
+are qualifying.
+
+The "page size" for the "values" and "batch" strategies can be affected
+by using the ``executemany_batch_page_size`` and
+``executemany_values_page_size`` engine parameters. These
+control how many parameter sets
+should be represented in each execution. The "values" page size defaults
+to 1000, which is different that psycopg2's default. The "batch" page
+size defaults to 100. These can be affected by passing new values to
+:func:`_engine.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@host/dbname",
+ executemany_mode='values',
+ executemany_values_page_size=10000, executemany_batch_page_size=500)
+
+.. versionchanged:: 1.4
+
+ The default for ``executemany_values_page_size`` is now 1000, up from
+ 100.
+
+.. seealso::
+
+ :ref:`tutorial_multiple_parameters` - General information on using the
+ :class:`_engine.Connection`
+ object to execute statements in such a way as to make
+ use of the DBAPI ``.executemany()`` method.
+
+
+.. _psycopg2_unicode:
+
+Unicode with Psycopg2
+----------------------
+
+The psycopg2 DBAPI driver supports Unicode data transparently. Under Python 2
+only, the SQLAlchemy psycopg2 dialect will enable the
+``psycopg2.extensions.UNICODE`` extension by default to ensure Unicode is
+handled properly; under Python 3, this is psycopg2's default behavior.
+
+The client character encoding can be controlled for the psycopg2 dialect
+in the following ways:
+
+* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be
+ passed in the database URL; this parameter is consumed by the underlying
+ ``libpq`` PostgreSQL client library::
+
+ engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
+
+ Alternatively, the above ``client_encoding`` value may be passed using
+ :paramref:`_sa.create_engine.connect_args` for programmatic establishment with
+ ``libpq``::
+
+ engine = create_engine(
+ "postgresql+psycopg2://user:pass@host/dbname",
+ connect_args={'client_encoding': 'utf8'}
+ )
+
+* For all PostgreSQL versions, psycopg2 supports a client-side encoding
+ value that will be passed to database connections when they are first
+ established. The SQLAlchemy psycopg2 dialect supports this using the
+ ``client_encoding`` parameter passed to :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://user:pass@host/dbname",
+ client_encoding="utf8"
+ )
+
+ .. tip:: The above ``client_encoding`` parameter admittedly is very similar
+ in appearance to usage of the parameter within the
+ :paramref:`_sa.create_engine.connect_args` dictionary; the difference
+ above is that the parameter is consumed by psycopg2 and is
+ passed to the database connection using ``SET client_encoding TO
+ 'utf8'``; in the previously mentioned style, the parameter is instead
+ passed through psycopg2 and consumed by the ``libpq`` library.
+
+* A common way to set up client encoding with PostgreSQL databases is to
+ ensure it is configured within the server-side postgresql.conf file;
+ this is the recommended way to set encoding for a server that is
+ consistently of one encoding in all databases::
+
+ # postgresql.conf file
+
+ # client_encoding = sql_ascii # actually, defaults to database
+ # encoding
+ client_encoding = utf8
+
+.. _psycopg2_disable_native_unicode:
+
+Disabling Native Unicode
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Under Python 2 only, SQLAlchemy can also be instructed to skip the usage of the
+psycopg2 ``UNICODE`` extension and to instead utilize its own unicode
+encode/decode services, which are normally reserved only for those DBAPIs that
+don't fully support unicode directly. Passing ``use_native_unicode=False`` to
+:func:`_sa.create_engine` will disable usage of ``psycopg2.extensions.
+UNICODE``. SQLAlchemy will instead encode data itself into Python bytestrings
+on the way in and coerce from bytes on the way back, using the value of the
+:func:`_sa.create_engine` ``encoding`` parameter, which defaults to ``utf-8``.
+SQLAlchemy's own unicode encode/decode functionality is steadily becoming
+obsolete as most DBAPIs now support unicode fully.
+
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+.. _psycopg2_isolation_level:
+
+Psycopg2 Transaction Isolation Level
+-------------------------------------
+
+As discussed in :ref:`postgresql_isolation_level`,
+all PostgreSQL dialects support setting of transaction isolation level
+both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine`
+,
+as well as the ``isolation_level`` argument used by
+:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect
+, these
+options make use of psycopg2's ``set_isolation_level()`` connection method,
+rather than emitting a PostgreSQL directive; this is because psycopg2's
+API-level setting is always emitted at the start of each transaction in any
+case.
+
+The psycopg2 dialect supports these constants for isolation level:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+ :ref:`pg8000_isolation_level`
+
+
+NOTICE logging
+---------------
+
+The psycopg2 dialect will log PostgreSQL NOTICE messages
+via the ``sqlalchemy.dialects.postgresql`` logger. When this logger
+is set to the ``logging.INFO`` level, notice messages will be logged::
+
+ import logging
+
+ logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
+
+Above, it is assumed that logging is configured externally. If this is not
+the case, configuration such as ``logging.basicConfig()`` must be utilized::
+
+ import logging
+
+ logging.basicConfig() # log messages to stdout
+ logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
+
+.. seealso::
+
+ `Logging HOWTO <https://docs.python.org/3/howto/logging.html>`_ - on the python.org website
+
+.. _psycopg2_hstore:
+
+HSTORE type
+------------
+
+The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
+the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
+by default when psycopg2 version 2.4 or greater is used, and
+it is detected that the target database has the HSTORE type set up for use.
+In other words, when the dialect makes the first
+connection, a sequence like the following is performed:
+
+1. Request the available HSTORE oids using
+ ``psycopg2.extras.HstoreAdapter.get_oids()``.
+ If this function returns a list of HSTORE identifiers, we then determine
+ that the ``HSTORE`` extension is present.
+ This function is **skipped** if the version of psycopg2 installed is
+ less than version 2.4.
+
+2. If the ``use_native_hstore`` flag is at its default of ``True``, and
+ we've detected that ``HSTORE`` oids are available, the
+ ``psycopg2.extensions.register_hstore()`` extension is invoked for all
+ connections.
+
+The ``register_hstore()`` extension has the effect of **all Python
+dictionaries being accepted as parameters regardless of the type of target
+column in SQL**. The dictionaries are converted by this extension into a
+textual HSTORE expression. If this behavior is not desired, disable the
+use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
+follows::
+
+ engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
+ use_native_hstore=False)
+
+The ``HSTORE`` type is **still supported** when the
+``psycopg2.extensions.register_hstore()`` extension is not used. It merely
+means that the coercion between Python dictionaries and the HSTORE
+string format, on both the parameter side and the result side, will take
+place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
+which may be more performant.
+
+""" # noqa
+from __future__ import absolute_import
+
+import decimal
+import logging
+import re
+from uuid import UUID as _python_UUID
+
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .hstore import HSTORE
+from .json import JSON
+from .json import JSONB
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...util import collections_abc
+
+
+logger = logging.getLogger("sqlalchemy.dialects.postgresql")
+
+
+class _PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGEnum(ENUM):
+ def result_processor(self, dialect, coltype):
+ if util.py2k and self._expect_unicode is True:
+ # for py2k, if the enum type needs unicode data (which is set up as
+ # part of the Enum() constructor based on values passed as py2k
+ # unicode objects) we have to use our own converters since
+ # psycopg2's don't work, a rare exception to the "modern DBAPIs
+ # support unicode everywhere" theme of deprecating
+ # convert_unicode=True. Use the special "force_nocheck" directive
+ # which forces unicode conversion to happen on the Python side
+ # without an isinstance() check. in py3k psycopg2 does the right
+ # thing automatically.
+ self._expect_unicode = "force_nocheck"
+ return super(_PGEnum, self).result_processor(dialect, coltype)
+
+
+class _PGHStore(HSTORE):
+ def bind_processor(self, dialect):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PGHStore, self).bind_processor(dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PGHStore, self).result_processor(dialect, coltype)
+
+
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
+class _PGJSON(JSON):
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGJSONB(JSONB):
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+_server_side_id = util.counter()
+
+
+class PGExecutionContext_psycopg2(PGExecutionContext):
+ _psycopg2_fetched_rows = None
+
+ def create_server_side_cursor(self):
+ # use server-side cursors:
+ # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+ return self._dbapi_connection.cursor(ident)
+
+ def post_exec(self):
+ if (
+ self._psycopg2_fetched_rows
+ and self.compiled
+ and self.compiled.returning
+ ):
+ # psycopg2 execute_values will provide for a real cursor where
+ # cursor.description works correctly. however, it executes the
+ # INSERT statement multiple times for multiple pages of rows, so
+ # while this cursor also supports calling .fetchall() directly, in
+ # order to get the list of all rows inserted across multiple pages,
+ # we have to retrieve the aggregated list from the execute_values()
+ # function directly.
+ strat_cls = _cursor.FullyBufferedCursorFetchStrategy
+ self.cursor_fetch_strategy = strat_cls(
+ self.cursor, initial_buffer=self._psycopg2_fetched_rows
+ )
+ self._log_notices(self.cursor)
+
+ def _log_notices(self, cursor):
+ # check also that notices is an iterable, after it's already
+ # established that we will be iterating through it. This is to get
+ # around test suites such as SQLAlchemy's using a Mock object for
+ # cursor
+ if not cursor.connection.notices or not isinstance(
+ cursor.connection.notices, collections_abc.Iterable
+ ):
+ return
+
+ for notice in cursor.connection.notices:
+ # NOTICE messages have a
+ # newline character at the end
+ logger.info(notice.rstrip())
+
+ cursor.connection.notices[:] = []
+
+
+class PGCompiler_psycopg2(PGCompiler):
+ pass
+
+
+class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
+ pass
+
+
+EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0)
+EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1)
+EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2)
+EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
+ "executemany_values_plus_batch",
+ canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES,
+)
+
+
+class PGDialect_psycopg2(PGDialect):
+ driver = "psycopg2"
+
+ supports_statement_cache = True
+
+ if util.py2k:
+ # turn off supports_unicode_statements for Python 2. psycopg2 supports
+ # unicode statements in Py2K. But! it does not support unicode *bound
+ # parameter names* because it uses the Python "%" operator to
+ # interpolate these into the string, and this fails. So for Py2K, we
+ # have to use full-on encoding for statements and parameters before
+ # passing to cursor.execute().
+ supports_unicode_statements = False
+
+ supports_server_side_cursors = True
+
+ default_paramstyle = "pyformat"
+ # set to true based on psycopg2 version
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PGExecutionContext_psycopg2
+ statement_compiler = PGCompiler_psycopg2
+ preparer = PGIdentifierPreparer_psycopg2
+ psycopg2_version = (0, 0)
+
+ _has_native_hstore = True
+
+ engine_config_types = PGDialect.engine_config_types.union(
+ {"use_native_unicode": util.asbool}
+ )
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: _PGNumeric,
+ ENUM: _PGEnum, # needs force_unicode
+ sqltypes.Enum: _PGEnum, # needs force_unicode
+ HSTORE: _PGHStore,
+ JSON: _PGJSON,
+ sqltypes.JSON: _PGJSON,
+ JSONB: _PGJSONB,
+ UUID: _PGUUID,
+ sqltypes.ARRAY: _PGARRAY,
+ },
+ )
+
+ def __init__(
+ self,
+ use_native_unicode=True,
+ client_encoding=None,
+ use_native_hstore=True,
+ use_native_uuid=True,
+ executemany_mode="values_only",
+ executemany_batch_page_size=100,
+ executemany_values_page_size=1000,
+ **kwargs
+ ):
+ PGDialect.__init__(self, **kwargs)
+ self.use_native_unicode = use_native_unicode
+ if not use_native_unicode and not util.py2k:
+ raise exc.ArgumentError(
+ "psycopg2 native_unicode mode is required under Python 3"
+ )
+ if not use_native_hstore:
+ self._has_native_hstore = False
+ self.use_native_hstore = use_native_hstore
+ self.use_native_uuid = use_native_uuid
+ self.supports_unicode_binds = use_native_unicode
+ self.client_encoding = client_encoding
+
+ # Parse executemany_mode argument, allowing it to be only one of the
+ # symbol names
+ self.executemany_mode = util.symbol.parse_user_argument(
+ executemany_mode,
+ {
+ EXECUTEMANY_PLAIN: [None],
+ EXECUTEMANY_BATCH: ["batch"],
+ EXECUTEMANY_VALUES: ["values_only"],
+ EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch", "values"],
+ },
+ "executemany_mode",
+ )
+
+ if self.executemany_mode & EXECUTEMANY_VALUES:
+ self.insert_executemany_returning = True
+
+ self.executemany_batch_page_size = executemany_batch_page_size
+ self.executemany_values_page_size = executemany_values_page_size
+
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+ if m:
+ self.psycopg2_version = tuple(
+ int(x) for x in m.group(1, 2, 3) if x is not None
+ )
+
+ if self.psycopg2_version < (2, 7):
+ raise ImportError(
+ "psycopg2 version 2.7 or higher is required."
+ )
+
+ def initialize(self, connection):
+ super(PGDialect_psycopg2, self).initialize(connection)
+ self._has_native_hstore = (
+ self.use_native_hstore
+ and self._hstore_oids(connection.connection) is not None
+ )
+
+ # PGDialect.initialize() checks server version for <= 8.2 and sets
+ # this flag to False if so
+ if not self.full_returning:
+ self.insert_executemany_returning = False
+ self.executemany_mode = EXECUTEMANY_PLAIN
+
+ self.supports_sane_multi_rowcount = not (
+ self.executemany_mode & EXECUTEMANY_BATCH
+ )
+
+ @classmethod
+ def dbapi(cls):
+ import psycopg2
+
+ return psycopg2
+
+ @classmethod
+ def _psycopg2_extensions(cls):
+ from psycopg2 import extensions
+
+ return extensions
+
+ @classmethod
+ def _psycopg2_extras(cls):
+ from psycopg2 import extras
+
+ return extras
+
+ @util.memoized_property
+ def _isolation_lookup(self):
+ extensions = self._psycopg2_extensions()
+ return {
+ "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT,
+ "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED,
+ "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
+ "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+ "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
+ }
+
+ def set_isolation_level(self, connection, level):
+ try:
+ level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ ),
+ replace_context=err,
+ )
+
+ connection.set_isolation_level(level)
+
+ def set_readonly(self, connection, value):
+ connection.readonly = value
+
+ def get_readonly(self, connection):
+ return connection.readonly
+
+ def set_deferrable(self, connection, value):
+ connection.deferrable = value
+
+ def get_deferrable(self, connection):
+ return connection.deferrable
+
+ def do_ping(self, dbapi_connection):
+ cursor = None
+ before_autocommit = dbapi_connection.autocommit
+ try:
+ if not before_autocommit:
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(self._dialect_specific_select_one)
+ finally:
+ cursor.close()
+ if not before_autocommit and not dbapi_connection.closed:
+ dbapi_connection.autocommit = before_autocommit
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, cursor):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def on_connect(self):
+ extras = self._psycopg2_extras()
+ extensions = self._psycopg2_extensions()
+
+ fns = []
+ if self.client_encoding is not None:
+
+ def on_connect(conn):
+ conn.set_client_encoding(self.client_encoding)
+
+ fns.append(on_connect)
+
+ if self.isolation_level is not None:
+
+ def on_connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self.use_native_uuid:
+
+ def on_connect(conn):
+ extras.register_uuid(None, conn)
+
+ fns.append(on_connect)
+
+ if util.py2k and self.dbapi and self.use_native_unicode:
+
+ def on_connect(conn):
+ extensions.register_type(extensions.UNICODE, conn)
+ extensions.register_type(extensions.UNICODEARRAY, conn)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self.use_native_hstore:
+
+ def on_connect(conn):
+ hstore_oids = self._hstore_oids(conn)
+ if hstore_oids is not None:
+ oid, array_oid = hstore_oids
+ kw = {"oid": oid}
+ if util.py2k:
+ kw["unicode"] = True
+ kw["array_oid"] = array_oid
+ extras.register_hstore(conn, **kw)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self._json_deserializer:
+
+ def on_connect(conn):
+ extras.register_default_json(
+ conn, loads=self._json_deserializer
+ )
+ extras.register_default_jsonb(
+ conn, loads=self._json_deserializer
+ )
+
+ fns.append(on_connect)
+
+ if fns:
+
+ def on_connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return on_connect
+ else:
+ return None
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if (
+ self.executemany_mode & EXECUTEMANY_VALUES
+ and context
+ and context.isinsert
+ and context.compiled._is_safe_for_fast_insert_values_helper
+ ):
+ executemany_values = (
+ "(%s)" % context.compiled.insert_single_values_expr
+ )
+ if not self.supports_unicode_statements:
+ executemany_values = executemany_values.encode(self.encoding)
+
+ # guard for statement that was altered via event hook or similar
+ if executemany_values not in statement:
+ executemany_values = None
+ else:
+ executemany_values = None
+
+ if executemany_values:
+ statement = statement.replace(executemany_values, "%s")
+ if self.executemany_values_page_size:
+ kwargs = {"page_size": self.executemany_values_page_size}
+ else:
+ kwargs = {}
+ xtras = self._psycopg2_extras()
+ context._psycopg2_fetched_rows = xtras.execute_values(
+ cursor,
+ statement,
+ parameters,
+ template=executemany_values,
+ fetch=bool(context.compiled.returning),
+ **kwargs
+ )
+
+ elif self.executemany_mode & EXECUTEMANY_BATCH:
+ if self.executemany_batch_page_size:
+ kwargs = {"page_size": self.executemany_batch_page_size}
+ else:
+ kwargs = {}
+ self._psycopg2_extras().execute_batch(
+ cursor, statement, parameters, **kwargs
+ )
+ else:
+ cursor.executemany(statement, parameters)
+
+ @util.memoized_instancemethod
+ def _hstore_oids(self, conn):
+ extras = self._psycopg2_extras()
+ if hasattr(conn, "dbapi_connection"):
+ conn = conn.dbapi_connection
+ oids = extras.HstoreAdapter.get_oids(conn)
+ if oids is not None and oids[0]:
+ return oids[0:2]
+ else:
+ return None
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ is_multihost = False
+ if "host" in url.query:
+ is_multihost = isinstance(url.query["host"], (list, tuple))
+
+ if opts or url.query:
+ if not opts:
+ opts = {}
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ opts.update(url.query)
+ if is_multihost:
+ hosts, ports = zip(
+ *[
+ token.split(":") if ":" in token else (token, "")
+ for token in url.query["host"]
+ ]
+ )
+ opts["host"] = ",".join(hosts)
+ if "port" in opts:
+ raise exc.ArgumentError(
+ "Can't mix 'multihost' formats together; use "
+ '"host=h1,h2,h3&port=p1,p2,p3" or '
+ '"host=h1:p1&host=h2:p2&host=h3:p3" separately'
+ )
+ opts["port"] = ",".join(ports)
+ return ([], opts)
+ else:
+ # no connection arguments whatsoever; psycopg2.connect()
+ # requires that "dsn" be present as a blank string.
+ return ([""], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ # check the "closed" flag. this might not be
+ # present on old psycopg2 versions. Also,
+ # this flag doesn't actually help in a lot of disconnect
+ # situations, so don't rely on it.
+ if getattr(connection, "closed", False):
+ return True
+
+ # checks based on strings. in the case that .closed
+ # didn't cut it, fall back onto these.
+ str_e = str(e).partition("\n")[0]
+ for msg in [
+ # these error messages from libpq: interfaces/libpq/fe-misc.c
+ # and interfaces/libpq/fe-secure.c.
+ "terminating connection",
+ "closed the connection",
+ "connection not open",
+ "could not receive data from server",
+ "could not send data to server",
+ # psycopg2 client errors, psycopg2/connection.h,
+ # psycopg2/cursor.h
+ "connection already closed",
+ "cursor already closed",
+ # not sure where this path is originally from, it may
+ # be obsolete. It really says "losed", not "closed".
+ "losed the connection unexpectedly",
+ # these can occur in newer SSL
+ "connection has been closed unexpectedly",
+ "SSL error: decryption failed or bad record mac",
+ "SSL SYSCALL error: Bad file descriptor",
+ "SSL SYSCALL error: EOF detected",
+ "SSL SYSCALL error: Operation timed out",
+ "SSL SYSCALL error: Bad address",
+ ]:
+ idx = str_e.find(msg)
+ if idx >= 0 and '"' not in str_e[:idx]:
+ return True
+ return False
+
+
+dialect = PGDialect_psycopg2
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
new file mode 100644
index 0000000..10d1aae
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
@@ -0,0 +1,60 @@
+# testing/engines.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg2cffi
+ :name: psycopg2cffi
+ :dbapi: psycopg2cffi
+ :connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/psycopg2cffi/
+
+``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
+layer. This makes it suitable for use in e.g. PyPy. Documentation
+is as per ``psycopg2``.
+
+.. versionadded:: 1.0.0
+
+.. seealso::
+
+ :mod:`sqlalchemy.dialects.postgresql.psycopg2`
+
+""" # noqa
+from .psycopg2 import PGDialect_psycopg2
+
+
+class PGDialect_psycopg2cffi(PGDialect_psycopg2):
+ driver = "psycopg2cffi"
+ supports_unicode_statements = True
+ supports_statement_cache = True
+
+ # psycopg2cffi's first release is 2.5.0, but reports
+ # __version__ as 2.4.4. Subsequent releases seem to have
+ # fixed this.
+
+ FEATURE_VERSION_MAP = dict(
+ native_json=(2, 4, 4),
+ native_jsonb=(2, 7, 1),
+ sane_multi_rowcount=(2, 4, 4),
+ array_oid=(2, 4, 4),
+ hstore_adapter=(2, 4, 4),
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("psycopg2cffi")
+
+ @classmethod
+ def _psycopg2_extensions(cls):
+ root = __import__("psycopg2cffi", fromlist=["extensions"])
+ return root.extensions
+
+ @classmethod
+ def _psycopg2_extras(cls):
+ root = __import__("psycopg2cffi", fromlist=["extras"])
+ return root.extras
+
+
+dialect = PGDialect_psycopg2cffi
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py
new file mode 100644
index 0000000..d273b8c
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py
@@ -0,0 +1,278 @@
+# postgresql/pygresql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: postgresql+pygresql
+ :name: pygresql
+ :dbapi: pgdb
+ :connectstring: postgresql+pygresql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://www.pygresql.org/
+
+.. note::
+
+ The pygresql dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended PostgreSQL
+ dialect is psycopg2.
+
+.. deprecated:: 1.4 The pygresql DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to PostgreSQL.
+
+""" # noqa
+
+import decimal
+import re
+
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .hstore import HSTORE
+from .json import JSON
+from .json import JSONB
+from ... import exc
+from ... import processors
+from ... import util
+from ...sql.elements import Null
+from ...types import JSON as Json
+from ...types import Numeric
+
+
+class _PGNumeric(Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if not isinstance(coltype, int):
+ coltype = coltype.oid
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # PyGreSQL returns Decimal natively for 1700 (numeric)
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # PyGreSQL returns float natively for 701 (float8)
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGHStore(HSTORE):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_hstore:
+ return super(_PGHStore, self).bind_processor(dialect)
+ hstore = dialect.dbapi.Hstore
+
+ def process(value):
+ if isinstance(value, dict):
+ return hstore(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_hstore:
+ return super(_PGHStore, self).result_processor(dialect, coltype)
+
+
+class _PGJSON(JSON):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_json:
+ return super(_PGJSON, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_json:
+ return super(_PGJSON, self).result_processor(dialect, coltype)
+
+
+class _PGJSONB(JSONB):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_json:
+ return super(_PGJSONB, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_json:
+ return super(_PGJSONB, self).result_processor(dialect, coltype)
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_uuid:
+ return super(_PGUUID, self).bind_processor(dialect)
+ uuid = dialect.dbapi.Uuid
+
+ def process(value):
+ if value is None:
+ return None
+ if isinstance(value, (str, bytes)):
+ if len(value) == 16:
+ return uuid(bytes=value)
+ return uuid(value)
+ if isinstance(value, int):
+ return uuid(int=value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_uuid:
+ return super(_PGUUID, self).result_processor(dialect, coltype)
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ return str(value)
+
+ return process
+
+
+class _PGCompiler(PGCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+
+ def post_process_text(self, text):
+ return text.replace("%", "%%")
+
+
+class _PGIdentifierPreparer(PGIdentifierPreparer):
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ return value.replace("%", "%%")
+
+
+class PGDialect_pygresql(PGDialect):
+
+ driver = "pygresql"
+ supports_statement_cache = True
+
+ statement_compiler = _PGCompiler
+ preparer = _PGIdentifierPreparer
+
+ @classmethod
+ def dbapi(cls):
+ import pgdb
+
+ util.warn_deprecated(
+ "The pygresql DBAPI is deprecated and will be removed "
+ "in a future version. Please use one of the supported DBAPIs to "
+ "connect to PostgreSQL.",
+ version="1.4",
+ )
+
+ return pgdb
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ Numeric: _PGNumeric,
+ HSTORE: _PGHStore,
+ Json: _PGJSON,
+ JSON: _PGJSON,
+ JSONB: _PGJSONB,
+ UUID: _PGUUID,
+ },
+ )
+
+ def __init__(self, **kwargs):
+ super(PGDialect_pygresql, self).__init__(**kwargs)
+ try:
+ version = self.dbapi.version
+ m = re.match(r"(\d+)\.(\d+)", version)
+ version = (int(m.group(1)), int(m.group(2)))
+ except (AttributeError, ValueError, TypeError):
+ version = (0, 0)
+ self.dbapi_version = version
+ if version < (5, 0):
+ has_native_hstore = has_native_json = has_native_uuid = False
+ if version != (0, 0):
+ util.warn(
+ "PyGreSQL is only fully supported by SQLAlchemy"
+ " since version 5.0."
+ )
+ else:
+ self.supports_unicode_statements = True
+ self.supports_unicode_binds = True
+ has_native_hstore = has_native_json = has_native_uuid = True
+ self.has_native_hstore = has_native_hstore
+ self.has_native_json = has_native_json
+ self.has_native_uuid = has_native_uuid
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["host"] = "%s:%s" % (
+ opts.get("host", "").rsplit(":", 1)[0],
+ opts.pop("port"),
+ )
+ opts.update(url.query)
+ return [], opts
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ if not connection:
+ return False
+ try:
+ connection = connection.connection
+ except AttributeError:
+ pass
+ else:
+ if not connection:
+ return False
+ try:
+ return connection.closed
+ except AttributeError: # PyGreSQL < 5.0
+ return connection._cnx is None
+ return False
+
+
+dialect = PGDialect_pygresql
diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
new file mode 100644
index 0000000..886e368
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
@@ -0,0 +1,126 @@
+# postgresql/pypostgresql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: postgresql+pypostgresql
+ :name: py-postgresql
+ :dbapi: pypostgresql
+ :connectstring: postgresql+pypostgresql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://python.projects.pgfoundry.org/
+
+.. note::
+
+ The pypostgresql dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended PostgreSQL
+ driver is psycopg2.
+
+.. deprecated:: 1.4 The py-postgresql DBAPI is deprecated and will be removed
+ in a future version. This DBAPI is superseded by the external
+ version available at external-dialect_. Please use the external version or
+ one of the supported DBAPIs to connect to PostgreSQL.
+
+.. TODO update link
+.. _external-dialect: https://github.com/PyGreSQL
+
+""" # noqa
+
+from .base import PGDialect
+from .base import PGExecutionContext
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+
+class PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return processors.to_str
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ return None
+ else:
+ return processors.to_float
+
+
+class PGExecutionContext_pypostgresql(PGExecutionContext):
+ pass
+
+
+class PGDialect_pypostgresql(PGDialect):
+ driver = "pypostgresql"
+
+ supports_statement_cache = True
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+ description_encoding = None
+ default_paramstyle = "pyformat"
+
+ # requires trunk version to support sane rowcounts
+ # TODO: use dbapi version information to set this flag appropriately
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ execution_ctx_cls = PGExecutionContext_pypostgresql
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: PGNumeric,
+ # prevents PGNumeric from being used
+ sqltypes.Float: sqltypes.Float,
+ },
+ )
+
+ @classmethod
+ def dbapi(cls):
+ from postgresql.driver import dbapi20
+
+ # TODO update link
+ util.warn_deprecated(
+ "The py-postgresql DBAPI is deprecated and will be removed "
+ "in a future version. This DBAPI is superseded by the external"
+ "version available at https://github.com/PyGreSQL. Please "
+ "use one of the supported DBAPIs to connect to PostgreSQL.",
+ version="1.4",
+ )
+
+ return dbapi20
+
+ _DBAPI_ERROR_NAMES = [
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ ]
+
+ @util.memoized_property
+ def dbapi_exception_translation_map(self):
+ if self.dbapi is None:
+ return {}
+
+ return dict(
+ (getattr(self.dbapi, name).__name__, name)
+ for name in self._DBAPI_ERROR_NAMES
+ )
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ else:
+ opts["port"] = 5432
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ return "connection is closed" in str(e)
+
+
+dialect = PGDialect_pypostgresql
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
new file mode 100644
index 0000000..51f3b04
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -0,0 +1,138 @@
+# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import types as sqltypes
+
+
+__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE")
+
+
+class RangeOperators(object):
+ """
+ This mixin provides functionality for the Range Operators
+ listed in the Range Operators table of the `PostgreSQL documentation`__
+ for Range Functions and Operators. It is used by all the range types
+ provided in the ``postgres`` dialect and can likely be used for
+ any range types you create yourself.
+
+ __ https://www.postgresql.org/docs/current/static/functions-range.html
+
+ No extra support is provided for the Range Functions listed in the Range
+ Functions table of the PostgreSQL documentation. For these, the normal
+ :func:`~sqlalchemy.sql.expression.func` object should be used.
+
+ """
+
+ class comparator_factory(sqltypes.Concatenable.Comparator):
+ """Define comparison operations for range types."""
+
+ def __ne__(self, other):
+ "Boolean expression. Returns true if two ranges are not equal"
+ if other is None:
+ return super(RangeOperators.comparator_factory, self).__ne__(
+ other
+ )
+ else:
+ return self.expr.op("<>", is_comparison=True)(other)
+
+ def contains(self, other, **kw):
+ """Boolean expression. Returns true if the right hand operand,
+ which can be an element or a range, is contained within the
+ column.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.expr.op("@>", is_comparison=True)(other)
+
+ def contained_by(self, other):
+ """Boolean expression. Returns true if the column is contained
+ within the right hand operand.
+ """
+ return self.expr.op("<@", is_comparison=True)(other)
+
+ def overlaps(self, other):
+ """Boolean expression. Returns true if the column overlaps
+ (has points in common with) the right hand operand.
+ """
+ return self.expr.op("&&", is_comparison=True)(other)
+
+ def strictly_left_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ left of the right hand operand.
+ """
+ return self.expr.op("<<", is_comparison=True)(other)
+
+ __lshift__ = strictly_left_of
+
+ def strictly_right_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ right of the right hand operand.
+ """
+ return self.expr.op(">>", is_comparison=True)(other)
+
+ __rshift__ = strictly_right_of
+
+ def not_extend_right_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend right of the range in the operand.
+ """
+ return self.expr.op("&<", is_comparison=True)(other)
+
+ def not_extend_left_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend left of the range in the operand.
+ """
+ return self.expr.op("&>", is_comparison=True)(other)
+
+ def adjacent_to(self, other):
+ """Boolean expression. Returns true if the range in the column
+ is adjacent to the range in the operand.
+ """
+ return self.expr.op("-|-", is_comparison=True)(other)
+
+ def __add__(self, other):
+ """Range expression. Returns the union of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ return self.expr.op("+")(other)
+
+
+class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL INT4RANGE type."""
+
+ __visit_name__ = "INT4RANGE"
+
+
+class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL INT8RANGE type."""
+
+ __visit_name__ = "INT8RANGE"
+
+
+class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL NUMRANGE type."""
+
+ __visit_name__ = "NUMRANGE"
+
+
+class DATERANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL DATERANGE type."""
+
+ __visit_name__ = "DATERANGE"
+
+
+class TSRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL TSRANGE type."""
+
+ __visit_name__ = "TSRANGE"
+
+
+class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL TSTZRANGE type."""
+
+ __visit_name__ = "TSTZRANGE"
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
new file mode 100644
index 0000000..8d8d933
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/__init__.py
@@ -0,0 +1,58 @@
+# sqlite/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import pysqlcipher # noqa
+from . import pysqlite # noqa
+from .base import BLOB
+from .base import BOOLEAN
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DECIMAL
+from .base import FLOAT
+from .base import INTEGER
+from .base import JSON
+from .base import NUMERIC
+from .base import REAL
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import VARCHAR
+from .dml import Insert
+from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+ from . import aiosqlite # noqa
+
+# default dialect
+base.dialect = dialect = pysqlite.dialect
+
+
+__all__ = (
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "JSON",
+ "NUMERIC",
+ "SMALLINT",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "VARCHAR",
+ "REAL",
+ "Insert",
+ "insert",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
new file mode 100644
index 0000000..9fc6d35
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
@@ -0,0 +1,335 @@
+# sqlite/aiosqlite.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: sqlite+aiosqlite
+ :name: aiosqlite
+ :dbapi: aiosqlite
+ :connectstring: sqlite+aiosqlite:///file_path
+ :url: https://pypi.org/project/aiosqlite/
+
+The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
+running on top of pysqlite.
+
+aiosqlite is a wrapper around pysqlite that uses a background thread for
+each connection. It does not actually use non-blocking IO, as SQLite
+databases are not socket-based. However it does provide a working asyncio
+interface that's useful for testing and prototyping purposes.
+
+Using a special asyncio mediation layer, the aiosqlite dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("sqlite+aiosqlite:///filename")
+
+The URL passes through all arguments to the ``pysqlite`` driver, so all
+connection arguments are the same as they are for that of :ref:`pysqlite`.
+
+
+""" # noqa
+
+from .base import SQLiteExecutionContext
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiosqlite_cursor:
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "description",
+ "await_",
+ "_rows",
+ "arraysize",
+ "rowcount",
+ "lastrowid",
+ )
+
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+ self.arraysize = 1
+ self.rowcount = -1
+ self.description = None
+ self._rows = []
+
+ def close(self):
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+
+ if parameters is None:
+ self.await_(_cursor.execute(operation))
+ else:
+ self.await_(_cursor.execute(operation, parameters))
+
+ if _cursor.description:
+ self.description = _cursor.description
+ self.lastrowid = self.rowcount = -1
+
+ if not self.server_side:
+ self._rows = self.await_(_cursor.fetchall())
+ else:
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+
+ if not self.server_side:
+ self.await_(_cursor.close())
+ else:
+ self._cursor = _cursor
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def executemany(self, operation, seq_of_parameters):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+ self.await_(_cursor.executemany(operation, seq_of_parameters))
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+ self.await_(_cursor.close())
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
+ __slots__ = "_cursor"
+
+ server_side = True
+
+ def __init__(self, *arg, **kw):
+ super().__init__(*arg, **kw)
+ self._cursor = None
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+
+ @property
+ def isolation_level(self):
+ return self._connection.isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value):
+ try:
+ self._connection.isolation_level = value
+ except Exception as error:
+ self._handle_exception(error)
+
+ def create_function(self, *args, **kw):
+ try:
+ self.await_(self._connection.create_function(*args, **kw))
+ except Exception as error:
+ self._handle_exception(error)
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiosqlite_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiosqlite_cursor(self)
+
+ def execute(self, *args, **kw):
+ return self.await_(self._connection.execute(*args, **kw))
+
+ def rollback(self):
+ try:
+ self.await_(self._connection.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def commit(self):
+ try:
+ self.await_(self._connection.commit())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def close(self):
+ try:
+ self.await_(self._connection.close())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def _handle_exception(self, error):
+ if (
+ isinstance(error, ValueError)
+ and error.args[0] == "no active connection"
+ ):
+ util.raise_(
+ self.dbapi.sqlite.OperationalError("no active connection"),
+ from_=error,
+ )
+ else:
+ raise error
+
+
+class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiosqlite_dbapi:
+ def __init__(self, aiosqlite, sqlite):
+ self.aiosqlite = aiosqlite
+ self.sqlite = sqlite
+ self.paramstyle = "qmark"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "DatabaseError",
+ "Error",
+ "IntegrityError",
+ "NotSupportedError",
+ "OperationalError",
+ "ProgrammingError",
+ "sqlite_version",
+ "sqlite_version_info",
+ ):
+ setattr(self, name, getattr(self.aiosqlite, name))
+
+ for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ for name in ("Binary",):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ # Q. WHY do we need this?
+ # A. Because there is no way to set connection.isolation_level
+ # otherwise
+ # Q. BUT HOW do you know it is SAFE ?????
+ # A. The only operation that isn't safe is the isolation level set
+ # operation which aiosqlite appears to have let slip through even
+ # though pysqlite appears to do check_same_thread for this.
+ # All execute operations etc. should be safe because they all
+ # go through the single executor thread.
+
+ kw["check_same_thread"] = False
+
+ connection = self.aiosqlite.connect(*arg, **kw)
+
+ # it's a Thread. you'll thank us later
+ connection.daemon = True
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_aiosqlite_connection(
+ self,
+ await_fallback(connection),
+ )
+ else:
+ return AsyncAdapt_aiosqlite_connection(
+ self,
+ await_only(connection),
+ )
+
+
+class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(server_side=True)
+
+
+class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
+ driver = "aiosqlite"
+ supports_statement_cache = True
+
+ is_async = True
+
+ supports_server_side_cursors = True
+
+ execution_ctx_cls = SQLiteExecutionContext_aiosqlite
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiosqlite_dbapi(
+ __import__("aiosqlite"), __import__("sqlite3")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+ if cls._is_url_file_db(url):
+ return pool.NullPool
+ else:
+ return pool.StaticPool
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, self.dbapi.OperationalError
+ ) and "no active connection" in str(e):
+ return True
+
+ return super().is_disconnect(e, connection, cursor)
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = SQLiteDialect_aiosqlite
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
new file mode 100644
index 0000000..0959d04
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -0,0 +1,2556 @@
+# sqlite/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: sqlite
+ :name: SQLite
+ :full_support: 3.21, 3.28+
+ :normal_support: 3.12+
+ :best_effort: 3.7.16+
+
+.. _sqlite_datetime:
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does
+not provide out of the box functionality for translating values between Python
+`datetime` objects and a SQLite-supported format. SQLAlchemy's own
+:class:`~sqlalchemy.types.DateTime` and related types provide date formatting
+and parsing functionality when SQLite is used. The implementation classes are
+:class:`_sqlite.DATETIME`, :class:`_sqlite.DATE` and :class:`_sqlite.TIME`.
+These types represent dates and times as ISO formatted strings, which also
+nicely support ordering. There's no reliance on typical "libc" internals for
+these functions so historical dates are fully supported.
+
+Ensuring Text affinity
+^^^^^^^^^^^^^^^^^^^^^^
+
+The DDL rendered for these types is the standard ``DATE``, ``TIME``
+and ``DATETIME`` indicators. However, custom storage formats can also be
+applied to these types. When the
+storage format is detected as containing no alpha characters, the DDL for
+these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``,
+so that the column continues to have textual affinity.
+
+.. seealso::
+
+ `Type Affinity <https://www.sqlite.org/datatype3.html#affinity>`_ -
+ in the SQLite documentation
+
+.. _sqlite_autoincrement:
+
+SQLite Auto Incrementing Behavior
+----------------------------------
+
+Background on SQLite's autoincrement is at: https://sqlite.org/autoinc.html
+
+Key concepts:
+
+* SQLite has an implicit "auto increment" feature that takes place for any
+ non-composite primary-key column that is specifically created using
+ "INTEGER PRIMARY KEY" for the type + primary key.
+
+* SQLite also has an explicit "AUTOINCREMENT" keyword, that is **not**
+ equivalent to the implicit autoincrement feature; this keyword is not
+ recommended for general use. SQLAlchemy does not render this keyword
+ unless a special SQLite-specific directive is used (see below). However,
+ it still requires that the column's type is named "INTEGER".
+
+Using the AUTOINCREMENT Keyword
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To specifically render the AUTOINCREMENT keyword on the primary key column
+when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table
+construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ sqlite_autoincrement=True)
+
+Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+SQLite's typing model is based on naming conventions. Among other things, this
+means that any type name which contains the substring ``"INT"`` will be
+determined to be of "integer affinity". A type named ``"BIGINT"``,
+``"SPECIAL_INT"`` or even ``"XYZINTQPR"``, will be considered by SQLite to be
+of "integer" affinity. However, **the SQLite autoincrement feature, whether
+implicitly or explicitly enabled, requires that the name of the column's type
+is exactly the string "INTEGER"**. Therefore, if an application uses a type
+like :class:`.BigInteger` for a primary key, on SQLite this type will need to
+be rendered as the name ``"INTEGER"`` when emitting the initial ``CREATE
+TABLE`` statement in order for the autoincrement behavior to be available.
+
+One approach to achieve this is to use :class:`.Integer` on SQLite
+only using :meth:`.TypeEngine.with_variant`::
+
+ table = Table(
+ "my_table", metadata,
+ Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True)
+ )
+
+Another is to use a subclass of :class:`.BigInteger` that overrides its DDL
+name to be ``INTEGER`` when compiled against SQLite::
+
+ from sqlalchemy import BigInteger
+ from sqlalchemy.ext.compiler import compiles
+
+ class SLBigInteger(BigInteger):
+ pass
+
+ @compiles(SLBigInteger, 'sqlite')
+ def bi_c(element, compiler, **kw):
+ return "INTEGER"
+
+ @compiles(SLBigInteger)
+ def bi_c(element, compiler, **kw):
+ return compiler.visit_BIGINT(element, **kw)
+
+
+ table = Table(
+ "my_table", metadata,
+ Column("id", SLBigInteger(), primary_key=True)
+ )
+
+.. seealso::
+
+ :meth:`.TypeEngine.with_variant`
+
+ :ref:`sqlalchemy.ext.compiler_toplevel`
+
+ `Datatypes In SQLite Version 3 <https://sqlite.org/datatype3.html>`_
+
+.. _sqlite_concurrency:
+
+Database Locking Behavior / Concurrency
+---------------------------------------
+
+SQLite is not designed for a high level of write concurrency. The database
+itself, being a file, is locked completely during write operations within
+transactions, meaning exactly one "connection" (in reality a file handle)
+has exclusive access to the database during this period - all other
+"connections" will be blocked during this time.
+
+The Python DBAPI specification also calls for a connection model that is
+always in a transaction; there is no ``connection.begin()`` method,
+only ``connection.commit()`` and ``connection.rollback()``, upon which a
+new transaction is to be begun immediately. This may seem to imply
+that the SQLite driver would in theory allow only a single filehandle on a
+particular database file at any time; however, there are several
+factors both within SQLite itself as well as within the pysqlite driver
+which loosen this restriction significantly.
+
+However, no matter what locking modes are used, SQLite will still always
+lock the database file once a transaction is started and DML (e.g. INSERT,
+UPDATE, DELETE) has at least been emitted, and this will block
+other transactions at least at the point that they also attempt to emit DML.
+By default, the length of time on this block is very short before it times out
+with an error.
+
+This behavior becomes more critical when used in conjunction with the
+SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs
+within a transaction, and with its autoflush model, may emit DML preceding
+any SELECT statement. This may lead to a SQLite database that locks
+more quickly than is expected. The locking mode of SQLite and the pysqlite
+driver can be manipulated to some degree, however it should be noted that
+achieving a high degree of write-concurrency with SQLite is a losing battle.
+
+For more information on SQLite's lack of write concurrency by design, please
+see
+`Situations Where Another RDBMS May Work Better - High Concurrency
+<https://www.sqlite.org/whentouse.html>`_ near the bottom of the page.
+
+The following subsections introduce areas that are impacted by SQLite's
+file-based architecture and additionally will usually require workarounds to
+work when using the pysqlite driver.
+
+.. _sqlite_isolation_level:
+
+Transaction Isolation Level / Autocommit
+----------------------------------------
+
+SQLite supports "transaction isolation" in a non-standard way, along two
+axes. One is that of the
+`PRAGMA read_uncommitted <https://www.sqlite.org/pragma.html#pragma_read_uncommitted>`_
+instruction. This setting can essentially switch SQLite between its
+default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation
+mode normally referred to as ``READ UNCOMMITTED``.
+
+SQLAlchemy ties into this PRAGMA statement using the
+:paramref:`_sa.create_engine.isolation_level` parameter of
+:func:`_sa.create_engine`.
+Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"``
+and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively.
+SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by
+the pysqlite driver's default behavior.
+
+When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also
+available, which will alter the pysqlite connection using the ``.isolation_level``
+attribute on the DBAPI connection and set it to None for the duration
+of the setting.
+
+.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level
+ when using the pysqlite / sqlite3 SQLite driver.
+
+
+The other axis along which SQLite's transactional locking is impacted is
+via the nature of the ``BEGIN`` statement used. The three varieties
+are "deferred", "immediate", and "exclusive", as described at
+`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_. A straight
+``BEGIN`` statement uses the "deferred" mode, where the database file is
+not locked until the first read or write operation, and read access remains
+open to other transactions until the first write operation. But again,
+it is critical to note that the pysqlite driver interferes with this behavior
+by *not even emitting BEGIN* until the first write operation.
+
+.. warning::
+
+ SQLite's transactional scope is impacted by unresolved
+ issues in the pysqlite driver, which defers BEGIN statements to a greater
+ degree than is often feasible. See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+SAVEPOINT Support
+----------------------------
+
+SQLite supports SAVEPOINTs, which only function once a transaction is
+begun. SQLAlchemy's SAVEPOINT support is available using the
+:meth:`_engine.Connection.begin_nested` method at the Core level, and
+:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs
+won't work at all with pysqlite unless workarounds are taken.
+
+.. warning::
+
+ SQLite's SAVEPOINT feature is impacted by unresolved
+ issues in the pysqlite driver, which defers BEGIN statements to a greater
+ degree than is often feasible. See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+Transactional DDL
+----------------------------
+
+The SQLite database supports transactional :term:`DDL` as well.
+In this case, the pysqlite driver is not only failing to start transactions,
+it also is ending any existing transaction when DDL is detected, so again,
+workarounds are required.
+
+.. warning::
+
+ SQLite's transactional DDL is impacted by unresolved issues
+ in the pysqlite driver, which fails to emit BEGIN and additionally
+ forces a COMMIT to cancel any transaction when DDL is encountered.
+ See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+.. _sqlite_foreign_keys:
+
+Foreign Key Support
+-------------------
+
+SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables,
+however by default these constraints have no effect on the operation of the
+table.
+
+Constraint checking on SQLite has three prerequisites:
+
+* At least version 3.6.19 of SQLite must be in use
+* The SQLite library must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY
+ or SQLITE_OMIT_TRIGGER symbols enabled.
+* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all
+ connections before use -- including the initial call to
+ :meth:`sqlalchemy.schema.MetaData.create_all`.
+
+SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for
+new connections through the usage of events::
+
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "connect")
+ def set_sqlite_pragma(dbapi_connection, connection_record):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("PRAGMA foreign_keys=ON")
+ cursor.close()
+
+.. warning::
+
+ When SQLite foreign keys are enabled, it is **not possible**
+ to emit CREATE or DROP statements for tables that contain
+ mutually-dependent foreign key constraints;
+ to emit the DDL for these tables requires that ALTER TABLE be used to
+ create or drop these constraints separately, for which SQLite has
+ no support.
+
+.. seealso::
+
+ `SQLite Foreign Key Support <https://www.sqlite.org/foreignkeys.html>`_
+ - on the SQLite web site.
+
+ :ref:`event_toplevel` - SQLAlchemy event API.
+
+ :ref:`use_alter` - more information on SQLAlchemy's facilities for handling
+ mutually-dependent foreign key constraints.
+
+.. _sqlite_on_conflict_ddl:
+
+ON CONFLICT support for constraints
+-----------------------------------
+
+.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for
+ SQLite, which occurs within a CREATE TABLE statement. For "ON CONFLICT" as
+ applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`.
+
+SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied
+to primary key, unique, check, and not null constraints. In DDL, it is
+rendered either within the "CONSTRAINT" clause or within the column definition
+itself depending on the location of the target constraint. To render this
+clause within DDL, the extension parameter ``sqlite_on_conflict`` can be
+specified with a string conflict resolution algorithm within the
+:class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`,
+:class:`.CheckConstraint` objects. Within the :class:`_schema.Column` object,
+there
+are individual parameters ``sqlite_on_conflict_not_null``,
+``sqlite_on_conflict_primary_key``, ``sqlite_on_conflict_unique`` which each
+correspond to the three types of relevant constraint types that can be
+indicated from a :class:`_schema.Column` object.
+
+.. seealso::
+
+ `ON CONFLICT <https://www.sqlite.org/lang_conflict.html>`_ - in the SQLite
+ documentation
+
+.. versionadded:: 1.3
+
+
+The ``sqlite_on_conflict`` parameters accept a string argument which is just
+the resolution name to be chosen, which on SQLite can be one of ROLLBACK,
+ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint
+that specifies the IGNORE algorithm::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer),
+ UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE')
+ )
+
+The above renders CREATE TABLE DDL as::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER,
+ PRIMARY KEY (id),
+ UNIQUE (id, data) ON CONFLICT IGNORE
+ )
+
+
+When using the :paramref:`_schema.Column.unique`
+flag to add a UNIQUE constraint
+to a single column, the ``sqlite_on_conflict_unique`` parameter can
+be added to the :class:`_schema.Column` as well, which will be added to the
+UNIQUE constraint in the DDL::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer, unique=True,
+ sqlite_on_conflict_unique='IGNORE')
+ )
+
+rendering::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER,
+ PRIMARY KEY (id),
+ UNIQUE (data) ON CONFLICT IGNORE
+ )
+
+To apply the FAIL algorithm for a NOT NULL constraint,
+``sqlite_on_conflict_not_null`` is used::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer, nullable=False,
+ sqlite_on_conflict_not_null='FAIL')
+ )
+
+this renders the column inline ON CONFLICT phrase::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER NOT NULL ON CONFLICT FAIL,
+ PRIMARY KEY (id)
+ )
+
+
+Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True,
+ sqlite_on_conflict_primary_key='FAIL')
+ )
+
+SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict
+resolution algorithm is applied to the constraint itself::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ PRIMARY KEY (id) ON CONFLICT FAIL
+ )
+
+.. _sqlite_on_conflict_insert:
+
+INSERT...ON CONFLICT (Upsert)
+-----------------------------------
+
+.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for
+ SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as
+ applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`.
+
+From version 3.24.0 onwards, SQLite supports "upserts" (update or insert)
+of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT``
+statement. A candidate row will only be inserted if that row does not violate
+any unique or primary key constraints. In the case of a unique constraint violation, a
+secondary action can occur which can be either "DO UPDATE", indicating that
+the data in the target row should be updated, or "DO NOTHING", which indicates
+to silently skip this row.
+
+Conflicts are determined using columns that are part of existing unique
+constraints and indexes. These constraints are identified by stating the
+columns and conditions that comprise the indexes.
+
+SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific
+:func:`_sqlite.insert()` function, which provides
+the generative methods :meth:`_sqlite.Insert.on_conflict_do_update`
+and :meth:`_sqlite.Insert.on_conflict_do_nothing`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.sqlite import insert
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?{stop}
+
+ >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(
+ ... index_elements=['id']
+ ... )
+
+ >>> print(do_nothing_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO NOTHING
+
+.. versionadded:: 1.4
+
+.. seealso::
+
+ `Upsert
+ <https://sqlite.org/lang_UPSERT.html>`_
+ - in the SQLite documentation.
+
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
+Both methods supply the "target" of the conflict using column inference:
+
+* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument
+ specifies a sequence containing string column names, :class:`_schema.Column`
+ objects, and/or SQL expression elements, which would identify a unique index
+ or unique constraint.
+
+* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements`
+ to infer an index, a partial index can be inferred by also specifying the
+ :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data')
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.user_email],
+ ... index_where=my_table.c.user_email.like('%@gmail.com'),
+ ... set_=dict(data=stmt.excluded.data)
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (data, user_email) VALUES (?, ?)
+ ON CONFLICT (user_email)
+ WHERE user_email LIKE '%@gmail.com'
+ DO UPDATE SET data = excluded.data
+ >>>
+
+The SET Clause
+^^^^^^^^^^^^^^^
+
+``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are specified using the
+:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter. This
+parameter accepts a dictionary which consists of direct values
+for UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+
+ >>> print(do_update_stmt)
+
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?
+
+.. warning::
+
+ The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take
+ into account Python-side default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`. These
+ values will not be exercised for an ON CONFLICT style of UPDATE, unless
+ they are manually specified in the
+ :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`~.sqlite.Insert.excluded` is available as an attribute on
+the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix
+on a column, that informs the DO UPDATE to update the row with the value that
+would have been inserted had the constraint not failed:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author)
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts
+a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where`
+parameter, which will limit those rows which receive an UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+
+ >>> on_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author),
+ ... where=(my_table.c.status == 2)
+ ... )
+ >>> print(on_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+ WHERE my_table.status = ?
+
+
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
+if any conflict with a unique constraint occurs; below this is illustrated
+using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id'])
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING
+
+
+If ``DO NOTHING`` is used without specifying any columns or constraint,
+it has the effect of skipping the INSERT for any unique violation which
+occurs:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing()
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING
+
+.. _sqlite_type_reflection:
+
+Type Reflection
+---------------
+
+SQLite types are unlike those of most other database backends, in that
+the string name of the type usually does not correspond to a "type" in a
+one-to-one fashion. Instead, SQLite links per-column typing behavior
+to one of five so-called "type affinities" based on a string matching
+pattern for the type.
+
+SQLAlchemy's reflection process, when inspecting types, uses a simple
+lookup table to link the keywords returned to provided SQLAlchemy types.
+This lookup table is present within the SQLite dialect as it is for all
+other dialects. However, the SQLite dialect has a different "fallback"
+routine for when a particular type name is not located in the lookup map;
+it instead implements the SQLite "type affinity" scheme located at
+https://www.sqlite.org/datatype3.html section 2.1.
+
+The provided typemap will make direct associations from an exact string
+name match for the following types:
+
+:class:`_types.BIGINT`, :class:`_types.BLOB`,
+:class:`_types.BOOLEAN`, :class:`_types.BOOLEAN`,
+:class:`_types.CHAR`, :class:`_types.DATE`,
+:class:`_types.DATETIME`, :class:`_types.FLOAT`,
+:class:`_types.DECIMAL`, :class:`_types.FLOAT`,
+:class:`_types.INTEGER`, :class:`_types.INTEGER`,
+:class:`_types.NUMERIC`, :class:`_types.REAL`,
+:class:`_types.SMALLINT`, :class:`_types.TEXT`,
+:class:`_types.TIME`, :class:`_types.TIMESTAMP`,
+:class:`_types.VARCHAR`, :class:`_types.NVARCHAR`,
+:class:`_types.NCHAR`
+
+When a type name does not match one of the above types, the "type affinity"
+lookup is used instead:
+
+* :class:`_types.INTEGER` is returned if the type name includes the
+ string ``INT``
+* :class:`_types.TEXT` is returned if the type name includes the
+ string ``CHAR``, ``CLOB`` or ``TEXT``
+* :class:`_types.NullType` is returned if the type name includes the
+ string ``BLOB``
+* :class:`_types.REAL` is returned if the type name includes the string
+ ``REAL``, ``FLOA`` or ``DOUB``.
+* Otherwise, the :class:`_types.NUMERIC` type is used.
+
+.. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting
+ columns.
+
+
+.. _sqlite_partial_index:
+
+Partial Indexes
+---------------
+
+A partial index, e.g. one which uses a WHERE clause, can be specified
+with the DDL system using the argument ``sqlite_where``::
+
+ tbl = Table('testtbl', m, Column('data', Integer))
+ idx = Index('test_idx1', tbl.c.data,
+ sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+The index will be rendered at create time as::
+
+ CREATE INDEX test_idx1 ON testtbl (data)
+ WHERE data > 5 AND data < 10
+
+.. versionadded:: 0.9.9
+
+.. _sqlite_dotted_column_names:
+
+Dotted Column Names
+-------------------
+
+Using table or column names that explicitly have periods in them is
+**not recommended**. While this is generally a bad idea for relational
+databases in general, as the dot is a syntactically significant character,
+the SQLite driver up until version **3.10.0** of SQLite has a bug which
+requires that SQLAlchemy filter out these dots in result sets.
+
+.. versionchanged:: 1.1
+
+ The following SQLite issue has been resolved as of version 3.10.0
+ of SQLite. SQLAlchemy as of **1.1** automatically disables its internal
+ workarounds based on detection of this version.
+
+The bug, entirely outside of SQLAlchemy, can be illustrated thusly::
+
+ import sqlite3
+
+ assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version"
+
+ conn = sqlite3.connect(":memory:")
+ cursor = conn.cursor()
+
+ cursor.execute("create table x (a integer, b integer)")
+ cursor.execute("insert into x (a, b) values (1, 1)")
+ cursor.execute("insert into x (a, b) values (2, 2)")
+
+ cursor.execute("select x.a, x.b from x")
+ assert [c[0] for c in cursor.description] == ['a', 'b']
+
+ cursor.execute('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert [c[0] for c in cursor.description] == ['a', 'b'], \
+ [c[0] for c in cursor.description]
+
+The second assertion fails::
+
+ Traceback (most recent call last):
+ File "test.py", line 19, in <module>
+ [c[0] for c in cursor.description]
+ AssertionError: ['x.a', 'x.b']
+
+Where above, the driver incorrectly reports the names of the columns
+including the name of the table, which is entirely inconsistent vs.
+when the UNION is not present.
+
+SQLAlchemy relies upon column names being predictable in how they match
+to the original statement, so the SQLAlchemy dialect has no choice but
+to filter these out::
+
+
+ from sqlalchemy import create_engine
+
+ eng = create_engine("sqlite://")
+ conn = eng.connect()
+
+ conn.exec_driver_sql("create table x (a integer, b integer)")
+ conn.exec_driver_sql("insert into x (a, b) values (1, 1)")
+ conn.exec_driver_sql("insert into x (a, b) values (2, 2)")
+
+ result = conn.exec_driver_sql("select x.a, x.b from x")
+ assert result.keys() == ["a", "b"]
+
+ result = conn.exec_driver_sql('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert result.keys() == ["a", "b"]
+
+Note that above, even though SQLAlchemy filters out the dots, *both
+names are still addressable*::
+
+ >>> row = result.first()
+ >>> row["a"]
+ 1
+ >>> row["x.a"]
+ 1
+ >>> row["b"]
+ 1
+ >>> row["x.b"]
+ 1
+
+Therefore, the workaround applied by SQLAlchemy only impacts
+:meth:`_engine.CursorResult.keys` and :meth:`.Row.keys()` in the public API. In
+the very specific case where an application is forced to use column names that
+contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and
+:meth:`.Row.keys()` is required to return these dotted names unmodified,
+the ``sqlite_raw_colnames`` execution option may be provided, either on a
+per-:class:`_engine.Connection` basis::
+
+ result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert result.keys() == ["x.a", "x.b"]
+
+or on a per-:class:`_engine.Engine` basis::
+
+ engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True})
+
+When using the per-:class:`_engine.Engine` execution option, note that
+**Core and ORM queries that use UNION may not function properly**.
+
+SQLite-specific table options
+-----------------------------
+
+One option for CREATE TABLE is supported directly by the SQLite
+dialect in conjunction with the :class:`_schema.Table` construct:
+
+* ``WITHOUT ROWID``::
+
+ Table("some_table", metadata, ..., sqlite_with_rowid=False)
+
+.. seealso::
+
+ `SQLite CREATE TABLE options
+ <https://www.sqlite.org/lang_createtable.html>`_
+
+""" # noqa
+
+import datetime
+import numbers
+import re
+
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from ... import exc
+from ... import processors
+from ... import schema as sa_schema
+from ... import sql
+from ... import types as sqltypes
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import ColumnElement
+from ...sql import compiler
+from ...sql import elements
+from ...sql import roles
+from ...sql import schema
+from ...types import BLOB # noqa
+from ...types import BOOLEAN # noqa
+from ...types import CHAR # noqa
+from ...types import DECIMAL # noqa
+from ...types import FLOAT # noqa
+from ...types import INTEGER # noqa
+from ...types import NUMERIC # noqa
+from ...types import REAL # noqa
+from ...types import SMALLINT # noqa
+from ...types import TEXT # noqa
+from ...types import TIMESTAMP # noqa
+from ...types import VARCHAR # noqa
+
+
+class _SQliteJson(JSON):
+ def result_processor(self, dialect, coltype):
+ default_processor = super(_SQliteJson, self).result_processor(
+ dialect, coltype
+ )
+
+ def process(value):
+ try:
+ return default_processor(value)
+ except TypeError:
+ if isinstance(value, numbers.Number):
+ return value
+ else:
+ raise
+
+ return process
+
+
+class _DateTimeMixin(object):
+ _reg = None
+ _storage_format = None
+
+ def __init__(self, storage_format=None, regexp=None, **kw):
+ super(_DateTimeMixin, self).__init__(**kw)
+ if regexp is not None:
+ self._reg = re.compile(regexp)
+ if storage_format is not None:
+ self._storage_format = storage_format
+
+ @property
+ def format_is_text_affinity(self):
+ """return True if the storage format will automatically imply
+ a TEXT affinity.
+
+ If the storage format contains no non-numeric characters,
+ it will imply a NUMERIC storage format on SQLite; in this case,
+ the type will generate its DDL as DATE_CHAR, DATETIME_CHAR,
+ TIME_CHAR.
+
+ .. versionadded:: 1.0.0
+
+ """
+ spec = self._storage_format % {
+ "year": 0,
+ "month": 0,
+ "day": 0,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
+ }
+ return bool(re.search(r"[^0-9]", spec))
+
+ def adapt(self, cls, **kw):
+ if issubclass(cls, _DateTimeMixin):
+ if self._storage_format:
+ kw["storage_format"] = self._storage_format
+ if self._reg:
+ kw["regexp"] = self._reg
+ return super(_DateTimeMixin, self).adapt(cls, **kw)
+
+ def literal_processor(self, dialect):
+ bp = self.bind_processor(dialect)
+
+ def process(value):
+ return "'%s'" % bp(value)
+
+ return process
+
+
+class DATETIME(_DateTimeMixin, sqltypes.DateTime):
+ r"""Represent a Python datetime object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ e.g.::
+
+ 2021-03-15 12:05:57.105542
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import DATETIME
+
+ dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d",
+ regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)"
+ )
+
+ :param storage_format: format string which will be applied to the dict
+ with keys year, month, day, hour, minute, second, and microsecond.
+
+ :param regexp: regular expression which will be applied to incoming result
+ rows. If the regexp contains named groups, the resulting match dict is
+ applied to the Python datetime() constructor as keyword arguments.
+ Otherwise, if positional groups are used, the datetime() constructor
+ is called with positional arguments via
+ ``*map(int, match_obj.groups(0))``.
+
+ """ # noqa
+
+ _storage_format = (
+ "%(year)04d-%(month)02d-%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+ )
+
+ def __init__(self, *args, **kwargs):
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
+ super(DATETIME, self).__init__(*args, **kwargs)
+ if truncate_microseconds:
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
+ "one of truncate_microseconds or storage_format."
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
+ "truncate_microseconds or regexp."
+ )
+ self._storage_format = (
+ "%(year)04d-%(month)02d-%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d"
+ )
+
+ def bind_processor(self, dialect):
+ datetime_datetime = datetime.datetime
+ datetime_date = datetime.date
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_datetime):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
+ }
+ elif isinstance(value, datetime_date):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
+ }
+ else:
+ raise TypeError(
+ "SQLite DateTime type only accepts Python "
+ "datetime and date objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.datetime
+ )
+ else:
+ return processors.str_to_datetime
+
+
+class DATE(_DateTimeMixin, sqltypes.Date):
+ r"""Represent a Python date object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(year)04d-%(month)02d-%(day)02d"
+
+ e.g.::
+
+ 2011-03-15
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import DATE
+
+ d = DATE(
+ storage_format="%(month)02d/%(day)02d/%(year)04d",
+ regexp=re.compile("(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+)")
+ )
+
+ :param storage_format: format string which will be applied to the
+ dict with keys year, month, and day.
+
+ :param regexp: regular expression which will be applied to
+ incoming result rows. If the regexp contains named groups, the
+ resulting match dict is applied to the Python date() constructor
+ as keyword arguments. Otherwise, if positional groups are used, the
+ date() constructor is called with positional arguments via
+ ``*map(int, match_obj.groups(0))``.
+ """
+
+ _storage_format = "%(year)04d-%(month)02d-%(day)02d"
+
+ def bind_processor(self, dialect):
+ datetime_date = datetime.date
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_date):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ }
+ else:
+ raise TypeError(
+ "SQLite Date type only accepts Python "
+ "date objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.date
+ )
+ else:
+ return processors.str_to_date
+
+
+class TIME(_DateTimeMixin, sqltypes.Time):
+ r"""Represent a Python time object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ e.g.::
+
+ 12:05:57.10558
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import TIME
+
+ t = TIME(storage_format="%(hour)02d-%(minute)02d-"
+ "%(second)02d-%(microsecond)06d",
+ regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?")
+ )
+
+ :param storage_format: format string which will be applied to the dict
+ with keys hour, minute, second, and microsecond.
+
+ :param regexp: regular expression which will be applied to incoming result
+ rows. If the regexp contains named groups, the resulting match dict is
+ applied to the Python time() constructor as keyword arguments. Otherwise,
+ if positional groups are used, the time() constructor is called with
+ positional arguments via ``*map(int, match_obj.groups(0))``.
+ """
+
+ _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ def __init__(self, *args, **kwargs):
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
+ super(TIME, self).__init__(*args, **kwargs)
+ if truncate_microseconds:
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
+ "one of truncate_microseconds or storage_format."
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
+ "truncate_microseconds or regexp."
+ )
+ self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d"
+
+ def bind_processor(self, dialect):
+ datetime_time = datetime.time
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_time):
+ return format_ % {
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
+ }
+ else:
+ raise TypeError(
+ "SQLite Time type only accepts Python "
+ "time objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.time
+ )
+ else:
+ return processors.str_to_time
+
+
+colspecs = {
+ sqltypes.Date: DATE,
+ sqltypes.DateTime: DATETIME,
+ sqltypes.JSON: _SQliteJson,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+ sqltypes.Time: TIME,
+}
+
+ischema_names = {
+ "BIGINT": sqltypes.BIGINT,
+ "BLOB": sqltypes.BLOB,
+ "BOOL": sqltypes.BOOLEAN,
+ "BOOLEAN": sqltypes.BOOLEAN,
+ "CHAR": sqltypes.CHAR,
+ "DATE": sqltypes.DATE,
+ "DATE_CHAR": sqltypes.DATE,
+ "DATETIME": sqltypes.DATETIME,
+ "DATETIME_CHAR": sqltypes.DATETIME,
+ "DOUBLE": sqltypes.FLOAT,
+ "DECIMAL": sqltypes.DECIMAL,
+ "FLOAT": sqltypes.FLOAT,
+ "INT": sqltypes.INTEGER,
+ "INTEGER": sqltypes.INTEGER,
+ "JSON": JSON,
+ "NUMERIC": sqltypes.NUMERIC,
+ "REAL": sqltypes.REAL,
+ "SMALLINT": sqltypes.SMALLINT,
+ "TEXT": sqltypes.TEXT,
+ "TIME": sqltypes.TIME,
+ "TIME_CHAR": sqltypes.TIME,
+ "TIMESTAMP": sqltypes.TIMESTAMP,
+ "VARCHAR": sqltypes.VARCHAR,
+ "NVARCHAR": sqltypes.NVARCHAR,
+ "NCHAR": sqltypes.NCHAR,
+}
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
+ "month": "%m",
+ "day": "%d",
+ "year": "%Y",
+ "second": "%S",
+ "hour": "%H",
+ "doy": "%j",
+ "minute": "%M",
+ "epoch": "%s",
+ "dow": "%w",
+ "week": "%W",
+ },
+ )
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_localtimestamp_func(self, func, **kw):
+ return 'DATETIME(CURRENT_TIMESTAMP, "localtime")'
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "length%s" % self.function_argspec(fn)
+
+ def visit_cast(self, cast, **kwargs):
+ if self.dialect.supports_cast:
+ return super(SQLiteCompiler, self).visit_cast(cast, **kwargs)
+ else:
+ return self.process(cast.clause, **kwargs)
+
+ def visit_extract(self, extract, **kw):
+ try:
+ return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
+ self.extract_map[extract.field],
+ self.process(extract.expr, **kw),
+ )
+ except KeyError as err:
+ util.raise_(
+ exc.CompileError(
+ "%s is not a valid extract argument." % extract.field
+ ),
+ replace_context=err,
+ )
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += "\n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT " + self.process(sql.literal(-1))
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ else:
+ text += " OFFSET " + self.process(sql.literal(0), **kw)
+ return text
+
+ def for_update_clause(self, select, **kw):
+ # sqlite has no "FOR UPDATE" AFAICT
+ return ""
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "%s IS NOT %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "%s IS %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ if binary.type._type_affinity is sqltypes.JSON:
+ expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
+ else:
+ expr = "JSON_EXTRACT(%s, %s)"
+
+ return expr % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ if binary.type._type_affinity is sqltypes.JSON:
+ expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
+ else:
+ expr = "JSON_EXTRACT(%s, %s)"
+
+ return expr % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ # slightly old SQLite versions don't seem to be able to handle
+ # the empty set impl
+ return self.visit_empty_set_expr(type_)
+
+ def visit_empty_set_expr(self, element_types):
+ return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
+ ", ".join("1" for type_ in element_types or [INTEGER()]),
+ ", ".join("1" for type_ in element_types or [INTEGER()]),
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " REGEXP ", **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " NOT REGEXP ", **kw)
+
+ def _on_conflict_target(self, clause, **kw):
+ if clause.constraint_target is not None:
+ target_text = "(%s)" % clause.constraint_target
+ elif clause.inferred_target_elements is not None:
+ target_text = "(%s)" % ", ".join(
+ (
+ self.preparer.quote(c)
+ if isinstance(c, util.string_types)
+ else self.process(c, include_table=False, use_schema=False)
+ )
+ for c in clause.inferred_target_elements
+ )
+ if clause.inferred_target_whereclause is not None:
+ target_text += " WHERE %s" % self.process(
+ clause.inferred_target_whereclause,
+ include_table=False,
+ use_schema=False,
+ literal_binds=True,
+ )
+
+ else:
+ target_text = ""
+
+ return target_text
+
+ def visit_on_conflict_do_nothing(self, on_conflict, **kw):
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ if target_text:
+ return "ON CONFLICT %s DO NOTHING" % target_text
+ else:
+ return "ON CONFLICT DO NOTHING"
+
+ def visit_on_conflict_do_update(self, on_conflict, **kw):
+ clause = on_conflict
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ action_set_ops = []
+
+ set_parameters = dict(clause.update_values_to_set)
+ # create a list of column assignment clauses as tuples
+
+ insert_statement = self.stack[-1]["selectable"]
+ cols = insert_statement.table.c
+ for c in cols:
+ col_key = c.key
+
+ if col_key in set_parameters:
+ value = set_parameters.pop(col_key)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
+
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(c.name)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ # check for names that don't match columns
+ if set_parameters:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.current_executable.table.name,
+ (", ".join("'%s'" % c for c in set_parameters)),
+ )
+ )
+ for k, v in set_parameters.items():
+ key_text = (
+ self.preparer.quote(k)
+ if isinstance(k, util.string_types)
+ else self.process(k, use_schema=False)
+ )
+ value_text = self.process(
+ coercions.expect(roles.ExpressionElementRole, v),
+ use_schema=False,
+ )
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ action_text = ", ".join(action_set_ops)
+ if clause.update_whereclause is not None:
+ action_text += " WHERE %s" % self.process(
+ clause.update_whereclause, include_table=True, use_schema=False
+ )
+
+ return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ coltype = self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ colspec = self.preparer.format_column(column) + " " + coltype
+ default = self.get_column_default_string(column)
+ if default is not None:
+ if isinstance(column.server_default.arg, ColumnElement):
+ default = "(" + default + ")"
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_not_null"
+ ]
+ if on_conflict_clause is not None:
+ colspec += " ON CONFLICT " + on_conflict_clause
+
+ if column.primary_key:
+ if (
+ column.autoincrement is True
+ and len(column.table.primary_key.columns) != 1
+ ):
+ raise exc.CompileError(
+ "SQLite does not support autoincrement for "
+ "composite primary keys"
+ )
+
+ if (
+ column.table.dialect_options["sqlite"]["autoincrement"]
+ and len(column.table.primary_key.columns) == 1
+ and issubclass(column.type._type_affinity, sqltypes.Integer)
+ and not column.foreign_keys
+ ):
+ colspec += " PRIMARY KEY"
+
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
+ if on_conflict_clause is not None:
+ colspec += " ON CONFLICT " + on_conflict_clause
+
+ colspec += " AUTOINCREMENT"
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+
+ return colspec
+
+ def visit_primary_key_constraint(self, constraint):
+ # for columns with sqlite_autoincrement=True,
+ # the PRIMARY KEY constraint can only be inline
+ # with the column itself.
+ if len(constraint.columns) == 1:
+ c = list(constraint)[0]
+ if (
+ c.primary_key
+ and c.table.dialect_options["sqlite"]["autoincrement"]
+ and issubclass(c.type._type_affinity, sqltypes.Integer)
+ and not c.foreign_keys
+ ):
+ return None
+
+ text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+ if on_conflict_clause is None and len(constraint.columns) == 1:
+ on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_unique_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_unique_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+ if on_conflict_clause is None and len(constraint.columns) == 1:
+ col1 = list(constraint)[0]
+ if isinstance(col1, schema.SchemaItem):
+ on_conflict_clause = list(constraint)[0].dialect_options[
+ "sqlite"
+ ]["on_conflict_unique"]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_check_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_check_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_column_check_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_column_check_constraint(
+ constraint
+ )
+
+ if constraint.dialect_options["sqlite"]["on_conflict"] is not None:
+ raise exc.CompileError(
+ "SQLite does not support on conflict clause for "
+ "column check constraint"
+ )
+
+ return text
+
+ def visit_foreign_key_constraint(self, constraint):
+
+ local_table = constraint.elements[0].parent.table
+ remote_table = constraint.elements[0].column.table
+
+ if local_table.schema != remote_table.schema:
+ return None
+ else:
+ return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint(
+ constraint
+ )
+
+ def define_constraint_remote_table(self, constraint, table, preparer):
+ """Format the remote table clause of a CREATE CONSTRAINT clause."""
+
+ return preparer.format_table(table, use_schema=False)
+
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ text += "INDEX "
+
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=True),
+ preparer.format_table(index.table, use_schema=False),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+
+ whereclause = index.dialect_options["sqlite"]["where"]
+ if whereclause is not None:
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def post_create_table(self, table):
+ if table.dialect_options["sqlite"]["with_rowid"] is False:
+ return "\n WITHOUT ROWID"
+ return ""
+
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_)
+
+ def visit_DATETIME(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_DATETIME(type_)
+ else:
+ return "DATETIME_CHAR"
+
+ def visit_DATE(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_DATE(type_)
+ else:
+ return "DATE_CHAR"
+
+ def visit_TIME(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_TIME(type_)
+ else:
+ return "TIME_CHAR"
+
+ def visit_JSON(self, type_, **kw):
+ # note this name provides NUMERIC affinity, not TEXT.
+ # should not be an issue unless the JSON value consists of a single
+ # numeric value. JSONTEXT can be used if this case is required.
+ return "JSON"
+
+
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = set(
+ [
+ "add",
+ "after",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "attach",
+ "autoincrement",
+ "before",
+ "begin",
+ "between",
+ "by",
+ "cascade",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "commit",
+ "conflict",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "database",
+ "default",
+ "deferrable",
+ "deferred",
+ "delete",
+ "desc",
+ "detach",
+ "distinct",
+ "drop",
+ "each",
+ "else",
+ "end",
+ "escape",
+ "except",
+ "exclusive",
+ "exists",
+ "explain",
+ "false",
+ "fail",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "glob",
+ "group",
+ "having",
+ "if",
+ "ignore",
+ "immediate",
+ "in",
+ "index",
+ "indexed",
+ "initially",
+ "inner",
+ "insert",
+ "instead",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "key",
+ "left",
+ "like",
+ "limit",
+ "match",
+ "natural",
+ "not",
+ "notnull",
+ "null",
+ "of",
+ "offset",
+ "on",
+ "or",
+ "order",
+ "outer",
+ "plan",
+ "pragma",
+ "primary",
+ "query",
+ "raise",
+ "references",
+ "reindex",
+ "rename",
+ "replace",
+ "restrict",
+ "right",
+ "rollback",
+ "row",
+ "select",
+ "set",
+ "table",
+ "temp",
+ "temporary",
+ "then",
+ "to",
+ "transaction",
+ "trigger",
+ "true",
+ "union",
+ "unique",
+ "update",
+ "using",
+ "vacuum",
+ "values",
+ "view",
+ "virtual",
+ "when",
+ "where",
+ ]
+ )
+
+
+class SQLiteExecutionContext(default.DefaultExecutionContext):
+ @util.memoized_property
+ def _preserve_raw_colnames(self):
+ return (
+ not self.dialect._broken_dotted_colnames
+ or self.execution_options.get("sqlite_raw_colnames", False)
+ )
+
+ def _translate_colname(self, colname):
+ # TODO: detect SQLite version 3.10.0 or greater;
+ # see [ticket:3633]
+
+ # adjust for dotted column names. SQLite
+ # in the case of UNION may store col names as
+ # "tablename.colname", or if using an attached database,
+ # "database.tablename.colname", in cursor.description
+ if not self._preserve_raw_colnames and "." in colname:
+ return colname.split(".")[-1], colname
+ else:
+ return colname, None
+
+
+class SQLiteDialect(default.DefaultDialect):
+ name = "sqlite"
+ supports_alter = False
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ # SQlite supports "DEFAULT VALUES" but *does not* support
+ # "VALUES (DEFAULT)"
+ supports_default_values = True
+ supports_default_metavalue = False
+
+ supports_empty_insert = False
+ supports_cast = True
+ supports_multivalues_insert = True
+ tuple_in_values = True
+ supports_statement_cache = True
+
+ default_paramstyle = "qmark"
+ execution_ctx_cls = SQLiteExecutionContext
+ statement_compiler = SQLiteCompiler
+ ddl_compiler = SQLiteDDLCompiler
+ type_compiler = SQLiteTypeCompiler
+ preparer = SQLiteIdentifierPreparer
+ ischema_names = ischema_names
+ colspecs = colspecs
+ isolation_level = None
+
+ construct_arguments = [
+ (
+ sa_schema.Table,
+ {
+ "autoincrement": False,
+ "with_rowid": True,
+ },
+ ),
+ (sa_schema.Index, {"where": None}),
+ (
+ sa_schema.Column,
+ {
+ "on_conflict_primary_key": None,
+ "on_conflict_not_null": None,
+ "on_conflict_unique": None,
+ },
+ ),
+ (sa_schema.Constraint, {"on_conflict": None}),
+ ]
+
+ _broken_fk_pragma_quotes = False
+ _broken_dotted_colnames = False
+
+ @util.deprecated_params(
+ _json_serializer=(
+ "1.3.7",
+ "The _json_serializer argument to the SQLite dialect has "
+ "been renamed to the correct name of json_serializer. The old "
+ "argument name will be removed in a future release.",
+ ),
+ _json_deserializer=(
+ "1.3.7",
+ "The _json_deserializer argument to the SQLite dialect has "
+ "been renamed to the correct name of json_deserializer. The old "
+ "argument name will be removed in a future release.",
+ ),
+ )
+ def __init__(
+ self,
+ isolation_level=None,
+ native_datetime=False,
+ json_serializer=None,
+ json_deserializer=None,
+ _json_serializer=None,
+ _json_deserializer=None,
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+
+ if _json_serializer:
+ json_serializer = _json_serializer
+ if _json_deserializer:
+ json_deserializer = _json_deserializer
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+
+ # this flag used by pysqlite dialect, and perhaps others in the
+ # future, to indicate the driver is handling date/timestamp
+ # conversions (and perhaps datetime/time as well on some hypothetical
+ # driver ?)
+ self.native_datetime = native_datetime
+
+ if self.dbapi is not None:
+ if self.dbapi.sqlite_version_info < (3, 7, 16):
+ util.warn(
+ "SQLite version %s is older than 3.7.16, and will not "
+ "support right nested joins, as are sometimes used in "
+ "more complex ORM scenarios. SQLAlchemy 1.4 and above "
+ "no longer tries to rewrite these joins."
+ % (self.dbapi.sqlite_version_info,)
+ )
+
+ self._broken_dotted_colnames = self.dbapi.sqlite_version_info < (
+ 3,
+ 10,
+ 0,
+ )
+ self.supports_default_values = self.dbapi.sqlite_version_info >= (
+ 3,
+ 3,
+ 8,
+ )
+ self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3)
+ self.supports_multivalues_insert = (
+ # https://www.sqlite.org/releaselog/3_7_11.html
+ self.dbapi.sqlite_version_info
+ >= (3, 7, 11)
+ )
+ # see https://www.sqlalchemy.org/trac/ticket/2568
+ # as well as https://www.sqlite.org/src/info/600482d161
+ self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < (
+ 3,
+ 6,
+ 14,
+ )
+
+ _isolation_lookup = util.immutabledict(
+ {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0}
+ )
+
+ def set_isolation_level(self, connection, level):
+ try:
+ isolation_level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (
+ level,
+ self.name,
+ ", ".join(self._isolation_lookup),
+ )
+ ),
+ replace_context=err,
+ )
+ cursor = connection.cursor()
+ cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ cursor.execute("PRAGMA read_uncommitted")
+ res = cursor.fetchone()
+ if res:
+ value = res[0]
+ else:
+ # https://www.sqlite.org/changes.html#version_3_3_3
+ # "Optional READ UNCOMMITTED isolation (instead of the
+ # default isolation level of SERIALIZABLE) and
+ # table level locking when database connections
+ # share a common cache.""
+ # pre-SQLite 3.3.0 default to 0
+ value = 0
+ cursor.close()
+ if value == 0:
+ return "SERIALIZABLE"
+ elif value == 1:
+ return "READ UNCOMMITTED"
+ else:
+ assert False, "Unknown isolation level %s" % value
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = "PRAGMA database_list"
+ dl = connection.exec_driver_sql(s)
+
+ return [db[1] for db in dl if db[1] != "temp"]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ else:
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s)
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ s = (
+ "SELECT name FROM sqlite_temp_master "
+ "WHERE type='table' ORDER BY name "
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_view_names(self, connection, **kw):
+ s = (
+ "SELECT name FROM sqlite_temp_master "
+ "WHERE type='view' ORDER BY name "
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ info = self._get_table_pragma(
+ connection, "table_info", table_name, schema=schema
+ )
+ return bool(info)
+
+ def _get_default_schema_name(self, connection):
+ return "main"
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ else:
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+ else:
+ try:
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE name = ? "
+ "AND type='view'"
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+ except exc.DBAPIError:
+ s = (
+ "SELECT sql FROM sqlite_master WHERE name = ? "
+ "AND type='view'"
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+
+ result = rs.fetchall()
+ if result:
+ return result[0].sql
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ pragma = "table_info"
+ # computed columns are threaded as hidden, they require table_xinfo
+ if self.server_version_info >= (3, 31):
+ pragma = "table_xinfo"
+ info = self._get_table_pragma(
+ connection, pragma, table_name, schema=schema
+ )
+ columns = []
+ tablesql = None
+ for row in info:
+ name = row[1]
+ type_ = row[2].upper()
+ nullable = not row[3]
+ default = row[4]
+ primary_key = row[5]
+ hidden = row[6] if pragma == "table_xinfo" else 0
+
+ # hidden has value 0 for normal columns, 1 for hidden columns,
+ # 2 for computed virtual columns and 3 for computed stored columns
+ # https://www.sqlite.org/src/info/069351b85f9a706f60d3e98fbc8aaf40c374356b967c0464aede30ead3d9d18b
+ if hidden == 1:
+ continue
+
+ generated = bool(hidden)
+ persisted = hidden == 3
+
+ if tablesql is None and generated:
+ tablesql = self._get_table_sql(
+ connection, table_name, schema, **kw
+ )
+
+ columns.append(
+ self._get_column_info(
+ name,
+ type_,
+ nullable,
+ default,
+ primary_key,
+ generated,
+ persisted,
+ tablesql,
+ )
+ )
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ type_,
+ nullable,
+ default,
+ primary_key,
+ generated,
+ persisted,
+ tablesql,
+ ):
+
+ if generated:
+ # the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)"
+ # somehow is "INTEGER GENERATED ALWAYS"
+ type_ = re.sub("generated", "", type_, flags=re.IGNORECASE)
+ type_ = re.sub("always", "", type_, flags=re.IGNORECASE).strip()
+
+ coltype = self._resolve_type_affinity(type_)
+
+ if default is not None:
+ default = util.text_type(default)
+
+ colspec = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "primary_key": primary_key,
+ }
+ if generated:
+ sqltext = ""
+ if tablesql:
+ pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?"
+ match = re.search(
+ re.escape(name) + pattern, tablesql, re.IGNORECASE
+ )
+ if match:
+ sqltext = match.group(1)
+ colspec["computed"] = {"sqltext": sqltext, "persisted": persisted}
+ return colspec
+
+ def _resolve_type_affinity(self, type_):
+ """Return a data type from a reflected column, using affinity rules.
+
+ SQLite's goal for universal compatibility introduces some complexity
+ during reflection, as a column's defined type might not actually be a
+ type that SQLite understands - or indeed, my not be defined *at all*.
+ Internally, SQLite handles this with a 'data type affinity' for each
+ column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER',
+ 'REAL', or 'NONE' (raw bits). The algorithm that determines this is
+ listed in https://www.sqlite.org/datatype3.html section 2.1.
+
+ This method allows SQLAlchemy to support that algorithm, while still
+ providing access to smarter reflection utilities by recognizing
+ column definitions that SQLite only supports through affinity (like
+ DATE and DOUBLE).
+
+ """
+ match = re.match(r"([\w ]+)(\(.*?\))?", type_)
+ if match:
+ coltype = match.group(1)
+ args = match.group(2)
+ else:
+ coltype = ""
+ args = ""
+
+ if coltype in self.ischema_names:
+ coltype = self.ischema_names[coltype]
+ elif "INT" in coltype:
+ coltype = sqltypes.INTEGER
+ elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype:
+ coltype = sqltypes.TEXT
+ elif "BLOB" in coltype or not coltype:
+ coltype = sqltypes.NullType
+ elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype:
+ coltype = sqltypes.REAL
+ else:
+ coltype = sqltypes.NUMERIC
+
+ if args is not None:
+ args = re.findall(r"(\d+)", args)
+ try:
+ coltype = coltype(*[int(a) for a in args])
+ except TypeError:
+ util.warn(
+ "Could not instantiate type %s with "
+ "reflected arguments %s; using no arguments."
+ % (coltype, args)
+ )
+ coltype = coltype()
+ else:
+ coltype = coltype()
+
+ return coltype
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ constraint_name = None
+ table_data = self._get_table_sql(connection, table_name, schema=schema)
+ if table_data:
+ PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY"
+ result = re.search(PK_PATTERN, table_data, re.I)
+ constraint_name = result.group(1) if result else None
+
+ cols = self.get_columns(connection, table_name, schema, **kw)
+ cols.sort(key=lambda col: col.get("primary_key"))
+ pkeys = []
+ for col in cols:
+ if col["primary_key"]:
+ pkeys.append(col["name"])
+
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # sqlite makes this *extremely difficult*.
+ # First, use the pragma to get the actual FKs.
+ pragma_fks = self._get_table_pragma(
+ connection, "foreign_key_list", table_name, schema=schema
+ )
+
+ fks = {}
+
+ for row in pragma_fks:
+ (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
+
+ if not rcol:
+ # no referred column, which means it was not named in the
+ # original DDL. The referred columns of the foreign key
+ # constraint are therefore the primary key of the referred
+ # table.
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ # note that if table doesn't exist, we still get back a record,
+ # just it has no columns in it
+ referred_columns = referred_pk["constrained_columns"]
+ else:
+ # note we use this list only if this is the first column
+ # in the constraint. for subsequent columns we ignore the
+ # list and append "rcol" if present.
+ referred_columns = []
+
+ if self._broken_fk_pragma_quotes:
+ rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl)
+
+ if numerical_id in fks:
+ fk = fks[numerical_id]
+ else:
+ fk = fks[numerical_id] = {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": schema,
+ "referred_table": rtbl,
+ "referred_columns": referred_columns,
+ "options": {},
+ }
+ fks[numerical_id] = fk
+
+ fk["constrained_columns"].append(lcol)
+
+ if rcol:
+ fk["referred_columns"].append(rcol)
+
+ def fk_sig(constrained_columns, referred_table, referred_columns):
+ return (
+ tuple(constrained_columns)
+ + (referred_table,)
+ + tuple(referred_columns)
+ )
+
+ # then, parse the actual SQL and attempt to find DDL that matches
+ # the names as well. SQLite saves the DDL in whatever format
+ # it was typed in as, so need to be liberal here.
+
+ keys_by_signature = dict(
+ (
+ fk_sig(
+ fk["constrained_columns"],
+ fk["referred_table"],
+ fk["referred_columns"],
+ ),
+ fk,
+ )
+ for fk in fks.values()
+ )
+
+ table_data = self._get_table_sql(connection, table_name, schema=schema)
+ if table_data is None:
+ # system tables, etc.
+ return []
+
+ def parse_fks():
+ FK_PATTERN = (
+ r"(?:CONSTRAINT (\w+) +)?"
+ r"FOREIGN KEY *\( *(.+?) *\) +"
+ r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *'
+ r"((?:ON (?:DELETE|UPDATE) "
+ r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)"
+ )
+ for match in re.finditer(FK_PATTERN, table_data, re.I):
+ (
+ constraint_name,
+ constrained_columns,
+ referred_quoted_name,
+ referred_name,
+ referred_columns,
+ onupdatedelete,
+ ) = match.group(1, 2, 3, 4, 5, 6)
+ constrained_columns = list(
+ self._find_cols_in_sig(constrained_columns)
+ )
+ if not referred_columns:
+ referred_columns = constrained_columns
+ else:
+ referred_columns = list(
+ self._find_cols_in_sig(referred_columns)
+ )
+ referred_name = referred_quoted_name or referred_name
+ options = {}
+
+ for token in re.split(r" *\bON\b *", onupdatedelete.upper()):
+ if token.startswith("DELETE"):
+ ondelete = token[6:].strip()
+ if ondelete and ondelete != "NO ACTION":
+ options["ondelete"] = ondelete
+ elif token.startswith("UPDATE"):
+ onupdate = token[6:].strip()
+ if onupdate and onupdate != "NO ACTION":
+ options["onupdate"] = onupdate
+ yield (
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ )
+
+ fkeys = []
+
+ for (
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ ) in parse_fks():
+ sig = fk_sig(constrained_columns, referred_name, referred_columns)
+ if sig not in keys_by_signature:
+ util.warn(
+ "WARNING: SQL-parsed foreign key constraint "
+ "'%s' could not be located in PRAGMA "
+ "foreign_keys for table %s" % (sig, table_name)
+ )
+ continue
+ key = keys_by_signature.pop(sig)
+ key["name"] = constraint_name
+ key["options"] = options
+ fkeys.append(key)
+ # assume the remainders are the unnamed, inline constraints, just
+ # use them as is as it's extremely difficult to parse inline
+ # constraints
+ fkeys.extend(keys_by_signature.values())
+ return fkeys
+
+ def _find_cols_in_sig(self, sig):
+ for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
+ yield match.group(1) or match.group(2)
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+
+ auto_index_by_sig = {}
+ for idx in self.get_indexes(
+ connection,
+ table_name,
+ schema=schema,
+ include_auto_indexes=True,
+ **kw
+ ):
+ if not idx["name"].startswith("sqlite_autoindex"):
+ continue
+ sig = tuple(idx["column_names"])
+ auto_index_by_sig[sig] = idx
+
+ table_data = self._get_table_sql(
+ connection, table_name, schema=schema, **kw
+ )
+ if not table_data:
+ return []
+
+ unique_constraints = []
+
+ def parse_uqs():
+ UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
+ INLINE_UNIQUE_PATTERN = (
+ r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
+ r"+[a-z0-9_ ]+? +UNIQUE"
+ )
+
+ for match in re.finditer(UNIQUE_PATTERN, table_data, re.I):
+ name, cols = match.group(1, 2)
+ yield name, list(self._find_cols_in_sig(cols))
+
+ # we need to match inlines as well, as we seek to differentiate
+ # a UNIQUE constraint from a UNIQUE INDEX, even though these
+ # are kind of the same thing :)
+ for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I):
+ cols = list(
+ self._find_cols_in_sig(match.group(1) or match.group(2))
+ )
+ yield None, cols
+
+ for name, cols in parse_uqs():
+ sig = tuple(cols)
+ if sig in auto_index_by_sig:
+ auto_index_by_sig.pop(sig)
+ parsed_constraint = {"name": name, "column_names": cols}
+ unique_constraints.append(parsed_constraint)
+ # NOTE: auto_index_by_sig might not be empty here,
+ # the PRIMARY KEY may have an entry.
+ return unique_constraints
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ table_data = self._get_table_sql(
+ connection, table_name, schema=schema, **kw
+ )
+ if not table_data:
+ return []
+
+ CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
+ check_constraints = []
+ # NOTE: we aren't using re.S here because we actually are
+ # taking advantage of each CHECK constraint being all on one
+ # line in the table definition in order to delineate. This
+ # necessarily makes assumptions as to how the CREATE TABLE
+ # was emitted.
+
+ for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ name = match.group(1)
+
+ if name:
+ name = re.sub(r'^"|"$', "", name)
+
+ check_constraints.append({"sqltext": match.group(2), "name": name})
+
+ return check_constraints
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ pragma_indexes = self._get_table_pragma(
+ connection, "index_list", table_name, schema=schema
+ )
+ indexes = []
+
+ include_auto_indexes = kw.pop("include_auto_indexes", False)
+ for row in pragma_indexes:
+ # ignore implicit primary key index.
+ # https://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html
+ if not include_auto_indexes and row[1].startswith(
+ "sqlite_autoindex"
+ ):
+ continue
+ indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+
+ # loop thru unique indexes to get the column names.
+ for idx in list(indexes):
+ pragma_index = self._get_table_pragma(
+ connection, "index_info", idx["name"]
+ )
+
+ for row in pragma_index:
+ if row[2] is None:
+ util.warn(
+ "Skipped unsupported reflection of "
+ "expression-based index %s" % idx["name"]
+ )
+ indexes.remove(idx)
+ break
+ else:
+ idx["column_names"].append(row[2])
+ return indexes
+
+ @reflection.cache
+ def _get_table_sql(self, connection, table_name, schema=None, **kw):
+ if schema:
+ schema_expr = "%s." % (
+ self.identifier_preparer.quote_identifier(schema)
+ )
+ else:
+ schema_expr = ""
+ try:
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM %(schema)ssqlite_master UNION ALL "
+ " SELECT * FROM %(schema)ssqlite_temp_master) "
+ "WHERE name = ? "
+ "AND type = 'table'" % {"schema": schema_expr}
+ )
+ rs = connection.exec_driver_sql(s, (table_name,))
+ except exc.DBAPIError:
+ s = (
+ "SELECT sql FROM %(schema)ssqlite_master "
+ "WHERE name = ? "
+ "AND type = 'table'" % {"schema": schema_expr}
+ )
+ rs = connection.exec_driver_sql(s, (table_name,))
+ return rs.scalar()
+
+ def _get_table_pragma(self, connection, pragma, table_name, schema=None):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ statements = ["PRAGMA %s." % quote(schema)]
+ else:
+ # because PRAGMA looks in all attached databases if no schema
+ # given, need to specify "main" schema, however since we want
+ # 'temp' tables in the same namespace as 'main', need to run
+ # the PRAGMA twice
+ statements = ["PRAGMA main.", "PRAGMA temp."]
+
+ qtable = quote(table_name)
+ for statement in statements:
+ statement = "%s%s(%s)" % (statement, pragma, qtable)
+ cursor = connection.exec_driver_sql(statement)
+ if not cursor._soft_closed:
+ # work around SQLite issue whereby cursor.description
+ # is blank when PRAGMA returns no rows:
+ # https://www.sqlite.org/cvstrac/tktview?tn=1884
+ result = cursor.fetchall()
+ else:
+ result = []
+ if result:
+ return result
+ else:
+ return []
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
new file mode 100644
index 0000000..b04a5e6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/dml.py
@@ -0,0 +1,200 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import util
+from ...sql import coercions
+from ...sql import roles
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """SQLite-specific implementation of INSERT.
+
+ Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
+
+ The :class:`_sqlite.Insert` object is created using the
+ :func:`sqlalchemy.dialects.sqlite.insert` function.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`sqlite_on_conflict_insert`
+
+ """
+
+ stringify_dialect = "sqlite"
+ inherit_cache = False
+
+ @util.memoized_property
+ def excluded(self):
+ """Provide the ``excluded`` namespace for an ON CONFLICT statement
+
+ SQLite's ON CONFLICT clause allows reference to the row that would
+ be inserted, known as ``excluded``. This attribute provides
+ all columns in this row to be referenceable.
+
+ .. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance
+ of :class:`_expression.ColumnCollection`, which provides an
+ interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.excluded.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.excluded["column name"]`` or
+ ``stmt.excluded["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ """
+ return alias(self.table, name="excluded").columns
+
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_update(
+ self,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ r"""
+ Specifies a DO UPDATE SET action for ON CONFLICT clause.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index or unique constraint.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ :param set\_:
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_sqlite.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of
+ UPDATE, unless they are manually specified in the
+ :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+
+ :param where:
+ Optional argument. If present, can be a literal SQL
+ string or an acceptable expression for a ``WHERE`` clause
+ that restricts the rows affected by ``DO UPDATE SET``. Rows
+ not meeting the ``WHERE`` condition will not be updated
+ (effectively a ``DO NOTHING`` for those rows).
+
+ """
+
+ self._post_values_clause = OnConflictDoUpdate(
+ index_elements, index_where, set_, where
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_nothing(self, index_elements=None, index_where=None):
+ """
+ Specifies a DO NOTHING action for ON CONFLICT clause.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index or unique constraint.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ """
+
+ self._post_values_clause = OnConflictDoNothing(
+ index_elements, index_where
+ )
+
+
+insert = public_factory(
+ Insert, ".dialects.sqlite.insert", ".dialects.sqlite.Insert"
+)
+
+
+class OnConflictClause(ClauseElement):
+ stringify_dialect = "sqlite"
+
+ def __init__(self, index_elements=None, index_where=None):
+
+ if index_elements is not None:
+ self.constraint_target = None
+ self.inferred_target_elements = index_elements
+ self.inferred_target_whereclause = index_where
+ else:
+ self.constraint_target = (
+ self.inferred_target_elements
+ ) = self.inferred_target_whereclause = None
+
+
+class OnConflictDoNothing(OnConflictClause):
+ __visit_name__ = "on_conflict_do_nothing"
+
+
+class OnConflictDoUpdate(OnConflictClause):
+ __visit_name__ = "on_conflict_do_update"
+
+ def __init__(
+ self,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ super(OnConflictDoUpdate, self).__init__(
+ index_elements=index_elements,
+ index_where=index_where,
+ )
+
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update_values_to_set = [
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
+ ]
+ self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py
new file mode 100644
index 0000000..614f954
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/json.py
@@ -0,0 +1,84 @@
+from ... import types as sqltypes
+
+
+class JSON(sqltypes.JSON):
+ """SQLite JSON type.
+
+ SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
+ that JSON1_ is a
+ `loadable extension <https://www.sqlite.org/loadext.html>`_ and as such
+ may not be available, or may require run-time loading.
+
+ :class:`_sqlite.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a SQLite backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`_sqlite.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_EXTRACT``
+ function wrapped in the ``JSON_QUOTE`` function at the database level.
+ Extracted values are quoted in order to ensure that the results are
+ always JSON string values.
+
+
+ .. versionadded:: 1.3
+
+
+ .. _JSON1: https://www.sqlite.org/json1.html
+
+ """
+
+
+# Note: these objects currently match exactly those of MySQL, however since
+# these are not generalizable to all JSON implementations, remain separately
+# implemented for each dialect.
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
new file mode 100644
index 0000000..e5b17e8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -0,0 +1,142 @@
+import os
+import re
+
+from ... import exc
+from ...engine import url as sa_url
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import follower_url_from_main
+from ...testing.provision import generate_driver_url
+from ...testing.provision import log
+from ...testing.provision import post_configure_engine
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import stop_test_class_outside_fixtures
+from ...testing.provision import temp_table_keyword_args
+
+
+# TODO: I can't get this to build dynamically with pytest-xdist procs
+_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
+
+
+@generate_driver_url.for_db("sqlite")
+def generate_driver_url(url, driver, query_str):
+ if driver == "pysqlcipher" and url.get_driver_name() != "pysqlcipher":
+ if url.database:
+ url = url.set(database=url.database + ".enc")
+ url = url.set(password="test")
+ url = url.set(drivername="sqlite+%s" % (driver,))
+ try:
+ url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return url
+
+
+@follower_url_from_main.for_db("sqlite")
+def _sqlite_follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+
+ if not url.database or url.database == ":memory:":
+ return url
+ else:
+
+ m = re.match(r"(.+?)\.(.+)$", url.database)
+ name, ext = m.group(1, 2)
+ drivername = url.get_driver_name()
+ return sa_url.make_url(
+ "sqlite+%s:///%s_%s.%s" % (drivername, drivername, ident, ext)
+ )
+
+
+@post_configure_engine.for_db("sqlite")
+def _sqlite_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "connect")
+ def connect(dbapi_connection, connection_record):
+ # use file DBs in all cases, memory acts kind of strangely
+ # as an attached
+ if not follower_ident:
+ # note this test_schema.db gets created for all test runs.
+ # there's not any dedicated cleanup step for it. it in some
+ # ways corresponds to the "test.test_schema" schema that's
+ # expected to be already present, so for now it just stays
+ # in a given checkout directory.
+ dbapi_connection.execute(
+ 'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
+ % (engine.driver,)
+ )
+ else:
+ dbapi_connection.execute(
+ 'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema'
+ % (follower_ident, engine.driver)
+ )
+
+
+@create_db.for_db("sqlite")
+def _sqlite_create_db(cfg, eng, ident):
+ pass
+
+
+@drop_db.for_db("sqlite")
+def _sqlite_drop_db(cfg, eng, ident):
+ for path in [
+ "%s.db" % ident,
+ "%s_%s_test_schema.db" % (ident, eng.driver),
+ ]:
+ if os.path.exists(path):
+ log.info("deleting SQLite database file: %s" % path)
+ os.remove(path)
+
+
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
+ with db.connect() as conn:
+ files = [
+ row.file
+ for row in conn.exec_driver_sql("PRAGMA database_list")
+ if row.file
+ ]
+
+ if files:
+ db.dispose()
+ # some sqlite file tests are not cleaning up well yet, so do this
+ # just to make things simple for now
+ for file_ in files:
+ if file_ and os.path.exists(file_):
+ os.remove(file_)
+
+
+@temp_table_keyword_args.for_db("sqlite")
+def _sqlite_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
+
+
+@run_reap_dbs.for_db("sqlite")
+def _reap_sqlite_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+ for ident in idents:
+ # we don't have a config so we can't call _sqlite_drop_db due to the
+ # decorator
+ for ext in ("db", "db.enc"):
+ for path in (
+ ["%s.%s" % (ident, ext)]
+ + [
+ "%s_%s.%s" % (drivername, ident, ext)
+ for drivername in _drivernames
+ ]
+ + [
+ "%s_test_schema.%s" % (drivername, ext)
+ for drivername in _drivernames
+ ]
+ + [
+ "%s_%s_test_schema.%s" % (ident, drivername, ext)
+ for drivername in _drivernames
+ ]
+ ):
+ if os.path.exists(path):
+ log.info("deleting SQLite database file: %s" % path)
+ os.remove(path)
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
new file mode 100644
index 0000000..65f94c8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
@@ -0,0 +1,164 @@
+# sqlite/pysqlcipher.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sqlite+pysqlcipher
+ :name: pysqlcipher
+ :dbapi: sqlcipher 3 or pysqlcipher
+ :connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>]
+
+ Dialect for support of DBAPIs that make use of the
+ `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
+
+
+Driver
+------
+
+Current dialect selection logic is:
+
+* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module,
+ that module is used.
+* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/
+* If not available, fall back to https://pypi.org/project/pysqlcipher3/
+* For Python 2, https://pypi.org/project/pysqlcipher/ is used.
+
+.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no
+ longer maintained; the ``sqlcipher3`` driver as of this writing appears
+ to be current. For future compatibility, any pysqlcipher-compatible DBAPI
+ may be used as follows::
+
+ import sqlcipher_compatible_driver
+
+ from sqlalchemy import create_engine
+
+ e = create_engine(
+ "sqlite+pysqlcipher://:password@/dbname.db",
+ module=sqlcipher_compatible_driver
+ )
+
+These drivers make use of the SQLCipher engine. This system essentially
+introduces new PRAGMA commands to SQLite which allows the setting of a
+passphrase and other encryption parameters, allowing the database file to be
+encrypted.
+
+
+Connect Strings
+---------------
+
+The format of the connect string is in every way the same as that
+of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
+"password" field is now accepted, which should contain a passphrase::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
+
+For an absolute file path, two leading slashes should be used for the
+database name::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
+
+A selection of additional encryption-related pragmas supported by SQLCipher
+as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
+in the query string, and will result in that PRAGMA being called for each
+new connection. Currently, ``cipher``, ``kdf_iter``
+``cipher_page_size`` and ``cipher_use_hmac`` are supported::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
+
+.. warning:: Previous versions of sqlalchemy did not take into consideration
+ the encryption-related pragmas passed in the url string, that were silently
+ ignored. This may cause errors when opening files saved by a
+ previous sqlalchemy version if the encryption options do not match.
+
+
+Pooling Behavior
+----------------
+
+The driver makes a change to the default pool behavior of pysqlite
+as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
+has been observed to be significantly slower on connection than the
+pysqlite driver, most likely due to the encryption overhead, so the
+dialect here defaults to using the :class:`.SingletonThreadPool`
+implementation,
+instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
+implementation is entirely configurable using the
+:paramref:`_sa.create_engine.poolclass` parameter; the :class:`.
+StaticPool` may
+be more feasible for single-threaded use, or :class:`.NullPool` may be used
+to prevent unencrypted connections from being held open for long periods of
+time, at the expense of slower startup time for new connections.
+
+
+""" # noqa
+
+from __future__ import absolute_import
+
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+
+
+class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
+ driver = "pysqlcipher"
+ supports_statement_cache = True
+
+ pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
+
+ @classmethod
+ def dbapi(cls):
+ if util.py3k:
+ try:
+ import sqlcipher3 as sqlcipher
+ except ImportError:
+ pass
+ else:
+ return sqlcipher
+
+ from pysqlcipher3 import dbapi2 as sqlcipher
+
+ else:
+ from pysqlcipher import dbapi2 as sqlcipher
+
+ return sqlcipher
+
+ @classmethod
+ def get_pool_class(cls, url):
+ return pool.SingletonThreadPool
+
+ def on_connect_url(self, url):
+ super_on_connect = super(
+ SQLiteDialect_pysqlcipher, self
+ ).on_connect_url(url)
+
+ # pull the info we need from the URL early. Even though URL
+ # is immutable, we don't want any in-place changes to the URL
+ # to affect things
+ passphrase = url.password or ""
+ url_query = dict(url.query)
+
+ def on_connect(conn):
+ cursor = conn.cursor()
+ cursor.execute('pragma key="%s"' % passphrase)
+ for prag in self.pragmas:
+ value = url_query.get(prag, None)
+ if value is not None:
+ cursor.execute('pragma %s="%s"' % (prag, value))
+ cursor.close()
+
+ if super_on_connect:
+ super_on_connect(conn)
+
+ return on_connect
+
+ def create_connect_args(self, url):
+ plain_url = url._replace(password=None)
+ plain_url = plain_url.difference_update_query(self.pragmas)
+ return super(SQLiteDialect_pysqlcipher, self).create_connect_args(
+ plain_url
+ )
+
+
+dialect = SQLiteDialect_pysqlcipher
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
new file mode 100644
index 0000000..1aae561
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -0,0 +1,613 @@
+# sqlite/pysqlite.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: sqlite+pysqlite
+ :name: pysqlite
+ :dbapi: sqlite3
+ :connectstring: sqlite+pysqlite:///file_path
+ :url: https://docs.python.org/library/sqlite3.html
+
+ Note that ``pysqlite`` is the same driver as the ``sqlite3``
+ module included with the Python distribution.
+
+Driver
+------
+
+The ``sqlite3`` Python DBAPI is standard on all modern Python versions;
+for cPython and Pypy, no additional installation is necessary.
+
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database"
+portion of the URL. Note that the format of a SQLAlchemy url is::
+
+ driver://user:pass@host/database
+
+This means that the actual filename to be used starts with the characters to
+the **right** of the third slash. So connecting to a relative filepath
+looks like::
+
+ # relative path
+ e = create_engine('sqlite:///path/to/database.db')
+
+An absolute path, which is denoted by starting with a slash, means you
+need **four** slashes::
+
+ # absolute path
+ e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be
+used. Double backslashes are probably needed::
+
+ # absolute path on Windows
+ e = create_engine('sqlite:///C:\\path\\to\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is
+present. Specify ``sqlite://`` and nothing else::
+
+ # in-memory database
+ e = create_engine('sqlite://')
+
+.. _pysqlite_uri_connections:
+
+URI Connections
+^^^^^^^^^^^^^^^
+
+Modern versions of SQLite support an alternative system of connecting using a
+`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage
+that additional driver-level arguments can be passed including options such as
+"read only". The Python sqlite3 driver supports this mode under modern Python
+3 versions. The SQLAlchemy pysqlite driver supports this mode of use by
+specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept
+as the "database" portion of the SQLAlchemy url (that is, following a slash)::
+
+ e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true")
+
+.. note:: The "uri=true" parameter must appear in the **query string**
+ of the URL. It will not currently work as expected if it is only
+ present in the :paramref:`_sa.create_engine.connect_args`
+ parameter dictionary.
+
+The logic reconciles the simultaneous presence of SQLAlchemy's query string and
+SQLite's query string by separating out the parameters that belong to the
+Python sqlite3 driver vs. those that belong to the SQLite URI. This is
+achieved through the use of a fixed list of parameters known to be accepted by
+the Python side of the driver. For example, to include a URL that indicates
+the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the
+SQLite "mode" and "nolock" parameters, they can all be passed together on the
+query string::
+
+ e = create_engine(
+ "sqlite:///file:path/to/database?"
+ "check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true"
+ )
+
+Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
+
+ sqlite3.connect(
+ "file:path/to/database?mode=ro&nolock=1",
+ check_same_thread=True, timeout=10, uri=True
+ )
+
+Regarding future parameters added to either the Python or native drivers. new
+parameter names added to the SQLite URI scheme should be automatically
+accommodated by this scheme. New parameter names added to the Python driver
+side can be accommodated by specifying them in the
+:paramref:`_sa.create_engine.connect_args` dictionary,
+until dialect support is
+added by SQLAlchemy. For the less likely case that the native SQLite driver
+adds a new parameter name that overlaps with one of the existing, known Python
+driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would
+require adjustment for the URL scheme to continue to support this.
+
+As is always the case for all SQLAlchemy dialects, the entire "URL" process
+can be bypassed in :func:`_sa.create_engine` through the use of the
+:paramref:`_sa.create_engine.creator`
+parameter which allows for a custom callable
+that creates a Python sqlite3 driver level connection directly.
+
+.. versionadded:: 1.3.9
+
+.. seealso::
+
+ `Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in
+ the SQLite documentation
+
+.. _pysqlite_regexp:
+
+Regular Expression Support
+---------------------------
+
+.. versionadded:: 1.4
+
+Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided
+using Python's re.search_ function. SQLite itself does not include a working
+regular expression operator; instead, it includes a non-implemented placeholder
+operator ``REGEXP`` that calls a user-defined function that must be provided.
+
+SQLAlchemy's implementation makes use of the pysqlite create_function_ hook
+as follows::
+
+
+ def regexp(a, b):
+ return re.search(a, b) is not None
+
+ sqlite_connection.create_function(
+ "regexp", 2, regexp,
+ )
+
+There is currently no support for regular expression flags as a separate
+argument, as these are not supported by SQLite's REGEXP operator, however these
+may be included inline within the regular expression string. See `Python regular expressions`_ for
+details.
+
+.. seealso::
+
+ `Python regular expressions`_: Documentation for Python's regular expression syntax.
+
+.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
+
+.. _re.search: https://docs.python.org/3/library/re.html#re.search
+
+.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search
+
+
+
+Compatibility with sqlite3 "native" date and datetime types
+-----------------------------------------------------------
+
+The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
+sqlite3.PARSE_COLNAMES options, which have the effect of any column
+or expression explicitly cast as "date" or "timestamp" will be converted
+to a Python date or datetime object. The date and datetime types provided
+with the pysqlite dialect are not currently compatible with these options,
+since they render the ISO date/datetime including microseconds, which
+pysqlite's driver does not. Additionally, SQLAlchemy does not at
+this time automatically render the "cast" syntax required for the
+freestanding functions "current_timestamp" and "current_date" to return
+datetime/date types natively. Unfortunately, pysqlite
+does not provide the standard DBAPI types in ``cursor.description``,
+leaving SQLAlchemy with no way to detect these types on the fly
+without expensive per-row type checks.
+
+Keeping in mind that pysqlite's parsing option is not recommended,
+nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
+can be forced if one configures "native_datetime=True" on create_engine()::
+
+ engine = create_engine('sqlite://',
+ connect_args={'detect_types':
+ sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
+ native_datetime=True
+ )
+
+With this flag enabled, the DATE and TIMESTAMP types (but note - not the
+DATETIME or TIME types...confused yet ?) will not perform any bind parameter
+or result processing. Execution of "func.current_date()" will return a string.
+"func.current_timestamp()" is registered as returning a DATETIME type in
+SQLAlchemy, so this function still receives SQLAlchemy-level result
+processing.
+
+.. _pysqlite_threading_pooling:
+
+Threading/Pooling Behavior
+---------------------------
+
+Pysqlite's default behavior is to prohibit the usage of a single connection
+in more than one thread. This is originally intended to work with older
+versions of SQLite that did not support multithreaded operation under
+various circumstances. In particular, older SQLite versions
+did not allow a ``:memory:`` database to be used in multiple threads
+under any circumstances.
+
+Pysqlite does include a now-undocumented flag known as
+``check_same_thread`` which will disable this check, however note that
+pysqlite connections are still not safe to use in concurrently in multiple
+threads. In particular, any statement execution calls would need to be
+externally mutexed, as Pysqlite does not provide for thread-safe propagation
+of error messages among other things. So while even ``:memory:`` databases
+can be shared among threads in modern SQLite, Pysqlite doesn't provide enough
+thread-safety to make this usage worth it.
+
+SQLAlchemy sets up pooling to work with Pysqlite's default behavior:
+
+* When a ``:memory:`` SQLite database is specified, the dialect by default
+ will use :class:`.SingletonThreadPool`. This pool maintains a single
+ connection per thread, so that all access to the engine within the current
+ thread use the same ``:memory:`` database - other threads would access a
+ different ``:memory:`` database.
+* When a file-based database is specified, the dialect will use
+ :class:`.NullPool` as the source of connections. This pool closes and
+ discards connections which are returned to the pool immediately. SQLite
+ file-based connections have extremely low overhead, so pooling is not
+ necessary. The scheme also prevents a connection from being used again in
+ a different thread and works best with SQLite's coarse-grained file locking.
+
+Using a Memory Database in Multiple Threads
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To use a ``:memory:`` database in a multithreaded scenario, the same
+connection object must be shared among threads, since the database exists
+only within the scope of that connection. The
+:class:`.StaticPool` implementation will maintain a single connection
+globally, and the ``check_same_thread`` flag can be passed to Pysqlite
+as ``False``::
+
+ from sqlalchemy.pool import StaticPool
+ engine = create_engine('sqlite://',
+ connect_args={'check_same_thread':False},
+ poolclass=StaticPool)
+
+Note that using a ``:memory:`` database in multiple threads requires a recent
+version of SQLite.
+
+Using Temporary Tables with SQLite
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Due to the way SQLite deals with temporary tables, if you wish to use a
+temporary table in a file-based SQLite database across multiple checkouts
+from the connection pool, such as when using an ORM :class:`.Session` where
+the temporary table should continue to remain after :meth:`.Session.commit` or
+:meth:`.Session.rollback` is called, a pool which maintains a single
+connection must be used. Use :class:`.SingletonThreadPool` if the scope is
+only needed within the current thread, or :class:`.StaticPool` is scope is
+needed within multiple threads for this case::
+
+ # maintain the same connection per thread
+ from sqlalchemy.pool import SingletonThreadPool
+ engine = create_engine('sqlite:///mydb.db',
+ poolclass=SingletonThreadPool)
+
+
+ # maintain the same connection across all threads
+ from sqlalchemy.pool import StaticPool
+ engine = create_engine('sqlite:///mydb.db',
+ poolclass=StaticPool)
+
+Note that :class:`.SingletonThreadPool` should be configured for the number
+of threads that are to be used; beyond that number, connections will be
+closed out in a non deterministic way.
+
+Unicode
+-------
+
+The pysqlite driver only returns Python ``unicode`` objects in result sets,
+never plain strings, and accommodates ``unicode`` objects within bound
+parameter values in all cases. Regardless of the SQLAlchemy string type in
+use, string-based result values will by Python ``unicode`` in Python 2.
+The :class:`.Unicode` type should still be used to indicate those columns that
+require unicode, however, so that non-``unicode`` values passed inadvertently
+will emit a warning. Pysqlite will emit an error if a non-``unicode`` string
+is passed containing non-ASCII characters.
+
+Dealing with Mixed String / Binary Columns in Python 3
+------------------------------------------------------
+
+The SQLite database is weakly typed, and as such it is possible when using
+binary values, which in Python 3 are represented as ``b'some string'``, that a
+particular SQLite database can have data values within different rows where
+some of them will be returned as a ``b''`` value by the Pysqlite driver, and
+others will be returned as Python strings, e.g. ``''`` values. This situation
+is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used
+consistently, however if a particular SQLite database has data that was
+inserted using the Pysqlite driver directly, or when using the SQLAlchemy
+:class:`.String` type which was later changed to :class:`.LargeBinary`, the
+table will not be consistently readable because SQLAlchemy's
+:class:`.LargeBinary` datatype does not handle strings so it has no way of
+"encoding" a value that is in string format.
+
+To deal with a SQLite table that has mixed string / binary data in the
+same column, use a custom type that will check each row individually::
+
+ # note this is Python 3 only
+
+ from sqlalchemy import String
+ from sqlalchemy import TypeDecorator
+
+ class MixedBinary(TypeDecorator):
+ impl = String
+ cache_ok = True
+
+ def process_result_value(self, value, dialect):
+ if isinstance(value, str):
+ value = bytes(value, 'utf-8')
+ elif value is not None:
+ value = bytes(value)
+
+ return value
+
+Then use the above ``MixedBinary`` datatype in the place where
+:class:`.LargeBinary` would normally be used.
+
+.. _pysqlite_serializable:
+
+Serializable isolation / Savepoints / Transactional DDL
+-------------------------------------------------------
+
+In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
+driver's assortment of issues that prevent several features of SQLite
+from working correctly. The pysqlite DBAPI driver has several
+long-standing bugs which impact the correctness of its transactional
+behavior. In its default mode of operation, SQLite features such as
+SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
+non-functional, and in order to use these features, workarounds must
+be taken.
+
+The issue is essentially that the driver attempts to second-guess the user's
+intent, failing to start transactions and sometimes ending them prematurely, in
+an effort to minimize the SQLite databases's file locking behavior, even
+though SQLite itself uses "shared" locks for read-only activities.
+
+SQLAlchemy chooses to not alter this behavior by default, as it is the
+long-expected behavior of the pysqlite driver; if and when the pysqlite
+driver attempts to repair these issues, that will be more of a driver towards
+defaults for SQLAlchemy.
+
+The good news is that with a few events, we can implement transactional
+support fully, by disabling pysqlite's feature entirely and emitting BEGIN
+ourselves. This is achieved using two event listeners::
+
+ from sqlalchemy import create_engine, event
+
+ engine = create_engine("sqlite:///myfile.db")
+
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ # disable pysqlite's emitting of the BEGIN statement entirely.
+ # also stops it from emitting COMMIT before any DDL.
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ # emit our own BEGIN
+ conn.exec_driver_sql("BEGIN")
+
+.. warning:: When using the above recipe, it is advised to not use the
+ :paramref:`.Connection.execution_options.isolation_level` setting on
+ :class:`_engine.Connection` and :func:`_sa.create_engine`
+ with the SQLite driver,
+ as this function necessarily will also alter the ".isolation_level" setting.
+
+
+Above, we intercept a new pysqlite connection and disable any transactional
+integration. Then, at the point at which SQLAlchemy knows that transaction
+scope is to begin, we emit ``"BEGIN"`` ourselves.
+
+When we take control of ``"BEGIN"``, we can also control directly SQLite's
+locking modes, introduced at
+`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
+by adding the desired locking mode to our ``"BEGIN"``::
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN EXCLUSIVE")
+
+.. seealso::
+
+ `BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
+ on the SQLite site
+
+ `sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
+ on the Python bug tracker
+
+ `sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
+ on the Python bug tracker
+
+
+""" # noqa
+
+import os
+import re
+
+from .base import DATE
+from .base import DATETIME
+from .base import SQLiteDialect
+from ... import exc
+from ... import pool
+from ... import types as sqltypes
+from ... import util
+
+
+class _SQLite_pysqliteTimeStamp(DATETIME):
+ def bind_processor(self, dialect):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATETIME.bind_processor(self, dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATETIME.result_processor(self, dialect, coltype)
+
+
+class _SQLite_pysqliteDate(DATE):
+ def bind_processor(self, dialect):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATE.bind_processor(self, dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATE.result_processor(self, dialect, coltype)
+
+
+class SQLiteDialect_pysqlite(SQLiteDialect):
+ default_paramstyle = "qmark"
+ supports_statement_cache = True
+
+ colspecs = util.update_copy(
+ SQLiteDialect.colspecs,
+ {
+ sqltypes.Date: _SQLite_pysqliteDate,
+ sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
+ },
+ )
+
+ if not util.py2k:
+ description_encoding = None
+
+ driver = "pysqlite"
+
+ @classmethod
+ def dbapi(cls):
+ if util.py2k:
+ try:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError:
+ try:
+ from sqlite3 import dbapi2 as sqlite
+ except ImportError as e:
+ raise e
+ else:
+ from sqlite3 import dbapi2 as sqlite
+ return sqlite
+
+ @classmethod
+ def _is_url_file_db(cls, url):
+ if (url.database and url.database != ":memory:") and (
+ url.query.get("mode", None) != "memory"
+ ):
+ return True
+ else:
+ return False
+
+ @classmethod
+ def get_pool_class(cls, url):
+ if cls._is_url_file_db(url):
+ return pool.NullPool
+ else:
+ return pool.SingletonThreadPool
+
+ def _get_server_version_info(self, connection):
+ return self.dbapi.sqlite_version_info
+
+ _isolation_lookup = SQLiteDialect._isolation_lookup.union(
+ {
+ "AUTOCOMMIT": None,
+ }
+ )
+
+ def set_isolation_level(self, connection, level):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+
+ if level == "AUTOCOMMIT":
+ dbapi_connection.isolation_level = None
+ else:
+ dbapi_connection.isolation_level = ""
+ return super(SQLiteDialect_pysqlite, self).set_isolation_level(
+ connection, level
+ )
+
+ def on_connect(self):
+ connect = super(SQLiteDialect_pysqlite, self).on_connect()
+
+ def regexp(a, b):
+ if b is None:
+ return None
+ return re.search(a, b) is not None
+
+ def set_regexp(connection):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+ dbapi_connection.create_function(
+ "regexp",
+ 2,
+ regexp,
+ )
+
+ fns = [set_regexp]
+
+ if self.isolation_level is not None:
+
+ def iso_level(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(iso_level)
+
+ def connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return connect
+
+ def create_connect_args(self, url):
+ if url.username or url.password or url.host or url.port:
+ raise exc.ArgumentError(
+ "Invalid SQLite URL: %s\n"
+ "Valid SQLite URL forms are:\n"
+ " sqlite:///:memory: (or, sqlite://)\n"
+ " sqlite:///relative/path/to/file.db\n"
+ " sqlite:////absolute/path/to/file.db" % (url,)
+ )
+
+ # theoretically, this list can be augmented, at least as far as
+ # parameter names accepted by sqlite3/pysqlite, using
+ # inspect.getfullargspec(). for the moment this seems like overkill
+ # as these parameters don't change very often, and as always,
+ # parameters passed to connect_args will always go to the
+ # sqlite3/pysqlite driver.
+ pysqlite_args = [
+ ("uri", bool),
+ ("timeout", float),
+ ("isolation_level", str),
+ ("detect_types", int),
+ ("check_same_thread", bool),
+ ("cached_statements", int),
+ ]
+ opts = url.query
+ pysqlite_opts = {}
+ for key, type_ in pysqlite_args:
+ util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
+
+ if pysqlite_opts.get("uri", False):
+ uri_opts = dict(opts)
+ # here, we are actually separating the parameters that go to
+ # sqlite3/pysqlite vs. those that go the SQLite URI. What if
+ # two names conflict? again, this seems to be not the case right
+ # now, and in the case that new names are added to
+ # either side which overlap, again the sqlite3/pysqlite parameters
+ # can be passed through connect_args instead of in the URL.
+ # If SQLite native URIs add a parameter like "timeout" that
+ # we already have listed here for the python driver, then we need
+ # to adjust for that here.
+ for key, type_ in pysqlite_args:
+ uri_opts.pop(key, None)
+ filename = url.database
+ if uri_opts:
+ # sorting of keys is for unit test support
+ filename += "?" + (
+ "&".join(
+ "%s=%s" % (key, uri_opts[key])
+ for key in sorted(uri_opts)
+ )
+ )
+ else:
+ filename = url.database or ":memory:"
+ if filename != ":memory:":
+ filename = os.path.abspath(filename)
+
+ return ([filename], pysqlite_opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ return isinstance(
+ e, self.dbapi.ProgrammingError
+ ) and "Cannot operate on a closed database." in str(e)
+
+
+dialect = SQLiteDialect_pysqlite
diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py
new file mode 100644
index 0000000..c7755c8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/__init__.py
@@ -0,0 +1,67 @@
+# sybase/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import pyodbc # noqa
+from . import pysybase # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import FLOAT
+from .base import IMAGE
+from .base import INT
+from .base import INTEGER
+from .base import MONEY
+from .base import NCHAR
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import SMALLINT
+from .base import SMALLMONEY
+from .base import TEXT
+from .base import TIME
+from .base import TINYINT
+from .base import UNICHAR
+from .base import UNITEXT
+from .base import UNIVARCHAR
+from .base import VARBINARY
+from .base import VARCHAR
+
+
+# default dialect
+base.dialect = dialect = pyodbc.dialect
+
+
+__all__ = (
+ "CHAR",
+ "VARCHAR",
+ "TIME",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "DATE",
+ "DATETIME",
+ "FLOAT",
+ "NUMERIC",
+ "BIGINT",
+ "INT",
+ "INTEGER",
+ "SMALLINT",
+ "BINARY",
+ "VARBINARY",
+ "UNITEXT",
+ "UNICHAR",
+ "UNIVARCHAR",
+ "IMAGE",
+ "BIT",
+ "MONEY",
+ "SMALLMONEY",
+ "TINYINT",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
new file mode 100644
index 0000000..83248d1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -0,0 +1,1100 @@
+# sybase/base.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# get_select_precolumns(), limit_clause() implementation
+# copyright (C) 2007 Fisch Asset Management
+# AG https://www.fam.ch, with coding by Alexander Houben
+# alexander.houben@thor-solutions.ch
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: sybase
+ :name: Sybase
+
+.. note::
+
+ The Sybase dialect within SQLAlchemy **is not currently supported**.
+ It is not tested within continuous integration and is likely to have
+ many issues and caveats not currently handled. Consider using the
+ `external dialect <https://github.com/gordthompson/sqlalchemy-sybase>`_
+ instead.
+
+.. deprecated:: 1.4 The internal Sybase dialect is deprecated and will be
+ removed in a future version. Use the external dialect.
+
+"""
+
+import re
+
+from sqlalchemy import exc
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import text
+from sqlalchemy.types import BIGINT
+from sqlalchemy.types import BINARY
+from sqlalchemy.types import CHAR
+from sqlalchemy.types import DATE
+from sqlalchemy.types import DATETIME
+from sqlalchemy.types import DECIMAL
+from sqlalchemy.types import FLOAT
+from sqlalchemy.types import INT # noqa
+from sqlalchemy.types import INTEGER
+from sqlalchemy.types import NCHAR
+from sqlalchemy.types import NUMERIC
+from sqlalchemy.types import NVARCHAR
+from sqlalchemy.types import REAL
+from sqlalchemy.types import SMALLINT
+from sqlalchemy.types import TEXT
+from sqlalchemy.types import TIME
+from sqlalchemy.types import TIMESTAMP
+from sqlalchemy.types import Unicode
+from sqlalchemy.types import VARBINARY
+from sqlalchemy.types import VARCHAR
+
+
+RESERVED_WORDS = set(
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "backup",
+ "begin",
+ "between",
+ "bigint",
+ "binary",
+ "bit",
+ "bottom",
+ "break",
+ "by",
+ "call",
+ "capability",
+ "cascade",
+ "case",
+ "cast",
+ "char",
+ "char_convert",
+ "character",
+ "check",
+ "checkpoint",
+ "close",
+ "comment",
+ "commit",
+ "connect",
+ "constraint",
+ "contains",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "cube",
+ "current",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "date",
+ "dbspace",
+ "deallocate",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "deleting",
+ "desc",
+ "distinct",
+ "do",
+ "double",
+ "drop",
+ "dynamic",
+ "else",
+ "elseif",
+ "encrypted",
+ "end",
+ "endif",
+ "escape",
+ "except",
+ "exception",
+ "exec",
+ "execute",
+ "existing",
+ "exists",
+ "externlogin",
+ "fetch",
+ "first",
+ "float",
+ "for",
+ "force",
+ "foreign",
+ "forward",
+ "from",
+ "full",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identified",
+ "if",
+ "in",
+ "index",
+ "index_lparen",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "inserting",
+ "install",
+ "instead",
+ "int",
+ "integer",
+ "integrated",
+ "intersect",
+ "into",
+ "iq",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "lateral",
+ "left",
+ "like",
+ "lock",
+ "login",
+ "long",
+ "match",
+ "membership",
+ "message",
+ "mode",
+ "modify",
+ "natural",
+ "new",
+ "no",
+ "noholdlock",
+ "not",
+ "notify",
+ "null",
+ "numeric",
+ "of",
+ "off",
+ "on",
+ "open",
+ "option",
+ "options",
+ "or",
+ "order",
+ "others",
+ "out",
+ "outer",
+ "over",
+ "passthrough",
+ "precision",
+ "prepare",
+ "primary",
+ "print",
+ "privileges",
+ "proc",
+ "procedure",
+ "publication",
+ "raiserror",
+ "readtext",
+ "real",
+ "reference",
+ "references",
+ "release",
+ "remote",
+ "remove",
+ "rename",
+ "reorganize",
+ "resource",
+ "restore",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rollback",
+ "rollup",
+ "save",
+ "savepoint",
+ "scroll",
+ "select",
+ "sensitive",
+ "session",
+ "set",
+ "setuser",
+ "share",
+ "smallint",
+ "some",
+ "sqlcode",
+ "sqlstate",
+ "start",
+ "stop",
+ "subtrans",
+ "subtransaction",
+ "synchronize",
+ "syntax_error",
+ "table",
+ "temporary",
+ "then",
+ "time",
+ "timestamp",
+ "tinyint",
+ "to",
+ "top",
+ "tran",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "unbounded",
+ "union",
+ "unique",
+ "unknown",
+ "unsigned",
+ "update",
+ "updating",
+ "user",
+ "using",
+ "validate",
+ "values",
+ "varbinary",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "with_cube",
+ "with_lparen",
+ "with_rollup",
+ "within",
+ "work",
+ "writetext",
+ ]
+)
+
+
+class _SybaseUnitypeMixin(object):
+ """these types appear to return a buffer object."""
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ return str(value) # decode("ucs-2")
+ else:
+ return None
+
+ return process
+
+
+class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
+ __visit_name__ = "UNICHAR"
+
+
+class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
+ __visit_name__ = "UNIVARCHAR"
+
+
+class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
+ __visit_name__ = "UNITEXT"
+
+
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = "TINYINT"
+
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = "BIT"
+
+
+class MONEY(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+
+class SMALLMONEY(sqltypes.TypeEngine):
+ __visit_name__ = "SMALLMONEY"
+
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+
+class IMAGE(sqltypes.LargeBinary):
+ __visit_name__ = "IMAGE"
+
+
+class SybaseTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_IMAGE(type_)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BIT(type_)
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_)
+
+ def visit_UNICHAR(self, type_, **kw):
+ return "UNICHAR(%d)" % type_.length
+
+ def visit_UNIVARCHAR(self, type_, **kw):
+ return "UNIVARCHAR(%d)" % type_.length
+
+ def visit_UNITEXT(self, type_, **kw):
+ return "UNITEXT"
+
+ def visit_TINYINT(self, type_, **kw):
+ return "TINYINT"
+
+ def visit_IMAGE(self, type_, **kw):
+ return "IMAGE"
+
+ def visit_BIT(self, type_, **kw):
+ return "BIT"
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_, **kw):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
+ return "UNIQUEIDENTIFIER"
+
+
+ischema_names = {
+ "bigint": BIGINT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "unsigned bigint": BIGINT, # TODO: unsigned flags
+ "unsigned int": INTEGER, # TODO: unsigned flags
+ "unsigned smallint": SMALLINT, # TODO: unsigned flags
+ "numeric": NUMERIC,
+ "decimal": DECIMAL,
+ "dec": DECIMAL,
+ "float": FLOAT,
+ "double": NUMERIC, # TODO
+ "double precision": NUMERIC, # TODO
+ "real": REAL,
+ "smallmoney": SMALLMONEY,
+ "money": MONEY,
+ "smalldatetime": DATETIME,
+ "datetime": DATETIME,
+ "date": DATE,
+ "time": TIME,
+ "char": CHAR,
+ "character": CHAR,
+ "varchar": VARCHAR,
+ "character varying": VARCHAR,
+ "char varying": VARCHAR,
+ "unichar": UNICHAR,
+ "unicode character": UNIVARCHAR,
+ "nchar": NCHAR,
+ "national char": NCHAR,
+ "national character": NCHAR,
+ "nvarchar": NVARCHAR,
+ "nchar varying": NVARCHAR,
+ "national char varying": NVARCHAR,
+ "national character varying": NVARCHAR,
+ "text": TEXT,
+ "unitext": UNITEXT,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "image": IMAGE,
+ "bit": BIT,
+ # not in documentation for ASE 15.7
+ "long varchar": TEXT, # TODO
+ "timestamp": TIMESTAMP,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+}
+
+
+class SybaseInspector(reflection.Inspector):
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_id(self, table_name, schema=None):
+ """Return the table id from `table_name` and `schema`."""
+
+ return self.dialect.get_table_id(
+ self.bind, table_name, schema, info_cache=self.info_cache
+ )
+
+
+class SybaseExecutionContext(default.DefaultExecutionContext):
+ _enable_identity_insert = False
+
+ def set_ddl_autocommit(self, connection, value):
+ """Must be implemented by subclasses to accommodate DDL executions.
+
+ "connection" is the raw unwrapped DBAPI connection. "value"
+ is True or False. when True, the connection should be configured
+ such that a DDL can take place subsequently. when False,
+ a DDL has taken place and the connection should be resumed
+ into non-autocommit mode.
+
+ """
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ if self.isinsert:
+ tbl = self.compiled.statement.table
+ seq_column = tbl._autoincrement_column
+ insert_has_sequence = seq_column is not None
+
+ if insert_has_sequence:
+ self._enable_identity_insert = (
+ seq_column.key in self.compiled_parameters[0]
+ )
+ else:
+ self._enable_identity_insert = False
+
+ if self._enable_identity_insert:
+ self.cursor.execute(
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(tbl)
+ )
+
+ if self.isddl:
+ # TODO: to enhance this, we can detect "ddl in tran" on the
+ # database settings. this error message should be improved to
+ # include a note about that.
+ if not self.should_autocommit:
+ raise exc.InvalidRequestError(
+ "The Sybase dialect only supports "
+ "DDL in 'autocommit' mode at this time."
+ )
+
+ self.root_connection.engine.logger.info(
+ "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')"
+ )
+
+ self.set_ddl_autocommit(
+ self.root_connection.connection.connection, True
+ )
+
+ def post_exec(self):
+ if self.isddl:
+ self.set_ddl_autocommit(self.root_connection, False)
+
+ if self._enable_identity_insert:
+ self.cursor.execute(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ )
+
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT @@identity AS lastrowid")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+
+class SybaseSQLCompiler(compiler.SQLCompiler):
+ ansi_bind_rules = True
+
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"},
+ )
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += " ROWS LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += " ROWS"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def visit_extract(self, extract, **kw):
+ field = self.extract_map.get(extract.field, extract.field)
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
+
+ def visit_now_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
+ # which SQLAlchemy doesn't use
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ kw["literal_binds"] = True
+ order_by = self.process(select._order_by_clause, **kw)
+
+ # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. FROM clause specific to Sybase."""
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+
+class SybaseDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
+
+ if column.table is None:
+ raise exc.CompileError(
+ "The Sybase dialect requires Table-bound "
+ "columns in order to generate DDL"
+ )
+ seq_col = column.table._autoincrement_column
+
+ # install a IDENTITY Sequence if we have an implicit IDENTITY column
+ if seq_col is column:
+ sequence = (
+ isinstance(column.default, sa_schema.Sequence)
+ and column.default
+ )
+ if sequence:
+ start, increment = sequence.start or 1, sequence.increment or 1
+ else:
+ start, increment = 1, 1
+ if (start, increment) == (1, 1):
+ colspec += " IDENTITY"
+ else:
+ # TODO: need correct syntax for this
+ colspec += " IDENTITY(%s,%s)" % (start, increment)
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.nullable is not None:
+ if not column.nullable or column.primary_key:
+ colspec += " NOT NULL"
+ else:
+ colspec += " NULL"
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(index.table.name),
+ self._prepared_index_name(drop.element, include_schema=False),
+ )
+
+
+class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+
+class SybaseDialect(default.DefaultDialect):
+ name = "sybase"
+ supports_unicode_statements = False
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ supports_statement_cache = True
+
+ supports_native_boolean = False
+ supports_unicode_binds = False
+ postfetch_lastrowid = True
+
+ colspecs = {}
+ ischema_names = ischema_names
+
+ type_compiler = SybaseTypeCompiler
+ statement_compiler = SybaseSQLCompiler
+ ddl_compiler = SybaseDDLCompiler
+ preparer = SybaseIdentifierPreparer
+ inspector = SybaseInspector
+
+ construct_arguments = []
+
+ def __init__(self, *args, **kwargs):
+ util.warn_deprecated(
+ "The Sybase dialect is deprecated and will be removed "
+ "in a future version. This dialect is superseded by the external "
+ "dialect https://github.com/gordthompson/sqlalchemy-sybase.",
+ version="1.4",
+ )
+ super(SybaseDialect, self).__init__(*args, **kwargs)
+
+ def _get_default_schema_name(self, connection):
+ return connection.scalar(
+ text("SELECT user_name() as user_name").columns(username=Unicode)
+ )
+
+ def initialize(self, connection):
+ super(SybaseDialect, self).initialize(connection)
+ if (
+ self.server_version_info is not None
+ and self.server_version_info < (15,)
+ ):
+ self.max_identifier_length = 30
+ else:
+ self.max_identifier_length = 255
+
+ def get_table_id(self, connection, table_name, schema=None, **kw):
+ """Fetch the id for schema.table_name.
+
+ Several reflection methods require the table id. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+
+ table_id = None
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLEID_SQL = text(
+ """
+ SELECT o.id AS id
+ FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
+ WHERE u.name = :schema_name
+ AND o.name = :table_name
+ AND o.type in ('U', 'V')
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+ if isinstance(table_name, unicode): # noqa
+ table_name = table_name.encode("ascii")
+ result = connection.execute(
+ TABLEID_SQL, schema_name=schema, table_name=table_name
+ )
+ table_id = result.scalar()
+ if table_id is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_id
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ COLUMN_SQL = text(
+ """
+ SELECT col.name AS name,
+ t.name AS type,
+ (col.status & 8) AS nullable,
+ (col.status & 128) AS autoincrement,
+ com.text AS 'default',
+ col.prec AS precision,
+ col.scale AS scale,
+ col.length AS length
+ FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
+ col.cdefault = com.id
+ WHERE col.usertype = t.usertype
+ AND col.id = :table_id
+ ORDER BY col.colid
+ """
+ )
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+
+ columns = []
+ for (
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default_,
+ precision,
+ scale,
+ length,
+ ) in results:
+ col_info = self._get_column_info(
+ name,
+ type_,
+ bool(nullable),
+ bool(autoincrement),
+ default_,
+ precision,
+ scale,
+ length,
+ )
+ columns.append(col_info)
+
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default,
+ precision,
+ scale,
+ length,
+ ):
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+
+ if coltype in (NUMERIC, DECIMAL):
+ args = (precision, scale)
+ elif coltype == FLOAT:
+ args = (precision,)
+ elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
+ args = (length,)
+ else:
+ args = ()
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ # is this necessary
+ # if is_array:
+ # coltype = ARRAY(coltype)
+ else:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ if default:
+ default = default.replace("DEFAULT", "").strip()
+ default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
+ else:
+ default = None
+
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement,
+ )
+ return column_info
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ table_cache = {}
+ column_cache = {}
+ foreign_keys = []
+
+ table_cache[table_id] = {"name": table_name, "schema": schema}
+
+ COLUMN_SQL = text(
+ """
+ SELECT c.colid AS id, c.name AS name
+ FROM syscolumns c
+ WHERE c.id = :table_id
+ """
+ )
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+ columns = {}
+ for col in results:
+ columns[col["id"]] = col["name"]
+ column_cache[table_id] = columns
+
+ REFCONSTRAINT_SQL = text(
+ """
+ SELECT o.name AS name, r.reftabid AS reftable_id,
+ r.keycnt AS 'count',
+ r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
+ r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
+ r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
+ r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
+ r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
+ r.fokey16 AS fokey16,
+ r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
+ r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
+ r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
+ r.refkey10 AS refkey10, r.refkey11 AS refkey11,
+ r.refkey12 AS refkey12, r.refkey13 AS refkey13,
+ r.refkey14 AS refkey14, r.refkey15 AS refkey15,
+ r.refkey16 AS refkey16
+ FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
+ WHERE r.tableid = :table_id
+ """
+ )
+ referential_constraints = connection.execute(
+ REFCONSTRAINT_SQL, table_id=table_id
+ ).fetchall()
+
+ REFTABLE_SQL = text(
+ """
+ SELECT o.name AS name, u.name AS 'schema'
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE o.id = :table_id
+ """
+ )
+
+ for r in referential_constraints:
+ reftable_id = r["reftable_id"]
+
+ if reftable_id not in table_cache:
+ c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
+ reftable = c.fetchone()
+ c.close()
+ table_info = {"name": reftable["name"], "schema": None}
+ if (
+ schema is not None
+ or reftable["schema"] != self.default_schema_name
+ ):
+ table_info["schema"] = reftable["schema"]
+
+ table_cache[reftable_id] = table_info
+ results = connection.execute(COLUMN_SQL, table_id=reftable_id)
+ reftable_columns = {}
+ for col in results:
+ reftable_columns[col["id"]] = col["name"]
+ column_cache[reftable_id] = reftable_columns
+
+ reftable = table_cache[reftable_id]
+ reftable_columns = column_cache[reftable_id]
+
+ constrained_columns = []
+ referred_columns = []
+ for i in range(1, r["count"] + 1):
+ constrained_columns.append(columns[r["fokey%i" % i]])
+ referred_columns.append(reftable_columns[r["refkey%i" % i]])
+
+ fk_info = {
+ "constrained_columns": constrained_columns,
+ "referred_schema": reftable["schema"],
+ "referred_table": reftable["name"],
+ "referred_columns": referred_columns,
+ "name": r["name"],
+ }
+
+ foreign_keys.append(fk_info)
+
+ return foreign_keys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ INDEX_SQL = text(
+ """
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ (i.status & 0x2) AS 'unique',
+ index_col(object_name(i.id), i.indid, 1) AS col_1,
+ index_col(object_name(i.id), i.indid, 2) AS col_2,
+ index_col(object_name(i.id), i.indid, 3) AS col_3,
+ index_col(object_name(i.id), i.indid, 4) AS col_4,
+ index_col(object_name(i.id), i.indid, 5) AS col_5,
+ index_col(object_name(i.id), i.indid, 6) AS col_6,
+ index_col(object_name(i.id), i.indid, 7) AS col_7,
+ index_col(object_name(i.id), i.indid, 8) AS col_8,
+ index_col(object_name(i.id), i.indid, 9) AS col_9,
+ index_col(object_name(i.id), i.indid, 10) AS col_10,
+ index_col(object_name(i.id), i.indid, 11) AS col_11,
+ index_col(object_name(i.id), i.indid, 12) AS col_12,
+ index_col(object_name(i.id), i.indid, 13) AS col_13,
+ index_col(object_name(i.id), i.indid, 14) AS col_14,
+ index_col(object_name(i.id), i.indid, 15) AS col_15,
+ index_col(object_name(i.id), i.indid, 16) AS col_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 0
+ AND i.indid BETWEEN 1 AND 254
+ """
+ )
+
+ results = connection.execute(INDEX_SQL, table_id=table_id)
+ indexes = []
+ for r in results:
+ column_names = []
+ for i in range(1, r["count"]):
+ column_names.append(r["col_%i" % (i,)])
+ index_info = {
+ "name": r["name"],
+ "unique": bool(r["unique"]),
+ "column_names": column_names,
+ }
+ indexes.append(index_info)
+
+ return indexes
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ PK_SQL = text(
+ """
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ index_col(object_name(i.id), i.indid, 1) AS pk_1,
+ index_col(object_name(i.id), i.indid, 2) AS pk_2,
+ index_col(object_name(i.id), i.indid, 3) AS pk_3,
+ index_col(object_name(i.id), i.indid, 4) AS pk_4,
+ index_col(object_name(i.id), i.indid, 5) AS pk_5,
+ index_col(object_name(i.id), i.indid, 6) AS pk_6,
+ index_col(object_name(i.id), i.indid, 7) AS pk_7,
+ index_col(object_name(i.id), i.indid, 8) AS pk_8,
+ index_col(object_name(i.id), i.indid, 9) AS pk_9,
+ index_col(object_name(i.id), i.indid, 10) AS pk_10,
+ index_col(object_name(i.id), i.indid, 11) AS pk_11,
+ index_col(object_name(i.id), i.indid, 12) AS pk_12,
+ index_col(object_name(i.id), i.indid, 13) AS pk_13,
+ index_col(object_name(i.id), i.indid, 14) AS pk_14,
+ index_col(object_name(i.id), i.indid, 15) AS pk_15,
+ index_col(object_name(i.id), i.indid, 16) AS pk_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 2048
+ AND i.indid BETWEEN 1 AND 254
+ """
+ )
+
+ results = connection.execute(PK_SQL, table_id=table_id)
+ pks = results.fetchone()
+ results.close()
+
+ constrained_columns = []
+ if pks:
+ for i in range(1, pks["count"] + 1):
+ constrained_columns.append(pks["pk_%i" % (i,)])
+ return {
+ "constrained_columns": constrained_columns,
+ "name": pks["name"],
+ }
+ else:
+ return {"constrained_columns": [], "name": None}
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+
+ SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
+
+ schemas = connection.execute(SCHEMA_SQL)
+
+ return [s["name"] for s in schemas]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLE_SQL = text(
+ """
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'U'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+
+ tables = connection.execute(TABLE_SQL, schema_name=schema)
+
+ return [t["name"] for t in tables]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_DEF_SQL = text(
+ """
+ SELECT c.text
+ FROM syscomments c JOIN sysobjects o ON c.id = o.id
+ WHERE o.name = :view_name
+ AND o.type = 'V'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(view_name, unicode): # noqa
+ view_name = view_name.encode("ascii")
+
+ view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
+
+ return view.scalar()
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_SQL = text(
+ """
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'V'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+ views = connection.execute(VIEW_SQL, schema_name=schema)
+
+ return [v["name"] for v in views]
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ try:
+ self.get_table_id(connection, table_name, schema)
+ except exc.NoSuchTableError:
+ return False
+ else:
+ return True
diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py
new file mode 100644
index 0000000..fe5a614
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py
@@ -0,0 +1,34 @@
+# sybase/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+
+.. dialect:: sybase+mxodbc
+ :name: mxODBC
+ :dbapi: mxodbc
+ :connectstring: sybase+mxodbc://<username>:<password>@<dsnname>
+ :url: https://www.egenix.com/
+
+.. note::
+
+ This dialect is a stub only and is likely non functional at this time.
+
+"""
+from sqlalchemy.connectors.mxodbc import MxODBCConnector
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+
+
+class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
+ pass
+
+
+class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_mxodbc
+ supports_statement_cache = True
+
+
+dialect = SybaseDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
new file mode 100644
index 0000000..f408e8f
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py
@@ -0,0 +1,89 @@
+# sybase/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sybase+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: sybase+pyodbc://<username>:<password>@<dsnname>[/<database>]
+ :url: https://pypi.org/project/pyodbc/
+
+Unicode Support
+---------------
+
+The pyodbc driver currently supports usage of these Sybase types with
+Unicode or multibyte strings::
+
+ CHAR
+ NCHAR
+ NVARCHAR
+ TEXT
+ VARCHAR
+
+Currently *not* supported are::
+
+ UNICHAR
+ UNITEXT
+ UNIVARCHAR
+
+""" # noqa
+
+import decimal
+
+from sqlalchemy import processors
+from sqlalchemy import types as sqltypes
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+
+
+class _SybNumeric_pyodbc(sqltypes.Numeric):
+ """Turns Decimals with adjusted() < -6 into floats.
+
+ It's not yet known how to get decimals with many
+ significant digits or very large adjusted() into Sybase
+ via pyodbc.
+
+ """
+
+ def bind_processor(self, dialect):
+ super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
+
+ def process(value):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
+
+ if value.adjusted() < -6:
+ return processors.to_float(value)
+
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+
+ return process
+
+
+class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
+ def set_ddl_autocommit(self, connection, value):
+ if value:
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+
+
+class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_pyodbc
+ supports_statement_cache = True
+
+ colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc}
+
+ @classmethod
+ def dbapi(cls):
+ return PyODBCConnector.dbapi()
+
+
+dialect = SybaseDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py
new file mode 100644
index 0000000..4c96aac
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/pysybase.py
@@ -0,0 +1,106 @@
+# sybase/pysybase.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sybase+pysybase
+ :name: Python-Sybase
+ :dbapi: Sybase
+ :connectstring: sybase+pysybase://<username>:<password>@<dsn>/[database name]
+ :url: https://python-sybase.sourceforge.net/
+
+Unicode Support
+---------------
+
+The python-sybase driver does not appear to support non-ASCII strings of any
+kind at this time.
+
+""" # noqa
+
+from sqlalchemy import processors
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+from sqlalchemy.dialects.sybase.base import SybaseSQLCompiler
+
+
+class _SybNumeric(sqltypes.Numeric):
+ def result_processor(self, dialect, type_):
+ if not self.asdecimal:
+ return processors.to_float
+ else:
+ return sqltypes.Numeric.result_processor(self, dialect, type_)
+
+
+class SybaseExecutionContext_pysybase(SybaseExecutionContext):
+ def set_ddl_autocommit(self, dbapi_connection, value):
+ if value:
+ # call commit() on the Sybase connection directly,
+ # to avoid any side effects of calling a Connection
+ # transactional method inside of pre_exec()
+ dbapi_connection.commit()
+
+ def pre_exec(self):
+ SybaseExecutionContext.pre_exec(self)
+
+ for param in self.parameters:
+ for key in list(param):
+ param["@" + key] = param[key]
+ del param[key]
+
+
+class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
+ def bindparam_string(self, name, **kw):
+ return "@" + name
+
+
+class SybaseDialect_pysybase(SybaseDialect):
+ driver = "pysybase"
+ execution_ctx_cls = SybaseExecutionContext_pysybase
+ statement_compiler = SybaseSQLCompiler_pysybase
+
+ supports_statement_cache = True
+
+ colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float}
+
+ @classmethod
+ def dbapi(cls):
+ import Sybase
+
+ return Sybase
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user", password="passwd")
+
+ return ([opts.pop("host")], opts)
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ # calling python-sybase executemany yields:
+ # TypeError: string too long for buffer
+ for param in parameters:
+ cursor.execute(statement, param)
+
+ def _get_server_version_info(self, connection):
+ vers = connection.exec_driver_sql("select @@version_number").scalar()
+ # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0),
+ # (12, 5, 0, 0)
+ return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ msg = str(e)
+ return (
+ "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ )
+ else:
+ return False
+
+
+dialect = SybaseDialect_pysybase
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
new file mode 100644
index 0000000..2437e17
--- /dev/null
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -0,0 +1,62 @@
+# engine/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL connections, SQL execution and high-level DB-API interface.
+
+The engine package defines the basic components used to interface
+DB-API modules with higher-level statement construction,
+connection-management, execution and result contexts. The primary
+"entry point" class into this package is the Engine and its public
+constructor ``create_engine()``.
+
+"""
+
+from . import events
+from . import util
+from .base import Connection
+from .base import Engine
+from .base import NestedTransaction
+from .base import RootTransaction
+from .base import Transaction
+from .base import TwoPhaseTransaction
+from .create import create_engine
+from .create import engine_from_config
+from .cursor import BaseCursorResult
+from .cursor import BufferedColumnResultProxy
+from .cursor import BufferedColumnRow
+from .cursor import BufferedRowResultProxy
+from .cursor import CursorResult
+from .cursor import FullyBufferedResultProxy
+from .cursor import LegacyCursorResult
+from .cursor import ResultProxy
+from .interfaces import AdaptedConnection
+from .interfaces import Compiled
+from .interfaces import Connectable
+from .interfaces import CreateEnginePlugin
+from .interfaces import Dialect
+from .interfaces import ExceptionContext
+from .interfaces import ExecutionContext
+from .interfaces import TypeCompiler
+from .mock import create_mock_engine
+from .reflection import Inspector
+from .result import ChunkedIteratorResult
+from .result import FilterResult
+from .result import FrozenResult
+from .result import IteratorResult
+from .result import MappingResult
+from .result import MergedResult
+from .result import Result
+from .result import result_tuple
+from .result import ScalarResult
+from .row import BaseRow
+from .row import LegacyRow
+from .row import Row
+from .row import RowMapping
+from .url import make_url
+from .url import URL
+from .util import connection_memoize
+from ..sql import ddl
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
new file mode 100644
index 0000000..f126eb0
--- /dev/null
+++ b/lib/sqlalchemy/engine/base.py
@@ -0,0 +1,3450 @@
+# engine/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import with_statement
+
+import contextlib
+import sys
+
+from .interfaces import Connectable
+from .interfaces import ExceptionContext
+from .util import _distill_params
+from .util import _distill_params_20
+from .util import TransactionalContext
+from .. import exc
+from .. import inspection
+from .. import log
+from .. import util
+from ..sql import compiler
+from ..sql import util as sql_util
+
+
+"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+"""
+
+_EMPTY_EXECUTION_OPTS = util.immutabledict()
+
+
+class Connection(Connectable):
+ """Provides high-level functionality for a wrapped DB-API connection.
+
+ **This is the SQLAlchemy 1.x.x version** of the :class:`_engine.Connection`
+ class. For the :term:`2.0 style` version, which features some API
+ differences, see :class:`_future.Connection`.
+
+ The :class:`_engine.Connection` object is procured by calling
+ the :meth:`_engine.Engine.connect` method of the :class:`_engine.Engine`
+ object, and provides services for execution of SQL statements as well
+ as transaction control.
+
+ The Connection object is **not** thread-safe. While a Connection can be
+ shared among threads using properly synchronized access, it is still
+ possible that the underlying DBAPI connection may not support shared
+ access between threads. Check the DBAPI documentation for details.
+
+ The Connection object represents a single DBAPI connection checked out
+ from the connection pool. In this state, the connection pool has no affect
+ upon the connection, including its expiration or timeout state. For the
+ connection pool to properly manage connections, connections should be
+ returned to the connection pool (i.e. ``connection.close()``) whenever the
+ connection is not in use.
+
+ .. index::
+ single: thread safety; Connection
+
+ """
+
+ _is_future = False
+ _sqla_logger_namespace = "sqlalchemy.engine.Connection"
+
+ # used by sqlalchemy.engine.util.TransactionalContext
+ _trans_context_manager = None
+
+ def __init__(
+ self,
+ engine,
+ connection=None,
+ close_with_result=False,
+ _branch_from=None,
+ _execution_options=None,
+ _dispatch=None,
+ _has_events=None,
+ _allow_revalidate=True,
+ ):
+ """Construct a new Connection."""
+ self.engine = engine
+ self.dialect = engine.dialect
+ self.__branch_from = _branch_from
+
+ if _branch_from:
+ # branching is always "from" the root connection
+ assert _branch_from.__branch_from is None
+ self._dbapi_connection = connection
+ self._execution_options = _execution_options
+ self._echo = _branch_from._echo
+ self.should_close_with_result = False
+ self.dispatch = _dispatch
+ self._has_events = _branch_from._has_events
+ else:
+ self._dbapi_connection = (
+ connection
+ if connection is not None
+ else engine.raw_connection()
+ )
+
+ self._transaction = self._nested_transaction = None
+ self.__savepoint_seq = 0
+ self.__in_begin = False
+ self.should_close_with_result = close_with_result
+
+ self.__can_reconnect = _allow_revalidate
+ self._echo = self.engine._should_log_info()
+
+ if _has_events is None:
+ # if _has_events is sent explicitly as False,
+ # then don't join the dispatch of the engine; we don't
+ # want to handle any of the engine's events in that case.
+ self.dispatch = self.dispatch._join(engine.dispatch)
+ self._has_events = _has_events or (
+ _has_events is None and engine._has_events
+ )
+
+ assert not _execution_options
+ self._execution_options = engine._execution_options
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.engine_connect(self, _branch_from is not None)
+
+ @util.memoized_property
+ def _message_formatter(self):
+ if "logging_token" in self._execution_options:
+ token = self._execution_options["logging_token"]
+ return lambda msg: "[%s] %s" % (token, msg)
+ else:
+ return None
+
+ def _log_info(self, message, *arg, **kw):
+ fmt = self._message_formatter
+
+ if fmt:
+ message = fmt(message)
+
+ if log.STACKLEVEL:
+ kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET
+
+ self.engine.logger.info(message, *arg, **kw)
+
+ def _log_debug(self, message, *arg, **kw):
+ fmt = self._message_formatter
+
+ if fmt:
+ message = fmt(message)
+
+ if log.STACKLEVEL:
+ kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET
+
+ self.engine.logger.debug(message, *arg, **kw)
+
+ @property
+ def _schema_translate_map(self):
+ return self._execution_options.get("schema_translate_map", None)
+
+ def schema_for_object(self, obj):
+ """Return the schema name for the given schema item taking into
+ account current schema translate map.
+
+ """
+
+ name = obj.schema
+ schema_translate_map = self._execution_options.get(
+ "schema_translate_map", None
+ )
+
+ if (
+ schema_translate_map
+ and name in schema_translate_map
+ and obj._use_schema_map
+ ):
+ return schema_translate_map[name]
+ else:
+ return name
+
+ def _branch(self):
+ """Return a new Connection which references this Connection's
+ engine and connection; but does not have close_with_result enabled,
+ and also whose close() method does nothing.
+
+ .. deprecated:: 1.4 the "branching" concept will be removed in
+ SQLAlchemy 2.0 as well as the "Connection.connect()" method which
+ is the only consumer for this.
+
+ The Core uses this very sparingly, only in the case of
+ custom SQL default functions that are to be INSERTed as the
+ primary key of a row where we need to get the value back, so we have
+ to invoke it distinctly - this is a very uncommon case.
+
+ Userland code accesses _branch() when the connect()
+ method is called. The branched connection
+ acts as much as possible like the parent, except that it stays
+ connected when a close() event occurs.
+
+ """
+ return self.engine._connection_cls(
+ self.engine,
+ self._dbapi_connection,
+ _branch_from=self.__branch_from if self.__branch_from else self,
+ _execution_options=self._execution_options,
+ _has_events=self._has_events,
+ _dispatch=self.dispatch,
+ )
+
+ def _generate_for_options(self):
+ """define connection method chaining behavior for execution_options"""
+
+ if self._is_future:
+ return self
+ else:
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.close()
+
+ def execution_options(self, **opt):
+ r""" Set non-SQL options for the connection which take effect
+ during execution.
+
+ For a "future" style connection, this method returns this same
+ :class:`_future.Connection` object with the new options added.
+
+ For a legacy connection, this method returns a copy of this
+ :class:`_engine.Connection` which references the same underlying DBAPI
+ connection, but also defines the given execution options which will
+ take effect for a call to
+ :meth:`execute`. As the new :class:`_engine.Connection` references the
+ same underlying resource, it's usually a good idea to ensure that
+ the copies will be discarded immediately, which is implicit if used
+ as in::
+
+ result = connection.execution_options(stream_results=True).\
+ execute(stmt)
+
+ Note that any key/value can be passed to
+ :meth:`_engine.Connection.execution_options`,
+ and it will be stored in the
+ ``_execution_options`` dictionary of the :class:`_engine.Connection`.
+ It
+ is suitable for usage by end-user schemes to communicate with
+ event listeners, for example.
+
+ The keywords that are currently recognized by SQLAlchemy itself
+ include all those listed under :meth:`.Executable.execution_options`,
+ as well as others that are specific to :class:`_engine.Connection`.
+
+ :param autocommit: Available on: Connection, statement.
+ When True, a COMMIT will be invoked after execution
+ when executed in 'autocommit' mode, i.e. when an explicit
+ transaction is not begun on the connection. Note that this
+ is **library level, not DBAPI level autocommit**. The DBAPI
+ connection will remain in a real transaction unless the
+ "AUTOCOMMIT" isolation level is used.
+
+ .. deprecated:: 1.4 The "autocommit" execution option is deprecated
+ and will be removed in SQLAlchemy 2.0. See
+ :ref:`migration_20_autocommit` for discussion.
+
+ :param compiled_cache: Available on: Connection.
+ A dictionary where :class:`.Compiled` objects
+ will be cached when the :class:`_engine.Connection`
+ compiles a clause
+ expression into a :class:`.Compiled` object. This dictionary will
+ supersede the statement cache that may be configured on the
+ :class:`_engine.Engine` itself. If set to None, caching
+ is disabled, even if the engine has a configured cache size.
+
+ Note that the ORM makes use of its own "compiled" caches for
+ some operations, including flush operations. The caching
+ used by the ORM internally supersedes a cache dictionary
+ specified here.
+
+ :param logging_token: Available on: :class:`_engine.Connection`,
+ :class:`_engine.Engine`.
+
+ Adds the specified string token surrounded by brackets in log
+ messages logged by the connection, i.e. the logging that's enabled
+ either via the :paramref:`_sa.create_engine.echo` flag or via the
+ ``logging.getLogger("sqlalchemy.engine")`` logger. This allows a
+ per-connection or per-sub-engine token to be available which is
+ useful for debugging concurrent connection scenarios.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`dbengine_logging_tokens` - usage example
+
+ :paramref:`_sa.create_engine.logging_name` - adds a name to the
+ name used by the Python logger object itself.
+
+ :param isolation_level: Available on: :class:`_engine.Connection`.
+
+ Set the transaction isolation level for the lifespan of this
+ :class:`_engine.Connection` object.
+ Valid values include those string
+ values accepted by the :paramref:`_sa.create_engine.isolation_level`
+ parameter passed to :func:`_sa.create_engine`. These levels are
+ semi-database specific; see individual dialect documentation for
+ valid levels.
+
+ The isolation level option applies the isolation level by emitting
+ statements on the DBAPI connection, and **necessarily affects the
+ original Connection object overall**, not just the copy that is
+ returned by the call to :meth:`_engine.Connection.execution_options`
+ method. The isolation level will remain at the given setting until
+ the DBAPI connection itself is returned to the connection pool, i.e.
+ the :meth:`_engine.Connection.close` method on the original
+ :class:`_engine.Connection` is called,
+ where an event handler will emit
+ additional statements on the DBAPI connection in order to revert the
+ isolation level change.
+
+ .. warning:: The ``isolation_level`` execution option should
+ **not** be used when a transaction is already established, that
+ is, the :meth:`_engine.Connection.begin`
+ method or similar has been
+ called. A database cannot change the isolation level on a
+ transaction in progress, and different DBAPIs and/or
+ SQLAlchemy dialects may implicitly roll back or commit
+ the transaction, or not affect the connection at all.
+
+ .. note:: The ``isolation_level`` execution option is implicitly
+ reset if the :class:`_engine.Connection` is invalidated, e.g. via
+ the :meth:`_engine.Connection.invalidate` method, or if a
+ disconnection error occurs. The new connection produced after
+ the invalidation will not have the isolation level re-applied
+ to it automatically.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :ref:`SQLite Transaction Isolation <sqlite_isolation_level>`
+
+ :ref:`PostgreSQL Transaction Isolation <postgresql_isolation_level>`
+
+ :ref:`MySQL Transaction Isolation <mysql_isolation_level>`
+
+ :ref:`SQL Server Transaction Isolation <mssql_isolation_level>`
+
+ :ref:`session_transaction_isolation` - for the ORM
+
+ :param no_parameters: When ``True``, if the final parameter
+ list or dictionary is totally empty, will invoke the
+ statement on the cursor as ``cursor.execute(statement)``,
+ not passing the parameter collection at all.
+ Some DBAPIs such as psycopg2 and mysql-python consider
+ percent signs as significant only when parameters are
+ present; this option allows code to generate SQL
+ containing percent signs (and possibly other characters)
+ that is neutral regarding whether it's executed by the DBAPI
+ or piped into a script that's later invoked by
+ command line tools.
+
+ :param stream_results: Available on: Connection, statement.
+ Indicate to the dialect that results should be
+ "streamed" and not pre-buffered, if possible. For backends
+ such as PostgreSQL, MySQL and MariaDB, this indicates the use of
+ a "server side cursor" as opposed to a client side cursor.
+ Other backends such as that of Oracle may already use server
+ side cursors by default.
+
+ The usage of
+ :paramref:`_engine.Connection.execution_options.stream_results` is
+ usually combined with setting a fixed number of rows to to be fetched
+ in batches, to allow for efficient iteration of database rows while
+ at the same time not loading all result rows into memory at once;
+ this can be configured on a :class:`_engine.Result` object using the
+ :meth:`_engine.Result.yield_per` method, after execution has
+ returned a new :class:`_engine.Result`. If
+ :meth:`_engine.Result.yield_per` is not used,
+ the :paramref:`_engine.Connection.execution_options.stream_results`
+ mode of operation will instead use a dynamically sized buffer
+ which buffers sets of rows at a time, growing on each batch
+ based on a fixed growth size up until a limit which may
+ be configured using the
+ :paramref:`_engine.Connection.execution_options.max_row_buffer`
+ parameter.
+
+ When using the ORM to fetch ORM mapped objects from a result,
+ :meth:`_engine.Result.yield_per` should always be used with
+ :paramref:`_engine.Connection.execution_options.stream_results`,
+ so that the ORM does not fetch all rows into new ORM objects at once.
+
+ For typical use, the
+ :paramref:`_engine.Connection.execution_options.yield_per` execution
+ option should be preferred, which sets up both
+ :paramref:`_engine.Connection.execution_options.stream_results` and
+ :meth:`_engine.Result.yield_per` at once. This option is supported
+ both at a core level by :class:`_engine.Connection` as well as by the
+ ORM :class:`_engine.Session`; the latter is described at
+ :ref:`orm_queryguide_yield_per`.
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - background on
+ :paramref:`_engine.Connection.execution_options.stream_results`
+
+ :paramref:`_engine.Connection.execution_options.max_row_buffer`
+
+ :paramref:`_engine.Connection.execution_options.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+ describing the ORM version of ``yield_per``
+
+ :param max_row_buffer: Available on: :class:`_engine.Connection`,
+ :class:`_sql.Executable`. Sets a maximum
+ buffer size to use when the
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option is used on a backend that supports server side
+ cursors. The default value if not specified is 1000.
+
+ .. seealso::
+
+ :paramref:`_engine.Connection.execution_options.stream_results`
+
+ :ref:`engine_stream_results`
+
+
+ :param yield_per: Available on: :class:`_engine.Connection`,
+ :class:`_sql.Executable`. Integer value applied which will
+ set the :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option and invoke :meth:`_engine.Result.yield_per`
+ automatically at once. Allows equivalent functionality as
+ is present when using this parameter with the ORM.
+
+ .. versionadded:: 1.4.40
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - background and examples
+ on using server side cursors with Core.
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+ describing the ORM version of ``yield_per``
+
+ :param schema_translate_map: Available on: :class:`_engine.Connection`,
+ :class:`_engine.Engine`, :class:`_sql.Executable`.
+
+ :param schema_translate_map: Available on: Connection, Engine.
+ A dictionary mapping schema names to schema names, that will be
+ applied to the :paramref:`_schema.Table.schema` element of each
+ :class:`_schema.Table`
+ encountered when SQL or DDL expression elements
+ are compiled into strings; the resulting schema name will be
+ converted based on presence in the map of the original name.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
+ .. seealso::
+
+ :meth:`_engine.Engine.execution_options`
+
+ :meth:`.Executable.execution_options`
+
+ :meth:`_engine.Connection.get_execution_options`
+
+
+ """ # noqa
+ c = self._generate_for_options()
+ c._execution_options = c._execution_options.union(opt)
+ if self._has_events or self.engine._has_events:
+ self.dispatch.set_connection_execution_options(c, opt)
+ self.dialect.set_connection_execution_options(c, opt)
+ return c
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+ """
+ return self._execution_options
+
+ @property
+ def closed(self):
+ """Return True if this connection is closed."""
+
+ # note this is independent for a "branched" connection vs.
+ # the base
+
+ return self._dbapi_connection is None and not self.__can_reconnect
+
+ @property
+ def invalidated(self):
+ """Return True if this connection was invalidated."""
+
+ # prior to 1.4, "invalid" was stored as a state independent of
+ # "closed", meaning an invalidated connection could be "closed",
+ # the _dbapi_connection would be None and closed=True, yet the
+ # "invalid" flag would stay True. This meant that there were
+ # three separate states (open/valid, closed/valid, closed/invalid)
+ # when there is really no reason for that; a connection that's
+ # "closed" does not need to be "invalid". So the state is now
+ # represented by the two facts alone.
+
+ if self.__branch_from:
+ return self.__branch_from.invalidated
+
+ return self._dbapi_connection is None and not self.closed
+
+ @property
+ def connection(self):
+ """The underlying DB-API connection managed by this Connection.
+
+ This is a SQLAlchemy connection-pool proxied connection
+ which then has the attribute
+ :attr:`_pool._ConnectionFairy.dbapi_connection` that refers to the
+ actual driver connection.
+
+ .. seealso::
+
+
+ :ref:`dbapi_connections`
+
+ """
+
+ if self._dbapi_connection is None:
+ try:
+ return self._revalidate_connection()
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ else:
+ return self._dbapi_connection
+
+ def get_isolation_level(self):
+ """Return the current isolation level assigned to this
+ :class:`_engine.Connection`.
+
+ This will typically be the default isolation level as determined
+ by the dialect, unless if the
+ :paramref:`.Connection.execution_options.isolation_level`
+ feature has been used to alter the isolation level on a
+ per-:class:`_engine.Connection` basis.
+
+ This attribute will typically perform a live SQL operation in order
+ to procure the current isolation level, so the value returned is the
+ actual level on the underlying DBAPI connection regardless of how
+ this state was set. Compare to the
+ :attr:`_engine.Connection.default_isolation_level` accessor
+ which returns the dialect-level setting without performing a SQL
+ query.
+
+ .. versionadded:: 0.9.9
+
+ .. seealso::
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :paramref:`.Connection.execution_options.isolation_level`
+ - set per :class:`_engine.Connection` isolation level
+
+ """
+ try:
+ return self.dialect.get_isolation_level(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ @property
+ def default_isolation_level(self):
+ """The default isolation level assigned to this
+ :class:`_engine.Connection`.
+
+ This is the isolation level setting that the
+ :class:`_engine.Connection`
+ has when first procured via the :meth:`_engine.Engine.connect` method.
+ This level stays in place until the
+ :paramref:`.Connection.execution_options.isolation_level` is used
+ to change the setting on a per-:class:`_engine.Connection` basis.
+
+ Unlike :meth:`_engine.Connection.get_isolation_level`,
+ this attribute is set
+ ahead of time from the first connection procured by the dialect,
+ so SQL query is not invoked when this accessor is called.
+
+ .. versionadded:: 0.9.9
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :paramref:`.Connection.execution_options.isolation_level`
+ - set per :class:`_engine.Connection` isolation level
+
+ """
+ return self.dialect.default_isolation_level
+
+ def _invalid_transaction(self):
+ if self.invalidated:
+ raise exc.PendingRollbackError(
+ "Can't reconnect until invalid %stransaction is rolled "
+ "back."
+ % (
+ "savepoint "
+ if self._nested_transaction is not None
+ else ""
+ ),
+ code="8s2b",
+ )
+ else:
+ assert not self._is_future
+ raise exc.PendingRollbackError(
+ "This connection is on an inactive %stransaction. "
+ "Please rollback() fully before proceeding."
+ % (
+ "savepoint "
+ if self._nested_transaction is not None
+ else ""
+ ),
+ code="8s2a",
+ )
+
+ def _revalidate_connection(self):
+ if self.__branch_from:
+ return self.__branch_from._revalidate_connection()
+ if self.__can_reconnect and self.invalidated:
+ if self._transaction is not None:
+ self._invalid_transaction()
+ self._dbapi_connection = self.engine.raw_connection(
+ _connection=self
+ )
+ return self._dbapi_connection
+ raise exc.ResourceClosedError("This Connection is closed")
+
+ @property
+ def _still_open_and_dbapi_connection_is_valid(self):
+ return self._dbapi_connection is not None and getattr(
+ self._dbapi_connection, "is_valid", False
+ )
+
+ @property
+ def info(self):
+ """Info dictionary associated with the underlying DBAPI connection
+ referred to by this :class:`_engine.Connection`, allowing user-defined
+ data to be associated with the connection.
+
+ The data here will follow along with the DBAPI connection including
+ after it is returned to the connection pool and used again
+ in subsequent instances of :class:`_engine.Connection`.
+
+ """
+
+ return self.connection.info
+
+ @util.deprecated_20(":meth:`.Connection.connect`")
+ def connect(self, close_with_result=False):
+ """Returns a branched version of this :class:`_engine.Connection`.
+
+ The :meth:`_engine.Connection.close` method on the returned
+ :class:`_engine.Connection` can be called and this
+ :class:`_engine.Connection` will remain open.
+
+ This method provides usage symmetry with
+ :meth:`_engine.Engine.connect`, including for usage
+ with context managers.
+
+ """
+
+ return self._branch()
+
+ def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ An attempt will be made to close the underlying DBAPI connection
+ immediately; however if this operation fails, the error is logged
+ but not raised. The connection is then discarded whether or not
+ close() succeeded.
+
+ Upon the next use (where "use" typically means using the
+ :meth:`_engine.Connection.execute` method or similar),
+ this :class:`_engine.Connection` will attempt to
+ procure a new DBAPI connection using the services of the
+ :class:`_pool.Pool` as a source of connectivity (e.g.
+ a "reconnection").
+
+ If a transaction was in progress (e.g. the
+ :meth:`_engine.Connection.begin` method has been called) when
+ :meth:`_engine.Connection.invalidate` method is called, at the DBAPI
+ level all state associated with this transaction is lost, as
+ the DBAPI connection is closed. The :class:`_engine.Connection`
+ will not allow a reconnection to proceed until the
+ :class:`.Transaction` object is ended, by calling the
+ :meth:`.Transaction.rollback` method; until that point, any attempt at
+ continuing to use the :class:`_engine.Connection` will raise an
+ :class:`~sqlalchemy.exc.InvalidRequestError`.
+ This is to prevent applications from accidentally
+ continuing an ongoing transactional operations despite the
+ fact that the transaction has been lost due to an
+ invalidation.
+
+ The :meth:`_engine.Connection.invalidate` method,
+ just like auto-invalidation,
+ will at the connection pool level invoke the
+ :meth:`_events.PoolEvents.invalidate` event.
+
+ :param exception: an optional ``Exception`` instance that's the
+ reason for the invalidation. is passed along to event handlers
+ and logging functions.
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ if self.__branch_from:
+ return self.__branch_from.invalidate(exception=exception)
+
+ if self.invalidated:
+ return
+
+ if self.closed:
+ raise exc.ResourceClosedError("This Connection is closed")
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self._dbapi_connection.invalidate(exception)
+ self._dbapi_connection = None
+
+ def detach(self):
+ """Detach the underlying DB-API connection from its connection pool.
+
+ E.g.::
+
+ with engine.connect() as conn:
+ conn.detach()
+ conn.execute(text("SET search_path TO schema1, schema2"))
+
+ # work with connection
+
+ # connection is fully closed (since we used "with:", can
+ # also call .close())
+
+ This :class:`_engine.Connection` instance will remain usable.
+ When closed
+ (or exited from a context manager context as above),
+ the DB-API connection will be literally closed and not
+ returned to its originating pool.
+
+ This method can be used to insulate the rest of an application
+ from a modified state on a connection (such as a transaction
+ isolation level or similar).
+
+ """
+
+ self._dbapi_connection.detach()
+
+ def _autobegin(self):
+ self.begin()
+
+ def begin(self):
+ """Begin a transaction and return a transaction handle.
+
+ The returned object is an instance of :class:`.Transaction`.
+ This object represents the "scope" of the transaction,
+ which completes when either the :meth:`.Transaction.rollback`
+ or :meth:`.Transaction.commit` method is called.
+
+ .. tip::
+
+ The :meth:`_engine.Connection.begin` method is invoked when using
+ the :meth:`_engine.Engine.begin` context manager method as well.
+ All documentation that refers to behaviors specific to the
+ :meth:`_engine.Connection.begin` method also apply to use of the
+ :meth:`_engine.Engine.begin` method.
+
+ Legacy use: nested calls to :meth:`.begin` on the same
+ :class:`_engine.Connection` will return new :class:`.Transaction`
+ objects that represent an emulated transaction within the scope of the
+ enclosing transaction, that is::
+
+ trans = conn.begin() # outermost transaction
+ trans2 = conn.begin() # "nested"
+ trans2.commit() # does nothing
+ trans.commit() # actually commits
+
+ Calls to :meth:`.Transaction.commit` only have an effect
+ when invoked via the outermost :class:`.Transaction` object, though the
+ :meth:`.Transaction.rollback` method of any of the
+ :class:`.Transaction` objects will roll back the
+ transaction.
+
+ .. tip::
+
+ The above "nesting" behavior is a legacy behavior specific to
+ :term:`1.x style` use and will be removed in SQLAlchemy 2.0. For
+ notes on :term:`2.0 style` use, see
+ :meth:`_future.Connection.begin`.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin_nested` - use a SAVEPOINT
+
+ :meth:`_engine.Connection.begin_twophase` -
+ use a two phase /XID transaction
+
+ :meth:`_engine.Engine.begin` - context manager available from
+ :class:`_engine.Engine`
+
+ """
+ if self._is_future:
+ assert not self.__branch_from
+ elif self.__branch_from:
+ return self.__branch_from.begin()
+
+ if self.__in_begin:
+ # for dialects that emit SQL within the process of
+ # dialect.do_begin() or dialect.do_begin_twophase(), this
+ # flag prevents "autobegin" from being emitted within that
+ # process, while allowing self._transaction to remain at None
+ # until it's complete.
+ return
+ elif self._transaction is None:
+ self._transaction = RootTransaction(self)
+ return self._transaction
+ else:
+ if self._is_future:
+ raise exc.InvalidRequestError(
+ "This connection has already initialized a SQLAlchemy "
+ "Transaction() object via begin() or autobegin; can't "
+ "call begin() here unless rollback() or commit() "
+ "is called first."
+ )
+ else:
+ return MarkerTransaction(self)
+
+ def begin_nested(self):
+ """Begin a nested transaction (i.e. SAVEPOINT) and return a
+ transaction handle, assuming an outer transaction is already
+ established.
+
+ Nested transactions require SAVEPOINT support in the
+ underlying database. Any transaction in the hierarchy may
+ ``commit`` and ``rollback``, however the outermost transaction
+ still controls the overall ``commit`` or ``rollback`` of the
+ transaction of a whole.
+
+ The legacy form of :meth:`_engine.Connection.begin_nested` method has
+ alternate behaviors based on whether or not the
+ :meth:`_engine.Connection.begin` method was called previously. If
+ :meth:`_engine.Connection.begin` was not called, then this method will
+ behave the same as the :meth:`_engine.Connection.begin` method and
+ return a :class:`.RootTransaction` object that begins and commits a
+ real transaction - **no savepoint is invoked**. If
+ :meth:`_engine.Connection.begin` **has** been called, and a
+ :class:`.RootTransaction` is already established, then this method
+ returns an instance of :class:`.NestedTransaction` which will invoke
+ and manage the scope of a SAVEPOINT.
+
+ .. tip::
+
+ The above mentioned behavior of
+ :meth:`_engine.Connection.begin_nested` is a legacy behavior
+ specific to :term:`1.x style` use. In :term:`2.0 style` use, the
+ :meth:`_future.Connection.begin_nested` method instead autobegins
+ the outer transaction that can be committed using
+ "commit-as-you-go" style; see
+ :meth:`_future.Connection.begin_nested` for migration details.
+
+ .. versionchanged:: 1.4.13 The behavior of
+ :meth:`_engine.Connection.begin_nested`
+ as returning a :class:`.RootTransaction` if
+ :meth:`_engine.Connection.begin` were not called has been restored
+ as was the case in 1.3.x versions; in previous 1.4.x versions, an
+ outer transaction would be "autobegun" but would not be committed.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :ref:`session_begin_nested` - ORM support for SAVEPOINT
+
+ """
+ if self._is_future:
+ assert not self.__branch_from
+ elif self.__branch_from:
+ return self.__branch_from.begin_nested()
+
+ if self._transaction is None:
+ if not self._is_future:
+ util.warn_deprecated_20(
+ "Calling Connection.begin_nested() in 2.0 style use will "
+ "return a NestedTransaction (SAVEPOINT) in all cases, "
+ "that will not commit the outer transaction. For code "
+ "that is cross-compatible between 1.x and 2.0 style use, "
+ "ensure Connection.begin() is called before calling "
+ "Connection.begin_nested()."
+ )
+ return self.begin()
+ else:
+ self._autobegin()
+
+ return NestedTransaction(self)
+
+ def begin_twophase(self, xid=None):
+ """Begin a two-phase or XA transaction and return a transaction
+ handle.
+
+ The returned object is an instance of :class:`.TwoPhaseTransaction`,
+ which in addition to the methods provided by
+ :class:`.Transaction`, also provides a
+ :meth:`~.TwoPhaseTransaction.prepare` method.
+
+ :param xid: the two phase transaction id. If not supplied, a
+ random id will be generated.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :meth:`_engine.Connection.begin_twophase`
+
+ """
+
+ if self.__branch_from:
+ return self.__branch_from.begin_twophase(xid=xid)
+
+ if self._transaction is not None:
+ raise exc.InvalidRequestError(
+ "Cannot start a two phase transaction when a transaction "
+ "is already in progress."
+ )
+ if xid is None:
+ xid = self.engine.dialect.create_xid()
+ return TwoPhaseTransaction(self, xid)
+
+ def recover_twophase(self):
+ return self.engine.dialect.do_recover_twophase(self)
+
+ def rollback_prepared(self, xid, recover=False):
+ self.engine.dialect.do_rollback_twophase(self, xid, recover=recover)
+
+ def commit_prepared(self, xid, recover=False):
+ self.engine.dialect.do_commit_twophase(self, xid, recover=recover)
+
+ def in_transaction(self):
+ """Return True if a transaction is in progress."""
+ if self.__branch_from is not None:
+ return self.__branch_from.in_transaction()
+
+ return self._transaction is not None and self._transaction.is_active
+
+ def in_nested_transaction(self):
+ """Return True if a transaction is in progress."""
+ if self.__branch_from is not None:
+ return self.__branch_from.in_nested_transaction()
+
+ return (
+ self._nested_transaction is not None
+ and self._nested_transaction.is_active
+ )
+
+ def _is_autocommit_isolation(self):
+ opt_iso = self._execution_options.get("isolation_level", None)
+ return bool(
+ opt_iso == "AUTOCOMMIT"
+ or (
+ opt_iso is None
+ and getattr(self.engine.dialect, "isolation_level", None)
+ == "AUTOCOMMIT"
+ )
+ )
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if self.__branch_from is not None:
+ return self.__branch_from.get_transaction()
+
+ return self._transaction
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+ if self.__branch_from is not None:
+
+ return self.__branch_from.get_nested_transaction()
+
+ return self._nested_transaction
+
+ def _begin_impl(self, transaction):
+ assert not self.__branch_from
+
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "BEGIN (implicit; DBAPI should not BEGIN due to "
+ "autocommit mode)"
+ )
+ else:
+ self._log_info("BEGIN (implicit)")
+
+ self.__in_begin = True
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.begin(self)
+
+ try:
+ self.engine.dialect.do_begin(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ finally:
+ self.__in_begin = False
+
+ def _rollback_impl(self):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback(self)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "ROLLBACK using DBAPI connection.rollback(), "
+ "DBAPI should ignore due to autocommit mode"
+ )
+ else:
+ self._log_info("ROLLBACK")
+ try:
+ self.engine.dialect.do_rollback(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _commit_impl(self, autocommit=False):
+ assert not self.__branch_from
+
+ # AUTOCOMMIT isolation-level is a dialect-specific concept, however
+ # if a connection has this set as the isolation level, we can skip
+ # the "autocommit" warning as the operation will do "autocommit"
+ # in any case
+ if autocommit and not self._is_autocommit_isolation():
+ util.warn_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit, which will be removed in "
+ "SQLAlchemy 2.0. "
+ "Use the .begin() method of Engine or Connection in order to "
+ "use an explicit transaction for DML and DDL statements."
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.commit(self)
+
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "COMMIT using DBAPI connection.commit(), "
+ "DBAPI should ignore due to autocommit mode"
+ )
+ else:
+ self._log_info("COMMIT")
+ try:
+ self.engine.dialect.do_commit(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _savepoint_impl(self, name=None):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.savepoint(self, name)
+
+ if name is None:
+ self.__savepoint_seq += 1
+ name = "sa_savepoint_%s" % self.__savepoint_seq
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_savepoint(self, name)
+ return name
+
+ def _rollback_to_savepoint_impl(self, name):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback_savepoint(self, name, None)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_rollback_to_savepoint(self, name)
+
+ def _release_savepoint_impl(self, name):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.release_savepoint(self, name, None)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_release_savepoint(self, name)
+
+ def _begin_twophase_impl(self, transaction):
+ assert not self.__branch_from
+
+ if self._echo:
+ self._log_info("BEGIN TWOPHASE (implicit)")
+ if self._has_events or self.engine._has_events:
+ self.dispatch.begin_twophase(self, transaction.xid)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.__in_begin = True
+ try:
+ self.engine.dialect.do_begin_twophase(self, transaction.xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ finally:
+ self.__in_begin = False
+
+ def _prepare_twophase_impl(self, xid):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.prepare_twophase(self, xid)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_prepare_twophase(self, xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _rollback_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback_twophase(self, xid, is_prepared)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_rollback_twophase(
+ self, xid, is_prepared
+ )
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _commit_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.commit_twophase(self, xid, is_prepared)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _autorollback(self):
+ if self.__branch_from:
+ self.__branch_from._autorollback()
+
+ if not self.in_transaction():
+ self._rollback_impl()
+
+ def _warn_for_legacy_exec_format(self):
+ util.warn_deprecated_20(
+ "The connection.execute() method in "
+ "SQLAlchemy 2.0 will accept parameters as a single "
+ "dictionary or a "
+ "single sequence of dictionaries only. "
+ "Parameters passed as keyword arguments, tuples or positionally "
+ "oriented dictionaries and/or tuples "
+ "will no longer be accepted."
+ )
+
+ def close(self):
+ """Close this :class:`_engine.Connection`.
+
+ This results in a release of the underlying database
+ resources, that is, the DBAPI connection referenced
+ internally. The DBAPI connection is typically restored
+ back to the connection-holding :class:`_pool.Pool` referenced
+ by the :class:`_engine.Engine` that produced this
+ :class:`_engine.Connection`. Any transactional state present on
+ the DBAPI connection is also unconditionally released via
+ the DBAPI connection's ``rollback()`` method, regardless
+ of any :class:`.Transaction` object that may be
+ outstanding with regards to this :class:`_engine.Connection`.
+
+ After :meth:`_engine.Connection.close` is called, the
+ :class:`_engine.Connection` is permanently in a closed state,
+ and will allow no further operations.
+
+ """
+
+ if self.__branch_from:
+ assert not self._is_future
+ util.warn_deprecated_20(
+ "The .close() method on a so-called 'branched' connection is "
+ "deprecated as of 1.4, as are 'branched' connections overall, "
+ "and will be removed in a future release. If this is a "
+ "default-handling function, don't close the connection."
+ )
+ self._dbapi_connection = None
+ self.__can_reconnect = False
+ return
+
+ if self._transaction:
+ self._transaction.close()
+ skip_reset = True
+ else:
+ skip_reset = False
+
+ if self._dbapi_connection is not None:
+ conn = self._dbapi_connection
+
+ # as we just closed the transaction, close the connection
+ # pool connection without doing an additional reset
+ if skip_reset:
+ conn._close_no_reset()
+ else:
+ conn.close()
+
+ # There is a slight chance that conn.close() may have
+ # triggered an invalidation here in which case
+ # _dbapi_connection would already be None, however usually
+ # it will be non-None here and in a "closed" state.
+ self._dbapi_connection = None
+ self.__can_reconnect = False
+
+ def scalar(self, object_, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying result/cursor is closed after execution.
+
+ """
+
+ return self.execute(object_, *multiparams, **params).scalar()
+
+ def scalars(self, object_, *multiparams, **params):
+ """Executes and returns a scalar result set, which yields scalar values
+ from the first column of each row.
+
+ This method is equivalent to calling :meth:`_engine.Connection.execute`
+ to receive a :class:`_result.Result` object, then invoking the
+ :meth:`_result.Result.scalars` method to produce a
+ :class:`_result.ScalarResult` instance.
+
+ :return: a :class:`_result.ScalarResult`
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ return self.execute(object_, *multiparams, **params).scalars()
+
+ def execute(self, statement, *multiparams, **params):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.CursorResult`.
+
+ :param statement: The statement to be executed. May be
+ one of:
+
+ * a plain string (deprecated)
+ * any :class:`_expression.ClauseElement` construct that is also
+ a subclass of :class:`.Executable`, such as a
+ :func:`_expression.select` construct
+ * a :class:`.FunctionElement`, such as that generated
+ by :data:`.func`, will be automatically wrapped in
+ a SELECT statement, which is then executed.
+ * a :class:`.DDLElement` object
+ * a :class:`.DefaultGenerator` object
+ * a :class:`.Compiled` object
+
+ .. deprecated:: 2.0 passing a string to
+ :meth:`_engine.Connection.execute` is
+ deprecated and will be removed in version 2.0. Use the
+ :func:`_expression.text` construct with
+ :meth:`_engine.Connection.execute`, or the
+ :meth:`_engine.Connection.exec_driver_sql`
+ method to invoke a driver-level
+ SQL string.
+
+ :param \*multiparams/\**params: represent bound parameter
+ values to be used in the execution. Typically,
+ the format is either a collection of one or more
+ dictionaries passed to \*multiparams::
+
+ conn.execute(
+ table.insert(),
+ {"id":1, "value":"v1"},
+ {"id":2, "value":"v2"}
+ )
+
+ ...or individual key/values interpreted by \**params::
+
+ conn.execute(
+ table.insert(), id=1, value="v1"
+ )
+
+ In the case that a plain SQL string is passed, and the underlying
+ DBAPI accepts positional bind parameters, a collection of tuples
+ or individual values in \*multiparams may be passed::
+
+ conn.execute(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ (1, "v1"), (2, "v2")
+ )
+
+ conn.execute(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ 1, "v1"
+ )
+
+ Note above, the usage of a question mark "?" or other
+ symbol is contingent upon the "paramstyle" accepted by the DBAPI
+ in use, which may be any of "qmark", "named", "pyformat", "format",
+ "numeric". See `pep-249
+ <https://www.python.org/dev/peps/pep-0249/>`_ for details on
+ paramstyle.
+
+ To execute a textual SQL statement which uses bound parameters in a
+ DBAPI-agnostic way, use the :func:`_expression.text` construct.
+
+ .. deprecated:: 2.0 use of tuple or scalar positional parameters
+ is deprecated. All params should be dicts or sequences of dicts.
+ Use :meth:`.exec_driver_sql` to execute a plain string with
+ tuple or scalar positional parameters.
+
+ """
+
+ if isinstance(statement, util.string_types):
+ util.warn_deprecated_20(
+ "Passing a string to Connection.execute() is "
+ "deprecated and will be removed in version 2.0. Use the "
+ "text() construct, "
+ "or the Connection.exec_driver_sql() method to invoke a "
+ "driver-level SQL string."
+ )
+
+ return self._exec_driver_sql(
+ statement,
+ multiparams,
+ params,
+ _EMPTY_EXECUTION_OPTS,
+ future=False,
+ )
+
+ try:
+ meth = statement._execute_on_connection
+ except AttributeError as err:
+ util.raise_(
+ exc.ObjectNotExecutableError(statement), replace_context=err
+ )
+ else:
+ return meth(self, multiparams, params, _EMPTY_EXECUTION_OPTS)
+
+ def _execute_function(self, func, multiparams, params, execution_options):
+ """Execute a sql.FunctionElement object."""
+
+ return self._execute_clauseelement(
+ func.select(), multiparams, params, execution_options
+ )
+
+ def _execute_default(
+ self,
+ default,
+ multiparams,
+ params,
+ # migrate is calling this directly :(
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ """Execute a schema.ColumnDefault object."""
+
+ execution_options = self._execution_options.merge_with(
+ execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ default,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ default, distilled_parameters, execution_options
+ )
+
+ try:
+ conn = self._dbapi_connection
+ if conn is None:
+ conn = self._revalidate_connection()
+
+ dialect = self.dialect
+ ctx = dialect.execution_ctx_cls._init_default(
+ dialect, self, conn, execution_options
+ )
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ ret = ctx._exec_default(None, default, None)
+ if self.should_close_with_result:
+ self.close()
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ default,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+
+ return ret
+
+ def _execute_ddl(self, ddl, multiparams, params, execution_options):
+ """Execute a schema.DDL object."""
+
+ execution_options = ddl._execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ ddl,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ ddl, distilled_parameters, execution_options
+ )
+
+ exec_opts = self._execution_options.merge_with(execution_options)
+ schema_translate_map = exec_opts.get("schema_translate_map", None)
+
+ dialect = self.dialect
+
+ compiled = ddl.compile(
+ dialect=dialect, schema_translate_map=schema_translate_map
+ )
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_ddl,
+ compiled,
+ None,
+ execution_options,
+ compiled,
+ )
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ ddl,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _invoke_before_exec_event(
+ self, elem, distilled_params, execution_options
+ ):
+
+ if len(distilled_params) == 1:
+ event_multiparams, event_params = [], distilled_params[0]
+ else:
+ event_multiparams, event_params = distilled_params, {}
+
+ for fn in self.dispatch.before_execute:
+ elem, event_multiparams, event_params = fn(
+ self,
+ elem,
+ event_multiparams,
+ event_params,
+ execution_options,
+ )
+
+ if event_multiparams:
+ distilled_params = list(event_multiparams)
+ if event_params:
+ raise exc.InvalidRequestError(
+ "Event handler can't return non-empty multiparams "
+ "and params at the same time"
+ )
+ elif event_params:
+ distilled_params = [event_params]
+ else:
+ distilled_params = []
+
+ return elem, distilled_params, event_multiparams, event_params
+
+ def _execute_clauseelement(
+ self, elem, multiparams, params, execution_options
+ ):
+ """Execute a sql.ClauseElement object."""
+
+ execution_options = elem._execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+
+ distilled_params = _distill_params(self, multiparams, params)
+
+ has_events = self._has_events or self.engine._has_events
+ if has_events:
+ (
+ elem,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ elem, distilled_params, execution_options
+ )
+
+ if distilled_params:
+ # ensure we don't retain a link to the view object for keys()
+ # which links to the values, which we don't want to cache
+ keys = sorted(distilled_params[0])
+ for_executemany = len(distilled_params) > 1
+ else:
+ keys = []
+ for_executemany = False
+
+ dialect = self.dialect
+
+ schema_translate_map = execution_options.get(
+ "schema_translate_map", None
+ )
+
+ compiled_cache = execution_options.get(
+ "compiled_cache", self.engine._compiled_cache
+ )
+
+ compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
+ dialect=dialect,
+ compiled_cache=compiled_cache,
+ column_keys=keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ )
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_compiled,
+ compiled_sql,
+ distilled_params,
+ execution_options,
+ compiled_sql,
+ distilled_params,
+ elem,
+ extracted_params,
+ cache_hit=cache_hit,
+ )
+ if has_events:
+ self.dispatch.after_execute(
+ self,
+ elem,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _execute_compiled(
+ self,
+ compiled,
+ multiparams,
+ params,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ """Execute a sql.Compiled object.
+
+ TODO: why do we have this? likely deprecate or remove
+
+ """
+
+ execution_options = compiled.execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ compiled,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ compiled, distilled_parameters, execution_options
+ )
+
+ dialect = self.dialect
+
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_compiled,
+ compiled,
+ distilled_parameters,
+ execution_options,
+ compiled,
+ distilled_parameters,
+ None,
+ None,
+ )
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ compiled,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _exec_driver_sql(
+ self, statement, multiparams, params, execution_options, future
+ ):
+
+ execution_options = self._execution_options.merge_with(
+ execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ (
+ statement,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ statement, distilled_parameters, execution_options
+ )
+
+ dialect = self.dialect
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_statement,
+ statement,
+ distilled_parameters,
+ execution_options,
+ statement,
+ distilled_parameters,
+ )
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ statement,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _execute_20(
+ self,
+ statement,
+ parameters=None,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ args_10style, kwargs_10style = _distill_params_20(parameters)
+ try:
+ meth = statement._execute_on_connection
+ except AttributeError as err:
+ util.raise_(
+ exc.ObjectNotExecutableError(statement), replace_context=err
+ )
+ else:
+ return meth(self, args_10style, kwargs_10style, execution_options)
+
+ def exec_driver_sql(
+ self, statement, parameters=None, execution_options=None
+ ):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.CursorResult`.
+
+ :param statement: The statement str to be executed. Bound parameters
+ must use the underlying DBAPI's paramstyle, such as "qmark",
+ "pyformat", "format", etc.
+
+ :param parameters: represent bound parameter values to be used in the
+ execution. The format is one of: a dictionary of named parameters,
+ a tuple of positional parameters, or a list containing either
+ dictionaries or tuples for multiple-execute support.
+
+ E.g. multiple dictionaries::
+
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
+ [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}]
+ )
+
+ Single dictionary::
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
+ dict(id=1, value="v1")
+ )
+
+ Single tuple::
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ (1, 'v1')
+ )
+
+ .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does
+ not participate in the
+ :meth:`_events.ConnectionEvents.before_execute` and
+ :meth:`_events.ConnectionEvents.after_execute` events. To
+ intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use
+ :meth:`_events.ConnectionEvents.before_cursor_execute` and
+ :meth:`_events.ConnectionEvents.after_cursor_execute`.
+
+ .. seealso::
+
+ :pep:`249`
+
+ """
+
+ args_10style, kwargs_10style = _distill_params_20(parameters)
+
+ return self._exec_driver_sql(
+ statement,
+ args_10style,
+ kwargs_10style,
+ execution_options,
+ future=True,
+ )
+
+ def _execute_context(
+ self,
+ dialect,
+ constructor,
+ statement,
+ parameters,
+ execution_options,
+ *args,
+ **kw
+ ):
+ """Create an :class:`.ExecutionContext` and execute, returning
+ a :class:`_engine.CursorResult`."""
+
+ branched = self
+ if self.__branch_from:
+ # if this is a "branched" connection, do everything in terms
+ # of the "root" connection, *except* for .close(), which is
+ # the only feature that branching provides
+ self = self.__branch_from
+
+ if execution_options:
+ yp = execution_options.get("yield_per", None)
+ if yp:
+ execution_options = execution_options.union(
+ {"stream_results": True, "max_row_buffer": yp}
+ )
+
+ try:
+ conn = self._dbapi_connection
+ if conn is None:
+ conn = self._revalidate_connection()
+
+ context = constructor(
+ dialect, self, conn, execution_options, *args, **kw
+ )
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, util.text_type(statement), parameters, None, None
+ )
+
+ if (
+ self._transaction
+ and not self._transaction.is_active
+ or (
+ self._nested_transaction
+ and not self._nested_transaction.is_active
+ )
+ ):
+ self._invalid_transaction()
+
+ elif self._trans_context_manager:
+ TransactionalContext._trans_ctx_check(self)
+
+ if self._is_future and self._transaction is None:
+ self._autobegin()
+
+ context.pre_exec()
+
+ if dialect.use_setinputsizes:
+ context._set_input_sizes()
+
+ cursor, statement, parameters = (
+ context.cursor,
+ context.statement,
+ context.parameters,
+ )
+
+ if not context.executemany:
+ parameters = parameters[0]
+
+ if self._has_events or self.engine._has_events:
+ for fn in self.dispatch.before_cursor_execute:
+ statement, parameters = fn(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
+
+ if self._echo:
+
+ self._log_info(statement)
+
+ stats = context._get_cache_stats()
+
+ if not self.engine.hide_parameters:
+ self._log_info(
+ "[%s] %r",
+ stats,
+ sql_util._repr_params(
+ parameters, batches=10, ismulti=context.executemany
+ ),
+ )
+ else:
+ self._log_info(
+ "[%s] [SQL parameters hidden due to hide_parameters=True]"
+ % (stats,)
+ )
+
+ evt_handled = False
+ try:
+ if context.executemany:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_executemany:
+ if fn(cursor, statement, parameters, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_executemany(
+ cursor, statement, parameters, context
+ )
+ elif not parameters and context.no_parameters:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_execute_no_params:
+ if fn(cursor, statement, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_execute_no_params(
+ cursor, statement, context
+ )
+ else:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_execute:
+ if fn(cursor, statement, parameters, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_execute(
+ cursor, statement, parameters, context
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_cursor_execute(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
+
+ context.post_exec()
+
+ result = context._setup_result_proxy()
+
+ if not self._is_future:
+ should_close_with_result = branched.should_close_with_result
+
+ if not result._soft_closed and should_close_with_result:
+ result._autoclose_connection = True
+
+ if (
+ # usually we're in a transaction so avoid relatively
+ # expensive / legacy should_autocommit call
+ self._transaction is None
+ and context.should_autocommit
+ ):
+ self._commit_impl(autocommit=True)
+
+ # for "connectionless" execution, we have to close this
+ # Connection after the statement is complete.
+ # legacy stuff.
+ if should_close_with_result and context._soft_closed:
+ assert not self._is_future
+
+ # CursorResult already exhausted rows / has no rows.
+ # close us now
+ branched.close()
+
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, statement, parameters, cursor, context
+ )
+
+ return result
+
+ def _cursor_execute(self, cursor, statement, parameters, context=None):
+ """Execute a statement + params on the given cursor.
+
+ Adds appropriate logging and exception handling.
+
+ This method is used by DefaultDialect for special-case
+ executions, such as for sequences and column defaults.
+ The path of statement execution in the majority of cases
+ terminates at _execute_context().
+
+ """
+ if self._has_events or self.engine._has_events:
+ for fn in self.dispatch.before_cursor_execute:
+ statement, parameters = fn(
+ self, cursor, statement, parameters, context, False
+ )
+
+ if self._echo:
+ self._log_info(statement)
+ self._log_info("[raw sql] %r", parameters)
+ try:
+ for fn in (
+ ()
+ if not self.dialect._has_events
+ else self.dialect.dispatch.do_execute
+ ):
+ if fn(cursor, statement, parameters, context):
+ break
+ else:
+ self.dialect.do_execute(cursor, statement, parameters, context)
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, statement, parameters, cursor, context
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_cursor_execute(
+ self, cursor, statement, parameters, context, False
+ )
+
+ def _safe_close_cursor(self, cursor):
+ """Close the given cursor, catching exceptions
+ and turning into log warnings.
+
+ """
+ try:
+ cursor.close()
+ except Exception:
+ # log the error through the connection pool's logger.
+ self.engine.pool.logger.error(
+ "Error closing cursor", exc_info=True
+ )
+
+ _reentrant_error = False
+ _is_disconnect = False
+
+ def _handle_dbapi_exception(
+ self, e, statement, parameters, cursor, context
+ ):
+ exc_info = sys.exc_info()
+
+ is_exit_exception = util.is_exit_exception(e)
+
+ if not self._is_disconnect:
+ self._is_disconnect = (
+ isinstance(e, self.dialect.dbapi.Error)
+ and not self.closed
+ and self.dialect.is_disconnect(
+ e,
+ self._dbapi_connection if not self.invalidated else None,
+ cursor,
+ )
+ ) or (is_exit_exception and not self.closed)
+
+ invalidate_pool_on_disconnect = not is_exit_exception
+
+ if self._reentrant_error:
+ util.raise_(
+ exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ hide_parameters=self.engine.hide_parameters,
+ dialect=self.dialect,
+ ismulti=context.executemany
+ if context is not None
+ else None,
+ ),
+ with_traceback=exc_info[2],
+ from_=e,
+ )
+ self._reentrant_error = True
+ try:
+ # non-DBAPI error - if we already got a context,
+ # or there's no string statement, don't wrap it
+ should_wrap = isinstance(e, self.dialect.dbapi.Error) or (
+ statement is not None
+ and context is None
+ and not is_exit_exception
+ )
+
+ if should_wrap:
+ sqlalchemy_exception = exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ hide_parameters=self.engine.hide_parameters,
+ connection_invalidated=self._is_disconnect,
+ dialect=self.dialect,
+ ismulti=context.executemany
+ if context is not None
+ else None,
+ )
+ else:
+ sqlalchemy_exception = None
+
+ newraise = None
+
+ if (
+ self._has_events or self.engine._has_events
+ ) and not self._execution_options.get(
+ "skip_user_error_events", False
+ ):
+ ctx = ExceptionContextImpl(
+ e,
+ sqlalchemy_exception,
+ self.engine,
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ self._is_disconnect,
+ invalidate_pool_on_disconnect,
+ )
+
+ for fn in self.dispatch.handle_error:
+ try:
+ # handler returns an exception;
+ # call next handler in a chain
+ per_fn = fn(ctx)
+ if per_fn is not None:
+ ctx.chained_exception = newraise = per_fn
+ except Exception as _raised:
+ # handler raises an exception - stop processing
+ newraise = _raised
+ break
+
+ if self._is_disconnect != ctx.is_disconnect:
+ self._is_disconnect = ctx.is_disconnect
+ if sqlalchemy_exception:
+ sqlalchemy_exception.connection_invalidated = (
+ ctx.is_disconnect
+ )
+
+ # set up potentially user-defined value for
+ # invalidate pool.
+ invalidate_pool_on_disconnect = (
+ ctx.invalidate_pool_on_disconnect
+ )
+
+ if should_wrap and context:
+ context.handle_dbapi_exception(e)
+
+ if not self._is_disconnect:
+ if cursor:
+ self._safe_close_cursor(cursor)
+ with util.safe_reraise(warn_only=True):
+ self._autorollback()
+
+ if newraise:
+ util.raise_(newraise, with_traceback=exc_info[2], from_=e)
+ elif should_wrap:
+ util.raise_(
+ sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+ )
+ else:
+ util.raise_(exc_info[1], with_traceback=exc_info[2])
+
+ finally:
+ del self._reentrant_error
+ if self._is_disconnect:
+ del self._is_disconnect
+ if not self.invalidated:
+ dbapi_conn_wrapper = self._dbapi_connection
+ if invalidate_pool_on_disconnect:
+ self.engine.pool._invalidate(dbapi_conn_wrapper, e)
+ self.invalidate(e)
+ if self.should_close_with_result:
+ assert not self._is_future
+ self.close()
+
+ @classmethod
+ def _handle_dbapi_exception_noconnection(cls, e, dialect, engine):
+ exc_info = sys.exc_info()
+
+ is_disconnect = dialect.is_disconnect(e, None, None)
+
+ should_wrap = isinstance(e, dialect.dbapi.Error)
+
+ if should_wrap:
+ sqlalchemy_exception = exc.DBAPIError.instance(
+ None,
+ None,
+ e,
+ dialect.dbapi.Error,
+ hide_parameters=engine.hide_parameters,
+ connection_invalidated=is_disconnect,
+ )
+ else:
+ sqlalchemy_exception = None
+
+ newraise = None
+
+ if engine._has_events:
+ ctx = ExceptionContextImpl(
+ e,
+ sqlalchemy_exception,
+ engine,
+ None,
+ None,
+ None,
+ None,
+ None,
+ is_disconnect,
+ True,
+ )
+ for fn in engine.dispatch.handle_error:
+ try:
+ # handler returns an exception;
+ # call next handler in a chain
+ per_fn = fn(ctx)
+ if per_fn is not None:
+ ctx.chained_exception = newraise = per_fn
+ except Exception as _raised:
+ # handler raises an exception - stop processing
+ newraise = _raised
+ break
+
+ if sqlalchemy_exception and is_disconnect != ctx.is_disconnect:
+ sqlalchemy_exception.connection_invalidated = (
+ is_disconnect
+ ) = ctx.is_disconnect
+
+ if newraise:
+ util.raise_(newraise, with_traceback=exc_info[2], from_=e)
+ elif should_wrap:
+ util.raise_(
+ sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+ )
+ else:
+ util.raise_(exc_info[1], with_traceback=exc_info[2])
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ """run a DDL visitor.
+
+ This method is only here so that the MockConnection can change the
+ options given to the visitor so that "checkfirst" is skipped.
+
+ """
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Connection.transaction` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context manager instead.",
+ )
+ def transaction(self, callable_, *args, **kwargs):
+ r"""Execute the given function within a transaction boundary.
+
+ The function is passed this :class:`_engine.Connection`
+ as the first argument, followed by the given \*args and \**kwargs,
+ e.g.::
+
+ def do_something(conn, x, y):
+ conn.execute(text("some statement"), {'x':x, 'y':y})
+
+ conn.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`_engine.Connection.begin`::
+
+ with conn.begin():
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ As well as with :meth:`_engine.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ .. seealso::
+
+ :meth:`_engine.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`_engine.Engine.transaction` - engine-level version of
+ :meth:`_engine.Connection.transaction`
+
+ """
+
+ kwargs["_sa_skip_warning"] = True
+ trans = self.begin()
+ try:
+ ret = self.run_callable(callable_, *args, **kwargs)
+ trans.commit()
+ return ret
+ except:
+ with util.safe_reraise():
+ trans.rollback()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Connection.run_callable` "
+ "method is deprecated and will "
+ "be removed in a future release. Invoke the callable function "
+ "directly, passing the Connection.",
+ )
+ def run_callable(self, callable_, *args, **kwargs):
+ r"""Given a callable object or function, execute it, passing
+ a :class:`_engine.Connection` as the first argument.
+
+ The given \*args and \**kwargs are passed subsequent
+ to the :class:`_engine.Connection` argument.
+
+ This function, along with :meth:`_engine.Engine.run_callable`,
+ allows a function to be run with a :class:`_engine.Connection`
+ or :class:`_engine.Engine` object without the need to know
+ which one is being dealt with.
+
+ """
+ return callable_(self, *args, **kwargs)
+
+
+class ExceptionContextImpl(ExceptionContext):
+ """Implement the :class:`.ExceptionContext` interface."""
+
+ def __init__(
+ self,
+ exception,
+ sqlalchemy_exception,
+ engine,
+ connection,
+ cursor,
+ statement,
+ parameters,
+ context,
+ is_disconnect,
+ invalidate_pool_on_disconnect,
+ ):
+ self.engine = engine
+ self.connection = connection
+ self.sqlalchemy_exception = sqlalchemy_exception
+ self.original_exception = exception
+ self.execution_context = context
+ self.statement = statement
+ self.parameters = parameters
+ self.is_disconnect = is_disconnect
+ self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect
+
+
+class Transaction(TransactionalContext):
+ """Represent a database transaction in progress.
+
+ The :class:`.Transaction` object is procured by
+ calling the :meth:`_engine.Connection.begin` method of
+ :class:`_engine.Connection`::
+
+ from sqlalchemy import create_engine
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+ connection = engine.connect()
+ trans = connection.begin()
+ connection.execute(text("insert into x (a, b) values (1, 2)"))
+ trans.commit()
+
+ The object provides :meth:`.rollback` and :meth:`.commit`
+ methods in order to control transaction boundaries. It
+ also implements a context manager interface so that
+ the Python ``with`` statement can be used with the
+ :meth:`_engine.Connection.begin` method::
+
+ with connection.begin():
+ connection.execute(text("insert into x (a, b) values (1, 2)"))
+
+ The Transaction object is **not** threadsafe.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :meth:`_engine.Connection.begin_twophase`
+
+ :meth:`_engine.Connection.begin_nested`
+
+ .. index::
+ single: thread safety; Transaction
+ """
+
+ __slots__ = ()
+
+ _is_root = False
+
+ def __init__(self, connection):
+ raise NotImplementedError()
+
+ def _do_deactivate(self):
+ """do whatever steps are necessary to set this transaction as
+ "deactive", however leave this transaction object in place as far
+ as the connection's state.
+
+ for a "real" transaction this should roll back the transaction
+ and ensure this transaction is no longer a reset agent.
+
+ this is used for nesting of marker transactions where the marker
+ can set the "real" transaction as rolled back, however it stays
+ in place.
+
+ for 2.0 we hope to remove this nesting feature.
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def _deactivated_from_connection(self):
+ """True if this transaction is totally deactivated from the connection
+ and therefore can no longer affect its state.
+
+ """
+ raise NotImplementedError()
+
+ def _do_close(self):
+ raise NotImplementedError()
+
+ def _do_rollback(self):
+ raise NotImplementedError()
+
+ def _do_commit(self):
+ raise NotImplementedError()
+
+ @property
+ def is_valid(self):
+ return self.is_active and not self.connection.invalidated
+
+ def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ try:
+ self._do_close()
+ finally:
+ assert not self.is_active
+
+ def rollback(self):
+ """Roll back this :class:`.Transaction`.
+
+ The implementation of this may vary based on the type of transaction in
+ use:
+
+ * For a simple database transaction (e.g. :class:`.RootTransaction`),
+ it corresponds to a ROLLBACK.
+
+ * For a :class:`.NestedTransaction`, it corresponds to a
+ "ROLLBACK TO SAVEPOINT" operation.
+
+ * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two
+ phase transactions may be used.
+
+
+ """
+ try:
+ self._do_rollback()
+ finally:
+ assert not self.is_active
+
+ def commit(self):
+ """Commit this :class:`.Transaction`.
+
+ The implementation of this may vary based on the type of transaction in
+ use:
+
+ * For a simple database transaction (e.g. :class:`.RootTransaction`),
+ it corresponds to a COMMIT.
+
+ * For a :class:`.NestedTransaction`, it corresponds to a
+ "RELEASE SAVEPOINT" operation.
+
+ * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two
+ phase transactions may be used.
+
+ """
+ try:
+ self._do_commit()
+ finally:
+ assert not self.is_active
+
+ def _get_subject(self):
+ return self.connection
+
+ def _transaction_is_active(self):
+ return self.is_active
+
+ def _transaction_is_closed(self):
+ return not self._deactivated_from_connection
+
+ def _rollback_can_be_called(self):
+ # for RootTransaction / NestedTransaction, it's safe to call
+ # rollback() even if the transaction is deactive and no warnings
+ # will be emitted. tested in
+ # test_transaction.py -> test_no_rollback_in_deactive(?:_savepoint)?
+ return True
+
+
+class MarkerTransaction(Transaction):
+ """A 'marker' transaction that is used for nested begin() calls.
+
+ .. deprecated:: 1.4 future connection for 2.0 won't support this pattern.
+
+ """
+
+ __slots__ = ("connection", "_is_active", "_transaction")
+
+ def __init__(self, connection):
+ assert connection._transaction is not None
+ if not connection._transaction.is_active:
+ raise exc.InvalidRequestError(
+ "the current transaction on this connection is inactive. "
+ "Please issue a rollback first."
+ )
+
+ assert not connection._is_future
+ util.warn_deprecated_20(
+ "Calling .begin() when a transaction is already begun, creating "
+ "a 'sub' transaction, is deprecated "
+ "and will be removed in 2.0. See the documentation section "
+ "'Migrating from the nesting pattern' for background on how "
+ "to migrate from this pattern."
+ )
+
+ self.connection = connection
+
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+
+ if connection._nested_transaction is not None:
+ self._transaction = connection._nested_transaction
+ else:
+ self._transaction = connection._transaction
+ self._is_active = True
+
+ @property
+ def _deactivated_from_connection(self):
+ return not self.is_active
+
+ @property
+ def is_active(self):
+ return self._is_active and self._transaction.is_active
+
+ def _deactivate(self):
+ self._is_active = False
+
+ def _do_close(self):
+ # does not actually roll back the root
+ self._deactivate()
+
+ def _do_rollback(self):
+ # does roll back the root
+ if self._is_active:
+ try:
+ self._transaction._do_deactivate()
+ finally:
+ self._deactivate()
+
+ def _do_commit(self):
+ self._deactivate()
+
+
+class RootTransaction(Transaction):
+ """Represent the "root" transaction on a :class:`_engine.Connection`.
+
+ This corresponds to the current "BEGIN/COMMIT/ROLLBACK" that's occurring
+ for the :class:`_engine.Connection`. The :class:`_engine.RootTransaction`
+ is created by calling upon the :meth:`_engine.Connection.begin` method, and
+ remains associated with the :class:`_engine.Connection` throughout its
+ active span. The current :class:`_engine.RootTransaction` in use is
+ accessible via the :attr:`_engine.Connection.get_transaction` method of
+ :class:`_engine.Connection`.
+
+ In :term:`2.0 style` use, the :class:`_future.Connection` also employs
+ "autobegin" behavior that will create a new
+ :class:`_engine.RootTransaction` whenever a connection in a
+ non-transactional state is used to emit commands on the DBAPI connection.
+ The scope of the :class:`_engine.RootTransaction` in 2.0 style
+ use can be controlled using the :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods.
+
+
+ """
+
+ _is_root = True
+
+ __slots__ = ("connection", "is_active")
+
+ def __init__(self, connection):
+ assert connection._transaction is None
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+ self.connection = connection
+ self._connection_begin_impl()
+ connection._transaction = self
+
+ self.is_active = True
+
+ def _deactivate_from_connection(self):
+ if self.is_active:
+ assert self.connection._transaction is self
+ self.is_active = False
+
+ elif self.connection._transaction is not self:
+ util.warn("transaction already deassociated from connection")
+
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._transaction is not self
+
+ def _do_deactivate(self):
+ # called from a MarkerTransaction to cancel this root transaction.
+ # the transaction stays in place as connection._transaction, but
+ # is no longer active and is no longer the reset agent for the
+ # pooled connection. the connection won't support a new begin()
+ # until this transaction is explicitly closed, rolled back,
+ # or committed.
+
+ assert self.connection._transaction is self
+
+ if self.is_active:
+ self._connection_rollback_impl()
+
+ # handle case where a savepoint was created inside of a marker
+ # transaction that refers to a root. nested has to be cancelled
+ # also.
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+
+ self._deactivate_from_connection()
+
+ def _connection_begin_impl(self):
+ self.connection._begin_impl(self)
+
+ def _connection_rollback_impl(self):
+ self.connection._rollback_impl()
+
+ def _connection_commit_impl(self):
+ self.connection._commit_impl()
+
+ def _close_impl(self, try_deactivate=False):
+ try:
+ if self.is_active:
+ self._connection_rollback_impl()
+
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+ finally:
+ if self.is_active or try_deactivate:
+ self._deactivate_from_connection()
+ if self.connection._transaction is self:
+ self.connection._transaction = None
+
+ assert not self.is_active
+ assert self.connection._transaction is not self
+
+ def _do_close(self):
+ self._close_impl()
+
+ def _do_rollback(self):
+ self._close_impl(try_deactivate=True)
+
+ def _do_commit(self):
+ if self.is_active:
+ assert self.connection._transaction is self
+
+ try:
+ self._connection_commit_impl()
+ finally:
+ # whether or not commit succeeds, cancel any
+ # nested transactions, make this transaction "inactive"
+ # and remove it as a reset agent
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+
+ self._deactivate_from_connection()
+
+ # ...however only remove as the connection's current transaction
+ # if commit succeeded. otherwise it stays on so that a rollback
+ # needs to occur.
+ self.connection._transaction = None
+ else:
+ if self.connection._transaction is self:
+ self.connection._invalid_transaction()
+ else:
+ raise exc.InvalidRequestError("This transaction is inactive")
+
+ assert not self.is_active
+ assert self.connection._transaction is not self
+
+
+class NestedTransaction(Transaction):
+ """Represent a 'nested', or SAVEPOINT transaction.
+
+ The :class:`.NestedTransaction` object is created by calling the
+ :meth:`_engine.Connection.begin_nested` method of
+ :class:`_engine.Connection`.
+
+ When using :class:`.NestedTransaction`, the semantics of "begin" /
+ "commit" / "rollback" are as follows:
+
+ * the "begin" operation corresponds to the "BEGIN SAVEPOINT" command, where
+ the savepoint is given an explicit name that is part of the state
+ of this object.
+
+ * The :meth:`.NestedTransaction.commit` method corresponds to a
+ "RELEASE SAVEPOINT" operation, using the savepoint identifier associated
+ with this :class:`.NestedTransaction`.
+
+ * The :meth:`.NestedTransaction.rollback` method corresponds to a
+ "ROLLBACK TO SAVEPOINT" operation, using the savepoint identifier
+ associated with this :class:`.NestedTransaction`.
+
+ The rationale for mimicking the semantics of an outer transaction in
+ terms of savepoints so that code may deal with a "savepoint" transaction
+ and an "outer" transaction in an agnostic way.
+
+ .. seealso::
+
+ :ref:`session_begin_nested` - ORM version of the SAVEPOINT API.
+
+ """
+
+ __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested")
+
+ def __init__(self, connection):
+ assert connection._transaction is not None
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+ self.connection = connection
+ self._savepoint = self.connection._savepoint_impl()
+ self.is_active = True
+ self._previous_nested = connection._nested_transaction
+ connection._nested_transaction = self
+
+ def _deactivate_from_connection(self, warn=True):
+ if self.connection._nested_transaction is self:
+ self.connection._nested_transaction = self._previous_nested
+ elif warn:
+ util.warn(
+ "nested transaction already deassociated from connection"
+ )
+
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._nested_transaction is not self
+
+ def _cancel(self):
+ # called by RootTransaction when the outer transaction is
+ # committed, rolled back, or closed to cancel all savepoints
+ # without any action being taken
+ self.is_active = False
+ self._deactivate_from_connection()
+ if self._previous_nested:
+ self._previous_nested._cancel()
+
+ def _close_impl(self, deactivate_from_connection, warn_already_deactive):
+ try:
+ if self.is_active and self.connection._transaction.is_active:
+ self.connection._rollback_to_savepoint_impl(self._savepoint)
+ finally:
+ self.is_active = False
+
+ if deactivate_from_connection:
+ self._deactivate_from_connection(warn=warn_already_deactive)
+
+ assert not self.is_active
+ if deactivate_from_connection:
+ assert self.connection._nested_transaction is not self
+
+ def _do_deactivate(self):
+ self._close_impl(False, False)
+
+ def _do_close(self):
+ self._close_impl(True, False)
+
+ def _do_rollback(self):
+ self._close_impl(True, True)
+
+ def _do_commit(self):
+ if self.is_active:
+ try:
+ self.connection._release_savepoint_impl(self._savepoint)
+ finally:
+ # nested trans becomes inactive on failed release
+ # unconditionally. this prevents it from trying to
+ # emit SQL when it rolls back.
+ self.is_active = False
+
+ # but only de-associate from connection if it succeeded
+ self._deactivate_from_connection()
+ else:
+ if self.connection._nested_transaction is self:
+ self.connection._invalid_transaction()
+ else:
+ raise exc.InvalidRequestError(
+ "This nested transaction is inactive"
+ )
+
+
+class TwoPhaseTransaction(RootTransaction):
+ """Represent a two-phase transaction.
+
+ A new :class:`.TwoPhaseTransaction` object may be procured
+ using the :meth:`_engine.Connection.begin_twophase` method.
+
+ The interface is the same as that of :class:`.Transaction`
+ with the addition of the :meth:`prepare` method.
+
+ """
+
+ __slots__ = ("connection", "is_active", "xid", "_is_prepared")
+
+ def __init__(self, connection, xid):
+ self._is_prepared = False
+ self.xid = xid
+ super(TwoPhaseTransaction, self).__init__(connection)
+
+ def prepare(self):
+ """Prepare this :class:`.TwoPhaseTransaction`.
+
+ After a PREPARE, the transaction can be committed.
+
+ """
+ if not self.is_active:
+ raise exc.InvalidRequestError("This transaction is inactive")
+ self.connection._prepare_twophase_impl(self.xid)
+ self._is_prepared = True
+
+ def _connection_begin_impl(self):
+ self.connection._begin_twophase_impl(self)
+
+ def _connection_rollback_impl(self):
+ self.connection._rollback_twophase_impl(self.xid, self._is_prepared)
+
+ def _connection_commit_impl(self):
+ self.connection._commit_twophase_impl(self.xid, self._is_prepared)
+
+
+class Engine(Connectable, log.Identified):
+ """
+ Connects a :class:`~sqlalchemy.pool.Pool` and
+ :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a
+ source of database connectivity and behavior.
+
+ This is the **SQLAlchemy 1.x version** of :class:`_engine.Engine`. For
+ the :term:`2.0 style` version, which includes some API differences,
+ see :class:`_future.Engine`.
+
+ An :class:`_engine.Engine` object is instantiated publicly using the
+ :func:`~sqlalchemy.create_engine` function.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :ref:`connections_toplevel`
+
+ """
+
+ _execution_options = _EMPTY_EXECUTION_OPTS
+ _has_events = False
+ _connection_cls = Connection
+ _sqla_logger_namespace = "sqlalchemy.engine.Engine"
+ _is_future = False
+
+ _schema_translate_map = None
+
+ def __init__(
+ self,
+ pool,
+ dialect,
+ url,
+ logging_name=None,
+ echo=None,
+ query_cache_size=500,
+ execution_options=None,
+ hide_parameters=False,
+ ):
+ self.pool = pool
+ self.url = url
+ self.dialect = dialect
+ if logging_name:
+ self.logging_name = logging_name
+ self.echo = echo
+ self.hide_parameters = hide_parameters
+ if query_cache_size != 0:
+ self._compiled_cache = util.LRUCache(
+ query_cache_size, size_alert=self._lru_size_alert
+ )
+ else:
+ self._compiled_cache = None
+ log.instance_logger(self, echoflag=echo)
+ if execution_options:
+ self.update_execution_options(**execution_options)
+
+ def _lru_size_alert(self, cache):
+ if self._should_log_info:
+ self.logger.info(
+ "Compiled cache size pruning from %d items to %d. "
+ "Increase cache size to reduce the frequency of pruning.",
+ len(cache),
+ cache.capacity,
+ )
+
+ @property
+ def engine(self):
+ return self
+
+ def clear_compiled_cache(self):
+ """Clear the compiled cache associated with the dialect.
+
+ This applies **only** to the built-in cache that is established
+ via the :paramref:`_engine.create_engine.query_cache_size` parameter.
+ It will not impact any dictionary caches that were passed via the
+ :paramref:`.Connection.execution_options.query_cache` parameter.
+
+ .. versionadded:: 1.4
+
+ """
+ if self._compiled_cache:
+ self._compiled_cache.clear()
+
+ def update_execution_options(self, **opt):
+ r"""Update the default execution_options dictionary
+ of this :class:`_engine.Engine`.
+
+ The given keys/values in \**opt are added to the
+ default execution options that will be used for
+ all connections. The initial contents of this dictionary
+ can be sent via the ``execution_options`` parameter
+ to :func:`_sa.create_engine`.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+
+ :meth:`_engine.Engine.execution_options`
+
+ """
+ self._execution_options = self._execution_options.union(opt)
+ self.dispatch.set_engine_execution_options(self, opt)
+ self.dialect.set_engine_execution_options(self, opt)
+
+ def execution_options(self, **opt):
+ """Return a new :class:`_engine.Engine` that will provide
+ :class:`_engine.Connection` objects with the given execution options.
+
+ The returned :class:`_engine.Engine` remains related to the original
+ :class:`_engine.Engine` in that it shares the same connection pool and
+ other state:
+
+ * The :class:`_pool.Pool` used by the new :class:`_engine.Engine`
+ is the
+ same instance. The :meth:`_engine.Engine.dispose`
+ method will replace
+ the connection pool instance for the parent engine as well
+ as this one.
+ * Event listeners are "cascaded" - meaning, the new
+ :class:`_engine.Engine`
+ inherits the events of the parent, and new events can be associated
+ with the new :class:`_engine.Engine` individually.
+ * The logging configuration and logging_name is copied from the parent
+ :class:`_engine.Engine`.
+
+ The intent of the :meth:`_engine.Engine.execution_options` method is
+ to implement "sharding" schemes where multiple :class:`_engine.Engine`
+ objects refer to the same connection pool, but are differentiated
+ by options that would be consumed by a custom event::
+
+ primary_engine = create_engine("mysql://")
+ shard1 = primary_engine.execution_options(shard_id="shard1")
+ shard2 = primary_engine.execution_options(shard_id="shard2")
+
+ Above, the ``shard1`` engine serves as a factory for
+ :class:`_engine.Connection`
+ objects that will contain the execution option
+ ``shard_id=shard1``, and ``shard2`` will produce
+ :class:`_engine.Connection`
+ objects that contain the execution option ``shard_id=shard2``.
+
+ An event handler can consume the above execution option to perform
+ a schema switch or other operation, given a connection. Below
+ we emit a MySQL ``use`` statement to switch databases, at the same
+ time keeping track of which database we've established using the
+ :attr:`_engine.Connection.info` dictionary,
+ which gives us a persistent
+ storage space that follows the DBAPI connection::
+
+ from sqlalchemy import event
+ from sqlalchemy.engine import Engine
+
+ shards = {"default": "base", shard_1: "db1", "shard_2": "db2"}
+
+ @event.listens_for(Engine, "before_cursor_execute")
+ def _switch_shard(conn, cursor, stmt,
+ params, context, executemany):
+ shard_id = conn._execution_options.get('shard_id', "default")
+ current_shard = conn.info.get("current_shard", None)
+
+ if current_shard != shard_id:
+ cursor.execute("use %s" % shards[shard_id])
+ conn.info["current_shard"] = shard_id
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+ - update execution options
+ on a :class:`_engine.Connection` object.
+
+ :meth:`_engine.Engine.update_execution_options`
+ - update the execution
+ options for a given :class:`_engine.Engine` in place.
+
+ :meth:`_engine.Engine.get_execution_options`
+
+
+ """
+ return self._option_cls(self, opt)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded: 1.3
+
+ .. seealso::
+
+ :meth:`_engine.Engine.execution_options`
+ """
+ return self._execution_options
+
+ @property
+ def name(self):
+ """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
+ in use by this :class:`Engine`."""
+
+ return self.dialect.name
+
+ @property
+ def driver(self):
+ """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
+ in use by this :class:`Engine`."""
+
+ return self.dialect.driver
+
+ echo = log.echo_property()
+
+ def __repr__(self):
+ return "Engine(%r)" % (self.url,)
+
+ def dispose(self, close=True):
+ """Dispose of the connection pool used by this
+ :class:`_engine.Engine`.
+
+ A new connection pool is created immediately after the old one has been
+ disposed. The previous connection pool is disposed either actively, by
+ closing out all currently checked-in connections in that pool, or
+ passively, by losing references to it but otherwise not closing any
+ connections. The latter strategy is more appropriate for an initializer
+ in a forked Python process.
+
+ :param close: if left at its default of ``True``, has the
+ effect of fully closing all **currently checked in**
+ database connections. Connections that are still checked out
+ will **not** be closed, however they will no longer be associated
+ with this :class:`_engine.Engine`,
+ so when they are closed individually, eventually the
+ :class:`_pool.Pool` which they are associated with will
+ be garbage collected and they will be closed out fully, if
+ not already closed on checkin.
+
+ If set to ``False``, the previous connection pool is de-referenced,
+ and otherwise not touched in any way.
+
+ .. versionadded:: 1.4.33 Added the :paramref:`.Engine.dispose.close`
+ parameter to allow the replacement of a connection pool in a child
+ process without interfering with the connections used by the parent
+ process.
+
+
+ .. seealso::
+
+ :ref:`engine_disposal`
+
+ :ref:`pooling_multiprocessing`
+
+ """
+ if close:
+ self.pool.dispose()
+ self.pool = self.pool.recreate()
+ self.dispatch.engine_disposed(self)
+
+ def _execute_default(
+ self, default, multiparams=(), params=util.EMPTY_DICT
+ ):
+ with self.connect() as conn:
+ return conn._execute_default(default, multiparams, params)
+
+ @contextlib.contextmanager
+ def _optional_conn_ctx_manager(self, connection=None):
+ if connection is None:
+ with self.connect() as conn:
+ yield conn
+ else:
+ yield connection
+
+ class _trans_ctx(object):
+ def __init__(self, conn, transaction, close_with_result):
+ self.conn = conn
+ self.transaction = transaction
+ self.close_with_result = close_with_result
+
+ def __enter__(self):
+ self.transaction.__enter__()
+ return self.conn
+
+ def __exit__(self, type_, value, traceback):
+ try:
+ self.transaction.__exit__(type_, value, traceback)
+ finally:
+ if not self.close_with_result:
+ self.conn.close()
+
+ def begin(self, close_with_result=False):
+ """Return a context manager delivering a :class:`_engine.Connection`
+ with a :class:`.Transaction` established.
+
+ E.g.::
+
+ with engine.begin() as conn:
+ conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ conn.execute(text("my_special_procedure(5)"))
+
+ Upon successful operation, the :class:`.Transaction`
+ is committed. If an error is raised, the :class:`.Transaction`
+ is rolled back.
+
+ Legacy use only: the ``close_with_result`` flag is normally ``False``,
+ and indicates that the :class:`_engine.Connection` will be closed when
+ the operation is complete. When set to ``True``, it indicates the
+ :class:`_engine.Connection` is in "single use" mode, where the
+ :class:`_engine.CursorResult` returned by the first call to
+ :meth:`_engine.Connection.execute` will close the
+ :class:`_engine.Connection` when that :class:`_engine.CursorResult` has
+ exhausted all result rows.
+
+ .. seealso::
+
+ :meth:`_engine.Engine.connect` - procure a
+ :class:`_engine.Connection` from
+ an :class:`_engine.Engine`.
+
+ :meth:`_engine.Connection.begin` - start a :class:`.Transaction`
+ for a particular :class:`_engine.Connection`.
+
+ """
+ if self._connection_cls._is_future:
+ conn = self.connect()
+ else:
+ conn = self.connect(close_with_result=close_with_result)
+ try:
+ trans = conn.begin()
+ except:
+ with util.safe_reraise():
+ conn.close()
+ return Engine._trans_ctx(conn, trans, close_with_result)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.transaction` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context "
+ "manager instead.",
+ )
+ def transaction(self, callable_, *args, **kwargs):
+ r"""Execute the given function within a transaction boundary.
+
+ The function is passed a :class:`_engine.Connection` newly procured
+ from :meth:`_engine.Engine.connect` as the first argument,
+ followed by the given \*args and \**kwargs.
+
+ e.g.::
+
+ def do_something(conn, x, y):
+ conn.execute(text("some statement"), {'x':x, 'y':y})
+
+ engine.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`_engine.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ .. seealso::
+
+ :meth:`_engine.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`_engine.Connection.transaction`
+ - connection-level version of
+ :meth:`_engine.Engine.transaction`
+
+ """
+ kwargs["_sa_skip_warning"] = True
+ with self.connect() as conn:
+ return conn.transaction(callable_, *args, **kwargs)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.run_callable` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context manager instead.",
+ )
+ def run_callable(self, callable_, *args, **kwargs):
+ r"""Given a callable object or function, execute it, passing
+ a :class:`_engine.Connection` as the first argument.
+
+ The given \*args and \**kwargs are passed subsequent
+ to the :class:`_engine.Connection` argument.
+
+ This function, along with :meth:`_engine.Connection.run_callable`,
+ allows a function to be run with a :class:`_engine.Connection`
+ or :class:`_engine.Engine` object without the need to know
+ which one is being dealt with.
+
+ """
+ kwargs["_sa_skip_warning"] = True
+ with self.connect() as conn:
+ return conn.run_callable(callable_, *args, **kwargs)
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ with self.begin() as conn:
+ conn._run_ddl_visitor(visitorcallable, element, **kwargs)
+
+ @util.deprecated_20(
+ ":meth:`_engine.Engine.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, statement, *multiparams, **params):
+ """Executes the given construct and returns a
+ :class:`_engine.CursorResult`.
+
+ The arguments are the same as those used by
+ :meth:`_engine.Connection.execute`.
+
+ Here, a :class:`_engine.Connection` is acquired using the
+ :meth:`_engine.Engine.connect` method, and the statement executed
+ with that connection. The returned :class:`_engine.CursorResult`
+ is flagged
+ such that when the :class:`_engine.CursorResult` is exhausted and its
+ underlying cursor is closed, the :class:`_engine.Connection`
+ created here
+ will also be closed, which allows its associated DBAPI connection
+ resource to be returned to the connection pool.
+
+ """
+ connection = self.connect(close_with_result=True)
+ return connection.execute(statement, *multiparams, **params)
+
+ @util.deprecated_20(
+ ":meth:`_engine.Engine.scalar`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`; the :meth:`_future.Result.scalar` "
+ "method can then be "
+ "used to return a scalar result.",
+ )
+ def scalar(self, statement, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying result/cursor is closed after execution.
+ """
+ return self.execute(statement, *multiparams, **params).scalar()
+
+ def _execute_clauseelement(
+ self,
+ elem,
+ multiparams=None,
+ params=None,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ connection = self.connect(close_with_result=True)
+ return connection._execute_clauseelement(
+ elem, multiparams, params, execution_options
+ )
+
+ def _execute_compiled(
+ self,
+ compiled,
+ multiparams,
+ params,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ connection = self.connect(close_with_result=True)
+ return connection._execute_compiled(
+ compiled, multiparams, params, execution_options
+ )
+
+ def connect(self, close_with_result=False):
+ """Return a new :class:`_engine.Connection` object.
+
+ The :class:`_engine.Connection` object is a facade that uses a DBAPI
+ connection internally in order to communicate with the database. This
+ connection is procured from the connection-holding :class:`_pool.Pool`
+ referenced by this :class:`_engine.Engine`. When the
+ :meth:`_engine.Connection.close` method of the
+ :class:`_engine.Connection` object
+ is called, the underlying DBAPI connection is then returned to the
+ connection pool, where it may be used again in a subsequent call to
+ :meth:`_engine.Engine.connect`.
+
+ """
+
+ return self._connection_cls(self, close_with_result=close_with_result)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.table_names` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.get_table_names`.",
+ )
+ def table_names(self, schema=None, connection=None):
+ """Return a list of all table names available in the database.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+
+ :param connection: Optional, use a specified connection.
+ """
+ with self._optional_conn_ctx_manager(connection) as conn:
+ insp = inspection.inspect(conn)
+ return insp.get_table_names(schema)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.has_table` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.has_table`.",
+ )
+ def has_table(self, table_name, schema=None):
+ """Return True if the given backend has a table of the given name.
+
+ .. seealso::
+
+ :ref:`metadata_reflection_inspector` - detailed schema inspection
+ using the :class:`_reflection.Inspector` interface.
+
+ :class:`.quoted_name` - used to pass quoting information along
+ with a schema identifier.
+
+ """
+ with self._optional_conn_ctx_manager(None) as conn:
+ insp = inspection.inspect(conn)
+ return insp.has_table(table_name, schema=schema)
+
+ def _wrap_pool_connect(self, fn, connection):
+ dialect = self.dialect
+ try:
+ return fn()
+ except dialect.dbapi.Error as e:
+ if connection is None:
+ Connection._handle_dbapi_exception_noconnection(
+ e, dialect, self
+ )
+ else:
+ util.raise_(
+ sys.exc_info()[1], with_traceback=sys.exc_info()[2]
+ )
+
+ def raw_connection(self, _connection=None):
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ The returned object is a proxied version of the DBAPI
+ connection object used by the underlying driver in use.
+ The object will have all the same behavior as the real DBAPI
+ connection, except that its ``close()`` method will result in the
+ connection being returned to the pool, rather than being closed
+ for real.
+
+ This method provides direct DBAPI connection access for
+ special situations when the API provided by
+ :class:`_engine.Connection`
+ is not needed. When a :class:`_engine.Connection` object is already
+ present, the DBAPI connection is available using
+ the :attr:`_engine.Connection.connection` accessor.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return self._wrap_pool_connect(self.pool.connect, _connection)
+
+
+class OptionEngineMixin(object):
+ _sa_propagate_class_events = False
+
+ def __init__(self, proxied, execution_options):
+ self._proxied = proxied
+ self.url = proxied.url
+ self.dialect = proxied.dialect
+ self.logging_name = proxied.logging_name
+ self.echo = proxied.echo
+ self._compiled_cache = proxied._compiled_cache
+ self.hide_parameters = proxied.hide_parameters
+ log.instance_logger(self, echoflag=self.echo)
+
+ # note: this will propagate events that are assigned to the parent
+ # engine after this OptionEngine is created. Since we share
+ # the events of the parent we also disallow class-level events
+ # to apply to the OptionEngine class directly.
+ #
+ # the other way this can work would be to transfer existing
+ # events only, using:
+ # self.dispatch._update(proxied.dispatch)
+ #
+ # that might be more appropriate however it would be a behavioral
+ # change for logic that assigns events to the parent engine and
+ # would like it to take effect for the already-created sub-engine.
+ self.dispatch = self.dispatch._join(proxied.dispatch)
+
+ self._execution_options = proxied._execution_options
+ self.update_execution_options(**execution_options)
+
+ def _get_pool(self):
+ return self._proxied.pool
+
+ def _set_pool(self, pool):
+ self._proxied.pool = pool
+
+ pool = property(_get_pool, _set_pool)
+
+ def _get_has_events(self):
+ return self._proxied._has_events or self.__dict__.get(
+ "_has_events", False
+ )
+
+ def _set_has_events(self, value):
+ self.__dict__["_has_events"] = value
+
+ _has_events = property(_get_has_events, _set_has_events)
+
+
+class OptionEngine(OptionEngineMixin, Engine):
+ pass
+
+
+Engine._option_cls = OptionEngine
diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py
new file mode 100644
index 0000000..c00bff4
--- /dev/null
+++ b/lib/sqlalchemy/engine/characteristics.py
@@ -0,0 +1,56 @@
+import abc
+
+from ..util import ABC
+
+
+class ConnectionCharacteristic(ABC):
+ """An abstract base for an object that can set, get and reset a
+ per-connection characteristic, typically one that gets reset when the
+ connection is returned to the connection pool.
+
+ transaction isolation is the canonical example, and the
+ ``IsolationLevelCharacteristic`` implementation provides this for the
+ ``DefaultDialect``.
+
+ The ``ConnectionCharacteristic`` class should call upon the ``Dialect`` for
+ the implementation of each method. The object exists strictly to serve as
+ a dialect visitor that can be placed into the
+ ``DefaultDialect.connection_characteristics`` dictionary where it will take
+ effect for calls to :meth:`_engine.Connection.execution_options` and
+ related APIs.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ transactional = False
+
+ @abc.abstractmethod
+ def reset_characteristic(self, dialect, dbapi_conn):
+ """Reset the characteristic on the connection to its default value."""
+
+ @abc.abstractmethod
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ """set characteristic on the connection to a given value."""
+
+ @abc.abstractmethod
+ def get_characteristic(self, dialect, dbapi_conn):
+ """Given a DBAPI connection, get the current value of the
+ characteristic.
+
+ """
+
+
+class IsolationLevelCharacteristic(ConnectionCharacteristic):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.reset_isolation_level(dbapi_conn)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_isolation_level(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_isolation_level(dbapi_conn)
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
new file mode 100644
index 0000000..b9886b7
--- /dev/null
+++ b/lib/sqlalchemy/engine/create.py
@@ -0,0 +1,743 @@
+# engine/create.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from . import base
+from . import url as _url
+from .mock import create_mock_engine
+from .. import event
+from .. import exc
+from .. import pool as poollib
+from .. import util
+from ..sql import compiler
+
+
+@util.deprecated_params(
+ strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.strategy` keyword is deprecated, "
+ "and the only argument accepted is 'mock'; please use "
+ ":func:`.create_mock_engine` going forward. For general "
+ "customization of create_engine which may have been accomplished "
+ "using strategies, see :class:`.CreateEnginePlugin`.",
+ ),
+ empty_in_strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is "
+ "deprecated, and no longer has any effect. All IN expressions "
+ "are now rendered using "
+ 'the "expanding parameter" strategy which renders a set of bound'
+ 'expressions, or an "empty set" SELECT, at statement execution'
+ "time.",
+ ),
+ case_sensitive=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.case_sensitive` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Applications should work with result column names in a case "
+ "sensitive fashion.",
+ ),
+)
+def create_engine(url, **kwargs):
+ """Create a new :class:`_engine.Engine` instance.
+
+ The standard calling form is to send the :ref:`URL <database_urls>` as the
+ first positional argument, usually a string
+ that indicates database dialect and connection arguments::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+
+ .. note::
+
+ Please review :ref:`database_urls` for general guidelines in composing
+ URL strings. In particular, special characters, such as those often
+ part of passwords, must be URL encoded to be properly parsed.
+
+ Additional keyword arguments may then follow it which
+ establish various options on the resulting :class:`_engine.Engine`
+ and its underlying :class:`.Dialect` and :class:`_pool.Pool`
+ constructs::
+
+ engine = create_engine("mysql://scott:tiger@hostname/dbname",
+ encoding='latin1', echo=True)
+
+ The string form of the URL is
+ ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
+ ``dialect`` is a database name such as ``mysql``, ``oracle``,
+ ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
+ ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
+ the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
+
+ ``**kwargs`` takes a wide variety of options which are routed
+ towards their appropriate components. Arguments may be specific to
+ the :class:`_engine.Engine`, the underlying :class:`.Dialect`,
+ as well as the
+ :class:`_pool.Pool`. Specific dialects also accept keyword arguments that
+ are unique to that dialect. Here, we describe the parameters
+ that are common to most :func:`_sa.create_engine()` usage.
+
+ Once established, the newly resulting :class:`_engine.Engine` will
+ request a connection from the underlying :class:`_pool.Pool` once
+ :meth:`_engine.Engine.connect` is called, or a method which depends on it
+ such as :meth:`_engine.Engine.execute` is invoked. The
+ :class:`_pool.Pool` in turn
+ will establish the first actual DBAPI connection when this request
+ is received. The :func:`_sa.create_engine` call itself does **not**
+ establish any actual DBAPI connections directly.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :doc:`/dialects/index`
+
+ :ref:`connections_toplevel`
+
+ :param case_sensitive: if False, result column names
+ will match in a case-insensitive fashion, that is,
+ ``row['SomeColumn']``.
+
+ :param connect_args: a dictionary of options which will be
+ passed directly to the DBAPI's ``connect()`` method as
+ additional keyword arguments. See the example
+ at :ref:`custom_dbapi_args`.
+
+ :param convert_unicode=False: if set to True, causes
+ all :class:`.String` datatypes to act as though the
+ :paramref:`.String.convert_unicode` flag has been set to ``True``,
+ regardless of a setting of ``False`` on an individual :class:`.String`
+ type. This has the effect of causing all :class:`.String` -based
+ columns to accommodate Python Unicode objects directly as though the
+ datatype were the :class:`.Unicode` type.
+
+ .. deprecated:: 1.3
+
+ The :paramref:`_sa.create_engine.convert_unicode` parameter
+ is deprecated and will be removed in a future release.
+ All modern DBAPIs now support Python Unicode directly and this
+ parameter is unnecessary.
+
+ :param creator: a callable which returns a DBAPI connection.
+ This creation function will be passed to the underlying
+ connection pool and will be used to create all new database
+ connections. Usage of this function causes connection
+ parameters specified in the URL argument to be bypassed.
+
+ This hook is not as flexible as the newer
+ :meth:`_events.DialectEvents.do_connect` hook which allows complete
+ control over how a connection is made to the database, given the full
+ set of URL arguments and state beforehand.
+
+ .. seealso::
+
+ :meth:`_events.DialectEvents.do_connect` - event hook that allows
+ full control over DBAPI connection mechanics.
+
+ :ref:`custom_dbapi_args`
+
+ :param echo=False: if True, the Engine will log all statements
+ as well as a ``repr()`` of their parameter lists to the default log
+ handler, which defaults to ``sys.stdout`` for output. If set to the
+ string ``"debug"``, result rows will be printed to the standard output
+ as well. The ``echo`` attribute of ``Engine`` can be modified at any
+ time to turn logging on and off; direct control of logging is also
+ available using the standard Python ``logging`` module.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param echo_pool=False: if True, the connection pool will log
+ informational output such as when connections are invalidated
+ as well as when connections are recycled to the default log handler,
+ which defaults to ``sys.stdout`` for output. If set to the string
+ ``"debug"``, the logging will include pool checkouts and checkins.
+ Direct control of logging is also available using the standard Python
+ ``logging`` module.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param empty_in_strategy: No longer used; SQLAlchemy now uses
+ "empty set" behavior for IN in all cases.
+
+ :param enable_from_linting: defaults to True. Will emit a warning
+ if a given SELECT statement is found to have un-linked FROM elements
+ which would cause a cartesian product.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`change_4737`
+
+ :param encoding: **legacy Python 2 value only, where it only applies to
+ specific DBAPIs, not used in Python 3 for any modern DBAPI driver.
+ Please refer to individual dialect documentation for client encoding
+ behaviors.** Defaults to the string value ``utf-8``. This value
+ refers **only** to the character encoding that is used when SQLAlchemy
+ sends or receives data from a :term:`DBAPI` that does not support
+ Python Unicode and **is only used under Python 2**, only for certain
+ DBAPI drivers, and only in certain circumstances. **Python 3 users
+ please DISREGARD this parameter and refer to the documentation for the
+ specific dialect in use in order to configure character encoding
+ behavior.**
+
+ .. note:: The ``encoding`` parameter deals only with in-Python
+ encoding issues that were prevalent with **some DBAPIS only**
+ under **Python 2 only**. Under Python 3 it is not used by
+ any modern dialect. For DBAPIs that require
+ client encoding configurations, which are most of those outside
+ of SQLite, please consult specific :ref:`dialect documentation
+ <dialect_toplevel>` for details.
+
+ All modern DBAPIs that work in Python 3 necessarily feature direct
+ support for Python unicode strings. Under Python 2, this was not
+ always the case. For those scenarios where the DBAPI is detected as
+ not supporting a Python ``unicode`` object under Python 2, this
+ encoding is used to determine the source/destination encoding. It is
+ **not used** for those cases where the DBAPI handles unicode directly.
+
+ To properly configure a system to accommodate Python ``unicode``
+ objects, the DBAPI should be configured to handle unicode to the
+ greatest degree as is appropriate - see the notes on unicode pertaining
+ to the specific target database in use at :ref:`dialect_toplevel`.
+
+ Areas where string encoding may need to be accommodated
+ outside of the DBAPI, nearly always under **Python 2 only**,
+ include zero or more of:
+
+ * the values passed to bound parameters, corresponding to
+ the :class:`.Unicode` type or the :class:`.String` type
+ when ``convert_unicode`` is ``True``;
+ * the values returned in result set columns corresponding
+ to the :class:`.Unicode` type or the :class:`.String`
+ type when ``convert_unicode`` is ``True``;
+ * the string SQL statement passed to the DBAPI's
+ ``cursor.execute()`` method;
+ * the string names of the keys in the bound parameter
+ dictionary passed to the DBAPI's ``cursor.execute()``
+ as well as ``cursor.setinputsizes()`` methods;
+ * the string column names retrieved from the DBAPI's
+ ``cursor.description`` attribute.
+
+ When using Python 3, the DBAPI is required to support all of the above
+ values as Python ``unicode`` objects, which in Python 3 are just known
+ as ``str``. In Python 2, the DBAPI does not specify unicode behavior
+ at all, so SQLAlchemy must make decisions for each of the above values
+ on a per-DBAPI basis - implementations are completely inconsistent in
+ their behavior.
+
+ :param execution_options: Dictionary execution options which will
+ be applied to all connections. See
+ :meth:`~sqlalchemy.engine.Connection.execution_options`
+
+ :param future: Use the 2.0 style :class:`_future.Engine` and
+ :class:`_future.Connection` API.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`migration_20_toplevel`
+
+ :param hide_parameters: Boolean, when set to True, SQL statement parameters
+ will not be displayed in INFO logging nor will they be formatted into
+ the string representation of :class:`.StatementError` objects.
+
+ .. versionadded:: 1.3.8
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :param implicit_returning=True: Legacy flag that when set to ``False``
+ will disable the use of ``RETURNING`` on supporting backends where it
+ would normally be used to fetch newly generated primary key values for
+ single-row INSERT statements that do not otherwise specify a RETURNING
+ clause. This behavior applies primarily to the PostgreSQL, Oracle,
+ SQL Server backends.
+
+ .. warning:: this flag originally allowed the "implicit returning"
+ feature to be *enabled* back when it was very new and there was not
+ well-established database support. In modern SQLAlchemy, this flag
+ should **always be set to True**. Some SQLAlchemy features will
+ fail to function properly if this flag is set to ``False``.
+
+ :param isolation_level: this string parameter is interpreted by various
+ dialects in order to affect the transaction isolation level of the
+ database connection. The parameter essentially accepts some subset of
+ these string arguments: ``"SERIALIZABLE"``, ``"REPEATABLE READ"``,
+ ``"READ COMMITTED"``, ``"READ UNCOMMITTED"`` and ``"AUTOCOMMIT"``.
+ Behavior here varies per backend, and
+ individual dialects should be consulted directly.
+
+ Note that the isolation level can also be set on a
+ per-:class:`_engine.Connection` basis as well, using the
+ :paramref:`.Connection.execution_options.isolation_level`
+ feature.
+
+ .. seealso::
+
+ :ref:`dbapi_autocommit`
+
+ :param json_deserializer: for dialects that support the
+ :class:`_types.JSON`
+ datatype, this is a Python callable that will convert a JSON string
+ to a Python object. By default, the Python ``json.loads`` function is
+ used.
+
+ .. versionchanged:: 1.3.7 The SQLite dialect renamed this from
+ ``_json_deserializer``.
+
+ :param json_serializer: for dialects that support the :class:`_types.JSON`
+ datatype, this is a Python callable that will render a given object
+ as JSON. By default, the Python ``json.dumps`` function is used.
+
+ .. versionchanged:: 1.3.7 The SQLite dialect renamed this from
+ ``_json_serializer``.
+
+
+ :param label_length=None: optional integer value which limits
+ the size of dynamically generated column labels to that many
+ characters. If less than 6, labels are generated as
+ "_(counter)". If ``None``, the value of
+ ``dialect.max_identifier_length``, which may be affected via the
+ :paramref:`_sa.create_engine.max_identifier_length` parameter,
+ is used instead. The value of
+ :paramref:`_sa.create_engine.label_length`
+ may not be larger than that of
+ :paramref:`_sa.create_engine.max_identfier_length`.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.max_identifier_length`
+
+ :param listeners: A list of one or more
+ :class:`~sqlalchemy.interfaces.PoolListener` objects which will
+ receive connection pool events.
+
+ :param logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.engine" logger. Defaults to a hexstring of the
+ object's id.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :paramref:`_engine.Connection.execution_options.logging_token`
+
+
+
+ :param max_identifier_length: integer; override the max_identifier_length
+ determined by the dialect. if ``None`` or zero, has no effect. This
+ is the database's configured maximum number of characters that may be
+ used in a SQL identifier such as a table name, column name, or label
+ name. All dialects determine this value automatically, however in the
+ case of a new database version for which this value has changed but
+ SQLAlchemy's dialect has not been adjusted, the value may be passed
+ here.
+
+ .. versionadded:: 1.3.9
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.label_length`
+
+ :param max_overflow=10: the number of connections to allow in
+ connection pool "overflow", that is connections that can be
+ opened above and beyond the pool_size setting, which defaults
+ to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
+
+ :param module=None: reference to a Python module object (the module
+ itself, not its string name). Specifies an alternate DBAPI module to
+ be used by the engine's dialect. Each sub-dialect references a
+ specific DBAPI which will be imported before first connect. This
+ parameter causes the import to be bypassed, and the given module to
+ be used instead. Can be used for testing of DBAPIs as well as to
+ inject "mock" DBAPI implementations into the :class:`_engine.Engine`.
+
+ :param paramstyle=None: The `paramstyle <https://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_
+ to use when rendering bound parameters. This style defaults to the
+ one recommended by the DBAPI itself, which is retrieved from the
+ ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept
+ more than one paramstyle, and in particular it may be desirable
+ to change a "named" paramstyle into a "positional" one, or vice versa.
+ When this attribute is passed, it should be one of the values
+ ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or
+ ``"pyformat"``, and should correspond to a parameter style known
+ to be supported by the DBAPI in use.
+
+ :param pool=None: an already-constructed instance of
+ :class:`~sqlalchemy.pool.Pool`, such as a
+ :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
+ pool will be used directly as the underlying connection pool
+ for the engine, bypassing whatever connection parameters are
+ present in the URL argument. For information on constructing
+ connection pools manually, see :ref:`pooling_toplevel`.
+
+ :param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
+ subclass, which will be used to create a connection pool
+ instance using the connection parameters given in the URL. Note
+ this differs from ``pool`` in that you don't actually
+ instantiate the pool in this case, you just indicate what type
+ of pool to be used.
+
+ :param pool_logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
+ id.
+
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param pool_pre_ping: boolean, if True will enable the connection pool
+ "pre-ping" feature that tests connections for liveness upon
+ each checkout.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`pool_disconnects_pessimistic`
+
+ :param pool_size=5: the number of connections to keep open
+ inside the connection pool. This used with
+ :class:`~sqlalchemy.pool.QueuePool` as
+ well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With
+ :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting
+ of 0 indicates no limit; to disable pooling, set ``poolclass`` to
+ :class:`~sqlalchemy.pool.NullPool` instead.
+
+ :param pool_recycle=-1: this setting causes the pool to recycle
+ connections after the given number of seconds has passed. It
+ defaults to -1, or no timeout. For example, setting to 3600
+ means connections will be recycled after one hour. Note that
+ MySQL in particular will disconnect automatically if no
+ activity is detected on a connection for eight hours (although
+ this is configurable with the MySQLDB connection itself and the
+ server configuration as well).
+
+ .. seealso::
+
+ :ref:`pool_setting_recycle`
+
+ :param pool_reset_on_return='rollback': set the
+ :paramref:`_pool.Pool.reset_on_return` parameter of the underlying
+ :class:`_pool.Pool` object, which can be set to the values
+ ``"rollback"``, ``"commit"``, or ``None``.
+
+ .. seealso::
+
+ :paramref:`_pool.Pool.reset_on_return`
+
+ :param pool_timeout=30: number of seconds to wait before giving
+ up on getting a connection from the pool. This is only used
+ with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is
+ subject to the limitations of Python time functions which may not be
+ reliable in the tens of milliseconds.
+
+ .. note: don't use 30.0 above, it seems to break with the :param tag
+
+ :param pool_use_lifo=False: use LIFO (last-in-first-out) when retrieving
+ connections from :class:`.QueuePool` instead of FIFO
+ (first-in-first-out). Using LIFO, a server-side timeout scheme can
+ reduce the number of connections used during non- peak periods of
+ use. When planning for server-side timeouts, ensure that a recycle or
+ pre-ping strategy is in use to gracefully handle stale connections.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`pool_use_lifo`
+
+ :ref:`pool_disconnects`
+
+ :param plugins: string list of plugin names to load. See
+ :class:`.CreateEnginePlugin` for background.
+
+ .. versionadded:: 1.2.3
+
+ :param query_cache_size: size of the cache used to cache the SQL string
+ form of queries. Set to zero to disable caching.
+
+ The cache is pruned of its least recently used items when its size reaches
+ N * 1.5. Defaults to 500, meaning the cache will always store at least
+ 500 SQL statements when filled, and will grow up to 750 items at which
+ point it is pruned back down to 500 by removing the 250 least recently
+ used items.
+
+ Caching is accomplished on a per-statement basis by generating a
+ cache key that represents the statement's structure, then generating
+ string SQL for the current dialect only if that key is not present
+ in the cache. All statements support caching, however some features
+ such as an INSERT with a large set of parameters will intentionally
+ bypass the cache. SQL logging will indicate statistics for each
+ statement whether or not it were pull from the cache.
+
+ .. note:: some ORM functions related to unit-of-work persistence as well
+ as some attribute loading strategies will make use of individual
+ per-mapper caches outside of the main cache.
+
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ if "strategy" in kwargs:
+ strat = kwargs.pop("strategy")
+ if strat == "mock":
+ return create_mock_engine(url, **kwargs)
+ else:
+ raise exc.ArgumentError("unknown strategy: %r" % strat)
+
+ kwargs.pop("empty_in_strategy", None)
+
+ # create url.URL object
+ u = _url.make_url(url)
+
+ u, plugins, kwargs = u._instantiate_plugins(kwargs)
+
+ entrypoint = u._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(u)
+
+ if kwargs.pop("_coerce_config", False):
+
+ def pop_kwarg(key, default=None):
+ value = kwargs.pop(key, default)
+ if key in dialect_cls.engine_config_types:
+ value = dialect_cls.engine_config_types[key](value)
+ return value
+
+ else:
+ pop_kwarg = kwargs.pop
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(dialect_cls):
+ if k in kwargs:
+ dialect_args[k] = pop_kwarg(k)
+
+ dbapi = kwargs.pop("module", None)
+ if dbapi is None:
+ dbapi_args = {}
+ for k in util.get_func_kwargs(dialect_cls.dbapi):
+ if k in kwargs:
+ dbapi_args[k] = pop_kwarg(k)
+ dbapi = dialect_cls.dbapi(**dbapi_args)
+
+ dialect_args["dbapi"] = dbapi
+
+ dialect_args.setdefault("compiler_linting", compiler.NO_LINTING)
+ enable_from_linting = kwargs.pop("enable_from_linting", True)
+ if enable_from_linting:
+ dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS
+
+ for plugin in plugins:
+ plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
+
+ # create dialect
+ dialect = dialect_cls(**dialect_args)
+
+ # assemble connection arguments
+ (cargs, cparams) = dialect.create_connect_args(u)
+ cparams.update(pop_kwarg("connect_args", {}))
+ cargs = list(cargs) # allow mutability
+
+ # look for existing pool or create
+ pool = pop_kwarg("pool", None)
+ if pool is None:
+
+ def connect(connection_record=None):
+ if dialect._has_events:
+ for fn in dialect.dispatch.do_connect:
+ connection = fn(dialect, connection_record, cargs, cparams)
+ if connection is not None:
+ return connection
+ return dialect.connect(*cargs, **cparams)
+
+ creator = pop_kwarg("creator", connect)
+
+ poolclass = pop_kwarg("poolclass", None)
+ if poolclass is None:
+ poolclass = dialect.get_dialect_pool_class(u)
+ pool_args = {"dialect": dialect}
+
+ # consume pool arguments from kwargs, translating a few of
+ # the arguments
+ translate = {
+ "logging_name": "pool_logging_name",
+ "echo": "echo_pool",
+ "timeout": "pool_timeout",
+ "recycle": "pool_recycle",
+ "events": "pool_events",
+ "reset_on_return": "pool_reset_on_return",
+ "pre_ping": "pool_pre_ping",
+ "use_lifo": "pool_use_lifo",
+ }
+ for k in util.get_cls_kwargs(poolclass):
+ tk = translate.get(k, k)
+ if tk in kwargs:
+ pool_args[k] = pop_kwarg(tk)
+
+ for plugin in plugins:
+ plugin.handle_pool_kwargs(poolclass, pool_args)
+
+ pool = poolclass(creator, **pool_args)
+ else:
+ if isinstance(pool, poollib.dbapi_proxy._DBProxy):
+ pool = pool.get_pool(*cargs, **cparams)
+
+ pool._dialect = dialect
+
+ # create engine.
+ if pop_kwarg("future", False):
+ from sqlalchemy import future
+
+ default_engine_class = future.Engine
+ else:
+ default_engine_class = base.Engine
+
+ engineclass = kwargs.pop("_future_engine_class", default_engine_class)
+
+ engine_args = {}
+ for k in util.get_cls_kwargs(engineclass):
+ if k in kwargs:
+ engine_args[k] = pop_kwarg(k)
+
+ # internal flags used by the test suite for instrumenting / proxying
+ # engines with mocks etc.
+ _initialize = kwargs.pop("_initialize", True)
+ _wrap_do_on_connect = kwargs.pop("_wrap_do_on_connect", None)
+
+ # all kwargs should be consumed
+ if kwargs:
+ raise TypeError(
+ "Invalid argument(s) %s sent to create_engine(), "
+ "using configuration %s/%s/%s. Please check that the "
+ "keyword arguments are appropriate for this combination "
+ "of components."
+ % (
+ ",".join("'%s'" % k for k in kwargs),
+ dialect.__class__.__name__,
+ pool.__class__.__name__,
+ engineclass.__name__,
+ )
+ )
+
+ engine = engineclass(pool, dialect, u, **engine_args)
+
+ if _initialize:
+ do_on_connect = dialect.on_connect_url(u)
+ if do_on_connect:
+ if _wrap_do_on_connect:
+ do_on_connect = _wrap_do_on_connect(do_on_connect)
+
+ def on_connect(dbapi_connection, connection_record):
+ do_on_connect(dbapi_connection)
+
+ event.listen(pool, "connect", on_connect)
+
+ def first_connect(dbapi_connection, connection_record):
+ c = base.Connection(
+ engine,
+ connection=dbapi_connection,
+ _has_events=False,
+ # reconnecting will be a reentrant condition, so if the
+ # connection goes away, Connection is then closed
+ _allow_revalidate=False,
+ )
+ c._execution_options = util.EMPTY_DICT
+
+ try:
+ dialect.initialize(c)
+ finally:
+ # note that "invalidated" and "closed" are mutually
+ # exclusive in 1.4 Connection.
+ if not c.invalidated and not c.closed:
+ # transaction is rolled back otherwise, tested by
+ # test/dialect/postgresql/test_dialect.py
+ # ::MiscBackendTest::test_initial_transaction_state
+ dialect.do_rollback(c.connection)
+
+ # previously, the "first_connect" event was used here, which was then
+ # scaled back if the "on_connect" handler were present. now,
+ # since "on_connect" is virtually always present, just use
+ # "connect" event with once_unless_exception in all cases so that
+ # the connection event flow is consistent in all cases.
+ event.listen(
+ pool, "connect", first_connect, _once_unless_exception=True
+ )
+
+ dialect_cls.engine_created(engine)
+ if entrypoint is not dialect_cls:
+ entrypoint.engine_created(engine)
+
+ for plugin in plugins:
+ plugin.engine_created(engine)
+
+ return engine
+
+
+def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+ """Create a new Engine instance using a configuration dictionary.
+
+ The dictionary is typically produced from a config file.
+
+ The keys of interest to ``engine_from_config()`` should be prefixed, e.g.
+ ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument
+ indicates the prefix to be searched for. Each matching key (after the
+ prefix is stripped) is treated as though it were the corresponding keyword
+ argument to a :func:`_sa.create_engine` call.
+
+ The only required key is (assuming the default prefix) ``sqlalchemy.url``,
+ which provides the :ref:`database URL <database_urls>`.
+
+ A select set of keyword arguments will be "coerced" to their
+ expected type based on string values. The set of arguments
+ is extensible per-dialect using the ``engine_config_types`` accessor.
+
+ :param configuration: A dictionary (typically produced from a config file,
+ but this is not a requirement). Items whose keys start with the value
+ of 'prefix' will have that prefix stripped, and will then be passed to
+ :func:`_sa.create_engine`.
+
+ :param prefix: Prefix to match and then strip from keys
+ in 'configuration'.
+
+ :param kwargs: Each keyword argument to ``engine_from_config()`` itself
+ overrides the corresponding item taken from the 'configuration'
+ dictionary. Keyword arguments should *not* be prefixed.
+
+ """
+
+ options = dict(
+ (key[len(prefix) :], configuration[key])
+ for key in configuration
+ if key.startswith(prefix)
+ )
+ options["_coerce_config"] = True
+ options.update(kwargs)
+ url = options.pop("url")
+ return create_engine(url, **options)
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
new file mode 100644
index 0000000..774916d
--- /dev/null
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -0,0 +1,1942 @@
+# engine/cursor.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define cursor-specific result set constructs including
+:class:`.BaseCursorResult`, :class:`.CursorResult`."""
+
+
+import collections
+import functools
+
+from .result import Result
+from .result import ResultMetaData
+from .result import SimpleResultMetaData
+from .result import tuplegetter
+from .row import LegacyRow
+from .. import exc
+from .. import util
+from ..sql import expression
+from ..sql import sqltypes
+from ..sql import util as sql_util
+from ..sql.base import _generative
+from ..sql.compiler import RM_NAME
+from ..sql.compiler import RM_OBJECTS
+from ..sql.compiler import RM_RENDERED_NAME
+from ..sql.compiler import RM_TYPE
+
+_UNPICKLED = util.symbol("unpickled")
+
+
+# metadata entry tuple indexes.
+# using raw tuple is faster than namedtuple.
+MD_INDEX = 0 # integer index in cursor.description
+MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns
+MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match
+MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup
+MD_RENDERED_NAME = 4 # name that is usually in cursor.description
+MD_PROCESSOR = 5 # callable to process a result value into a row
+MD_UNTRANSLATED = 6 # raw name from cursor.description
+
+
+class CursorResultMetaData(ResultMetaData):
+ """Result metadata for DBAPI cursors."""
+
+ __slots__ = (
+ "_keymap",
+ "case_sensitive",
+ "_processors",
+ "_keys",
+ "_keymap_by_result_column_idx",
+ "_tuplefilter",
+ "_translated_indexes",
+ "_safe_for_cache"
+ # don't need _unique_filters support here for now. Can be added
+ # if a need arises.
+ )
+
+ returns_rows = True
+
+ def _has_key(self, key):
+ return key in self._keymap
+
+ def _for_freeze(self):
+ return SimpleResultMetaData(
+ self._keys,
+ extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
+ )
+
+ def _reduce(self, keys):
+ recs = list(self._metadata_for_keys(keys))
+
+ indexes = [rec[MD_INDEX] for rec in recs]
+ new_keys = [rec[MD_LOOKUP_KEY] for rec in recs]
+
+ if self._translated_indexes:
+ indexes = [self._translated_indexes[idx] for idx in indexes]
+
+ tup = tuplegetter(*indexes)
+
+ new_metadata = self.__class__.__new__(self.__class__)
+ new_metadata.case_sensitive = self.case_sensitive
+ new_metadata._processors = self._processors
+ new_metadata._keys = new_keys
+ new_metadata._tuplefilter = tup
+ new_metadata._translated_indexes = indexes
+
+ new_recs = [
+ (index,) + rec[1:]
+ for index, rec in enumerate(self._metadata_for_keys(keys))
+ ]
+ new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
+
+ # TODO: need unit test for:
+ # result = connection.execute("raw sql, no columns").scalars()
+ # without the "or ()" it's failing because MD_OBJECTS is None
+ new_metadata._keymap.update(
+ {
+ e: new_rec
+ for new_rec in new_recs
+ for e in new_rec[MD_OBJECTS] or ()
+ }
+ )
+
+ return new_metadata
+
+ def _adapt_to_context(self, context):
+ """When using a cached Compiled construct that has a _result_map,
+ for a new statement that used the cached Compiled, we need to ensure
+ the keymap has the Column objects from our new statement as keys.
+ So here we rewrite keymap with new entries for the new columns
+ as matched to those of the cached statement.
+
+ """
+
+ if not context.compiled._result_columns:
+ return self
+
+ compiled_statement = context.compiled.statement
+ invoked_statement = context.invoked_statement
+
+ if compiled_statement is invoked_statement:
+ return self
+
+ # make a copy and add the columns from the invoked statement
+ # to the result map.
+ md = self.__class__.__new__(self.__class__)
+
+ md._keymap = dict(self._keymap)
+
+ keymap_by_position = self._keymap_by_result_column_idx
+
+ for idx, new in enumerate(invoked_statement._all_selected_columns):
+ try:
+ rec = keymap_by_position[idx]
+ except KeyError:
+ # this can happen when there are bogus column entries
+ # in a TextualSelect
+ pass
+ else:
+ md._keymap[new] = rec
+
+ md.case_sensitive = self.case_sensitive
+ md._processors = self._processors
+ assert not self._tuplefilter
+ md._tuplefilter = None
+ md._translated_indexes = None
+ md._keys = self._keys
+ md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
+ md._safe_for_cache = self._safe_for_cache
+ return md
+
+ def __init__(self, parent, cursor_description):
+ context = parent.context
+ dialect = context.dialect
+ self._tuplefilter = None
+ self._translated_indexes = None
+ self.case_sensitive = dialect.case_sensitive
+ self._safe_for_cache = False
+
+ if context.result_column_struct:
+ (
+ result_columns,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ ) = context.result_column_struct
+ num_ctx_cols = len(result_columns)
+ else:
+ result_columns = (
+ cols_are_ordered
+ ) = (
+ num_ctx_cols
+ ) = loose_column_name_matching = textual_ordered = False
+
+ # merge cursor.description with the column info
+ # present in the compiled structure, if any
+ raw = self._merge_cursor_description(
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ )
+
+ self._keymap = {}
+
+ # processors in key order for certain per-row
+ # views like __iter__ and slices
+ self._processors = [
+ metadata_entry[MD_PROCESSOR] for metadata_entry in raw
+ ]
+
+ if context.compiled:
+ self._keymap_by_result_column_idx = {
+ metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+ for metadata_entry in raw
+ }
+
+ # keymap by primary string...
+ by_key = dict(
+ [
+ (metadata_entry[MD_LOOKUP_KEY], metadata_entry)
+ for metadata_entry in raw
+ ]
+ )
+
+ # for compiled SQL constructs, copy additional lookup keys into
+ # the key lookup map, such as Column objects, labels,
+ # column keys and other names
+ if num_ctx_cols:
+
+ # if by-primary-string dictionary smaller (or bigger?!) than
+ # number of columns, assume we have dupes, rewrite
+ # dupe records with "None" for index which results in
+ # ambiguous column exception when accessed.
+ if len(by_key) != num_ctx_cols:
+ # new in 1.4: get the complete set of all possible keys,
+ # strings, objects, whatever, that are dupes across two
+ # different records, first.
+ index_by_key = {}
+ dupes = set()
+ for metadata_entry in raw:
+ for key in (metadata_entry[MD_RENDERED_NAME],) + (
+ metadata_entry[MD_OBJECTS] or ()
+ ):
+ if not self.case_sensitive and isinstance(
+ key, util.string_types
+ ):
+ key = key.lower()
+ idx = metadata_entry[MD_INDEX]
+ # if this key has been associated with more than one
+ # positional index, it's a dupe
+ if index_by_key.setdefault(key, idx) != idx:
+ dupes.add(key)
+
+ # then put everything we have into the keymap excluding only
+ # those keys that are dupes.
+ self._keymap.update(
+ [
+ (obj_elem, metadata_entry)
+ for metadata_entry in raw
+ if metadata_entry[MD_OBJECTS]
+ for obj_elem in metadata_entry[MD_OBJECTS]
+ if obj_elem not in dupes
+ ]
+ )
+
+ # then for the dupe keys, put the "ambiguous column"
+ # record into by_key.
+ by_key.update({key: (None, None, (), key) for key in dupes})
+
+ else:
+ # no dupes - copy secondary elements from compiled
+ # columns into self._keymap
+ self._keymap.update(
+ [
+ (obj_elem, metadata_entry)
+ for metadata_entry in raw
+ if metadata_entry[MD_OBJECTS]
+ for obj_elem in metadata_entry[MD_OBJECTS]
+ ]
+ )
+
+ # update keymap with primary string names taking
+ # precedence
+ self._keymap.update(by_key)
+
+ # update keymap with "translated" names (sqlite-only thing)
+ if not num_ctx_cols and context._translate_colname:
+ self._keymap.update(
+ [
+ (
+ metadata_entry[MD_UNTRANSLATED],
+ self._keymap[metadata_entry[MD_LOOKUP_KEY]],
+ )
+ for metadata_entry in raw
+ if metadata_entry[MD_UNTRANSLATED]
+ ]
+ )
+
+ def _merge_cursor_description(
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ ):
+ """Merge a cursor.description with compiled result column information.
+
+ There are at least four separate strategies used here, selected
+ depending on the type of SQL construct used to start with.
+
+ The most common case is that of the compiled SQL expression construct,
+ which generated the column names present in the raw SQL string and
+ which has the identical number of columns as were reported by
+ cursor.description. In this case, we assume a 1-1 positional mapping
+ between the entries in cursor.description and the compiled object.
+ This is also the most performant case as we disregard extracting /
+ decoding the column names present in cursor.description since we
+ already have the desired name we generated in the compiled SQL
+ construct.
+
+ The next common case is that of the completely raw string SQL,
+ such as passed to connection.execute(). In this case we have no
+ compiled construct to work with, so we extract and decode the
+ names from cursor.description and index those as the primary
+ result row target keys.
+
+ The remaining fairly common case is that of the textual SQL
+ that includes at least partial column information; this is when
+ we use a :class:`_expression.TextualSelect` construct.
+ This construct may have
+ unordered or ordered column information. In the ordered case, we
+ merge the cursor.description and the compiled construct's information
+ positionally, and warn if there are additional description names
+ present, however we still decode the names in cursor.description
+ as we don't have a guarantee that the names in the columns match
+ on these. In the unordered case, we match names in cursor.description
+ to that of the compiled construct based on name matching.
+ In both of these cases, the cursor.description names and the column
+ expression objects and names are indexed as result row target keys.
+
+ The final case is much less common, where we have a compiled
+ non-textual SQL expression construct, but the number of columns
+ in cursor.description doesn't match what's in the compiled
+ construct. We make the guess here that there might be textual
+ column expressions in the compiled construct that themselves include
+ a comma in them causing them to split. We do the same name-matching
+ as with textual non-ordered columns.
+
+ The name-matched system of merging is the same as that used by
+ SQLAlchemy for all cases up through te 0.9 series. Positional
+ matching for compiled SQL expressions was introduced in 1.0 as a
+ major performance feature, and positional matching for textual
+ :class:`_expression.TextualSelect` objects in 1.1.
+ As name matching is no longer
+ a common case, it was acceptable to factor it into smaller generator-
+ oriented methods that are easier to understand, but incur slightly
+ more performance overhead.
+
+ """
+
+ case_sensitive = context.dialect.case_sensitive
+
+ if (
+ num_ctx_cols
+ and cols_are_ordered
+ and not textual_ordered
+ and num_ctx_cols == len(cursor_description)
+ ):
+ self._keys = [elem[0] for elem in result_columns]
+ # pure positional 1-1 case; doesn't need to read
+ # the names from cursor.description
+
+ # this metadata is safe to cache because we are guaranteed
+ # to have the columns in the same order for new executions
+ self._safe_for_cache = True
+ return [
+ (
+ idx,
+ idx,
+ rmap_entry[RM_OBJECTS],
+ rmap_entry[RM_NAME].lower()
+ if not case_sensitive
+ else rmap_entry[RM_NAME],
+ rmap_entry[RM_RENDERED_NAME],
+ context.get_result_processor(
+ rmap_entry[RM_TYPE],
+ rmap_entry[RM_RENDERED_NAME],
+ cursor_description[idx][1],
+ ),
+ None,
+ )
+ for idx, rmap_entry in enumerate(result_columns)
+ ]
+ else:
+
+ # name-based or text-positional cases, where we need
+ # to read cursor.description names
+
+ if textual_ordered:
+ self._safe_for_cache = True
+ # textual positional case
+ raw_iterator = self._merge_textual_cols_by_position(
+ context, cursor_description, result_columns
+ )
+ elif num_ctx_cols:
+ # compiled SQL with a mismatch of description cols
+ # vs. compiled cols, or textual w/ unordered columns
+ # the order of columns can change if the query is
+ # against a "select *", so not safe to cache
+ self._safe_for_cache = False
+ raw_iterator = self._merge_cols_by_name(
+ context,
+ cursor_description,
+ result_columns,
+ loose_column_name_matching,
+ )
+ else:
+ # no compiled SQL, just a raw string, order of columns
+ # can change for "select *"
+ self._safe_for_cache = False
+ raw_iterator = self._merge_cols_by_none(
+ context, cursor_description
+ )
+
+ return [
+ (
+ idx,
+ ridx,
+ obj,
+ cursor_colname,
+ cursor_colname,
+ context.get_result_processor(
+ mapped_type, cursor_colname, coltype
+ ),
+ untranslated,
+ )
+ for (
+ idx,
+ ridx,
+ cursor_colname,
+ mapped_type,
+ coltype,
+ obj,
+ untranslated,
+ ) in raw_iterator
+ ]
+
+ def _colnames_from_description(self, context, cursor_description):
+ """Extract column names and data types from a cursor.description.
+
+ Applies unicode decoding, column translation, "normalization",
+ and case sensitivity rules to the names based on the dialect.
+
+ """
+
+ dialect = context.dialect
+ case_sensitive = dialect.case_sensitive
+ translate_colname = context._translate_colname
+ description_decoder = (
+ dialect._description_decoder
+ if dialect.description_encoding
+ else None
+ )
+ normalize_name = (
+ dialect.normalize_name if dialect.requires_name_normalize else None
+ )
+ untranslated = None
+
+ self._keys = []
+
+ for idx, rec in enumerate(cursor_description):
+ colname = rec[0]
+ coltype = rec[1]
+
+ if description_decoder:
+ colname = description_decoder(colname)
+
+ if translate_colname:
+ colname, untranslated = translate_colname(colname)
+
+ if normalize_name:
+ colname = normalize_name(colname)
+
+ self._keys.append(colname)
+ if not case_sensitive:
+ colname = colname.lower()
+
+ yield idx, colname, untranslated, coltype
+
+ def _merge_textual_cols_by_position(
+ self, context, cursor_description, result_columns
+ ):
+ num_ctx_cols = len(result_columns) if result_columns else None
+
+ if num_ctx_cols > len(cursor_description):
+ util.warn(
+ "Number of columns in textual SQL (%d) is "
+ "smaller than number of columns requested (%d)"
+ % (num_ctx_cols, len(cursor_description))
+ )
+ seen = set()
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ if idx < num_ctx_cols:
+ ctx_rec = result_columns[idx]
+ obj = ctx_rec[RM_OBJECTS]
+ ridx = idx
+ mapped_type = ctx_rec[RM_TYPE]
+ if obj[0] in seen:
+ raise exc.InvalidRequestError(
+ "Duplicate column expression requested "
+ "in textual SQL: %r" % obj[0]
+ )
+ seen.add(obj[0])
+ else:
+ mapped_type = sqltypes.NULLTYPE
+ obj = None
+ ridx = None
+ yield idx, ridx, colname, mapped_type, coltype, obj, untranslated
+
+ def _merge_cols_by_name(
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ loose_column_name_matching,
+ ):
+ dialect = context.dialect
+ case_sensitive = dialect.case_sensitive
+ match_map = self._create_description_match_map(
+ result_columns, case_sensitive, loose_column_name_matching
+ )
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ try:
+ ctx_rec = match_map[colname]
+ except KeyError:
+ mapped_type = sqltypes.NULLTYPE
+ obj = None
+ result_columns_idx = None
+ else:
+ obj = ctx_rec[1]
+ mapped_type = ctx_rec[2]
+ result_columns_idx = ctx_rec[3]
+ yield (
+ idx,
+ result_columns_idx,
+ colname,
+ mapped_type,
+ coltype,
+ obj,
+ untranslated,
+ )
+
+ @classmethod
+ def _create_description_match_map(
+ cls,
+ result_columns,
+ case_sensitive=True,
+ loose_column_name_matching=False,
+ ):
+ """when matching cursor.description to a set of names that are present
+ in a Compiled object, as is the case with TextualSelect, get all the
+ names we expect might match those in cursor.description.
+ """
+
+ d = {}
+ for ridx, elem in enumerate(result_columns):
+ key = elem[RM_RENDERED_NAME]
+
+ if not case_sensitive:
+ key = key.lower()
+ if key in d:
+ # conflicting keyname - just add the column-linked objects
+ # to the existing record. if there is a duplicate column
+ # name in the cursor description, this will allow all of those
+ # objects to raise an ambiguous column error
+ e_name, e_obj, e_type, e_ridx = d[key]
+ d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx
+ else:
+ d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx)
+
+ if loose_column_name_matching:
+ # when using a textual statement with an unordered set
+ # of columns that line up, we are expecting the user
+ # to be using label names in the SQL that match to the column
+ # expressions. Enable more liberal matching for this case;
+ # duplicate keys that are ambiguous will be fixed later.
+ for r_key in elem[RM_OBJECTS]:
+ d.setdefault(
+ r_key,
+ (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx),
+ )
+
+ return d
+
+ def _merge_cols_by_none(self, context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ yield (
+ idx,
+ None,
+ colname,
+ sqltypes.NULLTYPE,
+ coltype,
+ None,
+ untranslated,
+ )
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ if raiseerr:
+ util.raise_(
+ exc.NoSuchColumnError(
+ "Could not locate column in row for column '%s'"
+ % util.string_or_unprintable(key)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+
+ def _raise_for_ambiguous_column_name(self, rec):
+ raise exc.InvalidRequestError(
+ "Ambiguous column name '%s' in "
+ "result set column descriptions" % rec[MD_LOOKUP_KEY]
+ )
+
+ def _index_for_key(self, key, raiseerr=True):
+ # TODO: can consider pre-loading ints and negative ints
+ # into _keymap - also no coverage here
+ if isinstance(key, int):
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, raiseerr)
+ if rec is None:
+ return None
+
+ index = rec[0]
+
+ if index is None:
+ self._raise_for_ambiguous_column_name(rec)
+ return index
+
+ def _indexes_for_keys(self, keys):
+
+ try:
+ return [self._keymap[key][0] for key in keys]
+ except KeyError as ke:
+ # ensure it raises
+ CursorResultMetaData._key_fallback(self, ke.args[0], ke)
+
+ def _metadata_for_keys(self, keys):
+ for key in keys:
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ # ensure it raises
+ CursorResultMetaData._key_fallback(self, ke.args[0], ke)
+
+ index = rec[0]
+
+ if index is None:
+ self._raise_for_ambiguous_column_name(rec)
+
+ yield rec
+
+ def __getstate__(self):
+ return {
+ "_keymap": {
+ key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key)
+ for key, rec in self._keymap.items()
+ if isinstance(key, util.string_types + util.int_types)
+ },
+ "_keys": self._keys,
+ "case_sensitive": self.case_sensitive,
+ "_translated_indexes": self._translated_indexes,
+ "_tuplefilter": self._tuplefilter,
+ }
+
+ def __setstate__(self, state):
+ self._processors = [None for _ in range(len(state["_keys"]))]
+ self._keymap = state["_keymap"]
+
+ self._keymap_by_result_column_idx = {
+ rec[MD_RESULT_MAP_INDEX]: rec for rec in self._keymap.values()
+ }
+ self._keys = state["_keys"]
+ self.case_sensitive = state["case_sensitive"]
+
+ if state["_translated_indexes"]:
+ self._translated_indexes = state["_translated_indexes"]
+ self._tuplefilter = tuplegetter(*self._translated_indexes)
+ else:
+ self._translated_indexes = self._tuplefilter = None
+
+
+class LegacyCursorResultMetaData(CursorResultMetaData):
+ __slots__ = ()
+
+ def _contains(self, value, row):
+ key = value
+ if key in self._keymap:
+ util.warn_deprecated_20(
+ "Using the 'in' operator to test for string or column "
+ "keys, or integer indexes, in a :class:`.Row` object is "
+ "deprecated and will "
+ "be removed in a future release. "
+ "Use the `Row._fields` or `Row._mapping` attribute, i.e. "
+ "'key in row._fields'",
+ )
+ return True
+ else:
+ return self._key_fallback(key, None, False) is not None
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ map_ = self._keymap
+ result = None
+
+ if isinstance(key, util.string_types):
+ result = map_.get(key if self.case_sensitive else key.lower())
+ elif isinstance(key, expression.ColumnElement):
+ if (
+ key._tq_label
+ and (
+ key._tq_label
+ if self.case_sensitive
+ else key._tq_label.lower()
+ )
+ in map_
+ ):
+ result = map_[
+ key._tq_label
+ if self.case_sensitive
+ else key._tq_label.lower()
+ ]
+ elif (
+ hasattr(key, "name")
+ and (key.name if self.case_sensitive else key.name.lower())
+ in map_
+ ):
+ # match is only on name.
+ result = map_[
+ key.name if self.case_sensitive else key.name.lower()
+ ]
+
+ # search extra hard to make sure this
+ # isn't a column/label name overlap.
+ # this check isn't currently available if the row
+ # was unpickled.
+ if result is not None and result[MD_OBJECTS] not in (
+ None,
+ _UNPICKLED,
+ ):
+ for obj in result[MD_OBJECTS]:
+ if key._compare_name_for_result(obj):
+ break
+ else:
+ result = None
+ if result is not None:
+ if result[MD_OBJECTS] is _UNPICKLED:
+ util.warn_deprecated(
+ "Retrieving row values using Column objects from a "
+ "row that was unpickled is deprecated; adequate "
+ "state cannot be pickled for this to be efficient. "
+ "This usage will raise KeyError in a future release.",
+ version="1.4",
+ )
+ else:
+ util.warn_deprecated(
+ "Retrieving row values using Column objects with only "
+ "matching names as keys is deprecated, and will raise "
+ "KeyError in a future release; only Column "
+ "objects that are explicitly part of the statement "
+ "object should be used.",
+ version="1.4",
+ )
+ if result is None:
+ if raiseerr:
+ util.raise_(
+ exc.NoSuchColumnError(
+ "Could not locate column in row for column '%s'"
+ % util.string_or_unprintable(key)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+ else:
+ map_[key] = result
+ return result
+
+ def _warn_for_nonint(self, key):
+ util.warn_deprecated_20(
+ "Using non-integer/slice indices on Row is deprecated and will "
+ "be removed in version 2.0; please use row._mapping[<key>], or "
+ "the mappings() accessor on the Result object.",
+ stacklevel=4,
+ )
+
+ def _has_key(self, key):
+ if key in self._keymap:
+ return True
+ else:
+ return self._key_fallback(key, None, False) is not None
+
+
+class ResultFetchStrategy(object):
+ """Define a fetching strategy for a result object.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ alternate_cursor_description = None
+
+ def soft_close(self, result, dbapi_cursor):
+ raise NotImplementedError()
+
+ def hard_close(self, result, dbapi_cursor):
+ raise NotImplementedError()
+
+ def yield_per(self, result, dbapi_cursor, num):
+ return
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ raise NotImplementedError()
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ raise NotImplementedError()
+
+ def fetchall(self, result):
+ raise NotImplementedError()
+
+ def handle_exception(self, result, dbapi_cursor, err):
+ raise err
+
+
+class NoCursorFetchStrategy(ResultFetchStrategy):
+ """Cursor strategy for a result that has no open cursor.
+
+ There are two varieties of this strategy, one for DQL and one for
+ DML (and also DDL), each of which represent a result that had a cursor
+ but no longer has one.
+
+ """
+
+ __slots__ = ()
+
+ def soft_close(self, result, dbapi_cursor):
+ pass
+
+ def hard_close(self, result, dbapi_cursor):
+ pass
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ return self._non_result(result, None)
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ return self._non_result(result, [])
+
+ def fetchall(self, result, dbapi_cursor):
+ return self._non_result(result, [])
+
+ def _non_result(self, result, default, err=None):
+ raise NotImplementedError()
+
+
+class NoCursorDQLFetchStrategy(NoCursorFetchStrategy):
+ """Cursor strategy for a DQL result that has no open cursor.
+
+ This is a result set that can return rows, i.e. for a SELECT, or for an
+ INSERT, UPDATE, DELETE that includes RETURNING. However it is in the state
+ where the cursor is closed and no rows remain available. The owning result
+ object may or may not be "hard closed", which determines if the fetch
+ methods send empty results or raise for closed result.
+
+ """
+
+ __slots__ = ()
+
+ def _non_result(self, result, default, err=None):
+ if result.closed:
+ util.raise_(
+ exc.ResourceClosedError("This result object is closed."),
+ replace_context=err,
+ )
+ else:
+ return default
+
+
+_NO_CURSOR_DQL = NoCursorDQLFetchStrategy()
+
+
+class NoCursorDMLFetchStrategy(NoCursorFetchStrategy):
+ """Cursor strategy for a DML result that has no open cursor.
+
+ This is a result set that does not return rows, i.e. for an INSERT,
+ UPDATE, DELETE that does not include RETURNING.
+
+ """
+
+ __slots__ = ()
+
+ def _non_result(self, result, default, err=None):
+ # we only expect to have a _NoResultMetaData() here right now.
+ assert not result._metadata.returns_rows
+ result._metadata._we_dont_return_rows(err)
+
+
+_NO_CURSOR_DML = NoCursorDMLFetchStrategy()
+
+
+class CursorFetchStrategy(ResultFetchStrategy):
+ """Call fetch methods from a DBAPI cursor.
+
+ Alternate versions of this class may instead buffer the rows from
+ cursors or not use cursors at all.
+
+ """
+
+ __slots__ = ()
+
+ def soft_close(self, result, dbapi_cursor):
+ result.cursor_strategy = _NO_CURSOR_DQL
+
+ def hard_close(self, result, dbapi_cursor):
+ result.cursor_strategy = _NO_CURSOR_DQL
+
+ def handle_exception(self, result, dbapi_cursor, err):
+ result.connection._handle_dbapi_exception(
+ err, None, None, dbapi_cursor, result.context
+ )
+
+ def yield_per(self, result, dbapi_cursor, num):
+ result.cursor_strategy = BufferedRowCursorFetchStrategy(
+ dbapi_cursor,
+ {"max_row_buffer": num},
+ initial_buffer=collections.deque(),
+ growth_factor=0,
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ try:
+ row = dbapi_cursor.fetchone()
+ if row is None:
+ result._soft_close(hard=hard_close)
+ return row
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ try:
+ if size is None:
+ l = dbapi_cursor.fetchmany()
+ else:
+ l = dbapi_cursor.fetchmany(size)
+
+ if not l:
+ result._soft_close()
+ return l
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ def fetchall(self, result, dbapi_cursor):
+ try:
+ rows = dbapi_cursor.fetchall()
+ result._soft_close()
+ return rows
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+
+_DEFAULT_FETCH = CursorFetchStrategy()
+
+
+class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
+ """A cursor fetch strategy with row buffering behavior.
+
+ This strategy buffers the contents of a selection of rows
+ before ``fetchone()`` is called. This is to allow the results of
+ ``cursor.description`` to be available immediately, when
+ interfacing with a DB-API that requires rows to be consumed before
+ this information is available (currently psycopg2, when used with
+ server-side cursors).
+
+ The pre-fetching behavior fetches only one row initially, and then
+ grows its buffer size by a fixed amount with each successive need
+ for additional rows up the ``max_row_buffer`` size, which defaults
+ to 1000::
+
+ with psycopg2_engine.connect() as conn:
+
+ result = conn.execution_options(
+ stream_results=True, max_row_buffer=50
+ ).execute(text("select * from table"))
+
+ .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows.
+
+ .. seealso::
+
+ :ref:`psycopg2_execution_options`
+ """
+
+ __slots__ = ("_max_row_buffer", "_rowbuffer", "_bufsize", "_growth_factor")
+
+ def __init__(
+ self,
+ dbapi_cursor,
+ execution_options,
+ growth_factor=5,
+ initial_buffer=None,
+ ):
+ self._max_row_buffer = execution_options.get("max_row_buffer", 1000)
+
+ if initial_buffer is not None:
+ self._rowbuffer = initial_buffer
+ else:
+ self._rowbuffer = collections.deque(dbapi_cursor.fetchmany(1))
+ self._growth_factor = growth_factor
+
+ if growth_factor:
+ self._bufsize = min(self._max_row_buffer, self._growth_factor)
+ else:
+ self._bufsize = self._max_row_buffer
+
+ @classmethod
+ def create(cls, result):
+ return BufferedRowCursorFetchStrategy(
+ result.cursor,
+ result.context.execution_options,
+ )
+
+ def _buffer_rows(self, result, dbapi_cursor):
+ """this is currently used only by fetchone()."""
+
+ size = self._bufsize
+ try:
+ if size < 1:
+ new_rows = dbapi_cursor.fetchall()
+ else:
+ new_rows = dbapi_cursor.fetchmany(size)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ if not new_rows:
+ return
+ self._rowbuffer = collections.deque(new_rows)
+ if self._growth_factor and size < self._max_row_buffer:
+ self._bufsize = min(
+ self._max_row_buffer, size * self._growth_factor
+ )
+
+ def yield_per(self, result, dbapi_cursor, num):
+ self._growth_factor = 0
+ self._max_row_buffer = self._bufsize = num
+
+ def soft_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(BufferedRowCursorFetchStrategy, self).soft_close(
+ result, dbapi_cursor
+ )
+
+ def hard_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(BufferedRowCursorFetchStrategy, self).hard_close(
+ result, dbapi_cursor
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ if not self._rowbuffer:
+ self._buffer_rows(result, dbapi_cursor)
+ if not self._rowbuffer:
+ try:
+ result._soft_close(hard=hard_close)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+ return None
+ return self._rowbuffer.popleft()
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ if size is None:
+ return self.fetchall(result, dbapi_cursor)
+
+ buf = list(self._rowbuffer)
+ lb = len(buf)
+ if size > lb:
+ try:
+ new = dbapi_cursor.fetchmany(size - lb)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+ else:
+ if not new:
+ result._soft_close()
+ else:
+ buf.extend(new)
+
+ result = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ return result
+
+ def fetchall(self, result, dbapi_cursor):
+ try:
+ ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall())
+ self._rowbuffer.clear()
+ result._soft_close()
+ return ret
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+
+class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
+ """A cursor strategy that buffers rows fully upon creation.
+
+ Used for operations where a result is to be delivered
+ after the database conversation can not be continued,
+ such as MSSQL INSERT...OUTPUT after an autocommit.
+
+ """
+
+ __slots__ = ("_rowbuffer", "alternate_cursor_description")
+
+ def __init__(
+ self, dbapi_cursor, alternate_description=None, initial_buffer=None
+ ):
+ self.alternate_cursor_description = alternate_description
+ if initial_buffer is not None:
+ self._rowbuffer = collections.deque(initial_buffer)
+ else:
+ self._rowbuffer = collections.deque(dbapi_cursor.fetchall())
+
+ def yield_per(self, result, dbapi_cursor, num):
+ pass
+
+ def soft_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(FullyBufferedCursorFetchStrategy, self).soft_close(
+ result, dbapi_cursor
+ )
+
+ def hard_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(FullyBufferedCursorFetchStrategy, self).hard_close(
+ result, dbapi_cursor
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ if self._rowbuffer:
+ return self._rowbuffer.popleft()
+ else:
+ result._soft_close(hard=hard_close)
+ return None
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ if size is None:
+ return self.fetchall(result, dbapi_cursor)
+
+ buf = list(self._rowbuffer)
+ rows = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ if not rows:
+ result._soft_close()
+ return rows
+
+ def fetchall(self, result, dbapi_cursor):
+ ret = self._rowbuffer
+ self._rowbuffer = collections.deque()
+ result._soft_close()
+ return ret
+
+
+class _NoResultMetaData(ResultMetaData):
+ __slots__ = ()
+
+ returns_rows = False
+
+ def _we_dont_return_rows(self, err=None):
+ util.raise_(
+ exc.ResourceClosedError(
+ "This result object does not return rows. "
+ "It has been closed automatically."
+ ),
+ replace_context=err,
+ )
+
+ def _index_for_key(self, keys, raiseerr):
+ self._we_dont_return_rows()
+
+ def _metadata_for_keys(self, key):
+ self._we_dont_return_rows()
+
+ def _reduce(self, keys):
+ self._we_dont_return_rows()
+
+ @property
+ def _keymap(self):
+ self._we_dont_return_rows()
+
+ @property
+ def keys(self):
+ self._we_dont_return_rows()
+
+
+class _LegacyNoResultMetaData(_NoResultMetaData):
+ @property
+ def keys(self):
+ util.warn_deprecated_20(
+ "Calling the .keys() method on a result set that does not return "
+ "rows is deprecated and will raise ResourceClosedError in "
+ "SQLAlchemy 2.0.",
+ )
+ return []
+
+
+_NO_RESULT_METADATA = _NoResultMetaData()
+_LEGACY_NO_RESULT_METADATA = _LegacyNoResultMetaData()
+
+
+class BaseCursorResult(object):
+ """Base class for database result objects."""
+
+ out_parameters = None
+ _metadata = None
+ _soft_closed = False
+ closed = False
+
+ def __init__(self, context, cursor_strategy, cursor_description):
+ self.context = context
+ self.dialect = context.dialect
+ self.cursor = context.cursor
+ self.cursor_strategy = cursor_strategy
+ self.connection = context.root_connection
+ self._echo = echo = (
+ self.connection._echo and context.engine._should_log_debug()
+ )
+
+ if cursor_description is not None:
+ # inline of Result._row_getter(), set up an initial row
+ # getter assuming no transformations will be called as this
+ # is the most common case
+
+ if echo:
+ log = self.context.connection._log_debug
+
+ def log_row(row):
+ log("Row %r", sql_util._repr_row(row))
+ return row
+
+ self._row_logging_fn = log_row
+ else:
+ log_row = None
+
+ metadata = self._init_metadata(context, cursor_description)
+
+ keymap = metadata._keymap
+ processors = metadata._processors
+ process_row = self._process_row
+ key_style = process_row._default_key_style
+ _make_row = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+ if log_row:
+
+ def make_row(row):
+ made_row = _make_row(row)
+ log_row(made_row)
+ return made_row
+
+ else:
+ make_row = _make_row
+ self._set_memoized_attribute("_row_getter", make_row)
+
+ else:
+ self._metadata = self._no_result_metadata
+
+ def _init_metadata(self, context, cursor_description):
+
+ if context.compiled:
+ if context.compiled._cached_metadata:
+ metadata = self.context.compiled._cached_metadata
+ else:
+ metadata = self._cursor_metadata(self, cursor_description)
+ if metadata._safe_for_cache:
+ context.compiled._cached_metadata = metadata
+
+ # result rewrite/ adapt step. this is to suit the case
+ # when we are invoked against a cached Compiled object, we want
+ # to rewrite the ResultMetaData to reflect the Column objects
+ # that are in our current SQL statement object, not the one
+ # that is associated with the cached Compiled object.
+ # the Compiled object may also tell us to not
+ # actually do this step; this is to support the ORM where
+ # it is to produce a new Result object in any case, and will
+ # be using the cached Column objects against this database result
+ # so we don't want to rewrite them.
+ #
+ # Basically this step suits the use case where the end user
+ # is using Core SQL expressions and is accessing columns in the
+ # result row using row._mapping[table.c.column].
+ compiled = context.compiled
+ if (
+ compiled
+ and compiled._result_columns
+ and context.cache_hit is context.dialect.CACHE_HIT
+ and not context.execution_options.get(
+ "_result_disable_adapt_to_context", False
+ )
+ and compiled.statement is not context.invoked_statement
+ ):
+ metadata = metadata._adapt_to_context(context)
+
+ self._metadata = metadata
+
+ else:
+ self._metadata = metadata = self._cursor_metadata(
+ self, cursor_description
+ )
+ if self._echo:
+ context.connection._log_debug(
+ "Col %r", tuple(x[0] for x in cursor_description)
+ )
+ return metadata
+
+ def _soft_close(self, hard=False):
+ """Soft close this :class:`_engine.CursorResult`.
+
+ This releases all DBAPI cursor resources, but leaves the
+ CursorResult "open" from a semantic perspective, meaning the
+ fetchXXX() methods will continue to return empty results.
+
+ This method is called automatically when:
+
+ * all result rows are exhausted using the fetchXXX() methods.
+ * cursor.description is None.
+
+ This method is **not public**, but is documented in order to clarify
+ the "autoclose" process used.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_engine.CursorResult.close`
+
+
+ """
+ if (not hard and self._soft_closed) or (hard and self.closed):
+ return
+
+ if hard:
+ self.closed = True
+ self.cursor_strategy.hard_close(self, self.cursor)
+ else:
+ self.cursor_strategy.soft_close(self, self.cursor)
+
+ if not self._soft_closed:
+ cursor = self.cursor
+ self.cursor = None
+ self.connection._safe_close_cursor(cursor)
+ self._soft_closed = True
+
+ @property
+ def inserted_primary_key_rows(self):
+ """Return the value of
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ as a row contained within a list; some dialects may support a
+ multiple row form as well.
+
+ .. note:: As indicated below, in current SQLAlchemy versions this
+ accessor is only useful beyond what's already supplied by
+ :attr:`_engine.CursorResult.inserted_primary_key` when using the
+ :ref:`postgresql_psycopg2` dialect. Future versions hope to
+ generalize this feature to more dialects.
+
+ This accessor is added to support dialects that offer the feature
+ that is currently implemented by the :ref:`psycopg2_executemany_mode`
+ feature, currently **only the psycopg2 dialect**, which provides
+ for many rows to be INSERTed at once while still retaining the
+ behavior of being able to return server-generated primary key values.
+
+ * **When using the psycopg2 dialect, or other dialects that may support
+ "fast executemany" style inserts in upcoming releases** : When
+ invoking an INSERT statement while passing a list of rows as the
+ second argument to :meth:`_engine.Connection.execute`, this accessor
+ will then provide a list of rows, where each row contains the primary
+ key value for each row that was INSERTed.
+
+ * **When using all other dialects / backends that don't yet support
+ this feature**: This accessor is only useful for **single row INSERT
+ statements**, and returns the same information as that of the
+ :attr:`_engine.CursorResult.inserted_primary_key` within a
+ single-element list. When an INSERT statement is executed in
+ conjunction with a list of rows to be INSERTed, the list will contain
+ one row per row inserted in the statement, however it will contain
+ ``None`` for any server-generated values.
+
+ Future releases of SQLAlchemy will further generalize the
+ "fast execution helper" feature of psycopg2 to suit other dialects,
+ thus allowing this accessor to be of more general use.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() " "expression construct."
+ )
+ elif self.context._is_explicit_returning:
+ raise exc.InvalidRequestError(
+ "Can't call inserted_primary_key "
+ "when returning() "
+ "is used."
+ )
+ return self.context.inserted_primary_key_rows
+
+ @property
+ def inserted_primary_key(self):
+ """Return the primary key for the row just inserted.
+
+ The return value is a :class:`_result.Row` object representing
+ a named tuple of primary key values in the order in which the
+ primary key columns are configured in the source
+ :class:`_schema.Table`.
+
+ .. versionchanged:: 1.4.8 - the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ value is now a named tuple via the :class:`_result.Row` class,
+ rather than a plain tuple.
+
+ This accessor only applies to single row :func:`_expression.insert`
+ constructs which did not explicitly specify
+ :meth:`_expression.Insert.returning`. Support for multirow inserts,
+ while not yet available for most backends, would be accessed using
+ the :attr:`_engine.CursorResult.inserted_primary_key_rows` accessor.
+
+ Note that primary key columns which specify a server_default clause, or
+ otherwise do not qualify as "autoincrement" columns (see the notes at
+ :class:`_schema.Column`), and were generated using the database-side
+ default, will appear in this list as ``None`` unless the backend
+ supports "returning" and the insert statement executed with the
+ "implicit returning" enabled.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() construct.
+
+ """
+
+ if self.context.executemany:
+ raise exc.InvalidRequestError(
+ "This statement was an executemany call; if primary key "
+ "returning is supported, please "
+ "use .inserted_primary_key_rows."
+ )
+
+ ikp = self.inserted_primary_key_rows
+ if ikp:
+ return ikp[0]
+ else:
+ return None
+
+ def last_updated_params(self):
+ """Return the collection of updated parameters from this
+ execution.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an update() construct.
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an update() " "expression construct."
+ )
+ elif self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
+
+ def last_inserted_params(self):
+ """Return the collection of inserted parameters from this
+ execution.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() construct.
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() " "expression construct."
+ )
+ elif self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
+
+ @property
+ def returned_defaults_rows(self):
+ """Return a list of rows each containing the values of default
+ columns that were fetched using
+ the :meth:`.ValuesBase.return_defaults` feature.
+
+ The return value is a list of :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ """
+ return self.context.returned_default_rows
+
+ @property
+ def returned_defaults(self):
+ """Return the values of default columns that were fetched using
+ the :meth:`.ValuesBase.return_defaults` feature.
+
+ The value is an instance of :class:`.Row`, or ``None``
+ if :meth:`.ValuesBase.return_defaults` was not used or if the
+ backend does not support RETURNING.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`.ValuesBase.return_defaults`
+
+ """
+
+ if self.context.executemany:
+ raise exc.InvalidRequestError(
+ "This statement was an executemany call; if return defaults "
+ "is supported, please use .returned_defaults_rows."
+ )
+
+ rows = self.context.returned_default_rows
+ if rows:
+ return rows[0]
+ else:
+ return None
+
+ def lastrow_has_defaults(self):
+ """Return ``lastrow_has_defaults()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ """
+
+ return self.context.lastrow_has_defaults()
+
+ def postfetch_cols(self):
+ """Return ``postfetch_cols()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() or update() construct.
+
+ """
+
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert and not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() or update() "
+ "expression construct."
+ )
+ return self.context.postfetch_cols
+
+ def prefetch_cols(self):
+ """Return ``prefetch_cols()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() or update() construct.
+
+ """
+
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert and not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() or update() "
+ "expression construct."
+ )
+ return self.context.prefetch_cols
+
+ def supports_sane_rowcount(self):
+ """Return ``supports_sane_rowcount`` from the dialect.
+
+ See :attr:`_engine.CursorResult.rowcount` for background.
+
+ """
+
+ return self.dialect.supports_sane_rowcount
+
+ def supports_sane_multi_rowcount(self):
+ """Return ``supports_sane_multi_rowcount`` from the dialect.
+
+ See :attr:`_engine.CursorResult.rowcount` for background.
+
+ """
+
+ return self.dialect.supports_sane_multi_rowcount
+
+ @util.memoized_property
+ def rowcount(self):
+ """Return the 'rowcount' for this result.
+
+ The 'rowcount' reports the number of rows *matched*
+ by the WHERE criterion of an UPDATE or DELETE statement.
+
+ .. note::
+
+ Notes regarding :attr:`_engine.CursorResult.rowcount`:
+
+
+ * This attribute returns the number of rows *matched*,
+ which is not necessarily the same as the number of rows
+ that were actually *modified* - an UPDATE statement, for example,
+ may have no net change on a given row if the SET values
+ given are the same as those present in the row already.
+ Such a row would be matched but not modified.
+ On backends that feature both styles, such as MySQL,
+ rowcount is configured by default to return the match
+ count in all cases.
+
+ * :attr:`_engine.CursorResult.rowcount`
+ is *only* useful in conjunction
+ with an UPDATE or DELETE statement. Contrary to what the Python
+ DBAPI says, it does *not* return the
+ number of rows available from the results of a SELECT statement
+ as DBAPIs cannot support this functionality when rows are
+ unbuffered.
+
+ * :attr:`_engine.CursorResult.rowcount`
+ may not be fully implemented by
+ all dialects. In particular, most DBAPIs do not support an
+ aggregate rowcount result from an executemany call.
+ The :meth:`_engine.CursorResult.supports_sane_rowcount` and
+ :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods
+ will report from the dialect if each usage is known to be
+ supported.
+
+ * Statements that use RETURNING may not return a correct
+ rowcount.
+
+ .. seealso::
+
+ :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+
+ try:
+ return self.context.rowcount
+ except BaseException as e:
+ self.cursor_strategy.handle_exception(self, self.cursor, e)
+
+ @property
+ def lastrowid(self):
+ """Return the 'lastrowid' accessor on the DBAPI cursor.
+
+ This is a DBAPI specific method and is only functional
+ for those backends which support it, for statements
+ where it is appropriate. It's behavior is not
+ consistent across backends.
+
+ Usage of this method is normally unnecessary when
+ using insert() expression constructs; the
+ :attr:`~CursorResult.inserted_primary_key` attribute provides a
+ tuple of primary key values for a newly inserted row,
+ regardless of database backend.
+
+ """
+ try:
+ return self.context.get_lastrowid()
+ except BaseException as e:
+ self.cursor_strategy.handle_exception(self, self.cursor, e)
+
+ @property
+ def returns_rows(self):
+ """True if this :class:`_engine.CursorResult` returns zero or more
+ rows.
+
+ I.e. if it is legal to call the methods
+ :meth:`_engine.CursorResult.fetchone`,
+ :meth:`_engine.CursorResult.fetchmany`
+ :meth:`_engine.CursorResult.fetchall`.
+
+ Overall, the value of :attr:`_engine.CursorResult.returns_rows` should
+ always be synonymous with whether or not the DBAPI cursor had a
+ ``.description`` attribute, indicating the presence of result columns,
+ noting that a cursor that returns zero rows still has a
+ ``.description`` if a row-returning statement was emitted.
+
+ This attribute should be True for all results that are against
+ SELECT statements, as well as for DML statements INSERT/UPDATE/DELETE
+ that use RETURNING. For INSERT/UPDATE/DELETE statements that were
+ not using RETURNING, the value will usually be False, however
+ there are some dialect-specific exceptions to this, such as when
+ using the MSSQL / pyodbc dialect a SELECT is emitted inline in
+ order to retrieve an inserted primary key value.
+
+
+ """
+ return self._metadata.returns_rows
+
+ @property
+ def is_insert(self):
+ """True if this :class:`_engine.CursorResult` is the result
+ of a executing an expression language compiled
+ :func:`_expression.insert` construct.
+
+ When True, this implies that the
+ :attr:`inserted_primary_key` attribute is accessible,
+ assuming the statement did not include
+ a user defined "returning" construct.
+
+ """
+ return self.context.isinsert
+
+
+class CursorResult(BaseCursorResult, Result):
+ """A Result that is representing state from a DBAPI cursor.
+
+ .. versionchanged:: 1.4 The :class:`.CursorResult` and
+ :class:`.LegacyCursorResult`
+ classes replace the previous :class:`.ResultProxy` interface.
+ These classes are based on the :class:`.Result` calling API
+ which provides an updated usage model and calling facade for
+ SQLAlchemy Core and SQLAlchemy ORM.
+
+ Returns database rows via the :class:`.Row` class, which provides
+ additional API features and behaviors on top of the raw data returned by
+ the DBAPI. Through the use of filters such as the :meth:`.Result.scalars`
+ method, other kinds of objects may also be returned.
+
+ Within the scope of the 1.x series of SQLAlchemy, Core SQL results in
+ version 1.4 return an instance of :class:`._engine.LegacyCursorResult`
+ which takes the place of the ``CursorResult`` class used for the 1.3 series
+ and previously. This object returns rows as :class:`.LegacyRow` objects,
+ which maintains Python mapping (i.e. dictionary) like behaviors upon the
+ object itself. Going forward, the :attr:`.Row._mapping` attribute should
+ be used for dictionary behaviors.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - introductory material for accessing
+ :class:`_engine.CursorResult` and :class:`.Row` objects.
+
+ """
+
+ _cursor_metadata = CursorResultMetaData
+ _cursor_strategy_cls = CursorFetchStrategy
+ _no_result_metadata = _NO_RESULT_METADATA
+ _is_cursor = True
+
+ def _fetchiter_impl(self):
+ fetchone = self.cursor_strategy.fetchone
+
+ while True:
+ row = fetchone(self, self.cursor)
+ if row is None:
+ break
+ yield row
+
+ def _fetchone_impl(self, hard_close=False):
+ return self.cursor_strategy.fetchone(self, self.cursor, hard_close)
+
+ def _fetchall_impl(self):
+ return self.cursor_strategy.fetchall(self, self.cursor)
+
+ def _fetchmany_impl(self, size=None):
+ return self.cursor_strategy.fetchmany(self, self.cursor, size)
+
+ def _raw_row_iterator(self):
+ return self._fetchiter_impl()
+
+ def merge(self, *others):
+ merged_result = super(CursorResult, self).merge(*others)
+ setup_rowcounts = not self._metadata.returns_rows
+ if setup_rowcounts:
+ merged_result.rowcount = sum(
+ result.rowcount for result in (self,) + others
+ )
+ return merged_result
+
+ def close(self):
+ """Close this :class:`_engine.CursorResult`.
+
+ This closes out the underlying DBAPI cursor corresponding to the
+ statement execution, if one is still present. Note that the DBAPI
+ cursor is automatically released when the :class:`_engine.CursorResult`
+ exhausts all available rows. :meth:`_engine.CursorResult.close` is
+ generally an optional method except in the case when discarding a
+ :class:`_engine.CursorResult` that still has additional rows pending
+ for fetch.
+
+ After this method is called, it is no longer valid to call upon
+ the fetch methods, which will raise a :class:`.ResourceClosedError`
+ on subsequent use.
+
+ .. seealso::
+
+ :ref:`connections_toplevel`
+
+ """
+ self._soft_close(hard=True)
+
+ @_generative
+ def yield_per(self, num):
+ self._yield_per = num
+ self.cursor_strategy.yield_per(self, self.cursor, num)
+
+
+class LegacyCursorResult(CursorResult):
+ """Legacy version of :class:`.CursorResult`.
+
+ This class includes connection "connection autoclose" behavior for use with
+ "connectionless" execution, as well as delivers rows using the
+ :class:`.LegacyRow` row implementation.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _autoclose_connection = False
+ _process_row = LegacyRow
+ _cursor_metadata = LegacyCursorResultMetaData
+ _cursor_strategy_cls = CursorFetchStrategy
+
+ _no_result_metadata = _LEGACY_NO_RESULT_METADATA
+
+ def close(self):
+ """Close this :class:`_engine.LegacyCursorResult`.
+
+ This method has the same behavior as that of
+ :meth:`._engine.CursorResult`, but it also may close
+ the underlying :class:`.Connection` for the case of "connectionless"
+ execution.
+
+ .. deprecated:: 2.0 "connectionless" execution is deprecated and will
+ be removed in version 2.0. Version 2.0 will feature the
+ :class:`_future.Result`
+ object that will no longer affect the status
+ of the originating connection in any case.
+
+ After this method is called, it is no longer valid to call upon
+ the fetch methods, which will raise a :class:`.ResourceClosedError`
+ on subsequent use.
+
+ .. seealso::
+
+ :ref:`connections_toplevel`
+
+ :ref:`dbengine_implicit`
+ """
+ self._soft_close(hard=True)
+
+ def _soft_close(self, hard=False):
+ soft_closed = self._soft_closed
+ super(LegacyCursorResult, self)._soft_close(hard=hard)
+ if (
+ not soft_closed
+ and self._soft_closed
+ and self._autoclose_connection
+ ):
+ self.connection.close()
+
+
+ResultProxy = LegacyCursorResult
+
+
+class BufferedRowResultProxy(ResultProxy):
+ """A ResultProxy with row buffering behavior.
+
+ .. deprecated:: 1.4 this class is now supplied using a strategy object.
+ See :class:`.BufferedRowCursorFetchStrategy`.
+
+ """
+
+ _cursor_strategy_cls = BufferedRowCursorFetchStrategy
+
+
+class FullyBufferedResultProxy(ResultProxy):
+ """A result proxy that buffers rows fully upon creation.
+
+ .. deprecated:: 1.4 this class is now supplied using a strategy object.
+ See :class:`.FullyBufferedCursorFetchStrategy`.
+
+ """
+
+ _cursor_strategy_cls = FullyBufferedCursorFetchStrategy
+
+
+class BufferedColumnRow(LegacyRow):
+ """Row is now BufferedColumn in all cases"""
+
+
+class BufferedColumnResultProxy(ResultProxy):
+ """A ResultProxy with column buffering behavior.
+
+ .. versionchanged:: 1.4 This is now the default behavior of the Row
+ and this class does not change behavior in any way.
+
+ """
+
+ _process_row = BufferedColumnRow
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
new file mode 100644
index 0000000..268a2d6
--- /dev/null
+++ b/lib/sqlalchemy/engine/default.py
@@ -0,0 +1,1936 @@
+# engine/default.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Default implementations of per-dialect sqlalchemy.engine classes.
+
+These are semi-private implementation classes which are only of importance
+to database dialect authors; dialects will usually use the classes here
+as the base class for their own corresponding classes.
+
+"""
+
+import codecs
+import functools
+import random
+import re
+import weakref
+
+from . import characteristics
+from . import cursor as _cursor
+from . import interfaces
+from .base import Connection
+from .. import event
+from .. import exc
+from .. import pool
+from .. import processors
+from .. import types as sqltypes
+from .. import util
+from ..sql import compiler
+from ..sql import expression
+from ..sql.elements import quoted_name
+
+AUTOCOMMIT_REGEXP = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
+)
+
+# When we're handed literal SQL, ensure it's a SELECT query
+SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
+
+
+CACHE_HIT = util.symbol("CACHE_HIT")
+CACHE_MISS = util.symbol("CACHE_MISS")
+CACHING_DISABLED = util.symbol("CACHING_DISABLED")
+NO_CACHE_KEY = util.symbol("NO_CACHE_KEY")
+NO_DIALECT_SUPPORT = util.symbol("NO_DIALECT_SUPPORT")
+
+
+class DefaultDialect(interfaces.Dialect):
+ """Default implementation of Dialect"""
+
+ statement_compiler = compiler.SQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.GenericTypeCompiler
+ preparer = compiler.IdentifierPreparer
+ supports_alter = True
+ supports_comments = False
+ inline_comments = False
+ use_setinputsizes = False
+ supports_statement_cache = True
+
+ # the first value we'd get for an autoincrement
+ # column.
+ default_sequence_base = 1
+
+ # most DBAPIs happy with this for execute().
+ # not cx_oracle.
+ execute_sequence_format = tuple
+
+ supports_schemas = True
+ supports_views = True
+ supports_sequences = False
+ sequences_optional = False
+ preexecute_autoincrement_sequences = False
+ supports_identity_columns = False
+ postfetch_lastrowid = True
+ implicit_returning = False
+ full_returning = False
+ insert_executemany_returning = False
+
+ cte_follows_insert = False
+
+ supports_native_enum = False
+ supports_native_boolean = False
+ non_native_boolean_check_constraint = True
+
+ supports_simple_order_by_label = True
+
+ tuple_in_values = False
+
+ connection_characteristics = util.immutabledict(
+ {"isolation_level": characteristics.IsolationLevelCharacteristic()}
+ )
+
+ engine_config_types = util.immutabledict(
+ [
+ ("convert_unicode", util.bool_or_str("force")),
+ ("pool_timeout", util.asint),
+ ("echo", util.bool_or_str("debug")),
+ ("echo_pool", util.bool_or_str("debug")),
+ ("pool_recycle", util.asint),
+ ("pool_size", util.asint),
+ ("max_overflow", util.asint),
+ ("future", util.asbool),
+ ]
+ )
+
+ # if the NUMERIC type
+ # returns decimal.Decimal.
+ # *not* the FLOAT type however.
+ supports_native_decimal = False
+
+ if util.py3k:
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+ returns_unicode_strings = sqltypes.String.RETURNS_UNICODE
+ description_encoding = None
+ else:
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+ returns_unicode_strings = sqltypes.String.RETURNS_UNKNOWN
+ description_encoding = "use_encoding"
+
+ name = "default"
+
+ # length at which to truncate
+ # any identifier.
+ max_identifier_length = 9999
+ _user_defined_max_identifier_length = None
+
+ isolation_level = None
+
+ # sub-categories of max_identifier_length.
+ # currently these accommodate for MySQL which allows alias names
+ # of 255 but DDL names only of 64.
+ max_index_name_length = None
+ max_constraint_name_length = None
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ colspecs = {}
+ default_paramstyle = "named"
+
+ supports_default_values = False
+ """dialect supports INSERT... DEFAULT VALUES syntax"""
+
+ supports_default_metavalue = False
+ """dialect supports INSERT... VALUES (DEFAULT) syntax"""
+
+ # not sure if this is a real thing but the compiler will deliver it
+ # if this is the only flag enabled.
+ supports_empty_insert = True
+ """dialect supports INSERT () VALUES ()"""
+
+ supports_multivalues_insert = False
+
+ supports_is_distinct_from = True
+
+ supports_server_side_cursors = False
+
+ server_side_cursors = False
+
+ # extra record-level locking features (#4860)
+ supports_for_update_of = False
+
+ server_version_info = None
+
+ default_schema_name = None
+
+ construct_arguments = None
+ """Optional set of argument specifiers for various SQLAlchemy
+ constructs, typically schema items.
+
+ To implement, establish as a series of tuples, as in::
+
+ construct_arguments = [
+ (schema.Index, {
+ "using": False,
+ "where": None,
+ "ops": None
+ })
+ ]
+
+ If the above construct is established on the PostgreSQL dialect,
+ the :class:`.Index` construct will now accept the keyword arguments
+ ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``.
+ Any other argument specified to the constructor of :class:`.Index`
+ which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`.
+
+ A dialect which does not include a ``construct_arguments`` member will
+ not participate in the argument validation system. For such a dialect,
+ any argument name is accepted by all participating constructs, within
+ the namespace of arguments prefixed with that dialect name. The rationale
+ here is so that third-party dialects that haven't yet implemented this
+ feature continue to function in the old way.
+
+ .. versionadded:: 0.9.2
+
+ .. seealso::
+
+ :class:`.DialectKWArgs` - implementing base class which consumes
+ :attr:`.DefaultDialect.construct_arguments`
+
+
+ """
+
+ # indicates symbol names are
+ # UPPERCASEd if they are case insensitive
+ # within the database.
+ # if this is True, the methods normalize_name()
+ # and denormalize_name() must be provided.
+ requires_name_normalize = False
+
+ reflection_options = ()
+
+ dbapi_exception_translation_map = util.immutabledict()
+ """mapping used in the extremely unusual case that a DBAPI's
+ published exceptions don't actually have the __name__ that they
+ are linked towards.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ is_async = False
+
+ CACHE_HIT = CACHE_HIT
+ CACHE_MISS = CACHE_MISS
+ CACHING_DISABLED = CACHING_DISABLED
+ NO_CACHE_KEY = NO_CACHE_KEY
+ NO_DIALECT_SUPPORT = NO_DIALECT_SUPPORT
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`_sa.create_engine.convert_unicode` parameter "
+ "and corresponding dialect-level parameters are deprecated, "
+ "and will be removed in a future release. Modern DBAPIs support "
+ "Python Unicode natively and this parameter is unnecessary.",
+ ),
+ empty_in_strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is "
+ "deprecated, and no longer has any effect. All IN expressions "
+ "are now rendered using "
+ 'the "expanding parameter" strategy which renders a set of bound'
+ 'expressions, or an "empty set" SELECT, at statement execution'
+ "time.",
+ ),
+ case_sensitive=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.case_sensitive` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Applications should work with result column names in a case "
+ "sensitive fashion.",
+ ),
+ server_side_cursors=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.server_side_cursors` parameter "
+ "is deprecated and will be removed in a future release. Please "
+ "use the "
+ ":paramref:`_engine.Connection.execution_options.stream_results` "
+ "parameter.",
+ ),
+ )
+ def __init__(
+ self,
+ convert_unicode=False,
+ encoding="utf-8",
+ paramstyle=None,
+ dbapi=None,
+ implicit_returning=None,
+ case_sensitive=True,
+ supports_native_boolean=None,
+ max_identifier_length=None,
+ label_length=None,
+ # int() is because the @deprecated_params decorator cannot accommodate
+ # the direct reference to the "NO_LINTING" object
+ compiler_linting=int(compiler.NO_LINTING),
+ server_side_cursors=False,
+ **kwargs
+ ):
+
+ if not getattr(self, "ported_sqla_06", True):
+ util.warn(
+ "The %s dialect is not yet ported to the 0.6 format"
+ % self.name
+ )
+
+ if server_side_cursors:
+ if not self.supports_server_side_cursors:
+ raise exc.ArgumentError(
+ "Dialect %s does not support server side cursors" % self
+ )
+ else:
+ self.server_side_cursors = True
+
+ self.convert_unicode = convert_unicode
+ self.encoding = encoding
+ self.positional = False
+ self._ischema = None
+ self.dbapi = dbapi
+ if paramstyle is not None:
+ self.paramstyle = paramstyle
+ elif self.dbapi is not None:
+ self.paramstyle = self.dbapi.paramstyle
+ else:
+ self.paramstyle = self.default_paramstyle
+ if implicit_returning is not None:
+ self.implicit_returning = implicit_returning
+ self.positional = self.paramstyle in ("qmark", "format", "numeric")
+ self.identifier_preparer = self.preparer(self)
+ self.type_compiler = self.type_compiler(self)
+ if supports_native_boolean is not None:
+ self.supports_native_boolean = supports_native_boolean
+ self.case_sensitive = case_sensitive
+
+ self._user_defined_max_identifier_length = max_identifier_length
+ if self._user_defined_max_identifier_length:
+ self.max_identifier_length = (
+ self._user_defined_max_identifier_length
+ )
+ self.label_length = label_length
+ self.compiler_linting = compiler_linting
+ if self.description_encoding == "use_encoding":
+ self._description_decoder = (
+ processors.to_unicode_processor_factory
+ )(encoding)
+ elif self.description_encoding is not None:
+ self._description_decoder = (
+ processors.to_unicode_processor_factory
+ )(self.description_encoding)
+ self._encoder = codecs.getencoder(self.encoding)
+ self._decoder = processors.to_unicode_processor_factory(self.encoding)
+
+ def _ensure_has_table_connection(self, arg):
+
+ if not isinstance(arg, Connection):
+ raise exc.ArgumentError(
+ "The argument passed to Dialect.has_table() should be a "
+ "%s, got %s. "
+ "Additionally, the Dialect.has_table() method is for "
+ "internal dialect "
+ "use only; please use "
+ "``inspect(some_engine).has_table(<tablename>>)`` "
+ "for public API use." % (Connection, type(arg))
+ )
+
+ @util.memoized_property
+ def _supports_statement_cache(self):
+ ssc = self.__class__.__dict__.get("supports_statement_cache", None)
+ if ssc is None:
+ util.warn(
+ "Dialect %s:%s will not make use of SQL compilation caching "
+ "as it does not set the 'supports_statement_cache' attribute "
+ "to ``True``. This can have "
+ "significant performance implications including some "
+ "performance degradations in comparison to prior SQLAlchemy "
+ "versions. Dialect maintainers should seek to set this "
+ "attribute to True after appropriate development and testing "
+ "for SQLAlchemy 1.4 caching support. Alternatively, this "
+ "attribute may be set to False which will disable this "
+ "warning." % (self.name, self.driver),
+ code="cprf",
+ )
+
+ return bool(ssc)
+
+ @util.memoized_property
+ def _type_memos(self):
+ return weakref.WeakKeyDictionary()
+
+ @property
+ def dialect_description(self):
+ return self.name + "+" + self.driver
+
+ @property
+ def supports_sane_rowcount_returning(self):
+ """True if this dialect supports sane rowcount even if RETURNING is
+ in use.
+
+ For dialects that don't support RETURNING, this is synonymous with
+ ``supports_sane_rowcount``.
+
+ """
+ return self.supports_sane_rowcount
+
+ @classmethod
+ def get_pool_class(cls, url):
+ return getattr(cls, "poolclass", pool.QueuePool)
+
+ def get_dialect_pool_class(self, url):
+ return self.get_pool_class(url)
+
+ @classmethod
+ def load_provisioning(cls):
+ package = ".".join(cls.__module__.split(".")[0:-1])
+ try:
+ __import__(package + ".provision")
+ except ImportError:
+ pass
+
+ def initialize(self, connection):
+ try:
+ self.server_version_info = self._get_server_version_info(
+ connection
+ )
+ except NotImplementedError:
+ self.server_version_info = None
+ try:
+ self.default_schema_name = self._get_default_schema_name(
+ connection
+ )
+ except NotImplementedError:
+ self.default_schema_name = None
+
+ try:
+ self.default_isolation_level = self.get_default_isolation_level(
+ connection.connection
+ )
+ except NotImplementedError:
+ self.default_isolation_level = None
+
+ if self.returns_unicode_strings is sqltypes.String.RETURNS_UNKNOWN:
+ if util.py3k:
+ raise exc.InvalidRequestError(
+ "RETURNS_UNKNOWN is unsupported in Python 3"
+ )
+ self.returns_unicode_strings = self._check_unicode_returns(
+ connection
+ )
+
+ if (
+ self.description_encoding is not None
+ and self._check_unicode_description(connection)
+ ):
+ self._description_decoder = self.description_encoding = None
+
+ if not self._user_defined_max_identifier_length:
+ max_ident_length = self._check_max_identifier_length(connection)
+ if max_ident_length:
+ self.max_identifier_length = max_ident_length
+
+ if (
+ self.label_length
+ and self.label_length > self.max_identifier_length
+ ):
+ raise exc.ArgumentError(
+ "Label length of %d is greater than this dialect's"
+ " maximum identifier length of %d"
+ % (self.label_length, self.max_identifier_length)
+ )
+
+ def on_connect(self):
+ # inherits the docstring from interfaces.Dialect.on_connect
+ return None
+
+ def _check_max_identifier_length(self, connection):
+ """Perform a connection / server version specific check to determine
+ the max_identifier_length.
+
+ If the dialect's class level max_identifier_length should be used,
+ can return None.
+
+ .. versionadded:: 1.3.9
+
+ """
+ return None
+
+ def get_default_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level, or
+ a default isolation level if one cannot be retrieved.
+
+ May be overridden by subclasses in order to provide a
+ "fallback" isolation level for databases that cannot reliably
+ retrieve the actual isolation level.
+
+ By default, calls the :meth:`_engine.Interfaces.get_isolation_level`
+ method, propagating any exceptions raised.
+
+ .. versionadded:: 1.3.22
+
+ """
+ return self.get_isolation_level(dbapi_conn)
+
+ def _check_unicode_returns(self, connection, additional_tests=None):
+ # this now runs in py2k only and will be removed in 2.0; disabled for
+ # Python 3 in all cases under #5315
+ if util.py2k and not self.supports_unicode_statements:
+ cast_to = util.binary_type
+ else:
+ cast_to = util.text_type
+
+ if self.positional:
+ parameters = self.execute_sequence_format()
+ else:
+ parameters = {}
+
+ def check_unicode(test):
+ statement = cast_to(expression.select(test).compile(dialect=self))
+ try:
+ cursor = connection.connection.cursor()
+ connection._cursor_execute(cursor, statement, parameters)
+ row = cursor.fetchone()
+ cursor.close()
+ except exc.DBAPIError as de:
+ # note that _cursor_execute() will have closed the cursor
+ # if an exception is thrown.
+ util.warn(
+ "Exception attempting to "
+ "detect unicode returns: %r" % de
+ )
+ return False
+ else:
+ return isinstance(row[0], util.text_type)
+
+ tests = [
+ # detect plain VARCHAR
+ expression.cast(
+ expression.literal_column("'test plain returns'"),
+ sqltypes.VARCHAR(60),
+ ),
+ # detect if there's an NVARCHAR type with different behavior
+ # available
+ expression.cast(
+ expression.literal_column("'test unicode returns'"),
+ sqltypes.Unicode(60),
+ ),
+ ]
+
+ if additional_tests:
+ tests += additional_tests
+
+ results = {check_unicode(test) for test in tests}
+
+ if results.issuperset([True, False]):
+ return sqltypes.String.RETURNS_CONDITIONAL
+ else:
+ return (
+ sqltypes.String.RETURNS_UNICODE
+ if results == {True}
+ else sqltypes.String.RETURNS_BYTES
+ )
+
+ def _check_unicode_description(self, connection):
+ # all DBAPIs on Py2K return cursor.description as encoded
+
+ if util.py2k and not self.supports_unicode_statements:
+ cast_to = util.binary_type
+ else:
+ cast_to = util.text_type
+
+ cursor = connection.connection.cursor()
+ try:
+ cursor.execute(
+ cast_to(
+ expression.select(
+ expression.literal_column("'x'").label("some_label")
+ ).compile(dialect=self)
+ )
+ )
+ return isinstance(cursor.description[0][0], util.text_type)
+ finally:
+ cursor.close()
+
+ def type_descriptor(self, typeobj):
+ """Provide a database-specific :class:`.TypeEngine` object, given
+ the generic object which comes from the types module.
+
+ This method looks for a dictionary called
+ ``colspecs`` as a class or instance-level variable,
+ and passes on to :func:`_types.adapt_type`.
+
+ """
+ return sqltypes.adapt_type(typeobj, self.colspecs)
+
+ def has_index(self, connection, table_name, index_name, schema=None):
+ if not self.has_table(connection, table_name, schema=schema):
+ return False
+ for idx in self.get_indexes(connection, table_name, schema=schema):
+ if idx["name"] == index_name:
+ return True
+ else:
+ return False
+
+ def validate_identifier(self, ident):
+ if len(ident) > self.max_identifier_length:
+ raise exc.IdentifierError(
+ "Identifier '%s' exceeds maximum length of %d characters"
+ % (ident, self.max_identifier_length)
+ )
+
+ def connect(self, *cargs, **cparams):
+ # inherits the docstring from interfaces.Dialect.connect
+ return self.dbapi.connect(*cargs, **cparams)
+
+ def create_connect_args(self, url):
+ # inherits the docstring from interfaces.Dialect.create_connect_args
+ opts = url.translate_connect_args()
+ opts.update(url.query)
+ return [[], opts]
+
+ def set_engine_execution_options(self, engine, opts):
+ supported_names = set(self.connection_characteristics).intersection(
+ opts
+ )
+ if supported_names:
+ characteristics = util.immutabledict(
+ (name, opts[name]) for name in supported_names
+ )
+
+ @event.listens_for(engine, "engine_connect")
+ def set_connection_characteristics(connection, branch):
+ if not branch:
+ self._set_connection_characteristics(
+ connection, characteristics
+ )
+
+ def set_connection_execution_options(self, connection, opts):
+ supported_names = set(self.connection_characteristics).intersection(
+ opts
+ )
+ if supported_names:
+ characteristics = util.immutabledict(
+ (name, opts[name]) for name in supported_names
+ )
+ self._set_connection_characteristics(connection, characteristics)
+
+ def _set_connection_characteristics(self, connection, characteristics):
+
+ characteristic_values = [
+ (name, self.connection_characteristics[name], value)
+ for name, value in characteristics.items()
+ ]
+
+ if connection.in_transaction():
+ trans_objs = [
+ (name, obj)
+ for name, obj, value in characteristic_values
+ if obj.transactional
+ ]
+ if trans_objs:
+ if connection._is_future:
+ raise exc.InvalidRequestError(
+ "This connection has already initialized a SQLAlchemy "
+ "Transaction() object via begin() or autobegin; "
+ "%s may not be altered unless rollback() or commit() "
+ "is called first."
+ % (", ".join(name for name, obj in trans_objs))
+ )
+ else:
+ util.warn(
+ "Connection is already established with a "
+ "Transaction; "
+ "setting %s may implicitly rollback or "
+ "commit "
+ "the existing transaction, or have no effect until "
+ "next transaction"
+ % (", ".join(name for name, obj in trans_objs))
+ )
+
+ dbapi_connection = connection.connection.dbapi_connection
+ for name, characteristic, value in characteristic_values:
+ characteristic.set_characteristic(self, dbapi_connection, value)
+ connection.connection._connection_record.finalize_callback.append(
+ functools.partial(self._reset_characteristics, characteristics)
+ )
+
+ def _reset_characteristics(self, characteristics, dbapi_connection):
+ for characteristic_name in characteristics:
+ characteristic = self.connection_characteristics[
+ characteristic_name
+ ]
+ characteristic.reset_characteristic(self, dbapi_connection)
+
+ def do_begin(self, dbapi_connection):
+ pass
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback()
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit()
+
+ def do_close(self, dbapi_connection):
+ dbapi_connection.close()
+
+ @util.memoized_property
+ def _dialect_specific_select_one(self):
+ return str(expression.select(1).compile(dialect=self))
+
+ def do_ping(self, dbapi_connection):
+ cursor = None
+ try:
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(self._dialect_specific_select_one)
+ finally:
+ cursor.close()
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, cursor):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def create_xid(self):
+ """Create a random two-phase transaction ID.
+
+ This id will be passed to do_begin_twophase(), do_rollback_twophase(),
+ do_commit_twophase(). Its format is unspecified.
+ """
+
+ return "_sa_%032x" % random.randint(0, 2 ** 128)
+
+ def do_savepoint(self, connection, name):
+ connection.execute(expression.SavepointClause(name))
+
+ def do_rollback_to_savepoint(self, connection, name):
+ connection.execute(expression.RollbackToSavepointClause(name))
+
+ def do_release_savepoint(self, connection, name):
+ connection.execute(expression.ReleaseSavepointClause(name))
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ cursor.executemany(statement, parameters)
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ cursor.execute(statement, parameters)
+
+ def do_execute_no_params(self, cursor, statement, context=None):
+ cursor.execute(statement)
+
+ def is_disconnect(self, e, connection, cursor):
+ return False
+
+ def reset_isolation_level(self, dbapi_conn):
+ # default_isolation_level is read from the first connection
+ # after the initial set of 'isolation_level', if any, so is
+ # the configured default of this dialect.
+ self.set_isolation_level(dbapi_conn, self.default_isolation_level)
+
+ def normalize_name(self, name):
+ if name is None:
+ return None
+ if util.py2k:
+ if isinstance(name, str):
+ name = name.decode(self.encoding)
+
+ name_lower = name.lower()
+ name_upper = name.upper()
+
+ if name_upper == name_lower:
+ # name has no upper/lower conversion, e.g. non-european characters.
+ # return unchanged
+ return name
+ elif name_upper == name and not (
+ self.identifier_preparer._requires_quotes
+ )(name_lower):
+ # name is all uppercase and doesn't require quoting; normalize
+ # to all lower case
+ return name_lower
+ elif name_lower == name:
+ # name is all lower case, which if denormalized means we need to
+ # force quoting on it
+ return quoted_name(name, quote=True)
+ else:
+ # name is mixed case, means it will be quoted in SQL when used
+ # later, no normalizes
+ return name
+
+ def denormalize_name(self, name):
+ if name is None:
+ return None
+
+ name_lower = name.lower()
+ name_upper = name.upper()
+
+ if name_upper == name_lower:
+ # name has no upper/lower conversion, e.g. non-european characters.
+ # return unchanged
+ return name
+ elif name_lower == name and not (
+ self.identifier_preparer._requires_quotes
+ )(name_lower):
+ name = name_upper
+ if util.py2k:
+ if not self.supports_unicode_binds:
+ name = name.encode(self.encoding)
+ else:
+ name = unicode(name) # noqa
+ return name
+
+ def get_driver_connection(self, connection):
+ return connection
+
+
+class _RendersLiteral(object):
+ def literal_processor(self, dialect):
+ def process(value):
+ return "'%s'" % value
+
+ return process
+
+
+class _StrDateTime(_RendersLiteral, sqltypes.DateTime):
+ pass
+
+
+class _StrDate(_RendersLiteral, sqltypes.Date):
+ pass
+
+
+class _StrTime(_RendersLiteral, sqltypes.Time):
+ pass
+
+
+class StrCompileDialect(DefaultDialect):
+
+ statement_compiler = compiler.StrSQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.StrSQLTypeCompiler
+ preparer = compiler.IdentifierPreparer
+
+ supports_statement_cache = True
+
+ supports_identity_columns = True
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = False
+ implicit_returning = False
+
+ supports_native_boolean = True
+
+ supports_multivalues_insert = True
+ supports_simple_order_by_label = True
+
+ colspecs = {
+ sqltypes.DateTime: _StrDateTime,
+ sqltypes.Date: _StrDate,
+ sqltypes.Time: _StrTime,
+ }
+
+
+class DefaultExecutionContext(interfaces.ExecutionContext):
+ isinsert = False
+ isupdate = False
+ isdelete = False
+ is_crud = False
+ is_text = False
+ isddl = False
+ executemany = False
+ compiled = None
+ statement = None
+ result_column_struct = None
+ returned_default_rows = None
+ execution_options = util.immutabledict()
+
+ include_set_input_sizes = None
+ exclude_set_input_sizes = None
+
+ cursor_fetch_strategy = _cursor._DEFAULT_FETCH
+
+ cache_stats = None
+ invoked_statement = None
+
+ _is_implicit_returning = False
+ _is_explicit_returning = False
+ _is_future_result = False
+ _is_server_side = False
+
+ _soft_closed = False
+
+ # a hook for SQLite's translation of
+ # result column names
+ # NOTE: pyhive is using this hook, can't remove it :(
+ _translate_colname = None
+
+ _expanded_parameters = util.immutabledict()
+
+ cache_hit = NO_CACHE_KEY
+
+ @classmethod
+ def _init_ddl(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ compiled_ddl,
+ ):
+ """Initialize execution context for a DDLElement construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+
+ self.compiled = compiled = compiled_ddl
+ self.isddl = True
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.unicode_statement = util.text_type(compiled)
+ if compiled.schema_translate_map:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, schema_translate_map
+ )
+
+ if not dialect.supports_unicode_statements:
+ self.statement = dialect._encoder(self.unicode_statement)[0]
+ else:
+ self.statement = self.unicode_statement
+
+ self.cursor = self.create_cursor()
+ self.compiled_parameters = []
+
+ if dialect.positional:
+ self.parameters = [dialect.execute_sequence_format()]
+ else:
+ self.parameters = [{}]
+
+ return self
+
+ @classmethod
+ def _init_compiled(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ compiled,
+ parameters,
+ invoked_statement,
+ extracted_parameters,
+ cache_hit=CACHING_DISABLED,
+ ):
+ """Initialize execution context for a Compiled construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+ self.extracted_parameters = extracted_parameters
+ self.invoked_statement = invoked_statement
+ self.compiled = compiled
+ self.cache_hit = cache_hit
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.result_column_struct = (
+ compiled._result_columns,
+ compiled._ordered_columns,
+ compiled._textual_ordered_columns,
+ compiled._loose_column_name_matching,
+ )
+ self.isinsert = compiled.isinsert
+ self.isupdate = compiled.isupdate
+ self.isdelete = compiled.isdelete
+ self.is_text = compiled.isplaintext
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ self.is_crud = True
+ self._is_explicit_returning = bool(compiled.statement._returning)
+ self._is_implicit_returning = bool(
+ compiled.returning and not compiled.statement._returning
+ )
+
+ if not parameters:
+ self.compiled_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters,
+ escape_names=False,
+ )
+ ]
+ else:
+ self.compiled_parameters = [
+ compiled.construct_params(
+ m,
+ escape_names=False,
+ _group_number=grp,
+ extracted_parameters=extracted_parameters,
+ )
+ for grp, m in enumerate(parameters)
+ ]
+
+ self.executemany = len(parameters) > 1
+
+ # this must occur before create_cursor() since the statement
+ # has to be regexed in some cases for server side cursor
+ if util.py2k:
+ self.unicode_statement = util.text_type(compiled.string)
+ else:
+ self.unicode_statement = compiled.string
+
+ self.cursor = self.create_cursor()
+
+ if self.compiled.insert_prefetch or self.compiled.update_prefetch:
+ if self.executemany:
+ self._process_executemany_defaults()
+ else:
+ self._process_executesingle_defaults()
+
+ processors = compiled._bind_processors
+
+ if compiled.literal_execute_params or compiled.post_compile_params:
+ if self.executemany:
+ raise exc.InvalidRequestError(
+ "'literal_execute' or 'expanding' parameters can't be "
+ "used with executemany()"
+ )
+
+ expanded_state = compiled._process_parameters_for_postcompile(
+ self.compiled_parameters[0]
+ )
+
+ # re-assign self.unicode_statement
+ self.unicode_statement = expanded_state.statement
+
+ # used by set_input_sizes() which is needed for Oracle
+ self._expanded_parameters = expanded_state.parameter_expansion
+
+ processors = dict(processors)
+ processors.update(expanded_state.processors)
+ positiontup = expanded_state.positiontup
+ elif compiled.positional:
+ positiontup = self.compiled.positiontup
+
+ if compiled.schema_translate_map:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, schema_translate_map
+ )
+
+ # final self.unicode_statement is now assigned, encode if needed
+ # by dialect
+ if not dialect.supports_unicode_statements:
+ self.statement = self.unicode_statement.encode(
+ self.dialect.encoding
+ )
+ else:
+ self.statement = self.unicode_statement
+
+ # Convert the dictionary of bind parameter values
+ # into a dict or list to be sent to the DBAPI's
+ # execute() or executemany() method.
+ parameters = []
+ if compiled.positional:
+ for compiled_params in self.compiled_parameters:
+ param = [
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in positiontup
+ ]
+ parameters.append(dialect.execute_sequence_format(param))
+ else:
+ encode = not dialect.supports_unicode_statements
+ if encode:
+ encoder = dialect._encoder
+ for compiled_params in self.compiled_parameters:
+ escaped_bind_names = compiled.escaped_bind_names
+
+ if encode:
+ if escaped_bind_names:
+ param = {
+ encoder(escaped_bind_names.get(key, key))[
+ 0
+ ]: processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ param = {
+ encoder(key)[0]: processors[key](
+ compiled_params[key]
+ )
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ if escaped_bind_names:
+ param = {
+ escaped_bind_names.get(key, key): processors[key](
+ compiled_params[key]
+ )
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ param = {
+ key: processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+
+ parameters.append(param)
+
+ self.parameters = dialect.execute_sequence_format(parameters)
+
+ return self
+
+ @classmethod
+ def _init_statement(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ statement,
+ parameters,
+ ):
+ """Initialize execution context for a string SQL statement."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+ self.is_text = True
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ if not parameters:
+ if self.dialect.positional:
+ self.parameters = [dialect.execute_sequence_format()]
+ else:
+ self.parameters = [{}]
+ elif isinstance(parameters[0], dialect.execute_sequence_format):
+ self.parameters = parameters
+ elif isinstance(parameters[0], dict):
+ if dialect.supports_unicode_statements:
+ self.parameters = parameters
+ else:
+ self.parameters = [
+ {dialect._encoder(k)[0]: d[k] for k in d}
+ for d in parameters
+ ] or [{}]
+ else:
+ self.parameters = [
+ dialect.execute_sequence_format(p) for p in parameters
+ ]
+
+ self.executemany = len(parameters) > 1
+
+ if not dialect.supports_unicode_statements and isinstance(
+ statement, util.text_type
+ ):
+ self.unicode_statement = statement
+ self.statement = dialect._encoder(statement)[0]
+ else:
+ self.statement = self.unicode_statement = statement
+
+ self.cursor = self.create_cursor()
+ return self
+
+ @classmethod
+ def _init_default(
+ cls, dialect, connection, dbapi_connection, execution_options
+ ):
+ """Initialize execution context for a ColumnDefault construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.cursor = self.create_cursor()
+ return self
+
+ def _get_cache_stats(self):
+ if self.compiled is None:
+ return "raw sql"
+
+ now = util.perf_counter()
+
+ ch = self.cache_hit
+
+ if ch is NO_CACHE_KEY:
+ return "no key %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is CACHE_HIT:
+ return "cached since %.4gs ago" % (now - self.compiled._gen_time,)
+ elif ch is CACHE_MISS:
+ return "generated in %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is CACHING_DISABLED:
+ return "caching disabled %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is NO_DIALECT_SUPPORT:
+ return "dialect %s+%s does not support caching %.5fs" % (
+ self.dialect.name,
+ self.dialect.driver,
+ now - self.compiled._gen_time,
+ )
+ else:
+ return "unknown"
+
+ @util.memoized_property
+ def identifier_preparer(self):
+ if self.compiled:
+ return self.compiled.preparer
+ elif "schema_translate_map" in self.execution_options:
+ return self.dialect.identifier_preparer._with_schema_translate(
+ self.execution_options["schema_translate_map"]
+ )
+ else:
+ return self.dialect.identifier_preparer
+
+ @util.memoized_property
+ def engine(self):
+ return self.root_connection.engine
+
+ @util.memoized_property
+ def postfetch_cols(self):
+ return self.compiled.postfetch
+
+ @util.memoized_property
+ def prefetch_cols(self):
+ if self.isinsert:
+ return self.compiled.insert_prefetch
+ elif self.isupdate:
+ return self.compiled.update_prefetch
+ else:
+ return ()
+
+ @util.memoized_property
+ def returning_cols(self):
+ self.compiled.returning
+
+ @util.memoized_property
+ def no_parameters(self):
+ return self.execution_options.get("no_parameters", False)
+
+ @util.memoized_property
+ def should_autocommit(self):
+ autocommit = self.execution_options.get(
+ "autocommit",
+ not self.compiled
+ and self.statement
+ and expression.PARSE_AUTOCOMMIT
+ or False,
+ )
+
+ if autocommit is expression.PARSE_AUTOCOMMIT:
+ return self.should_autocommit_text(self.unicode_statement)
+ else:
+ return autocommit
+
+ def _execute_scalar(self, stmt, type_, parameters=None):
+ """Execute a string statement on the current cursor, returning a
+ scalar result.
+
+ Used to fire off sequences, default phrases, and "select lastrowid"
+ types of statements individually or in the context of a parent INSERT
+ or UPDATE statement.
+
+ """
+
+ conn = self.root_connection
+ if (
+ isinstance(stmt, util.text_type)
+ and not self.dialect.supports_unicode_statements
+ ):
+ stmt = self.dialect._encoder(stmt)[0]
+
+ if "schema_translate_map" in self.execution_options:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+
+ rst = self.identifier_preparer._render_schema_translates
+ stmt = rst(stmt, schema_translate_map)
+
+ if not parameters:
+ if self.dialect.positional:
+ parameters = self.dialect.execute_sequence_format()
+ else:
+ parameters = {}
+
+ conn._cursor_execute(self.cursor, stmt, parameters, context=self)
+ r = self.cursor.fetchone()[0]
+ if type_ is not None:
+ # apply type post processors to the result
+ proc = type_._cached_result_processor(
+ self.dialect, self.cursor.description[0][1]
+ )
+ if proc:
+ return proc(r)
+ return r
+
+ @property
+ def connection(self):
+ conn = self.root_connection
+ if conn._is_future:
+ return conn
+ else:
+ return conn._branch()
+
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_REGEXP.match(statement)
+
+ def _use_server_side_cursor(self):
+ if not self.dialect.supports_server_side_cursors:
+ return False
+
+ if self.dialect.server_side_cursors:
+ # this is deprecated
+ use_server_side = self.execution_options.get(
+ "stream_results", True
+ ) and (
+ (
+ self.compiled
+ and isinstance(
+ self.compiled.statement, expression.Selectable
+ )
+ or (
+ (
+ not self.compiled
+ or isinstance(
+ self.compiled.statement, expression.TextClause
+ )
+ )
+ and self.unicode_statement
+ and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement)
+ )
+ )
+ )
+ else:
+ use_server_side = self.execution_options.get(
+ "stream_results", False
+ )
+
+ return use_server_side
+
+ def create_cursor(self):
+ if (
+ # inlining initial preference checks for SS cursors
+ self.dialect.supports_server_side_cursors
+ and (
+ self.execution_options.get("stream_results", False)
+ or (
+ self.dialect.server_side_cursors
+ and self._use_server_side_cursor()
+ )
+ )
+ ):
+ self._is_server_side = True
+ return self.create_server_side_cursor()
+ else:
+ self._is_server_side = False
+ return self.create_default_cursor()
+
+ def create_default_cursor(self):
+ return self._dbapi_connection.cursor()
+
+ def create_server_side_cursor(self):
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ pass
+
+ def get_out_parameter_values(self, names):
+ raise NotImplementedError(
+ "This dialect does not support OUT parameters"
+ )
+
+ def post_exec(self):
+ pass
+
+ def get_result_processor(self, type_, colname, coltype):
+ """Return a 'result processor' for a given type as present in
+ cursor.description.
+
+ This has a default implementation that dialects can override
+ for context-sensitive result type handling.
+
+ """
+ return type_._cached_result_processor(self.dialect, coltype)
+
+ def get_lastrowid(self):
+ """return self.cursor.lastrowid, or equivalent, after an INSERT.
+
+ This may involve calling special cursor functions, issuing a new SELECT
+ on the cursor (or a new one), or returning a stored value that was
+ calculated within post_exec().
+
+ This function will only be called for dialects which support "implicit"
+ primary key generation, keep preexecute_autoincrement_sequences set to
+ False, and when no explicit id value was bound to the statement.
+
+ The function is called once for an INSERT statement that would need to
+ return the last inserted primary key for those dialects that make use
+ of the lastrowid concept. In these cases, it is called directly after
+ :meth:`.ExecutionContext.post_exec`.
+
+ """
+ return self.cursor.lastrowid
+
+ def handle_dbapi_exception(self, e):
+ pass
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
+
+ def supports_sane_rowcount(self):
+ return self.dialect.supports_sane_rowcount
+
+ def supports_sane_multi_rowcount(self):
+ return self.dialect.supports_sane_multi_rowcount
+
+ def _setup_result_proxy(self):
+ exec_opt = self.execution_options
+
+ if self.is_crud or self.is_text:
+ result = self._setup_dml_or_text_result()
+ yp = sr = False
+ else:
+ yp = exec_opt.get("yield_per", None)
+ sr = self._is_server_side or exec_opt.get("stream_results", False)
+ strategy = self.cursor_fetch_strategy
+ if sr and strategy is _cursor._DEFAULT_FETCH:
+ strategy = _cursor.BufferedRowCursorFetchStrategy(
+ self.cursor, self.execution_options
+ )
+ cursor_description = (
+ strategy.alternate_cursor_description
+ or self.cursor.description
+ )
+ if cursor_description is None:
+ strategy = _cursor._NO_CURSOR_DQL
+
+ if self._is_future_result:
+ if self.root_connection.should_close_with_result:
+ raise exc.InvalidRequestError(
+ "can't use future_result=True with close_with_result"
+ )
+ result = _cursor.CursorResult(
+ self, strategy, cursor_description
+ )
+ else:
+ result = _cursor.LegacyCursorResult(
+ self, strategy, cursor_description
+ )
+
+ if (
+ self.compiled
+ and not self.isddl
+ and self.compiled.has_out_parameters
+ ):
+ self._setup_out_parameters(result)
+
+ self._soft_closed = result._soft_closed
+
+ if yp:
+ result = result.yield_per(yp)
+
+ return result
+
+ def _setup_out_parameters(self, result):
+
+ out_bindparams = [
+ (param, name)
+ for param, name in self.compiled.bind_names.items()
+ if param.isoutparam
+ ]
+ out_parameters = {}
+
+ for bindparam, raw_value in zip(
+ [param for param, name in out_bindparams],
+ self.get_out_parameter_values(
+ [name for param, name in out_bindparams]
+ ),
+ ):
+
+ type_ = bindparam.type
+ impl_type = type_.dialect_impl(self.dialect)
+ dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
+ result_processor = impl_type.result_processor(
+ self.dialect, dbapi_type
+ )
+ if result_processor is not None:
+ raw_value = result_processor(raw_value)
+ out_parameters[bindparam.key] = raw_value
+
+ result.out_parameters = out_parameters
+
+ def _setup_dml_or_text_result(self):
+ if self.isinsert:
+ if self.compiled.postfetch_lastrowid:
+ self.inserted_primary_key_rows = (
+ self._setup_ins_pk_from_lastrowid()
+ )
+ # else if not self._is_implicit_returning,
+ # the default inserted_primary_key_rows accessor will
+ # return an "empty" primary key collection when accessed.
+
+ strategy = self.cursor_fetch_strategy
+ if self._is_server_side and strategy is _cursor._DEFAULT_FETCH:
+ strategy = _cursor.BufferedRowCursorFetchStrategy(
+ self.cursor, self.execution_options
+ )
+ cursor_description = (
+ strategy.alternate_cursor_description or self.cursor.description
+ )
+ if cursor_description is None:
+ strategy = _cursor._NO_CURSOR_DML
+
+ if self._is_future_result:
+ result = _cursor.CursorResult(self, strategy, cursor_description)
+ else:
+ result = _cursor.LegacyCursorResult(
+ self, strategy, cursor_description
+ )
+
+ if self.isinsert:
+ if self._is_implicit_returning:
+ rows = result.all()
+
+ self.returned_default_rows = rows
+
+ self.inserted_primary_key_rows = (
+ self._setup_ins_pk_from_implicit_returning(result, rows)
+ )
+
+ # test that it has a cursor metadata that is accurate. the
+ # first row will have been fetched and current assumptions
+ # are that the result has only one row, until executemany()
+ # support is added here.
+ assert result._metadata.returns_rows
+ result._soft_close()
+ elif not self._is_explicit_returning:
+ result._soft_close()
+
+ # we assume here the result does not return any rows.
+ # *usually*, this will be true. However, some dialects
+ # such as that of MSSQL/pyodbc need to SELECT a post fetch
+ # function so this is not necessarily true.
+ # assert not result.returns_rows
+
+ elif self.isupdate and self._is_implicit_returning:
+ row = result.fetchone()
+ self.returned_default_rows = [row]
+ result._soft_close()
+
+ # test that it has a cursor metadata that is accurate.
+ # the rows have all been fetched however.
+ assert result._metadata.returns_rows
+
+ elif not result._metadata.returns_rows:
+ # no results, get rowcount
+ # (which requires open cursor on some drivers
+ # such as kintersbasdb, mxodbc)
+ result.rowcount
+ result._soft_close()
+ return result
+
+ @util.memoized_property
+ def inserted_primary_key_rows(self):
+ # if no specific "get primary key" strategy was set up
+ # during execution, return a "default" primary key based
+ # on what's in the compiled_parameters and nothing else.
+ return self._setup_ins_pk_from_empty()
+
+ def _setup_ins_pk_from_lastrowid(self):
+ getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+
+ lastrowid = self.get_lastrowid()
+ return [getter(lastrowid, self.compiled_parameters[0])]
+
+ def _setup_ins_pk_from_empty(self):
+ getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+ return [getter(None, param) for param in self.compiled_parameters]
+
+ def _setup_ins_pk_from_implicit_returning(self, result, rows):
+
+ if not rows:
+ return []
+
+ getter = self.compiled._inserted_primary_key_from_returning_getter
+ compiled_params = self.compiled_parameters
+
+ return [
+ getter(row, param) for row, param in zip(rows, compiled_params)
+ ]
+
+ def lastrow_has_defaults(self):
+ return (self.isinsert or self.isupdate) and bool(
+ self.compiled.postfetch
+ )
+
+ def _set_input_sizes(self):
+ """Given a cursor and ClauseParameters, call the appropriate
+ style of ``setinputsizes()`` on the cursor, using DB-API types
+ from the bind parameter's ``TypeEngine`` objects.
+
+ This method only called by those dialects which require it,
+ currently cx_oracle, asyncpg and pg8000.
+
+ """
+ if self.isddl or self.is_text:
+ return
+
+ inputsizes = self.compiled._get_set_input_sizes_lookup(
+ include_types=self.include_set_input_sizes,
+ exclude_types=self.exclude_set_input_sizes,
+ )
+
+ if inputsizes is None:
+ return
+
+ if self.dialect._has_events:
+ inputsizes = dict(inputsizes)
+ self.dialect.dispatch.do_setinputsizes(
+ inputsizes, self.cursor, self.statement, self.parameters, self
+ )
+
+ has_escaped_names = bool(self.compiled.escaped_bind_names)
+ if has_escaped_names:
+ escaped_bind_names = self.compiled.escaped_bind_names
+
+ if self.dialect.positional:
+ items = [
+ (key, self.compiled.binds[key])
+ for key in self.compiled.positiontup
+ ]
+ else:
+ items = [
+ (key, bindparam)
+ for bindparam, key in self.compiled.bind_names.items()
+ ]
+
+ generic_inputsizes = []
+ for key, bindparam in items:
+ if bindparam in self.compiled.literal_execute_params:
+ continue
+
+ if key in self._expanded_parameters:
+ if bindparam.type._is_tuple_type:
+ num = len(bindparam.type.types)
+ dbtypes = inputsizes[bindparam]
+ generic_inputsizes.extend(
+ (
+ (
+ escaped_bind_names.get(paramname, paramname)
+ if has_escaped_names
+ else paramname
+ ),
+ dbtypes[idx % num],
+ bindparam.type.types[idx % num],
+ )
+ for idx, paramname in enumerate(
+ self._expanded_parameters[key]
+ )
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+ generic_inputsizes.extend(
+ (
+ (
+ escaped_bind_names.get(paramname, paramname)
+ if has_escaped_names
+ else paramname
+ ),
+ dbtype,
+ bindparam.type,
+ )
+ for paramname in self._expanded_parameters[key]
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+
+ escaped_name = (
+ escaped_bind_names.get(key, key)
+ if has_escaped_names
+ else key
+ )
+
+ generic_inputsizes.append(
+ (escaped_name, dbtype, bindparam.type)
+ )
+ try:
+ self.dialect.do_set_input_sizes(
+ self.cursor, generic_inputsizes, self
+ )
+ except BaseException as e:
+ self.root_connection._handle_dbapi_exception(
+ e, None, None, None, self
+ )
+
+ def _exec_default(self, column, default, type_):
+ if default.is_sequence:
+ return self.fire_sequence(default, type_)
+ elif default.is_callable:
+ self.current_column = column
+ return default.arg(self)
+ elif default.is_clause_element:
+ return self._exec_default_clause_element(column, default, type_)
+ else:
+ return default.arg
+
+ def _exec_default_clause_element(self, column, default, type_):
+ # execute a default that's a complete clause element. Here, we have
+ # to re-implement a miniature version of the compile->parameters->
+ # cursor.execute() sequence, since we don't want to modify the state
+ # of the connection / result in progress or create new connection/
+ # result objects etc.
+ # .. versionchanged:: 1.4
+
+ if not default._arg_is_typed:
+ default_arg = expression.type_coerce(default.arg, type_)
+ else:
+ default_arg = default.arg
+ compiled = expression.select(default_arg).compile(dialect=self.dialect)
+ compiled_params = compiled.construct_params()
+ processors = compiled._bind_processors
+ if compiled.positional:
+ positiontup = compiled.positiontup
+ parameters = self.dialect.execute_sequence_format(
+ [
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in positiontup
+ ]
+ )
+ else:
+ parameters = dict(
+ (
+ key,
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key],
+ )
+ for key in compiled_params
+ )
+ return self._execute_scalar(
+ util.text_type(compiled), type_, parameters=parameters
+ )
+
+ current_parameters = None
+ """A dictionary of parameters applied to the current row.
+
+ This attribute is only available in the context of a user-defined default
+ generation function, e.g. as described at :ref:`context_default_functions`.
+ It consists of a dictionary which includes entries for each column/value
+ pair that is to be part of the INSERT or UPDATE statement. The keys of the
+ dictionary will be the key value of each :class:`_schema.Column`,
+ which is usually
+ synonymous with the name.
+
+ Note that the :attr:`.DefaultExecutionContext.current_parameters` attribute
+ does not accommodate for the "multi-values" feature of the
+ :meth:`_expression.Insert.values` method. The
+ :meth:`.DefaultExecutionContext.get_current_parameters` method should be
+ preferred.
+
+ .. seealso::
+
+ :meth:`.DefaultExecutionContext.get_current_parameters`
+
+ :ref:`context_default_functions`
+
+ """
+
+ def get_current_parameters(self, isolate_multiinsert_groups=True):
+ """Return a dictionary of parameters applied to the current row.
+
+ This method can only be used in the context of a user-defined default
+ generation function, e.g. as described at
+ :ref:`context_default_functions`. When invoked, a dictionary is
+ returned which includes entries for each column/value pair that is part
+ of the INSERT or UPDATE statement. The keys of the dictionary will be
+ the key value of each :class:`_schema.Column`,
+ which is usually synonymous
+ with the name.
+
+ :param isolate_multiinsert_groups=True: indicates that multi-valued
+ INSERT constructs created using :meth:`_expression.Insert.values`
+ should be
+ handled by returning only the subset of parameters that are local
+ to the current column default invocation. When ``False``, the
+ raw parameters of the statement are returned including the
+ naming convention used in the case of multi-valued INSERT.
+
+ .. versionadded:: 1.2 added
+ :meth:`.DefaultExecutionContext.get_current_parameters`
+ which provides more functionality over the existing
+ :attr:`.DefaultExecutionContext.current_parameters`
+ attribute.
+
+ .. seealso::
+
+ :attr:`.DefaultExecutionContext.current_parameters`
+
+ :ref:`context_default_functions`
+
+ """
+ try:
+ parameters = self.current_parameters
+ column = self.current_column
+ except AttributeError:
+ raise exc.InvalidRequestError(
+ "get_current_parameters() can only be invoked in the "
+ "context of a Python side column default function"
+ )
+
+ compile_state = self.compiled.compile_state
+ if (
+ isolate_multiinsert_groups
+ and self.isinsert
+ and compile_state._has_multi_parameters
+ ):
+ if column._is_multiparam_column:
+ index = column.index + 1
+ d = {column.original.key: parameters[column.key]}
+ else:
+ d = {column.key: parameters[column.key]}
+ index = 0
+ keys = compile_state._dict_parameters.keys()
+ d.update(
+ (key, parameters["%s_m%d" % (key, index)]) for key in keys
+ )
+ return d
+ else:
+ return parameters
+
+ def get_insert_default(self, column):
+ if column.default is None:
+ return None
+ else:
+ return self._exec_default(column, column.default, column.type)
+
+ def get_update_default(self, column):
+ if column.onupdate is None:
+ return None
+ else:
+ return self._exec_default(column, column.onupdate, column.type)
+
+ def _process_executemany_defaults(self):
+ key_getter = self.compiled._within_exec_param_key_getter
+
+ scalar_defaults = {}
+
+ insert_prefetch = self.compiled.insert_prefetch
+ update_prefetch = self.compiled.update_prefetch
+
+ # pre-determine scalar Python-side defaults
+ # to avoid many calls of get_insert_default()/
+ # get_update_default()
+ for c in insert_prefetch:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
+ scalar_defaults[c] = c.default.arg
+
+ for c in update_prefetch:
+ if c.onupdate and c.onupdate.is_scalar:
+ scalar_defaults[c] = c.onupdate.arg
+
+ for param in self.compiled_parameters:
+ self.current_parameters = param
+ for c in insert_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
+ else:
+ val = self.get_insert_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+ for c in update_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
+ else:
+ val = self.get_update_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+
+ del self.current_parameters
+
+ def _process_executesingle_defaults(self):
+ key_getter = self.compiled._within_exec_param_key_getter
+ self.current_parameters = (
+ compiled_parameters
+ ) = self.compiled_parameters[0]
+
+ for c in self.compiled.insert_prefetch:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
+ val = c.default.arg
+ else:
+ val = self.get_insert_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+
+ for c in self.compiled.update_prefetch:
+ val = self.get_update_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+ del self.current_parameters
+
+
+DefaultDialect.execution_ctx_cls = DefaultExecutionContext
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
new file mode 100644
index 0000000..286c4d4
--- /dev/null
+++ b/lib/sqlalchemy/engine/events.py
@@ -0,0 +1,835 @@
+# sqlalchemy/engine/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from .base import Engine
+from .interfaces import Connectable
+from .interfaces import Dialect
+from .. import event
+from .. import exc
+
+
+class ConnectionEvents(event.Events):
+ """Available events for :class:`.Connectable`, which includes
+ :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+ The methods here define the name of an event as well as the names of
+ members that are passed to listener functions.
+
+ An event listener can be associated with any :class:`.Connectable`
+ class or instance, such as an :class:`_engine.Engine`, e.g.::
+
+ from sqlalchemy import event, create_engine
+
+ def before_cursor_execute(conn, cursor, statement, parameters, context,
+ executemany):
+ log.info("Received statement: %s", statement)
+
+ engine = create_engine('postgresql://scott:tiger@localhost/test')
+ event.listen(engine, "before_cursor_execute", before_cursor_execute)
+
+ or with a specific :class:`_engine.Connection`::
+
+ with engine.begin() as conn:
+ @event.listens_for(conn, 'before_cursor_execute')
+ def before_cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ log.info("Received statement: %s", statement)
+
+ When the methods are called with a `statement` parameter, such as in
+ :meth:`.after_cursor_execute` or :meth:`.before_cursor_execute`,
+ the statement is the exact SQL string that was prepared for transmission
+ to the DBAPI ``cursor`` in the connection's :class:`.Dialect`.
+
+ The :meth:`.before_execute` and :meth:`.before_cursor_execute`
+ events can also be established with the ``retval=True`` flag, which
+ allows modification of the statement and parameters to be sent
+ to the database. The :meth:`.before_cursor_execute` event is
+ particularly useful here to add ad-hoc string transformations, such
+ as comments, to all executions::
+
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "before_cursor_execute", retval=True)
+ def comment_sql_calls(conn, cursor, statement, parameters,
+ context, executemany):
+ statement = statement + " -- some comment"
+ return statement, parameters
+
+ .. note:: :class:`_events.ConnectionEvents` can be established on any
+ combination of :class:`_engine.Engine`, :class:`_engine.Connection`,
+ as well
+ as instances of each of those classes. Events across all
+ four scopes will fire off for a given instance of
+ :class:`_engine.Connection`. However, for performance reasons, the
+ :class:`_engine.Connection` object determines at instantiation time
+ whether or not its parent :class:`_engine.Engine` has event listeners
+ established. Event listeners added to the :class:`_engine.Engine`
+ class or to an instance of :class:`_engine.Engine`
+ *after* the instantiation
+ of a dependent :class:`_engine.Connection` instance will usually
+ *not* be available on that :class:`_engine.Connection` instance.
+ The newly
+ added listeners will instead take effect for
+ :class:`_engine.Connection`
+ instances created subsequent to those event listeners being
+ established on the parent :class:`_engine.Engine` class or instance.
+
+ :param retval=False: Applies to the :meth:`.before_execute` and
+ :meth:`.before_cursor_execute` events only. When True, the
+ user-defined event function must have a return value, which
+ is a tuple of parameters that replace the given statement
+ and parameters. See those methods for a description of
+ specific return arguments.
+
+ """
+
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = Connectable
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ target._has_events = True
+
+ if not retval:
+ if identifier == "before_execute":
+ orig_fn = fn
+
+ def wrap_before_execute(
+ conn, clauseelement, multiparams, params, execution_options
+ ):
+ orig_fn(
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ execution_options,
+ )
+ return clauseelement, multiparams, params
+
+ fn = wrap_before_execute
+ elif identifier == "before_cursor_execute":
+ orig_fn = fn
+
+ def wrap_before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ orig_fn(
+ conn,
+ cursor,
+ statement,
+ parameters,
+ context,
+ executemany,
+ )
+ return statement, parameters
+
+ fn = wrap_before_cursor_execute
+ elif retval and identifier not in (
+ "before_execute",
+ "before_cursor_execute",
+ "handle_error",
+ ):
+ raise exc.ArgumentError(
+ "Only the 'before_execute', "
+ "'before_cursor_execute' and 'handle_error' engine "
+ "event listeners accept the 'retval=True' "
+ "argument."
+ )
+ event_key.with_wrapper(fn).base_listen()
+
+ @event._legacy_signature(
+ "1.4",
+ ["conn", "clauseelement", "multiparams", "params"],
+ lambda conn, clauseelement, multiparams, params, execution_options: (
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ ),
+ )
+ def before_execute(
+ self, conn, clauseelement, multiparams, params, execution_options
+ ):
+ """Intercept high level execute() events, receiving uncompiled
+ SQL constructs and other objects prior to rendering into SQL.
+
+ This event is good for debugging SQL compilation issues as well
+ as early manipulation of the parameters being sent to the database,
+ as the parameter lists will be in a consistent format here.
+
+ This event can be optionally established with the ``retval=True``
+ flag. The ``clauseelement``, ``multiparams``, and ``params``
+ arguments should be returned as a three-tuple in this case::
+
+ @event.listens_for(Engine, "before_execute", retval=True)
+ def before_execute(conn, clauseelement, multiparams, params):
+ # do something with clauseelement, multiparams, params
+ return clauseelement, multiparams, params
+
+ :param conn: :class:`_engine.Connection` object
+ :param clauseelement: SQL expression construct, :class:`.Compiled`
+ instance, or string statement passed to
+ :meth:`_engine.Connection.execute`.
+ :param multiparams: Multiple parameter sets, a list of dictionaries.
+ :param params: Single parameter set, a single dictionary.
+ :param execution_options: dictionary of execution
+ options passed along with the statement, if any. This is a merge
+ of all options that will be used, including those of the statement,
+ the connection, and those passed in to the method itself for
+ the 2.0 style of execution.
+
+ .. versionadded: 1.4
+
+ .. seealso::
+
+ :meth:`.before_cursor_execute`
+
+ """
+
+ @event._legacy_signature(
+ "1.4",
+ ["conn", "clauseelement", "multiparams", "params", "result"],
+ lambda conn, clauseelement, multiparams, params, execution_options, result: ( # noqa
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ result,
+ ),
+ )
+ def after_execute(
+ self,
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ execution_options,
+ result,
+ ):
+ """Intercept high level execute() events after execute.
+
+
+ :param conn: :class:`_engine.Connection` object
+ :param clauseelement: SQL expression construct, :class:`.Compiled`
+ instance, or string statement passed to
+ :meth:`_engine.Connection.execute`.
+ :param multiparams: Multiple parameter sets, a list of dictionaries.
+ :param params: Single parameter set, a single dictionary.
+ :param execution_options: dictionary of execution
+ options passed along with the statement, if any. This is a merge
+ of all options that will be used, including those of the statement,
+ the connection, and those passed in to the method itself for
+ the 2.0 style of execution.
+
+ .. versionadded: 1.4
+
+ :param result: :class:`_engine.CursorResult` generated by the
+ execution.
+
+ """
+
+ def before_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
+ """Intercept low-level cursor execute() events before execution,
+ receiving the string SQL statement and DBAPI-specific parameter list to
+ be invoked against a cursor.
+
+ This event is a good choice for logging as well as late modifications
+ to the SQL string. It's less ideal for parameter modifications except
+ for those which are specific to a target backend.
+
+ This event can be optionally established with the ``retval=True``
+ flag. The ``statement`` and ``parameters`` arguments should be
+ returned as a two-tuple in this case::
+
+ @event.listens_for(Engine, "before_cursor_execute", retval=True)
+ def before_cursor_execute(conn, cursor, statement,
+ parameters, context, executemany):
+ # do something with statement, parameters
+ return statement, parameters
+
+ See the example at :class:`_events.ConnectionEvents`.
+
+ :param conn: :class:`_engine.Connection` object
+ :param cursor: DBAPI cursor object
+ :param statement: string SQL statement, as to be passed to the DBAPI
+ :param parameters: Dictionary, tuple, or list of parameters being
+ passed to the ``execute()`` or ``executemany()`` method of the
+ DBAPI ``cursor``. In some cases may be ``None``.
+ :param context: :class:`.ExecutionContext` object in use. May
+ be ``None``.
+ :param executemany: boolean, if ``True``, this is an ``executemany()``
+ call, if ``False``, this is an ``execute()`` call.
+
+ .. seealso::
+
+ :meth:`.before_execute`
+
+ :meth:`.after_cursor_execute`
+
+ """
+
+ def after_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
+ """Intercept low-level cursor execute() events after execution.
+
+ :param conn: :class:`_engine.Connection` object
+ :param cursor: DBAPI cursor object. Will have results pending
+ if the statement was a SELECT, but these should not be consumed
+ as they will be needed by the :class:`_engine.CursorResult`.
+ :param statement: string SQL statement, as passed to the DBAPI
+ :param parameters: Dictionary, tuple, or list of parameters being
+ passed to the ``execute()`` or ``executemany()`` method of the
+ DBAPI ``cursor``. In some cases may be ``None``.
+ :param context: :class:`.ExecutionContext` object in use. May
+ be ``None``.
+ :param executemany: boolean, if ``True``, this is an ``executemany()``
+ call, if ``False``, this is an ``execute()`` call.
+
+ """
+
+ def handle_error(self, exception_context):
+ r"""Intercept all exceptions processed by the
+ :class:`_engine.Connection`.
+
+ This includes all exceptions emitted by the DBAPI as well as
+ within SQLAlchemy's statement invocation process, including
+ encoding errors and other statement validation errors. Other areas
+ in which the event is invoked include transaction begin and end,
+ result row fetching, cursor creation.
+
+ Note that :meth:`.handle_error` may support new kinds of exceptions
+ and new calling scenarios at *any time*. Code which uses this
+ event must expect new calling patterns to be present in minor
+ releases.
+
+ To support the wide variety of members that correspond to an exception,
+ as well as to allow extensibility of the event without backwards
+ incompatibility, the sole argument received is an instance of
+ :class:`.ExceptionContext`. This object contains data members
+ representing detail about the exception.
+
+ Use cases supported by this hook include:
+
+ * read-only, low-level exception handling for logging and
+ debugging purposes
+ * exception re-writing
+ * Establishing or disabling whether a connection or the owning
+ connection pool is invalidated or expired in response to a
+ specific exception [1]_.
+
+ The hook is called while the cursor from the failed operation
+ (if any) is still open and accessible. Special cleanup operations
+ can be called on this cursor; SQLAlchemy will attempt to close
+ this cursor subsequent to this hook being invoked. If the connection
+ is in "autocommit" mode, the transaction also remains open within
+ the scope of this hook; the rollback of the per-statement transaction
+ also occurs after the hook is called.
+
+ .. note::
+
+ .. [1] The pool "pre_ping" handler enabled using the
+ :paramref:`_sa.create_engine.pool_pre_ping` parameter does
+ **not** consult this event before deciding if the "ping"
+ returned false, as opposed to receiving an unhandled error.
+ For this use case, the :ref:`legacy recipe based on
+ engine_connect() may be used
+ <pool_disconnects_pessimistic_custom>`. A future API allow
+ more comprehensive customization of the "disconnect"
+ detection mechanism across all functions.
+
+ A handler function has two options for replacing
+ the SQLAlchemy-constructed exception into one that is user
+ defined. It can either raise this new exception directly, in
+ which case all further event listeners are bypassed and the
+ exception will be raised, after appropriate cleanup as taken
+ place::
+
+ @event.listens_for(Engine, "handle_error")
+ def handle_exception(context):
+ if isinstance(context.original_exception,
+ psycopg2.OperationalError) and \
+ "failed" in str(context.original_exception):
+ raise MySpecialException("failed operation")
+
+ .. warning:: Because the
+ :meth:`_events.ConnectionEvents.handle_error`
+ event specifically provides for exceptions to be re-thrown as
+ the ultimate exception raised by the failed statement,
+ **stack traces will be misleading** if the user-defined event
+ handler itself fails and throws an unexpected exception;
+ the stack trace may not illustrate the actual code line that
+ failed! It is advised to code carefully here and use
+ logging and/or inline debugging if unexpected exceptions are
+ occurring.
+
+ Alternatively, a "chained" style of event handling can be
+ used, by configuring the handler with the ``retval=True``
+ modifier and returning the new exception instance from the
+ function. In this case, event handling will continue onto the
+ next handler. The "chained" exception is available using
+ :attr:`.ExceptionContext.chained_exception`::
+
+ @event.listens_for(Engine, "handle_error", retval=True)
+ def handle_exception(context):
+ if context.chained_exception is not None and \
+ "special" in context.chained_exception.message:
+ return MySpecialException("failed",
+ cause=context.chained_exception)
+
+ Handlers that return ``None`` may be used within the chain; when
+ a handler returns ``None``, the previous exception instance,
+ if any, is maintained as the current exception that is passed onto the
+ next handler.
+
+ When a custom exception is raised or returned, SQLAlchemy raises
+ this new exception as-is, it is not wrapped by any SQLAlchemy
+ object. If the exception is not a subclass of
+ :class:`sqlalchemy.exc.StatementError`,
+ certain features may not be available; currently this includes
+ the ORM's feature of adding a detail hint about "autoflush" to
+ exceptions raised within the autoflush process.
+
+ :param context: an :class:`.ExceptionContext` object. See this
+ class for details on all available members.
+
+ .. versionadded:: 0.9.7 Added the
+ :meth:`_events.ConnectionEvents.handle_error` hook.
+
+ .. versionchanged:: 1.1 The :meth:`.handle_error` event will now
+ receive all exceptions that inherit from ``BaseException``,
+ including ``SystemExit`` and ``KeyboardInterrupt``. The setting for
+ :attr:`.ExceptionContext.is_disconnect` is ``True`` in this case and
+ the default for
+ :attr:`.ExceptionContext.invalidate_pool_on_disconnect` is
+ ``False``.
+
+ .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is now
+ invoked when an :class:`_engine.Engine` fails during the initial
+ call to :meth:`_engine.Engine.connect`, as well as when a
+ :class:`_engine.Connection` object encounters an error during a
+ reconnect operation.
+
+ .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is
+ not fired off when a dialect makes use of the
+ ``skip_user_error_events`` execution option. This is used
+ by dialects which intend to catch SQLAlchemy-specific exceptions
+ within specific operations, such as when the MySQL dialect detects
+ a table not present within the ``has_table()`` dialect method.
+ Prior to 1.0.0, code which implements :meth:`.handle_error` needs
+ to ensure that exceptions thrown in these scenarios are re-raised
+ without modification.
+
+ """
+
+ def engine_connect(self, conn, branch):
+ """Intercept the creation of a new :class:`_engine.Connection`.
+
+ This event is called typically as the direct result of calling
+ the :meth:`_engine.Engine.connect` method.
+
+ It differs from the :meth:`_events.PoolEvents.connect` method, which
+ refers to the actual connection to a database at the DBAPI level;
+ a DBAPI connection may be pooled and reused for many operations.
+ In contrast, this event refers only to the production of a higher level
+ :class:`_engine.Connection` wrapper around such a DBAPI connection.
+
+ It also differs from the :meth:`_events.PoolEvents.checkout` event
+ in that it is specific to the :class:`_engine.Connection` object,
+ not the
+ DBAPI connection that :meth:`_events.PoolEvents.checkout` deals with,
+ although
+ this DBAPI connection is available here via the
+ :attr:`_engine.Connection.connection` attribute.
+ But note there can in fact
+ be multiple :meth:`_events.PoolEvents.checkout`
+ events within the lifespan
+ of a single :class:`_engine.Connection` object, if that
+ :class:`_engine.Connection`
+ is invalidated and re-established. There can also be multiple
+ :class:`_engine.Connection`
+ objects generated for the same already-checked-out
+ DBAPI connection, in the case that a "branch" of a
+ :class:`_engine.Connection`
+ is produced.
+
+ :param conn: :class:`_engine.Connection` object.
+ :param branch: if True, this is a "branch" of an existing
+ :class:`_engine.Connection`. A branch is generated within the course
+ of a statement execution to invoke supplemental statements, most
+ typically to pre-execute a SELECT of a default value for the purposes
+ of an INSERT statement.
+
+ .. seealso::
+
+ :meth:`_events.PoolEvents.checkout`
+ the lower-level pool checkout event
+ for an individual DBAPI connection
+
+ """
+
+ def set_connection_execution_options(self, conn, opts):
+ """Intercept when the :meth:`_engine.Connection.execution_options`
+ method is called.
+
+ This method is called after the new :class:`_engine.Connection`
+ has been
+ produced, with the newly updated execution options collection, but
+ before the :class:`.Dialect` has acted upon any of those new options.
+
+ Note that this method is not called when a new
+ :class:`_engine.Connection`
+ is produced which is inheriting execution options from its parent
+ :class:`_engine.Engine`; to intercept this condition, use the
+ :meth:`_events.ConnectionEvents.engine_connect` event.
+
+ :param conn: The newly copied :class:`_engine.Connection` object
+
+ :param opts: dictionary of options that were passed to the
+ :meth:`_engine.Connection.execution_options` method.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.set_engine_execution_options`
+ - event
+ which is called when :meth:`_engine.Engine.execution_options`
+ is called.
+
+
+ """
+
+ def set_engine_execution_options(self, engine, opts):
+ """Intercept when the :meth:`_engine.Engine.execution_options`
+ method is called.
+
+ The :meth:`_engine.Engine.execution_options` method produces a shallow
+ copy of the :class:`_engine.Engine` which stores the new options.
+ That new
+ :class:`_engine.Engine` is passed here.
+ A particular application of this
+ method is to add a :meth:`_events.ConnectionEvents.engine_connect`
+ event
+ handler to the given :class:`_engine.Engine`
+ which will perform some per-
+ :class:`_engine.Connection` task specific to these execution options.
+
+ :param conn: The newly copied :class:`_engine.Engine` object
+
+ :param opts: dictionary of options that were passed to the
+ :meth:`_engine.Connection.execution_options` method.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.set_connection_execution_options`
+ - event
+ which is called when :meth:`_engine.Connection.execution_options`
+ is
+ called.
+
+ """
+
+ def engine_disposed(self, engine):
+ """Intercept when the :meth:`_engine.Engine.dispose` method is called.
+
+ The :meth:`_engine.Engine.dispose` method instructs the engine to
+ "dispose" of it's connection pool (e.g. :class:`_pool.Pool`), and
+ replaces it with a new one. Disposing of the old pool has the
+ effect that existing checked-in connections are closed. The new
+ pool does not establish any new connections until it is first used.
+
+ This event can be used to indicate that resources related to the
+ :class:`_engine.Engine` should also be cleaned up,
+ keeping in mind that the
+ :class:`_engine.Engine`
+ can still be used for new requests in which case
+ it re-acquires connection resources.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ def begin(self, conn):
+ """Intercept begin() events.
+
+ :param conn: :class:`_engine.Connection` object
+
+ """
+
+ def rollback(self, conn):
+ """Intercept rollback() events, as initiated by a
+ :class:`.Transaction`.
+
+ Note that the :class:`_pool.Pool` also "auto-rolls back"
+ a DBAPI connection upon checkin, if the ``reset_on_return``
+ flag is set to its default value of ``'rollback'``.
+ To intercept this
+ rollback, use the :meth:`_events.PoolEvents.reset` hook.
+
+ :param conn: :class:`_engine.Connection` object
+
+ .. seealso::
+
+ :meth:`_events.PoolEvents.reset`
+
+ """
+
+ def commit(self, conn):
+ """Intercept commit() events, as initiated by a
+ :class:`.Transaction`.
+
+ Note that the :class:`_pool.Pool` may also "auto-commit"
+ a DBAPI connection upon checkin, if the ``reset_on_return``
+ flag is set to the value ``'commit'``. To intercept this
+ commit, use the :meth:`_events.PoolEvents.reset` hook.
+
+ :param conn: :class:`_engine.Connection` object
+ """
+
+ def savepoint(self, conn, name):
+ """Intercept savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+
+ """
+
+ def rollback_savepoint(self, conn, name, context):
+ """Intercept rollback_savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+ :param context: not used
+
+ """
+ # TODO: deprecate "context"
+
+ def release_savepoint(self, conn, name, context):
+ """Intercept release_savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+ :param context: not used
+
+ """
+ # TODO: deprecate "context"
+
+ def begin_twophase(self, conn, xid):
+ """Intercept begin_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+
+ """
+
+ def prepare_twophase(self, conn, xid):
+ """Intercept prepare_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ """
+
+ def rollback_twophase(self, conn, xid, is_prepared):
+ """Intercept rollback_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ :param is_prepared: boolean, indicates if
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+
+ """
+
+ def commit_twophase(self, conn, xid, is_prepared):
+ """Intercept commit_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ :param is_prepared: boolean, indicates if
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+
+ """
+
+
+class DialectEvents(event.Events):
+ """event interface for execution-replacement functions.
+
+ These events allow direct instrumentation and replacement
+ of key dialect functions which interact with the DBAPI.
+
+ .. note::
+
+ :class:`.DialectEvents` hooks should be considered **semi-public**
+ and experimental.
+ These hooks are not for general use and are only for those situations
+ where intricate re-statement of DBAPI mechanics must be injected onto
+ an existing dialect. For general-use statement-interception events,
+ please use the :class:`_events.ConnectionEvents` interface.
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.before_cursor_execute`
+
+ :meth:`_events.ConnectionEvents.before_execute`
+
+ :meth:`_events.ConnectionEvents.after_cursor_execute`
+
+ :meth:`_events.ConnectionEvents.after_execute`
+
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = Dialect
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ target = event_key.dispatch_target
+
+ target._has_events = True
+ event_key.base_listen()
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ if issubclass(target, Engine):
+ return Dialect
+ elif issubclass(target, Dialect):
+ return target
+ elif isinstance(target, Engine):
+ return target.dialect
+ elif isinstance(target, Dialect):
+ return target
+ elif hasattr(target, "dispatch") and hasattr(
+ target.dispatch._events, "_no_async_engine_events"
+ ):
+ target.dispatch._events._no_async_engine_events()
+ else:
+ return None
+
+ def do_connect(self, dialect, conn_rec, cargs, cparams):
+ """Receive connection arguments before a connection is made.
+
+ This event is useful in that it allows the handler to manipulate the
+ cargs and/or cparams collections that control how the DBAPI
+ ``connect()`` function will be called. ``cargs`` will always be a
+ Python list that can be mutated in-place, and ``cparams`` a Python
+ dictionary that may also be mutated::
+
+ e = create_engine("postgresql+psycopg2://user@host/dbname")
+
+ @event.listens_for(e, 'do_connect')
+ def receive_do_connect(dialect, conn_rec, cargs, cparams):
+ cparams["password"] = "some_password"
+
+ The event hook may also be used to override the call to ``connect()``
+ entirely, by returning a non-``None`` DBAPI connection object::
+
+ e = create_engine("postgresql+psycopg2://user@host/dbname")
+
+ @event.listens_for(e, 'do_connect')
+ def receive_do_connect(dialect, conn_rec, cargs, cparams):
+ return psycopg2.connect(*cargs, **cparams)
+
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`custom_dbapi_args`
+
+ """
+
+ def do_executemany(self, cursor, statement, parameters, context):
+ """Receive a cursor to have executemany() called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_execute_no_params(self, cursor, statement, context):
+ """Receive a cursor to have execute() with no parameters called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_execute(self, cursor, statement, parameters, context):
+ """Receive a cursor to have execute() called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_setinputsizes(
+ self, inputsizes, cursor, statement, parameters, context
+ ):
+ """Receive the setinputsizes dictionary for possible modification.
+
+ This event is emitted in the case where the dialect makes use of the
+ DBAPI ``cursor.setinputsizes()`` method which passes information about
+ parameter binding for a particular statement. The given
+ ``inputsizes`` dictionary will contain :class:`.BindParameter` objects
+ as keys, linked to DBAPI-specific type objects as values; for
+ parameters that are not bound, they are added to the dictionary with
+ ``None`` as the value, which means the parameter will not be included
+ in the ultimate setinputsizes call. The event may be used to inspect
+ and/or log the datatypes that are being bound, as well as to modify the
+ dictionary in place. Parameters can be added, modified, or removed
+ from this dictionary. Callers will typically want to inspect the
+ :attr:`.BindParameter.type` attribute of the given bind objects in
+ order to make decisions about the DBAPI object.
+
+ After the event, the ``inputsizes`` dictionary is converted into
+ an appropriate datastructure to be passed to ``cursor.setinputsizes``;
+ either a list for a positional bound parameter execution style,
+ or a dictionary of string parameter keys to DBAPI type objects for
+ a named bound parameter execution style.
+
+ The setinputsizes hook overall is only used for dialects which include
+ the flag ``use_setinputsizes=True``. Dialects which use this
+ include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
+
+ .. note::
+
+ For use with pyodbc, the ``use_setinputsizes`` flag
+ must be passed to the dialect, e.g.::
+
+ create_engine("mssql+pyodbc://...", use_setinputsizes=True)
+
+ .. seealso::
+
+ :ref:`mssql_pyodbc_setinputsizes`
+
+ .. versionadded:: 1.2.9
+
+ .. seealso::
+
+ :ref:`cx_oracle_setinputsizes`
+
+ """
+ pass
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
new file mode 100644
index 0000000..4f2524a
--- /dev/null
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -0,0 +1,1719 @@
+# engine/interfaces.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define core interfaces used by the engine system."""
+
+from .. import util
+from ..sql.compiler import Compiled # noqa
+from ..sql.compiler import TypeCompiler # noqa
+from ..util.concurrency import await_only
+
+
+class Dialect(object):
+ """Define the behavior of a specific database and DB-API combination.
+
+ Any aspect of metadata definition, SQL query generation,
+ execution, result-set handling, or anything else which varies
+ between databases is defined under the general category of the
+ Dialect. The Dialect acts as a factory for other
+ database-specific object implementations including
+ ExecutionContext, Compiled, DefaultGenerator, and TypeEngine.
+
+ .. note:: Third party dialects should not subclass :class:`.Dialect`
+ directly. Instead, subclass :class:`.default.DefaultDialect` or
+ descendant class.
+
+ All dialects include the following attributes. There are many other
+ attributes that may be supported as well:
+
+ ``name``
+ identifying name for the dialect from a DBAPI-neutral point of view
+ (i.e. 'sqlite')
+
+ ``driver``
+ identifying name for the dialect's DBAPI
+
+ ``positional``
+ True if the paramstyle for this Dialect is positional.
+
+ ``paramstyle``
+ the paramstyle to be used (some DB-APIs support multiple
+ paramstyles).
+
+ ``encoding``
+ type of encoding to use for unicode, usually defaults to
+ 'utf-8'.
+
+ ``statement_compiler``
+ a :class:`.Compiled` class used to compile SQL statements
+
+ ``ddl_compiler``
+ a :class:`.Compiled` class used to compile DDL statements
+
+ ``server_version_info``
+ a tuple containing a version number for the DB backend in use.
+ This value is only available for supporting dialects, and is
+ typically populated during the initial connection to the database.
+
+ ``default_schema_name``
+ the name of the default schema. This value is only available for
+ supporting dialects, and is typically populated during the
+ initial connection to the database.
+
+ ``execution_ctx_cls``
+ a :class:`.ExecutionContext` class used to handle statement execution
+
+ ``execute_sequence_format``
+ either the 'tuple' or 'list' type, depending on what cursor.execute()
+ accepts for the second argument (they vary).
+
+ ``preparer``
+ a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
+ quote identifiers.
+
+ ``supports_alter``
+ ``True`` if the database supports ``ALTER TABLE`` - used only for
+ generating foreign key constraints in certain circumstances
+
+ ``max_identifier_length``
+ The maximum length of identifier names.
+
+ ``supports_sane_rowcount``
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements.
+
+ ``supports_sane_multi_rowcount``
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements when executed via
+ executemany.
+
+ ``preexecute_autoincrement_sequences``
+ True if 'implicit' primary key functions must be executed separately
+ in order to get their value. This is currently oriented towards
+ PostgreSQL.
+
+ ``implicit_returning``
+ use RETURNING or equivalent during INSERT execution in order to load
+ newly generated primary keys and other column defaults in one execution,
+ which are then available via inserted_primary_key.
+ If an insert statement has returning() specified explicitly,
+ the "implicit" functionality is not used and inserted_primary_key
+ will not be available.
+
+ ``colspecs``
+ A dictionary of TypeEngine classes from sqlalchemy.types mapped
+ to subclasses that are specific to the dialect class. This
+ dictionary is class-level only and is not accessed from the
+ dialect instance itself.
+
+ ``supports_default_values``
+ Indicates if the construct ``INSERT INTO tablename DEFAULT
+ VALUES`` is supported
+
+ ``supports_sequences``
+ Indicates if the dialect supports CREATE SEQUENCE or similar.
+
+ ``sequences_optional``
+ If True, indicates if the "optional" flag on the Sequence() construct
+ should signal to not generate a CREATE SEQUENCE. Applies only to
+ dialects that support sequences. Currently used only to allow PostgreSQL
+ SERIAL to be used on a column that specifies Sequence() for usage on
+ other backends.
+
+ ``supports_native_enum``
+ Indicates if the dialect supports a native ENUM construct.
+ This will prevent types.Enum from generating a CHECK
+ constraint when that type is used.
+
+ ``supports_native_boolean``
+ Indicates if the dialect supports a native boolean construct.
+ This will prevent types.Boolean from generating a CHECK
+ constraint when that type is used.
+
+ ``dbapi_exception_translation_map``
+ A dictionary of names that will contain as values the names of
+ pep-249 exceptions ("IntegrityError", "OperationalError", etc)
+ keyed to alternate class names, to support the case where a
+ DBAPI has exception classes that aren't named as they are
+ referred to (e.g. IntegrityError = MyException). In the vast
+ majority of cases this dictionary is empty.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ _has_events = False
+
+ supports_statement_cache = True
+ """indicates if this dialect supports caching.
+
+ All dialects that are compatible with statement caching should set this
+ flag to True directly on each dialect class and subclass that supports
+ it. SQLAlchemy tests that this flag is locally present on each dialect
+ subclass before it will use statement caching. This is to provide
+ safety for legacy or new dialects that are not yet fully tested to be
+ compliant with SQL statement caching.
+
+ .. versionadded:: 1.4.5
+
+ .. seealso::
+
+ :ref:`engine_thirdparty_caching`
+
+ """
+
+ def create_connect_args(self, url):
+ """Build DB-API compatible connection arguments.
+
+ Given a :class:`.URL` object, returns a tuple
+ consisting of a ``(*args, **kwargs)`` suitable to send directly
+ to the dbapi's connect function. The arguments are sent to the
+ :meth:`.Dialect.connect` method which then runs the DBAPI-level
+ ``connect()`` function.
+
+ The method typically makes use of the
+ :meth:`.URL.translate_connect_args`
+ method in order to generate a dictionary of options.
+
+ The default implementation is::
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args()
+ opts.update(url.query)
+ return [[], opts]
+
+ :param url: a :class:`.URL` object
+
+ :return: a tuple of ``(*args, **kwargs)`` which will be passed to the
+ :meth:`.Dialect.connect` method.
+
+ .. seealso::
+
+ :meth:`.URL.translate_connect_args`
+
+ """
+
+ raise NotImplementedError()
+
+ @classmethod
+ def type_descriptor(cls, typeobj):
+ """Transform a generic type to a dialect-specific type.
+
+ Dialect classes will usually use the
+ :func:`_types.adapt_type` function in the types module to
+ accomplish this.
+
+ The returned result is cached *per dialect class* so can
+ contain no dialect-instance state.
+
+ """
+
+ raise NotImplementedError()
+
+ def initialize(self, connection):
+ """Called during strategized creation of the dialect with a
+ connection.
+
+ Allows dialects to configure options based on server version info or
+ other properties.
+
+ The connection passed here is a SQLAlchemy Connection object,
+ with full capabilities.
+
+ The initialize() method of the base dialect should be called via
+ super().
+
+ .. note:: as of SQLAlchemy 1.4, this method is called **before**
+ any :meth:`_engine.Dialect.on_connect` hooks are called.
+
+ """
+
+ pass
+
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return column
+ information as a list of dictionaries with these keys:
+
+ name
+ the column's name
+
+ type
+ [sqlalchemy.types#TypeEngine]
+
+ nullable
+ boolean
+
+ default
+ the column's default value
+
+ autoincrement
+ boolean
+
+ sequence
+ a dictionary of the form
+ {'name' : str, 'start' :int, 'increment': int, 'minvalue': int,
+ 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool,
+ 'cycle': bool, 'cache': int, 'order': bool}
+
+ Additional column attributes may be present.
+ """
+
+ raise NotImplementedError()
+
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ """Return information about the primary key constraint on
+ table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return primary
+ key information as a dictionary with these keys:
+
+ constrained_columns
+ a list of column names that make up the primary key
+
+ name
+ optional name of the primary key constraint.
+
+ """
+ raise NotImplementedError()
+
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return foreign
+ key information as a list of dicts with these keys:
+
+ name
+ the constraint's name
+
+ constrained_columns
+ a list of column names that make up the foreign key
+
+ referred_schema
+ the name of the referred schema
+
+ referred_table
+ the name of the referred table
+
+ referred_columns
+ a list of column names in the referred table that correspond to
+ constrained_columns
+ """
+
+ raise NotImplementedError()
+
+ def get_table_names(self, connection, schema=None, **kw):
+ """Return a list of table names for `schema`."""
+
+ raise NotImplementedError()
+
+ def get_temp_table_names(self, connection, schema=None, **kw):
+ """Return a list of temporary table names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_view_names(self, connection, schema=None, **kw):
+ """Return a list of all view names available in the database.
+
+ :param schema: schema name to query, if not the default schema.
+ """
+
+ raise NotImplementedError()
+
+ def get_sequence_names(self, connection, schema=None, **kw):
+ """Return a list of all sequence names available in the database.
+
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4
+ """
+
+ raise NotImplementedError()
+
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ """Return a list of temporary view names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ """Return view definition.
+
+ Given a :class:`_engine.Connection`, a string
+ `view_name`, and an optional string `schema`, return the view
+ definition.
+ """
+
+ raise NotImplementedError()
+
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Return information about indexes in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name` and an optional string `schema`, return index
+ information as a list of dictionaries with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
+ """
+
+ raise NotImplementedError()
+
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ r"""Return information about unique constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ unique constraint information as a list of dicts with these keys:
+
+ name
+ the unique constraint's name
+
+ column_names
+ list of column names in order
+
+ \**kw
+ other options passed to the dialect's get_unique_constraints()
+ method.
+
+ .. versionadded:: 0.9.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ r"""Return information about check constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ check constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the check constraint's name
+
+ * ``sqltext`` -
+ the check constraint's SQL expression
+
+ * ``**kw`` -
+ other options passed to the dialect's get_check_constraints()
+ method.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ r"""Return the "comment" for the table identified by `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ table comment information as a dictionary with this key:
+
+ text
+ text of the comment
+
+ Raises ``NotImplementedError`` for dialects that don't support
+ comments.
+
+ .. versionadded:: 1.2
+
+ """
+
+ raise NotImplementedError()
+
+ def normalize_name(self, name):
+ """convert the given name to lowercase if it is detected as
+ case insensitive.
+
+ This method is only used if the dialect defines
+ requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
+ def denormalize_name(self, name):
+ """convert the given name to a case insensitive identifier
+ for the backend if it is an all-lowercase name.
+
+ This method is only used if the dialect defines
+ requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
+ def has_table(self, connection, table_name, schema=None, **kw):
+ """For internal dialect use, check the existence of a particular table
+ in the database.
+
+ Given a :class:`_engine.Connection` object, a string table_name and
+ optional schema name, return True if the given table exists in the
+ database, False otherwise.
+
+ This method serves as the underlying implementation of the
+ public facing :meth:`.Inspector.has_table` method, and is also used
+ internally to implement the "checkfirst" behavior for methods like
+ :meth:`_schema.Table.create` and :meth:`_schema.MetaData.create_all`.
+
+ .. note:: This method is used internally by SQLAlchemy, and is
+ published so that third-party dialects may provide an
+ implementation. It is **not** the public API for checking for table
+ presence. Please use the :meth:`.Inspector.has_table` method.
+ Alternatively, for legacy cross-compatibility, the
+ :meth:`_engine.Engine.has_table` method may be used.
+
+ """
+
+ raise NotImplementedError()
+
+ def has_index(self, connection, table_name, index_name, schema=None):
+ """Check the existence of a particular index name in the database.
+
+ Given a :class:`_engine.Connection` object, a string
+ `table_name` and string index name, return True if an index of the
+ given name on the given table exists, false otherwise.
+
+ The :class:`.DefaultDialect` implements this in terms of the
+ :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods,
+ however dialects can implement a more performant version.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ raise NotImplementedError()
+
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ """Check the existence of a particular sequence in the database.
+
+ Given a :class:`_engine.Connection` object and a string
+ `sequence_name`, return True if the given sequence exists in
+ the database, False otherwise.
+ """
+
+ raise NotImplementedError()
+
+ def _get_server_version_info(self, connection):
+ """Retrieve the server version info from the given connection.
+
+ This is used by the default implementation to populate the
+ "server_version_info" attribute and is called exactly
+ once upon first connect.
+
+ """
+
+ raise NotImplementedError()
+
+ def _get_default_schema_name(self, connection):
+ """Return the string name of the currently selected schema from
+ the given connection.
+
+ This is used by the default implementation to populate the
+ "default_schema_name" attribute and is called exactly
+ once upon first connect.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_begin(self, dbapi_connection):
+ """Provide an implementation of ``connection.begin()``, given a
+ DB-API connection.
+
+ The DBAPI has no dedicated "begin" method and it is expected
+ that transactions are implicit. This hook is provided for those
+ DBAPIs that might need additional help in this area.
+
+ Note that :meth:`.Dialect.do_begin` is not called unless a
+ :class:`.Transaction` object is in use. The
+ :meth:`.Dialect.do_autocommit`
+ hook is provided for DBAPIs that need some extra commands emitted
+ after a commit in order to enter the next transaction, when the
+ SQLAlchemy :class:`_engine.Connection`
+ is used in its default "autocommit"
+ mode.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback(self, dbapi_connection):
+ """Provide an implementation of ``connection.rollback()``, given
+ a DB-API connection.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_commit(self, dbapi_connection):
+ """Provide an implementation of ``connection.commit()``, given a
+ DB-API connection.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_close(self, dbapi_connection):
+ """Provide an implementation of ``connection.close()``, given a DBAPI
+ connection.
+
+ This hook is called by the :class:`_pool.Pool`
+ when a connection has been
+ detached from the pool, or is being returned beyond the normal
+ capacity of the pool.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ """invoke the cursor.setinputsizes() method with appropriate arguments
+
+ This hook is called if the dialect.use_inputsizes flag is set to True.
+ Parameter data is passed in a list of tuples (paramname, dbtype,
+ sqltype), where ``paramname`` is the key of the parameter in the
+ statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
+ SQLAlchemy type. The order of tuples is in the correct parameter order.
+
+ .. versionadded:: 1.4
+
+
+ """
+ raise NotImplementedError()
+
+ def create_xid(self):
+ """Create a two-phase transaction ID.
+
+ This id will be passed to do_begin_twophase(),
+ do_rollback_twophase(), do_commit_twophase(). Its format is
+ unspecified.
+ """
+
+ raise NotImplementedError()
+
+ def do_savepoint(self, connection, name):
+ """Create a savepoint with the given name.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback_to_savepoint(self, connection, name):
+ """Rollback a connection to the named savepoint.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_release_savepoint(self, connection, name):
+ """Release the named savepoint on a connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+ """
+
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ """Begin a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+
+ """
+
+ raise NotImplementedError()
+
+ def do_prepare_twophase(self, connection, xid):
+ """Prepare a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ """Rollback a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+ :param is_prepared: whether or not
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+ :param recover: if the recover flag was passed.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ """Commit a two phase transaction on the given connection.
+
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+ :param is_prepared: whether or not
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+ :param recover: if the recover flag was passed.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_recover_twophase(self, connection):
+ """Recover list of uncommitted prepared two phase transaction
+ identifiers on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of ``cursor.executemany(statement,
+ parameters)``."""
+
+ raise NotImplementedError()
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of ``cursor.execute(statement,
+ parameters)``."""
+
+ raise NotImplementedError()
+
+ def do_execute_no_params(
+ self, cursor, statement, parameters, context=None
+ ):
+ """Provide an implementation of ``cursor.execute(statement)``.
+
+ The parameter collection should not be sent.
+
+ """
+
+ raise NotImplementedError()
+
+ def is_disconnect(self, e, connection, cursor):
+ """Return True if the given DB-API error indicates an invalid
+ connection"""
+
+ raise NotImplementedError()
+
+ def connect(self, *cargs, **cparams):
+ r"""Establish a connection using this dialect's DBAPI.
+
+ The default implementation of this method is::
+
+ def connect(self, *cargs, **cparams):
+ return self.dbapi.connect(*cargs, **cparams)
+
+ The ``*cargs, **cparams`` parameters are generated directly
+ from this dialect's :meth:`.Dialect.create_connect_args` method.
+
+ This method may be used for dialects that need to perform programmatic
+ per-connection steps when a new connection is procured from the
+ DBAPI.
+
+
+ :param \*cargs: positional parameters returned from the
+ :meth:`.Dialect.create_connect_args` method
+
+ :param \*\*cparams: keyword parameters returned from the
+ :meth:`.Dialect.create_connect_args` method.
+
+ :return: a DBAPI connection, typically from the :pep:`249` module
+ level ``.connect()`` function.
+
+ .. seealso::
+
+ :meth:`.Dialect.create_connect_args`
+
+ :meth:`.Dialect.on_connect`
+
+ """
+
+ def on_connect_url(self, url):
+ """return a callable which sets up a newly created DBAPI connection.
+
+ This method is a new hook that supersedes the
+ :meth:`_engine.Dialect.on_connect` method when implemented by a
+ dialect. When not implemented by a dialect, it invokes the
+ :meth:`_engine.Dialect.on_connect` method directly to maintain
+ compatibility with existing dialects. There is no deprecation
+ for :meth:`_engine.Dialect.on_connect` expected.
+
+ The callable should accept a single argument "conn" which is the
+ DBAPI connection itself. The inner callable has no
+ return value.
+
+ E.g.::
+
+ class MyDialect(default.DefaultDialect):
+ # ...
+
+ def on_connect_url(self, url):
+ def do_on_connect(connection):
+ connection.execute("SET SPECIAL FLAGS etc")
+
+ return do_on_connect
+
+ This is used to set dialect-wide per-connection options such as
+ isolation modes, Unicode modes, etc.
+
+ This method differs from :meth:`_engine.Dialect.on_connect` in that
+ it is passed the :class:`_engine.URL` object that's relevant to the
+ connect args. Normally the only way to get this is from the
+ :meth:`_engine.Dialect.on_connect` hook is to look on the
+ :class:`_engine.Engine` itself, however this URL object may have been
+ replaced by plugins.
+
+ .. note::
+
+ The default implementation of
+ :meth:`_engine.Dialect.on_connect_url` is to invoke the
+ :meth:`_engine.Dialect.on_connect` method. Therefore if a dialect
+ implements this method, the :meth:`_engine.Dialect.on_connect`
+ method **will not be called** unless the overriding dialect calls
+ it directly from here.
+
+ .. versionadded:: 1.4.3 added :meth:`_engine.Dialect.on_connect_url`
+ which normally calls into :meth:`_engine.Dialect.on_connect`.
+
+ :param url: a :class:`_engine.URL` object representing the
+ :class:`_engine.URL` that was passed to the
+ :meth:`_engine.Dialect.create_connect_args` method.
+
+ :return: a callable that accepts a single DBAPI connection as an
+ argument, or None.
+
+ .. seealso::
+
+ :meth:`_engine.Dialect.on_connect`
+
+ """
+ return self.on_connect()
+
+ def on_connect(self):
+ """return a callable which sets up a newly created DBAPI connection.
+
+ The callable should accept a single argument "conn" which is the
+ DBAPI connection itself. The inner callable has no
+ return value.
+
+ E.g.::
+
+ class MyDialect(default.DefaultDialect):
+ # ...
+
+ def on_connect(self):
+ def do_on_connect(connection):
+ connection.execute("SET SPECIAL FLAGS etc")
+
+ return do_on_connect
+
+ This is used to set dialect-wide per-connection options such as
+ isolation modes, Unicode modes, etc.
+
+ The "do_on_connect" callable is invoked by using the
+ :meth:`_events.PoolEvents.connect` event
+ hook, then unwrapping the DBAPI connection and passing it into the
+ callable.
+
+ .. versionchanged:: 1.4 the on_connect hook is no longer called twice
+ for the first connection of a dialect. The on_connect hook is still
+ called before the :meth:`_engine.Dialect.initialize` method however.
+
+ .. versionchanged:: 1.4.3 the on_connect hook is invoked from a new
+ method on_connect_url that passes the URL that was used to create
+ the connect args. Dialects can implement on_connect_url instead
+ of on_connect if they need the URL object that was used for the
+ connection in order to get additional context.
+
+ If None is returned, no event listener is generated.
+
+ :return: a callable that accepts a single DBAPI connection as an
+ argument, or None.
+
+ .. seealso::
+
+ :meth:`.Dialect.connect` - allows the DBAPI ``connect()`` sequence
+ itself to be controlled.
+
+ :meth:`.Dialect.on_connect_url` - supersedes
+ :meth:`.Dialect.on_connect` to also receive the
+ :class:`_engine.URL` object in context.
+
+ """
+ return None
+
+ def reset_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, revert its isolation to the default.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine`
+ isolation level facilities; these APIs should be preferred for
+ most typical use cases.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+ """
+
+ raise NotImplementedError()
+
+ def set_isolation_level(self, dbapi_conn, level):
+ """Given a DBAPI connection, set its isolation level.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine`
+ isolation level facilities; these APIs should be preferred for
+ most typical use cases.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+ """
+
+ raise NotImplementedError()
+
+ def get_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level.
+
+ When working with a :class:`_engine.Connection` object,
+ the corresponding
+ DBAPI connection may be procured using the
+ :attr:`_engine.Connection.connection` accessor.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine` isolation level facilities;
+ these APIs should be preferred for most typical use cases.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+
+ """
+
+ raise NotImplementedError()
+
+ def get_default_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level, or
+ a default isolation level if one cannot be retrieved.
+
+ This method may only raise NotImplementedError and
+ **must not raise any other exception**, as it is used implicitly upon
+ first connect.
+
+ The method **must return a value** for a dialect that supports
+ isolation level settings, as this level is what will be reverted
+ towards when a per-connection isolation level change is made.
+
+ The method defaults to using the :meth:`.Dialect.get_isolation_level`
+ method unless overridden by a dialect.
+
+ .. versionadded:: 1.3.22
+
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def get_dialect_cls(cls, url):
+ """Given a URL, return the :class:`.Dialect` that will be used.
+
+ This is a hook that allows an external plugin to provide functionality
+ around an existing dialect, by allowing the plugin to be loaded
+ from the url based on an entrypoint, and then the plugin returns
+ the actual dialect to be used.
+
+ By default this just returns the cls.
+
+ .. versionadded:: 1.0.3
+
+ """
+ return cls
+
+ @classmethod
+ def load_provisioning(cls):
+ """set up the provision.py module for this dialect.
+
+ For dialects that include a provision.py module that sets up
+ provisioning followers, this method should initiate that process.
+
+ A typical implementation would be::
+
+ @classmethod
+ def load_provisioning(cls):
+ __import__("mydialect.provision")
+
+ The default method assumes a module named ``provision.py`` inside
+ the owning package of the current dialect, based on the ``__module__``
+ attribute::
+
+ @classmethod
+ def load_provisioning(cls):
+ package = ".".join(cls.__module__.split(".")[0:-1])
+ try:
+ __import__(package + ".provision")
+ except ImportError:
+ pass
+
+ .. versionadded:: 1.3.14
+
+ """
+
+ @classmethod
+ def engine_created(cls, engine):
+ """A convenience hook called before returning the final
+ :class:`_engine.Engine`.
+
+ If the dialect returned a different class from the
+ :meth:`.get_dialect_cls`
+ method, then the hook is called on both classes, first on
+ the dialect class returned by the :meth:`.get_dialect_cls` method and
+ then on the class on which the method was called.
+
+ The hook should be used by dialects and/or wrappers to apply special
+ events to the engine or its components. In particular, it allows
+ a dialect-wrapping class to apply dialect-level events.
+
+ .. versionadded:: 1.0.3
+
+ """
+
+ def get_driver_connection(self, connection):
+ """Returns the connection object as returned by the external driver
+ package.
+
+ For normal dialects that use a DBAPI compliant driver this call
+ will just return the ``connection`` passed as argument.
+ For dialects that instead adapt a non DBAPI compliant driver, like
+ when adapting an asyncio driver, this call will return the
+ connection-like object as returned by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+ raise NotImplementedError()
+
+
+class CreateEnginePlugin(object):
+ """A set of hooks intended to augment the construction of an
+ :class:`_engine.Engine` object based on entrypoint names in a URL.
+
+ The purpose of :class:`_engine.CreateEnginePlugin` is to allow third-party
+ systems to apply engine, pool and dialect level event listeners without
+ the need for the target application to be modified; instead, the plugin
+ names can be added to the database URL. Target applications for
+ :class:`_engine.CreateEnginePlugin` include:
+
+ * connection and SQL performance tools, e.g. which use events to track
+ number of checkouts and/or time spent with statements
+
+ * connectivity plugins such as proxies
+
+ A rudimentary :class:`_engine.CreateEnginePlugin` that attaches a logger
+ to an :class:`_engine.Engine` object might look like::
+
+
+ import logging
+
+ from sqlalchemy.engine import CreateEnginePlugin
+ from sqlalchemy import event
+
+ class LogCursorEventsPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ # consume the parameter "log_cursor_logging_name" from the
+ # URL query
+ logging_name = url.query.get("log_cursor_logging_name", "log_cursor")
+
+ self.log = logging.getLogger(logging_name)
+
+ def update_url(self, url):
+ "update the URL to one that no longer includes our parameters"
+ return url.difference_update_query(["log_cursor_logging_name"])
+
+ def engine_created(self, engine):
+ "attach an event listener after the new Engine is constructed"
+ event.listen(engine, "before_cursor_execute", self._log_event)
+
+
+ def _log_event(
+ self,
+ conn,
+ cursor,
+ statement,
+ parameters,
+ context,
+ executemany):
+
+ self.log.info("Plugin logged cursor event: %s", statement)
+
+
+
+ Plugins are registered using entry points in a similar way as that
+ of dialects::
+
+ entry_points={
+ 'sqlalchemy.plugins': [
+ 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin'
+ ]
+
+ A plugin that uses the above names would be invoked from a database
+ URL as in::
+
+ from sqlalchemy import create_engine
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=log_cursor_plugin&log_cursor_logging_name=mylogger"
+ )
+
+ The ``plugin`` URL parameter supports multiple instances, so that a URL
+ may specify multiple plugins; they are loaded in the order stated
+ in the URL::
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three")
+
+ The plugin names may also be passed directly to :func:`_sa.create_engine`
+ using the :paramref:`_sa.create_engine.plugins` argument::
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test",
+ plugins=["myplugin"])
+
+ .. versionadded:: 1.2.3 plugin names can also be specified
+ to :func:`_sa.create_engine` as a list
+
+ A plugin may consume plugin-specific arguments from the
+ :class:`_engine.URL` object as well as the ``kwargs`` dictionary, which is
+ the dictionary of arguments passed to the :func:`_sa.create_engine`
+ call. "Consuming" these arguments includes that they must be removed
+ when the plugin initializes, so that the arguments are not passed along
+ to the :class:`_engine.Dialect` constructor, where they will raise an
+ :class:`_exc.ArgumentError` because they are not known by the dialect.
+
+ As of version 1.4 of SQLAlchemy, arguments should continue to be consumed
+ from the ``kwargs`` dictionary directly, by removing the values with a
+ method such as ``dict.pop``. Arguments from the :class:`_engine.URL` object
+ should be consumed by implementing the
+ :meth:`_engine.CreateEnginePlugin.update_url` method, returning a new copy
+ of the :class:`_engine.URL` with plugin-specific parameters removed::
+
+ class MyPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ self.my_argument_one = url.query['my_argument_one']
+ self.my_argument_two = url.query['my_argument_two']
+ self.my_argument_three = kwargs.pop('my_argument_three', None)
+
+ def update_url(self, url):
+ return url.difference_update_query(
+ ["my_argument_one", "my_argument_two"]
+ )
+
+ Arguments like those illustrated above would be consumed from a
+ :func:`_sa.create_engine` call such as::
+
+ from sqlalchemy import create_engine
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=myplugin&my_argument_one=foo&my_argument_two=bar",
+ my_argument_three='bat'
+ )
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now immutable; a
+ :class:`_engine.CreateEnginePlugin` that needs to alter the
+ :class:`_engine.URL` should implement the newly added
+ :meth:`_engine.CreateEnginePlugin.update_url` method, which
+ is invoked after the plugin is constructed.
+
+ For migration, construct the plugin in the following way, checking
+ for the existence of the :meth:`_engine.CreateEnginePlugin.update_url`
+ method to detect which version is running::
+
+ class MyPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ if hasattr(CreateEnginePlugin, "update_url"):
+ # detect the 1.4 API
+ self.my_argument_one = url.query['my_argument_one']
+ self.my_argument_two = url.query['my_argument_two']
+ else:
+ # detect the 1.3 and earlier API - mutate the
+ # URL directly
+ self.my_argument_one = url.query.pop('my_argument_one')
+ self.my_argument_two = url.query.pop('my_argument_two')
+
+ self.my_argument_three = kwargs.pop('my_argument_three', None)
+
+ def update_url(self, url):
+ # this method is only called in the 1.4 version
+ return url.difference_update_query(
+ ["my_argument_one", "my_argument_two"]
+ )
+
+ .. seealso::
+
+ :ref:`change_5526` - overview of the :class:`_engine.URL` change which
+ also includes notes regarding :class:`_engine.CreateEnginePlugin`.
+
+
+ When the engine creation process completes and produces the
+ :class:`_engine.Engine` object, it is again passed to the plugin via the
+ :meth:`_engine.CreateEnginePlugin.engine_created` hook. In this hook, additional
+ changes can be made to the engine, most typically involving setup of
+ events (e.g. those defined in :ref:`core_event_toplevel`).
+
+ .. versionadded:: 1.1
+
+ """ # noqa: E501
+
+ def __init__(self, url, kwargs):
+ """Construct a new :class:`.CreateEnginePlugin`.
+
+ The plugin object is instantiated individually for each call
+ to :func:`_sa.create_engine`. A single :class:`_engine.
+ Engine` will be
+ passed to the :meth:`.CreateEnginePlugin.engine_created` method
+ corresponding to this URL.
+
+ :param url: the :class:`_engine.URL` object. The plugin may inspect
+ the :class:`_engine.URL` for arguments. Arguments used by the
+ plugin should be removed, by returning an updated :class:`_engine.URL`
+ from the :meth:`_engine.CreateEnginePlugin.update_url` method.
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now immutable, so a
+ :class:`_engine.CreateEnginePlugin` that needs to alter the
+ :class:`_engine.URL` object should implement the
+ :meth:`_engine.CreateEnginePlugin.update_url` method.
+
+ :param kwargs: The keyword arguments passed to
+ :func:`_sa.create_engine`.
+
+ """
+ self.url = url
+
+ def update_url(self, url):
+ """Update the :class:`_engine.URL`.
+
+ A new :class:`_engine.URL` should be returned. This method is
+ typically used to consume configuration arguments from the
+ :class:`_engine.URL` which must be removed, as they will not be
+ recognized by the dialect. The
+ :meth:`_engine.URL.difference_update_query` method is available
+ to remove these arguments. See the docstring at
+ :class:`_engine.CreateEnginePlugin` for an example.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ def handle_dialect_kwargs(self, dialect_cls, dialect_args):
+ """parse and modify dialect kwargs"""
+
+ def handle_pool_kwargs(self, pool_cls, pool_args):
+ """parse and modify pool kwargs"""
+
+ def engine_created(self, engine):
+ """Receive the :class:`_engine.Engine`
+ object when it is fully constructed.
+
+ The plugin may make additional changes to the engine, such as
+ registering engine or connection pool events.
+
+ """
+
+
+class ExecutionContext(object):
+ """A messenger object for a Dialect that corresponds to a single
+ execution.
+
+ ExecutionContext should have these data members:
+
+ connection
+ Connection object which can be freely used by default value
+ generators to execute SQL. This Connection should reference the
+ same underlying connection/transactional resources of
+ root_connection.
+
+ root_connection
+ Connection object which is the source of this ExecutionContext. This
+ Connection may have close_with_result=True set, in which case it can
+ only be used once.
+
+ dialect
+ dialect which created this ExecutionContext.
+
+ cursor
+ DB-API cursor procured from the connection,
+
+ compiled
+ if passed to constructor, sqlalchemy.engine.base.Compiled object
+ being executed,
+
+ statement
+ string version of the statement to be executed. Is either
+ passed to the constructor, or must be created from the
+ sql.Compiled object by the time pre_exec() has completed.
+
+ parameters
+ bind parameters passed to the execute() method. For compiled
+ statements, this is a dictionary or list of dictionaries. For
+ textual statements, it should be in a format suitable for the
+ dialect's paramstyle (i.e. dict or list of dicts for non
+ positional, list or list of lists/tuples for positional).
+
+ isinsert
+ True if the statement is an INSERT.
+
+ isupdate
+ True if the statement is an UPDATE.
+
+ should_autocommit
+ True if the statement is a "committable" statement.
+
+ prefetch_cols
+ a list of Column objects for which a client-side default
+ was fired off. Applies to inserts and updates.
+
+ postfetch_cols
+ a list of Column objects for which a server-side default or
+ inline SQL expression value was fired off. Applies to inserts
+ and updates.
+ """
+
+ def create_cursor(self):
+ """Return a new cursor generated from this ExecutionContext's
+ connection.
+
+ Some dialects may wish to change the behavior of
+ connection.cursor(), such as postgresql which may return a PG
+ "server side" cursor.
+ """
+
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ """Called before an execution of a compiled statement.
+
+ If a compiled statement was passed to this ExecutionContext,
+ the `statement` and `parameters` datamembers must be
+ initialized after this statement is complete.
+ """
+
+ raise NotImplementedError()
+
+ def get_out_parameter_values(self, out_param_names):
+ """Return a sequence of OUT parameter values from a cursor.
+
+ For dialects that support OUT parameters, this method will be called
+ when there is a :class:`.SQLCompiler` object which has the
+ :attr:`.SQLCompiler.has_out_parameters` flag set. This flag in turn
+ will be set to True if the statement itself has :class:`.BindParameter`
+ objects that have the ``.isoutparam`` flag set which are consumed by
+ the :meth:`.SQLCompiler.visit_bindparam` method. If the dialect
+ compiler produces :class:`.BindParameter` objects with ``.isoutparam``
+ set which are not handled by :meth:`.SQLCompiler.visit_bindparam`, it
+ should set this flag explicitly.
+
+ The list of names that were rendered for each bound parameter
+ is passed to the method. The method should then return a sequence of
+ values corresponding to the list of parameter objects. Unlike in
+ previous SQLAlchemy versions, the values can be the **raw values** from
+ the DBAPI; the execution context will apply the appropriate type
+ handler based on what's present in self.compiled.binds and update the
+ values. The processed dictionary will then be made available via the
+ ``.out_parameters`` collection on the result object. Note that
+ SQLAlchemy 1.4 has multiple kinds of result object as part of the 2.0
+ transition.
+
+ .. versionadded:: 1.4 - added
+ :meth:`.ExecutionContext.get_out_parameter_values`, which is invoked
+ automatically by the :class:`.DefaultExecutionContext` when there
+ are :class:`.BindParameter` objects with the ``.isoutparam`` flag
+ set. This replaces the practice of setting out parameters within
+ the now-removed ``get_result_proxy()`` method.
+
+ """
+ raise NotImplementedError()
+
+ def post_exec(self):
+ """Called after the execution of a compiled statement.
+
+ If a compiled statement was passed to this ExecutionContext,
+ the `last_insert_ids`, `last_inserted_params`, etc.
+ datamembers should be available after this method completes.
+ """
+
+ raise NotImplementedError()
+
+ def handle_dbapi_exception(self, e):
+ """Receive a DBAPI exception which occurred upon execute, result
+ fetch, etc."""
+
+ raise NotImplementedError()
+
+ def should_autocommit_text(self, statement):
+ """Parse the given textual statement and return True if it refers to
+ a "committable" statement"""
+
+ raise NotImplementedError()
+
+ def lastrow_has_defaults(self):
+ """Return True if the last INSERT or UPDATE row contained
+ inlined or database-side defaults.
+ """
+
+ raise NotImplementedError()
+
+ def get_rowcount(self):
+ """Return the DBAPI ``cursor.rowcount`` value, or in some
+ cases an interpreted value.
+
+ See :attr:`_engine.CursorResult.rowcount` for details on this.
+
+ """
+
+ raise NotImplementedError()
+
+
+@util.deprecated_20_cls(
+ ":class:`.Connectable`",
+ alternative=(
+ "The :class:`_engine.Engine` will be the only Core "
+ "object that features a .connect() method, and the "
+ ":class:`_engine.Connection` will be the only object that features "
+ "an .execute() method."
+ ),
+ constructor=None,
+)
+class Connectable(object):
+ """Interface for an object which supports execution of SQL constructs.
+
+ The two implementations of :class:`.Connectable` are
+ :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+ Connectable must also implement the 'dialect' member which references a
+ :class:`.Dialect` instance.
+
+ """
+
+ def connect(self, **kwargs):
+ """Return a :class:`_engine.Connection` object.
+
+ Depending on context, this may be ``self`` if this object
+ is already an instance of :class:`_engine.Connection`, or a newly
+ procured :class:`_engine.Connection` if this object is an instance
+ of :class:`_engine.Engine`.
+
+ """
+
+ engine = None
+ """The :class:`_engine.Engine` instance referred to by this
+ :class:`.Connectable`.
+
+ May be ``self`` if this is already an :class:`_engine.Engine`.
+
+ """
+
+ def execute(self, object_, *multiparams, **params):
+ """Executes the given construct and returns a
+ :class:`_engine.CursorResult`.
+ """
+ raise NotImplementedError()
+
+ def scalar(self, object_, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying cursor is closed after execution.
+ """
+ raise NotImplementedError()
+
+ def _run_visitor(self, visitorcallable, element, **kwargs):
+ raise NotImplementedError()
+
+ def _execute_clauseelement(self, elem, multiparams=None, params=None):
+ raise NotImplementedError()
+
+
+class ExceptionContext(object):
+ """Encapsulate information about an error condition in progress.
+
+ This object exists solely to be passed to the
+ :meth:`_events.ConnectionEvents.handle_error` event,
+ supporting an interface that
+ can be extended without backwards-incompatibility.
+
+ .. versionadded:: 0.9.7
+
+ """
+
+ connection = None
+ """The :class:`_engine.Connection` in use during the exception.
+
+ This member is present, except in the case of a failure when
+ first connecting.
+
+ .. seealso::
+
+ :attr:`.ExceptionContext.engine`
+
+
+ """
+
+ engine = None
+ """The :class:`_engine.Engine` in use during the exception.
+
+ This member should always be present, even in the case of a failure
+ when first connecting.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ cursor = None
+ """The DBAPI cursor object.
+
+ May be None.
+
+ """
+
+ statement = None
+ """String SQL statement that was emitted directly to the DBAPI.
+
+ May be None.
+
+ """
+
+ parameters = None
+ """Parameter collection that was emitted directly to the DBAPI.
+
+ May be None.
+
+ """
+
+ original_exception = None
+ """The exception object which was caught.
+
+ This member is always present.
+
+ """
+
+ sqlalchemy_exception = None
+ """The :class:`sqlalchemy.exc.StatementError` which wraps the original,
+ and will be raised if exception handling is not circumvented by the event.
+
+ May be None, as not all exception types are wrapped by SQLAlchemy.
+ For DBAPI-level exceptions that subclass the dbapi's Error class, this
+ field will always be present.
+
+ """
+
+ chained_exception = None
+ """The exception that was returned by the previous handler in the
+ exception chain, if any.
+
+ If present, this exception will be the one ultimately raised by
+ SQLAlchemy unless a subsequent handler replaces it.
+
+ May be None.
+
+ """
+
+ execution_context = None
+ """The :class:`.ExecutionContext` corresponding to the execution
+ operation in progress.
+
+ This is present for statement execution operations, but not for
+ operations such as transaction begin/end. It also is not present when
+ the exception was raised before the :class:`.ExecutionContext`
+ could be constructed.
+
+ Note that the :attr:`.ExceptionContext.statement` and
+ :attr:`.ExceptionContext.parameters` members may represent a
+ different value than that of the :class:`.ExecutionContext`,
+ potentially in the case where a
+ :meth:`_events.ConnectionEvents.before_cursor_execute` event or similar
+ modified the statement/parameters to be sent.
+
+ May be None.
+
+ """
+
+ is_disconnect = None
+ """Represent whether the exception as occurred represents a "disconnect"
+ condition.
+
+ This flag will always be True or False within the scope of the
+ :meth:`_events.ConnectionEvents.handle_error` handler.
+
+ SQLAlchemy will defer to this flag in order to determine whether or not
+ the connection should be invalidated subsequently. That is, by
+ assigning to this flag, a "disconnect" event which then results in
+ a connection and pool invalidation can be invoked or prevented by
+ changing this flag.
+
+
+ .. note:: The pool "pre_ping" handler enabled using the
+ :paramref:`_sa.create_engine.pool_pre_ping` parameter does **not**
+ consult this event before deciding if the "ping" returned false,
+ as opposed to receiving an unhandled error. For this use case, the
+ :ref:`legacy recipe based on engine_connect() may be used
+ <pool_disconnects_pessimistic_custom>`. A future API allow more
+ comprehensive customization of the "disconnect" detection mechanism
+ across all functions.
+
+ """
+
+ invalidate_pool_on_disconnect = True
+ """Represent whether all connections in the pool should be invalidated
+ when a "disconnect" condition is in effect.
+
+ Setting this flag to False within the scope of the
+ :meth:`_events.ConnectionEvents.handle_error`
+ event will have the effect such
+ that the full collection of connections in the pool will not be
+ invalidated during a disconnect; only the current connection that is the
+ subject of the error will actually be invalidated.
+
+ The purpose of this flag is for custom disconnect-handling schemes where
+ the invalidation of other connections in the pool is to be performed
+ based on other conditions, or even on a per-connection basis.
+
+ .. versionadded:: 1.0.3
+
+ """
+
+
+class AdaptedConnection(object):
+ """Interface of an adapted connection object to support the DBAPI protocol.
+
+ Used by asyncio dialects to provide a sync-style pep-249 facade on top
+ of the asyncio connection/cursor API provided by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ __slots__ = ("_connection",)
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect."""
+ return self._connection
+
+ def run_async(self, fn):
+ """Run the awaitable returned by the given function, which is passed
+ the raw asyncio driver connection.
+
+ This is used to invoke awaitable-only methods on the driver connection
+ within the context of a "synchronous" method, like a connection
+ pool event handler.
+
+ E.g.::
+
+ engine = create_async_engine(...)
+
+ @event.listens_for(engine.sync_engine, "connect")
+ def register_custom_types(dbapi_connection, ...):
+ dbapi_connection.run_async(
+ lambda connection: connection.set_type_codec(
+ 'MyCustomType', encoder, decoder, ...
+ )
+ )
+
+ .. versionadded:: 1.4.30
+
+ .. seealso::
+
+ :ref:`asyncio_events_run_async`
+
+ """
+ return await_only(fn(self._connection))
+
+ def __repr__(self):
+ return "<AdaptedConnection %s>" % self._connection
diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py
new file mode 100644
index 0000000..6fcb09f
--- /dev/null
+++ b/lib/sqlalchemy/engine/mock.py
@@ -0,0 +1,118 @@
+# engine/mock.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from operator import attrgetter
+
+from . import base
+from . import url as _url
+from .. import util
+from ..sql import ddl
+
+
+class MockConnection(base.Connectable):
+ def __init__(self, dialect, execute):
+ self._dialect = dialect
+ self.execute = execute
+
+ engine = property(lambda s: s)
+ dialect = property(attrgetter("_dialect"))
+ name = property(lambda s: s._dialect.name)
+
+ def schema_for_object(self, obj):
+ return obj.schema
+
+ def connect(self, **kwargs):
+ return self
+
+ def execution_options(self, **kw):
+ return self
+
+ def compiler(self, statement, parameters, **kwargs):
+ return self._dialect.compiler(
+ statement, parameters, engine=self, **kwargs
+ )
+
+ def create(self, entity, **kwargs):
+ kwargs["checkfirst"] = False
+
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
+
+ def drop(self, entity, **kwargs):
+ kwargs["checkfirst"] = False
+
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single(entity)
+
+ def _run_ddl_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
+ kwargs["checkfirst"] = False
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+
+ def execute(self, object_, *multiparams, **params):
+ raise NotImplementedError()
+
+
+def create_mock_engine(url, executor, **kw):
+ """Create a "mock" engine used for echoing DDL.
+
+ This is a utility function used for debugging or storing the output of DDL
+ sequences as generated by :meth:`_schema.MetaData.create_all`
+ and related methods.
+
+ The function accepts a URL which is used only to determine the kind of
+ dialect to be used, as well as an "executor" callable function which
+ will receive a SQL expression object and parameters, which can then be
+ echoed or otherwise printed. The executor's return value is not handled,
+ nor does the engine allow regular string statements to be invoked, and
+ is therefore only useful for DDL that is sent to the database without
+ receiving any results.
+
+ E.g.::
+
+ from sqlalchemy import create_mock_engine
+
+ def dump(sql, *multiparams, **params):
+ print(sql.compile(dialect=engine.dialect))
+
+ engine = create_mock_engine('postgresql://', dump)
+ metadata.create_all(engine, checkfirst=False)
+
+ :param url: A string URL which typically needs to contain only the
+ database backend name.
+
+ :param executor: a callable which receives the arguments ``sql``,
+ ``*multiparams`` and ``**params``. The ``sql`` parameter is typically
+ an instance of :class:`.DDLElement`, which can then be compiled into a
+ string using :meth:`.DDLElement.compile`.
+
+ .. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces
+ the previous "mock" engine strategy used with
+ :func:`_sa.create_engine`.
+
+ .. seealso::
+
+ :ref:`faq_ddl_as_string`
+
+ """
+
+ # create url.URL object
+ u = _url.make_url(url)
+
+ dialect_cls = u.get_dialect()
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(dialect_cls):
+ if k in kw:
+ dialect_args[k] = kw.pop(k)
+
+ # create dialect
+ dialect = dialect_cls(**dialect_args)
+
+ return MockConnection(dialect, executor)
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
new file mode 100644
index 0000000..b475228
--- /dev/null
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -0,0 +1,1160 @@
+# engine/reflection.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides an abstraction for obtaining database schema information.
+
+Usage Notes:
+
+Here are some general conventions when accessing the low level inspector
+methods such as get_table_names, get_columns, etc.
+
+1. Inspector methods return lists of dicts in most cases for the following
+ reasons:
+
+ * They're both standard types that can be serialized.
+ * Using a dict instead of a tuple allows easy expansion of attributes.
+ * Using a list for the outer structure maintains order and is easy to work
+ with (e.g. list comprehension [d['name'] for d in cols]).
+
+2. Records that contain a name, such as the column name in a column record
+ use the key 'name'. So for most return values, each record will have a
+ 'name' attribute..
+"""
+
+import contextlib
+
+from .base import Connectable
+from .base import Connection
+from .base import Engine
+from .. import exc
+from .. import inspection
+from .. import sql
+from .. import util
+from ..sql import operators
+from ..sql import schema as sa_schema
+from ..sql.type_api import TypeEngine
+from ..util import topological
+
+
+@util.decorator
+def cache(fn, self, con, *args, **kw):
+ info_cache = kw.get("info_cache", None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = (
+ fn.__name__,
+ tuple(a for a in args if isinstance(a, util.string_types)),
+ tuple((k, v) for k, v in kw.items() if k != "info_cache"),
+ )
+ ret = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+
+
+@inspection._self_inspects
+class Inspector(object):
+ """Performs database schema inspection.
+
+ The Inspector acts as a proxy to the reflection methods of the
+ :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a
+ consistent interface as well as caching support for previously
+ fetched metadata.
+
+ A :class:`_reflection.Inspector` object is usually created via the
+ :func:`_sa.inspect` function, which may be passed an
+ :class:`_engine.Engine`
+ or a :class:`_engine.Connection`::
+
+ from sqlalchemy import inspect, create_engine
+ engine = create_engine('...')
+ insp = inspect(engine)
+
+ Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated
+ with the engine may opt to return an :class:`_reflection.Inspector`
+ subclass that
+ provides additional methods specific to the dialect's target database.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The __init__() method on :class:`_reflection.Inspector` "
+ "is deprecated and "
+ "will be removed in a future release. Please use the "
+ ":func:`.sqlalchemy.inspect` "
+ "function on an :class:`_engine.Engine` or "
+ ":class:`_engine.Connection` "
+ "in order to "
+ "acquire an :class:`_reflection.Inspector`.",
+ )
+ def __init__(self, bind):
+ """Initialize a new :class:`_reflection.Inspector`.
+
+ :param bind: a :class:`~sqlalchemy.engine.Connectable`,
+ which is typically an instance of
+ :class:`~sqlalchemy.engine.Engine` or
+ :class:`~sqlalchemy.engine.Connection`.
+
+ For a dialect-specific instance of :class:`_reflection.Inspector`, see
+ :meth:`_reflection.Inspector.from_engine`
+
+ """
+ return self._init_legacy(bind)
+
+ @classmethod
+ def _construct(cls, init, bind):
+
+ if hasattr(bind.dialect, "inspector"):
+ cls = bind.dialect.inspector
+
+ self = cls.__new__(cls)
+ init(self, bind)
+ return self
+
+ def _init_legacy(self, bind):
+ if hasattr(bind, "exec_driver_sql"):
+ self._init_connection(bind)
+ else:
+ self._init_engine(bind)
+
+ def _init_engine(self, engine):
+ self.bind = self.engine = engine
+ engine.connect().close()
+ self._op_context_requires_connect = True
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
+
+ def _init_connection(self, connection):
+ self.bind = connection
+ self.engine = connection.engine
+ self._op_context_requires_connect = False
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
+
+ @classmethod
+ @util.deprecated(
+ "1.4",
+ "The from_engine() method on :class:`_reflection.Inspector` "
+ "is deprecated and "
+ "will be removed in a future release. Please use the "
+ ":func:`.sqlalchemy.inspect` "
+ "function on an :class:`_engine.Engine` or "
+ ":class:`_engine.Connection` "
+ "in order to "
+ "acquire an :class:`_reflection.Inspector`.",
+ )
+ def from_engine(cls, bind):
+ """Construct a new dialect-specific Inspector object from the given
+ engine or connection.
+
+ :param bind: a :class:`~sqlalchemy.engine.Connectable`,
+ which is typically an instance of
+ :class:`~sqlalchemy.engine.Engine` or
+ :class:`~sqlalchemy.engine.Connection`.
+
+ This method differs from direct a direct constructor call of
+ :class:`_reflection.Inspector` in that the
+ :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to
+ provide a dialect-specific :class:`_reflection.Inspector` instance,
+ which may
+ provide additional methods.
+
+ See the example at :class:`_reflection.Inspector`.
+
+ """
+ return cls._construct(cls._init_legacy, bind)
+
+ @inspection._inspects(Connectable)
+ def _connectable_insp(bind):
+ # this method should not be used unless some unusual case
+ # has subclassed "Connectable"
+
+ return Inspector._construct(Inspector._init_legacy, bind)
+
+ @inspection._inspects(Engine)
+ def _engine_insp(bind):
+ return Inspector._construct(Inspector._init_engine, bind)
+
+ @inspection._inspects(Connection)
+ def _connection_insp(bind):
+ return Inspector._construct(Inspector._init_connection, bind)
+
+ @contextlib.contextmanager
+ def _operation_context(self):
+ """Return a context that optimizes for multiple operations on a single
+ transaction.
+
+ This essentially allows connect()/close() to be called if we detected
+ that we're against an :class:`_engine.Engine` and not a
+ :class:`_engine.Connection`.
+
+ """
+ if self._op_context_requires_connect:
+ conn = self.bind.connect()
+ else:
+ conn = self.bind
+ try:
+ yield conn
+ finally:
+ if self._op_context_requires_connect:
+ conn.close()
+
+ @contextlib.contextmanager
+ def _inspection_context(self):
+ """Return an :class:`_reflection.Inspector`
+ from this one that will run all
+ operations on a single connection.
+
+ """
+
+ with self._operation_context() as conn:
+ sub_insp = self._construct(self.__class__._init_connection, conn)
+ sub_insp.info_cache = self.info_cache
+ yield sub_insp
+
+ @property
+ def default_schema_name(self):
+ """Return the default schema name presented by the dialect
+ for the current engine's database user.
+
+ E.g. this is typically ``public`` for PostgreSQL and ``dbo``
+ for SQL Server.
+
+ """
+ return self.dialect.default_schema_name
+
+ def get_schema_names(self):
+ """Return all schema names."""
+
+ if hasattr(self.dialect, "get_schema_names"):
+ with self._operation_context() as conn:
+ return self.dialect.get_schema_names(
+ conn, info_cache=self.info_cache
+ )
+ return []
+
+ def get_table_names(self, schema=None):
+ """Return all table names in referred to within a particular schema.
+
+ The names are expected to be real tables only, not views.
+ Views are instead returned using the
+ :meth:`_reflection.Inspector.get_view_names`
+ method.
+
+
+ :param schema: Schema name. If ``schema`` is left at ``None``, the
+ database's default schema is
+ used, else the named schema is searched. If the database does not
+ support named schemas, behavior is undefined if ``schema`` is not
+ passed as ``None``. For special quoting, use :class:`.quoted_name`.
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names`
+
+ :attr:`_schema.MetaData.sorted_tables`
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def has_table(self, table_name, schema=None):
+ """Return True if the backend has a table of the given name.
+
+
+ :param table_name: name of the table to check
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method
+ replaces the :meth:`_engine.Engine.has_table` method.
+
+ """
+ # TODO: info_cache?
+ with self._operation_context() as conn:
+ return self.dialect.has_table(conn, table_name, schema)
+
+ def has_sequence(self, sequence_name, schema=None):
+ """Return True if the backend has a table of the given name.
+
+ :param sequence_name: name of the table to check
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4
+
+ """
+ # TODO: info_cache?
+ with self._operation_context() as conn:
+ return self.dialect.has_sequence(conn, sequence_name, schema)
+
+ def get_sorted_table_and_fkc_names(self, schema=None):
+ """Return dependency-sorted table and foreign key constraint names in
+ referred to within a particular schema.
+
+ This will yield 2-tuples of
+ ``(tablename, [(tname, fkname), (tname, fkname), ...])``
+ consisting of table names in CREATE order grouped with the foreign key
+ constraint names that are not detected as belonging to a cycle.
+ The final element
+ will be ``(None, [(tname, fkname), (tname, fkname), ..])``
+ which will consist of remaining
+ foreign key constraint names that would require a separate CREATE
+ step after-the-fact, based on dependencies between tables.
+
+ .. versionadded:: 1.0.-
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_table_names`
+
+ :func:`.sort_tables_and_constraints` - similar method which works
+ with an already-given :class:`_schema.MetaData`.
+
+ """
+
+ with self._operation_context() as conn:
+ tnames = self.dialect.get_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ tuples = set()
+ remaining_fkcs = set()
+
+ fknames_for_table = {}
+ for tname in tnames:
+ fkeys = self.get_foreign_keys(tname, schema)
+ fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
+ for fkey in fkeys:
+ if tname != fkey["referred_table"]:
+ tuples.add((fkey["referred_table"], tname))
+ try:
+ candidate_sort = list(topological.sort(tuples, tnames))
+ except exc.CircularDependencyError as err:
+ for edge in err.edges:
+ tuples.remove(edge)
+ remaining_fkcs.update(
+ (edge[1], fkc) for fkc in fknames_for_table[edge[1]]
+ )
+
+ candidate_sort = list(topological.sort(tuples, tnames))
+ return [
+ (tname, fknames_for_table[tname].difference(remaining_fkcs))
+ for tname in candidate_sort
+ ] + [(None, list(remaining_fkcs))]
+
+ def get_temp_table_names(self):
+ """Return a list of temporary table names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_temp_table_names(
+ conn, info_cache=self.info_cache
+ )
+
+ def get_temp_view_names(self):
+ """Return a list of temporary view names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.get_temp_view_names(
+ conn, info_cache=self.info_cache
+ )
+
+ def get_table_options(self, table_name, schema=None, **kw):
+ """Return a dictionary of options specified when the table of the
+ given name was created.
+
+ This currently includes some options that apply to MySQL tables.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+ if hasattr(self.dialect, "get_table_options"):
+ with self._operation_context() as conn:
+ return self.dialect.get_table_options(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+ return {}
+
+ def get_view_names(self, schema=None):
+ """Return all view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def get_sequence_names(self, schema=None):
+ """Return all sequence names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_sequence_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def get_view_definition(self, view_name, schema=None):
+ """Return definition for `view_name`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_definition(
+ conn, view_name, schema, info_cache=self.info_cache
+ )
+
+ def get_columns(self, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ column information as a list of dicts with these keys:
+
+ * ``name`` - the column's name
+
+ * ``type`` - the type of this column; an instance of
+ :class:`~sqlalchemy.types.TypeEngine`
+
+ * ``nullable`` - boolean flag if the column is NULL or NOT NULL
+
+ * ``default`` - the column's server default value - this is returned
+ as a string SQL expression.
+
+ * ``autoincrement`` - indicates that the column is auto incremented -
+ this is returned as a boolean or 'auto'
+
+ * ``comment`` - (optional) the comment on the column. Only some
+ dialects return this key
+
+ * ``computed`` - (optional) when present it indicates that this column
+ is computed by the database. Only some dialects return this key.
+ Returned as a dict with the keys:
+
+ * ``sqltext`` - the expression used to generate this column returned
+ as a string SQL expression
+
+ * ``persisted`` - (optional) boolean that indicates if the column is
+ stored in the table
+
+ .. versionadded:: 1.3.16 - added support for computed reflection.
+
+ * ``identity`` - (optional) when present it indicates that this column
+ is a generated always column. Only some dialects return this key.
+ For a list of keywords on this dict see :class:`_schema.Identity`.
+
+ .. versionadded:: 1.4 - added support for identity column reflection.
+
+ * ``dialect_options`` - (optional) a dict with dialect specific options
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :return: list of dictionaries, each representing the definition of
+ a database column.
+
+ """
+
+ with self._operation_context() as conn:
+ col_defs = self.dialect.get_columns(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+ for col_def in col_defs:
+ # make this easy and only return instances for coltype
+ coltype = col_def["type"]
+ if not isinstance(coltype, TypeEngine):
+ col_def["type"] = coltype()
+ return col_defs
+
+ def get_pk_constraint(self, table_name, schema=None, **kw):
+ """Return information about primary key constraint on `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ primary key information as a dictionary with these keys:
+
+ * ``constrained_columns`` -
+ a list of column names that make up the primary key
+
+ * ``name`` -
+ optional name of the primary key constraint.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.get_pk_constraint(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_foreign_keys(self, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ foreign key information as a list of dicts with these keys:
+
+ * ``constrained_columns`` -
+ a list of column names that make up the foreign key
+
+ * ``referred_schema`` -
+ the name of the referred schema
+
+ * ``referred_table`` -
+ the name of the referred table
+
+ * ``referred_columns`` -
+ a list of column names in the referred table that correspond to
+ constrained_columns
+
+ * ``name`` -
+ optional name of the foreign key constraint.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_foreign_keys(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_indexes(self, table_name, schema=None, **kw):
+ """Return information about indexes in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ index information as a list of dicts with these keys:
+
+ * ``name`` -
+ the index's name
+
+ * ``column_names`` -
+ list of column names in order
+
+ * ``unique`` -
+ boolean
+
+ * ``column_sorting`` -
+ optional dict mapping column names to tuple of sort keywords,
+ which may include ``asc``, ``desc``, ``nulls_first``, ``nulls_last``.
+
+ .. versionadded:: 1.3.5
+
+ * ``dialect_options`` -
+ dict of dialect-specific index options. May not be present
+ for all dialects.
+
+ .. versionadded:: 1.0.0
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_indexes(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_unique_constraints(self, table_name, schema=None, **kw):
+ """Return information about unique constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ unique constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the unique constraint's name
+
+ * ``column_names`` -
+ list of column names in order
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_unique_constraints(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_table_comment(self, table_name, schema=None, **kw):
+ """Return information about the table comment for ``table_name``.
+
+ Given a string ``table_name`` and an optional string ``schema``,
+ return table comment information as a dictionary with these keys:
+
+ * ``text`` -
+ text of the comment.
+
+ Raises ``NotImplementedError`` for a dialect that does not support
+ comments.
+
+ .. versionadded:: 1.2
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_comment(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_check_constraints(self, table_name, schema=None, **kw):
+ """Return information about check constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ check constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the check constraint's name
+
+ * ``sqltext`` -
+ the check constraint's SQL expression
+
+ * ``dialect_options`` -
+ may or may not be present; a dictionary with additional
+ dialect-specific options for this CHECK constraint
+
+ .. versionadded:: 1.3.8
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_check_constraints(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ @util.deprecated_20(
+ ":meth:`_reflection.Inspector.reflecttable`",
+ "The :meth:`_reflection.Inspector.reflecttable` "
+ "method was renamed to "
+ ":meth:`_reflection.Inspector.reflect_table`. This deprecated alias "
+ "will be removed in a future release.",
+ )
+ def reflecttable(self, *args, **kwargs):
+ "See reflect_table. This method name is deprecated"
+ return self.reflect_table(*args, **kwargs)
+
+ def reflect_table(
+ self,
+ table,
+ include_columns,
+ exclude_columns=(),
+ resolve_fks=True,
+ _extend_on=None,
+ ):
+ """Given a :class:`_schema.Table` object, load its internal
+ constructs based on introspection.
+
+ This is the underlying method used by most dialects to produce
+ table reflection. Direct usage is like::
+
+ from sqlalchemy import create_engine, MetaData, Table
+ from sqlalchemy import inspect
+
+ engine = create_engine('...')
+ meta = MetaData()
+ user_table = Table('user', meta)
+ insp = inspect(engine)
+ insp.reflect_table(user_table, None)
+
+ .. versionchanged:: 1.4 Renamed from ``reflecttable`` to
+ ``reflect_table``
+
+ :param table: a :class:`~sqlalchemy.schema.Table` instance.
+ :param include_columns: a list of string column names to include
+ in the reflection process. If ``None``, all columns are reflected.
+
+ """
+
+ if _extend_on is not None:
+ if table in _extend_on:
+ return
+ else:
+ _extend_on.add(table)
+
+ dialect = self.bind.dialect
+
+ with self._operation_context() as conn:
+ schema = conn.schema_for_object(table)
+
+ table_name = table.name
+
+ # get table-level arguments that are specifically
+ # intended for reflection, e.g. oracle_resolve_synonyms.
+ # these are unconditionally passed to related Table
+ # objects
+ reflection_options = dict(
+ (k, table.dialect_kwargs.get(k))
+ for k in dialect.reflection_options
+ if k in table.dialect_kwargs
+ )
+
+ # reflect table options, like mysql_engine
+ tbl_opts = self.get_table_options(
+ table_name, schema, **table.dialect_kwargs
+ )
+ if tbl_opts:
+ # add additional kwargs to the Table if the dialect
+ # returned them
+ table._validate_dialect_kwargs(tbl_opts)
+
+ if util.py2k:
+ if isinstance(schema, str):
+ schema = schema.decode(dialect.encoding)
+ if isinstance(table_name, str):
+ table_name = table_name.decode(dialect.encoding)
+
+ found_table = False
+ cols_by_orig_name = {}
+
+ for col_d in self.get_columns(
+ table_name, schema, **table.dialect_kwargs
+ ):
+ found_table = True
+
+ self._reflect_column(
+ table,
+ col_d,
+ include_columns,
+ exclude_columns,
+ cols_by_orig_name,
+ )
+
+ # NOTE: support tables/views with no columns
+ if not found_table and not self.has_table(table_name, schema):
+ raise exc.NoSuchTableError(table_name)
+
+ self._reflect_pk(
+ table_name, schema, table, cols_by_orig_name, exclude_columns
+ )
+
+ self._reflect_fk(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on,
+ reflection_options,
+ )
+
+ self._reflect_indexes(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_unique_constraints(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_check_constraints(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_table_comment(
+ table_name, schema, table, reflection_options
+ )
+
+ def _reflect_column(
+ self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
+ ):
+
+ orig_name = col_d["name"]
+
+ table.metadata.dispatch.column_reflect(self, table, col_d)
+ table.dispatch.column_reflect(self, table, col_d)
+
+ # fetch name again as column_reflect is allowed to
+ # change it
+ name = col_d["name"]
+ if (include_columns and name not in include_columns) or (
+ exclude_columns and name in exclude_columns
+ ):
+ return
+
+ coltype = col_d["type"]
+
+ col_kw = dict(
+ (k, col_d[k])
+ for k in [
+ "nullable",
+ "autoincrement",
+ "quote",
+ "info",
+ "key",
+ "comment",
+ ]
+ if k in col_d
+ )
+
+ if "dialect_options" in col_d:
+ col_kw.update(col_d["dialect_options"])
+
+ colargs = []
+ if col_d.get("default") is not None:
+ default = col_d["default"]
+ if isinstance(default, sql.elements.TextClause):
+ default = sa_schema.DefaultClause(default, _reflected=True)
+ elif not isinstance(default, sa_schema.FetchedValue):
+ default = sa_schema.DefaultClause(
+ sql.text(col_d["default"]), _reflected=True
+ )
+
+ colargs.append(default)
+
+ if "computed" in col_d:
+ computed = sa_schema.Computed(**col_d["computed"])
+ colargs.append(computed)
+
+ if "identity" in col_d:
+ computed = sa_schema.Identity(**col_d["identity"])
+ colargs.append(computed)
+
+ if "sequence" in col_d:
+ self._reflect_col_sequence(col_d, colargs)
+
+ cols_by_orig_name[orig_name] = col = sa_schema.Column(
+ name, coltype, *colargs, **col_kw
+ )
+
+ if col.key in table.primary_key:
+ col.primary_key = True
+ table.append_column(col, replace_existing=True)
+
+ def _reflect_col_sequence(self, col_d, colargs):
+ if "sequence" in col_d:
+ # TODO: mssql and sybase are using this.
+ seq = col_d["sequence"]
+ sequence = sa_schema.Sequence(seq["name"], 1, 1)
+ if "start" in seq:
+ sequence.start = seq["start"]
+ if "increment" in seq:
+ sequence.increment = seq["increment"]
+ colargs.append(sequence)
+
+ def _reflect_pk(
+ self, table_name, schema, table, cols_by_orig_name, exclude_columns
+ ):
+ pk_cons = self.get_pk_constraint(
+ table_name, schema, **table.dialect_kwargs
+ )
+ if pk_cons:
+ pk_cols = [
+ cols_by_orig_name[pk]
+ for pk in pk_cons["constrained_columns"]
+ if pk in cols_by_orig_name and pk not in exclude_columns
+ ]
+
+ # update pk constraint name
+ table.primary_key.name = pk_cons.get("name")
+
+ # tell the PKConstraint to re-initialize
+ # its column collection
+ table.primary_key._reload(pk_cols)
+
+ def _reflect_fk(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on,
+ reflection_options,
+ ):
+ fkeys = self.get_foreign_keys(
+ table_name, schema, **table.dialect_kwargs
+ )
+ for fkey_d in fkeys:
+ conname = fkey_d["name"]
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ constrained_columns = [
+ cols_by_orig_name[c].key if c in cols_by_orig_name else c
+ for c in fkey_d["constrained_columns"]
+ ]
+
+ if (
+ exclude_columns
+ and set(constrained_columns).intersection(exclude_columns)
+ or (
+ include_columns
+ and set(constrained_columns).difference(include_columns)
+ )
+ ):
+ continue
+
+ referred_schema = fkey_d["referred_schema"]
+ referred_table = fkey_d["referred_table"]
+ referred_columns = fkey_d["referred_columns"]
+ refspec = []
+ if referred_schema is not None:
+ if resolve_fks:
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ schema=referred_schema,
+ autoload_with=self.bind,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(
+ ".".join([referred_schema, referred_table, column])
+ )
+ else:
+ if resolve_fks:
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload_with=self.bind,
+ schema=sa_schema.BLANK_SCHEMA,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(".".join([referred_table, column]))
+ if "options" in fkey_d:
+ options = fkey_d["options"]
+ else:
+ options = {}
+
+ table.append_constraint(
+ sa_schema.ForeignKeyConstraint(
+ constrained_columns,
+ refspec,
+ conname,
+ link_to_name=True,
+ **options
+ )
+ )
+
+ _index_sort_exprs = [
+ ("asc", operators.asc_op),
+ ("desc", operators.desc_op),
+ ("nulls_first", operators.nulls_first_op),
+ ("nulls_last", operators.nulls_last_op),
+ ]
+
+ def _reflect_indexes(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+ # Indexes
+ indexes = self.get_indexes(table_name, schema)
+ for index_d in indexes:
+ name = index_d["name"]
+ columns = index_d["column_names"]
+ column_sorting = index_d.get("column_sorting", {})
+ unique = index_d["unique"]
+ flavor = index_d.get("type", "index")
+ dialect_options = index_d.get("dialect_options", {})
+
+ duplicates = index_d.get("duplicates_constraint")
+ if include_columns and not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting %s key for (%s), key covers omitted columns."
+ % (flavor, ", ".join(columns))
+ )
+ continue
+ if duplicates:
+ continue
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ idx_cols = []
+ for c in columns:
+ try:
+ idx_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
+ except KeyError:
+ util.warn(
+ "%s key '%s' was not located in "
+ "columns for table '%s'" % (flavor, c, table_name)
+ )
+ continue
+ c_sorting = column_sorting.get(c, ())
+ for k, op in self._index_sort_exprs:
+ if k in c_sorting:
+ idx_col = op(idx_col)
+ idx_cols.append(idx_col)
+
+ sa_schema.Index(
+ name,
+ *idx_cols,
+ _table=table,
+ **dict(list(dialect_options.items()) + [("unique", unique)])
+ )
+
+ def _reflect_unique_constraints(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+
+ # Unique Constraints
+ try:
+ constraints = self.get_unique_constraints(table_name, schema)
+ except NotImplementedError:
+ # optional dialect feature
+ return
+
+ for const_d in constraints:
+ conname = const_d["name"]
+ columns = const_d["column_names"]
+ duplicates = const_d.get("duplicates_index")
+ if include_columns and not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting unique constraint key for (%s), "
+ "key covers omitted columns." % ", ".join(columns)
+ )
+ continue
+ if duplicates:
+ continue
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ constrained_cols = []
+ for c in columns:
+ try:
+ constrained_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
+ except KeyError:
+ util.warn(
+ "unique constraint key '%s' was not located in "
+ "columns for table '%s'" % (c, table_name)
+ )
+ else:
+ constrained_cols.append(constrained_col)
+ table.append_constraint(
+ sa_schema.UniqueConstraint(*constrained_cols, name=conname)
+ )
+
+ def _reflect_check_constraints(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+ try:
+ constraints = self.get_check_constraints(table_name, schema)
+ except NotImplementedError:
+ # optional dialect feature
+ return
+
+ for const_d in constraints:
+ table.append_constraint(sa_schema.CheckConstraint(**const_d))
+
+ def _reflect_table_comment(
+ self, table_name, schema, table, reflection_options
+ ):
+ try:
+ comment_dict = self.get_table_comment(table_name, schema)
+ except NotImplementedError:
+ return
+ else:
+ table.comment = comment_dict.get("text", None)
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
new file mode 100644
index 0000000..1fd4e1c
--- /dev/null
+++ b/lib/sqlalchemy/engine/result.py
@@ -0,0 +1,1857 @@
+# engine/result.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define generic result set constructs."""
+
+
+import functools
+import itertools
+import operator
+
+from .row import _baserow_usecext
+from .row import Row
+from .. import exc
+from .. import util
+from ..sql.base import _generative
+from ..sql.base import HasMemoized
+from ..sql.base import InPlaceGenerative
+from ..util import collections_abc
+from ..util import py2k
+
+
+if _baserow_usecext:
+ from sqlalchemy.cresultproxy import tuplegetter
+
+ _row_as_tuple = tuplegetter
+else:
+
+ def tuplegetter(*indexes):
+ it = operator.itemgetter(*indexes)
+
+ if len(indexes) > 1:
+ return it
+ else:
+ return lambda row: (it(row),)
+
+ def _row_as_tuple(*indexes):
+ # circumvent LegacyRow.__getitem__ pointing to
+ # _get_by_key_impl_mapping for now. otherwise we could
+ # use itemgetter
+ getters = [
+ operator.methodcaller("_get_by_int_impl", index)
+ for index in indexes
+ ]
+ return lambda rec: tuple([getter(rec) for getter in getters])
+
+
+class ResultMetaData(object):
+ """Base for metadata about result rows."""
+
+ __slots__ = ()
+
+ _tuplefilter = None
+ _translated_indexes = None
+ _unique_filters = None
+
+ @property
+ def keys(self):
+ return RMKeyView(self)
+
+ def _has_key(self, key):
+ raise NotImplementedError()
+
+ def _for_freeze(self):
+ raise NotImplementedError()
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ assert raiseerr
+ util.raise_(KeyError(key), replace_context=err)
+
+ def _warn_for_nonint(self, key):
+ util.warn_deprecated_20(
+ "Retrieving row members using strings or other non-integers is "
+ "deprecated; use row._mapping for a dictionary interface "
+ "to the row"
+ )
+
+ def _raise_for_nonint(self, key):
+ raise TypeError(
+ "TypeError: tuple indices must be integers or slices, not %s"
+ % type(key).__name__
+ )
+
+ def _index_for_key(self, keys, raiseerr):
+ raise NotImplementedError()
+
+ def _metadata_for_keys(self, key):
+ raise NotImplementedError()
+
+ def _reduce(self, keys):
+ raise NotImplementedError()
+
+ def _getter(self, key, raiseerr=True):
+
+ index = self._index_for_key(key, raiseerr)
+
+ if index is not None:
+ return operator.itemgetter(index)
+ else:
+ return None
+
+ def _row_as_tuple_getter(self, keys):
+ indexes = self._indexes_for_keys(keys)
+ return _row_as_tuple(*indexes)
+
+
+class RMKeyView(collections_abc.KeysView):
+ __slots__ = ("_parent", "_keys")
+
+ def __init__(self, parent):
+ self._parent = parent
+ self._keys = [k for k in parent._keys if k is not None]
+
+ def __len__(self):
+ return len(self._keys)
+
+ def __repr__(self):
+ return "{0.__class__.__name__}({0._keys!r})".format(self)
+
+ def __iter__(self):
+ return iter(self._keys)
+
+ def __contains__(self, item):
+ if not _baserow_usecext and isinstance(item, int):
+ return False
+
+ # note this also includes special key fallback behaviors
+ # which also don't seem to be tested in test_resultset right now
+ return self._parent._has_key(item)
+
+ def __eq__(self, other):
+ return list(other) == list(self)
+
+ def __ne__(self, other):
+ return list(other) != list(self)
+
+
+class SimpleResultMetaData(ResultMetaData):
+ """result metadata for in-memory collections."""
+
+ __slots__ = (
+ "_keys",
+ "_keymap",
+ "_processors",
+ "_tuplefilter",
+ "_translated_indexes",
+ "_unique_filters",
+ )
+
+ def __init__(
+ self,
+ keys,
+ extra=None,
+ _processors=None,
+ _tuplefilter=None,
+ _translated_indexes=None,
+ _unique_filters=None,
+ ):
+ self._keys = list(keys)
+ self._tuplefilter = _tuplefilter
+ self._translated_indexes = _translated_indexes
+ self._unique_filters = _unique_filters
+
+ if extra:
+ recs_names = [
+ (
+ (name,) + extras,
+ (index, name, extras),
+ )
+ for index, (name, extras) in enumerate(zip(self._keys, extra))
+ ]
+ else:
+ recs_names = [
+ ((name,), (index, name, ()))
+ for index, name in enumerate(self._keys)
+ ]
+
+ self._keymap = {key: rec for keys, rec in recs_names for key in keys}
+
+ self._processors = _processors
+
+ def _has_key(self, key):
+ return key in self._keymap
+
+ def _for_freeze(self):
+ unique_filters = self._unique_filters
+ if unique_filters and self._tuplefilter:
+ unique_filters = self._tuplefilter(unique_filters)
+
+ # TODO: are we freezing the result with or without uniqueness
+ # applied?
+ return SimpleResultMetaData(
+ self._keys,
+ extra=[self._keymap[key][2] for key in self._keys],
+ _unique_filters=unique_filters,
+ )
+
+ def __getstate__(self):
+ return {
+ "_keys": self._keys,
+ "_translated_indexes": self._translated_indexes,
+ }
+
+ def __setstate__(self, state):
+ if state["_translated_indexes"]:
+ _translated_indexes = state["_translated_indexes"]
+ _tuplefilter = tuplegetter(*_translated_indexes)
+ else:
+ _translated_indexes = _tuplefilter = None
+ self.__init__(
+ state["_keys"],
+ _translated_indexes=_translated_indexes,
+ _tuplefilter=_tuplefilter,
+ )
+
+ def _contains(self, value, row):
+ return value in row._data
+
+ def _index_for_key(self, key, raiseerr=True):
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, raiseerr)
+
+ return rec[0]
+
+ def _indexes_for_keys(self, keys):
+ return [self._keymap[key][0] for key in keys]
+
+ def _metadata_for_keys(self, keys):
+ for key in keys:
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, True)
+
+ yield rec
+
+ def _reduce(self, keys):
+ try:
+ metadata_for_keys = [
+ self._keymap[
+ self._keys[key] if int in key.__class__.__mro__ else key
+ ]
+ for key in keys
+ ]
+ except KeyError as ke:
+ self._key_fallback(ke.args[0], ke, True)
+
+ indexes, new_keys, extra = zip(*metadata_for_keys)
+
+ if self._translated_indexes:
+ indexes = [self._translated_indexes[idx] for idx in indexes]
+
+ tup = tuplegetter(*indexes)
+
+ new_metadata = SimpleResultMetaData(
+ new_keys,
+ extra=extra,
+ _tuplefilter=tup,
+ _translated_indexes=indexes,
+ _processors=self._processors,
+ _unique_filters=self._unique_filters,
+ )
+
+ return new_metadata
+
+
+def result_tuple(fields, extra=None):
+ parent = SimpleResultMetaData(fields, extra)
+ return functools.partial(
+ Row, parent, parent._processors, parent._keymap, Row._default_key_style
+ )
+
+
+# a symbol that indicates to internal Result methods that
+# "no row is returned". We can't use None for those cases where a scalar
+# filter is applied to rows.
+_NO_ROW = util.symbol("NO_ROW")
+
+
+class ResultInternal(InPlaceGenerative):
+ _real_result = None
+ _generate_rows = True
+ _unique_filter_state = None
+ _post_creational_filter = None
+ _is_cursor = False
+
+ @HasMemoized.memoized_attribute
+ def _row_getter(self):
+ real_result = self._real_result if self._real_result else self
+
+ if real_result._source_supports_scalars:
+ if not self._generate_rows:
+ return None
+ else:
+ _proc = real_result._process_row
+
+ def process_row(
+ metadata, processors, keymap, key_style, scalar_obj
+ ):
+ return _proc(
+ metadata, processors, keymap, key_style, (scalar_obj,)
+ )
+
+ else:
+ process_row = real_result._process_row
+
+ key_style = real_result._process_row._default_key_style
+ metadata = self._metadata
+
+ keymap = metadata._keymap
+ processors = metadata._processors
+ tf = metadata._tuplefilter
+
+ if tf and not real_result._source_supports_scalars:
+ if processors:
+ processors = tf(processors)
+
+ _make_row_orig = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+
+ def make_row(row):
+ return _make_row_orig(tf(row))
+
+ else:
+ make_row = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+
+ fns = ()
+
+ if real_result._row_logging_fn:
+ fns = (real_result._row_logging_fn,)
+ else:
+ fns = ()
+
+ if fns:
+ _make_row = make_row
+
+ def make_row(row):
+ row = _make_row(row)
+ for fn in fns:
+ row = fn(row)
+ return row
+
+ return make_row
+
+ @HasMemoized.memoized_attribute
+ def _iterator_getter(self):
+
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def iterrows(self):
+ for row in self._fetchiter_impl():
+ obj = make_row(row) if make_row else row
+ hashed = strategy(obj) if strategy else obj
+ if hashed in uniques:
+ continue
+ uniques.add(hashed)
+ if post_creational_filter:
+ obj = post_creational_filter(obj)
+ yield obj
+
+ else:
+
+ def iterrows(self):
+ for row in self._fetchiter_impl():
+ row = make_row(row) if make_row else row
+ if post_creational_filter:
+ row = post_creational_filter(row)
+ yield row
+
+ return iterrows
+
+ def _raw_all_rows(self):
+ make_row = self._row_getter
+ rows = self._fetchall_impl()
+ return [make_row(row) for row in rows]
+
+ def _allrows(self):
+
+ post_creational_filter = self._post_creational_filter
+
+ make_row = self._row_getter
+
+ rows = self._fetchall_impl()
+ if make_row:
+ made_rows = [make_row(row) for row in rows]
+ else:
+ made_rows = rows
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ rows = [
+ made_row
+ for made_row, sig_row in [
+ (
+ made_row,
+ strategy(made_row) if strategy else made_row,
+ )
+ for made_row in made_rows
+ ]
+ if sig_row not in uniques and not uniques.add(sig_row)
+ ]
+ else:
+ rows = made_rows
+
+ if post_creational_filter:
+ rows = [post_creational_filter(row) for row in rows]
+ return rows
+
+ @HasMemoized.memoized_attribute
+ def _onerow_getter(self):
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def onerow(self):
+ _onerow = self._fetchone_impl
+ while True:
+ row = _onerow()
+ if row is None:
+ return _NO_ROW
+ else:
+ obj = make_row(row) if make_row else row
+ hashed = strategy(obj) if strategy else obj
+ if hashed in uniques:
+ continue
+ else:
+ uniques.add(hashed)
+ if post_creational_filter:
+ obj = post_creational_filter(obj)
+ return obj
+
+ else:
+
+ def onerow(self):
+ row = self._fetchone_impl()
+ if row is None:
+ return _NO_ROW
+ else:
+ row = make_row(row) if make_row else row
+ if post_creational_filter:
+ row = post_creational_filter(row)
+ return row
+
+ return onerow
+
+ @HasMemoized.memoized_attribute
+ def _manyrow_getter(self):
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def filterrows(make_row, rows, strategy, uniques):
+ if make_row:
+ rows = [make_row(row) for row in rows]
+
+ if strategy:
+ made_rows = (
+ (made_row, strategy(made_row)) for made_row in rows
+ )
+ else:
+ made_rows = ((made_row, made_row) for made_row in rows)
+ return [
+ made_row
+ for made_row, sig_row in made_rows
+ if sig_row not in uniques and not uniques.add(sig_row)
+ ]
+
+ def manyrows(self, num):
+ collect = []
+
+ _manyrows = self._fetchmany_impl
+
+ if num is None:
+ # if None is passed, we don't know the default
+ # manyrows number, DBAPI has this as cursor.arraysize
+ # different DBAPIs / fetch strategies may be different.
+ # do a fetch to find what the number is. if there are
+ # only fewer rows left, then it doesn't matter.
+ real_result = (
+ self._real_result if self._real_result else self
+ )
+ if real_result._yield_per:
+ num_required = num = real_result._yield_per
+ else:
+ rows = _manyrows(num)
+ num = len(rows)
+ collect.extend(
+ filterrows(make_row, rows, strategy, uniques)
+ )
+ num_required = num - len(collect)
+ else:
+ num_required = num
+
+ while num_required:
+ rows = _manyrows(num_required)
+ if not rows:
+ break
+
+ collect.extend(
+ filterrows(make_row, rows, strategy, uniques)
+ )
+ num_required = num - len(collect)
+
+ if post_creational_filter:
+ collect = [post_creational_filter(row) for row in collect]
+ return collect
+
+ else:
+
+ def manyrows(self, num):
+ if num is None:
+ real_result = (
+ self._real_result if self._real_result else self
+ )
+ num = real_result._yield_per
+
+ rows = self._fetchmany_impl(num)
+ if make_row:
+ rows = [make_row(row) for row in rows]
+ if post_creational_filter:
+ rows = [post_creational_filter(row) for row in rows]
+ return rows
+
+ return manyrows
+
+ def _only_one_row(
+ self,
+ raise_for_second_row,
+ raise_for_none,
+ scalar,
+ ):
+ onerow = self._fetchone_impl
+
+ row = onerow(hard_close=True)
+ if row is None:
+ if raise_for_none:
+ raise exc.NoResultFound(
+ "No row was found when one was required"
+ )
+ else:
+ return None
+
+ if scalar and self._source_supports_scalars:
+ self._generate_rows = False
+ make_row = None
+ else:
+ make_row = self._row_getter
+
+ try:
+ row = make_row(row) if make_row else row
+ except:
+ self._soft_close(hard=True)
+ raise
+
+ if raise_for_second_row:
+ if self._unique_filter_state:
+ # for no second row but uniqueness, need to essentially
+ # consume the entire result :(
+ uniques, strategy = self._unique_strategy
+
+ existing_row_hash = strategy(row) if strategy else row
+
+ while True:
+ next_row = onerow(hard_close=True)
+ if next_row is None:
+ next_row = _NO_ROW
+ break
+
+ try:
+ next_row = make_row(next_row) if make_row else next_row
+
+ if strategy:
+ if existing_row_hash == strategy(next_row):
+ continue
+ elif row == next_row:
+ continue
+ # here, we have a row and it's different
+ break
+ except:
+ self._soft_close(hard=True)
+ raise
+ else:
+ next_row = onerow(hard_close=True)
+ if next_row is None:
+ next_row = _NO_ROW
+
+ if next_row is not _NO_ROW:
+ self._soft_close(hard=True)
+ raise exc.MultipleResultsFound(
+ "Multiple rows were found when exactly one was required"
+ if raise_for_none
+ else "Multiple rows were found when one or none "
+ "was required"
+ )
+ else:
+ next_row = _NO_ROW
+ # if we checked for second row then that would have
+ # closed us :)
+ self._soft_close(hard=True)
+
+ if not scalar:
+ post_creational_filter = self._post_creational_filter
+ if post_creational_filter:
+ row = post_creational_filter(row)
+
+ if scalar and make_row:
+ return row[0]
+ else:
+ return row
+
+ def _iter_impl(self):
+ return self._iterator_getter(self)
+
+ def _next_impl(self):
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ raise StopIteration()
+ else:
+ return row
+
+ @_generative
+ def _column_slices(self, indexes):
+ real_result = self._real_result if self._real_result else self
+
+ if real_result._source_supports_scalars and len(indexes) == 1:
+ util.warn_deprecated(
+ "The Result.columns() method has a bug in SQLAlchemy 1.4 that "
+ "is causing it to yield scalar values, rather than Row "
+ "objects, in the case where a single index is passed and the "
+ "result is against ORM mapped objects. In SQLAlchemy 2.0, "
+ "Result will continue yield Row objects in this scenario. "
+ "Use the Result.scalars() method to yield scalar values.",
+ "2.0",
+ )
+ self._generate_rows = False
+ else:
+ self._generate_rows = True
+ self._metadata = self._metadata._reduce(indexes)
+
+ @HasMemoized.memoized_attribute
+ def _unique_strategy(self):
+ uniques, strategy = self._unique_filter_state
+
+ real_result = (
+ self._real_result if self._real_result is not None else self
+ )
+
+ if not strategy and self._metadata._unique_filters:
+ if (
+ real_result._source_supports_scalars
+ and not self._generate_rows
+ ):
+ strategy = self._metadata._unique_filters[0]
+ else:
+ filters = self._metadata._unique_filters
+ if self._metadata._tuplefilter:
+ filters = self._metadata._tuplefilter(filters)
+
+ strategy = operator.methodcaller("_filter_on_values", filters)
+ return uniques, strategy
+
+
+class _WithKeys(object):
+ # used mainly to share documentation on the keys method.
+ # py2k does not allow overriding the __doc__ attribute.
+ def keys(self):
+ """Return an iterable view which yields the string keys that would
+ be represented by each :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ The view also can be tested for key containment using the Python
+ ``in`` operator, which will test both for the string keys represented
+ in the view, as well as for alternate keys such as column objects.
+
+ .. versionchanged:: 1.4 a key view object is returned rather than a
+ plain list.
+
+
+ """
+ return self._metadata.keys
+
+
+class Result(_WithKeys, ResultInternal):
+ """Represent a set of database results.
+
+ .. versionadded:: 1.4 The :class:`.Result` object provides a completely
+ updated usage model and calling facade for SQLAlchemy Core and
+ SQLAlchemy ORM. In Core, it forms the basis of the
+ :class:`.CursorResult` object which replaces the previous
+ :class:`.ResultProxy` interface. When using the ORM, a higher level
+ object called :class:`.ChunkedIteratorResult` is normally used.
+
+ .. note:: In SQLAlchemy 1.4 and above, this object is
+ used for ORM results returned by :meth:`_orm.Session.execute`, which can
+ yield instances of ORM mapped objects either individually or within
+ tuple-like rows. Note that the :class:`_result.Result` object does not
+ deduplicate instances or rows automatically as is the case with the
+ legacy :class:`_orm.Query` object. For in-Python de-duplication of
+ instances or rows, use the :meth:`_result.Result.unique` modifier
+ method.
+
+ .. seealso::
+
+ :ref:`tutorial_fetching_rows` - in the :doc:`/tutorial/index`
+
+ """
+
+ _process_row = Row
+
+ _row_logging_fn = None
+
+ _source_supports_scalars = False
+
+ _yield_per = None
+
+ _attributes = util.immutabledict()
+
+ def __init__(self, cursor_metadata):
+ self._metadata = cursor_metadata
+
+ def _soft_close(self, hard=False):
+ raise NotImplementedError()
+
+ def close(self):
+ """close this :class:`_result.Result`.
+
+ The behavior of this method is implementation specific, and is
+ not implemented by default. The method should generally end
+ the resources in use by the result object and also cause any
+ subsequent iteration or row fetching to raise
+ :class:`.ResourceClosedError`.
+
+ .. versionadded:: 1.4.27 - ``.close()`` was previously not generally
+ available for all :class:`_result.Result` classes, instead only
+ being available on the :class:`_engine.CursorResult` returned for
+ Core statement executions. As most other result objects, namely the
+ ones used by the ORM, are proxying a :class:`_engine.CursorResult`
+ in any case, this allows the underlying cursor result to be closed
+ from the outside facade for the case when the ORM query is using
+ the ``yield_per`` execution option where it does not immediately
+ exhaust and autoclose the database cursor.
+
+ """
+ self._soft_close(hard=True)
+
+ @_generative
+ def yield_per(self, num):
+ """Configure the row-fetching strategy to fetch ``num`` rows at a time.
+
+ This impacts the underlying behavior of the result when iterating over
+ the result object, or otherwise making use of methods such as
+ :meth:`_engine.Result.fetchone` that return one row at a time. Data
+ from the underlying cursor or other data source will be buffered up to
+ this many rows in memory, and the buffered collection will then be
+ yielded out one row at at time or as many rows are requested. Each time
+ the buffer clears, it will be refreshed to this many rows or as many
+ rows remain if fewer remain.
+
+ The :meth:`_engine.Result.yield_per` method is generally used in
+ conjunction with the
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option, which will allow the database dialect in use to make
+ use of a server side cursor, if the DBAPI supports a specific "server
+ side cursor" mode separate from its default mode of operation.
+
+ .. tip::
+
+ Consider using the
+ :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option, which will simultaneously set
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ to ensure the use of server side cursors, as well as automatically
+ invoke the :meth:`_engine.Result.yield_per` method to establish
+ a fixed row buffer size at once.
+
+ The :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option is available for ORM operations, with
+ :class:`_orm.Session`-oriented use described at
+ :ref:`orm_queryguide_yield_per`. The Core-only version which works
+ with :class:`_engine.Connection` is new as of SQLAlchemy 1.4.40.
+
+ .. versionadded:: 1.4
+
+ :param num: number of rows to fetch each time the buffer is refilled.
+ If set to a value below 1, fetches all rows for the next buffer.
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - describes Core behavior for
+ :meth:`_engine.Result.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+ """
+ self._yield_per = num
+
+ @_generative
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.Result`.
+
+ When this filter is applied with no arguments, the rows or objects
+ returned will filtered such that each row is returned uniquely. The
+ algorithm used to determine this uniqueness is by default the Python
+ hashing identity of the whole tuple. In some cases a specialized
+ per-entity hashing scheme may be used, such as when using the ORM, a
+ scheme is applied which works against the primary key identity of
+ returned objects.
+
+ The unique filter is applied **after all other filters**, which means
+ if the columns returned have been refined using a method such as the
+ :meth:`_engine.Result.columns` or :meth:`_engine.Result.scalars`
+ method, the uniquing is applied to **only the column or columns
+ returned**. This occurs regardless of the order in which these
+ methods have been called upon the :class:`_engine.Result` object.
+
+ The unique filter also changes the calculus used for methods like
+ :meth:`_engine.Result.fetchmany` and :meth:`_engine.Result.partitions`.
+ When using :meth:`_engine.Result.unique`, these methods will continue
+ to yield the number of rows or objects requested, after uniquing
+ has been applied. However, this necessarily impacts the buffering
+ behavior of the underlying cursor or datasource, such that multiple
+ underlying calls to ``cursor.fetchmany()`` may be necessary in order
+ to accumulate enough objects in order to provide a unique collection
+ of the requested size.
+
+ :param strategy: a callable that will be applied to rows or objects
+ being iterated, which should return an object that represents the
+ unique value of the row. A Python ``set()`` is used to store
+ these identities. If not passed, a default uniqueness strategy
+ is used which may have been assembled by the source of this
+ :class:`_engine.Result` object.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row.
+
+ This method may be used to limit the columns returned as well
+ as to reorder them. The given list of expressions are normally
+ a series of integers or string key names. They may also be
+ appropriate :class:`.ColumnElement` objects which correspond to
+ a given statement construct.
+
+ E.g.::
+
+ statement = select(table.c.x, table.c.y, table.c.z)
+ result = connection.execute(statement)
+
+ for z, y in result.columns('z', 'y'):
+ # ...
+
+
+ Example of using the column objects from the statement itself::
+
+ for z, y in result.columns(
+ statement.selected_columns.c.z,
+ statement.selected_columns.c.y
+ ):
+ # ...
+
+ .. versionadded:: 1.4
+
+ :param \*col_expressions: indicates columns to be returned. Elements
+ may be integer row indexes, string column names, or appropriate
+ :class:`.ColumnElement` objects corresponding to a select construct.
+
+ :return: this :class:`_engine.Result` object with the modifications
+ given.
+
+ """
+ return self._column_slices(col_expressions)
+
+ def scalars(self, index=0):
+ """Return a :class:`_result.ScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ E.g.::
+
+ >>> result = conn.execute(text("select int_id from table"))
+ >>> result.scalars().all()
+ [1, 2, 3]
+
+ When results are fetched from the :class:`_result.ScalarResult`
+ filtering object, the single column-row that would be returned by the
+ :class:`_result.Result` is instead returned as the column's value.
+
+ .. versionadded:: 1.4
+
+ :param index: integer or row key indicating the column to be fetched
+ from each row, defaults to ``0`` indicating the first column.
+
+ :return: a new :class:`_result.ScalarResult` filtering object referring
+ to this :class:`_result.Result` object.
+
+ """
+ return ScalarResult(self, index)
+
+ def _getter(self, key, raiseerr=True):
+ """return a callable that will retrieve the given key from a
+ :class:`.Row`.
+
+ """
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
+ return self._metadata._getter(key, raiseerr)
+
+ def _tuple_getter(self, keys):
+ """return a callable that will retrieve the given keys from a
+ :class:`.Row`.
+
+ """
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
+ return self._metadata._row_as_tuple_getter(keys)
+
+ def mappings(self):
+ """Apply a mappings filter to returned rows, returning an instance of
+ :class:`_result.MappingResult`.
+
+ When this filter is applied, fetching rows will return
+ :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ :return: a new :class:`_result.MappingResult` filtering object
+ referring to this :class:`_result.Result` object.
+
+ """
+
+ return MappingResult(self)
+
+ def _raw_row_iterator(self):
+ """Return a safe iterator that yields raw row data.
+
+ This is used by the :meth:`._engine.Result.merge` method
+ to merge multiple compatible results together.
+
+ """
+ raise NotImplementedError()
+
+ def _fetchiter_impl(self):
+ raise NotImplementedError()
+
+ def _fetchone_impl(self, hard_close=False):
+ raise NotImplementedError()
+
+ def _fetchall_impl(self):
+ raise NotImplementedError()
+
+ def _fetchmany_impl(self, size=None):
+ raise NotImplementedError()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of rows of the size given.
+
+ Each list will be of the size given, excluding the last list to
+ be yielded, which may have a small number of rows. No empty
+ lists will be yielded.
+
+ The result object is automatically closed when the iterator
+ is fully consumed.
+
+ Note that the backend driver will usually buffer the entire result
+ ahead of time unless the
+ :paramref:`.Connection.execution_options.stream_results` execution
+ option is used indicating that the driver should not pre-buffer
+ results, if possible. Not all drivers support this option and
+ the option is silently ignored for those who do not.
+
+ When using the ORM, the :meth:`_engine.Result.partitions` method
+ is typically more effective from a memory perspective when it is
+ combined with use of the
+ :ref:`yield_per execution option <orm_queryguide_yield_per>`,
+ which instructs both the DBAPI driver to use server side cursors,
+ if available, as well as instructs the ORM loading internals to only
+ build a certain amount of ORM objects from a result at a time before
+ yielding them out.
+
+ .. versionadded:: 1.4
+
+ :param size: indicate the maximum number of rows to be present
+ in each list yielded. If None, makes use of the value set by
+ the :meth:`_engine.Result.yield_per`, method, if it were called,
+ or the :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option, which is equivalent in this regard. If
+ yield_per weren't set, it makes use of the
+ :meth:`_engine.Result.fetchmany` default, which may be backend
+ specific and not well defined.
+
+ :return: iterator of lists
+
+ .. seealso::
+
+ :ref:`engine_stream_results`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.Result.all` method."""
+
+ return self._allrows()
+
+ def fetchone(self):
+ """Fetch one row.
+
+ When all rows are exhausted, returns None.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch the first row of a result only, use the
+ :meth:`_engine.Result.first` method. To iterate through all
+ rows, iterate the :class:`_engine.Result` object directly.
+
+ :return: a :class:`.Row` object if no filters are applied, or None
+ if no rows remain.
+
+ """
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ def fetchmany(self, size=None):
+ """Fetch many rows.
+
+ When all rows are exhausted, returns an empty list.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch rows in groups, use the :meth:`._result.Result.partitions`
+ method.
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all rows in a list.
+
+ Closes the result set after invocation. Subsequent invocations
+ will return an empty list.
+
+ .. versionadded:: 1.4
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return self._allrows()
+
+ def first(self):
+ """Fetch the first row or None if no row is present.
+
+ Closes the result set and discards remaining rows.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
+
+ Additionally, in contrast to the behavior of the legacy ORM
+ :meth:`_orm.Query.first` method, **no limit is applied** to the
+ SQL query which was invoked to produce this :class:`_engine.Result`;
+ for a DBAPI driver that buffers results in memory before yielding
+ rows, all rows will be sent to the Python process and all but
+ the first row will be discarded.
+
+ .. seealso::
+
+ :ref:`migration_20_unify_select`
+
+ :return: a :class:`.Row` object, or None
+ if no rows remain.
+
+ .. seealso::
+
+ :meth:`_result.Result.scalar`
+
+ :meth:`_result.Result.one`
+
+ """
+
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the result has no rows.
+ Raises :class:`.MultipleResultsFound`
+ if multiple rows are returned.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row` or None if no row is available.
+
+ :raises: :class:`.MultipleResultsFound`
+
+ .. seealso::
+
+ :meth:`_result.Result.first`
+
+ :meth:`_result.Result.one`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=True
+ )
+
+ def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=True
+ )
+
+ def one(self):
+ """Return exactly one row or raise an exception.
+
+ Raises :class:`.NoResultFound` if the result returns no
+ rows, or :class:`.MultipleResultsFound` if multiple rows
+ would be returned.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar_one` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.one`.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row`.
+
+ :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+ .. seealso::
+
+ :meth:`_result.Result.first`
+
+ :meth:`_result.Result.one_or_none`
+
+ :meth:`_result.Result.scalar_one`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+ def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=True
+ )
+
+ def freeze(self):
+ """Return a callable object that will produce copies of this
+ :class:`.Result` when invoked.
+
+ The callable object returned is an instance of
+ :class:`_engine.FrozenResult`.
+
+ This is used for result set caching. The method must be called
+ on the result when it has been unconsumed, and calling the method
+ will consume the result fully. When the :class:`_engine.FrozenResult`
+ is retrieved from a cache, it can be called any number of times where
+ it will produce a new :class:`_engine.Result` object each time
+ against its stored set of rows.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ """
+
+ return FrozenResult(self)
+
+ def merge(self, *others):
+ """Merge this :class:`.Result` with other compatible result
+ objects.
+
+ The object returned is an instance of :class:`_engine.MergedResult`,
+ which will be composed of iterators from the given result
+ objects.
+
+ The new result will use the metadata from this result object.
+ The subsequent result objects must be against an identical
+ set of result / cursor metadata, otherwise the behavior is
+ undefined.
+
+ """
+ return MergedResult(self._metadata, (self,) + others)
+
+
+class FilterResult(ResultInternal):
+ """A wrapper for a :class:`_engine.Result` that returns objects other than
+ :class:`_result.Row` objects, such as dictionaries or scalar objects.
+
+ :class:`.FilterResult` is the common base for additional result
+ APIs including :class:`.MappingResult`, :class:`.ScalarResult`
+ and :class:`.AsyncResult`.
+
+ """
+
+ _post_creational_filter = None
+
+ @_generative
+ def yield_per(self, num):
+ """Configure the row-fetching strategy to fetch ``num`` rows at a time.
+
+ The :meth:`_engine.FilterResult.yield_per` method is a pass through
+ to the :meth:`_engine.Result.yield_per` method. See that method's
+ documentation for usage notes.
+
+ .. versionadded:: 1.4.40 - added :meth:`_engine.FilterResult.yield_per`
+ so that the method is available on all result set implementations
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - describes Core behavior for
+ :meth:`_engine.Result.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+ """
+ self._real_result = self._real_result.yield_per(num)
+
+ def _soft_close(self, hard=False):
+ self._real_result._soft_close(hard=hard)
+
+ @property
+ def _attributes(self):
+ return self._real_result._attributes
+
+ def _fetchiter_impl(self):
+ return self._real_result._fetchiter_impl()
+
+ def _fetchone_impl(self, hard_close=False):
+ return self._real_result._fetchone_impl(hard_close=hard_close)
+
+ def _fetchall_impl(self):
+ return self._real_result._fetchall_impl()
+
+ def _fetchmany_impl(self, size=None):
+ return self._real_result._fetchmany_impl(size=size)
+
+
+class ScalarResult(FilterResult):
+ """A wrapper for a :class:`_result.Result` that returns scalar values
+ rather than :class:`_row.Row` values.
+
+ The :class:`_result.ScalarResult` object is acquired by calling the
+ :meth:`_result.Result.scalars` method.
+
+ A special limitation of :class:`_result.ScalarResult` is that it has
+ no ``fetchone()`` method; since the semantics of ``fetchone()`` are that
+ the ``None`` value indicates no more results, this is not compatible
+ with :class:`_result.ScalarResult` since there is no way to distinguish
+ between ``None`` as a row value versus ``None`` as an indicator. Use
+ ``next(result)`` to receive values individually.
+
+ """
+
+ _generate_rows = False
+
+ def __init__(self, real_result, index):
+ self._real_result = real_result
+
+ if real_result._source_supports_scalars:
+ self._metadata = real_result._metadata
+ self._post_creational_filter = None
+ else:
+ self._metadata = real_result._metadata._reduce([index])
+ self._post_creational_filter = operator.itemgetter(0)
+
+ self._unique_filter_state = real_result._unique_filter_state
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.ScalarResult`.
+
+ See :meth:`_engine.Result.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+
+ return self._allrows()
+
+ def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._allrows()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+
+class MappingResult(_WithKeys, FilterResult):
+ """A wrapper for a :class:`_engine.Result` that returns dictionary values
+ rather than :class:`_engine.Row` values.
+
+ The :class:`_engine.MappingResult` object is acquired by calling the
+ :meth:`_engine.Result.mappings` method.
+
+ """
+
+ _generate_rows = True
+
+ _post_creational_filter = operator.attrgetter("_mapping")
+
+ def __init__(self, result):
+ self._real_result = result
+ self._unique_filter_state = result._unique_filter_state
+ self._metadata = result._metadata
+ if result._source_supports_scalars:
+ self._metadata = self._metadata._reduce([0])
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.MappingResult`.
+
+ See :meth:`_engine.Result.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row."""
+ return self._column_slices(col_expressions)
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.MappingResult.all` method."""
+
+ return self._allrows()
+
+ def fetchone(self):
+ """Fetch one object.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return self._allrows()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+
+class FrozenResult(object):
+ """Represents a :class:`.Result` object in a "frozen" state suitable
+ for caching.
+
+ The :class:`_engine.FrozenResult` object is returned from the
+ :meth:`_engine.Result.freeze` method of any :class:`_engine.Result`
+ object.
+
+ A new iterable :class:`.Result` object is generated from a fixed
+ set of data each time the :class:`.FrozenResult` is invoked as
+ a callable::
+
+
+ result = connection.execute(query)
+
+ frozen = result.freeze()
+
+ unfrozen_result_one = frozen()
+
+ for row in unfrozen_result_one:
+ print(row)
+
+ unfrozen_result_two = frozen()
+ rows = unfrozen_result_two.all()
+
+ # ... etc
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ :func:`_orm.loading.merge_frozen_result` - ORM function to merge
+ a frozen result back into a :class:`_orm.Session`.
+
+ """
+
+ def __init__(self, result):
+ self.metadata = result._metadata._for_freeze()
+ self._source_supports_scalars = result._source_supports_scalars
+ self._attributes = result._attributes
+
+ if self._source_supports_scalars:
+ self.data = list(result._raw_row_iterator())
+ else:
+ self.data = result.fetchall()
+
+ def rewrite_rows(self):
+ if self._source_supports_scalars:
+ return [[elem] for elem in self.data]
+ else:
+ return [list(row) for row in self.data]
+
+ def with_new_rows(self, tuple_data):
+ fr = FrozenResult.__new__(FrozenResult)
+ fr.metadata = self.metadata
+ fr._attributes = self._attributes
+ fr._source_supports_scalars = self._source_supports_scalars
+
+ if self._source_supports_scalars:
+ fr.data = [d[0] for d in tuple_data]
+ else:
+ fr.data = tuple_data
+ return fr
+
+ def __call__(self):
+ result = IteratorResult(self.metadata, iter(self.data))
+ result._attributes = self._attributes
+ result._source_supports_scalars = self._source_supports_scalars
+ return result
+
+
+class IteratorResult(Result):
+ """A :class:`.Result` that gets data from a Python iterator of
+ :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _hard_closed = False
+
+ def __init__(
+ self,
+ cursor_metadata,
+ iterator,
+ raw=None,
+ _source_supports_scalars=False,
+ ):
+ self._metadata = cursor_metadata
+ self.iterator = iterator
+ self.raw = raw
+ self._source_supports_scalars = _source_supports_scalars
+
+ def _soft_close(self, hard=False, **kw):
+ if hard:
+ self._hard_closed = True
+ if self.raw is not None:
+ self.raw._soft_close(hard=hard, **kw)
+ self.iterator = iter([])
+ self._reset_memoizations()
+
+ def _raise_hard_closed(self):
+ raise exc.ResourceClosedError("This result object is closed.")
+
+ def _raw_row_iterator(self):
+ return self.iterator
+
+ def _fetchiter_impl(self):
+ if self._hard_closed:
+ self._raise_hard_closed()
+ return self.iterator
+
+ def _fetchone_impl(self, hard_close=False):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ row = next(self.iterator, _NO_ROW)
+ if row is _NO_ROW:
+ self._soft_close(hard=hard_close)
+ return None
+ else:
+ return row
+
+ def _fetchall_impl(self):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ try:
+ return list(self.iterator)
+ finally:
+ self._soft_close()
+
+ def _fetchmany_impl(self, size=None):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ return list(itertools.islice(self.iterator, 0, size))
+
+
+def null_result():
+ return IteratorResult(SimpleResultMetaData([]), iter([]))
+
+
+class ChunkedIteratorResult(IteratorResult):
+ """An :class:`.IteratorResult` that works from an iterator-producing
+ callable.
+
+ The given ``chunks`` argument is a function that is given a number of rows
+ to return in each chunk, or ``None`` for all rows. The function should
+ then return an un-consumed iterator of lists, each list of the requested
+ size.
+
+ The function can be called at any time again, in which case it should
+ continue from the same result set but adjust the chunk size as given.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def __init__(
+ self,
+ cursor_metadata,
+ chunks,
+ source_supports_scalars=False,
+ raw=None,
+ dynamic_yield_per=False,
+ ):
+ self._metadata = cursor_metadata
+ self.chunks = chunks
+ self._source_supports_scalars = source_supports_scalars
+ self.raw = raw
+ self.iterator = itertools.chain.from_iterable(self.chunks(None))
+ self.dynamic_yield_per = dynamic_yield_per
+
+ @_generative
+ def yield_per(self, num):
+ # TODO: this throws away the iterator which may be holding
+ # onto a chunk. the yield_per cannot be changed once any
+ # rows have been fetched. either find a way to enforce this,
+ # or we can't use itertools.chain and will instead have to
+ # keep track.
+
+ self._yield_per = num
+ self.iterator = itertools.chain.from_iterable(self.chunks(num))
+
+ def _soft_close(self, **kw):
+ super(ChunkedIteratorResult, self)._soft_close(**kw)
+ self.chunks = lambda size: []
+
+ def _fetchmany_impl(self, size=None):
+ if self.dynamic_yield_per:
+ self.iterator = itertools.chain.from_iterable(self.chunks(size))
+ return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size)
+
+
+class MergedResult(IteratorResult):
+ """A :class:`_engine.Result` that is merged from any number of
+ :class:`_engine.Result` objects.
+
+ Returned by the :meth:`_engine.Result.merge` method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ closed = False
+
+ def __init__(self, cursor_metadata, results):
+ self._results = results
+ super(MergedResult, self).__init__(
+ cursor_metadata,
+ itertools.chain.from_iterable(
+ r._raw_row_iterator() for r in results
+ ),
+ )
+
+ self._unique_filter_state = results[0]._unique_filter_state
+ self._yield_per = results[0]._yield_per
+
+ # going to try something w/ this in next rev
+ self._source_supports_scalars = results[0]._source_supports_scalars
+
+ self._attributes = self._attributes.merge_with(
+ *[r._attributes for r in results]
+ )
+
+ def _soft_close(self, hard=False, **kw):
+ for r in self._results:
+ r._soft_close(hard=hard, **kw)
+ if hard:
+ self.closed = True
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
new file mode 100644
index 0000000..e80e8c6
--- /dev/null
+++ b/lib/sqlalchemy/engine/row.py
@@ -0,0 +1,621 @@
+# engine/row.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define row constructs including :class:`.Row`."""
+
+
+import operator
+
+from .. import util
+from ..sql import util as sql_util
+from ..util.compat import collections_abc
+
+MD_INDEX = 0 # integer index in cursor.description
+
+# This reconstructor is necessary so that pickles with the C extension or
+# without use the same Binary format.
+try:
+ # We need a different reconstructor on the C extension so that we can
+ # add extra checks that fields have correctly been initialized by
+ # __setstate__.
+ from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor
+
+ # The extra function embedding is needed so that the
+ # reconstructor function has the same signature whether or not
+ # the extension is present.
+ def rowproxy_reconstructor(cls, state):
+ return safe_rowproxy_reconstructor(cls, state)
+
+
+except ImportError:
+
+ def rowproxy_reconstructor(cls, state):
+ obj = cls.__new__(cls)
+ obj.__setstate__(state)
+ return obj
+
+
+KEY_INTEGER_ONLY = 0
+"""__getitem__ only allows integer values, raises TypeError otherwise"""
+
+KEY_OBJECTS_ONLY = 1
+"""__getitem__ only allows string/object values, raises TypeError otherwise"""
+
+KEY_OBJECTS_BUT_WARN = 2
+"""__getitem__ allows integer or string/object values, but emits a 2.0
+deprecation warning if string/object is passed"""
+
+KEY_OBJECTS_NO_WARN = 3
+"""__getitem__ allows integer or string/object values with no warnings
+or errors."""
+
+try:
+ from sqlalchemy.cresultproxy import BaseRow
+
+ _baserow_usecext = True
+except ImportError:
+ _baserow_usecext = False
+
+ class BaseRow(object):
+ __slots__ = ("_parent", "_data", "_keymap", "_key_style")
+
+ def __init__(self, parent, processors, keymap, key_style, data):
+ """Row objects are constructed by CursorResult objects."""
+
+ object.__setattr__(self, "_parent", parent)
+
+ if processors:
+ object.__setattr__(
+ self,
+ "_data",
+ tuple(
+ [
+ proc(value) if proc else value
+ for proc, value in zip(processors, data)
+ ]
+ ),
+ )
+ else:
+ object.__setattr__(self, "_data", tuple(data))
+
+ object.__setattr__(self, "_keymap", keymap)
+
+ object.__setattr__(self, "_key_style", key_style)
+
+ def __reduce__(self):
+ return (
+ rowproxy_reconstructor,
+ (self.__class__, self.__getstate__()),
+ )
+
+ def _filter_on_values(self, filters):
+ return Row(
+ self._parent,
+ filters,
+ self._keymap,
+ self._key_style,
+ self._data,
+ )
+
+ def _values_impl(self):
+ return list(self)
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __hash__(self):
+ return hash(self._data)
+
+ def _get_by_int_impl(self, key):
+ return self._data[key]
+
+ def _get_by_key_impl(self, key):
+ if int in key.__class__.__mro__:
+ return self._data[key]
+
+ if self._key_style == KEY_INTEGER_ONLY:
+ self._parent._raise_for_nonint(key)
+
+ # the following is all LegacyRow support. none of this
+ # should be called if not LegacyRow
+ # assert isinstance(self, LegacyRow)
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._parent._key_fallback(key, ke)
+ except TypeError:
+ if isinstance(key, slice):
+ return tuple(self._data[key])
+ else:
+ raise
+
+ mdindex = rec[MD_INDEX]
+ if mdindex is None:
+ self._parent._raise_for_ambiguous_column_name(rec)
+
+ elif self._key_style == KEY_OBJECTS_BUT_WARN and mdindex != key:
+ self._parent._warn_for_nonint(key)
+
+ return self._data[mdindex]
+
+ # The original 1.4 plan was that Row would not allow row["str"]
+ # access, however as the C extensions were inadvertently allowing
+ # this coupled with the fact that orm Session sets future=True,
+ # this allows a softer upgrade path. see #6218
+ __getitem__ = _get_by_key_impl
+
+ def _get_by_key_impl_mapping(self, key):
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._parent._key_fallback(key, ke)
+
+ mdindex = rec[MD_INDEX]
+ if mdindex is None:
+ self._parent._raise_for_ambiguous_column_name(rec)
+ elif (
+ self._key_style == KEY_OBJECTS_ONLY
+ and int in key.__class__.__mro__
+ ):
+ raise KeyError(key)
+
+ return self._data[mdindex]
+
+ def __getattr__(self, name):
+ try:
+ return self._get_by_key_impl_mapping(name)
+ except KeyError as e:
+ util.raise_(AttributeError(e.args[0]), replace_context=e)
+
+
+class Row(BaseRow, collections_abc.Sequence):
+ """Represent a single result row.
+
+ The :class:`.Row` object represents a row of a database result. It is
+ typically associated in the 1.x series of SQLAlchemy with the
+ :class:`_engine.CursorResult` object, however is also used by the ORM for
+ tuple-like results as of SQLAlchemy 1.4.
+
+ The :class:`.Row` object seeks to act as much like a Python named
+ tuple as possible. For mapping (i.e. dictionary) behavior on a row,
+ such as testing for containment of keys, refer to the :attr:`.Row._mapping`
+ attribute.
+
+ .. seealso::
+
+ :ref:`tutorial_selecting_data` - includes examples of selecting
+ rows from SELECT statements.
+
+ :class:`.LegacyRow` - Compatibility interface introduced in SQLAlchemy
+ 1.4.
+
+ .. versionchanged:: 1.4
+
+ Renamed ``RowProxy`` to :class:`.Row`. :class:`.Row` is no longer a
+ "proxy" object in that it contains the final form of data within it,
+ and now acts mostly like a named tuple. Mapping-like functionality is
+ moved to the :attr:`.Row._mapping` attribute, but will remain available
+ in SQLAlchemy 1.x series via the :class:`.LegacyRow` class that is used
+ by :class:`_engine.LegacyCursorResult`.
+ See :ref:`change_4710_core` for background
+ on this change.
+
+ """
+
+ __slots__ = ()
+
+ # in 2.0, this should be KEY_INTEGER_ONLY
+ _default_key_style = KEY_OBJECTS_BUT_WARN
+
+ def __setattr__(self, name, value):
+ raise AttributeError("can't set attribute")
+
+ def __delattr__(self, name):
+ raise AttributeError("can't delete attribute")
+
+ @property
+ def _mapping(self):
+ """Return a :class:`.RowMapping` for this :class:`.Row`.
+
+ This object provides a consistent Python mapping (i.e. dictionary)
+ interface for the data contained within the row. The :class:`.Row`
+ by itself behaves like a named tuple, however in the 1.4 series of
+ SQLAlchemy, the :class:`.LegacyRow` class is still used by Core which
+ continues to have mapping-like behaviors against the row object
+ itself.
+
+ .. seealso::
+
+ :attr:`.Row._fields`
+
+ .. versionadded:: 1.4
+
+ """
+ return RowMapping(
+ self._parent,
+ None,
+ self._keymap,
+ RowMapping._default_key_style,
+ self._data,
+ )
+
+ def _special_name_accessor(name):
+ """Handle ambiguous names such as "count" and "index" """
+
+ @property
+ def go(self):
+ if self._parent._has_key(name):
+ return self.__getattr__(name)
+ else:
+
+ def meth(*arg, **kw):
+ return getattr(collections_abc.Sequence, name)(
+ self, *arg, **kw
+ )
+
+ return meth
+
+ return go
+
+ count = _special_name_accessor("count")
+ index = _special_name_accessor("index")
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def __getstate__(self):
+ return {
+ "_parent": self._parent,
+ "_data": self._data,
+ "_key_style": self._key_style,
+ }
+
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ object.__setattr__(self, "_parent", parent)
+ object.__setattr__(self, "_data", state["_data"])
+ object.__setattr__(self, "_keymap", parent._keymap)
+ object.__setattr__(self, "_key_style", state["_key_style"])
+
+ def _op(self, other, op):
+ return (
+ op(tuple(self), tuple(other))
+ if isinstance(other, Row)
+ else op(tuple(self), other)
+ )
+
+ __hash__ = BaseRow.__hash__
+
+ def __lt__(self, other):
+ return self._op(other, operator.lt)
+
+ def __le__(self, other):
+ return self._op(other, operator.le)
+
+ def __ge__(self, other):
+ return self._op(other, operator.ge)
+
+ def __gt__(self, other):
+ return self._op(other, operator.gt)
+
+ def __eq__(self, other):
+ return self._op(other, operator.eq)
+
+ def __ne__(self, other):
+ return self._op(other, operator.ne)
+
+ def __repr__(self):
+ return repr(sql_util._repr_row(self))
+
+ @util.deprecated_20(
+ ":meth:`.Row.keys`",
+ alternative="Use the namedtuple standard accessor "
+ ":attr:`.Row._fields`, or for full mapping behavior use "
+ "row._mapping.keys() ",
+ )
+ def keys(self):
+ """Return the list of keys as strings represented by this
+ :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ This method is analogous to the Python dictionary ``.keys()`` method,
+ except that it returns a list, not an iterator.
+
+ .. seealso::
+
+ :attr:`.Row._fields`
+
+ :attr:`.Row._mapping`
+
+ """
+ return self._parent.keys
+
+ @property
+ def _fields(self):
+ """Return a tuple of string keys as represented by this
+ :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ This attribute is analogous to the Python named tuple ``._fields``
+ attribute.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`.Row._mapping`
+
+ """
+ return tuple([k for k in self._parent.keys if k is not None])
+
+ def _asdict(self):
+ """Return a new dict which maps field names to their corresponding
+ values.
+
+ This method is analogous to the Python named tuple ``._asdict()``
+ method, and works by applying the ``dict()`` constructor to the
+ :attr:`.Row._mapping` attribute.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`.Row._mapping`
+
+ """
+ return dict(self._mapping)
+
+ def _replace(self):
+ raise NotImplementedError()
+
+ @property
+ def _field_defaults(self):
+ raise NotImplementedError()
+
+
+class LegacyRow(Row):
+ """A subclass of :class:`.Row` that delivers 1.x SQLAlchemy behaviors
+ for Core.
+
+ The :class:`.LegacyRow` class is where most of the Python mapping
+ (i.e. dictionary-like)
+ behaviors are implemented for the row object. The mapping behavior
+ of :class:`.Row` going forward is accessible via the :class:`.Row._mapping`
+ attribute.
+
+ .. versionadded:: 1.4 - added :class:`.LegacyRow` which encapsulates most
+ of the deprecated behaviors of :class:`.Row`.
+
+ """
+
+ __slots__ = ()
+
+ if util.SQLALCHEMY_WARN_20:
+ _default_key_style = KEY_OBJECTS_BUT_WARN
+ else:
+ _default_key_style = KEY_OBJECTS_NO_WARN
+
+ def __contains__(self, key):
+ return self._parent._contains(key, self)
+
+ # prior to #6218, LegacyRow would redirect the behavior of __getitem__
+ # for the non C version of BaseRow. This is now set up by Python BaseRow
+ # in all cases
+ # if not _baserow_usecext:
+ # __getitem__ = BaseRow._get_by_key_impl
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.has_key` method is deprecated and will be "
+ "removed in a future release. To test for key membership, use "
+ "the :attr:`Row._mapping` attribute, i.e. 'key in row._mapping`.",
+ )
+ def has_key(self, key):
+ """Return True if this :class:`.LegacyRow` contains the given key.
+
+ Through the SQLAlchemy 1.x series, the ``__contains__()`` method of
+ :class:`.Row` (or :class:`.LegacyRow` as of SQLAlchemy 1.4) also links
+ to :meth:`.Row.has_key`, in that an expression such as ::
+
+ "some_col" in row
+
+ Will return True if the row contains a column named ``"some_col"``,
+ in the way that a Python mapping works.
+
+ However, it is planned that the 2.0 series of SQLAlchemy will reverse
+ this behavior so that ``__contains__()`` will refer to a value being
+ present in the row, in the way that a Python tuple works.
+
+ .. seealso::
+
+ :ref:`change_4710_core`
+
+ """
+
+ return self._parent._has_key(key)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.items` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.items()'.",
+ )
+ def items(self):
+ """Return a list of tuples, each tuple containing a key/value pair.
+
+ This method is analogous to the Python dictionary ``.items()`` method,
+ except that it returns a list, not an iterator.
+
+ """
+
+ return [(key, self[key]) for key in self.keys()]
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.iterkeys` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.keys()'.",
+ )
+ def iterkeys(self):
+ """Return a an iterator against the :meth:`.Row.keys` method.
+
+ This method is analogous to the Python-2-only dictionary
+ ``.iterkeys()`` method.
+
+ """
+ return iter(self._parent.keys)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.itervalues` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.values()'.",
+ )
+ def itervalues(self):
+ """Return a an iterator against the :meth:`.Row.values` method.
+
+ This method is analogous to the Python-2-only dictionary
+ ``.itervalues()`` method.
+
+ """
+ return iter(self)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.values` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.values()'.",
+ )
+ def values(self):
+ """Return the values represented by this :class:`.Row` as a list.
+
+ This method is analogous to the Python dictionary ``.values()`` method,
+ except that it returns a list, not an iterator.
+
+ """
+
+ return self._values_impl()
+
+
+BaseRowProxy = BaseRow
+RowProxy = Row
+
+
+class ROMappingView(
+ collections_abc.KeysView,
+ collections_abc.ValuesView,
+ collections_abc.ItemsView,
+):
+ __slots__ = (
+ "_mapping",
+ "_items",
+ )
+
+ def __init__(self, mapping, items):
+ self._mapping = mapping
+ self._items = items
+
+ def __len__(self):
+ return len(self._items)
+
+ def __repr__(self):
+ return "{0.__class__.__name__}({0._mapping!r})".format(self)
+
+ def __iter__(self):
+ return iter(self._items)
+
+ def __contains__(self, item):
+ return item in self._items
+
+ def __eq__(self, other):
+ return list(other) == list(self)
+
+ def __ne__(self, other):
+ return list(other) != list(self)
+
+
+class RowMapping(BaseRow, collections_abc.Mapping):
+ """A ``Mapping`` that maps column names and objects to :class:`.Row`
+ values.
+
+ The :class:`.RowMapping` is available from a :class:`.Row` via the
+ :attr:`.Row._mapping` attribute, as well as from the iterable interface
+ provided by the :class:`.MappingResult` object returned by the
+ :meth:`_engine.Result.mappings` method.
+
+ :class:`.RowMapping` supplies Python mapping (i.e. dictionary) access to
+ the contents of the row. This includes support for testing of
+ containment of specific keys (string column names or objects), as well
+ as iteration of keys, values, and items::
+
+ for row in result:
+ if 'a' in row._mapping:
+ print("Column 'a': %s" % row._mapping['a'])
+
+ print("Column b: %s" % row._mapping[table.c.b])
+
+
+ .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the
+ mapping-like access previously provided by a database result row,
+ which now seeks to behave mostly like a named tuple.
+
+ """
+
+ __slots__ = ()
+
+ _default_key_style = KEY_OBJECTS_ONLY
+
+ if not _baserow_usecext:
+
+ __getitem__ = BaseRow._get_by_key_impl_mapping
+
+ def _values_impl(self):
+ return list(self._data)
+
+ def __iter__(self):
+ return (k for k in self._parent.keys if k is not None)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __contains__(self, key):
+ return self._parent._has_key(key)
+
+ def __repr__(self):
+ return repr(dict(self))
+
+ def items(self):
+ """Return a view of key/value tuples for the elements in the
+ underlying :class:`.Row`.
+
+ """
+ return ROMappingView(self, [(key, self[key]) for key in self.keys()])
+
+ def keys(self):
+ """Return a view of 'keys' for string column names represented
+ by the underlying :class:`.Row`.
+
+ """
+
+ return self._parent.keys
+
+ def values(self):
+ """Return a view of values for the values represented in the
+ underlying :class:`.Row`.
+
+ """
+ return ROMappingView(self, self._values_impl())
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
new file mode 100644
index 0000000..54a5e51
--- /dev/null
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -0,0 +1,17 @@
+# engine/strategies.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Deprecated mock engine strategy used by Alembic.
+
+
+"""
+
+from .mock import MockConnection # noqa
+
+
+class MockEngineStrategy(object):
+ MockConnection = MockConnection
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
new file mode 100644
index 0000000..db971c2
--- /dev/null
+++ b/lib/sqlalchemy/engine/url.py
@@ -0,0 +1,806 @@
+# engine/url.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
+information about a database connection specification.
+
+The URL object is created automatically when
+:func:`~sqlalchemy.engine.create_engine` is called with a string
+argument; alternatively, the URL is a public-facing construct which can
+be used directly and is also accepted directly by ``create_engine()``.
+"""
+
+import re
+
+from .interfaces import Dialect
+from .. import exc
+from .. import util
+from ..dialects import plugins
+from ..dialects import registry
+from ..util import collections_abc
+from ..util import compat
+
+
+class URL(
+ util.namedtuple(
+ "URL",
+ [
+ "drivername",
+ "username",
+ "password",
+ "host",
+ "port",
+ "database",
+ "query",
+ ],
+ )
+):
+ """
+ Represent the components of a URL used to connect to a database.
+
+ This object is suitable to be passed directly to a
+ :func:`_sa.create_engine` call. The fields of the URL are parsed
+ from a string by the :func:`.make_url` function. The string
+ format of the URL is an RFC-1738-style string.
+
+ To create a new :class:`_engine.URL` object, use the
+ :func:`_engine.url.make_url` function. To construct a :class:`_engine.URL`
+ programmatically, use the :meth:`_engine.URL.create` constructor.
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now an immutable object. To
+ create a URL, use the :func:`_engine.make_url` or
+ :meth:`_engine.URL.create` function / method. To modify
+ a :class:`_engine.URL`, use methods like
+ :meth:`_engine.URL.set` and
+ :meth:`_engine.URL.update_query_dict` to return a new
+ :class:`_engine.URL` object with modifications. See notes for this
+ change at :ref:`change_5526`.
+
+ :class:`_engine.URL` contains the following attributes:
+
+ * :attr:`_engine.URL.drivername`: database backend and driver name, such as
+ ``postgresql+psycopg2``
+ * :attr:`_engine.URL.username`: username string
+ * :attr:`_engine.URL.password`: password string
+ * :attr:`_engine.URL.host`: string hostname
+ * :attr:`_engine.URL.port`: integer port number
+ * :attr:`_engine.URL.database`: string database name
+ * :attr:`_engine.URL.query`: an immutable mapping representing the query
+ string. contains strings for keys and either strings or tuples of
+ strings for values.
+
+
+ """
+
+ def __new__(self, *arg, **kw):
+ if kw.pop("_new_ok", False):
+ return super(URL, self).__new__(self, *arg, **kw)
+ else:
+ util.warn_deprecated(
+ "Calling URL() directly is deprecated and will be disabled "
+ "in a future release. The public constructor for URL is "
+ "now the URL.create() method.",
+ "1.4",
+ )
+ return URL.create(*arg, **kw)
+
+ @classmethod
+ def create(
+ cls,
+ drivername,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=util.EMPTY_DICT,
+ ):
+ """Create a new :class:`_engine.URL` object.
+
+ :param drivername: the name of the database backend. This name will
+ correspond to a module in sqlalchemy/databases or a third party
+ plug-in.
+ :param username: The user name.
+ :param password: database password. Is typically a string, but may
+ also be an object that can be stringified with ``str()``.
+
+ .. note:: A password-producing object will be stringified only
+ **once** per :class:`_engine.Engine` object. For dynamic password
+ generation per connect, see :ref:`engines_dynamic_tokens`.
+
+ :param host: The name of the host.
+ :param port: The port number.
+ :param database: The database name.
+ :param query: A dictionary of string keys to string values to be passed
+ to the dialect and/or the DBAPI upon connect. To specify non-string
+ parameters to a Python DBAPI directly, use the
+ :paramref:`_sa.create_engine.connect_args` parameter to
+ :func:`_sa.create_engine`. See also
+ :attr:`_engine.URL.normalized_query` for a dictionary that is
+ consistently string->list of string.
+ :return: new :class:`_engine.URL` object.
+
+ .. versionadded:: 1.4
+
+ The :class:`_engine.URL` object is now an **immutable named
+ tuple**. In addition, the ``query`` dictionary is also immutable.
+ To create a URL, use the :func:`_engine.url.make_url` or
+ :meth:`_engine.URL.create` function/ method. To modify a
+ :class:`_engine.URL`, use the :meth:`_engine.URL.set` and
+ :meth:`_engine.URL.update_query` methods.
+
+ """
+
+ return cls(
+ cls._assert_str(drivername, "drivername"),
+ cls._assert_none_str(username, "username"),
+ password,
+ cls._assert_none_str(host, "host"),
+ cls._assert_port(port),
+ cls._assert_none_str(database, "database"),
+ cls._str_dict(query),
+ _new_ok=True,
+ )
+
+ @classmethod
+ def _assert_port(cls, port):
+ if port is None:
+ return None
+ try:
+ return int(port)
+ except TypeError:
+ raise TypeError("Port argument must be an integer or None")
+
+ @classmethod
+ def _assert_str(cls, v, paramname):
+ if not isinstance(v, compat.string_types):
+ raise TypeError("%s must be a string" % paramname)
+ return v
+
+ @classmethod
+ def _assert_none_str(cls, v, paramname):
+ if v is None:
+ return v
+
+ return cls._assert_str(v, paramname)
+
+ @classmethod
+ def _str_dict(cls, dict_):
+ if dict_ is None:
+ return util.EMPTY_DICT
+
+ def _assert_value(val):
+ if isinstance(val, compat.string_types):
+ return val
+ elif isinstance(val, collections_abc.Sequence):
+ return tuple(_assert_value(elem) for elem in val)
+ else:
+ raise TypeError(
+ "Query dictionary values must be strings or "
+ "sequences of strings"
+ )
+
+ def _assert_str(v):
+ if not isinstance(v, compat.string_types):
+ raise TypeError("Query dictionary keys must be strings")
+ return v
+
+ if isinstance(dict_, collections_abc.Sequence):
+ dict_items = dict_
+ else:
+ dict_items = dict_.items()
+
+ return util.immutabledict(
+ {
+ _assert_str(key): _assert_value(
+ value,
+ )
+ for key, value in dict_items
+ }
+ )
+
+ def set(
+ self,
+ drivername=None,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=None,
+ ):
+ """return a new :class:`_engine.URL` object with modifications.
+
+ Values are used if they are non-None. To set a value to ``None``
+ explicitly, use the :meth:`_engine.URL._replace` method adapted
+ from ``namedtuple``.
+
+ :param drivername: new drivername
+ :param username: new username
+ :param password: new password
+ :param host: new hostname
+ :param port: new port
+ :param query: new query parameters, passed a dict of string keys
+ referring to string or sequence of string values. Fully
+ replaces the previous list of arguments.
+
+ :return: new :class:`_engine.URL` object.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_engine.URL.update_query_dict`
+
+ """
+
+ kw = {}
+ if drivername is not None:
+ kw["drivername"] = drivername
+ if username is not None:
+ kw["username"] = username
+ if password is not None:
+ kw["password"] = password
+ if host is not None:
+ kw["host"] = host
+ if port is not None:
+ kw["port"] = port
+ if database is not None:
+ kw["database"] = database
+ if query is not None:
+ kw["query"] = query
+
+ return self._replace(**kw)
+
+ def _replace(self, **kw):
+ """Override ``namedtuple._replace()`` to provide argument checking."""
+
+ if "drivername" in kw:
+ self._assert_str(kw["drivername"], "drivername")
+ for name in "username", "host", "database":
+ if name in kw:
+ self._assert_none_str(kw[name], name)
+ if "port" in kw:
+ self._assert_port(kw["port"])
+ if "query" in kw:
+ kw["query"] = self._str_dict(kw["query"])
+
+ return super(URL, self)._replace(**kw)
+
+ def update_query_string(self, query_string, append=False):
+ """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query`
+ parameter dictionary updated by the given query string.
+
+ E.g.::
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+ :param query_string: a URL escaped query string, not including the
+ question mark.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_dict`
+
+ """ # noqa: E501
+ return self.update_query_pairs(
+ util.parse_qsl(query_string), append=append
+ )
+
+ def update_query_pairs(self, key_value_pairs, append=False):
+ """Return a new :class:`_engine.URL` object with the
+ :attr:`_engine.URL.query`
+ parameter dictionary updated by the given sequence of key/value pairs
+
+ E.g.::
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")])
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+ :param key_value_pairs: A sequence of tuples containing two strings
+ each.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.difference_update_query`
+
+ :meth:`_engine.URL.set`
+
+ """ # noqa: E501
+
+ existing_query = self.query
+ new_keys = {}
+
+ for key, value in key_value_pairs:
+ if key in new_keys:
+ new_keys[key] = util.to_list(new_keys[key])
+ new_keys[key].append(value)
+ else:
+ new_keys[key] = value
+
+ if append:
+ new_query = {}
+
+ for k in new_keys:
+ if k in existing_query:
+ new_query[k] = util.to_list(
+ existing_query[k]
+ ) + util.to_list(new_keys[k])
+ else:
+ new_query[k] = new_keys[k]
+
+ new_query.update(
+ {
+ k: existing_query[k]
+ for k in set(existing_query).difference(new_keys)
+ }
+ )
+ else:
+ new_query = self.query.union(new_keys)
+ return self.set(query=new_query)
+
+ def update_query_dict(self, query_parameters, append=False):
+ """Return a new :class:`_engine.URL` object with the
+ :attr:`_engine.URL.query` parameter dictionary updated by the given
+ dictionary.
+
+ The dictionary typically contains string keys and string values.
+ In order to represent a query parameter that is expressed multiple
+ times, pass a sequence of string values.
+
+ E.g.::
+
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"})
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+
+ :param query_parameters: A dictionary with string keys and values
+ that are either strings, or sequences of strings.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_string`
+
+ :meth:`_engine.URL.update_query_pairs`
+
+ :meth:`_engine.URL.difference_update_query`
+
+ :meth:`_engine.URL.set`
+
+ """ # noqa: E501
+ return self.update_query_pairs(query_parameters.items(), append=append)
+
+ def difference_update_query(self, names):
+ """
+ Remove the given names from the :attr:`_engine.URL.query` dictionary,
+ returning the new :class:`_engine.URL`.
+
+ E.g.::
+
+ url = url.difference_update_query(['foo', 'bar'])
+
+ Equivalent to using :meth:`_engine.URL.set` as follows::
+
+ url = url.set(
+ query={
+ key: url.query[key]
+ for key in set(url.query).difference(['foo', 'bar'])
+ }
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_dict`
+
+ :meth:`_engine.URL.set`
+
+ """
+
+ if not set(names).intersection(self.query):
+ return self
+
+ return URL(
+ self.drivername,
+ self.username,
+ self.password,
+ self.host,
+ self.port,
+ self.database,
+ util.immutabledict(
+ {
+ key: self.query[key]
+ for key in set(self.query).difference(names)
+ }
+ ),
+ _new_ok=True,
+ )
+
+ @util.memoized_property
+ def normalized_query(self):
+ """Return the :attr:`_engine.URL.query` dictionary with values normalized
+ into sequences.
+
+ As the :attr:`_engine.URL.query` dictionary may contain either
+ string values or sequences of string values to differentiate between
+ parameters that are specified multiple times in the query string,
+ code that needs to handle multiple parameters generically will wish
+ to use this attribute so that all parameters present are presented
+ as sequences. Inspiration is from Python's ``urllib.parse.parse_qs``
+ function. E.g.::
+
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
+ >>> url.query
+ immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
+ >>> url.normalized_query
+ immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': ('/path/to/crt',)})
+
+ """ # noqa: E501
+
+ return util.immutabledict(
+ {
+ k: (v,) if not isinstance(v, tuple) else v
+ for k, v in self.query.items()
+ }
+ )
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.URL.__to_string__ method is deprecated and will "
+ "be removed in a future release. Please use the "
+ ":meth:`_engine.URL.render_as_string` method.",
+ )
+ def __to_string__(self, hide_password=True):
+ """Render this :class:`_engine.URL` object as a string.
+
+ :param hide_password: Defaults to True. The password is not shown
+ in the string unless this is set to False.
+
+ """
+ return self.render_as_string(hide_password=hide_password)
+
+ def render_as_string(self, hide_password=True):
+ """Render this :class:`_engine.URL` object as a string.
+
+ This method is used when the ``__str__()`` or ``__repr__()``
+ methods are used. The method directly includes additional options.
+
+ :param hide_password: Defaults to True. The password is not shown
+ in the string unless this is set to False.
+
+ """
+ s = self.drivername + "://"
+ if self.username is not None:
+ s += _rfc_1738_quote(self.username)
+ if self.password is not None:
+ s += ":" + (
+ "***"
+ if hide_password
+ else _rfc_1738_quote(str(self.password))
+ )
+ s += "@"
+ if self.host is not None:
+ if ":" in self.host:
+ s += "[%s]" % self.host
+ else:
+ s += self.host
+ if self.port is not None:
+ s += ":" + str(self.port)
+ if self.database is not None:
+ s += "/" + self.database
+ if self.query:
+ keys = list(self.query)
+ keys.sort()
+ s += "?" + "&".join(
+ "%s=%s" % (util.quote_plus(k), util.quote_plus(element))
+ for k in keys
+ for element in util.to_list(self.query[k])
+ )
+ return s
+
+ def __str__(self):
+ return self.render_as_string(hide_password=False)
+
+ def __repr__(self):
+ return self.render_as_string()
+
+ def __copy__(self):
+ return self.__class__.create(
+ self.drivername,
+ self.username,
+ self.password,
+ self.host,
+ self.port,
+ self.database,
+ # note this is an immutabledict of str-> str / tuple of str,
+ # also fully immutable. does not require deepcopy
+ self.query,
+ )
+
+ def __deepcopy__(self, memo):
+ return self.__copy__()
+
+ def __hash__(self):
+ return hash(str(self))
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, URL)
+ and self.drivername == other.drivername
+ and self.username == other.username
+ and self.password == other.password
+ and self.host == other.host
+ and self.database == other.database
+ and self.query == other.query
+ and self.port == other.port
+ )
+
+ def __ne__(self, other):
+ return not self == other
+
+ def get_backend_name(self):
+ """Return the backend name.
+
+ This is the name that corresponds to the database backend in
+ use, and is the portion of the :attr:`_engine.URL.drivername`
+ that is to the left of the plus sign.
+
+ """
+ if "+" not in self.drivername:
+ return self.drivername
+ else:
+ return self.drivername.split("+")[0]
+
+ def get_driver_name(self):
+ """Return the backend name.
+
+ This is the name that corresponds to the DBAPI driver in
+ use, and is the portion of the :attr:`_engine.URL.drivername`
+ that is to the right of the plus sign.
+
+ If the :attr:`_engine.URL.drivername` does not include a plus sign,
+ then the default :class:`_engine.Dialect` for this :class:`_engine.URL`
+ is imported in order to get the driver name.
+
+ """
+
+ if "+" not in self.drivername:
+ return self.get_dialect().driver
+ else:
+ return self.drivername.split("+")[1]
+
+ def _instantiate_plugins(self, kwargs):
+ plugin_names = util.to_list(self.query.get("plugin", ()))
+ plugin_names += kwargs.get("plugins", [])
+
+ kwargs = dict(kwargs)
+
+ loaded_plugins = [
+ plugins.load(plugin_name)(self, kwargs)
+ for plugin_name in plugin_names
+ ]
+
+ u = self.difference_update_query(["plugin", "plugins"])
+
+ for plugin in loaded_plugins:
+ new_u = plugin.update_url(u)
+ if new_u is not None:
+ u = new_u
+
+ kwargs.pop("plugins", None)
+
+ return u, loaded_plugins, kwargs
+
+ def _get_entrypoint(self):
+ """Return the "entry point" dialect class.
+
+ This is normally the dialect itself except in the case when the
+ returned class implements the get_dialect_cls() method.
+
+ """
+ if "+" not in self.drivername:
+ name = self.drivername
+ else:
+ name = self.drivername.replace("+", ".")
+ cls = registry.load(name)
+ # check for legacy dialects that
+ # would return a module with 'dialect' as the
+ # actual class
+ if (
+ hasattr(cls, "dialect")
+ and isinstance(cls.dialect, type)
+ and issubclass(cls.dialect, Dialect)
+ ):
+ return cls.dialect
+ else:
+ return cls
+
+ def get_dialect(self):
+ """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
+ to this URL's driver name.
+
+ """
+ entrypoint = self._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(self)
+ return dialect_cls
+
+ def translate_connect_args(self, names=None, **kw):
+ r"""Translate url attributes into a dictionary of connection arguments.
+
+ Returns attributes of this url (`host`, `database`, `username`,
+ `password`, `port`) as a plain dictionary. The attribute names are
+ used as the keys by default. Unset or false attributes are omitted
+ from the final dictionary.
+
+ :param \**kw: Optional, alternate key names for url attributes.
+
+ :param names: Deprecated. Same purpose as the keyword-based alternate
+ names, but correlates the name to the original positionally.
+ """
+
+ if names is not None:
+ util.warn_deprecated(
+ "The `URL.translate_connect_args.name`s parameter is "
+ "deprecated. Please pass the "
+ "alternate names as kw arguments.",
+ "1.4",
+ )
+
+ translated = {}
+ attribute_names = ["host", "database", "username", "password", "port"]
+ for sname in attribute_names:
+ if names:
+ name = names.pop(0)
+ elif sname in kw:
+ name = kw[sname]
+ else:
+ name = sname
+ if name is not None and getattr(self, sname, False):
+ if sname == "password":
+ translated[name] = str(getattr(self, sname))
+ else:
+ translated[name] = getattr(self, sname)
+
+ return translated
+
+
+def make_url(name_or_url):
+ """Given a string or unicode instance, produce a new URL instance.
+
+ The given string is parsed according to the RFC 1738 spec. If an
+ existing URL object is passed, just returns the object.
+ """
+
+ if isinstance(name_or_url, util.string_types):
+ return _parse_rfc1738_args(name_or_url)
+ else:
+ return name_or_url
+
+
+def _parse_rfc1738_args(name):
+ pattern = re.compile(
+ r"""
+ (?P<name>[\w\+]+)://
+ (?:
+ (?P<username>[^:/]*)
+ (?::(?P<password>[^@]*))?
+ @)?
+ (?:
+ (?:
+ \[(?P<ipv6host>[^/\?]+)\] |
+ (?P<ipv4host>[^/:\?]+)
+ )?
+ (?::(?P<port>[^/\?]*))?
+ )?
+ (?:/(?P<database>[^\?]*))?
+ (?:\?(?P<query>.*))?
+ """,
+ re.X,
+ )
+
+ m = pattern.match(name)
+ if m is not None:
+ components = m.groupdict()
+ if components["query"] is not None:
+ query = {}
+
+ for key, value in util.parse_qsl(components["query"]):
+ if util.py2k:
+ key = key.encode("ascii")
+ if key in query:
+ query[key] = util.to_list(query[key])
+ query[key].append(value)
+ else:
+ query[key] = value
+ else:
+ query = None
+ components["query"] = query
+
+ if components["username"] is not None:
+ components["username"] = _rfc_1738_unquote(components["username"])
+
+ if components["password"] is not None:
+ components["password"] = _rfc_1738_unquote(components["password"])
+
+ ipv4host = components.pop("ipv4host")
+ ipv6host = components.pop("ipv6host")
+ components["host"] = ipv4host or ipv6host
+ name = components.pop("name")
+
+ if components["port"]:
+ components["port"] = int(components["port"])
+
+ return URL.create(name, **components)
+
+ else:
+ raise exc.ArgumentError(
+ "Could not parse rfc1738 URL from string '%s'" % name
+ )
+
+
+def _rfc_1738_quote(text):
+ return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
+
+
+def _rfc_1738_unquote(text):
+ return util.unquote(text)
+
+
+def _parse_keyvalue_args(name):
+ m = re.match(r"(\w+)://(.*)", name)
+ if m is not None:
+ (name, args) = m.group(1, 2)
+ opts = dict(util.parse_qsl(args))
+ return URL(name, *opts)
+ else:
+ return None
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
new file mode 100644
index 0000000..1b03ebb
--- /dev/null
+++ b/lib/sqlalchemy/engine/util.py
@@ -0,0 +1,253 @@
+# engine/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import exc
+from .. import util
+from ..util import collections_abc
+from ..util import immutabledict
+
+
+def connection_memoize(key):
+ """Decorator, memoize a function in a connection.info stash.
+
+ Only applicable to functions which take no arguments other than a
+ connection. The memo will be stored in ``connection.info[key]``.
+ """
+
+ @util.decorator
+ def decorated(fn, self, connection):
+ connection = connection.connect()
+ try:
+ return connection.info[key]
+ except KeyError:
+ connection.info[key] = val = fn(self, connection)
+ return val
+
+ return decorated
+
+
+_no_tuple = ()
+_no_kw = util.immutabledict()
+
+
+def _distill_params(connection, multiparams, params):
+ r"""Given arguments from the calling form \*multiparams, \**params,
+ return a list of bind parameter structures, usually a list of
+ dictionaries.
+
+ In the case of 'raw' execution which accepts positional parameters,
+ it may be a list of tuples or lists.
+
+ """
+
+ if not multiparams:
+ if params:
+ connection._warn_for_legacy_exec_format()
+ return [params]
+ else:
+ return []
+ elif len(multiparams) == 1:
+ zero = multiparams[0]
+ if isinstance(zero, (list, tuple)):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
+ # execute(stmt, [{}, {}, {}, ...])
+ # execute(stmt, [(), (), (), ...])
+ return zero
+ else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
+ # execute(stmt, ("value", "value"))
+ return [zero]
+ elif hasattr(zero, "keys"):
+ # execute(stmt, {"key":"value"})
+ return [zero]
+ else:
+ connection._warn_for_legacy_exec_format()
+ # execute(stmt, "value")
+ return [[zero]]
+ else:
+ connection._warn_for_legacy_exec_format()
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
+ return multiparams
+ else:
+ return [multiparams]
+
+
+def _distill_cursor_params(connection, multiparams, params):
+ """_distill_params without any warnings. more appropriate for
+ "cursor" params that can include tuple arguments, lists of tuples,
+ etc.
+
+ """
+
+ if not multiparams:
+ if params:
+ return [params]
+ else:
+ return []
+ elif len(multiparams) == 1:
+ zero = multiparams[0]
+ if isinstance(zero, (list, tuple)):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
+ # execute(stmt, [{}, {}, {}, ...])
+ # execute(stmt, [(), (), (), ...])
+ return zero
+ else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
+ # execute(stmt, ("value", "value"))
+
+ return [zero]
+ elif hasattr(zero, "keys"):
+ # execute(stmt, {"key":"value"})
+ return [zero]
+ else:
+ # execute(stmt, "value")
+ return [[zero]]
+ else:
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
+ return multiparams
+ else:
+ return [multiparams]
+
+
+def _distill_params_20(params):
+ if params is None:
+ return _no_tuple, _no_kw
+ elif isinstance(params, list):
+ # collections_abc.MutableSequence): # avoid abc.__instancecheck__
+ if params and not isinstance(
+ params[0], (collections_abc.Mapping, tuple)
+ ):
+ raise exc.ArgumentError(
+ "List argument must consist only of tuples or dictionaries"
+ )
+
+ return (params,), _no_kw
+ elif isinstance(
+ params,
+ (tuple, dict, immutabledict),
+ # only do abc.__instancecheck__ for Mapping after we've checked
+ # for plain dictionaries and would otherwise raise
+ ) or isinstance(params, collections_abc.Mapping):
+ return (params,), _no_kw
+ else:
+ raise exc.ArgumentError("mapping or sequence expected for parameters")
+
+
+class TransactionalContext(object):
+ """Apply Python context manager behavior to transaction objects.
+
+ Performs validation to ensure the subject of the transaction is not
+ used if the transaction were ended prematurely.
+
+ """
+
+ _trans_subject = None
+
+ def _transaction_is_active(self):
+ raise NotImplementedError()
+
+ def _transaction_is_closed(self):
+ raise NotImplementedError()
+
+ def _rollback_can_be_called(self):
+ """indicates the object is in a state that is known to be acceptable
+ for rollback() to be called.
+
+ This does not necessarily mean rollback() will succeed or not raise
+ an error, just that there is currently no state detected that indicates
+ rollback() would fail or emit warnings.
+
+ It also does not mean that there's a transaction in progress, as
+ it is usually safe to call rollback() even if no transaction is
+ present.
+
+ .. versionadded:: 1.4.28
+
+ """
+ raise NotImplementedError()
+
+ def _get_subject(self):
+ raise NotImplementedError()
+
+ @classmethod
+ def _trans_ctx_check(cls, subject):
+ trans_context = subject._trans_context_manager
+ if trans_context:
+ if not trans_context._transaction_is_active():
+ raise exc.InvalidRequestError(
+ "Can't operate on closed transaction inside context "
+ "manager. Please complete the context manager "
+ "before emitting further commands."
+ )
+
+ def __enter__(self):
+ subject = self._get_subject()
+
+ # none for outer transaction, may be non-None for nested
+ # savepoint, legacy nesting cases
+ trans_context = subject._trans_context_manager
+ self._outer_trans_ctx = trans_context
+
+ self._trans_subject = subject
+ subject._trans_context_manager = self
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ subject = self._trans_subject
+
+ # simplistically we could assume that
+ # "subject._trans_context_manager is self". However, any calling
+ # code that is manipulating __exit__ directly would break this
+ # assumption. alembic context manager
+ # is an example of partial use that just calls __exit__ and
+ # not __enter__ at the moment. it's safe to assume this is being done
+ # in the wild also
+ out_of_band_exit = (
+ subject is None or subject._trans_context_manager is not self
+ )
+
+ if type_ is None and self._transaction_is_active():
+ try:
+ self.commit()
+ except:
+ with util.safe_reraise():
+ if self._rollback_can_be_called():
+ self.rollback()
+ finally:
+ if not out_of_band_exit:
+ subject._trans_context_manager = self._outer_trans_ctx
+ self._trans_subject = self._outer_trans_ctx = None
+ else:
+ try:
+ if not self._transaction_is_active():
+ if not self._transaction_is_closed():
+ self.close()
+ else:
+ if self._rollback_can_be_called():
+ self.rollback()
+ finally:
+ if not out_of_band_exit:
+ subject._trans_context_manager = self._outer_trans_ctx
+ self._trans_subject = self._outer_trans_ctx = None
diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py
new file mode 100644
index 0000000..a89bea8
--- /dev/null
+++ b/lib/sqlalchemy/event/__init__.py
@@ -0,0 +1,17 @@
+# event/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .api import CANCEL
+from .api import contains
+from .api import listen
+from .api import listens_for
+from .api import NO_RETVAL
+from .api import remove
+from .attr import RefCollection
+from .base import dispatcher
+from .base import Events
+from .legacy import _legacy_signature
diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py
new file mode 100644
index 0000000..ce44f57
--- /dev/null
+++ b/lib/sqlalchemy/event/api.py
@@ -0,0 +1,219 @@
+# event/api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Public API functions for the event system.
+
+"""
+from __future__ import absolute_import
+
+from .base import _registrars
+from .registry import _EventKey
+from .. import exc
+from .. import util
+
+
+CANCEL = util.symbol("CANCEL")
+NO_RETVAL = util.symbol("NO_RETVAL")
+
+
+def _event_key(target, identifier, fn):
+ for evt_cls in _registrars[identifier]:
+ tgt = evt_cls._accept_with(target)
+ if tgt is not None:
+ return _EventKey(target, identifier, fn, tgt)
+ else:
+ raise exc.InvalidRequestError(
+ "No such event '%s' for target '%s'" % (identifier, target)
+ )
+
+
+def listen(target, identifier, fn, *args, **kw):
+ """Register a listener function for the given target.
+
+ The :func:`.listen` function is part of the primary interface for the
+ SQLAlchemy event system, documented at :ref:`event_toplevel`.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+ event.listen(
+ UniqueConstraint,
+ "after_parent_attach",
+ unique_constraint_name)
+
+ :param bool insert: The default behavior for event handlers is to append
+ the decorated user defined function to an internal list of registered
+ event listeners upon discovery. If a user registers a function with
+ ``insert=True``, SQLAlchemy will insert (prepend) the function to the
+ internal list upon discovery. This feature is not typically used or
+ recommended by the SQLAlchemy maintainers, but is provided to ensure
+ certain user defined functions can run before others, such as when
+ :ref:`Changing the sql_mode in MySQL <mysql_sql_mode>`.
+
+ :param bool named: When using named argument passing, the names listed in
+ the function argument specification will be used as keys in the
+ dictionary.
+ See :ref:`event_named_argument_styles`.
+
+ :param bool once: Private/Internal API usage. Deprecated. This parameter
+ would provide that an event function would run only once per given
+ target. It does not however imply automatic de-registration of the
+ listener function; associating an arbitrarily high number of listeners
+ without explicitly removing them will cause memory to grow unbounded even
+ if ``once=True`` is specified.
+
+ :param bool propagate: The ``propagate`` kwarg is available when working
+ with ORM instrumentation and mapping events.
+ See :class:`_ormevent.MapperEvents` and
+ :meth:`_ormevent.MapperEvents.before_mapper_configured` for examples.
+
+ :param bool retval: This flag applies only to specific event listeners,
+ each of which includes documentation explaining when it should be used.
+ By default, no listener ever requires a return value.
+ However, some listeners do support special behaviors for return values,
+ and include in their documentation that the ``retval=True`` flag is
+ necessary for a return value to be processed.
+
+ Event listener suites that make use of :paramref:`_event.listen.retval`
+ include :class:`_events.ConnectionEvents` and
+ :class:`_ormevent.AttributeEvents`.
+
+ .. note::
+
+ The :func:`.listen` function cannot be called at the same time
+ that the target event is being run. This has implications
+ for thread safety, and also means an event cannot be added
+ from inside the listener function for itself. The list of
+ events to be run are present inside of a mutable collection
+ that can't be changed during iteration.
+
+ Event registration and removal is not intended to be a "high
+ velocity" operation; it is a configurational operation. For
+ systems that need to quickly associate and deassociate with
+ events at high scale, use a mutable structure that is handled
+ from inside of a single listener.
+
+ .. seealso::
+
+ :func:`.listens_for`
+
+ :func:`.remove`
+
+ """
+
+ _event_key(target, identifier, fn).listen(*args, **kw)
+
+
+def listens_for(target, identifier, *args, **kw):
+ """Decorate a function as a listener for the given target + identifier.
+
+ The :func:`.listens_for` decorator is part of the primary interface for the
+ SQLAlchemy event system, documented at :ref:`event_toplevel`.
+
+ This function generally shares the same kwargs as :func:`.listens`.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ @event.listens_for(UniqueConstraint, "after_parent_attach")
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+
+ A given function can also be invoked for only the first invocation
+ of the event using the ``once`` argument::
+
+ @event.listens_for(Mapper, "before_configure", once=True)
+ def on_config():
+ do_config()
+
+
+ .. warning:: The ``once`` argument does not imply automatic de-registration
+ of the listener function after it has been invoked a first time; a
+ listener entry will remain associated with the target object.
+ Associating an arbitrarily high number of listeners without explicitly
+ removing them will cause memory to grow unbounded even if ``once=True``
+ is specified.
+
+ .. seealso::
+
+ :func:`.listen` - general description of event listening
+
+ """
+
+ def decorate(fn):
+ listen(target, identifier, fn, *args, **kw)
+ return fn
+
+ return decorate
+
+
+def remove(target, identifier, fn):
+ """Remove an event listener.
+
+ The arguments here should match exactly those which were sent to
+ :func:`.listen`; all the event registration which proceeded as a result
+ of this call will be reverted by calling :func:`.remove` with the same
+ arguments.
+
+ e.g.::
+
+ # if a function was registered like this...
+ @event.listens_for(SomeMappedClass, "before_insert", propagate=True)
+ def my_listener_function(*arg):
+ pass
+
+ # ... it's removed like this
+ event.remove(SomeMappedClass, "before_insert", my_listener_function)
+
+ Above, the listener function associated with ``SomeMappedClass`` was also
+ propagated to subclasses of ``SomeMappedClass``; the :func:`.remove`
+ function will revert all of these operations.
+
+ .. note::
+
+ The :func:`.remove` function cannot be called at the same time
+ that the target event is being run. This has implications
+ for thread safety, and also means an event cannot be removed
+ from inside the listener function for itself. The list of
+ events to be run are present inside of a mutable collection
+ that can't be changed during iteration.
+
+ Event registration and removal is not intended to be a "high
+ velocity" operation; it is a configurational operation. For
+ systems that need to quickly associate and deassociate with
+ events at high scale, use a mutable structure that is handled
+ from inside of a single listener.
+
+ .. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
+ used as the container for the list of events, which explicitly
+ disallows collection mutation while the collection is being
+ iterated.
+
+ .. seealso::
+
+ :func:`.listen`
+
+ """
+ _event_key(target, identifier, fn).remove()
+
+
+def contains(target, identifier, fn):
+ """Return True if the given target/ident/fn is set up to listen."""
+
+ return _event_key(target, identifier, fn).contains()
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
new file mode 100644
index 0000000..0d16165
--- /dev/null
+++ b/lib/sqlalchemy/event/attr.py
@@ -0,0 +1,468 @@
+# event/attr.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Attribute implementation for _Dispatch classes.
+
+The various listener targets for a particular event class are represented
+as attributes, which refer to collections of listeners to be fired off.
+These collections can exist at the class level as well as at the instance
+level. An event is fired off using code like this::
+
+ some_object.dispatch.first_connect(arg1, arg2)
+
+Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and
+``first_connect`` is typically an instance of ``_ListenerCollection``
+if event listeners are present, or ``_EmptyListener`` if none are present.
+
+The attribute mechanics here spend effort trying to ensure listener functions
+are available with a minimum of function call overhead, that unnecessary
+objects aren't created (i.e. many empty per-instance listener collections),
+as well as that everything is garbage collectable when owning references are
+lost. Other features such as "propagation" of listener functions across
+many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances,
+as well as support for subclass propagation (e.g. events assigned to
+``Pool`` vs. ``QueuePool``) are all implemented here.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import collections
+from itertools import chain
+import weakref
+
+from . import legacy
+from . import registry
+from .. import exc
+from .. import util
+from ..util import threading
+from ..util.concurrency import AsyncAdaptedLock
+
+
+class RefCollection(util.MemoizedSlots):
+ __slots__ = ("ref",)
+
+ def _memoized_attr_ref(self):
+ return weakref.ref(self, registry._collection_gced)
+
+
+class _empty_collection(object):
+ def append(self, element):
+ pass
+
+ def extend(self, other):
+ pass
+
+ def remove(self, element):
+ pass
+
+ def __iter__(self):
+ return iter([])
+
+ def clear(self):
+ pass
+
+
+class _ClsLevelDispatch(RefCollection):
+ """Class-level events on :class:`._Dispatch` classes."""
+
+ __slots__ = (
+ "clsname",
+ "name",
+ "arg_names",
+ "has_kw",
+ "legacy_signatures",
+ "_clslevel",
+ "__weakref__",
+ )
+
+ def __init__(self, parent_dispatch_cls, fn):
+ self.name = fn.__name__
+ self.clsname = parent_dispatch_cls.__name__
+ argspec = util.inspect_getfullargspec(fn)
+ self.arg_names = argspec.args[1:]
+ self.has_kw = bool(argspec.varkw)
+ self.legacy_signatures = list(
+ reversed(
+ sorted(
+ getattr(fn, "_legacy_signatures", []), key=lambda s: s[0]
+ )
+ )
+ )
+ fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
+
+ self._clslevel = weakref.WeakKeyDictionary()
+
+ def _adjust_fn_spec(self, fn, named):
+ if named:
+ fn = self._wrap_fn_for_kw(fn)
+ if self.legacy_signatures:
+ try:
+ argspec = util.get_callable_argspec(fn, no_self=True)
+ except TypeError:
+ pass
+ else:
+ fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
+ return fn
+
+ def _wrap_fn_for_kw(self, fn):
+ def wrap_kw(*args, **kw):
+ argdict = dict(zip(self.arg_names, args))
+ argdict.update(kw)
+ return fn(**argdict)
+
+ return wrap_kw
+
+ def insert(self, event_key, propagate):
+ target = event_key.dispatch_target
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target
+ )
+
+ for cls in util.walk_subclasses(target):
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ if cls not in self._clslevel:
+ self._assign_cls_collection(cls)
+ self._clslevel[cls].appendleft(event_key._listen_fn)
+ registry._stored_in_collection(event_key, self)
+
+ def append(self, event_key, propagate):
+ target = event_key.dispatch_target
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target
+ )
+ for cls in util.walk_subclasses(target):
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ if cls not in self._clslevel:
+ self._assign_cls_collection(cls)
+ self._clslevel[cls].append(event_key._listen_fn)
+ registry._stored_in_collection(event_key, self)
+
+ def _assign_cls_collection(self, target):
+ if getattr(target, "_sa_propagate_class_events", True):
+ self._clslevel[target] = collections.deque()
+ else:
+ self._clslevel[target] = _empty_collection()
+
+ def update_subclass(self, target):
+ if target not in self._clslevel:
+ self._assign_cls_collection(target)
+ clslevel = self._clslevel[target]
+ for cls in target.__mro__[1:]:
+ if cls in self._clslevel:
+ clslevel.extend(
+ [fn for fn in self._clslevel[cls] if fn not in clslevel]
+ )
+
+ def remove(self, event_key):
+ target = event_key.dispatch_target
+ for cls in util.walk_subclasses(target):
+ if cls in self._clslevel:
+ self._clslevel[cls].remove(event_key._listen_fn)
+ registry._removed_from_collection(event_key, self)
+
+ def clear(self):
+ """Clear all class level listeners"""
+
+ to_clear = set()
+ for dispatcher in self._clslevel.values():
+ to_clear.update(dispatcher)
+ dispatcher.clear()
+ registry._clear(self, to_clear)
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _ClsLevelDispatch at the class level of
+ a dispatcher, this returns self.
+
+ """
+ return self
+
+
+class _InstanceLevelDispatch(RefCollection):
+ __slots__ = ()
+
+ def _adjust_fn_spec(self, fn, named):
+ return self.parent._adjust_fn_spec(fn, named)
+
+
+class _EmptyListener(_InstanceLevelDispatch):
+ """Serves as a proxy interface to the events
+ served by a _ClsLevelDispatch, when there are no
+ instance-level events present.
+
+ Is replaced by _ListenerCollection when instance-level
+ events are added.
+
+ """
+
+ propagate = frozenset()
+ listeners = ()
+
+ __slots__ = "parent", "parent_listeners", "name"
+
+ def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
+ self.parent = parent # _ClsLevelDispatch
+ self.parent_listeners = parent._clslevel[target_cls]
+ self.name = parent.name
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _EmptyListener at the instance level of
+ a dispatcher, this generates a new
+ _ListenerCollection, applies it to the instance,
+ and returns it.
+
+ """
+ result = _ListenerCollection(self.parent, obj._instance_cls)
+ if getattr(obj, self.name) is self:
+ setattr(obj, self.name, result)
+ else:
+ assert isinstance(getattr(obj, self.name), _JoinedListener)
+ return result
+
+ def _needs_modify(self, *args, **kw):
+ raise NotImplementedError("need to call for_modify()")
+
+ exec_once = (
+ exec_once_unless_exception
+ ) = insert = append = remove = clear = _needs_modify
+
+ def __call__(self, *args, **kw):
+ """Execute this event."""
+
+ for fn in self.parent_listeners:
+ fn(*args, **kw)
+
+ def __len__(self):
+ return len(self.parent_listeners)
+
+ def __iter__(self):
+ return iter(self.parent_listeners)
+
+ def __bool__(self):
+ return bool(self.parent_listeners)
+
+ __nonzero__ = __bool__
+
+
+class _CompoundListener(_InstanceLevelDispatch):
+ __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once"
+
+ def _set_asyncio(self):
+ self._exec_once_mutex = AsyncAdaptedLock()
+
+ def _memoized_attr__exec_once_mutex(self):
+ return threading.Lock()
+
+ def _exec_once_impl(self, retry_on_exception, *args, **kw):
+ with self._exec_once_mutex:
+ if not self._exec_once:
+ try:
+ self(*args, **kw)
+ exception = False
+ except:
+ exception = True
+ raise
+ finally:
+ if not exception or not retry_on_exception:
+ self._exec_once = True
+
+ def exec_once(self, *args, **kw):
+ """Execute this event, but only if it has not been
+ executed already for this collection."""
+
+ if not self._exec_once:
+ self._exec_once_impl(False, *args, **kw)
+
+ def exec_once_unless_exception(self, *args, **kw):
+ """Execute this event, but only if it has not been
+ executed already for this collection, or was called
+ by a previous exec_once_unless_exception call and
+ raised an exception.
+
+ If exec_once was already called, then this method will never run
+ the callable regardless of whether it raised or not.
+
+ .. versionadded:: 1.3.8
+
+ """
+ if not self._exec_once:
+ self._exec_once_impl(True, *args, **kw)
+
+ def _exec_w_sync_on_first_run(self, *args, **kw):
+ """Execute this event, and use a mutex if it has not been
+ executed already for this collection, or was called
+ by a previous _exec_w_sync_on_first_run call and
+ raised an exception.
+
+ If _exec_w_sync_on_first_run was already called and didn't raise an
+ exception, then a mutex is not used.
+
+ .. versionadded:: 1.4.11
+
+ """
+ if not self._exec_w_sync_once:
+ with self._exec_once_mutex:
+ try:
+ self(*args, **kw)
+ except:
+ raise
+ else:
+ self._exec_w_sync_once = True
+ else:
+ self(*args, **kw)
+
+ def __call__(self, *args, **kw):
+ """Execute this event."""
+
+ for fn in self.parent_listeners:
+ fn(*args, **kw)
+ for fn in self.listeners:
+ fn(*args, **kw)
+
+ def __len__(self):
+ return len(self.parent_listeners) + len(self.listeners)
+
+ def __iter__(self):
+ return chain(self.parent_listeners, self.listeners)
+
+ def __bool__(self):
+ return bool(self.listeners or self.parent_listeners)
+
+ __nonzero__ = __bool__
+
+
+class _ListenerCollection(_CompoundListener):
+ """Instance-level attributes on instances of :class:`._Dispatch`.
+
+ Represents a collection of listeners.
+
+ As of 0.7.9, _ListenerCollection is only first
+ created via the _EmptyListener.for_modify() method.
+
+ """
+
+ __slots__ = (
+ "parent_listeners",
+ "parent",
+ "name",
+ "listeners",
+ "propagate",
+ "__weakref__",
+ )
+
+ def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
+ self._exec_once = False
+ self._exec_w_sync_once = False
+ self.parent_listeners = parent._clslevel[target_cls]
+ self.parent = parent
+ self.name = parent.name
+ self.listeners = collections.deque()
+ self.propagate = set()
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _ListenerCollection at the instance level of
+ a dispatcher, this returns self.
+
+ """
+ return self
+
+ def _update(self, other, only_propagate=True):
+ """Populate from the listeners in another :class:`_Dispatch`
+ object."""
+
+ existing_listeners = self.listeners
+ existing_listener_set = set(existing_listeners)
+ self.propagate.update(other.propagate)
+ other_listeners = [
+ l
+ for l in other.listeners
+ if l not in existing_listener_set
+ and not only_propagate
+ or l in self.propagate
+ ]
+
+ existing_listeners.extend(other_listeners)
+
+ to_associate = other.propagate.union(other_listeners)
+ registry._stored_in_collection_multi(self, other, to_associate)
+
+ def insert(self, event_key, propagate):
+ if event_key.prepend_to_list(self, self.listeners):
+ if propagate:
+ self.propagate.add(event_key._listen_fn)
+
+ def append(self, event_key, propagate):
+ if event_key.append_to_list(self, self.listeners):
+ if propagate:
+ self.propagate.add(event_key._listen_fn)
+
+ def remove(self, event_key):
+ self.listeners.remove(event_key._listen_fn)
+ self.propagate.discard(event_key._listen_fn)
+ registry._removed_from_collection(event_key, self)
+
+ def clear(self):
+ registry._clear(self, self.listeners)
+ self.propagate.clear()
+ self.listeners.clear()
+
+
+class _JoinedListener(_CompoundListener):
+ __slots__ = "parent", "name", "local", "parent_listeners"
+
+ def __init__(self, parent, name, local):
+ self._exec_once = False
+ self.parent = parent
+ self.name = name
+ self.local = local
+ self.parent_listeners = self.local
+
+ @property
+ def listeners(self):
+ return getattr(self.parent, self.name)
+
+ def _adjust_fn_spec(self, fn, named):
+ return self.local._adjust_fn_spec(fn, named)
+
+ def for_modify(self, obj):
+ self.local = self.parent_listeners = self.local.for_modify(obj)
+ return self
+
+ def insert(self, event_key, propagate):
+ self.local.insert(event_key, propagate)
+
+ def append(self, event_key, propagate):
+ self.local.append(event_key, propagate)
+
+ def remove(self, event_key):
+ self.local.remove(event_key)
+
+ def clear(self):
+ raise NotImplementedError()
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
new file mode 100644
index 0000000..510e16b
--- /dev/null
+++ b/lib/sqlalchemy/event/base.py
@@ -0,0 +1,345 @@
+# event/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base implementation classes.
+
+The public-facing ``Events`` serves as the base class for an event interface;
+its public attributes represent different kinds of events. These attributes
+are mirrored onto a ``_Dispatch`` class, which serves as a container for
+collections of listener functions. These collections are represented both
+at the class level of a particular ``_Dispatch`` class as well as within
+instances of ``_Dispatch``.
+
+"""
+from __future__ import absolute_import
+
+import weakref
+
+from .attr import _ClsLevelDispatch
+from .attr import _EmptyListener
+from .attr import _JoinedListener
+from .. import util
+
+
+_registrars = util.defaultdict(list)
+
+
+def _is_event_name(name):
+ # _sa_event prefix is special to support internal-only event names.
+ # most event names are just plain method names that aren't
+ # underscored.
+
+ return (
+ not name.startswith("_") and name != "dispatch"
+ ) or name.startswith("_sa_event")
+
+
+class _UnpickleDispatch(object):
+ """Serializable callable that re-generates an instance of
+ :class:`_Dispatch` given a particular :class:`.Events` subclass.
+
+ """
+
+ def __call__(self, _instance_cls):
+ for cls in _instance_cls.__mro__:
+ if "dispatch" in cls.__dict__:
+ return cls.__dict__["dispatch"].dispatch._for_class(
+ _instance_cls
+ )
+ else:
+ raise AttributeError("No class with a 'dispatch' member present.")
+
+
+class _Dispatch(object):
+ """Mirror the event listening definitions of an Events class with
+ listener collections.
+
+ Classes which define a "dispatch" member will return a
+ non-instantiated :class:`._Dispatch` subclass when the member
+ is accessed at the class level. When the "dispatch" member is
+ accessed at the instance level of its owner, an instance
+ of the :class:`._Dispatch` class is returned.
+
+ A :class:`._Dispatch` class is generated for each :class:`.Events`
+ class defined, by the :func:`._create_dispatcher_class` function.
+ The original :class:`.Events` classes remain untouched.
+ This decouples the construction of :class:`.Events` subclasses from
+ the implementation used by the event internals, and allows
+ inspecting tools like Sphinx to work in an unsurprising
+ way against the public API.
+
+ """
+
+ # In one ORM edge case, an attribute is added to _Dispatch,
+ # so __dict__ is used in just that case and potentially others.
+ __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
+
+ _empty_listener_reg = weakref.WeakKeyDictionary()
+
+ def __init__(self, parent, instance_cls=None):
+ self._parent = parent
+ self._instance_cls = instance_cls
+
+ if instance_cls:
+ try:
+ self._empty_listeners = self._empty_listener_reg[instance_cls]
+ except KeyError:
+ self._empty_listeners = self._empty_listener_reg[
+ instance_cls
+ ] = {
+ ls.name: _EmptyListener(ls, instance_cls)
+ for ls in parent._event_descriptors
+ }
+ else:
+ self._empty_listeners = {}
+
+ def __getattr__(self, name):
+ # Assign EmptyListeners as attributes on demand
+ # to reduce startup time for new dispatch objects.
+ try:
+ ls = self._empty_listeners[name]
+ except KeyError:
+ raise AttributeError(name)
+ else:
+ setattr(self, ls.name, ls)
+ return ls
+
+ @property
+ def _event_descriptors(self):
+ for k in self._event_names:
+ # Yield _ClsLevelDispatch related
+ # to relevant event name.
+ yield getattr(self, k)
+
+ @property
+ def _listen(self):
+ return self._events._listen
+
+ def _for_class(self, instance_cls):
+ return self.__class__(self, instance_cls)
+
+ def _for_instance(self, instance):
+ instance_cls = instance.__class__
+ return self._for_class(instance_cls)
+
+ def _join(self, other):
+ """Create a 'join' of this :class:`._Dispatch` and another.
+
+ This new dispatcher will dispatch events to both
+ :class:`._Dispatch` objects.
+
+ """
+ if "_joined_dispatch_cls" not in self.__class__.__dict__:
+ cls = type(
+ "Joined%s" % self.__class__.__name__,
+ (_JoinedDispatcher,),
+ {"__slots__": self._event_names},
+ )
+
+ self.__class__._joined_dispatch_cls = cls
+ return self._joined_dispatch_cls(self, other)
+
+ def __reduce__(self):
+ return _UnpickleDispatch(), (self._instance_cls,)
+
+ def _update(self, other, only_propagate=True):
+ """Populate from the listeners in another :class:`_Dispatch`
+ object."""
+ for ls in other._event_descriptors:
+ if isinstance(ls, _EmptyListener):
+ continue
+ getattr(self, ls.name).for_modify(self)._update(
+ ls, only_propagate=only_propagate
+ )
+
+ def _clear(self):
+ for ls in self._event_descriptors:
+ ls.for_modify(self).clear()
+
+
+class _EventMeta(type):
+ """Intercept new Event subclasses and create
+ associated _Dispatch classes."""
+
+ def __init__(cls, classname, bases, dict_):
+ _create_dispatcher_class(cls, classname, bases, dict_)
+ type.__init__(cls, classname, bases, dict_)
+
+
+def _create_dispatcher_class(cls, classname, bases, dict_):
+ """Create a :class:`._Dispatch` class corresponding to an
+ :class:`.Events` class."""
+
+ # there's all kinds of ways to do this,
+ # i.e. make a Dispatch class that shares the '_listen' method
+ # of the Event class, this is the straight monkeypatch.
+ if hasattr(cls, "dispatch"):
+ dispatch_base = cls.dispatch.__class__
+ else:
+ dispatch_base = _Dispatch
+
+ event_names = [k for k in dict_ if _is_event_name(k)]
+ dispatch_cls = type(
+ "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names}
+ )
+
+ dispatch_cls._event_names = event_names
+
+ dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
+ for k in dispatch_cls._event_names:
+ setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
+ _registrars[k].append(cls)
+
+ for super_ in dispatch_cls.__bases__:
+ if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
+ for ls in super_._events.dispatch._event_descriptors:
+ setattr(dispatch_inst, ls.name, ls)
+ dispatch_cls._event_names.append(ls.name)
+
+ if getattr(cls, "_dispatch_target", None):
+ the_cls = cls._dispatch_target
+ if (
+ hasattr(the_cls, "__slots__")
+ and "_slots_dispatch" in the_cls.__slots__
+ ):
+ cls._dispatch_target.dispatch = slots_dispatcher(cls)
+ else:
+ cls._dispatch_target.dispatch = dispatcher(cls)
+
+
+def _remove_dispatcher(cls):
+ for k in cls.dispatch._event_names:
+ _registrars[k].remove(cls)
+ if not _registrars[k]:
+ del _registrars[k]
+
+
+class Events(util.with_metaclass(_EventMeta, object)):
+ """Define event listening functions for a particular target type."""
+
+ @staticmethod
+ def _set_dispatch(cls, dispatch_cls):
+ # This allows an Events subclass to define additional utility
+ # methods made available to the target via
+ # "self.dispatch._events.<utilitymethod>"
+ # @staticmethod to allow easy "super" calls while in a metaclass
+ # constructor.
+ cls.dispatch = dispatch_cls(None)
+ dispatch_cls._events = cls
+ return cls.dispatch
+
+ @classmethod
+ def _accept_with(cls, target):
+ def dispatch_is(*types):
+ return all(isinstance(target.dispatch, t) for t in types)
+
+ def dispatch_parent_is(t):
+ return isinstance(target.dispatch.parent, t)
+
+ # Mapper, ClassManager, Session override this to
+ # also accept classes, scoped_sessions, sessionmakers, etc.
+ if hasattr(target, "dispatch"):
+ if (
+ dispatch_is(cls.dispatch.__class__)
+ or dispatch_is(type, cls.dispatch.__class__)
+ or (
+ dispatch_is(_JoinedDispatcher)
+ and dispatch_parent_is(cls.dispatch.__class__)
+ )
+ ):
+ return target
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ propagate=False,
+ insert=False,
+ named=False,
+ asyncio=False,
+ ):
+ event_key.base_listen(
+ propagate=propagate, insert=insert, named=named, asyncio=asyncio
+ )
+
+ @classmethod
+ def _remove(cls, event_key):
+ event_key.remove()
+
+ @classmethod
+ def _clear(cls):
+ cls.dispatch._clear()
+
+
+class _JoinedDispatcher(object):
+ """Represent a connection between two _Dispatch objects."""
+
+ __slots__ = "local", "parent", "_instance_cls"
+
+ def __init__(self, local, parent):
+ self.local = local
+ self.parent = parent
+ self._instance_cls = self.local._instance_cls
+
+ def __getattr__(self, name):
+ # Assign _JoinedListeners as attributes on demand
+ # to reduce startup time for new dispatch objects.
+ ls = getattr(self.local, name)
+ jl = _JoinedListener(self.parent, ls.name, ls)
+ setattr(self, ls.name, jl)
+ return jl
+
+ @property
+ def _listen(self):
+ return self.parent._listen
+
+ @property
+ def _events(self):
+ return self.parent._events
+
+
+class dispatcher(object):
+ """Descriptor used by target classes to
+ deliver the _Dispatch class at the class level
+ and produce new _Dispatch instances for target
+ instances.
+
+ """
+
+ def __init__(self, events):
+ self.dispatch = events.dispatch
+ self.events = events
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ try:
+ obj.__dict__["dispatch"] = disp
+ except AttributeError as ae:
+ util.raise_(
+ TypeError(
+ "target %r doesn't have __dict__, should it be "
+ "defining _slots_dispatch?" % (obj,)
+ ),
+ replace_context=ae,
+ )
+ return disp
+
+
+class slots_dispatcher(dispatcher):
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ if hasattr(obj, "_slots_dispatch"):
+ return obj._slots_dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ obj._slots_dispatch = disp
+ return disp
diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py
new file mode 100644
index 0000000..d9f6ce5
--- /dev/null
+++ b/lib/sqlalchemy/event/legacy.py
@@ -0,0 +1,185 @@
+# event/legacy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Routines to handle adaption of legacy call signatures,
+generation of deprecation notes and docstrings.
+
+"""
+
+from .. import util
+
+
+def _legacy_signature(since, argnames, converter=None):
+ def leg(fn):
+ if not hasattr(fn, "_legacy_signatures"):
+ fn._legacy_signatures = []
+ fn._legacy_signatures.append((since, argnames, converter))
+ return fn
+
+ return leg
+
+
+def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
+ for since, argnames, conv in dispatch_collection.legacy_signatures:
+ if argnames[-1] == "**kw":
+ has_kw = True
+ argnames = argnames[0:-1]
+ else:
+ has_kw = False
+
+ if len(argnames) == len(argspec.args) and has_kw is bool(
+ argspec.varkw
+ ):
+
+ formatted_def = "def %s(%s%s)" % (
+ dispatch_collection.name,
+ ", ".join(dispatch_collection.arg_names),
+ ", **kw" if has_kw else "",
+ )
+ warning_txt = (
+ 'The argument signature for the "%s.%s" event listener '
+ "has changed as of version %s, and conversion for "
+ "the old argument signature will be removed in a "
+ 'future release. The new signature is "%s"'
+ % (
+ dispatch_collection.clsname,
+ dispatch_collection.name,
+ since,
+ formatted_def,
+ )
+ )
+
+ if conv:
+ assert not has_kw
+
+ def wrap_leg(*args):
+ util.warn_deprecated(warning_txt, version=since)
+ return fn(*conv(*args))
+
+ else:
+
+ def wrap_leg(*args, **kw):
+ util.warn_deprecated(warning_txt, version=since)
+ argdict = dict(zip(dispatch_collection.arg_names, args))
+ args = [argdict[name] for name in argnames]
+ if has_kw:
+ return fn(*args, **kw)
+ else:
+ return fn(*args)
+
+ return wrap_leg
+ else:
+ return fn
+
+
+def _indent(text, indent):
+ return "\n".join(indent + line for line in text.split("\n"))
+
+
+def _standard_listen_example(dispatch_collection, sample_target, fn):
+ example_kw_arg = _indent(
+ "\n".join(
+ "%(arg)s = kw['%(arg)s']" % {"arg": arg}
+ for arg in dispatch_collection.arg_names[0:2]
+ ),
+ " ",
+ )
+ if dispatch_collection.legacy_signatures:
+ current_since = max(
+ since
+ for since, args, conv in dispatch_collection.legacy_signatures
+ )
+ else:
+ current_since = None
+ text = (
+ "from sqlalchemy import event\n\n\n"
+ "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+ "def receive_%(event_name)s("
+ "%(named_event_arguments)s%(has_kw_arguments)s):\n"
+ " \"listen for the '%(event_name)s' event\"\n"
+ "\n # ... (event handling logic) ...\n"
+ )
+
+ text %= {
+ "current_since": " (arguments as of %s)" % current_since
+ if current_since
+ else "",
+ "event_name": fn.__name__,
+ "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
+ "named_event_arguments": ", ".join(dispatch_collection.arg_names),
+ "example_kw_arg": example_kw_arg,
+ "sample_target": sample_target,
+ }
+ return text
+
+
+def _legacy_listen_examples(dispatch_collection, sample_target, fn):
+ text = ""
+ for since, args, conv in dispatch_collection.legacy_signatures:
+ text += (
+ "\n# DEPRECATED calling style (pre-%(since)s, "
+ "will be removed in a future release)\n"
+ "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+ "def receive_%(event_name)s("
+ "%(named_event_arguments)s%(has_kw_arguments)s):\n"
+ " \"listen for the '%(event_name)s' event\"\n"
+ "\n # ... (event handling logic) ...\n"
+ % {
+ "since": since,
+ "event_name": fn.__name__,
+ "has_kw_arguments": " **kw"
+ if dispatch_collection.has_kw
+ else "",
+ "named_event_arguments": ", ".join(args),
+ "sample_target": sample_target,
+ }
+ )
+ return text
+
+
+def _version_signature_changes(parent_dispatch_cls, dispatch_collection):
+ since, args, conv = dispatch_collection.legacy_signatures[0]
+ return (
+ "\n.. deprecated:: %(since)s\n"
+ " The :class:`.%(clsname)s.%(event_name)s` event now accepts the \n"
+ " arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
+ " Support for listener functions which accept the previous \n"
+ ' argument signature(s) listed above as "deprecated" will be \n'
+ " removed in a future release."
+ % {
+ "since": since,
+ "clsname": parent_dispatch_cls.__name__,
+ "event_name": dispatch_collection.name,
+ "named_event_arguments": ", ".join(dispatch_collection.arg_names),
+ "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
+ }
+ )
+
+
+def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn):
+ header = (
+ ".. container:: event_signatures\n\n"
+ " Example argument forms::\n"
+ "\n"
+ )
+
+ sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
+ text = header + _indent(
+ _standard_listen_example(dispatch_collection, sample_target, fn),
+ " " * 8,
+ )
+ if dispatch_collection.legacy_signatures:
+ text += _indent(
+ _legacy_listen_examples(dispatch_collection, sample_target, fn),
+ " " * 8,
+ )
+
+ text += _version_signature_changes(
+ parent_dispatch_cls, dispatch_collection
+ )
+
+ return util.inject_docstring_text(fn.__doc__, text, 1)
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
new file mode 100644
index 0000000..ac143c4
--- /dev/null
+++ b/lib/sqlalchemy/event/registry.py
@@ -0,0 +1,297 @@
+# event/registry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides managed registration services on behalf of :func:`.listen`
+arguments.
+
+By "managed registration", we mean that event listening functions and
+other objects can be added to various collections in such a way that their
+membership in all those collections can be revoked at once, based on
+an equivalent :class:`._EventKey`.
+
+"""
+
+from __future__ import absolute_import
+
+import collections
+import types
+import weakref
+
+from .. import exc
+from .. import util
+
+
+_key_to_collection = collections.defaultdict(dict)
+"""
+Given an original listen() argument, can locate all
+listener collections and the listener fn contained
+
+(target, identifier, fn) -> {
+ ref(listenercollection) -> ref(listener_fn)
+ ref(listenercollection) -> ref(listener_fn)
+ ref(listenercollection) -> ref(listener_fn)
+ }
+"""
+
+_collection_to_key = collections.defaultdict(dict)
+"""
+Given a _ListenerCollection or _ClsLevelListener, can locate
+all the original listen() arguments and the listener fn contained
+
+ref(listenercollection) -> {
+ ref(listener_fn) -> (target, identifier, fn),
+ ref(listener_fn) -> (target, identifier, fn),
+ ref(listener_fn) -> (target, identifier, fn),
+ }
+"""
+
+
+def _collection_gced(ref):
+ # defaultdict, so can't get a KeyError
+ if not _collection_to_key or ref not in _collection_to_key:
+ return
+ listener_to_key = _collection_to_key.pop(ref)
+ for key in listener_to_key.values():
+ if key in _key_to_collection:
+ # defaultdict, so can't get a KeyError
+ dispatch_reg = _key_to_collection[key]
+ dispatch_reg.pop(ref)
+ if not dispatch_reg:
+ _key_to_collection.pop(key)
+
+
+def _stored_in_collection(event_key, owner):
+ key = event_key._key
+
+ dispatch_reg = _key_to_collection[key]
+
+ owner_ref = owner.ref
+ listen_ref = weakref.ref(event_key._listen_fn)
+
+ if owner_ref in dispatch_reg:
+ return False
+
+ dispatch_reg[owner_ref] = listen_ref
+
+ listener_to_key = _collection_to_key[owner_ref]
+ listener_to_key[listen_ref] = key
+
+ return True
+
+
+def _removed_from_collection(event_key, owner):
+ key = event_key._key
+
+ dispatch_reg = _key_to_collection[key]
+
+ listen_ref = weakref.ref(event_key._listen_fn)
+
+ owner_ref = owner.ref
+ dispatch_reg.pop(owner_ref, None)
+ if not dispatch_reg:
+ del _key_to_collection[key]
+
+ if owner_ref in _collection_to_key:
+ listener_to_key = _collection_to_key[owner_ref]
+ listener_to_key.pop(listen_ref)
+
+
+def _stored_in_collection_multi(newowner, oldowner, elements):
+ if not elements:
+ return
+
+ oldowner = oldowner.ref
+ newowner = newowner.ref
+
+ old_listener_to_key = _collection_to_key[oldowner]
+ new_listener_to_key = _collection_to_key[newowner]
+
+ for listen_fn in elements:
+ listen_ref = weakref.ref(listen_fn)
+ try:
+ key = old_listener_to_key[listen_ref]
+ except KeyError:
+ # can occur during interpreter shutdown.
+ # see #6740
+ continue
+
+ try:
+ dispatch_reg = _key_to_collection[key]
+ except KeyError:
+ continue
+
+ if newowner in dispatch_reg:
+ assert dispatch_reg[newowner] == listen_ref
+ else:
+ dispatch_reg[newowner] = listen_ref
+
+ new_listener_to_key[listen_ref] = key
+
+
+def _clear(owner, elements):
+ if not elements:
+ return
+
+ owner = owner.ref
+ listener_to_key = _collection_to_key[owner]
+ for listen_fn in elements:
+ listen_ref = weakref.ref(listen_fn)
+ key = listener_to_key[listen_ref]
+ dispatch_reg = _key_to_collection[key]
+ dispatch_reg.pop(owner, None)
+
+ if not dispatch_reg:
+ del _key_to_collection[key]
+
+
+class _EventKey(object):
+ """Represent :func:`.listen` arguments."""
+
+ __slots__ = (
+ "target",
+ "identifier",
+ "fn",
+ "fn_key",
+ "fn_wrap",
+ "dispatch_target",
+ )
+
+ def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None):
+ self.target = target
+ self.identifier = identifier
+ self.fn = fn
+ if isinstance(fn, types.MethodType):
+ self.fn_key = id(fn.__func__), id(fn.__self__)
+ else:
+ self.fn_key = id(fn)
+ self.fn_wrap = _fn_wrap
+ self.dispatch_target = dispatch_target
+
+ @property
+ def _key(self):
+ return (id(self.target), self.identifier, self.fn_key)
+
+ def with_wrapper(self, fn_wrap):
+ if fn_wrap is self._listen_fn:
+ return self
+ else:
+ return _EventKey(
+ self.target,
+ self.identifier,
+ self.fn,
+ self.dispatch_target,
+ _fn_wrap=fn_wrap,
+ )
+
+ def with_dispatch_target(self, dispatch_target):
+ if dispatch_target is self.dispatch_target:
+ return self
+ else:
+ return _EventKey(
+ self.target,
+ self.identifier,
+ self.fn,
+ dispatch_target,
+ _fn_wrap=self.fn_wrap,
+ )
+
+ def listen(self, *args, **kw):
+ once = kw.pop("once", False)
+ once_unless_exception = kw.pop("_once_unless_exception", False)
+ named = kw.pop("named", False)
+
+ target, identifier, fn = (
+ self.dispatch_target,
+ self.identifier,
+ self._listen_fn,
+ )
+
+ dispatch_collection = getattr(target.dispatch, identifier)
+
+ adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named)
+
+ self = self.with_wrapper(adjusted_fn)
+
+ stub_function = getattr(
+ self.dispatch_target.dispatch._events, self.identifier
+ )
+ if hasattr(stub_function, "_sa_warn"):
+ stub_function._sa_warn()
+
+ if once or once_unless_exception:
+ self.with_wrapper(
+ util.only_once(
+ self._listen_fn, retry_on_exception=once_unless_exception
+ )
+ ).listen(*args, **kw)
+ else:
+ self.dispatch_target.dispatch._listen(self, *args, **kw)
+
+ def remove(self):
+ key = self._key
+
+ if key not in _key_to_collection:
+ raise exc.InvalidRequestError(
+ "No listeners found for event %s / %r / %s "
+ % (self.target, self.identifier, self.fn)
+ )
+
+ dispatch_reg = _key_to_collection.pop(key)
+
+ for collection_ref, listener_ref in dispatch_reg.items():
+ collection = collection_ref()
+ listener_fn = listener_ref()
+ if collection is not None and listener_fn is not None:
+ collection.remove(self.with_wrapper(listener_fn))
+
+ def contains(self):
+ """Return True if this event key is registered to listen."""
+ return self._key in _key_to_collection
+
+ def base_listen(
+ self,
+ propagate=False,
+ insert=False,
+ named=False,
+ retval=None,
+ asyncio=False,
+ ):
+
+ target, identifier = self.dispatch_target, self.identifier
+
+ dispatch_collection = getattr(target.dispatch, identifier)
+
+ for_modify = dispatch_collection.for_modify(target.dispatch)
+ if asyncio:
+ for_modify._set_asyncio()
+
+ if insert:
+ for_modify.insert(self, propagate)
+ else:
+ for_modify.append(self, propagate)
+
+ @property
+ def _listen_fn(self):
+ return self.fn_wrap or self.fn
+
+ def append_to_list(self, owner, list_):
+ if _stored_in_collection(self, owner):
+ list_.append(self._listen_fn)
+ return True
+ else:
+ return False
+
+ def remove_from_list(self, owner, list_):
+ _removed_from_collection(self, owner)
+ list_.remove(self._listen_fn)
+
+ def prepend_to_list(self, owner, list_):
+ if _stored_in_collection(self, owner):
+ list_.appendleft(self._listen_fn)
+ return True
+ else:
+ return False
diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py
new file mode 100644
index 0000000..d17b0b1
--- /dev/null
+++ b/lib/sqlalchemy/events.py
@@ -0,0 +1,14 @@
+# sqlalchemy/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Core event interfaces."""
+
+from .engine.events import ConnectionEvents
+from .engine.events import DialectEvents
+from .pool.events import PoolEvents
+from .sql.base import SchemaEventTarget
+from .sql.events import DDLEvents
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
new file mode 100644
index 0000000..78bcef3
--- /dev/null
+++ b/lib/sqlalchemy/exc.py
@@ -0,0 +1,733 @@
+# sqlalchemy/exc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Exceptions used with SQLAlchemy.
+
+The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are
+raised as a result of DBAPI exceptions are all subclasses of
+:exc:`.DBAPIError`.
+
+"""
+
+from .util import _preloaded
+from .util import compat
+
+_version_token = None
+
+
+class HasDescriptionCode(object):
+ """helper which adds 'code' as an attribute and '_code_str' as a method"""
+
+ code = None
+
+ def __init__(self, *arg, **kw):
+ code = kw.pop("code", None)
+ if code is not None:
+ self.code = code
+ super(HasDescriptionCode, self).__init__(*arg, **kw)
+
+ def _code_str(self):
+ if not self.code:
+ return ""
+ else:
+ return (
+ "(Background on this error at: "
+ "https://sqlalche.me/e/%s/%s)"
+ % (
+ _version_token,
+ self.code,
+ )
+ )
+
+ def __str__(self):
+ message = super(HasDescriptionCode, self).__str__()
+ if self.code:
+ message = "%s %s" % (message, self._code_str())
+ return message
+
+
+class SQLAlchemyError(HasDescriptionCode, Exception):
+ """Generic error class."""
+
+ def _message(self, as_unicode=compat.py3k):
+ # rules:
+ #
+ # 1. under py2k, for __str__ return single string arg as it was
+ # given without converting to unicode. for __unicode__
+ # do a conversion but check that it's not unicode already just in
+ # case
+ #
+ # 2. under py3k, single arg string will usually be a unicode
+ # object, but since __str__() must return unicode, check for
+ # bytestring just in case
+ #
+ # 3. for multiple self.args, this is not a case in current
+ # SQLAlchemy though this is happening in at least one known external
+ # library, call str() which does a repr().
+ #
+ if len(self.args) == 1:
+ text = self.args[0]
+
+ if as_unicode and isinstance(text, compat.binary_types):
+ text = compat.decode_backslashreplace(text, "utf-8")
+ # This is for when the argument is not a string of any sort.
+ # Otherwise, converting this exception to string would fail for
+ # non-string arguments.
+ elif compat.py3k or not as_unicode:
+ text = str(text)
+ else:
+ text = compat.text_type(text)
+
+ return text
+ else:
+ # this is not a normal case within SQLAlchemy but is here for
+ # compatibility with Exception.args - the str() comes out as
+ # a repr() of the tuple
+ return str(self.args)
+
+ def _sql_message(self, as_unicode):
+ message = self._message(as_unicode)
+
+ if self.code:
+ message = "%s %s" % (message, self._code_str())
+
+ return message
+
+ def __str__(self):
+ return self._sql_message(compat.py3k)
+
+ def __unicode__(self):
+ return self._sql_message(as_unicode=True)
+
+
+class ArgumentError(SQLAlchemyError):
+ """Raised when an invalid or conflicting function argument is supplied.
+
+ This error generally corresponds to construction time state errors.
+
+ """
+
+
+class ObjectNotExecutableError(ArgumentError):
+ """Raised when an object is passed to .execute() that can't be
+ executed as SQL.
+
+ .. versionadded:: 1.1
+
+ """
+
+ def __init__(self, target):
+ super(ObjectNotExecutableError, self).__init__(
+ "Not an executable object: %r" % target
+ )
+ self.target = target
+
+ def __reduce__(self):
+ return self.__class__, (self.target,)
+
+
+class NoSuchModuleError(ArgumentError):
+ """Raised when a dynamically-loaded module (usually a database dialect)
+ of a particular name cannot be located."""
+
+
+class NoForeignKeysError(ArgumentError):
+ """Raised when no foreign keys can be located between two selectables
+ during a join."""
+
+
+class AmbiguousForeignKeysError(ArgumentError):
+ """Raised when more than one foreign key matching can be located
+ between two selectables during a join."""
+
+
+class CircularDependencyError(SQLAlchemyError):
+ """Raised by topological sorts when a circular dependency is detected.
+
+ There are two scenarios where this error occurs:
+
+ * In a Session flush operation, if two objects are mutually dependent
+ on each other, they can not be inserted or deleted via INSERT or
+ DELETE statements alone; an UPDATE will be needed to post-associate
+ or pre-deassociate one of the foreign key constrained values.
+ The ``post_update`` flag described at :ref:`post_update` can resolve
+ this cycle.
+ * In a :attr:`_schema.MetaData.sorted_tables` operation, two
+ :class:`_schema.ForeignKey`
+ or :class:`_schema.ForeignKeyConstraint` objects mutually refer to each
+ other. Apply the ``use_alter=True`` flag to one or both,
+ see :ref:`use_alter`.
+
+ """
+
+ def __init__(self, message, cycles, edges, msg=None, code=None):
+ if msg is None:
+ message += " (%s)" % ", ".join(repr(s) for s in cycles)
+ else:
+ message = msg
+ SQLAlchemyError.__init__(self, message, code=code)
+ self.cycles = cycles
+ self.edges = edges
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (None, self.cycles, self.edges, self.args[0]),
+ {"code": self.code} if self.code is not None else {},
+ )
+
+
+class CompileError(SQLAlchemyError):
+ """Raised when an error occurs during SQL compilation"""
+
+
+class UnsupportedCompilationError(CompileError):
+ """Raised when an operation is not supported by the given compiler.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ :ref:`error_l7de`
+ """
+
+ code = "l7de"
+
+ def __init__(self, compiler, element_type, message=None):
+ super(UnsupportedCompilationError, self).__init__(
+ "Compiler %r can't render element of type %s%s"
+ % (compiler, element_type, ": %s" % message if message else "")
+ )
+ self.compiler = compiler
+ self.element_type = element_type
+ self.message = message
+
+ def __reduce__(self):
+ return self.__class__, (self.compiler, self.element_type, self.message)
+
+
+class IdentifierError(SQLAlchemyError):
+ """Raised when a schema name is beyond the max character limit"""
+
+
+class DisconnectionError(SQLAlchemyError):
+ """A disconnect is detected on a raw DB-API connection.
+
+ This error is raised and consumed internally by a connection pool. It can
+ be raised by the :meth:`_events.PoolEvents.checkout`
+ event so that the host pool
+ forces a retry; the exception will be caught three times in a row before
+ the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError`
+ regarding the connection attempt.
+
+ """
+
+ invalidate_pool = False
+
+
+class InvalidatePoolError(DisconnectionError):
+ """Raised when the connection pool should invalidate all stale connections.
+
+ A subclass of :class:`_exc.DisconnectionError` that indicates that the
+ disconnect situation encountered on the connection probably means the
+ entire pool should be invalidated, as the database has been restarted.
+
+ This exception will be handled otherwise the same way as
+ :class:`_exc.DisconnectionError`, allowing three attempts to reconnect
+ before giving up.
+
+ .. versionadded:: 1.2
+
+ """
+
+ invalidate_pool = True
+
+
+class TimeoutError(SQLAlchemyError): # noqa
+ """Raised when a connection pool times out on getting a connection."""
+
+
+class InvalidRequestError(SQLAlchemyError):
+ """SQLAlchemy was asked to do something it can't do.
+
+ This error generally corresponds to runtime state errors.
+
+ """
+
+
+class NoInspectionAvailable(InvalidRequestError):
+ """A subject passed to :func:`sqlalchemy.inspection.inspect` produced
+ no context for inspection."""
+
+
+class PendingRollbackError(InvalidRequestError):
+ """A transaction has failed and needs to be rolled back before
+ continuing.
+
+ .. versionadded:: 1.4
+
+ """
+
+
+class ResourceClosedError(InvalidRequestError):
+ """An operation was requested from a connection, cursor, or other
+ object that's in a closed state."""
+
+
+class NoSuchColumnError(InvalidRequestError, KeyError):
+ """A nonexistent column is requested from a ``Row``."""
+
+
+class NoResultFound(InvalidRequestError):
+ """A database result was required but none was found.
+
+
+ .. versionchanged:: 1.4 This exception is now part of the
+ ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol
+ remains importable from ``sqlalchemy.orm.exc``.
+
+
+ """
+
+
+class MultipleResultsFound(InvalidRequestError):
+ """A single database result was required but more than one were found.
+
+ .. versionchanged:: 1.4 This exception is now part of the
+ ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol
+ remains importable from ``sqlalchemy.orm.exc``.
+
+
+ """
+
+
+class NoReferenceError(InvalidRequestError):
+ """Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
+
+
+class AwaitRequired(InvalidRequestError):
+ """Error raised by the async greenlet spawn if no async operation
+ was awaited when it required one.
+
+ """
+
+ code = "xd1r"
+
+
+class MissingGreenlet(InvalidRequestError):
+ r"""Error raised by the async greenlet await\_ if called while not inside
+ the greenlet spawn context.
+
+ """
+
+ code = "xd2s"
+
+
+class NoReferencedTableError(NoReferenceError):
+ """Raised by ``ForeignKey`` when the referred ``Table`` cannot be
+ located.
+
+ """
+
+ def __init__(self, message, tname):
+ NoReferenceError.__init__(self, message)
+ self.table_name = tname
+
+ def __reduce__(self):
+ return self.__class__, (self.args[0], self.table_name)
+
+
+class NoReferencedColumnError(NoReferenceError):
+ """Raised by ``ForeignKey`` when the referred ``Column`` cannot be
+ located.
+
+ """
+
+ def __init__(self, message, tname, cname):
+ NoReferenceError.__init__(self, message)
+ self.table_name = tname
+ self.column_name = cname
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (self.args[0], self.table_name, self.column_name),
+ )
+
+
+class NoSuchTableError(InvalidRequestError):
+ """Table does not exist or is not visible to a connection."""
+
+
+class UnreflectableTableError(InvalidRequestError):
+ """Table exists but can't be reflected for some reason.
+
+ .. versionadded:: 1.2
+
+ """
+
+
+class UnboundExecutionError(InvalidRequestError):
+ """SQL was attempted without a database connection to execute it on."""
+
+
+class DontWrapMixin(object):
+ """A mixin class which, when applied to a user-defined Exception class,
+ will not be wrapped inside of :exc:`.StatementError` if the error is
+ emitted within the process of executing a statement.
+
+ E.g.::
+
+ from sqlalchemy.exc import DontWrapMixin
+
+ class MyCustomException(Exception, DontWrapMixin):
+ pass
+
+ class MySpecialType(TypeDecorator):
+ impl = String
+
+ def process_bind_param(self, value, dialect):
+ if value == 'invalid':
+ raise MyCustomException("invalid!")
+
+ """
+
+
+class StatementError(SQLAlchemyError):
+ """An error occurred during execution of a SQL statement.
+
+ :class:`StatementError` wraps the exception raised
+ during execution, and features :attr:`.statement`
+ and :attr:`.params` attributes which supply context regarding
+ the specifics of the statement which had an issue.
+
+ The wrapped exception object is available in
+ the :attr:`.orig` attribute.
+
+ """
+
+ statement = None
+ """The string SQL statement being invoked when this exception occurred."""
+
+ params = None
+ """The parameter list being used when this exception occurred."""
+
+ orig = None
+ """The DBAPI exception object."""
+
+ ismulti = None
+
+ def __init__(
+ self,
+ message,
+ statement,
+ params,
+ orig,
+ hide_parameters=False,
+ code=None,
+ ismulti=None,
+ ):
+ SQLAlchemyError.__init__(self, message, code=code)
+ self.statement = statement
+ self.params = params
+ self.orig = orig
+ self.ismulti = ismulti
+ self.hide_parameters = hide_parameters
+ self.detail = []
+
+ def add_detail(self, msg):
+ self.detail.append(msg)
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (
+ self.args[0],
+ self.statement,
+ self.params,
+ self.orig,
+ self.hide_parameters,
+ self.__dict__.get("code"),
+ self.ismulti,
+ ),
+ {"detail": self.detail},
+ )
+
+ @_preloaded.preload_module("sqlalchemy.sql.util")
+ def _sql_message(self, as_unicode):
+ util = _preloaded.preloaded.sql_util
+
+ details = [self._message(as_unicode=as_unicode)]
+ if self.statement:
+ if not as_unicode and not compat.py3k:
+ stmt_detail = "[SQL: %s]" % compat.safe_bytestring(
+ self.statement
+ )
+ else:
+ stmt_detail = "[SQL: %s]" % self.statement
+ details.append(stmt_detail)
+ if self.params:
+ if self.hide_parameters:
+ details.append(
+ "[SQL parameters hidden due to hide_parameters=True]"
+ )
+ else:
+ params_repr = util._repr_params(
+ self.params, 10, ismulti=self.ismulti
+ )
+ details.append("[parameters: %r]" % params_repr)
+ code_str = self._code_str()
+ if code_str:
+ details.append(code_str)
+ return "\n".join(["(%s)" % det for det in self.detail] + details)
+
+
+class DBAPIError(StatementError):
+ """Raised when the execution of a database operation fails.
+
+ Wraps exceptions raised by the DB-API underlying the
+ database operation. Driver-specific implementations of the standard
+ DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
+ :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
+ :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
+ that there is no guarantee that different DB-API implementations will
+ raise the same exception type for any given error condition.
+
+ :class:`DBAPIError` features :attr:`~.StatementError.statement`
+ and :attr:`~.StatementError.params` attributes which supply context
+ regarding the specifics of the statement which had an issue, for the
+ typical case when the error was raised within the context of
+ emitting a SQL statement.
+
+ The wrapped exception object is available in the
+ :attr:`~.StatementError.orig` attribute. Its type and properties are
+ DB-API implementation specific.
+
+ """
+
+ code = "dbapi"
+
+ @classmethod
+ def instance(
+ cls,
+ statement,
+ params,
+ orig,
+ dbapi_base_err,
+ hide_parameters=False,
+ connection_invalidated=False,
+ dialect=None,
+ ismulti=None,
+ ):
+ # Don't ever wrap these, just return them directly as if
+ # DBAPIError didn't exist.
+ if (
+ isinstance(orig, BaseException) and not isinstance(orig, Exception)
+ ) or isinstance(orig, DontWrapMixin):
+ return orig
+
+ if orig is not None:
+ # not a DBAPI error, statement is present.
+ # raise a StatementError
+ if isinstance(orig, SQLAlchemyError) and statement:
+ return StatementError(
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig.args[0],
+ ),
+ statement,
+ params,
+ orig,
+ hide_parameters=hide_parameters,
+ code=orig.code,
+ ismulti=ismulti,
+ )
+ elif not isinstance(orig, dbapi_base_err) and statement:
+ return StatementError(
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig,
+ ),
+ statement,
+ params,
+ orig,
+ hide_parameters=hide_parameters,
+ ismulti=ismulti,
+ )
+
+ glob = globals()
+ for super_ in orig.__class__.__mro__:
+ name = super_.__name__
+ if dialect:
+ name = dialect.dbapi_exception_translation_map.get(
+ name, name
+ )
+ if name in glob and issubclass(glob[name], DBAPIError):
+ cls = glob[name]
+ break
+
+ return cls(
+ statement,
+ params,
+ orig,
+ connection_invalidated=connection_invalidated,
+ hide_parameters=hide_parameters,
+ code=cls.code,
+ ismulti=ismulti,
+ )
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (
+ self.statement,
+ self.params,
+ self.orig,
+ self.hide_parameters,
+ self.connection_invalidated,
+ self.__dict__.get("code"),
+ self.ismulti,
+ ),
+ {"detail": self.detail},
+ )
+
+ def __init__(
+ self,
+ statement,
+ params,
+ orig,
+ hide_parameters=False,
+ connection_invalidated=False,
+ code=None,
+ ismulti=None,
+ ):
+ try:
+ text = str(orig)
+ except Exception as e:
+ text = "Error in str() of DB-API-generated exception: " + str(e)
+ StatementError.__init__(
+ self,
+ "(%s.%s) %s"
+ % (orig.__class__.__module__, orig.__class__.__name__, text),
+ statement,
+ params,
+ orig,
+ hide_parameters,
+ code=code,
+ ismulti=ismulti,
+ )
+ self.connection_invalidated = connection_invalidated
+
+
+class InterfaceError(DBAPIError):
+ """Wraps a DB-API InterfaceError."""
+
+ code = "rvf5"
+
+
+class DatabaseError(DBAPIError):
+ """Wraps a DB-API DatabaseError."""
+
+ code = "4xp6"
+
+
+class DataError(DatabaseError):
+ """Wraps a DB-API DataError."""
+
+ code = "9h9h"
+
+
+class OperationalError(DatabaseError):
+ """Wraps a DB-API OperationalError."""
+
+ code = "e3q8"
+
+
+class IntegrityError(DatabaseError):
+ """Wraps a DB-API IntegrityError."""
+
+ code = "gkpj"
+
+
+class InternalError(DatabaseError):
+ """Wraps a DB-API InternalError."""
+
+ code = "2j85"
+
+
+class ProgrammingError(DatabaseError):
+ """Wraps a DB-API ProgrammingError."""
+
+ code = "f405"
+
+
+class NotSupportedError(DatabaseError):
+ """Wraps a DB-API NotSupportedError."""
+
+ code = "tw8g"
+
+
+# Warnings
+
+
+class SADeprecationWarning(HasDescriptionCode, DeprecationWarning):
+ """Issued for usage of deprecated APIs."""
+
+ deprecated_since = None
+ "Indicates the version that started raising this deprecation warning"
+
+
+class Base20DeprecationWarning(SADeprecationWarning):
+ """Issued for usage of APIs specifically deprecated or legacy in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`error_b8d9`.
+
+ :ref:`deprecation_20_mode`
+
+ """
+
+ deprecated_since = "1.4"
+ "Indicates the version that started raising this deprecation warning"
+
+ def __str__(self):
+ return (
+ super(Base20DeprecationWarning, self).__str__()
+ + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)"
+ )
+
+
+class LegacyAPIWarning(Base20DeprecationWarning):
+ """indicates an API that is in 'legacy' status, a long term deprecation."""
+
+
+class RemovedIn20Warning(Base20DeprecationWarning):
+ """indicates an API that will be fully removed in SQLAlchemy 2.0."""
+
+
+class MovedIn20Warning(RemovedIn20Warning):
+ """Subtype of RemovedIn20Warning to indicate an API that moved only."""
+
+
+class SAPendingDeprecationWarning(PendingDeprecationWarning):
+ """A similar warning as :class:`_exc.SADeprecationWarning`, this warning
+ is not used in modern versions of SQLAlchemy.
+
+ """
+
+ deprecated_since = None
+ "Indicates the version that started raising this deprecation warning"
+
+
+class SAWarning(HasDescriptionCode, RuntimeWarning):
+ """Issued at runtime."""
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py
new file mode 100644
index 0000000..62bbbf3
--- /dev/null
+++ b/lib/sqlalchemy/ext/__init__.py
@@ -0,0 +1,11 @@
+# ext/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import util as _sa_util
+
+
+_sa_util.preloaded.import_prefix("sqlalchemy.ext")
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
new file mode 100644
index 0000000..fbf377a
--- /dev/null
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -0,0 +1,1627 @@
+# ext/associationproxy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Contain the ``AssociationProxy`` class.
+
+The ``AssociationProxy`` is a Python property object which provides
+transparent proxied access to the endpoint of an association object.
+
+See the example ``examples/association/proxied_association.py``.
+
+"""
+import operator
+
+from .. import exc
+from .. import inspect
+from .. import orm
+from .. import util
+from ..orm import collections
+from ..orm import interfaces
+from ..sql import or_
+from ..sql.operators import ColumnOperators
+
+
+def association_proxy(target_collection, attr, **kw):
+ r"""Return a Python property implementing a view of a target
+ attribute which references an attribute on members of the
+ target.
+
+ The returned value is an instance of :class:`.AssociationProxy`.
+
+ Implements a Python property representing a relationship as a collection
+ of simpler values, or a scalar value. The proxied property will mimic
+ the collection type of the target (list, dict or set), or, in the case of
+ a one to one relationship, a simple scalar value.
+
+ :param target_collection: Name of the attribute we'll proxy to.
+ This attribute is typically mapped by
+ :func:`~sqlalchemy.orm.relationship` to link to a target collection, but
+ can also be a many-to-one or non-scalar relationship.
+
+ :param attr: Attribute on the associated instance or instances we'll
+ proxy for.
+
+ For example, given a target collection of [obj1, obj2], a list created
+ by this proxy property would look like [getattr(obj1, *attr*),
+ getattr(obj2, *attr*)]
+
+ If the relationship is one-to-one or otherwise uselist=False, then
+ simply: getattr(obj, *attr*)
+
+ :param creator: optional.
+
+ When new items are added to this proxied collection, new instances of
+ the class collected by the target collection will be created. For list
+ and set collections, the target class constructor will be called with
+ the 'value' for the new instance. For dict types, two arguments are
+ passed: key and value.
+
+ If you want to construct instances differently, supply a *creator*
+ function that takes arguments as above and returns instances.
+
+ For scalar relationships, creator() will be called if the target is None.
+ If the target is present, set operations are proxied to setattr() on the
+ associated object.
+
+ If you have an associated object with multiple attributes, you may set
+ up multiple association proxies mapping to different attributes. See
+ the unit tests for examples, and for examples of how creator() functions
+ can be used to construct the scalar relationship on-demand in this
+ situation.
+
+ :param \*\*kw: Passes along any other keyword arguments to
+ :class:`.AssociationProxy`.
+
+ """
+ return AssociationProxy(target_collection, attr, **kw)
+
+
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
+"""Symbol indicating an :class:`.InspectionAttr` that's
+ of type :class:`.AssociationProxy`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+"""
+
+
+class AssociationProxy(interfaces.InspectionAttrInfo):
+ """A descriptor that presents a read/write view of an object attribute."""
+
+ is_attribute = True
+ extension_type = ASSOCIATION_PROXY
+
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
+ """Construct a new :class:`.AssociationProxy`.
+
+ The :func:`.association_proxy` function is provided as the usual
+ entrypoint here, though :class:`.AssociationProxy` can be instantiated
+ and/or subclassed directly.
+
+ :param target_collection: Name of the collection we'll proxy to,
+ usually created with :func:`_orm.relationship`.
+
+ :param attr: Attribute on the collected instances we'll proxy
+ for. For example, given a target collection of [obj1, obj2], a
+ list created by this proxy property would look like
+ [getattr(obj1, attr), getattr(obj2, attr)]
+
+ :param creator: Optional. When new items are added to this proxied
+ collection, new instances of the class collected by the target
+ collection will be created. For list and set collections, the
+ target class constructor will be called with the 'value' for the
+ new instance. For dict types, two arguments are passed:
+ key and value.
+
+ If you want to construct instances differently, supply a 'creator'
+ function that takes arguments as above and returns instances.
+
+ :param cascade_scalar_deletes: when True, indicates that setting
+ the proxied value to ``None``, or deleting it via ``del``, should
+ also remove the source object. Only applies to scalar attributes.
+ Normally, removing the proxied target will not remove the proxy
+ source, as this object may have other state that is still to be
+ kept.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`cascade_scalar_deletes` - complete usage example
+
+ :param getset_factory: Optional. Proxied attribute access is
+ automatically handled by routines that get and set values based on
+ the `attr` argument for this proxy.
+
+ If you would like to customize this behavior, you may supply a
+ `getset_factory` callable that produces a tuple of `getter` and
+ `setter` functions. The factory is called with two arguments, the
+ abstract type of the underlying collection and this proxy instance.
+
+ :param proxy_factory: Optional. The type of collection to emulate is
+ determined by sniffing the target collection. If your collection
+ type can't be determined by duck typing or you'd like to use a
+ different collection implementation, you may supply a factory
+ function to produce those collections. Only applicable to
+ non-scalar relationships.
+
+ :param proxy_bulk_set: Optional, use with proxy_factory. See
+ the _set() method for details.
+
+ :param info: optional, will be assigned to
+ :attr:`.AssociationProxy.info` if present.
+
+ .. versionadded:: 1.0.9
+
+ """
+ self.target_collection = target_collection
+ self.value_attr = attr
+ self.creator = creator
+ self.getset_factory = getset_factory
+ self.proxy_factory = proxy_factory
+ self.proxy_bulk_set = proxy_bulk_set
+ self.cascade_scalar_deletes = cascade_scalar_deletes
+
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
+ if info:
+ self.info = info
+
+ def __get__(self, obj, class_):
+ if class_ is None:
+ return self
+ inst = self._as_instance(class_, obj)
+ if inst:
+ return inst.get(obj)
+
+ # obj has to be None here
+ # assert obj is None
+
+ return self
+
+ def __set__(self, obj, values):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).set(obj, values)
+
+ def __delete__(self, obj):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).delete(obj)
+
+ def for_class(self, class_, obj=None):
+ r"""Return the internal state local to a specific mapped class.
+
+ E.g., given a class ``User``::
+
+ class User(Base):
+ # ...
+
+ keywords = association_proxy('kws', 'keyword')
+
+ If we access this :class:`.AssociationProxy` from
+ :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the
+ target class for this proxy as mapped by ``User``::
+
+ inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class
+
+ This returns an instance of :class:`.AssociationProxyInstance` that
+ is specific to the ``User`` class. The :class:`.AssociationProxy`
+ object remains agnostic of its parent class.
+
+ :param class\_: the class that we are returning state for.
+
+ :param obj: optional, an instance of the class that is required
+ if the attribute refers to a polymorphic target, e.g. where we have
+ to look at the type of the actual destination object to get the
+ complete path.
+
+ .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores
+ any state specific to a particular parent class; the state is now
+ stored in per-class :class:`.AssociationProxyInstance` objects.
+
+
+ """
+ return self._as_instance(class_, obj)
+
+ def _as_instance(self, class_, obj):
+ try:
+ inst = class_.__dict__[self.key + "_inst"]
+ except KeyError:
+ inst = None
+
+ # avoid exception context
+ if inst is None:
+ owner = self._calc_owner(class_)
+ if owner is not None:
+ inst = AssociationProxyInstance.for_proxy(self, owner, obj)
+ setattr(class_, self.key + "_inst", inst)
+ else:
+ inst = None
+
+ if inst is not None and not inst._is_canonical:
+ # the AssociationProxyInstance can't be generalized
+ # since the proxied attribute is not on the targeted
+ # class, only on subclasses of it, which might be
+ # different. only return for the specific
+ # object's current value
+ return inst._non_canonical_get_for_object(obj)
+ else:
+ return inst
+
+ def _calc_owner(self, target_cls):
+ # we might be getting invoked for a subclass
+ # that is not mapped yet, in some declarative situations.
+ # save until we are mapped
+ try:
+ insp = inspect(target_cls)
+ except exc.NoInspectionAvailable:
+ # can't find a mapper, don't set owner. if we are a not-yet-mapped
+ # subclass, we can also scan through __mro__ to find a mapped
+ # class, but instead just wait for us to be called again against a
+ # mapped class normally.
+ return None
+ else:
+ return insp.mapper.class_manager.class_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ setattr(o, attr, v)
+
+ return getter, setter
+
+ def __repr__(self):
+ return "AssociationProxy(%r, %r)" % (
+ self.target_collection,
+ self.value_attr,
+ )
+
+
+class AssociationProxyInstance(object):
+ """A per-class object that serves class- and object-specific results.
+
+ This is used by :class:`.AssociationProxy` when it is invoked
+ in terms of a specific class or instance of a class, i.e. when it is
+ used as a regular Python descriptor.
+
+ When referring to the :class:`.AssociationProxy` as a normal Python
+ descriptor, the :class:`.AssociationProxyInstance` is the object that
+ actually serves the information. Under normal circumstances, its presence
+ is transparent::
+
+ >>> User.keywords.scalar
+ False
+
+ In the special case that the :class:`.AssociationProxy` object is being
+ accessed directly, in order to get an explicit handle to the
+ :class:`.AssociationProxyInstance`, use the
+ :meth:`.AssociationProxy.for_class` method::
+
+ proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User)
+
+ # view if proxy object is scalar or not
+ >>> proxy_state.scalar
+ False
+
+ .. versionadded:: 1.3
+
+ """ # noqa
+
+ def __init__(self, parent, owning_class, target_class, value_attr):
+ self.parent = parent
+ self.key = parent.key
+ self.owning_class = owning_class
+ self.target_collection = parent.target_collection
+ self.collection_class = None
+ self.target_class = target_class
+ self.value_attr = value_attr
+
+ target_class = None
+ """The intermediary class handled by this
+ :class:`.AssociationProxyInstance`.
+
+ Intercepted append/set/assignment events will result
+ in the generation of new instances of this class.
+
+ """
+
+ @classmethod
+ def for_proxy(cls, parent, owning_class, parent_instance):
+ target_collection = parent.target_collection
+ value_attr = parent.value_attr
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
+
+ # this was never asserted before but this should be made clear.
+ if not isinstance(prop, orm.RelationshipProperty):
+ util.raise_(
+ NotImplementedError(
+ "association proxy to a non-relationship "
+ "intermediary is not supported"
+ ),
+ replace_context=None,
+ )
+
+ target_class = prop.mapper.class_
+
+ try:
+ target_assoc = cls._cls_unwrap_target_assoc_proxy(
+ target_class, value_attr
+ )
+ except AttributeError:
+ # the proxied attribute doesn't exist on the target class;
+ # return an "ambiguous" instance that will work on a per-object
+ # basis
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ except Exception as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ "Association proxy received an unexpected error when "
+ "trying to retreive attribute "
+ '"%s.%s" from '
+ 'class "%s": %s'
+ % (
+ target_class.__name__,
+ parent.value_attr,
+ target_class.__name__,
+ err,
+ )
+ ),
+ from_=err,
+ )
+ else:
+ return cls._construct_for_assoc(
+ target_assoc, parent, owning_class, target_class, value_attr
+ )
+
+ @classmethod
+ def _construct_for_assoc(
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
+ if target_assoc is not None:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ attr = getattr(target_class, value_attr)
+ if not hasattr(attr, "_is_internal_proxy"):
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ is_object = attr._impl_uses_objects
+ if is_object:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ else:
+ return ColumnAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ def _get_property(self):
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ @property
+ def _comparator(self):
+ return self._get_property().comparator
+
+ def __clause_element__(self):
+ raise NotImplementedError(
+ "The association proxy can't be used as a plain column "
+ "expression; it only works inside of a comparison expression"
+ )
+
+ @classmethod
+ def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
+ attr = getattr(target_class, value_attr)
+ if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
+ return attr
+ return None
+
+ @util.memoized_property
+ def _unwrap_target_assoc_proxy(self):
+ return self._cls_unwrap_target_assoc_proxy(
+ self.target_class, self.value_attr
+ )
+
+ @property
+ def remote_attr(self):
+ """The 'remote' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ """
+ return getattr(self.target_class, self.value_attr)
+
+ @property
+ def local_attr(self):
+ """The 'local' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return getattr(self.owning_class, self.target_collection)
+
+ @property
+ def attr(self):
+ """Return a tuple of ``(local_attr, remote_attr)``.
+
+ This attribute was originally intended to facilitate using the
+ :meth:`_query.Query.join` method to join across the two relationships
+ at once, however this makes use of a deprecated calling style.
+
+ To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with
+ an association proxy, the current method is to make use of the
+ :attr:`.AssociationProxyInstance.local_attr` and
+ :attr:`.AssociationProxyInstance.remote_attr` attributes separately::
+
+ stmt = (
+ select(Parent).
+ join(Parent.proxied.local_attr).
+ join(Parent.proxied.remote_attr)
+ )
+
+ A future release may seek to provide a more succinct join pattern
+ for association proxy attributes.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return (self.local_attr, self.remote_attr)
+
+ @util.memoized_property
+ def scalar(self):
+ """Return ``True`` if this :class:`.AssociationProxyInstance`
+ proxies a scalar relationship on the local side."""
+
+ scalar = not self._get_property().uselist
+ if scalar:
+ self._initialize_scalar_accessors()
+ return scalar
+
+ @util.memoized_property
+ def _value_is_scalar(self):
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
+
+ @property
+ def _target_is_object(self):
+ raise NotImplementedError()
+
+ def _initialize_scalar_accessors(self):
+ if self.parent.getset_factory:
+ get, set_ = self.parent.getset_factory(None, self)
+ else:
+ get, set_ = self.parent._default_getset(None)
+ self._scalar_get, self._scalar_set = get, set_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ return setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ return setattr(o, attr, v)
+
+ return getter, setter
+
+ @property
+ def info(self):
+ return self.parent.info
+
+ def get(self, obj):
+ if obj is None:
+ return self
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ return self._scalar_get(target)
+ else:
+ try:
+ # If the owning instance is reborn (orm session resurrect,
+ # etc.), refresh the proxy cache.
+ creator_id, self_id, proxy = getattr(obj, self.key)
+ except AttributeError:
+ pass
+ else:
+ if id(obj) == creator_id and id(self) == self_id:
+ assert self.collection_class is not None
+ return proxy
+
+ self.collection_class, proxy = self._new(
+ _lazy_collection(obj, self.target_collection)
+ )
+ setattr(obj, self.key, (id(obj), id(self), proxy))
+ return proxy
+
+ def set(self, obj, values):
+ if self.scalar:
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
+ target = getattr(obj, self.target_collection)
+ if target is None:
+ if values is None:
+ return
+ setattr(obj, self.target_collection, creator(values))
+ else:
+ self._scalar_set(target, values)
+ if values is None and self.parent.cascade_scalar_deletes:
+ setattr(obj, self.target_collection, None)
+ else:
+ proxy = self.get(obj)
+ assert self.collection_class is not None
+ if proxy is not values:
+ proxy._bulk_replace(self, values)
+
+ def delete(self, obj):
+ if self.owning_class is None:
+ self._calc_owner(obj, None)
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ if target is not None:
+ delattr(target, self.value_attr)
+ delattr(obj, self.target_collection)
+
+ def _new(self, lazy_collection):
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
+ collection_class = util.duck_type_collection(lazy_collection())
+
+ if self.parent.proxy_factory:
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(collection_class, self)
+ else:
+ getter, setter = self.parent._default_getset(collection_class)
+
+ if collection_class is list:
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is dict:
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is set:
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ else:
+ raise exc.ArgumentError(
+ "could not guess which interface to use for "
+ 'collection_class "%s" backing "%s"; specify a '
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
+
+ def _set(self, proxy, values):
+ if self.parent.proxy_bulk_set:
+ self.parent.proxy_bulk_set(proxy, values)
+ elif self.collection_class is list:
+ proxy.extend(values)
+ elif self.collection_class is dict:
+ proxy.update(values)
+ elif self.collection_class is set:
+ proxy.update(values)
+ else:
+ raise exc.ArgumentError(
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
+
+ def _inflate(self, proxy):
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(
+ self.collection_class, self
+ )
+ else:
+ getter, setter = self.parent._default_getset(self.collection_class)
+
+ proxy.creator = creator
+ proxy.getter = getter
+ proxy.setter = setter
+
+ def _criterion_exists(self, criterion=None, **kwargs):
+ is_has = kwargs.pop("is_has", None)
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ inner = target_assoc._criterion_exists(
+ criterion=criterion, **kwargs
+ )
+ return self._comparator._criterion_exists(inner)
+
+ if self._target_is_object:
+ prop = getattr(self.target_class, self.value_attr)
+ value_expr = prop._criterion_exists(criterion, **kwargs)
+ else:
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't apply keyword arguments to column-targeted "
+ "association proxy; use =="
+ )
+ elif is_has and criterion is not None:
+ raise exc.ArgumentError(
+ "Non-empty has() not allowed for "
+ "column-targeted association proxy; use =="
+ )
+
+ value_expr = criterion
+
+ return self._comparator._criterion_exists(value_expr)
+
+ def any(self, criterion=None, **kwargs):
+ """Produce a proxied 'any' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'any()' not implemented for scalar " "attributes. Use has()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=False, **kwargs
+ )
+
+ def has(self, criterion=None, **kwargs):
+ """Produce a proxied 'has' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'has()' not implemented for collections. " "Use any()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=True, **kwargs
+ )
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.parent)
+
+
+class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` where we cannot determine
+ the type of target object.
+ """
+
+ _is_canonical = False
+
+ def _ambiguous(self):
+ raise AttributeError(
+ "Association proxy %s.%s refers to an attribute '%s' that is not "
+ "directly mapped on class %s; therefore this operation cannot "
+ "proceed since we don't know what type of object is referred "
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
+
+ def get(self, obj):
+ if obj is None:
+ return self
+ else:
+ return super(AmbiguousAssociationProxyInstance, self).get(obj)
+
+ def __eq__(self, obj):
+ self._ambiguous()
+
+ def __ne__(self, obj):
+ self._ambiguous()
+
+ def any(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ def has(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ @util.memoized_property
+ def _lookup_cache(self):
+ # mapping of <subclass>->AssociationProxyInstance.
+ # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
+ # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
+ return {}
+
+ def _non_canonical_get_for_object(self, parent_instance):
+ if parent_instance is not None:
+ actual_obj = getattr(parent_instance, self.target_collection)
+ if actual_obj is not None:
+ try:
+ insp = inspect(actual_obj)
+ except exc.NoInspectionAvailable:
+ pass
+ else:
+ mapper = insp.mapper
+ instance_class = mapper.class_
+ if instance_class not in self._lookup_cache:
+ self._populate_cache(instance_class, mapper)
+
+ try:
+ return self._lookup_cache[instance_class]
+ except KeyError:
+ pass
+
+ # no object or ambiguous object given, so return "self", which
+ # is a proxy with generally only instance-level functionality
+ return self
+
+ def _populate_cache(self, instance_class, mapper):
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ if mapper.isa(prop.mapper):
+ target_class = instance_class
+ try:
+ target_assoc = self._cls_unwrap_target_assoc_proxy(
+ target_class, self.value_attr
+ )
+ except AttributeError:
+ pass
+ else:
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
+ )
+
+
+class ObjectAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` that has an object as a target."""
+
+ _target_is_object = True
+ _is_canonical = True
+
+ def contains(self, obj):
+ """Produce a proxied 'contains' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`,
+ :meth:`.RelationshipProperty.Comparator.has`,
+ and/or :meth:`.RelationshipProperty.Comparator.contains`
+ operators of the underlying proxied attributes.
+ """
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ return self._comparator._criterion_exists(
+ target_assoc.contains(obj)
+ if not target_assoc.scalar
+ else target_assoc == obj
+ )
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr).contains(obj)
+ )
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
+ raise exc.InvalidRequestError(
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
+ else:
+
+ return self._comparator._criterion_exists(**{self.value_attr: obj})
+
+ def __eq__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ if obj is None:
+ return or_(
+ self._comparator.has(**{self.value_attr: obj}),
+ self._comparator == None,
+ )
+ else:
+ return self._comparator.has(**{self.value_attr: obj})
+
+ def __ne__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr) != obj
+ )
+
+
+class ColumnAssociationProxyInstance(
+ ColumnOperators, AssociationProxyInstance
+):
+ """an :class:`.AssociationProxyInstance` that has a database column as a
+ target.
+ """
+
+ _target_is_object = False
+ _is_canonical = True
+
+ def __eq__(self, other):
+ # special case "is None" to check for no related row as well
+ expr = self._criterion_exists(
+ self.remote_attr.operate(operator.eq, other)
+ )
+ if other is None:
+ return or_(expr, self._comparator == None)
+ else:
+ return expr
+
+ def operate(self, op, *other, **kwargs):
+ return self._criterion_exists(
+ self.remote_attr.operate(op, *other, **kwargs)
+ )
+
+
+class _lazy_collection(object):
+ def __init__(self, obj, target):
+ self.parent = obj
+ self.target = target
+
+ def __call__(self):
+ return getattr(self.parent, self.target)
+
+ def __getstate__(self):
+ return {"obj": self.parent, "target": self.target}
+
+ def __setstate__(self, state):
+ self.parent = state["obj"]
+ self.target = state["target"]
+
+
+class _AssociationCollection(object):
+ def __init__(self, lazy_collection, creator, getter, setter, parent):
+ """Constructs an _AssociationCollection.
+
+ This will always be a subclass of either _AssociationList,
+ _AssociationSet, or _AssociationDict.
+
+ lazy_collection
+ A callable returning a list-based collection of entities (usually an
+ object attribute managed by a SQLAlchemy relationship())
+
+ creator
+ A function that creates new target entities. Given one parameter:
+ value. This assertion is assumed::
+
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+
+ getter
+ A function. Given an associated object, return the 'value'.
+
+ setter
+ A function. Given an associated object and a value, store that
+ value on the object.
+
+ """
+ self.lazy_collection = lazy_collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+ self.parent = parent
+
+ col = property(lambda self: self.lazy_collection())
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ return bool(self.col)
+
+ __nonzero__ = __bool__
+
+ def __getstate__(self):
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
+
+ def __setstate__(self, state):
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
+ self.parent._inflate(self)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ self.clear()
+ assoc_proxy._set(self, values)
+
+
+class _AssociationList(_AssociationCollection):
+ """Generic, converting, list-to-list proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, value):
+ return self.setter(object_, value)
+
+ def __getitem__(self, index):
+ if not isinstance(index, slice):
+ return self._get(self.col[index])
+ else:
+ return [self._get(member) for member in self.col[index]]
+
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ self._set(self.col[index], value)
+ else:
+ if index.stop is None:
+ stop = len(self)
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+
+ start = index.start or 0
+ rng = list(range(index.start or 0, stop, step))
+ if step == 1:
+ for i in rng:
+ del self[start]
+ i = start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value), len(rng))
+ )
+ for i, item in zip(rng, value):
+ self._set(self.col[i], item)
+
+ def __delitem__(self, index):
+ del self.col[index]
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __getslice__(self, start, end):
+ return [self._get(member) for member in self.col[start:end]]
+
+ def __setslice__(self, start, end, values):
+ members = [self._create(v) for v in values]
+ self.col[start:end] = members
+
+ def __delslice__(self, start, end):
+ del self.col[start:end]
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or
+ just use the underlying collection directly from its property
+ on the parent.
+ """
+
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def append(self, value):
+ col = self.col
+ item = self._create(value)
+ col.append(item)
+
+ def count(self, value):
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
+
+ def extend(self, values):
+ for v in values:
+ self.append(v)
+
+ def insert(self, index, value):
+ self.col[index:index] = [self._create(value)]
+
+ def pop(self, index=-1):
+ return self.getter(self.col.pop(index))
+
+ def remove(self, value):
+ for i, val in enumerate(self):
+ if val == value:
+ del self.col[i]
+ return
+ raise ValueError("value not in list")
+
+ def reverse(self):
+ """Not supported, use reversed(mylist)"""
+
+ raise NotImplementedError
+
+ def sort(self):
+ """Not supported, use sorted(mylist)"""
+
+ raise NotImplementedError
+
+ def clear(self):
+ del self.col[0 : len(self.col)]
+
+ def __eq__(self, other):
+ return list(self) == other
+
+ def __ne__(self, other):
+ return list(self) != other
+
+ def __lt__(self, other):
+ return list(self) < other
+
+ def __le__(self, other):
+ return list(self) <= other
+
+ def __gt__(self, other):
+ return list(self) > other
+
+ def __ge__(self, other):
+ return list(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(list(self), other)
+
+ def __add__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return list(self) + other
+
+ def __radd__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return other + list(self)
+
+ def __mul__(self, n):
+ if not isinstance(n, int):
+ return NotImplemented
+ return list(self) * n
+
+ __rmul__ = __mul__
+
+ def __iadd__(self, iterable):
+ self.extend(iterable)
+ return self
+
+ def __imul__(self, n):
+ # unlike a regular list *=, proxied __imul__ will generate unique
+ # backing objects for each copy. *= on proxied lists is a bit of
+ # a stretch anyhow, and this interpretation of the __imul__ contract
+ # is more plausibly useful than copying the backing objects.
+ if not isinstance(n, int):
+ return NotImplemented
+ if n == 0:
+ self.clear()
+ elif n > 1:
+ self.extend(list(self) * (n - 1))
+ return self
+
+ def index(self, item, *args):
+ return list(self).index(item, *args)
+
+ def copy(self):
+ return list(self)
+
+ def __repr__(self):
+ return repr(list(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+_NotProvided = util.symbol("_NotProvided")
+
+
+class _AssociationDict(_AssociationCollection):
+ """Generic, converting, dict-to-dict proxy."""
+
+ def _create(self, key, value):
+ return self.creator(key, value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, key, value):
+ return self.setter(object_, key, value)
+
+ def __getitem__(self, key):
+ return self._get(self.col[key])
+
+ def __setitem__(self, key, value):
+ if key in self.col:
+ self._set(self.col[key], key, value)
+ else:
+ self.col[key] = self._create(key, value)
+
+ def __delitem__(self, key):
+ del self.col[key]
+
+ def __contains__(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def has_key(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def __iter__(self):
+ return iter(self.col.keys())
+
+ def clear(self):
+ self.col.clear()
+
+ def __eq__(self, other):
+ return dict(self) == other
+
+ def __ne__(self, other):
+ return dict(self) != other
+
+ def __lt__(self, other):
+ return dict(self) < other
+
+ def __le__(self, other):
+ return dict(self) <= other
+
+ def __gt__(self, other):
+ return dict(self) > other
+
+ def __ge__(self, other):
+ return dict(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(dict(self), other)
+
+ def __repr__(self):
+ return repr(dict(self.items()))
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def setdefault(self, key, default=None):
+ if key not in self.col:
+ self.col[key] = self._create(key, default)
+ return default
+ else:
+ return self[key]
+
+ def keys(self):
+ return self.col.keys()
+
+ if util.py2k:
+
+ def iteritems(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def itervalues(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def iterkeys(self):
+ return self.col.iterkeys()
+
+ def values(self):
+ return [self._get(member) for member in self.col.values()]
+
+ def items(self):
+ return [(k, self._get(self.col[k])) for k in self]
+
+ else:
+
+ def items(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def values(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def pop(self, key, default=_NotProvided):
+ if default is _NotProvided:
+ member = self.col.pop(key)
+ else:
+ member = self.col.pop(key, default)
+ return self._get(member)
+
+ def popitem(self):
+ item = self.col.popitem()
+ return (item[0], self._get(item[1]))
+
+ def update(self, *a, **kw):
+ if len(a) > 1:
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
+ elif len(a) == 1:
+ seq_or_map = a[0]
+ # discern dict from sequence - took the advice from
+ # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
+ # still not perfect :(
+ if hasattr(seq_or_map, "keys"):
+ for item in seq_or_map:
+ self[item] = seq_or_map[item]
+ else:
+ try:
+ for k, v in seq_or_map:
+ self[k] = v
+ except ValueError as err:
+ util.raise_(
+ ValueError(
+ "dictionary update sequence "
+ "requires 2-element tuples"
+ ),
+ replace_context=err,
+ )
+
+ for key, value in kw:
+ self[key] = value
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ for key, member in values.items() or ():
+ if key in additions:
+ self[key] = member
+ elif key in constants:
+ self[key] = member
+
+ for key in removals:
+ del self[key]
+
+ def copy(self):
+ return dict(self.items())
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
+ func.__doc__ = getattr(dict, func_name).__doc__
+ del func_name, func
+
+
+class _AssociationSet(_AssociationCollection):
+ """Generic, converting, set-to-set proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ if self.col:
+ return True
+ else:
+ return False
+
+ __nonzero__ = __bool__
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or just use
+ the underlying collection directly from its property on the parent.
+
+ """
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def add(self, value):
+ if value not in self:
+ self.col.add(self._create(value))
+
+ # for discard and remove, choosing a more expensive check strategy rather
+ # than call self.creator()
+ def discard(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ break
+
+ def remove(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ return
+ raise KeyError(value)
+
+ def pop(self):
+ if not self.col:
+ raise KeyError("pop from an empty set")
+ member = self.col.pop()
+ return self._get(member)
+
+ def update(self, other):
+ for value in other:
+ self.add(value)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ appender = self.add
+ remover = self.remove
+
+ for member in values or ():
+ if member in additions:
+ appender(member)
+ elif member in constants:
+ appender(member)
+
+ for member in removals:
+ remover(member)
+
+ def __ior__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.add(value)
+ return self
+
+ def _set(self):
+ return set(iter(self))
+
+ def union(self, other):
+ return set(self).union(other)
+
+ __or__ = union
+
+ def difference(self, other):
+ return set(self).difference(other)
+
+ __sub__ = difference
+
+ def difference_update(self, other):
+ for value in other:
+ self.discard(value)
+
+ def __isub__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.discard(value)
+ return self
+
+ def intersection(self, other):
+ return set(self).intersection(other)
+
+ __and__ = intersection
+
+ def intersection_update(self, other):
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __iand__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def symmetric_difference(self, other):
+ return set(self).symmetric_difference(other)
+
+ __xor__ = symmetric_difference
+
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __ixor__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def issubset(self, other):
+ return set(self).issubset(other)
+
+ def issuperset(self, other):
+ return set(self).issuperset(other)
+
+ def clear(self):
+ self.col.clear()
+
+ def copy(self):
+ return set(self)
+
+ def __eq__(self, other):
+ return set(self) == other
+
+ def __ne__(self, other):
+ return set(self) != other
+
+ def __lt__(self, other):
+ return set(self) < other
+
+ def __le__(self, other):
+ return set(self) <= other
+
+ def __gt__(self, other):
+ return set(self) > other
+
+ def __ge__(self, other):
+ return set(self) >= other
+
+ def __repr__(self):
+ return repr(set(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
+ func.__doc__ = getattr(set, func_name).__doc__
+ del func_name, func
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
new file mode 100644
index 0000000..15b2cb0
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -0,0 +1,22 @@
+# ext/asyncio/__init__.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import async_engine_from_config
+from .engine import AsyncConnection
+from .engine import AsyncEngine
+from .engine import AsyncTransaction
+from .engine import create_async_engine
+from .events import AsyncConnectionEvents
+from .events import AsyncSessionEvents
+from .result import AsyncMappingResult
+from .result import AsyncResult
+from .result import AsyncScalarResult
+from .scoping import async_scoped_session
+from .session import async_object_session
+from .session import async_session
+from .session import AsyncSession
+from .session import AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
new file mode 100644
index 0000000..3f77f55
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -0,0 +1,89 @@
+import abc
+import functools
+import weakref
+
+from . import exc as async_exc
+
+
+class ReversibleProxy:
+ # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
+ _proxy_objects = {}
+ __slots__ = ("__weakref__",)
+
+ def _assign_proxied(self, target):
+ if target is not None:
+ target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ proxy_ref = weakref.ref(
+ self,
+ functools.partial(ReversibleProxy._target_gced, target_ref),
+ )
+ ReversibleProxy._proxy_objects[target_ref] = proxy_ref
+
+ return target
+
+ @classmethod
+ def _target_gced(cls, ref, proxy_ref=None):
+ cls._proxy_objects.pop(ref, None)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ raise NotImplementedError()
+
+ @classmethod
+ def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ try:
+ proxy_ref = cls._proxy_objects[weakref.ref(target)]
+ except KeyError:
+ pass
+ else:
+ proxy = proxy_ref()
+ if proxy is not None:
+ return proxy
+
+ if regenerate:
+ return cls._regenerate_proxy_for_target(target)
+ else:
+ return None
+
+
+class StartableContext(abc.ABC):
+ __slots__ = ()
+
+ @abc.abstractmethod
+ async def start(self, is_ctxmanager=False):
+ pass
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aenter__(self):
+ return await self.start(is_ctxmanager=True)
+
+ @abc.abstractmethod
+ async def __aexit__(self, type_, value, traceback):
+ pass
+
+ def _raise_for_not_started(self):
+ raise async_exc.AsyncContextNotStarted(
+ "%s context has not been started and object has not been awaited."
+ % (self.__class__.__name__)
+ )
+
+
+class ProxyComparable(ReversibleProxy):
+ __slots__ = ()
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, self.__class__)
+ and self._proxied == other._proxied
+ )
+
+ def __ne__(self, other):
+ return (
+ not isinstance(other, self.__class__)
+ or self._proxied != other._proxied
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
new file mode 100644
index 0000000..4fbe4f7
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -0,0 +1,828 @@
+# ext/asyncio/engine.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import asyncio
+
+from . import exc as async_exc
+from .base import ProxyComparable
+from .base import StartableContext
+from .result import _ensure_sync_result
+from .result import AsyncResult
+from ... import exc
+from ... import inspection
+from ... import util
+from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
+from ...future import Connection
+from ...future import Engine
+from ...util.concurrency import greenlet_spawn
+
+
+def create_async_engine(*arg, **kw):
+ """Create a new async engine instance.
+
+ Arguments passed to :func:`_asyncio.create_async_engine` are mostly
+ identical to those passed to the :func:`_sa.create_engine` function.
+ The specified dialect must be an asyncio-compatible dialect
+ such as :ref:`dialect-postgresql-asyncpg`.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if kw.get("server_side_cursors", False):
+ raise async_exc.AsyncMethodRequired(
+ "Can't set server_side_cursors for async engine globally; "
+ "use the connection.stream() method for an async "
+ "streaming result set"
+ )
+ kw["future"] = True
+ sync_engine = _create_engine(*arg, **kw)
+ return AsyncEngine(sync_engine)
+
+
+def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+ """Create a new AsyncEngine instance using a configuration dictionary.
+
+ This function is analogous to the :func:`_sa.engine_from_config` function
+ in SQLAlchemy Core, except that the requested dialect must be an
+ asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`.
+ The argument signature of the function is identical to that
+ of :func:`_sa.engine_from_config`.
+
+ .. versionadded:: 1.4.29
+
+ """
+ options = {
+ key[len(prefix) :]: value
+ for key, value in configuration.items()
+ if key.startswith(prefix)
+ }
+ options["_coerce_config"] = True
+ options.update(kwargs)
+ url = options.pop("url")
+ return create_async_engine(url, **options)
+
+
+class AsyncConnectable:
+ __slots__ = "_slots_dispatch", "__weakref__"
+
+
+@util.create_proxy_methods(
+ Connection,
+ ":class:`_future.Connection`",
+ ":class:`_asyncio.AsyncConnection`",
+ classmethods=[],
+ methods=[],
+ attributes=[
+ "closed",
+ "invalidated",
+ "dialect",
+ "default_isolation_level",
+ ],
+)
+class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Connection`.
+
+ :class:`_asyncio.AsyncConnection` is acquired using the
+ :meth:`_asyncio.AsyncEngine.connect`
+ method of :class:`_asyncio.AsyncEngine`::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ async with engine.connect() as conn:
+ result = await conn.execute(select(table))
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncConnection is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncConnection that matches this one given only the
+ # "sync" elements.
+ __slots__ = (
+ "engine",
+ "sync_engine",
+ "sync_connection",
+ )
+
+ def __init__(self, async_engine, sync_connection=None):
+ self.engine = async_engine
+ self.sync_engine = async_engine.sync_engine
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ sync_connection: Connection
+ """Reference to the sync-style :class:`_engine.Connection` this
+ :class:`_asyncio.AsyncConnection` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncConnection` is associated with via its underlying
+ :class:`_engine.Connection`.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncConnection` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+ if self.sync_connection:
+ raise exc.InvalidRequestError("connection is already started")
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
+ return self
+
+ @property
+ def connection(self):
+ """Not implemented for async; call
+ :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+ """
+ raise exc.InvalidRequestError(
+ "AsyncConnection.connection accessor is not implemented as the "
+ "attribute may need to reconnect on an invalidated connection. "
+ "Use the get_raw_connection() method."
+ )
+
+ async def get_raw_connection(self):
+ """Return the pooled DBAPI-level connection in use by this
+ :class:`_asyncio.AsyncConnection`.
+
+ This is a SQLAlchemy connection-pool proxied connection
+ which then has the attribute
+ :attr:`_pool._ConnectionFairy.driver_connection` that refers to the
+ actual driver connection. Its
+ :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead
+ to an :class:`_engine.AdaptedConnection` instance that
+ adapts the driver connection to the DBAPI protocol.
+
+ """
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(getattr, conn, "connection")
+
+ @property
+ def _proxied(self):
+ return self.sync_connection
+
+ @property
+ def info(self):
+ """Return the :attr:`_engine.Connection.info` dictionary of the
+ underlying :class:`_engine.Connection`.
+
+ This dictionary is freely writable for user-defined state to be
+ associated with the database connection.
+
+ This attribute is only available if the :class:`.AsyncConnection` is
+ currently connected. If the :attr:`.AsyncConnection.closed` attribute
+ is ``True``, then accessing this attribute will raise
+ :class:`.ResourceClosedError`.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ return self.sync_connection.info
+
+ def _sync_connection(self):
+ if not self.sync_connection:
+ self._raise_for_not_started()
+ return self.sync_connection
+
+ def begin(self):
+ """Begin a transaction prior to autobegin occurring."""
+ self._sync_connection()
+ return AsyncTransaction(self)
+
+ def begin_nested(self):
+ """Begin a nested transaction and return a transaction handle."""
+ self._sync_connection()
+ return AsyncTransaction(self, nested=True)
+
+ async def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ See the method :meth:`_engine.Connection.invalidate` for full
+ detail on this method.
+
+ """
+
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.invalidate, exception=exception)
+
+ async def get_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def set_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ def in_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+
+ conn = self._sync_connection()
+
+ return conn.in_transaction()
+
+ def in_nested_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ return conn.in_nested_transaction()
+
+ def get_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_transaction` method to get the current
+ :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ nested (savepoint) transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_nested_transaction` method to get the
+ current :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_nested_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ async def execution_options(self, **opt):
+ r"""Set non-SQL options for the connection which take effect
+ during execution.
+
+ This returns this :class:`_asyncio.AsyncConnection` object with
+ the new options added.
+
+ See :meth:`_future.Connection.execution_options` for full details
+ on this method.
+
+ """
+
+ conn = self._sync_connection()
+ c2 = await greenlet_spawn(conn.execution_options, **opt)
+ assert c2 is conn
+ return self
+
+ async def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.commit)
+
+ async def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.rollback)
+
+ async def close(self):
+ """Close this :class:`_asyncio.AsyncConnection`.
+
+ This has the effect of also rolling back the transaction if one
+ is in place.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.close)
+
+ async def exec_driver_sql(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a driver-level SQL string and return buffered
+ :class:`_engine.Result`.
+
+ """
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn.exec_driver_sql,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+
+ return await _ensure_sync_result(result, self.exec_driver_sql)
+
+ async def stream(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object."""
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ util.EMPTY_DICT.merge_with(
+ execution_options, {"stream_results": True}
+ ),
+ _require_await=True,
+ )
+ if not result.context._is_server_side:
+ # TODO: real exception here
+ assert False, "server side result expected"
+ return AsyncResult(result)
+
+ async def execute(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and return a buffered
+ :class:`_engine.Result`.
+
+ :param object: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar objects.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalars` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a :class:`_engine.ScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def stream_scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement and returns a streaming scalar result
+ object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.AsyncResult.scalars` method after invoking the
+ :meth:`_future.Connection.stream` method. Parameters are equivalent.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.stream(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing self as the first argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with async_engine.begin() as conn:
+ await conn.run_sync(metadata.create_all)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(fn, conn, *arg, **kw)
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+
+@util.create_proxy_methods(
+ Engine,
+ ":class:`_future.Engine`",
+ ":class:`_asyncio.AsyncEngine`",
+ classmethods=[],
+ methods=[
+ "clear_compiled_cache",
+ "update_execution_options",
+ "get_execution_options",
+ ],
+ attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(ProxyComparable, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Engine`.
+
+ :class:`_asyncio.AsyncEngine` is acquired using the
+ :func:`_asyncio.create_async_engine` function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncEngine is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncEngine that matches this one given only the
+ # "sync" elements.
+ __slots__ = ("sync_engine", "_proxied")
+
+ _connection_cls = AsyncConnection
+
+ _option_cls: type
+
+ class _trans_ctx(StartableContext):
+ def __init__(self, conn):
+ self.conn = conn
+
+ async def start(self, is_ctxmanager=False):
+ await self.conn.start(is_ctxmanager=is_ctxmanager)
+ self.transaction = self.conn.begin()
+ await self.transaction.__aenter__()
+
+ return self.conn
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.transaction.__aexit__(type_, value, traceback)
+ await self.conn.close()
+
+ await asyncio.shield(go())
+
+ def __init__(self, sync_engine):
+ if not sync_engine.dialect.is_async:
+ raise exc.InvalidRequestError(
+ "The asyncio extension requires an async driver to be used. "
+ f"The loaded {sync_engine.dialect.driver!r} is not async."
+ )
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
+
+ def begin(self):
+ """Return a context manager which when entered will deliver an
+ :class:`_asyncio.AsyncConnection` with an
+ :class:`_asyncio.AsyncTransaction` established.
+
+ E.g.::
+
+ async with async_engine.begin() as conn:
+ await conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ await conn.execute(text("my_special_procedure(5)"))
+
+
+ """
+ conn = self.connect()
+ return self._trans_ctx(conn)
+
+ def connect(self):
+ """Return an :class:`_asyncio.AsyncConnection` object.
+
+ The :class:`_asyncio.AsyncConnection` will procure a database
+ connection from the underlying connection pool when it is entered
+ as an async context manager::
+
+ async with async_engine.connect() as conn:
+ result = await conn.execute(select(user_table))
+
+ The :class:`_asyncio.AsyncConnection` may also be started outside of a
+ context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
+ method.
+
+ """
+
+ return self._connection_cls(self)
+
+ async def raw_connection(self):
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return await greenlet_spawn(self.sync_engine.raw_connection)
+
+ def execution_options(self, **opt):
+ """Return a new :class:`_asyncio.AsyncEngine` that will provide
+ :class:`_asyncio.AsyncConnection` objects with the given execution
+ options.
+
+ Proxied from :meth:`_future.Engine.execution_options`. See that
+ method for details.
+
+ """
+
+ return AsyncEngine(self.sync_engine.execution_options(**opt))
+
+ async def dispose(self):
+ """Dispose of the connection pool used by this
+ :class:`_asyncio.AsyncEngine`.
+
+ This will close all connection pool connections that are
+ **currently checked in**. See the documentation for the underlying
+ :meth:`_future.Engine.dispose` method for further notes.
+
+ .. seealso::
+
+ :meth:`_future.Engine.dispose`
+
+ """
+
+ await greenlet_spawn(self.sync_engine.dispose)
+
+
+class AsyncTransaction(ProxyComparable, StartableContext):
+ """An asyncio proxy for a :class:`_engine.Transaction`."""
+
+ __slots__ = ("connection", "sync_transaction", "nested")
+
+ def __init__(self, connection, nested=False):
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
+ self.nested = nested
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
+ obj = cls.__new__(cls)
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
+ obj.nested = nested
+ return obj
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ @property
+ def _proxied(self):
+ return self.sync_transaction
+
+ @property
+ def is_valid(self):
+ return self._sync_transaction().is_valid
+
+ @property
+ def is_active(self):
+ return self._sync_transaction().is_active
+
+ async def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ await greenlet_spawn(self._sync_transaction().close)
+
+ async def rollback(self):
+ """Roll back this :class:`.Transaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`.Transaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncTransaction` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def _get_sync_engine_or_connection(async_engine):
+ if isinstance(async_engine, AsyncConnection):
+ return async_engine.sync_connection
+
+ try:
+ return async_engine.sync_engine
+ except AttributeError as e:
+ raise exc.ArgumentError(
+ "AsyncEngine expected, got %r" % async_engine
+ ) from e
+
+
+@inspection._inspects(AsyncConnection)
+def _no_insp_for_async_conn_yet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncConnection is currently not supported. "
+ "Please use ``run_sync`` to pass a callable where it's possible "
+ "to call ``inspect`` on the passed connection.",
+ code="xd3s",
+ )
+
+
+@inspection._inspects(AsyncEngine)
+def _no_insp_for_async_engine_xyet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncEngine is currently not supported. "
+ "Please obtain a connection then use ``conn.run_sync`` to pass a "
+ "callable where it's possible to call ``inspect`` on the "
+ "passed connection.",
+ code="xd3s",
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
new file mode 100644
index 0000000..c5d5e01
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/events.py
@@ -0,0 +1,44 @@
+# ext/asyncio/events.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = AsyncConnectable
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+ _target_class_doc = "SomeSession"
+ _dispatch_target = AsyncSession
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py
new file mode 100644
index 0000000..cf0d9a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/exc.py
@@ -0,0 +1,21 @@
+# ext/asyncio/exc.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import exc
+
+
+class AsyncMethodRequired(exc.InvalidRequestError):
+ """an API can't be used because its result would not be
+ compatible with async"""
+
+
+class AsyncContextNotStarted(exc.InvalidRequestError):
+ """a startable context manager has not been started."""
+
+
+class AsyncContextAlreadyStarted(exc.InvalidRequestError):
+ """a startable context manager is already started."""
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
new file mode 100644
index 0000000..a77b6a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -0,0 +1,671 @@
+# ext/asyncio/result.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import operator
+
+from . import exc as async_exc
+from ...engine.result import _NO_ROW
+from ...engine.result import FilterResult
+from ...engine.result import FrozenResult
+from ...engine.result import MergedResult
+from ...sql.base import _generative
+from ...util.concurrency import greenlet_spawn
+
+
+class AsyncCommon(FilterResult):
+ async def close(self):
+ """Close this result."""
+
+ await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
+ """An asyncio wrapper around a :class:`_result.Result` object.
+
+ The :class:`_asyncio.AsyncResult` only applies to statement executions that
+ use a server-side cursor. It is returned only from the
+ :meth:`_asyncio.AsyncConnection.stream` and
+ :meth:`_asyncio.AsyncSession.stream` methods.
+
+ .. note:: As is the case with :class:`_engine.Result`, this object is
+ used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
+ which can yield instances of ORM mapped objects either individually or
+ within tuple-like rows. Note that these result objects do not
+ deduplicate instances or rows automatically as is the case with the
+ legacy :class:`_orm.Query` object. For in-Python de-duplication of
+ instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
+ method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def __init__(self, real_result):
+ self._real_result = real_result
+
+ self._metadata = real_result._metadata
+ self._unique_filter_state = real_result._unique_filter_state
+
+ # BaseCursorResult pre-generates the "_row_getter". Use that
+ # if available rather than building a second one
+ if "_row_getter" in real_result.__dict__:
+ self._set_memoized_attribute(
+ "_row_getter", real_result.__dict__["_row_getter"]
+ )
+
+ def keys(self):
+ """Return the :meth:`_engine.Result.keys` collection from the
+ underlying :class:`_engine.Result`.
+
+ """
+ return self._metadata.keys
+
+ @_generative
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncResult`.
+
+ Refer to :meth:`_engine.Result.unique` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ self._unique_filter_state = (set(), strategy)
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row.
+
+ Refer to :meth:`_engine.Result.columns` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of rows of the size given.
+
+ An async iterator is returned::
+
+ async def scroll_results(connection):
+ result = await connection.stream(select(users_table))
+
+ async for partition in result.partitions(100):
+ print("list of rows: %s" % partition)
+
+ .. seealso::
+
+ :meth:`_engine.Result.partitions`
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchone(self):
+ """Fetch one row.
+
+ When all rows are exhausted, returns None.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch the first row of a result only, use the
+ :meth:`_engine.Result.first` method. To iterate through all
+ rows, iterate the :class:`_engine.Result` object directly.
+
+ :return: a :class:`.Row` object if no filters are applied, or None
+ if no rows remain.
+
+ """
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many rows.
+
+ When all rows are exhausted, returns an empty list.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch rows in groups, use the
+ :meth:`._asyncio.AsyncResult.partitions` method.
+
+ :return: a list of :class:`.Row` objects.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.partitions`
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all rows in a list.
+
+ Closes the result set after invocation. Subsequent invocations
+ will return an empty list.
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first row or None if no row is present.
+
+ Closes the result set and discards remaining rows.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default. To
+ return exactly one single scalar value, that is, the first column of
+ the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
+ or combine :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.first`.
+
+ :return: a :class:`.Row` object, or None
+ if no rows remain.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.scalar`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the result has no rows.
+ Raises :class:`.MultipleResultsFound`
+ if multiple rows are returned.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row` or None if no row is available.
+
+ :raises: :class:`.MultipleResultsFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, True)
+
+ async def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one_or_none`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, True)
+
+ async def one(self):
+ """Return exactly one row or raise an exception.
+
+ Raises :class:`.NoResultFound` if the result returns no
+ rows, or :class:`.MultipleResultsFound` if multiple rows
+ would be returned.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the
+ :meth:`_asyncio.AsyncResult.scalar_one` method, or combine
+ :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.one`.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row`.
+
+ :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalar_one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+ async def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, True)
+
+ async def freeze(self):
+ """Return a callable object that will produce copies of this
+ :class:`_asyncio.AsyncResult` when invoked.
+
+ The callable object returned is an instance of
+ :class:`_engine.FrozenResult`.
+
+ This is used for result set caching. The method must be called
+ on the result when it has been unconsumed, and calling the method
+ will consume the result fully. When the :class:`_engine.FrozenResult`
+ is retrieved from a cache, it can be called any number of times where
+ it will produce a new :class:`_engine.Result` object each time
+ against its stored set of rows.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ """
+
+ return await greenlet_spawn(FrozenResult, self)
+
+ def merge(self, *others):
+ """Merge this :class:`_asyncio.AsyncResult` with other compatible
+ result objects.
+
+ The object returned is an instance of :class:`_engine.MergedResult`,
+ which will be composed of iterators from the given result
+ objects.
+
+ The new result will use the metadata from this result object.
+ The subsequent result objects must be against an identical
+ set of result / cursor metadata, otherwise the behavior is
+ undefined.
+
+ """
+ return MergedResult(self._metadata, (self,) + others)
+
+ def scalars(self, index=0):
+ """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ Refer to :meth:`_result.Result.scalars` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :param index: integer or row key indicating the column to be fetched
+ from each row, defaults to ``0`` indicating the first column.
+
+ :return: a new :class:`_asyncio.AsyncScalarResult` filtering object
+ referring to this :class:`_asyncio.AsyncResult` object.
+
+ """
+ return AsyncScalarResult(self._real_result, index)
+
+ def mappings(self):
+ """Apply a mappings filter to returned rows, returning an instance of
+ :class:`_asyncio.AsyncMappingResult`.
+
+ When this filter is applied, fetching rows will return
+ :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+ Refer to :meth:`_result.Result.mappings` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :return: a new :class:`_asyncio.AsyncMappingResult` filtering object
+ referring to the underlying :class:`_result.Result` object.
+
+ """
+
+ return AsyncMappingResult(self._real_result)
+
+
+class AsyncScalarResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
+ rather than :class:`_row.Row` values.
+
+ The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.scalars` method.
+
+ Refer to the :class:`_result.ScalarResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = False
+
+ def __init__(self, real_result, index):
+ self._real_result = real_result
+
+ if real_result._source_supports_scalars:
+ self._metadata = real_result._metadata
+ self._post_creational_filter = None
+ else:
+ self._metadata = real_result._metadata._reduce([index])
+ self._post_creational_filter = operator.itemgetter(0)
+
+ self._unique_filter_state = real_result._unique_filter_state
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncScalarResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+class AsyncMappingResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
+ values rather than :class:`_engine.Row` values.
+
+ The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.mappings` method.
+
+ Refer to the :class:`_result.MappingResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = True
+
+ _post_creational_filter = operator.attrgetter("_mapping")
+
+ def __init__(self, result):
+ self._real_result = result
+ self._unique_filter_state = result._unique_filter_state
+ self._metadata = result._metadata
+ if result._source_supports_scalars:
+ self._metadata = self._metadata._reduce([0])
+
+ def keys(self):
+ """Return an iterable view which yields the string keys that would
+ be represented by each :class:`.Row`.
+
+ The view also can be tested for key containment using the Python
+ ``in`` operator, which will test both for the string keys represented
+ in the view, as well as for alternate keys such as column objects.
+
+ .. versionchanged:: 1.4 a key view object is returned rather than a
+ plain list.
+
+
+ """
+ return self._metadata.keys
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncMappingResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row."""
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchone(self):
+ """Fetch one object.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+async def _ensure_sync_result(result, calling_method):
+ if not result._is_cursor:
+ cursor_result = getattr(result, "raw", None)
+ else:
+ cursor_result = result
+ if cursor_result and cursor_result.context._is_server_side:
+ await greenlet_spawn(cursor_result.close)
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the %s.%s() method with a "
+ "server-side cursor. "
+ "Use the %s.stream() method for an async "
+ "streaming result set."
+ % (
+ calling_method.__self__.__class__.__name__,
+ calling_method.__name__,
+ calling_method.__self__.__class__.__name__,
+ )
+ )
+ return result
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
new file mode 100644
index 0000000..8eca8c5
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -0,0 +1,107 @@
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+ AsyncSession,
+ ":class:`_asyncio.AsyncSession`",
+ ":class:`_asyncio.scoping.async_scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "invalidate",
+ "merge",
+ "refresh",
+ "rollback",
+ "scalar",
+ "scalars",
+ "stream",
+ "stream_scalars",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class async_scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.AsyncSession` objects.
+
+ See the section :ref:`asyncio_scoped_session` for usage details.
+
+ .. versionadded:: 1.4.19
+
+
+ """
+
+ _support_async = True
+
+ def __init__(self, session_factory, scopefunc):
+ """Construct a new :class:`_asyncio.async_scoped_session`.
+
+ :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`_orm.sessionmaker` which itself was passed the
+ :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
+ parameter::
+
+ async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
+ AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+
+ :param scopefunc: function which defines
+ the current scope. A function such as ``asyncio.current_task``
+ may be useful here.
+
+ """ # noqa: E501
+
+ self.session_factory = session_factory
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ async def remove(self):
+ """Dispose of the current :class:`.AsyncSession`, if present.
+
+ Different from scoped_session's remove method, this method would use
+ await to wait for the close method of AsyncSession.
+
+ """
+
+ if self.registry.has():
+ await self.registry().close()
+ self.registry.clear()
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
new file mode 100644
index 0000000..378cbcb
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -0,0 +1,759 @@
+# ext/asyncio/session.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import asyncio
+
+from . import engine
+from . import result as _result
+from .base import ReversibleProxy
+from .base import StartableContext
+from .result import _ensure_sync_result
+from ... import util
+from ...orm import object_session
+from ...orm import Session
+from ...orm import state as _instance_state
+from ...util.concurrency import greenlet_spawn
+
+_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
+_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
+
+
+@util.create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_asyncio.AsyncSession`",
+ classmethods=["object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "is_modified",
+ "in_transaction",
+ "in_nested_transaction",
+ ],
+ attributes=[
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class AsyncSession(ReversibleProxy):
+ """Asyncio version of :class:`_orm.Session`.
+
+ The :class:`_asyncio.AsyncSession` is a proxy for a traditional
+ :class:`_orm.Session` instance.
+
+ .. versionadded:: 1.4
+
+ To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
+ implementations, see the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
+
+
+ """
+
+ _is_asyncio = True
+
+ dispatch = None
+
+ def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ r"""Construct a new :class:`_asyncio.AsyncSession`.
+
+ All parameters other than ``sync_session_class`` are passed to the
+ ``sync_session_class`` callable directly to instantiate a new
+ :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
+ parameter documentation.
+
+ :param sync_session_class:
+ A :class:`_orm.Session` subclass or other callable which will be used
+ to construct the :class:`_orm.Session` which will be proxied. This
+ parameter may be used to provide custom :class:`_orm.Session`
+ subclasses. Defaults to the
+ :attr:`_asyncio.AsyncSession.sync_session_class` class-level
+ attribute.
+
+ .. versionadded:: 1.4.24
+
+ """
+ kw["future"] = True
+ if bind:
+ self.bind = bind
+ bind = engine._get_sync_engine_or_connection(bind)
+
+ if binds:
+ self.binds = binds
+ binds = {
+ key: engine._get_sync_engine_or_connection(b)
+ for key, b in binds.items()
+ }
+
+ if sync_session_class:
+ self.sync_session_class = sync_session_class
+
+ self.sync_session = self._proxied = self._assign_proxied(
+ self.sync_session_class(bind=bind, binds=binds, **kw)
+ )
+
+ sync_session_class = Session
+ """The class or callable that provides the
+ underlying :class:`_orm.Session` instance for a particular
+ :class:`_asyncio.AsyncSession`.
+
+ At the class level, this attribute is the default value for the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
+ subclasses of :class:`_asyncio.AsyncSession` can override this.
+
+ At the instance level, this attribute indicates the current class or
+ callable that was used to provide the :class:`_orm.Session` instance for
+ this :class:`_asyncio.AsyncSession` instance.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ sync_session: Session
+ """Reference to the underlying :class:`_orm.Session` this
+ :class:`_asyncio.AsyncSession` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+
+ """
+
+ async def refresh(
+ self, instance, attribute_names=None, with_for_update=None
+ ):
+ """Expire and refresh the attributes on the given instance.
+
+ A query will be issued to the database and all attributes will be
+ refreshed with their current database value.
+
+ This is the async version of the :meth:`_orm.Session.refresh` method.
+ See that method for a complete description of all options.
+
+ .. seealso::
+
+ :meth:`_orm.Session.refresh` - main documentation for refresh
+
+ """
+
+ return await greenlet_spawn(
+ self.sync_session.refresh,
+ instance,
+ attribute_names=attribute_names,
+ with_for_update=with_for_update,
+ )
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing sync self as the first
+ argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with AsyncSession(async_engine) as session:
+ await session.run_sync(some_business_method)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+
+ async def execute(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a buffered
+ :class:`_engine.Result` object.
+
+ .. seealso::
+
+ :meth:`_orm.Session.execute` - main documentation for execute
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _EXECUTE_OPTIONS
+ )
+ else:
+ execution_options = _EXECUTE_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a scalar result.
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalar` - main documentation for scalar
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return scalar results.
+
+ :return: a :class:`_result.ScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def get(
+ self,
+ entity,
+ ident,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ ):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ .. seealso::
+
+ :meth:`_orm.Session.get` - main documentation for get
+
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.get,
+ entity,
+ ident,
+ options=options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ )
+
+ async def stream(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object.
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _STREAM_OPTIONS
+ )
+ else:
+ execution_options = _STREAM_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return _result.AsyncResult(result)
+
+ async def stream_scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a stream of scalar results.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.scalars` - non streaming version
+
+ """
+
+ result = await self.stream(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def delete(self, instance):
+ """Mark an instance as deleted.
+
+ The database delete operation occurs upon ``flush()``.
+
+ As this operation may need to cascade along unloaded relationships,
+ it is awaitable to allow for those queries to take place.
+
+ .. seealso::
+
+ :meth:`_orm.Session.delete` - main documentation for delete
+
+ """
+ return await greenlet_spawn(self.sync_session.delete, instance)
+
+ async def merge(self, instance, load=True, options=None):
+ """Copy the state of a given instance into a corresponding instance
+ within this :class:`_asyncio.AsyncSession`.
+
+ .. seealso::
+
+ :meth:`_orm.Session.merge` - main documentation for merge
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.merge, instance, load=load, options=options
+ )
+
+ async def flush(self, objects=None):
+ """Flush all the object changes to the database.
+
+ .. seealso::
+
+ :meth:`_orm.Session.flush` - main documentation for flush
+
+ """
+ await greenlet_spawn(self.sync_session.flush, objects=objects)
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ trans = self.sync_session.get_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ trans = self.sync_session.get_nested_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ """Return a "bind" to which the synchronous proxied :class:`_orm.Session`
+ is bound.
+
+ Unlike the :meth:`_orm.Session.get_bind` method, this method is
+ currently **not** used by this :class:`.AsyncSession` in any way
+ in order to resolve engines for requests.
+
+ .. note::
+
+ This method proxies directly to the :meth:`_orm.Session.get_bind`
+ method, however is currently **not** useful as an override target,
+ in contrast to that of the :meth:`_orm.Session.get_bind` method.
+ The example below illustrates how to implement custom
+ :meth:`_orm.Session.get_bind` schemes that work with
+ :class:`.AsyncSession` and :class:`.AsyncEngine`.
+
+ The pattern introduced at :ref:`session_custom_partitioning`
+ illustrates how to apply a custom bind-lookup scheme to a
+ :class:`_orm.Session` given a set of :class:`_engine.Engine` objects.
+ To apply a corresponding :meth:`_orm.Session.get_bind` implementation
+ for use with a :class:`.AsyncSession` and :class:`.AsyncEngine`
+ objects, continue to subclass :class:`_orm.Session` and apply it to
+ :class:`.AsyncSession` using
+ :paramref:`.AsyncSession.sync_session_class`. The inner method must
+ continue to return :class:`_engine.Engine` instances, which can be
+ acquired from a :class:`_asyncio.AsyncEngine` using the
+ :attr:`_asyncio.AsyncEngine.sync_engine` attribute::
+
+ # using example from "Custom Vertical Partitioning"
+
+
+ import random
+
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.orm import Session, sessionmaker
+
+ # construct async engines w/ async drivers
+ engines = {
+ 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
+ 'other':create_async_engine("sqlite+aiosqlite:///other.db"),
+ 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
+ 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
+ }
+
+ class RoutingSession(Session):
+ def get_bind(self, mapper=None, clause=None, **kw):
+ # within get_bind(), return sync engines
+ if mapper and issubclass(mapper.class_, MyOtherClass):
+ return engines['other'].sync_engine
+ elif self._flushing or isinstance(clause, (Update, Delete)):
+ return engines['leader'].sync_engine
+ else:
+ return engines[
+ random.choice(['follower1','follower2'])
+ ].sync_engine
+
+ # apply to AsyncSession using sync_session_class
+ AsyncSessionMaker = sessionmaker(
+ class_=AsyncSession,
+ sync_session_class=RoutingSession
+ )
+
+ The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
+ implicitly non-blocking context in the same manner as ORM event hooks
+ and functions that are invoked via :meth:`.AsyncSession.run_sync`, so
+ routines that wish to run SQL commands inside of
+ :meth:`_orm.Session.get_bind` can continue to do so using
+ blocking-style code, which will be translated to implicitly async calls
+ at the point of invoking IO on the database drivers.
+
+ """ # noqa: E501
+
+ return self.sync_session.get_bind(
+ mapper=mapper, clause=clause, bind=bind, **kw
+ )
+
+ async def connection(self, **kw):
+ r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
+ this :class:`.Session` object's transactional state.
+
+ This method may also be used to establish execution options for the
+ database connection used by the current transaction.
+
+ .. versionadded:: 1.4.24 Added \**kw arguments which are passed
+ through to the underlying :meth:`_orm.Session.connection` method.
+
+ .. seealso::
+
+ :meth:`_orm.Session.connection` - main documentation for
+ "connection"
+
+ """
+
+ sync_connection = await greenlet_spawn(
+ self.sync_session.connection, **kw
+ )
+ return engine.AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+
+ def begin(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object.
+
+ The underlying :class:`_orm.Session` will perform the
+ "begin" action when the :class:`_asyncio.AsyncSessionTransaction`
+ object is entered::
+
+ async with async_session.begin():
+ # .. ORM transaction is begun
+
+ Note that database IO will not normally occur when the session-level
+ transaction is begun, as database transactions begin on an
+ on-demand basis. However, the begin block is async to accommodate
+ for a :meth:`_orm.SessionEvents.after_transaction_create`
+ event hook that may perform IO.
+
+ For a general description of ORM begin, see
+ :meth:`_orm.Session.begin`.
+
+ """
+
+ return AsyncSessionTransaction(self)
+
+ def begin_nested(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object
+ which will begin a "nested" transaction, e.g. SAVEPOINT.
+
+ Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
+
+ For a general description of ORM begin nested, see
+ :meth:`_orm.Session.begin_nested`.
+
+ """
+
+ return AsyncSessionTransaction(self, nested=True)
+
+ async def rollback(self):
+ """Rollback the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.rollback)
+
+ async def commit(self):
+ """Commit the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.commit)
+
+ async def close(self):
+ """Close out the transactional resources and ORM objects used by this
+ :class:`_asyncio.AsyncSession`.
+
+ This expunges all ORM objects associated with this
+ :class:`_asyncio.AsyncSession`, ends any transaction in progress and
+ :term:`releases` any :class:`_asyncio.AsyncConnection` objects which
+ this :class:`_asyncio.AsyncSession` itself has checked out from
+ associated :class:`_asyncio.AsyncEngine` objects. The operation then
+ leaves the :class:`_asyncio.AsyncSession` in a state which it may be
+ used again.
+
+ .. tip::
+
+ The :meth:`_asyncio.AsyncSession.close` method **does not prevent
+ the Session from being used again**. The
+ :class:`_asyncio.AsyncSession` itself does not actually have a
+ distinct "closed" state; it merely means the
+ :class:`_asyncio.AsyncSession` will release all database
+ connections and ORM objects.
+
+
+ .. seealso::
+
+ :ref:`session_closing` - detail on the semantics of
+ :meth:`_asyncio.AsyncSession.close`
+
+ """
+ await greenlet_spawn(self.sync_session.close)
+
+ async def invalidate(self):
+ """Close this Session, using connection invalidation.
+
+ For a complete description, see :meth:`_orm.Session.invalidate`.
+ """
+ return await greenlet_spawn(self.sync_session.invalidate)
+
+ @classmethod
+ async def close_all(self):
+ """Close all :class:`_asyncio.AsyncSession` sessions."""
+ return await greenlet_spawn(self.sync_session.close_all)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+ def _maker_context_manager(self):
+ # no @contextlib.asynccontextmanager until python3.7, gr
+ return _AsyncSessionContextManager(self)
+
+
+class _AsyncSessionContextManager:
+ def __init__(self, async_session):
+ self.async_session = async_session
+
+ async def __aenter__(self):
+ self.trans = self.async_session.begin()
+ await self.trans.__aenter__()
+ return self.async_session
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.trans.__aexit__(type_, value, traceback)
+ await self.async_session.__aexit__(type_, value, traceback)
+
+ await asyncio.shield(go())
+
+
+class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+ """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
+
+ This object is provided so that a transaction-holding object
+ for the :meth:`_asyncio.AsyncSession.begin` may be returned.
+
+ The object supports both explicit calls to
+ :meth:`_asyncio.AsyncSessionTransaction.commit` and
+ :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
+ async context manager.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("session", "sync_transaction", "nested")
+
+ def __init__(self, session, nested=False):
+ self.session = session
+ self.nested = nested
+ self.sync_transaction = None
+
+ @property
+ def is_active(self):
+ return (
+ self._sync_transaction() is not None
+ and self._sync_transaction().is_active
+ )
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ async def rollback(self):
+ """Roll back this :class:`_asyncio.AsyncTransaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`_asyncio.AsyncTransaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.session.sync_session.begin_nested
+ if self.nested
+ else self.session.sync_session.begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def async_object_session(instance):
+ """Return the :class:`_asyncio.AsyncSession` to which the given instance
+ belongs.
+
+ This function makes use of the sync-API function
+ :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
+ refers to the given instance, and from there links it to the original
+ :class:`_asyncio.AsyncSession`.
+
+ If the :class:`_asyncio.AsyncSession` has been garbage collected, the
+ return value is ``None``.
+
+ This functionality is also available from the
+ :attr:`_orm.InstanceState.async_session` accessor.
+
+ :param instance: an ORM mapped instance
+ :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ session = object_session(instance)
+ if session is not None:
+ return async_session(session)
+ else:
+ return None
+
+
+def async_session(session):
+ """Return the :class:`_asyncio.AsyncSession` which is proxying the given
+ :class:`_orm.Session` object, if any.
+
+ :param session: a :class:`_orm.Session` instance.
+ :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
+
+
+_instance_state._async_provider = async_session
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
new file mode 100644
index 0000000..a5d7267
--- /dev/null
+++ b/lib/sqlalchemy/ext/automap.py
@@ -0,0 +1,1234 @@
+# ext/automap.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system
+which automatically generates mapped classes and relationships from a database
+schema, typically though not necessarily one which is reflected.
+
+It is hoped that the :class:`.AutomapBase` system provides a quick
+and modernized solution to the problem that the very famous
+`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_
+also tries to solve, that of generating a quick and rudimentary object
+model from an existing database on the fly. By addressing the issue strictly
+at the mapper configuration level, and integrating fully with existing
+Declarative class techniques, :class:`.AutomapBase` seeks to provide
+a well-integrated approach to the issue of expediently auto-generating ad-hoc
+mappings.
+
+.. tip:: The :ref:`automap_toplevel` extension is geared towards a
+ "zero declaration" approach, where a complete ORM model including classes
+ and pre-named relationships can be generated on the fly from a database
+ schema. For applications that still want to use explicit class declarations
+ including explicit relationship definitions in conjunction with reflection
+ of tables, the :class:`.DeferredReflection` class, described at
+ :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice.
+
+
+
+Basic Use
+=========
+
+The simplest usage is to reflect an existing database into a new model.
+We create a new :class:`.AutomapBase` class in a similar manner as to how
+we create a declarative base class, using :func:`.automap_base`.
+We then call :meth:`.AutomapBase.prepare` on the resulting base class,
+asking it to reflect the schema and produce mappings::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy.orm import Session
+ from sqlalchemy import create_engine
+
+ Base = automap_base()
+
+ # engine, suppose it has two tables 'user' and 'address' set up
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # reflect the tables
+ Base.prepare(autoload_with=engine)
+
+ # mapped classes are now created with names by default
+ # matching that of the table name.
+ User = Base.classes.user
+ Address = Base.classes.address
+
+ session = Session(engine)
+
+ # rudimentary relationships are produced
+ session.add(Address(email_address="foo@bar.com", user=User(name="foo")))
+ session.commit()
+
+ # collection-based relationships are by default named
+ # "<classname>_collection"
+ print (u1.address_collection)
+
+Above, calling :meth:`.AutomapBase.prepare` while passing along the
+:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the
+:meth:`_schema.MetaData.reflect`
+method will be called on this declarative base
+classes' :class:`_schema.MetaData` collection; then, each **viable**
+:class:`_schema.Table` within the :class:`_schema.MetaData`
+will get a new mapped class
+generated automatically. The :class:`_schema.ForeignKeyConstraint`
+objects which
+link the various tables together will be used to produce new, bidirectional
+:func:`_orm.relationship` objects between classes.
+The classes and relationships
+follow along a default naming scheme that we can customize. At this point,
+our basic mapping consisting of related ``User`` and ``Address`` classes is
+ready to use in the traditional way.
+
+.. note:: By **viable**, we mean that for a table to be mapped, it must
+ specify a primary key. Additionally, if the table is detected as being
+ a pure association table between two other tables, it will not be directly
+ mapped and will instead be configured as a many-to-many table between
+ the mappings for the two referring tables.
+
+Generating Mappings from an Existing MetaData
+=============================================
+
+We can pass a pre-declared :class:`_schema.MetaData` object to
+:func:`.automap_base`.
+This object can be constructed in any way, including programmatically, from
+a serialized file, or from itself being reflected using
+:meth:`_schema.MetaData.reflect`.
+Below we illustrate a combination of reflection and
+explicit table declaration::
+
+ from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey
+ from sqlalchemy.ext.automap import automap_base
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # produce our own MetaData object
+ metadata = MetaData()
+
+ # we can reflect it ourselves from a database, using options
+ # such as 'only' to limit what tables we look at...
+ metadata.reflect(engine, only=['user', 'address'])
+
+ # ... or just define our own Table objects with it (or combine both)
+ Table('user_order', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', ForeignKey('user.id'))
+ )
+
+ # we can then produce a set of mappings from this MetaData.
+ Base = automap_base(metadata=metadata)
+
+ # calling prepare() just sets up mapped classes and relationships.
+ Base.prepare()
+
+ # mapped classes are ready
+ User, Address, Order = Base.classes.user, Base.classes.address,\
+ Base.classes.user_order
+
+Specifying Classes Explicitly
+=============================
+
+.. tip:: If explicit classes are expected to be prominent in an application,
+ consider using :class:`.DeferredReflection` instead.
+
+The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined
+explicitly, in a way similar to that of the :class:`.DeferredReflection` class.
+Classes that extend from :class:`.AutomapBase` act like regular declarative
+classes, but are not immediately mapped after their construction, and are
+instead mapped when we call :meth:`.AutomapBase.prepare`. The
+:meth:`.AutomapBase.prepare` method will make use of the classes we've
+established based on the table name we use. If our schema contains tables
+``user`` and ``address``, we can define one or both of the classes to be used::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ # pre-declare User for the 'user' table
+ class User(Base):
+ __tablename__ = 'user'
+
+ # override schema elements like Columns
+ user_name = Column('name', String)
+
+ # override relationships too, if desired.
+ # we must use the same name that automap would use for the
+ # relationship, and also must refer to the class name that automap will
+ # generate for "address"
+ address_collection = relationship("address", collection_class=set)
+
+ # reflect
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine)
+
+ # we still have Address generated from the tablename "address",
+ # but User is the same as Base.classes.User now
+
+ Address = Base.classes.address
+
+ u1 = session.query(User).first()
+ print (u1.address_collection)
+
+ # the backref is still there:
+ a1 = session.query(Address).first()
+ print (a1.user)
+
+Above, one of the more intricate details is that we illustrated overriding
+one of the :func:`_orm.relationship` objects that automap would have created.
+To do this, we needed to make sure the names match up with what automap
+would normally generate, in that the relationship name would be
+``User.address_collection`` and the name of the class referred to, from
+automap's perspective, is called ``address``, even though we are referring to
+it as ``Address`` within our usage of this class.
+
+Overriding Naming Schemes
+=========================
+
+:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and
+relationship names based on a schema, which means it has decision points in how
+these names are determined. These three decision points are provided using
+functions which can be passed to the :meth:`.AutomapBase.prepare` method, and
+are known as :func:`.classname_for_table`,
+:func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship`. Any or all of these
+functions are provided as in the example below, where we use a "camel case"
+scheme for class names and a "pluralizer" for collection names using the
+`Inflect <https://pypi.org/project/inflect>`_ package::
+
+ import re
+ import inflect
+
+ def camelize_classname(base, tablename, table):
+ "Produce a 'camelized' class name, e.g. "
+ "'words_and_underscores' -> 'WordsAndUnderscores'"
+
+ return str(tablename[0].upper() + \
+ re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:]))
+
+ _pluralizer = inflect.engine()
+ def pluralize_collection(base, local_cls, referred_cls, constraint):
+ "Produce an 'uncamelized', 'pluralized' class name, e.g. "
+ "'SomeTerm' -> 'some_terms'"
+
+ referred_name = referred_cls.__name__
+ uncamelized = re.sub(r'[A-Z]',
+ lambda m: "_%s" % m.group(0).lower(),
+ referred_name)[1:]
+ pluralized = _pluralizer.plural(uncamelized)
+ return pluralized
+
+ from sqlalchemy.ext.automap import automap_base
+
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ Base.prepare(autoload_with=engine,
+ classname_for_table=camelize_classname,
+ name_for_collection_relationship=pluralize_collection
+ )
+
+From the above mapping, we would now have classes ``User`` and ``Address``,
+where the collection from ``User`` to ``Address`` is called
+``User.addresses``::
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ u1 = User(addresses=[Address(email="foo@bar.com")])
+
+Relationship Detection
+======================
+
+The vast majority of what automap accomplishes is the generation of
+:func:`_orm.relationship` structures based on foreign keys. The mechanism
+by which this works for many-to-one and one-to-many relationships is as
+follows:
+
+1. A given :class:`_schema.Table`, known to be mapped to a particular class,
+ is examined for :class:`_schema.ForeignKeyConstraint` objects.
+
+2. From each :class:`_schema.ForeignKeyConstraint`, the remote
+ :class:`_schema.Table`
+ object present is matched up to the class to which it is to be mapped,
+ if any, else it is skipped.
+
+3. As the :class:`_schema.ForeignKeyConstraint`
+ we are examining corresponds to a
+ reference from the immediate mapped class, the relationship will be set up
+ as a many-to-one referring to the referred class; a corresponding
+ one-to-many backref will be created on the referred class referring
+ to this class.
+
+4. If any of the columns that are part of the
+ :class:`_schema.ForeignKeyConstraint`
+ are not nullable (e.g. ``nullable=False``), a
+ :paramref:`_orm.relationship.cascade` keyword argument
+ of ``all, delete-orphan`` will be added to the keyword arguments to
+ be passed to the relationship or backref. If the
+ :class:`_schema.ForeignKeyConstraint` reports that
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable
+ set of columns, the option :paramref:`_orm.relationship.passive_deletes`
+ flag is set to ``True`` in the set of relationship keyword arguments.
+ Note that not all backends support reflection of ON DELETE.
+
+ .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key
+ constraints when producing a one-to-many relationship and establish
+ a default cascade of ``all, delete-orphan`` if so; additionally,
+ if the constraint specifies
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns,
+ the ``passive_deletes=True`` option is also added.
+
+5. The names of the relationships are determined using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ callable functions. It is important to note that the default relationship
+ naming derives the name from the **the actual class name**. If you've
+ given a particular class an explicit name by declaring it, or specified an
+ alternate class naming scheme, that's the name from which the relationship
+ name will be derived.
+
+6. The classes are inspected for an existing mapped property matching these
+ names. If one is detected on one side, but none on the other side,
+ :class:`.AutomapBase` attempts to create a relationship on the missing side,
+ then uses the :paramref:`_orm.relationship.back_populates`
+ parameter in order to
+ point the new relationship to the other side.
+
+7. In the usual case where no relationship is on either side,
+ :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the
+ "many-to-one" side and matches it to the other using the
+ :paramref:`_orm.relationship.backref` parameter.
+
+8. Production of the :func:`_orm.relationship` and optionally the
+ :func:`.backref`
+ is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship`
+ function, which can be supplied by the end-user in order to augment
+ the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to
+ make use of custom implementations of these functions.
+
+Custom Relationship Arguments
+-----------------------------
+
+The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used
+to add parameters to relationships. For most cases, we can make use of the
+existing :func:`.automap.generate_relationship` function to return
+the object, after augmenting the given keyword dictionary with our own
+arguments.
+
+Below is an illustration of how to send
+:paramref:`_orm.relationship.cascade` and
+:paramref:`_orm.relationship.passive_deletes`
+options along to all one-to-many relationships::
+
+ from sqlalchemy.ext.automap import generate_relationship
+
+ def _gen_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw):
+ if direction is interfaces.ONETOMANY:
+ kw['cascade'] = 'all, delete-orphan'
+ kw['passive_deletes'] = True
+ # make use of the built-in function to actually return
+ # the result.
+ return generate_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw)
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine,
+ generate_relationship=_gen_relationship)
+
+Many-to-Many relationships
+--------------------------
+
+:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g.
+those which contain a ``secondary`` argument. The process for producing these
+is as follows:
+
+1. A given :class:`_schema.Table` is examined for
+ :class:`_schema.ForeignKeyConstraint`
+ objects, before any mapped class has been assigned to it.
+
+2. If the table contains two and exactly two
+ :class:`_schema.ForeignKeyConstraint`
+ objects, and all columns within this table are members of these two
+ :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a
+ "secondary" table, and will **not be mapped directly**.
+
+3. The two (or one, for self-referential) external tables to which the
+ :class:`_schema.Table`
+ refers to are matched to the classes to which they will be
+ mapped, if any.
+
+4. If mapped classes for both sides are located, a many-to-many bi-directional
+ :func:`_orm.relationship` / :func:`.backref`
+ pair is created between the two
+ classes.
+
+5. The override logic for many-to-many works the same as that of one-to-many/
+ many-to-one; the :func:`.generate_relationship` function is called upon
+ to generate the structures and existing attributes will be maintained.
+
+Relationships with Inheritance
+------------------------------
+
+:mod:`.sqlalchemy.ext.automap` will not generate any relationships between
+two classes that are in an inheritance relationship. That is, with two
+classes given as follows::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on': type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ }
+
+The foreign key from ``Engineer`` to ``Employee`` is used not for a
+relationship, but to establish joined inheritance between the two classes.
+
+Note that this means automap will not generate *any* relationships
+for foreign keys that link from a subclass to a superclass. If a mapping
+has actual relationships from subclass to superclass as well, those
+need to be explicit. Below, as we have two separate foreign keys
+from ``Engineer`` to ``Employee``, we need to set up both the relationship
+we want as well as the ``inherit_condition``, as these are not things
+SQLAlchemy can guess::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on':type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ favorite_employee_id = Column(Integer, ForeignKey('employee.id'))
+
+ favorite_employee = relationship(Employee,
+ foreign_keys=favorite_employee_id)
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ 'inherit_condition': id == Employee.id
+ }
+
+Handling Simple Naming Conflicts
+--------------------------------
+
+In the case of naming conflicts during mapping, override any of
+:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship` as needed. For example, if
+automap is attempting to name a many-to-one relationship the same as an
+existing column, an alternate convention can be conditionally selected. Given
+a schema:
+
+.. sourcecode:: sql
+
+ CREATE TABLE table_a (
+ id INTEGER PRIMARY KEY
+ );
+
+ CREATE TABLE table_b (
+ id INTEGER PRIMARY KEY,
+ table_a INTEGER,
+ FOREIGN KEY(table_a) REFERENCES table_a(id)
+ );
+
+The above schema will first automap the ``table_a`` table as a class named
+``table_a``; it will then automap a relationship onto the class for ``table_b``
+with the same name as this related class, e.g. ``table_a``. This
+relationship name conflicts with the mapping column ``table_b.table_a``,
+and will emit an error on mapping.
+
+We can resolve this conflict by using an underscore as follows::
+
+ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ name = referred_cls.__name__.lower()
+ local_table = local_cls.__table__
+ if name in local_table.columns:
+ newname = name + "_"
+ warnings.warn(
+ "Already detected name %s present. using %s" %
+ (name, newname))
+ return newname
+ return name
+
+
+ Base.prepare(autoload_with=engine,
+ name_for_scalar_relationship=name_for_scalar_relationship)
+
+Alternatively, we can change the name on the column side. The columns
+that are mapped can be modified using the technique described at
+:ref:`mapper_column_distinct_names`, by assigning the column explicitly
+to a new name::
+
+ Base = automap_base()
+
+ class TableB(Base):
+ __tablename__ = 'table_b'
+ _table_a = Column('table_a', ForeignKey('table_a.id'))
+
+ Base.prepare(autoload_with=engine)
+
+
+Using Automap with Explicit Declarations
+========================================
+
+As noted previously, automap has no dependency on reflection, and can make
+use of any collection of :class:`_schema.Table` objects within a
+:class:`_schema.MetaData`
+collection. From this, it follows that automap can also be used
+generate missing relationships given an otherwise complete model that fully
+defines table metadata::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import Column, Integer, String, ForeignKey
+
+ Base = automap_base()
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ class Address(Base):
+ __tablename__ = 'address'
+
+ id = Column(Integer, primary_key=True)
+ email = Column(String)
+ user_id = Column(ForeignKey('user.id'))
+
+ # produce relationships
+ Base.prepare()
+
+ # mapping is complete, with "address_collection" and
+ # "user" relationships
+ a1 = Address(email='u1')
+ a2 = Address(email='u2')
+ u1 = User(address_collection=[a1, a2])
+ assert a1.user is u1
+
+Above, given mostly complete ``User`` and ``Address`` mappings, the
+:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a
+bidirectional relationship pair ``Address.user`` and
+``User.address_collection`` to be generated on the mapped classes.
+
+Note that when subclassing :class:`.AutomapBase`,
+the :meth:`.AutomapBase.prepare` method is required; if not called, the classes
+we've declared are in an un-mapped state.
+
+
+.. _automap_intercepting_columns:
+
+Intercepting Column Definitions
+===============================
+
+The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an
+event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept
+the information reflected about a database column before the :class:`_schema.Column`
+object is constructed. For example if we wanted to map columns using a
+naming convention such as ``"attr_<columnname>"``, the event could
+be applied as::
+
+ @event.listens_for(Base.metadata, "column_reflect")
+ def column_reflect(inspector, table, column_info):
+ # set column.key = "attr_<lower_case_name>"
+ column_info['key'] = "attr_%s" % column_info['name'].lower()
+
+ # run reflection
+ Base.prepare(autoload_with=engine)
+
+.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event
+ may be applied to a :class:`_schema.MetaData` object.
+
+.. seealso::
+
+ :meth:`_events.DDLEvents.column_reflect`
+
+ :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation
+
+
+""" # noqa
+from .. import util
+from ..orm import backref
+from ..orm import declarative_base as _declarative_base
+from ..orm import exc as orm_exc
+from ..orm import interfaces
+from ..orm import relationship
+from ..orm.decl_base import _DeferredMapperConfig
+from ..orm.mapper import _CONFIGURE_MUTEX
+from ..schema import ForeignKeyConstraint
+from ..sql import and_
+
+
+def classname_for_table(base, tablename, table):
+ """Return the class name that should be used, given the name
+ of a table.
+
+ The default implementation is::
+
+ return str(tablename)
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.classname_for_table`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param tablename: string name of the :class:`_schema.Table`.
+
+ :param table: the :class:`_schema.Table` object itself.
+
+ :return: a string class name.
+
+ .. note::
+
+ In Python 2, the string used for the class name **must** be a
+ non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute
+ of :class:`_schema.Table` is typically a Python unicode subclass,
+ so the
+ ``str()`` function should be applied to this name, after accounting for
+ any non-ASCII characters.
+
+ """
+ return str(tablename)
+
+
+def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a scalar object reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower()
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower()
+
+
+def name_for_collection_relationship(
+ base, local_cls, referred_cls, constraint
+):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a collection reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower() + "_collection"
+
+ Alternate implementations
+ can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower() + "_collection"
+
+
+def generate_relationship(
+ base, direction, return_fn, attrname, local_cls, referred_cls, **kw
+):
+ r"""Generate a :func:`_orm.relationship` or :func:`.backref`
+ on behalf of two
+ mapped classes.
+
+ An alternate implementation of this function can be specified using the
+ :paramref:`.AutomapBase.prepare.generate_relationship` parameter.
+
+ The default implementation of this function is as follows::
+
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param direction: indicate the "direction" of the relationship; this will
+ be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`.
+
+ :param return_fn: the function that is used by default to create the
+ relationship. This will be either :func:`_orm.relationship` or
+ :func:`.backref`. The :func:`.backref` function's result will be used to
+ produce a new :func:`_orm.relationship` in a second step,
+ so it is critical
+ that user-defined implementations correctly differentiate between the two
+ functions, if a custom relationship function is being used.
+
+ :param attrname: the attribute name to which this relationship is being
+ assigned. If the value of :paramref:`.generate_relationship.return_fn` is
+ the :func:`.backref` function, then this name is the name that is being
+ assigned to the backref.
+
+ :param local_cls: the "local" class to which this relationship or backref
+ will be locally present.
+
+ :param referred_cls: the "referred" class to which the relationship or
+ backref refers to.
+
+ :param \**kw: all additional keyword arguments are passed along to the
+ function.
+
+ :return: a :func:`_orm.relationship` or :func:`.backref` construct,
+ as dictated
+ by the :paramref:`.generate_relationship.return_fn` parameter.
+
+ """
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+
+class AutomapBase(object):
+ """Base class for an "automap" schema.
+
+ The :class:`.AutomapBase` class can be compared to the "declarative base"
+ class that is produced by the :func:`.declarative.declarative_base`
+ function. In practice, the :class:`.AutomapBase` class is always used
+ as a mixin along with an actual declarative base.
+
+ A new subclassable :class:`.AutomapBase` is typically instantiated
+ using the :func:`.automap_base` function.
+
+ .. seealso::
+
+ :ref:`automap_toplevel`
+
+ """
+
+ __abstract__ = True
+
+ classes = None
+ """An instance of :class:`.util.Properties` containing classes.
+
+ This object behaves much like the ``.c`` collection on a table. Classes
+ are present under the name they were given, e.g.::
+
+ Base = automap_base()
+ Base.prepare(autoload_with=some_engine)
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ """
+
+ @classmethod
+ @util.deprecated_params(
+ engine=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.engine` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Please use the "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "parameter.",
+ ),
+ reflect=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.reflect` "
+ "parameter is deprecated and will be removed in a future "
+ "release. Reflection is enabled when "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "is passed.",
+ ),
+ )
+ def prepare(
+ cls,
+ autoload_with=None,
+ engine=None,
+ reflect=False,
+ schema=None,
+ classname_for_table=None,
+ collection_class=None,
+ name_for_scalar_relationship=None,
+ name_for_collection_relationship=None,
+ generate_relationship=None,
+ reflection_options=util.EMPTY_DICT,
+ ):
+ """Extract mapped classes and relationships from the
+ :class:`_schema.MetaData` and
+ perform mappings.
+
+ :param engine: an :class:`_engine.Engine` or
+ :class:`_engine.Connection` with which
+ to perform schema reflection, if specified.
+ If the :paramref:`.AutomapBase.prepare.reflect` argument is False,
+ this object is not used.
+
+ :param reflect: if True, the :meth:`_schema.MetaData.reflect`
+ method is called
+ on the :class:`_schema.MetaData` associated with this
+ :class:`.AutomapBase`.
+ The :class:`_engine.Engine` passed via
+ :paramref:`.AutomapBase.prepare.engine` will be used to perform the
+ reflection if present; else, the :class:`_schema.MetaData`
+ should already be
+ bound to some engine else the operation will fail.
+
+ :param classname_for_table: callable function which will be used to
+ produce new class names, given a table name. Defaults to
+ :func:`.classname_for_table`.
+
+ :param name_for_scalar_relationship: callable function which will be
+ used to produce relationship names for scalar relationships. Defaults
+ to :func:`.name_for_scalar_relationship`.
+
+ :param name_for_collection_relationship: callable function which will
+ be used to produce relationship names for collection-oriented
+ relationships. Defaults to :func:`.name_for_collection_relationship`.
+
+ :param generate_relationship: callable function which will be used to
+ actually generate :func:`_orm.relationship` and :func:`.backref`
+ constructs. Defaults to :func:`.generate_relationship`.
+
+ :param collection_class: the Python collection class that will be used
+ when a new :func:`_orm.relationship`
+ object is created that represents a
+ collection. Defaults to ``list``.
+
+ :param schema: When present in conjunction with the
+ :paramref:`.AutomapBase.prepare.reflect` flag, is passed to
+ :meth:`_schema.MetaData.reflect`
+ to indicate the primary schema where tables
+ should be reflected from. When omitted, the default schema in use
+ by the database connection is used.
+
+ .. versionadded:: 1.1
+
+ :param reflection_options: When present, this dictionary of options
+ will be passed to :meth:`_schema.MetaData.reflect`
+ to supply general reflection-specific options like ``only`` and/or
+ dialect-specific options like ``oracle_resolve_synonyms``.
+
+ .. versionadded:: 1.4
+
+ """
+ glbls = globals()
+ if classname_for_table is None:
+ classname_for_table = glbls["classname_for_table"]
+ if name_for_scalar_relationship is None:
+ name_for_scalar_relationship = glbls[
+ "name_for_scalar_relationship"
+ ]
+ if name_for_collection_relationship is None:
+ name_for_collection_relationship = glbls[
+ "name_for_collection_relationship"
+ ]
+ if generate_relationship is None:
+ generate_relationship = glbls["generate_relationship"]
+ if collection_class is None:
+ collection_class = list
+
+ if autoload_with:
+ reflect = True
+
+ if engine:
+ autoload_with = engine
+
+ if reflect:
+ opts = dict(
+ schema=schema,
+ extend_existing=True,
+ autoload_replace=False,
+ )
+ if reflection_options:
+ opts.update(reflection_options)
+ cls.metadata.reflect(autoload_with, **opts)
+
+ with _CONFIGURE_MUTEX:
+ table_to_map_config = dict(
+ (m.local_table, m)
+ for m in _DeferredMapperConfig.classes_for_base(
+ cls, sort=False
+ )
+ )
+
+ many_to_many = []
+
+ for table in cls.metadata.tables.values():
+ lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
+ if lcl_m2m is not None:
+ many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
+ elif not table.primary_key:
+ continue
+ elif table not in table_to_map_config:
+ mapped_cls = type(
+ classname_for_table(cls, table.name, table),
+ (cls,),
+ {"__table__": table},
+ )
+ map_config = _DeferredMapperConfig.config_for_cls(
+ mapped_cls
+ )
+ cls.classes[map_config.cls.__name__] = mapped_cls
+ table_to_map_config[table] = map_config
+
+ for map_config in table_to_map_config.values():
+ _relationships_for_fks(
+ cls,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
+ _m2m_relationship(
+ cls,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for map_config in _DeferredMapperConfig.classes_for_base(cls):
+ map_config.map()
+
+ _sa_decl_prepare = True
+ """Indicate that the mapping of classes should be deferred.
+
+ The presence of this attribute name indicates to declarative
+ that the call to mapper() should not occur immediately; instead,
+ information about the table and attributes to be mapped are gathered
+ into an internal structure called _DeferredMapperConfig. These
+ objects can be collected later using classes_for_base(), additional
+ mapping decisions can be made, and then the map() method will actually
+ apply the mapping.
+
+ The only real reason this deferral of the whole
+ thing is needed is to support primary key columns that aren't reflected
+ yet when the class is declared; everything else can theoretically be
+ added to the mapper later. However, the _DeferredMapperConfig is a
+ nice interface in any case which exists at that not usually exposed point
+ at which declarative has the class and the Table but hasn't called
+ mapper() yet.
+
+ """
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AutomapBase. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+def automap_base(declarative_base=None, **kw):
+ r"""Produce a declarative automap base.
+
+ This function produces a new base class that is a product of the
+ :class:`.AutomapBase` class as well a declarative base produced by
+ :func:`.declarative.declarative_base`.
+
+ All parameters other than ``declarative_base`` are keyword arguments
+ that are passed directly to the :func:`.declarative.declarative_base`
+ function.
+
+ :param declarative_base: an existing class produced by
+ :func:`.declarative.declarative_base`. When this is passed, the function
+ no longer invokes :func:`.declarative.declarative_base` itself, and all
+ other keyword arguments are ignored.
+
+ :param \**kw: keyword arguments are passed along to
+ :func:`.declarative.declarative_base`.
+
+ """
+ if declarative_base is None:
+ Base = _declarative_base(**kw)
+ else:
+ Base = declarative_base
+
+ return type(
+ Base.__name__,
+ (AutomapBase, Base),
+ {"__abstract__": True, "classes": util.Properties({})},
+ )
+
+
+def _is_many_to_many(automap_base, table):
+ fk_constraints = [
+ const
+ for const in table.constraints
+ if isinstance(const, ForeignKeyConstraint)
+ ]
+ if len(fk_constraints) != 2:
+ return None, None, None
+
+ cols = sum(
+ [
+ [fk.parent for fk in fk_constraint.elements]
+ for fk_constraint in fk_constraints
+ ],
+ [],
+ )
+
+ if set(cols) != set(table.c):
+ return None, None, None
+
+ return (
+ fk_constraints[0].elements[0].column.table,
+ fk_constraints[1].elements[0].column.table,
+ fk_constraints,
+ )
+
+
+def _relationships_for_fks(
+ automap_base,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+ local_table = map_config.local_table
+ local_cls = map_config.cls # derived from a weakref, may be None
+
+ if local_table is None or local_cls is None:
+ return
+ for constraint in local_table.constraints:
+ if isinstance(constraint, ForeignKeyConstraint):
+ fks = constraint.elements
+ referred_table = fks[0].column.table
+ referred_cfg = table_to_map_config.get(referred_table, None)
+ if referred_cfg is None:
+ continue
+ referred_cls = referred_cfg.cls
+
+ if local_cls is not referred_cls and issubclass(
+ local_cls, referred_cls
+ ):
+ continue
+
+ relationship_name = name_for_scalar_relationship(
+ automap_base, local_cls, referred_cls, constraint
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, constraint
+ )
+
+ o2m_kws = {}
+ nullable = False not in {fk.parent.nullable for fk in fks}
+ if not nullable:
+ o2m_kws["cascade"] = "all, delete-orphan"
+
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "cascade"
+ ):
+ o2m_kws["passive_deletes"] = True
+ else:
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "set null"
+ ):
+ o2m_kws["passive_deletes"] = True
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ else:
+ backref_obj = None
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOONE,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ backref=backref_obj,
+ remote_side=[fk.column for fk in constraint.elements],
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
+
+
+def _m2m_relationship(
+ automap_base,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+
+ map_config = table_to_map_config.get(lcl_m2m, None)
+ referred_cfg = table_to_map_config.get(rem_m2m, None)
+ if map_config is None or referred_cfg is None:
+ return
+
+ local_cls = map_config.cls
+ referred_cls = referred_cfg.cls
+
+ relationship_name = name_for_collection_relationship(
+ automap_base, local_cls, referred_cls, m2m_const[0]
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, m2m_const[1]
+ )
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if table in table_to_map_config:
+ overlaps = "__*"
+ else:
+ overlaps = None
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ overlaps=overlaps,
+ )
+ else:
+ backref_obj = None
+
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ backref=backref_obj,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
new file mode 100644
index 0000000..109e0c0
--- /dev/null
+++ b/lib/sqlalchemy/ext/baked.py
@@ -0,0 +1,648 @@
+# sqlalchemy/ext/baked.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Baked query extension.
+
+Provides a creational pattern for the :class:`.query.Query` object which
+allows the fully constructed object, Core select statement, and string
+compiled result to be fully cached.
+
+
+"""
+
+import logging
+
+from .. import exc as sa_exc
+from .. import util
+from ..orm import exc as orm_exc
+from ..orm import strategy_options
+from ..orm.query import Query
+from ..orm.session import Session
+from ..sql import func
+from ..sql import literal_column
+from ..sql import util as sql_util
+from ..util import collections_abc
+
+
+log = logging.getLogger(__name__)
+
+
+class Bakery(object):
+ """Callable which returns a :class:`.BakedQuery`.
+
+ This object is returned by the class method
+ :meth:`.BakedQuery.bakery`. It exists as an object
+ so that the "cache" can be easily inspected.
+
+ .. versionadded:: 1.2
+
+
+ """
+
+ __slots__ = "cls", "cache"
+
+ def __init__(self, cls_, cache):
+ self.cls = cls_
+ self.cache = cache
+
+ def __call__(self, initial_fn, *args):
+ return self.cls(self.cache, initial_fn, args)
+
+
+class BakedQuery(object):
+ """A builder object for :class:`.query.Query` objects."""
+
+ __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
+
+ def __init__(self, bakery, initial_fn, args=()):
+ self._cache_key = ()
+ self._update_cache_key(initial_fn, args)
+ self.steps = [initial_fn]
+ self._spoiled = False
+ self._bakery = bakery
+
+ @classmethod
+ def bakery(cls, size=200, _size_alert=None):
+ """Construct a new bakery.
+
+ :return: an instance of :class:`.Bakery`
+
+ """
+
+ return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
+
+ def _clone(self):
+ b1 = BakedQuery.__new__(BakedQuery)
+ b1._cache_key = self._cache_key
+ b1.steps = list(self.steps)
+ b1._bakery = self._bakery
+ b1._spoiled = self._spoiled
+ return b1
+
+ def _update_cache_key(self, fn, args=()):
+ self._cache_key += (fn.__code__,) + args
+
+ def __iadd__(self, other):
+ if isinstance(other, tuple):
+ self.add_criteria(*other)
+ else:
+ self.add_criteria(other)
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, tuple):
+ return self.with_criteria(*other)
+ else:
+ return self.with_criteria(other)
+
+ def add_criteria(self, fn, *args):
+ """Add a criteria function to this :class:`.BakedQuery`.
+
+ This is equivalent to using the ``+=`` operator to
+ modify a :class:`.BakedQuery` in-place.
+
+ """
+ self._update_cache_key(fn, args)
+ self.steps.append(fn)
+ return self
+
+ def with_criteria(self, fn, *args):
+ """Add a criteria function to a :class:`.BakedQuery` cloned from this
+ one.
+
+ This is equivalent to using the ``+`` operator to
+ produce a new :class:`.BakedQuery` with modifications.
+
+ """
+ return self._clone().add_criteria(fn, *args)
+
+ def for_session(self, session):
+ """Return a :class:`_baked.Result` object for this
+ :class:`.BakedQuery`.
+
+ This is equivalent to calling the :class:`.BakedQuery` as a
+ Python callable, e.g. ``result = my_baked_query(session)``.
+
+ """
+ return Result(self, session)
+
+ def __call__(self, session):
+ return self.for_session(session)
+
+ def spoil(self, full=False):
+ """Cancel any query caching that will occur on this BakedQuery object.
+
+ The BakedQuery can continue to be used normally, however additional
+ creational functions will not be cached; they will be called
+ on every invocation.
+
+ This is to support the case where a particular step in constructing
+ a baked query disqualifies the query from being cacheable, such
+ as a variant that relies upon some uncacheable value.
+
+ :param full: if False, only functions added to this
+ :class:`.BakedQuery` object subsequent to the spoil step will be
+ non-cached; the state of the :class:`.BakedQuery` up until
+ this point will be pulled from the cache. If True, then the
+ entire :class:`_query.Query` object is built from scratch each
+ time, with all creational functions being called on each
+ invocation.
+
+ """
+ if not full and not self._spoiled:
+ _spoil_point = self._clone()
+ _spoil_point._cache_key += ("_query_only",)
+ self.steps = [_spoil_point._retrieve_baked_query]
+ self._spoiled = True
+ return self
+
+ def _effective_key(self, session):
+ """Return the key that actually goes into the cache dictionary for
+ this :class:`.BakedQuery`, taking into account the given
+ :class:`.Session`.
+
+ This basically means we also will include the session's query_class,
+ as the actual :class:`_query.Query` object is part of what's cached
+ and needs to match the type of :class:`_query.Query` that a later
+ session will want to use.
+
+ """
+ return self._cache_key + (session._query_cls,)
+
+ def _with_lazyload_options(self, options, effective_path, cache_path=None):
+ """Cloning version of _add_lazyload_options."""
+ q = self._clone()
+ q._add_lazyload_options(options, effective_path, cache_path=cache_path)
+ return q
+
+ def _add_lazyload_options(self, options, effective_path, cache_path=None):
+ """Used by per-state lazy loaders to add options to the
+ "lazy load" query from a parent query.
+
+ Creates a cache key based on given load path and query options;
+ if a repeatable cache key cannot be generated, the query is
+ "spoiled" so that it won't use caching.
+
+ """
+
+ key = ()
+
+ if not cache_path:
+ cache_path = effective_path
+
+ for opt in options:
+ if opt._is_legacy_option or opt._is_compile_state:
+ ck = opt._generate_cache_key()
+ if ck is None:
+ self.spoil(full=True)
+ else:
+ assert not ck[1], (
+ "loader options with variable bound parameters "
+ "not supported with baked queries. Please "
+ "use new-style select() statements for cached "
+ "ORM queries."
+ )
+ key += ck[0]
+
+ self.add_criteria(
+ lambda q: q._with_current_path(effective_path).options(*options),
+ cache_path.path,
+ key,
+ )
+
+ def _retrieve_baked_query(self, session):
+ query = self._bakery.get(self._effective_key(session), None)
+ if query is None:
+ query = self._as_query(session)
+ self._bakery[self._effective_key(session)] = query.with_session(
+ None
+ )
+ return query.with_session(session)
+
+ def _bake(self, session):
+ query = self._as_query(session)
+ query.session = None
+
+ # in 1.4, this is where before_compile() event is
+ # invoked
+ statement = query._statement_20()
+
+ # if the query is not safe to cache, we still do everything as though
+ # we did cache it, since the receiver of _bake() assumes subqueryload
+ # context was set up, etc.
+ #
+ # note also we want to cache the statement itself because this
+ # allows the statement itself to hold onto its cache key that is
+ # used by the Connection, which in itself is more expensive to
+ # generate than what BakedQuery was able to provide in 1.3 and prior
+
+ if statement._compile_options._bake_ok:
+ self._bakery[self._effective_key(session)] = (
+ query,
+ statement,
+ )
+
+ return query, statement
+
+ def to_query(self, query_or_session):
+ """Return the :class:`_query.Query` object for use as a subquery.
+
+ This method should be used within the lambda callable being used
+ to generate a step of an enclosing :class:`.BakedQuery`. The
+ parameter should normally be the :class:`_query.Query` object that
+ is passed to the lambda::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(lambda s: s.query(Address))
+ main_bq += lambda q: q.filter(
+ sub_bq.to_query(q).exists())
+
+ In the case where the subquery is used in the first callable against
+ a :class:`.Session`, the :class:`.Session` is also accepted::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(
+ lambda s: s.query(
+ Address.id, sub_bq.to_query(q).scalar_subquery())
+ )
+
+ :param query_or_session: a :class:`_query.Query` object or a class
+ :class:`.Session` object, that is assumed to be within the context
+ of an enclosing :class:`.BakedQuery` callable.
+
+
+ .. versionadded:: 1.3
+
+
+ """
+
+ if isinstance(query_or_session, Session):
+ session = query_or_session
+ elif isinstance(query_or_session, Query):
+ session = query_or_session.session
+ if session is None:
+ raise sa_exc.ArgumentError(
+ "Given Query needs to be associated with a Session"
+ )
+ else:
+ raise TypeError(
+ "Query or Session object expected, got %r."
+ % type(query_or_session)
+ )
+ return self._as_query(session)
+
+ def _as_query(self, session):
+ query = self.steps[0](session)
+
+ for step in self.steps[1:]:
+ query = step(query)
+
+ return query
+
+
+class Result(object):
+ """Invokes a :class:`.BakedQuery` against a :class:`.Session`.
+
+ The :class:`_baked.Result` object is where the actual :class:`.query.Query`
+ object gets created, or retrieved from the cache,
+ against a target :class:`.Session`, and is then invoked for results.
+
+ """
+
+ __slots__ = "bq", "session", "_params", "_post_criteria"
+
+ def __init__(self, bq, session):
+ self.bq = bq
+ self.session = session
+ self._params = {}
+ self._post_criteria = []
+
+ def params(self, *args, **kw):
+ """Specify parameters to be replaced into the string SQL statement."""
+
+ if len(args) == 1:
+ kw.update(args[0])
+ elif len(args) > 0:
+ raise sa_exc.ArgumentError(
+ "params() takes zero or one positional argument, "
+ "which is a dictionary."
+ )
+ self._params.update(kw)
+ return self
+
+ def _using_post_criteria(self, fns):
+ if fns:
+ self._post_criteria.extend(fns)
+ return self
+
+ def with_post_criteria(self, fn):
+ """Add a criteria function that will be applied post-cache.
+
+ This adds a function that will be run against the
+ :class:`_query.Query` object after it is retrieved from the
+ cache. This currently includes **only** the
+ :meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
+ methods.
+
+ .. warning:: :meth:`_baked.Result.with_post_criteria`
+ functions are applied
+ to the :class:`_query.Query`
+ object **after** the query's SQL statement
+ object has been retrieved from the cache. Only
+ :meth:`_query.Query.params` and
+ :meth:`_query.Query.execution_options`
+ methods should be used.
+
+
+ .. versionadded:: 1.2
+
+
+ """
+ return self._using_post_criteria([fn])
+
+ def _as_query(self):
+ q = self.bq._as_query(self.session).params(self._params)
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ def __str__(self):
+ return str(self._as_query())
+
+ def __iter__(self):
+ return self._iter().__iter__()
+
+ def _iter(self):
+ bq = self.bq
+
+ if not self.session.enable_baked_queries or bq._spoiled:
+ return self._as_query()._iter()
+
+ query, statement = bq._bakery.get(
+ bq._effective_key(self.session), (None, None)
+ )
+ if query is None:
+ query, statement = bq._bake(self.session)
+
+ if self._params:
+ q = query.params(self._params)
+ else:
+ q = query
+ for fn in self._post_criteria:
+ q = fn(q)
+
+ params = q._params
+ execution_options = dict(q._execution_options)
+ execution_options.update(
+ {
+ "_sa_orm_load_options": q.load_options,
+ "compiled_cache": bq._bakery,
+ }
+ )
+
+ result = self.session.execute(
+ statement, params, execution_options=execution_options
+ )
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if result._attributes.get("filtered", False):
+ result = result.unique()
+
+ return result
+
+ def count(self):
+ """return the 'count'.
+
+ Equivalent to :meth:`_query.Query.count`.
+
+ Note this uses a subquery to ensure an accurate count regardless
+ of the structure of the original statement.
+
+ .. versionadded:: 1.1.6
+
+ """
+
+ col = func.count(literal_column("*"))
+ bq = self.bq.with_criteria(lambda q: q._from_self(col))
+ return bq.for_session(self.session).params(self._params).scalar()
+
+ def scalar(self):
+ """Return the first element of the first result or None
+ if no rows present. If multiple rows are returned,
+ raises MultipleResultsFound.
+
+ Equivalent to :meth:`_query.Query.scalar`.
+
+ .. versionadded:: 1.1.6
+
+ """
+ try:
+ ret = self.one()
+ if not isinstance(ret, collections_abc.Sequence):
+ return ret
+ return ret[0]
+ except orm_exc.NoResultFound:
+ return None
+
+ def first(self):
+ """Return the first row.
+
+ Equivalent to :meth:`_query.Query.first`.
+
+ """
+
+ bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
+ return (
+ bq.for_session(self.session)
+ .params(self._params)
+ ._using_post_criteria(self._post_criteria)
+ ._iter()
+ .first()
+ )
+
+ def one(self):
+ """Return exactly one result or raise an exception.
+
+ Equivalent to :meth:`_query.Query.one`.
+
+ """
+ return self._iter().one()
+
+ def one_or_none(self):
+ """Return one or zero results, or raise an exception for multiple
+ rows.
+
+ Equivalent to :meth:`_query.Query.one_or_none`.
+
+ .. versionadded:: 1.0.9
+
+ """
+ return self._iter().one_or_none()
+
+ def all(self):
+ """Return all rows.
+
+ Equivalent to :meth:`_query.Query.all`.
+
+ """
+ return self._iter().all()
+
+ def get(self, ident):
+ """Retrieve an object based on identity.
+
+ Equivalent to :meth:`_query.Query.get`.
+
+ """
+
+ query = self.bq.steps[0](self.session)
+ return query._get_impl(ident, self._load_on_pk_identity)
+
+ def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
+ """Load the given primary key identity from the database."""
+
+ mapper = query._raw_columns[0]._annotations["parententity"]
+
+ _get_clause, _get_params = mapper._get_clause
+
+ def setup(query):
+ _lcl_get_clause = _get_clause
+ q = query._clone()
+ q._get_condition()
+ q._order_by = None
+
+ # None present in ident - turn those comparisons
+ # into "IS NULL"
+ if None in primary_key_identity:
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+ _lcl_get_clause = sql_util.adapt_criterion_to_null(
+ _lcl_get_clause, nones
+ )
+
+ # TODO: can mapper._get_clause be pre-adapted?
+ q._where_criteria = (
+ sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
+ )
+
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ # cache the query against a key that includes
+ # which positions in the primary key are NULL
+ # (remember, we can map to an OUTER JOIN)
+ bq = self.bq
+
+ # add the clause we got from mapper._get_clause to the cache
+ # key so that if a race causes multiple calls to _get_clause,
+ # we've cached on ours
+ bq = bq._clone()
+ bq._cache_key += (_get_clause,)
+
+ bq = bq.with_criteria(
+ setup, tuple(elem is None for elem in primary_key_identity)
+ )
+
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
+
+ result = list(bq.for_session(self.session).params(**params))
+ l = len(result)
+ if l > 1:
+ raise orm_exc.MultipleResultsFound()
+ elif l:
+ return result[0]
+ else:
+ return None
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def bake_lazy_loaders():
+ """Enable the use of baked queries for all lazyloaders systemwide.
+
+ The "baked" implementation of lazy loading is now the sole implementation
+ for the base lazy loader; this method has no effect except for a warning.
+
+ """
+ pass
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def unbake_lazy_loaders():
+ """Disable the use of baked queries for all lazyloaders systemwide.
+
+ This method now raises NotImplementedError() as the "baked" implementation
+ is the only lazy load implementation. The
+ :paramref:`_orm.relationship.bake_queries` flag may be used to disable
+ the caching of queries on a per-relationship basis.
+
+ """
+ raise NotImplementedError(
+ "Baked lazy loading is now the default implementation"
+ )
+
+
+@strategy_options.loader_option()
+def baked_lazyload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using "lazy"
+ loading with a "baked" query used in the load.
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
+
+
+@baked_lazyload._add_unbound_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
+ )
+
+
+@baked_lazyload._add_unbound_all_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload_all(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
+ )
+
+
+baked_lazyload = baked_lazyload._unbound_fn
+baked_lazyload_all = baked_lazyload_all._unbound_all_fn
+
+bakery = BakedQuery.bakery
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
new file mode 100644
index 0000000..76b59ea
--- /dev/null
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -0,0 +1,613 @@
+# ext/compiler.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provides an API for creation of custom ClauseElements and compilers.
+
+Synopsis
+========
+
+Usage involves the creation of one or more
+:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
+more callables defining its compilation::
+
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.sql.expression import ColumnClause
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
+the base expression element for named column objects. The ``compiles``
+decorator registers itself with the ``MyColumn`` class so that it is invoked
+when the object is compiled to a string::
+
+ from sqlalchemy import select
+
+ s = select(MyColumn('x'), MyColumn('y'))
+ print(str(s))
+
+Produces::
+
+ SELECT [x], [y]
+
+Dialect-specific compilation rules
+==================================
+
+Compilers can also be made dialect-specific. The appropriate compiler will be
+invoked for the dialect in use::
+
+ from sqlalchemy.schema import DDLElement
+
+ class AlterColumn(DDLElement):
+ inherit_cache = False
+
+ def __init__(self, column, cmd):
+ self.column = column
+ self.cmd = cmd
+
+ @compiles(AlterColumn)
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER COLUMN %s ..." % element.column.name
+
+ @compiles(AlterColumn, 'postgresql')
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
+ element.column.name)
+
+The second ``visit_alter_table`` will be invoked when any ``postgresql``
+dialect is used.
+
+.. _compilerext_compiling_subelements:
+
+Compiling sub-elements of a custom expression construct
+=======================================================
+
+The ``compiler`` argument is the
+:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
+can be inspected for any information about the in-progress compilation,
+including ``compiler.dialect``, ``compiler.statement`` etc. The
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
+method which can be used for compilation of embedded attributes::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+ insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
+ print(insert)
+
+Produces::
+
+ "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
+ FROM mytable WHERE mytable.x > :x_1)"
+
+.. note::
+
+ The above ``InsertFromSelect`` construct is only an example, this actual
+ functionality is already available using the
+ :meth:`_expression.Insert.from_select` method.
+
+.. note::
+
+ The above ``InsertFromSelect`` construct probably wants to have "autocommit"
+ enabled. See :ref:`enabling_compiled_autocommit` for this step.
+
+Cross Compiling between SQL and DDL compilers
+---------------------------------------------
+
+SQL and DDL constructs are each compiled using different base compilers -
+``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
+compilation rules of SQL expressions from within a DDL expression. The
+``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
+below where we generate a CHECK constraint that embeds a SQL expression::
+
+ @compiles(MyConstraint)
+ def compile_my_constraint(constraint, ddlcompiler, **kw):
+ kw['literal_binds'] = True
+ return "CONSTRAINT %s CHECK (%s)" % (
+ constraint.name,
+ ddlcompiler.sql_compiler.process(
+ constraint.expression, **kw)
+ )
+
+Above, we add an additional flag to the process step as called by
+:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
+indicates that any SQL expression which refers to a :class:`.BindParameter`
+object or other "literal" object such as those which refer to strings or
+integers should be rendered **in-place**, rather than being referred to as
+a bound parameter; when emitting DDL, bound parameters are typically not
+supported.
+
+
+.. _enabling_compiled_autocommit:
+
+Enabling Autocommit on a Construct
+==================================
+
+Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`,
+when
+asked to execute a construct in the absence of a user-defined transaction,
+detects if the given construct represents DML or DDL, that is, a data
+modification or data definition statement, which requires (or may require,
+in the case of DDL) that the transaction generated by the DBAPI be committed
+(recall that DBAPI always has a transaction going on regardless of what
+SQLAlchemy does). Checking for this is actually accomplished by checking for
+the "autocommit" execution option on the construct. When building a
+construct like an INSERT derivation, a new DDL type, or perhaps a stored
+procedure that alters data, the "autocommit" option needs to be set in order
+for the statement to function with "connectionless" execution
+(as described in :ref:`dbengine_implicit`).
+
+Currently a quick way to do this is to subclass :class:`.Executable`, then
+add the "autocommit" flag to the ``_execution_options`` dictionary (note this
+is a "frozen" dictionary which supplies a generative ``union()`` method)::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class MyInsertThing(Executable, ClauseElement):
+ _execution_options = \
+ Executable._execution_options.union({'autocommit': True})
+
+More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
+DELETE, :class:`.UpdateBase` can be used, which already is a subclass
+of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the
+``autocommit`` flag::
+
+ from sqlalchemy.sql.expression import UpdateBase
+
+ class MyInsertThing(UpdateBase):
+ def __init__(self, ...):
+ ...
+
+
+
+
+DDL elements that subclass :class:`.DDLElement` already have the
+"autocommit" flag turned on.
+
+
+
+
+Changing the default compilation of existing constructs
+=======================================================
+
+The compiler extension applies just as well to the existing constructs. When
+overriding the compilation of a built in SQL construct, the @compiles
+decorator is invoked upon the appropriate class (be sure to use the class,
+i.e. ``Insert`` or ``Select``, instead of the creation function such
+as ``insert()`` or ``select()``).
+
+Within the new compilation function, to get at the "original" compilation
+routine, use the appropriate visit_XXX method - this
+because compiler.process() will call upon the overriding routine and cause
+an endless loop. Such as, to add "prefix" to all insert statements::
+
+ from sqlalchemy.sql.expression import Insert
+
+ @compiles(Insert)
+ def prefix_inserts(insert, compiler, **kw):
+ return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
+
+The above compiler will prefix all INSERT statements with "some prefix" when
+compiled.
+
+.. _type_compilation_extension:
+
+Changing Compilation of Types
+=============================
+
+``compiler`` works for types, too, such as below where we implement the
+MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
+
+ @compiles(String, 'mssql')
+ @compiles(VARCHAR, 'mssql')
+ def compile_varchar(element, compiler, **kw):
+ if element.length == 'max':
+ return "VARCHAR('max')"
+ else:
+ return compiler.visit_VARCHAR(element, **kw)
+
+ foo = Table('foo', metadata,
+ Column('data', VARCHAR('max'))
+ )
+
+Subclassing Guidelines
+======================
+
+A big part of using the compiler extension is subclassing SQLAlchemy
+expression constructs. To make this easier, the expression and
+schema packages feature a set of "bases" intended for common tasks.
+A synopsis is as follows:
+
+* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
+ expression class. Any SQL expression can be derived from this base, and is
+ probably the best choice for longer constructs such as specialized INSERT
+ statements.
+
+* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
+ "column-like" elements. Anything that you'd place in the "columns" clause of
+ a SELECT statement (as well as order by and group by) can derive from this -
+ the object will automatically have Python "comparison" behavior.
+
+ :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
+ ``type`` member which is expression's return type. This can be established
+ at the instance level in the constructor, or at the class level if its
+ generally constant::
+
+ class timestamp(ColumnElement):
+ type = TIMESTAMP()
+ inherit_cache = True
+
+* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
+ ``ColumnElement`` and a "from clause" like object, and represents a SQL
+ function or stored procedure type of call. Since most databases support
+ statements along the line of "SELECT FROM <some function>"
+ ``FunctionElement`` adds in the ability to be used in the FROM clause of a
+ ``select()`` construct::
+
+ from sqlalchemy.sql.expression import FunctionElement
+
+ class coalesce(FunctionElement):
+ name = 'coalesce'
+ inherit_cache = True
+
+ @compiles(coalesce)
+ def compile(element, compiler, **kw):
+ return "coalesce(%s)" % compiler.process(element.clauses, **kw)
+
+ @compiles(coalesce, 'oracle')
+ def compile(element, compiler, **kw):
+ if len(element.clauses) > 2:
+ raise TypeError("coalesce only supports two arguments on Oracle")
+ return "nvl(%s)" % compiler.process(element.clauses, **kw)
+
+* :class:`.DDLElement` - The root of all DDL expressions,
+ like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement`
+ subclasses is issued by a :class:`.DDLCompiler` instead of a
+ :class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook
+ in conjunction with event hooks like :meth:`.DDLEvents.before_create` and
+ :meth:`.DDLEvents.after_create`, allowing the construct to be invoked
+ automatically during CREATE TABLE and DROP TABLE sequences.
+
+ .. seealso::
+
+ :ref:`metadata_ddl_toplevel` - contains examples of associating
+ :class:`.DDL` objects (which are themselves :class:`.DDLElement`
+ instances) with :class:`.DDLEvents` event hooks.
+
+* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
+ should be used with any expression class that represents a "standalone"
+ SQL statement that can be passed directly to an ``execute()`` method. It
+ is already implicit within ``DDLElement`` and ``FunctionElement``.
+
+Most of the above constructs also respond to SQL statement caching. A
+subclassed construct will want to define the caching behavior for the object,
+which usually means setting the flag ``inherit_cache`` to the value of
+``False`` or ``True``. See the next section :ref:`compilerext_caching`
+for background.
+
+
+.. _compilerext_caching:
+
+Enabling Caching Support for Custom Constructs
+==============================================
+
+SQLAlchemy as of version 1.4 includes a
+:ref:`SQL compilation caching facility <sql_caching>` which will allow
+equivalent SQL constructs to cache their stringified form, along with other
+structural information used to fetch results from the statement.
+
+For reasons discussed at :ref:`caching_caveats`, the implementation of this
+caching system takes a conservative approach towards including custom SQL
+constructs and/or subclasses within the caching system. This includes that
+any user-defined SQL constructs, including all the examples for this
+extension, will not participate in caching by default unless they positively
+assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
+attribute when set to ``True`` at the class level of a specific subclass
+will indicate that instances of this class may be safely cached, using the
+cache key generation scheme of the immediate superclass. This applies
+for example to the "synopsis" example indicated previously::
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, the ``MyColumn`` class does not include any new state that
+affects its SQL compilation; the cache key of ``MyColumn`` instances will
+make use of that of the ``ColumnClause`` superclass, meaning it will take
+into account the class of the object (``MyColumn``), the string name and
+datatype of the object::
+
+ >>> MyColumn("some_name", String())._generate_cache_key()
+ CacheKey(
+ key=('0', <class '__main__.MyColumn'>,
+ 'name', 'some_name',
+ 'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
+ ('length', None), ('collation', None))
+ ), bindparams=[])
+
+For objects that are likely to be **used liberally as components within many
+larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
+datatypes, it's important that **caching be enabled as much as possible**, as
+this may otherwise negatively affect performance.
+
+An example of an object that **does** contain state which affects its SQL
+compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
+this is an "INSERT FROM SELECT" construct that combines together a
+:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
+which independently affect the SQL string generation of the construct. For
+this class, the example illustrates that it simply does not participate in
+caching::
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+While it is also possible that the above ``InsertFromSelect`` could be made to
+produce a cache key that is composed of that of the :class:`_schema.Table` and
+:class:`_sql.Select` components together, the API for this is not at the moment
+fully public. However, for an "INSERT FROM SELECT" construct, which is only
+used by itself for specific operations, caching is not as critical as in the
+previous example.
+
+For objects that are **used in relative isolation and are generally
+standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
+SELECT", **caching is generally less critical** as the lack of caching for such
+a construct will have only localized implications for that specific operation.
+
+
+Further Examples
+================
+
+"UTC timestamp" function
+-------------------------
+
+A function that works like "CURRENT_TIMESTAMP" except applies the
+appropriate conversions so that the time is in UTC time. Timestamps are best
+stored in relational databases as UTC, without time zones. UTC so that your
+database doesn't think time has gone backwards in the hour when daylight
+savings ends, without timezones because timezones are like character
+encodings - they're best applied only at the endpoints of an application
+(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
+
+For PostgreSQL and Microsoft SQL Server::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import DateTime
+
+ class utcnow(expression.FunctionElement):
+ type = DateTime()
+ inherit_cache = True
+
+ @compiles(utcnow, 'postgresql')
+ def pg_utcnow(element, compiler, **kw):
+ return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
+
+ @compiles(utcnow, 'mssql')
+ def ms_utcnow(element, compiler, **kw):
+ return "GETUTCDATE()"
+
+Example usage::
+
+ from sqlalchemy import (
+ Table, Column, Integer, String, DateTime, MetaData
+ )
+ metadata = MetaData()
+ event = Table("event", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("description", String(50), nullable=False),
+ Column("timestamp", DateTime, server_default=utcnow())
+ )
+
+"GREATEST" function
+-------------------
+
+The "GREATEST" function is given any number of arguments and returns the one
+that is of the highest value - its equivalent to Python's ``max``
+function. A SQL standard version versus a CASE based version which only
+accommodates two arguments::
+
+ from sqlalchemy.sql import expression, case
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import Numeric
+
+ class greatest(expression.FunctionElement):
+ type = Numeric()
+ name = 'greatest'
+ inherit_cache = True
+
+ @compiles(greatest)
+ def default_greatest(element, compiler, **kw):
+ return compiler.visit_function(element)
+
+ @compiles(greatest, 'sqlite')
+ @compiles(greatest, 'mssql')
+ @compiles(greatest, 'oracle')
+ def case_greatest(element, compiler, **kw):
+ arg1, arg2 = list(element.clauses)
+ return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw)
+
+Example usage::
+
+ Session.query(Account).\
+ filter(
+ greatest(
+ Account.checking_balance,
+ Account.savings_balance) > 10000
+ )
+
+"false" expression
+------------------
+
+Render a "false" constant expression, rendering as "0" on platforms that
+don't have a "false" constant::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+
+ class sql_false(expression.ColumnElement):
+ inherit_cache = True
+
+ @compiles(sql_false)
+ def default_false(element, compiler, **kw):
+ return "false"
+
+ @compiles(sql_false, 'mssql')
+ @compiles(sql_false, 'mysql')
+ @compiles(sql_false, 'oracle')
+ def int_false(element, compiler, **kw):
+ return "0"
+
+Example usage::
+
+ from sqlalchemy import select, union_all
+
+ exp = union_all(
+ select(users.c.name, sql_false().label("enrolled")),
+ select(customers.c.name, customers.c.enrolled)
+ )
+
+"""
+from .. import exc
+from .. import util
+from ..sql import sqltypes
+
+
+def compiles(class_, *specs):
+ """Register a function as a compiler for a
+ given :class:`_expression.ClauseElement` type."""
+
+ def decorate(fn):
+ # get an existing @compiles handler
+ existing = class_.__dict__.get("_compiler_dispatcher", None)
+
+ # get the original handler. All ClauseElement classes have one
+ # of these, but some TypeEngine classes will not.
+ existing_dispatch = getattr(class_, "_compiler_dispatch", None)
+
+ if not existing:
+ existing = _dispatcher()
+
+ if existing_dispatch:
+
+ def _wrap_existing_dispatch(element, compiler, **kw):
+ try:
+ return existing_dispatch(element, compiler, **kw)
+ except exc.UnsupportedCompilationError as uce:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ from_=uce,
+ )
+
+ existing.specs["default"] = _wrap_existing_dispatch
+
+ # TODO: why is the lambda needed ?
+ setattr(
+ class_,
+ "_compiler_dispatch",
+ lambda *arg, **kw: existing(*arg, **kw),
+ )
+ setattr(class_, "_compiler_dispatcher", existing)
+
+ if specs:
+ for s in specs:
+ existing.specs[s] = fn
+
+ else:
+ existing.specs["default"] = fn
+ return fn
+
+ return decorate
+
+
+def deregister(class_):
+ """Remove all custom compilers associated with a given
+ :class:`_expression.ClauseElement` type.
+
+ """
+
+ if hasattr(class_, "_compiler_dispatcher"):
+ class_._compiler_dispatch = class_._original_compiler_dispatch
+ del class_._compiler_dispatcher
+
+
+class _dispatcher(object):
+ def __init__(self):
+ self.specs = {}
+
+ def __call__(self, element, compiler, **kw):
+ # TODO: yes, this could also switch off of DBAPI in use.
+ fn = self.specs.get(compiler.dialect.name, None)
+ if not fn:
+ try:
+ fn = self.specs["default"]
+ except KeyError as ke:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ replace_context=ke,
+ )
+
+ # if compilation includes add_to_result_map, collect add_to_result_map
+ # arguments from the user-defined callable, which are probably none
+ # because this is not public API. if it wasn't called, then call it
+ # ourselves.
+ arm = kw.get("add_to_result_map", None)
+ if arm:
+ arm_collection = []
+ kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
+
+ expr = fn(element, compiler, **kw)
+
+ if arm:
+ if not arm_collection:
+ arm_collection.append(
+ (None, None, (element,), sqltypes.NULLTYPE)
+ )
+ for tup in arm_collection:
+ arm(*tup)
+ return expr
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
new file mode 100644
index 0000000..6215e35
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -0,0 +1,64 @@
+# ext/declarative/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .extensions import AbstractConcreteBase
+from .extensions import ConcreteBase
+from .extensions import DeferredReflection
+from .extensions import instrument_declarative
+from ... import util
+from ...orm.decl_api import as_declarative as _as_declarative
+from ...orm.decl_api import declarative_base as _declarative_base
+from ...orm.decl_api import DeclarativeMeta
+from ...orm.decl_api import declared_attr
+from ...orm.decl_api import has_inherited_table as _has_inherited_table
+from ...orm.decl_api import synonym_for as _synonym_for
+
+
+@util.moved_20(
+ "The ``declarative_base()`` function is now available as "
+ ":func:`sqlalchemy.orm.declarative_base`."
+)
+def declarative_base(*arg, **kw):
+ return _declarative_base(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``as_declarative()`` function is now available as "
+ ":func:`sqlalchemy.orm.as_declarative`"
+)
+def as_declarative(*arg, **kw):
+ return _as_declarative(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``has_inherited_table()`` function is now available as "
+ ":func:`sqlalchemy.orm.has_inherited_table`."
+)
+def has_inherited_table(*arg, **kw):
+ return _has_inherited_table(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``synonym_for()`` function is now available as "
+ ":func:`sqlalchemy.orm.synonym_for`"
+)
+def synonym_for(*arg, **kw):
+ return _synonym_for(*arg, **kw)
+
+
+__all__ = [
+ "declarative_base",
+ "synonym_for",
+ "has_inherited_table",
+ "instrument_declarative",
+ "declared_attr",
+ "as_declarative",
+ "ConcreteBase",
+ "AbstractConcreteBase",
+ "DeclarativeMeta",
+ "DeferredReflection",
+]
diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py
new file mode 100644
index 0000000..7818841
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/extensions.py
@@ -0,0 +1,463 @@
+# ext/declarative/extensions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Public API functions and helpers for declarative."""
+
+
+from ... import inspection
+from ... import util
+from ...orm import exc as orm_exc
+from ...orm import registry
+from ...orm import relationships
+from ...orm.base import _mapper_or_none
+from ...orm.clsregistry import _resolver
+from ...orm.decl_base import _DeferredMapperConfig
+from ...orm.util import polymorphic_union
+from ...schema import Table
+from ...util import OrderedDict
+
+
+@util.deprecated(
+ "2.0",
+ "the instrument_declarative function is deprecated "
+ "and will be removed in SQLAlhcemy 2.0. Please use "
+ ":meth:`_orm.registry.map_declaratively",
+)
+def instrument_declarative(cls, cls_registry, metadata):
+ """Given a class, configure the class declaratively,
+ using the given registry, which can be any dictionary, and
+ MetaData object.
+
+ """
+ registry(metadata=metadata, class_registry=cls_registry).map_declaratively(
+ cls
+ )
+
+
+class ConcreteBase(object):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.ConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.ConcreteBase` produces a mapped
+ table for the class itself. Compare to :class:`.AbstractConcreteBase`,
+ which does not.
+
+ Example::
+
+ from sqlalchemy.ext.declarative import ConcreteBase
+
+ class Employee(ConcreteBase, Base):
+ __tablename__ = 'employee'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee',
+ 'concrete':True}
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+
+ The name of the discriminator column used by :func:`.polymorphic_union`
+ defaults to the name ``type``. To suit the use case of a mapping where an
+ actual column in a mapped table is already named ``type``, the
+ discriminator name can be configured by setting the
+ ``_concrete_discriminator_name`` attribute::
+
+ class Employee(ConcreteBase, Base):
+ _concrete_discriminator_name = '_concrete_discriminator'
+
+ .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
+ attribute to :class:`_declarative.ConcreteBase` so that the
+ virtual discriminator column name can be customized.
+
+ .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
+ need only be placed on the basemost class to take correct effect for
+ all subclasses. An explicit error message is now raised if the
+ mapped column names conflict with the discriminator name, whereas
+ in the 1.3.x series there would be some warnings and then a non-useful
+ query would be generated.
+
+ .. seealso::
+
+ :class:`.AbstractConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+
+ """
+
+ @classmethod
+ def _create_polymorphic_union(cls, mappers, discriminator_name):
+ return polymorphic_union(
+ OrderedDict(
+ (mp.polymorphic_identity, mp.local_table) for mp in mappers
+ ),
+ discriminator_name,
+ "pjoin",
+ )
+
+ @classmethod
+ def __declare_first__(cls):
+ m = cls.__mapper__
+ if m.with_polymorphic:
+ return
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+
+ mappers = list(m.self_and_descendants)
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+ m._set_with_polymorphic(("*", pjoin))
+ m._set_polymorphic_on(pjoin.c[discriminator_name])
+
+
+class AbstractConcreteBase(ConcreteBase):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.AbstractConcreteBase` does produce a mapped class
+ for the base class, however it is not persisted to any table; it
+ is instead mapped directly to the "polymorphic" selectable directly
+ and is only used for selecting. Compare to :class:`.ConcreteBase`,
+ which does create a persisted table for the base class.
+
+ .. note::
+
+ The :class:`.AbstractConcreteBase` class does not intend to set up the
+ mapping for the base class until all the subclasses have been defined,
+ as it needs to create a mapping against a selectable that will include
+ all subclass tables. In order to achieve this, it waits for the
+ **mapper configuration event** to occur, at which point it scans
+ through all the configured subclasses and sets up a mapping that will
+ query against all subclasses at once.
+
+ While this event is normally invoked automatically, in the case of
+ :class:`.AbstractConcreteBase`, it may be necessary to invoke it
+ explicitly after **all** subclass mappings are defined, if the first
+ operation is to be a query against this base class. To do so, invoke
+ :func:`.configure_mappers` once all the desired classes have been
+ configured::
+
+ from sqlalchemy.orm import configure_mappers
+
+ configure_mappers()
+
+ .. seealso::
+
+ :func:`_orm.configure_mappers`
+
+
+ Example::
+
+ from sqlalchemy.ext.declarative import AbstractConcreteBase
+
+ class Employee(AbstractConcreteBase, Base):
+ pass
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ The abstract base class is handled by declarative in a special way;
+ at class configuration time, it behaves like a declarative mixin
+ or an ``__abstract__`` base class. Once classes are configured
+ and mappings are produced, it then gets mapped itself, but
+ after all of its descendants. This is a very unique system of mapping
+ not found in any other SQLAlchemy system.
+
+ Using this approach, we can specify columns and properties
+ that will take place on mapped subclasses, in the way that
+ we normally do as in :ref:`declarative_mixins`::
+
+ class Company(Base):
+ __tablename__ = 'company'
+ id = Column(Integer, primary_key=True)
+
+ class Employee(AbstractConcreteBase, Base):
+ employee_id = Column(Integer, primary_key=True)
+
+ @declared_attr
+ def company_id(cls):
+ return Column(ForeignKey('company.id'))
+
+ @declared_attr
+ def company(cls):
+ return relationship("Company")
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ When we make use of our mappings however, both ``Manager`` and
+ ``Employee`` will have an independently usable ``.company`` attribute::
+
+ session.query(Employee).filter(Employee.company.has(id=5))
+
+ .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
+ have been reworked to support relationships established directly
+ on the abstract base, without any special configurational steps.
+
+ .. seealso::
+
+ :class:`.ConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+ """
+
+ __no_table__ = True
+
+ @classmethod
+ def __declare_first__(cls):
+ cls._sa_decl_prepare_nocascade()
+
+ @classmethod
+ def _sa_decl_prepare_nocascade(cls):
+ if getattr(cls, "__mapper__", None):
+ return
+
+ to_map = _DeferredMapperConfig.config_for_cls(cls)
+
+ # can't rely on 'self_and_descendants' here
+ # since technically an immediate subclass
+ # might not be mapped, but a subclass
+ # may be.
+ mappers = []
+ stack = list(cls.__subclasses__())
+ while stack:
+ klass = stack.pop()
+ stack.extend(klass.__subclasses__())
+ mn = _mapper_or_none(klass)
+ if mn is not None:
+ mappers.append(mn)
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+
+ # For columns that were declared on the class, these
+ # are normally ignored with the "__no_table__" mapping,
+ # unless they have a different attribute key vs. col name
+ # and are in the properties argument.
+ # In that case, ensure we update the properties entry
+ # to the correct column from the pjoin target table.
+ declared_cols = set(to_map.declared_columns)
+ for k, v in list(to_map.properties.items()):
+ if v in declared_cols:
+ to_map.properties[k] = pjoin.c[v.key]
+
+ to_map.local_table = pjoin
+
+ m_args = to_map.mapper_args_fn or dict
+
+ def mapper_args():
+ args = m_args()
+ args["polymorphic_on"] = pjoin.c[discriminator_name]
+ return args
+
+ to_map.mapper_args_fn = mapper_args
+
+ m = to_map.map()
+
+ for scls in cls.__subclasses__():
+ sm = _mapper_or_none(scls)
+ if sm and sm.concrete and cls in scls.__bases__:
+ sm._set_concrete_base(m)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AbstractConcreteBase and "
+ "has a mapping pending until all subclasses are defined. "
+ "Call the sqlalchemy.orm.configure_mappers() function after "
+ "all subclasses have been defined to "
+ "complete the mapping of this class."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+class DeferredReflection(object):
+ """A helper class for construction of mappings based on
+ a deferred reflection step.
+
+ Normally, declarative can be used with reflection by
+ setting a :class:`_schema.Table` object using autoload_with=engine
+ as the ``__table__`` attribute on a declarative class.
+ The caveat is that the :class:`_schema.Table` must be fully
+ reflected, or at the very least have a primary key column,
+ at the point at which a normal declarative mapping is
+ constructed, meaning the :class:`_engine.Engine` must be available
+ at class declaration time.
+
+ The :class:`.DeferredReflection` mixin moves the construction
+ of mappers to be at a later point, after a specific
+ method is called which first reflects all :class:`_schema.Table`
+ objects created so far. Classes can define it as such::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.declarative import DeferredReflection
+ Base = declarative_base()
+
+ class MyClass(DeferredReflection, Base):
+ __tablename__ = 'mytable'
+
+ Above, ``MyClass`` is not yet mapped. After a series of
+ classes have been defined in the above fashion, all tables
+ can be reflected and mappings created using
+ :meth:`.prepare`::
+
+ engine = create_engine("someengine://...")
+ DeferredReflection.prepare(engine)
+
+ The :class:`.DeferredReflection` mixin can be applied to individual
+ classes, used as the base for the declarative base itself,
+ or used in a custom abstract class. Using an abstract base
+ allows that only a subset of classes to be prepared for a
+ particular prepare step, which is necessary for applications
+ that use more than one engine. For example, if an application
+ has two engines, you might use two bases, and prepare each
+ separately, e.g.::
+
+ class ReflectedOne(DeferredReflection, Base):
+ __abstract__ = True
+
+ class ReflectedTwo(DeferredReflection, Base):
+ __abstract__ = True
+
+ class MyClass(ReflectedOne):
+ __tablename__ = 'mytable'
+
+ class MyOtherClass(ReflectedOne):
+ __tablename__ = 'myothertable'
+
+ class YetAnotherClass(ReflectedTwo):
+ __tablename__ = 'yetanothertable'
+
+ # ... etc.
+
+ Above, the class hierarchies for ``ReflectedOne`` and
+ ``ReflectedTwo`` can be configured separately::
+
+ ReflectedOne.prepare(engine_one)
+ ReflectedTwo.prepare(engine_two)
+
+ .. seealso::
+
+ :ref:`orm_declarative_reflected_deferred_reflection` - in the
+ :ref:`orm_declarative_table_config_toplevel` section.
+
+ """
+
+ @classmethod
+ def prepare(cls, engine):
+ """Reflect all :class:`_schema.Table` objects for all current
+ :class:`.DeferredReflection` subclasses"""
+
+ to_map = _DeferredMapperConfig.classes_for_base(cls)
+
+ with inspection.inspect(engine)._inspection_context() as insp:
+ for thingy in to_map:
+ cls._sa_decl_prepare(thingy.local_table, insp)
+ thingy.map()
+ mapper = thingy.cls.__mapper__
+ metadata = mapper.class_.metadata
+ for rel in mapper._props.values():
+ if (
+ isinstance(rel, relationships.RelationshipProperty)
+ and rel.secondary is not None
+ ):
+ if isinstance(rel.secondary, Table):
+ cls._reflect_table(rel.secondary, insp)
+ elif isinstance(rel.secondary, str):
+
+ _, resolve_arg = _resolver(rel.parent.class_, rel)
+
+ rel.secondary = resolve_arg(rel.secondary)
+ rel.secondary._resolvers += (
+ cls._sa_deferred_table_resolver(
+ insp, metadata
+ ),
+ )
+
+ # controversy! do we resolve it here? or leave
+ # it deferred? I think doing it here is necessary
+ # so the connection does not leak.
+ rel.secondary = rel.secondary()
+
+ @classmethod
+ def _sa_deferred_table_resolver(cls, inspector, metadata):
+ def _resolve(key):
+ t1 = Table(key, metadata)
+ cls._reflect_table(t1, inspector)
+ return t1
+
+ return _resolve
+
+ @classmethod
+ def _sa_decl_prepare(cls, local_table, inspector):
+ # autoload Table, which is already
+ # present in the metadata. This
+ # will fill in db-loaded columns
+ # into the existing Table object.
+ if local_table is not None:
+ cls._reflect_table(local_table, inspector)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of DeferredReflection. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+ @classmethod
+ def _reflect_table(cls, table, inspector):
+ Table(
+ table.name,
+ table.metadata,
+ extend_existing=True,
+ autoload_replace=False,
+ autoload_with=inspector,
+ schema=table.schema,
+ )
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
new file mode 100644
index 0000000..bad076e
--- /dev/null
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -0,0 +1,256 @@
+# ext/horizontal_shard.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the :ref:`examples_sharding` example included in
+the source distribution.
+
+"""
+
+from .. import event
+from .. import exc
+from .. import inspect
+from .. import util
+from ..orm.query import Query
+from ..orm.session import Session
+
+__all__ = ["ShardedSession", "ShardedQuery"]
+
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self.execute_chooser = self.session.execute_chooser
+ self._shard_id = None
+
+ def set_shard(self, shard_id):
+ """Return a new query, limited to a single shard ID.
+
+ All subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+
+ The shard_id can be passed for a 2.0 style execution to the
+ bind_arguments dictionary of :meth:`.Session.execute`::
+
+ results = session.execute(
+ stmt,
+ bind_arguments={"shard_id": "my_shard"}
+ )
+
+ """
+ return self.execution_options(_sa_shard_id=shard_id)
+
+
+class ShardedSession(Session):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ execute_chooser=None,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
+ """Construct a ShardedSession.
+
+ :param shard_chooser: A callable which, passed a Mapper, a mapped
+ instance, and possibly a SQL clause, returns a shard ID. This id
+ may be based off of the attributes present within the object, or on
+ some round-robin scheme. If the scheme is based on a selection, it
+ should set whatever state on the instance to mark it in the future as
+ participating in that shard.
+
+ :param id_chooser: A callable, passed a query and a tuple of identity
+ values, which should return a list of shard ids where the ID might
+ reside. The databases will be queried in the order of this listing.
+
+ :param execute_chooser: For a given :class:`.ORMExecuteState`,
+ returns the list of shard_ids
+ where the query should be issued. Results from all shards returned
+ will be combined together into a single listing.
+
+ .. versionchanged:: 1.4 The ``execute_chooser`` parameter
+ supersedes the ``query_chooser`` parameter.
+
+ :param shards: A dictionary of string shard names
+ to :class:`~sqlalchemy.engine.Engine` objects.
+
+ """
+ query_chooser = kwargs.pop("query_chooser", None)
+ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
+
+ event.listen(
+ self, "do_orm_execute", execute_and_instances, retval=True
+ )
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+
+ if query_chooser:
+ util.warn_deprecated(
+ "The ``query_choser`` parameter is deprecated; "
+ "please use ``execute_chooser``.",
+ "1.4",
+ )
+ if execute_chooser:
+ raise exc.ArgumentError(
+ "Can't pass query_chooser and execute_chooser "
+ "at the same time."
+ )
+
+ def execute_chooser(orm_context):
+ return query_chooser(orm_context.statement)
+
+ self.execute_chooser = execute_chooser
+ else:
+ self.execute_chooser = execute_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ if shards is not None:
+ for k in shards:
+ self.bind_shard(k, shards[k])
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
+ """override the default :meth:`.Session._identity_lookup` method so
+ that we search for a given non-token primary key identity across all
+ possible identity tokens (e.g. shard ids).
+
+ .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
+ the :class:`_query.Query` object to the :class:`.Session`.
+
+ """
+
+ if identity_token is not None:
+ return super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
+ )
+ else:
+ q = self.query(mapper)
+ if lazy_loaded_from:
+ q = q._set_lazyload_from(lazy_loaded_from)
+ for shard_id in self.id_chooser(q, primary_key_identity):
+ obj = super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=shard_id,
+ lazy_loaded_from=lazy_loaded_from,
+ **kw
+ )
+ if obj is not None:
+ return obj
+
+ return None
+
+ def _choose_shard_and_assign(self, mapper, instance, **kw):
+ if instance is not None:
+ state = inspect(instance)
+ if state.key:
+ token = state.key[2]
+ assert token is not None
+ return token
+ elif state.identity_token:
+ return state.identity_token
+
+ shard_id = self.shard_chooser(mapper, instance, **kw)
+ if instance is not None:
+ state.identity_token = shard_id
+ return shard_id
+
+ def connection_callable(
+ self, mapper=None, instance=None, shard_id=None, **kwargs
+ ):
+ """Provide a :class:`_engine.Connection` to use in the unit of work
+ flush process.
+
+ """
+
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(mapper, instance)
+
+ if self.in_transaction():
+ return self.get_transaction().connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(
+ mapper, shard_id=shard_id, instance=instance
+ ).connect(**kwargs)
+
+ def get_bind(
+ self, mapper=None, shard_id=None, instance=None, clause=None, **kw
+ ):
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(
+ mapper, instance, clause=clause
+ )
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+
+def execute_and_instances(orm_context):
+ if orm_context.is_select:
+ load_options = active_options = orm_context.load_options
+ update_options = None
+
+ elif orm_context.is_update or orm_context.is_delete:
+ load_options = None
+ update_options = active_options = orm_context.update_delete_options
+ else:
+ load_options = update_options = active_options = None
+
+ session = orm_context.session
+
+ def iter_for_shard(shard_id, load_options, update_options):
+ execution_options = dict(orm_context.local_execution_options)
+
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["shard_id"] = shard_id
+
+ if orm_context.is_select:
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+ elif orm_context.is_update or orm_context.is_delete:
+ update_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_update_options"] = update_options
+
+ return orm_context.invoke_statement(
+ bind_arguments=bind_arguments, execution_options=execution_options
+ )
+
+ if active_options and active_options._refresh_identity_token is not None:
+ shard_id = active_options._refresh_identity_token
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
+ elif "shard_id" in orm_context.bind_arguments:
+ shard_id = orm_context.bind_arguments["shard_id"]
+ else:
+ shard_id = None
+
+ if shard_id is not None:
+ return iter_for_shard(shard_id, load_options, update_options)
+ else:
+ partial = []
+ for shard_id in session.execute_chooser(orm_context):
+ result_ = iter_for_shard(shard_id, load_options, update_options)
+ partial.append(result_)
+
+ return partial[0].merge(*partial[1:])
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
new file mode 100644
index 0000000..cc0aca6
--- /dev/null
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -0,0 +1,1206 @@
+# ext/hybrid.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define attributes on ORM-mapped classes that have "hybrid" behavior.
+
+"hybrid" means the attribute has distinct behaviors defined at the
+class level and at the instance level.
+
+The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of
+method decorator, is around 50 lines of code and has almost no
+dependencies on the rest of SQLAlchemy. It can, in theory, work with
+any descriptor-based expression system.
+
+Consider a mapping ``Interval``, representing integer ``start`` and ``end``
+values. We can define higher level functions on mapped classes that produce SQL
+expressions at the class level, and Python expression evaluation at the
+instance level. Below, each function decorated with :class:`.hybrid_method` or
+:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
+as the class itself::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.orm import Session, aliased
+ from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
+
+ Base = declarative_base()
+
+ class Interval(Base):
+ __tablename__ = 'interval'
+
+ id = Column(Integer, primary_key=True)
+ start = Column(Integer, nullable=False)
+ end = Column(Integer, nullable=False)
+
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @hybrid_method
+ def contains(self, point):
+ return (self.start <= point) & (point <= self.end)
+
+ @hybrid_method
+ def intersects(self, other):
+ return self.contains(other.start) | self.contains(other.end)
+
+Above, the ``length`` property returns the difference between the
+``end`` and ``start`` attributes. With an instance of ``Interval``,
+this subtraction occurs in Python, using normal Python descriptor
+mechanics::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+
+When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
+descriptor evaluates the function body given the ``Interval`` class as
+the argument, which when evaluated with SQLAlchemy expression mechanics
+(here using the :attr:`.QueryableAttribute.expression` accessor)
+returns a new SQL expression::
+
+ >>> print(Interval.length.expression)
+ interval."end" - interval.start
+
+ >>> print(Session().query(Interval).filter(Interval.length > 10))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start > :param_1
+
+ORM methods such as :meth:`_query.Query.filter_by`
+generally use ``getattr()`` to
+locate attributes, so can also be used with hybrid attributes::
+
+ >>> print(Session().query(Interval).filter_by(length=5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start = :param_1
+
+The ``Interval`` class example also illustrates two methods,
+``contains()`` and ``intersects()``, decorated with
+:class:`.hybrid_method`. This decorator applies the same idea to
+methods that :class:`.hybrid_property` applies to attributes. The
+methods return boolean values, and take advantage of the Python ``|``
+and ``&`` bitwise operators to produce equivalent instance-level and
+SQL expression-level boolean behavior::
+
+ >>> i1.contains(6)
+ True
+ >>> i1.contains(15)
+ False
+ >>> i1.intersects(Interval(7, 18))
+ True
+ >>> i1.intersects(Interval(25, 29))
+ False
+
+ >>> print(Session().query(Interval).filter(Interval.contains(15)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval.start <= :start_1 AND interval."end" > :end_1
+
+ >>> ia = aliased(Interval)
+ >>> print(Session().query(Interval, ia).filter(Interval.intersects(ia)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end, interval_1.id AS interval_1_id,
+ interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end
+ FROM interval, interval AS interval_1
+ WHERE interval.start <= interval_1.start
+ AND interval."end" > interval_1.start
+ OR interval.start <= interval_1."end"
+ AND interval."end" > interval_1."end"
+
+.. _hybrid_distinct_expression:
+
+Defining Expression Behavior Distinct from Attribute Behavior
+--------------------------------------------------------------
+
+Our usage of the ``&`` and ``|`` bitwise operators above was
+fortunate, considering our functions operated on two boolean values to
+return a new one. In many cases, the construction of an in-Python
+function and a SQLAlchemy SQL expression have enough differences that
+two separate Python expressions should be defined. The
+:mod:`~sqlalchemy.ext.hybrid` decorators define the
+:meth:`.hybrid_property.expression` modifier for this purpose. As an
+example we'll define the radius of the interval, which requires the
+usage of the absolute value function::
+
+ from sqlalchemy import func
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ @radius.expression
+ def radius(cls):
+ return func.abs(cls.length) / 2
+
+Above the Python function ``abs()`` is used for instance-level
+operations, the SQL function ``ABS()`` is used via the :data:`.func`
+object for class-level expressions::
+
+ >>> i1.radius
+ 2
+
+ >>> print(Session().query(Interval).filter(Interval.radius > 5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
+
+.. note:: When defining an expression for a hybrid property or method, the
+ expression method **must** retain the name of the original hybrid, else
+ the new hybrid with the additional state will be attached to the class
+ with the non-matching name. To use the example above::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ # WRONG - the non-matching name will cause this function to be
+ # ignored
+ @radius.expression
+ def radius_expression(cls):
+ return func.abs(cls.length) / 2
+
+ This is also true for other mutator methods, such as
+ :meth:`.hybrid_property.update_expression`. This is the same behavior
+ as that of the ``@property`` construct that is part of standard Python.
+
+Defining Setters
+----------------
+
+Hybrid properties can also define setter methods. If we wanted
+``length`` above, when set, to modify the endpoint value::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+The ``length(self, value)`` method is now called upon set::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+ >>> i1.length = 12
+ >>> i1.end
+ 17
+
+.. _hybrid_bulk_update:
+
+Allowing Bulk ORM Update
+------------------------
+
+A hybrid can define a custom "UPDATE" handler for when using the
+:meth:`_query.Query.update` method, allowing the hybrid to be used in the
+SET clause of the update.
+
+Normally, when using a hybrid with :meth:`_query.Query.update`, the SQL
+expression is used as the column that's the target of the SET. If our
+``Interval`` class had a hybrid ``start_point`` that linked to
+``Interval.start``, this could be substituted directly::
+
+ session.query(Interval).update({Interval.start_point: 10})
+
+However, when using a composite hybrid like ``Interval.length``, this
+hybrid represents more than one column. We can set up a handler that will
+accommodate a value passed to :meth:`_query.Query.update` which can affect
+this, using the :meth:`.hybrid_property.update_expression` decorator.
+A handler that works similarly to our setter would be::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+ @length.update_expression
+ def length(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+Above, if we use ``Interval.length`` in an UPDATE expression as::
+
+ session.query(Interval).update(
+ {Interval.length: 25}, synchronize_session='fetch')
+
+We'll get an UPDATE statement along the lines of::
+
+ UPDATE interval SET end=start + :value
+
+In some cases, the default "evaluate" strategy can't perform the SET
+expression in Python; while the addition operator we're using above
+is supported, for more complex SET expressions it will usually be necessary
+to use either the "fetch" or False synchronization strategy as illustrated
+above.
+
+.. note:: For ORM bulk updates to work with hybrids, the function name
+ of the hybrid must match that of how it is accessed. Something
+ like this wouldn't work::
+
+ class Interval(object):
+ # ...
+
+ def _get(self):
+ return self.end - self.start
+
+ def _set(self, value):
+ self.end = self.start + value
+
+ def _update_expr(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+ length = hybrid_property(
+ fget=_get, fset=_set, update_expr=_update_expr
+ )
+
+ The Python descriptor protocol does not provide any reliable way for
+ a descriptor to know what attribute name it was accessed as, and
+ the UPDATE scheme currently relies upon being able to access the
+ attribute from an instance by name in order to perform the instance
+ synchronization step.
+
+.. versionadded:: 1.2 added support for bulk updates to hybrid properties.
+
+Working with Relationships
+--------------------------
+
+There's no essential difference when creating hybrids that work with
+related objects as opposed to column-based data. The need for distinct
+expressions tends to be greater. The two variants we'll illustrate
+are the "join-dependent" hybrid, and the "correlated subquery" hybrid.
+
+Join-Dependent Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Consider the following declarative
+mapping which relates a ``User`` to a ``SavingsAccount``::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ if self.accounts:
+ return self.accounts[0].balance
+ else:
+ return None
+
+ @balance.setter
+ def balance(self, value):
+ if not self.accounts:
+ account = Account(owner=self)
+ else:
+ account = self.accounts[0]
+ account.balance = value
+
+ @balance.expression
+ def balance(cls):
+ return SavingsAccount.balance
+
+The above hybrid property ``balance`` works with the first
+``SavingsAccount`` entry in the list of accounts for this user. The
+in-Python getter/setter methods can treat ``accounts`` as a Python
+list available on ``self``.
+
+However, at the expression level, it's expected that the ``User`` class will
+be used in an appropriate context such that an appropriate join to
+``SavingsAccount`` will be present::
+
+ >>> print(Session().query(User, User.balance).
+ ... join(User.accounts).filter(User.balance > 5000))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" JOIN account ON "user".id = account.user_id
+ WHERE account.balance > :balance_1
+
+Note however, that while the instance level accessors need to worry
+about whether ``self.accounts`` is even present, this issue expresses
+itself differently at the SQL expression level, where we basically
+would use an outer join::
+
+ >>> from sqlalchemy import or_
+ >>> print (Session().query(User, User.balance).outerjoin(User.accounts).
+ ... filter(or_(User.balance < 5000, User.balance == None)))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
+ WHERE account.balance < :balance_1 OR account.balance IS NULL
+
+Correlated Subquery Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+We can, of course, forego being dependent on the enclosing query's usage
+of joins in favor of the correlated subquery, which can portably be packed
+into a single column expression. A correlated subquery is more portable, but
+often performs more poorly at the SQL level. Using the same technique
+illustrated at :ref:`mapper_column_property_sql_expressions`,
+we can adjust our ``SavingsAccount`` example to aggregate the balances for
+*all* accounts, and use a correlated subquery for the column expression::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+ from sqlalchemy import select, func
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ return sum(acc.balance for acc in self.accounts)
+
+ @balance.expression
+ def balance(cls):
+ return select(func.sum(SavingsAccount.balance)).\
+ where(SavingsAccount.user_id==cls.id).\
+ label('total_balance')
+
+The above recipe will give us the ``balance`` column which renders
+a correlated SELECT::
+
+ >>> print(s.query(User).filter(User.balance > 400))
+ SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE (SELECT sum(account.balance) AS sum_1
+ FROM account
+ WHERE account.user_id = "user".id) > :param_1
+
+.. _hybrid_custom_comparators:
+
+Building Custom Comparators
+---------------------------
+
+The hybrid property also includes a helper that allows construction of
+custom comparators. A comparator object allows one to customize the
+behavior of each SQLAlchemy expression operator individually. They
+are useful when creating custom types that have some highly
+idiosyncratic behavior on the SQL side.
+
+.. note:: The :meth:`.hybrid_property.comparator` decorator introduced
+ in this section **replaces** the use of the
+ :meth:`.hybrid_property.expression` decorator.
+ They cannot be used together.
+
+The example class below allows case-insensitive comparisons on the attribute
+named ``word_insensitive``::
+
+ from sqlalchemy.ext.hybrid import Comparator, hybrid_property
+ from sqlalchemy import func, Column, Integer, String
+ from sqlalchemy.orm import Session
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class CaseInsensitiveComparator(Comparator):
+ def __eq__(self, other):
+ return func.lower(self.__clause_element__()) == func.lower(other)
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return self.word.lower()
+
+ @word_insensitive.comparator
+ def word_insensitive(cls):
+ return CaseInsensitiveComparator(cls.word)
+
+Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
+SQL function to both sides::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = lower(:lower_1)
+
+The ``CaseInsensitiveComparator`` above implements part of the
+:class:`.ColumnOperators` interface. A "coercion" operation like
+lowercasing can be applied to all comparison operations (i.e. ``eq``,
+``lt``, ``gt``, etc.) using :meth:`.Operators.operate`::
+
+ class CaseInsensitiveComparator(Comparator):
+ def operate(self, op, other, **kwargs):
+ return op(
+ func.lower(self.__clause_element__()),
+ func.lower(other),
+ **kwargs,
+ )
+
+.. _hybrid_reuse_subclass:
+
+Reusing Hybrid Properties across Subclasses
+-------------------------------------------
+
+A hybrid can be referred to from a superclass, to allow modifying
+methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter`
+to be used to redefine those methods on a subclass. This is similar to
+how the standard Python ``@property`` object works::
+
+ class FirstNameOnly(Base):
+ # ...
+
+ first_name = Column(String)
+
+ @hybrid_property
+ def name(self):
+ return self.first_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name = value
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.getter
+ def name(self):
+ return self.first_name + ' ' + self.last_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name, self.last_name = value.split(' ', 1)
+
+Above, the ``FirstNameLastName`` class refers to the hybrid from
+``FirstNameOnly.name`` to repurpose its getter and setter for the subclass.
+
+When overriding :meth:`.hybrid_property.expression` and
+:meth:`.hybrid_property.comparator` alone as the first reference to the
+superclass, these names conflict with the same-named accessors on the class-
+level :class:`.QueryableAttribute` object returned at the class level. To
+override these methods when referring directly to the parent class descriptor,
+add the special qualifier :attr:`.hybrid_property.overrides`, which will de-
+reference the instrumented attribute back to the hybrid object::
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.overrides.expression
+ def name(cls):
+ return func.concat(cls.first_name, ' ', cls.last_name)
+
+.. versionadded:: 1.2 Added :meth:`.hybrid_property.getter` as well as the
+ ability to redefine accessors per-subclass.
+
+
+Hybrid Value Objects
+--------------------
+
+Note in our previous example, if we were to compare the ``word_insensitive``
+attribute of a ``SearchWord`` instance to a plain Python string, the plain
+Python string would not be coerced to lower case - the
+``CaseInsensitiveComparator`` we built, being returned by
+``@word_insensitive.comparator``, only applies to the SQL side.
+
+A more comprehensive form of the custom comparator is to construct a *Hybrid
+Value Object*. This technique applies the target value or expression to a value
+object which is then returned by the accessor in all cases. The value object
+allows control of all operations upon the value as well as how compared values
+are treated, both on the SQL expression side as well as the Python value side.
+Replacing the previous ``CaseInsensitiveComparator`` class with a new
+``CaseInsensitiveWord`` class::
+
+ class CaseInsensitiveWord(Comparator):
+ "Hybrid value representing a lower case representation of a word."
+
+ def __init__(self, word):
+ if isinstance(word, basestring):
+ self.word = word.lower()
+ elif isinstance(word, CaseInsensitiveWord):
+ self.word = word.word
+ else:
+ self.word = func.lower(word)
+
+ def operate(self, op, other, **kwargs):
+ if not isinstance(other, CaseInsensitiveWord):
+ other = CaseInsensitiveWord(other)
+ return op(self.word, other.word, **kwargs)
+
+ def __clause_element__(self):
+ return self.word
+
+ def __str__(self):
+ return self.word
+
+ key = 'word'
+ "Label to apply to Query tuple results"
+
+Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may
+be a SQL function, or may be a Python native. By overriding ``operate()`` and
+``__clause_element__()`` to work in terms of ``self.word``, all comparison
+operations will work against the "converted" form of ``word``, whether it be
+SQL side or Python side. Our ``SearchWord`` class can now deliver the
+``CaseInsensitiveWord`` object unconditionally from a single hybrid call::
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return CaseInsensitiveWord(self.word)
+
+The ``word_insensitive`` attribute now has case-insensitive comparison behavior
+universally, including SQL expression vs. Python expression (note the Python
+value is converted to lower case on the Python side here)::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = :lower_1
+
+SQL expression versus SQL expression::
+
+ >>> sw1 = aliased(SearchWord)
+ >>> sw2 = aliased(SearchWord)
+ >>> print(Session().query(
+ ... sw1.word_insensitive,
+ ... sw2.word_insensitive).\
+ ... filter(
+ ... sw1.word_insensitive > sw2.word_insensitive
+ ... ))
+ SELECT lower(searchword_1.word) AS lower_1,
+ lower(searchword_2.word) AS lower_2
+ FROM searchword AS searchword_1, searchword AS searchword_2
+ WHERE lower(searchword_1.word) > lower(searchword_2.word)
+
+Python only expression::
+
+ >>> ws1 = SearchWord(word="SomeWord")
+ >>> ws1.word_insensitive == "sOmEwOrD"
+ True
+ >>> ws1.word_insensitive == "XOmEwOrX"
+ False
+ >>> print(ws1.word_insensitive)
+ someword
+
+The Hybrid Value pattern is very useful for any kind of value that may have
+multiple representations, such as timestamps, time deltas, units of
+measurement, currencies and encrypted passwords.
+
+.. seealso::
+
+ `Hybrids and Value Agnostic Types
+ <https://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_
+ - on the techspot.zzzeek.org blog
+
+ `Value Agnostic Types, Part II
+ <https://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ -
+ on the techspot.zzzeek.org blog
+
+.. _hybrid_transformers:
+
+Building Transformers
+----------------------
+
+A *transformer* is an object which can receive a :class:`_query.Query`
+object and
+return a new one. The :class:`_query.Query` object includes a method
+:meth:`.with_transformation` that returns a new :class:`_query.Query`
+transformed by
+the given function.
+
+We can combine this with the :class:`.Comparator` class to produce one type
+of recipe which can both set up the FROM clause of a query as well as assign
+filtering criterion.
+
+Consider a mapped class ``Node``, which assembles using adjacency list into a
+hierarchical tree pattern::
+
+ from sqlalchemy import Column, Integer, ForeignKey
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+Suppose we wanted to add an accessor ``grandparent``. This would return the
+``parent`` of ``Node.parent``. When we have an instance of ``Node``, this is
+simple::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class Node(Base):
+ # ...
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+For the expression, things are not so clear. We'd need to construct a
+:class:`_query.Query` where we :meth:`_query.Query.join` twice along
+``Node.parent`` to get to the ``grandparent``. We can instead return a
+transforming callable that we'll combine with the :class:`.Comparator` class to
+receive any :class:`_query.Query` object, and return a new one that's joined to
+the ``Node.parent`` attribute and filtered based on the given criterion::
+
+ from sqlalchemy.ext.hybrid import Comparator
+
+ class GrandparentTransformer(Comparator):
+ def operate(self, op, other, **kwargs):
+ def transform(q):
+ cls = self.__clause_element__()
+ parent_alias = aliased(cls)
+ return q.join(parent_alias, cls.parent).filter(
+ op(parent_alias.parent, other, **kwargs)
+ )
+
+ return transform
+
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+ @grandparent.comparator
+ def grandparent(cls):
+ return GrandparentTransformer(cls)
+
+The ``GrandparentTransformer`` overrides the core :meth:`.Operators.operate`
+method at the base of the :class:`.Comparator` hierarchy to return a query-
+transforming callable, which then runs the given comparison operation in a
+particular context. Such as, in the example above, the ``operate`` method is
+called, given the :attr:`.Operators.eq` callable as well as the right side of
+the comparison ``Node(id=5)``. A function ``transform`` is then returned which
+will transform a :class:`_query.Query` first to join to ``Node.parent``,
+then to
+compare ``parent_alias`` using :attr:`.Operators.eq` against the left and right
+sides, passing into :meth:`_query.Query.filter`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent==Node(id=5)).\
+ ... all()
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+We can modify the pattern to be more verbose but flexible by separating the
+"join" step from the "filter" step. The tricky part here is ensuring that
+successive instances of ``GrandparentTransformer`` use the same
+:class:`.AliasedClass` object against ``Node``. Below we use a simple
+memoizing approach that associates a ``GrandparentTransformer`` with each
+class::
+
+ class Node(Base):
+
+ # ...
+
+ @grandparent.comparator
+ def grandparent(cls):
+ # memoize a GrandparentTransformer
+ # per class
+ if '_gp' not in cls.__dict__:
+ cls._gp = GrandparentTransformer(cls)
+ return cls._gp
+
+ class GrandparentTransformer(Comparator):
+
+ def __init__(self, cls):
+ self.parent_alias = aliased(cls)
+
+ @property
+ def join(self):
+ def go(q):
+ return q.join(self.parent_alias, Node.parent)
+ return go
+
+ def operate(self, op, other, **kwargs):
+ return op(self.parent_alias.parent, other, **kwargs)
+
+.. sourcecode:: pycon+sql
+
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent.join).\
+ ... filter(Node.grandparent==Node(id=5))
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+The "transformer" pattern is an experimental pattern that starts to make usage
+of some functional programming paradigms. While it's only recommended for
+advanced and/or patient developers, there's probably a whole lot of amazing
+things it can be used for.
+
+""" # noqa
+from .. import util
+from ..orm import attributes
+from ..orm import interfaces
+
+HYBRID_METHOD = util.symbol("HYBRID_METHOD")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+
+class hybrid_method(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python object method with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_METHOD
+
+ def __init__(self, func, expr=None):
+ """Create a new :class:`.hybrid_method`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_method
+
+ class SomeClass(object):
+ @hybrid_method
+ def value(self, x, y):
+ return self._value + x + y
+
+ @value.expression
+ def value(self, x, y):
+ return func.some_function(self._value, x, y)
+
+ """
+ self.func = func
+ self.expression(expr or func)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.expr.__get__(owner, owner.__class__)
+ else:
+ return self.func.__get__(instance, owner)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a
+ SQL-expression producing method."""
+
+ self.expr = expr
+ if not self.expr.__doc__:
+ self.expr.__doc__ = self.func.__doc__
+ return self
+
+
+class hybrid_property(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python descriptor with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_PROPERTY
+
+ def __init__(
+ self,
+ fget,
+ fset=None,
+ fdel=None,
+ expr=None,
+ custom_comparator=None,
+ update_expr=None,
+ ):
+ """Create a new :class:`.hybrid_property`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class SomeClass(object):
+ @hybrid_property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._value = value
+
+ """
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ self.expr = expr
+ self.custom_comparator = custom_comparator
+ self.update_expr = update_expr
+ util.update_wrapper(self, fget)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self._expr_comparator(owner)
+ else:
+ return self.fget(instance)
+
+ def __set__(self, instance, value):
+ if self.fset is None:
+ raise AttributeError("can't set attribute")
+ self.fset(instance, value)
+
+ def __delete__(self, instance):
+ if self.fdel is None:
+ raise AttributeError("can't delete attribute")
+ self.fdel(instance)
+
+ def _copy(self, **kw):
+ defaults = {
+ key: value
+ for key, value in self.__dict__.items()
+ if not key.startswith("_")
+ }
+ defaults.update(**kw)
+ return type(self)(**defaults)
+
+ @property
+ def overrides(self):
+ """Prefix for a method that is overriding an existing attribute.
+
+ The :attr:`.hybrid_property.overrides` accessor just returns
+ this hybrid object, which when called at the class level from
+ a parent class, will de-reference the "instrumented attribute"
+ normally returned at this level, and allow modifying decorators
+ like :meth:`.hybrid_property.expression` and
+ :meth:`.hybrid_property.comparator`
+ to be used without conflicting with the same-named attributes
+ normally present on the :class:`.QueryableAttribute`::
+
+ class SuperClass(object):
+ # ...
+
+ @hybrid_property
+ def foobar(self):
+ return self._foobar
+
+ class SubClass(SuperClass):
+ # ...
+
+ @SuperClass.foobar.overrides.expression
+ def foobar(cls):
+ return func.subfoobar(self._foobar)
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`hybrid_reuse_subclass`
+
+ """
+ return self
+
+ def getter(self, fget):
+ """Provide a modifying decorator that defines a getter method.
+
+ .. versionadded:: 1.2
+
+ """
+
+ return self._copy(fget=fget)
+
+ def setter(self, fset):
+ """Provide a modifying decorator that defines a setter method."""
+
+ return self._copy(fset=fset)
+
+ def deleter(self, fdel):
+ """Provide a modifying decorator that defines a deletion method."""
+
+ return self._copy(fdel=fdel)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a SQL-expression
+ producing method.
+
+ When a hybrid is invoked at the class level, the SQL expression given
+ here is wrapped inside of a specialized :class:`.QueryableAttribute`,
+ which is the same kind of object used by the ORM to represent other
+ mapped attributes. The reason for this is so that other class-level
+ attributes such as docstrings and a reference to the hybrid itself may
+ be maintained within the structure that's returned, without any
+ modifications to the original SQL expression passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as well as this hybrid object.
+ However, that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ .. seealso::
+
+ :ref:`hybrid_distinct_expression`
+
+ """
+
+ return self._copy(expr=expr)
+
+ def comparator(self, comparator):
+ """Provide a modifying decorator that defines a custom
+ comparator producing method.
+
+ The return value of the decorated method should be an instance of
+ :class:`~.hybrid.Comparator`.
+
+ .. note:: The :meth:`.hybrid_property.comparator` decorator
+ **replaces** the use of the :meth:`.hybrid_property.expression`
+ decorator. They cannot be used together.
+
+ When a hybrid is invoked at the class level, the
+ :class:`~.hybrid.Comparator` object given here is wrapped inside of a
+ specialized :class:`.QueryableAttribute`, which is the same kind of
+ object used by the ORM to represent other mapped attributes. The
+ reason for this is so that other class-level attributes such as
+ docstrings and a reference to the hybrid itself may be maintained
+ within the structure that's returned, without any modifications to the
+ original comparator object passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as this hybrid object. However,
+ that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ """
+ return self._copy(custom_comparator=comparator)
+
+ def update_expression(self, meth):
+ """Provide a modifying decorator that defines an UPDATE tuple
+ producing method.
+
+ The method accepts a single value, which is the value to be
+ rendered into the SET clause of an UPDATE statement. The method
+ should then process this value into individual column expressions
+ that fit into the ultimate SET clause, and return them as a
+ sequence of 2-tuples. Each tuple
+ contains a column expression as the key and a value to be rendered.
+
+ E.g.::
+
+ class Person(Base):
+ # ...
+
+ first_name = Column(String)
+ last_name = Column(String)
+
+ @hybrid_property
+ def fullname(self):
+ return first_name + " " + last_name
+
+ @fullname.update_expression
+ def fullname(cls, value):
+ fname, lname = value.split(" ", 1)
+ return [
+ (cls.first_name, fname),
+ (cls.last_name, lname)
+ ]
+
+ .. versionadded:: 1.2
+
+ """
+ return self._copy(update_expr=meth)
+
+ @util.memoized_property
+ def _expr_comparator(self):
+ if self.custom_comparator is not None:
+ return self._get_comparator(self.custom_comparator)
+ elif self.expr is not None:
+ return self._get_expr(self.expr)
+ else:
+ return self._get_expr(self.fget)
+
+ def _get_expr(self, expr):
+ def _expr(cls):
+ return ExprComparator(cls, expr(cls), self)
+
+ util.update_wrapper(_expr, expr)
+
+ return self._get_comparator(_expr)
+
+ def _get_comparator(self, comparator):
+
+ proxy_attr = attributes.create_proxied_attribute(self)
+
+ def expr_comparator(owner):
+ # because this is the descriptor protocol, we don't really know
+ # what our attribute name is. so search for it through the
+ # MRO.
+ for lookup in owner.__mro__:
+ if self.__name__ in lookup.__dict__:
+ if lookup.__dict__[self.__name__] is self:
+ name = self.__name__
+ break
+ else:
+ name = attributes.NO_KEY
+
+ return proxy_attr(
+ owner,
+ name,
+ self,
+ comparator(owner),
+ doc=comparator.__doc__ or self.__doc__,
+ )
+
+ return expr_comparator
+
+
+class Comparator(interfaces.PropComparator):
+ """A helper class that allows easy construction of custom
+ :class:`~.orm.interfaces.PropComparator`
+ classes for usage with hybrids."""
+
+ property = None
+
+ def __init__(self, expression):
+ self.expression = expression
+
+ def __clause_element__(self):
+ expr = self.expression
+ if hasattr(expr, "__clause_element__"):
+ expr = expr.__clause_element__()
+ return expr
+
+ def adapt_to_entity(self, adapt_to_entity):
+ # interesting....
+ return self
+
+
+class ExprComparator(Comparator):
+ def __init__(self, cls, expression, hybrid):
+ self.cls = cls
+ self.expression = expression
+ self.hybrid = hybrid
+
+ def __getattr__(self, key):
+ return getattr(self.expression, key)
+
+ @property
+ def info(self):
+ return self.hybrid.info
+
+ def _bulk_update_tuples(self, value):
+ if isinstance(self.expression, attributes.QueryableAttribute):
+ return self.expression._bulk_update_tuples(value)
+ elif self.hybrid.update_expr is not None:
+ return self.hybrid.update_expr(self.cls, value)
+ else:
+ return [(self.expression, value)]
+
+ @property
+ def property(self):
+ return self.expression.property
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.expression, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.expression, **kwargs)
diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py
new file mode 100644
index 0000000..7cbac54
--- /dev/null
+++ b/lib/sqlalchemy/ext/indexable.py
@@ -0,0 +1,352 @@
+# ext/index.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define attributes on ORM-mapped classes that have "index" attributes for
+columns with :class:`_types.Indexable` types.
+
+"index" means the attribute is associated with an element of an
+:class:`_types.Indexable` column with the predefined index to access it.
+The :class:`_types.Indexable` types include types such as
+:class:`_types.ARRAY`, :class:`_types.JSON` and
+:class:`_postgresql.HSTORE`.
+
+
+
+The :mod:`~sqlalchemy.ext.indexable` extension provides
+:class:`_schema.Column`-like interface for any element of an
+:class:`_types.Indexable` typed column. In simple cases, it can be
+treated as a :class:`_schema.Column` - mapped attribute.
+
+
+.. versionadded:: 1.1
+
+Synopsis
+========
+
+Given ``Person`` as a model with a primary key and JSON data field.
+While this field may have any number of elements encoded within it,
+we would like to refer to the element called ``name`` individually
+as a dedicated attribute which behaves like a standalone column::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ name = index_property('data', 'name')
+
+
+Above, the ``name`` attribute now behaves like a mapped column. We
+can compose a new ``Person`` and set the value of ``name``::
+
+ >>> person = Person(name='Alchemist')
+
+The value is now accessible::
+
+ >>> person.name
+ 'Alchemist'
+
+Behind the scenes, the JSON field was initialized to a new blank dictionary
+and the field was set::
+
+ >>> person.data
+ {"name": "Alchemist'}
+
+The field is mutable in place::
+
+ >>> person.name = 'Renamed'
+ >>> person.name
+ 'Renamed'
+ >>> person.data
+ {'name': 'Renamed'}
+
+When using :class:`.index_property`, the change that we make to the indexable
+structure is also automatically tracked as history; we no longer need
+to use :class:`~.mutable.MutableDict` in order to track this change
+for the unit of work.
+
+Deletions work normally as well::
+
+ >>> del person.name
+ >>> person.data
+ {}
+
+Above, deletion of ``person.name`` deletes the value from the dictionary,
+but not the dictionary itself.
+
+A missing key will produce ``AttributeError``::
+
+ >>> person = Person()
+ >>> person.name
+ ...
+ AttributeError: 'name'
+
+Unless you set a default value::
+
+ >>> class Person(Base):
+ >>> __tablename__ = 'person'
+ >>>
+ >>> id = Column(Integer, primary_key=True)
+ >>> data = Column(JSON)
+ >>>
+ >>> name = index_property('data', 'name', default=None) # See default
+
+ >>> person = Person()
+ >>> print(person.name)
+ None
+
+
+The attributes are also accessible at the class level.
+Below, we illustrate ``Person.name`` used to generate
+an indexed SQL criteria::
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ >>> query = session.query(Person).filter(Person.name == 'Alchemist')
+
+The above query is equivalent to::
+
+ >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
+
+Multiple :class:`.index_property` objects can be chained to produce
+multiple levels of indexing::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ birthday = index_property('data', 'birthday')
+ year = index_property('birthday', 'year')
+ month = index_property('birthday', 'month')
+ day = index_property('birthday', 'day')
+
+Above, a query such as::
+
+ q = session.query(Person).filter(Person.year == '1980')
+
+On a PostgreSQL backend, the above query will render as::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
+
+Default Values
+==============
+
+:class:`.index_property` includes special behaviors for when the indexed
+data structure does not exist, and a set operation is called:
+
+* For an :class:`.index_property` that is given an integer index value,
+ the default data structure will be a Python list of ``None`` values,
+ at least as long as the index value; the value is then set at its
+ place in the list. This means for an index value of zero, the list
+ will be initialized to ``[None]`` before setting the given value,
+ and for an index value of five, the list will be initialized to
+ ``[None, None, None, None, None]`` before setting the fifth element
+ to the given value. Note that an existing list is **not** extended
+ in place to receive a value.
+
+* for an :class:`.index_property` that is given any other kind of index
+ value (e.g. strings usually), a Python dictionary is used as the
+ default data structure.
+
+* The default data structure can be set to any Python callable using the
+ :paramref:`.index_property.datatype` parameter, overriding the previous
+ rules.
+
+
+Subclassing
+===========
+
+:class:`.index_property` can be subclassed, in particular for the common
+use case of providing coercion of values or SQL expressions as they are
+accessed. Below is a common recipe for use with a PostgreSQL JSON type,
+where we want to also include automatic casting plus ``astext()``::
+
+ class pg_json_property(index_property):
+ def __init__(self, attr_name, index, cast_type):
+ super(pg_json_property, self).__init__(attr_name, index)
+ self.cast_type = cast_type
+
+ def expr(self, model):
+ expr = super(pg_json_property, self).expr(model)
+ return expr.astext.cast(self.cast_type)
+
+The above subclass can be used with the PostgreSQL-specific
+version of :class:`_postgresql.JSON`::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.dialects.postgresql import JSON
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ age = pg_json_property('data', 'age', Integer)
+
+The ``age`` attribute at the instance level works as before; however
+when rendering SQL, PostgreSQL's ``->>`` operator will be used
+for indexed access, instead of the usual index operator of ``->``::
+
+ >>> query = session.query(Person).filter(Person.age < 20)
+
+The above query will render::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
+
+""" # noqa
+from __future__ import absolute_import
+
+from .. import inspect
+from .. import util
+from ..ext.hybrid import hybrid_property
+from ..orm.attributes import flag_modified
+
+
+__all__ = ["index_property"]
+
+
+class index_property(hybrid_property): # noqa
+ """A property generator. The generated property describes an object
+ attribute that corresponds to an :class:`_types.Indexable`
+ column.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :mod:`sqlalchemy.ext.indexable`
+
+ """
+
+ _NO_DEFAULT_ARGUMENT = object()
+
+ def __init__(
+ self,
+ attr_name,
+ index,
+ default=_NO_DEFAULT_ARGUMENT,
+ datatype=None,
+ mutable=True,
+ onebased=True,
+ ):
+ """Create a new :class:`.index_property`.
+
+ :param attr_name:
+ An attribute name of an `Indexable` typed column, or other
+ attribute that returns an indexable structure.
+ :param index:
+ The index to be used for getting and setting this value. This
+ should be the Python-side index value for integers.
+ :param default:
+ A value which will be returned instead of `AttributeError`
+ when there is not a value at given index.
+ :param datatype: default datatype to use when the field is empty.
+ By default, this is derived from the type of index used; a
+ Python list for an integer index, or a Python dictionary for
+ any other style of index. For a list, the list will be
+ initialized to a list of None values that is at least
+ ``index`` elements long.
+ :param mutable: if False, writes and deletes to the attribute will
+ be disallowed.
+ :param onebased: assume the SQL representation of this value is
+ one-based; that is, the first index in SQL is 1, not zero.
+ """
+
+ if mutable:
+ super(index_property, self).__init__(
+ self.fget, self.fset, self.fdel, self.expr
+ )
+ else:
+ super(index_property, self).__init__(
+ self.fget, None, None, self.expr
+ )
+ self.attr_name = attr_name
+ self.index = index
+ self.default = default
+ is_numeric = isinstance(index, int)
+ onebased = is_numeric and onebased
+
+ if datatype is not None:
+ self.datatype = datatype
+ else:
+ if is_numeric:
+ self.datatype = lambda: [None for x in range(index + 1)]
+ else:
+ self.datatype = dict
+ self.onebased = onebased
+
+ def _fget_default(self, err=None):
+ if self.default == self._NO_DEFAULT_ARGUMENT:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ return self.default
+
+ def fget(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ return self._fget_default()
+ try:
+ value = column_value[self.index]
+ except (KeyError, IndexError) as err:
+ return self._fget_default(err)
+ else:
+ return value
+
+ def fset(self, instance, value):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name, None)
+ if column_value is None:
+ column_value = self.datatype()
+ setattr(instance, attr_name, column_value)
+ column_value[self.index] = value
+ setattr(instance, attr_name, column_value)
+ if attr_name in inspect(instance).mapper.attrs:
+ flag_modified(instance, attr_name)
+
+ def fdel(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ raise AttributeError(self.attr_name)
+ try:
+ del column_value[self.index]
+ except KeyError as err:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ setattr(instance, attr_name, column_value)
+ flag_modified(instance, attr_name)
+
+ def expr(self, model):
+ column = getattr(model, self.attr_name)
+ index = self.index
+ if self.onebased:
+ index += 1
+ return column[index]
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
new file mode 100644
index 0000000..54f3e64
--- /dev/null
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -0,0 +1,416 @@
+"""Extensible class instrumentation.
+
+The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
+systems of class instrumentation within the ORM. Class instrumentation
+refers to how the ORM places attributes on the class which maintain
+data and track changes to that data, as well as event hooks installed
+on the class.
+
+.. note::
+ The extension package is provided for the benefit of integration
+ with other object management packages, which already perform
+ their own instrumentation. It is not intended for general use.
+
+For examples of how the instrumentation extension is used,
+see the example :ref:`examples_instrumentation`.
+
+"""
+import weakref
+
+from .. import util
+from ..orm import attributes
+from ..orm import base as orm_base
+from ..orm import collections
+from ..orm import exc as orm_exc
+from ..orm import instrumentation as orm_instrumentation
+from ..orm.instrumentation import _default_dict_getter
+from ..orm.instrumentation import _default_manager_getter
+from ..orm.instrumentation import _default_state_getter
+from ..orm.instrumentation import ClassManager
+from ..orm.instrumentation import InstrumentationFactory
+
+
+INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
+"""Attribute, elects custom instrumentation when present on a mapped class.
+
+Allows a class to specify a slightly or wildly different technique for
+tracking changes made to mapped attributes and collections.
+
+Only one instrumentation implementation is allowed in a given object
+inheritance hierarchy.
+
+The value of this attribute must be a callable and will be passed a class
+object. The callable must return one of:
+
+ - An instance of an :class:`.InstrumentationManager` or subclass
+ - An object implementing all or some of InstrumentationManager (TODO)
+ - A dictionary of callables, implementing all or some of the above (TODO)
+ - An instance of a :class:`.ClassManager` or subclass
+
+This attribute is consulted by SQLAlchemy instrumentation
+resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
+has been imported. If custom finders are installed in the global
+instrumentation_finders list, they may or may not choose to honor this
+attribute.
+
+"""
+
+
+def find_native_user_instrumentation_hook(cls):
+ """Find user-specified instrumentation management for a class."""
+ return getattr(cls, INSTRUMENTATION_MANAGER, None)
+
+
+instrumentation_finders = [find_native_user_instrumentation_hook]
+"""An extensible sequence of callables which return instrumentation
+implementations
+
+When a class is registered, each callable will be passed a class object.
+If None is returned, the
+next finder in the sequence is consulted. Otherwise the return must be an
+instrumentation factory that follows the same guidelines as
+sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
+
+By default, the only finder is find_native_user_instrumentation_hook, which
+searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
+ClassManager instrumentation is used.
+
+"""
+
+
+class ExtendedInstrumentationRegistry(InstrumentationFactory):
+ """Extends :class:`.InstrumentationFactory` with additional
+ bookkeeping, to accommodate multiple types of
+ class managers.
+
+ """
+
+ _manager_finders = weakref.WeakKeyDictionary()
+ _state_finders = weakref.WeakKeyDictionary()
+ _dict_finders = weakref.WeakKeyDictionary()
+ _extended = False
+
+ def _locate_extended_factory(self, class_):
+ for finder in instrumentation_finders:
+ factory = finder(class_)
+ if factory is not None:
+ manager = self._extended_class_manager(class_, factory)
+ return manager, factory
+ else:
+ return None, None
+
+ def _check_conflicts(self, class_, factory):
+ existing_factories = self._collect_management_factories_for(
+ class_
+ ).difference([factory])
+ if existing_factories:
+ raise TypeError(
+ "multiple instrumentation implementations specified "
+ "in %s inheritance hierarchy: %r"
+ % (class_.__name__, list(existing_factories))
+ )
+
+ def _extended_class_manager(self, class_, factory):
+ manager = factory(class_)
+ if not isinstance(manager, ClassManager):
+ manager = _ClassInstrumentationAdapter(class_, manager)
+
+ if factory != ClassManager and not self._extended:
+ # somebody invoked a custom ClassManager.
+ # reinstall global "getter" functions with the more
+ # expensive ones.
+ self._extended = True
+ _install_instrumented_lookups()
+
+ self._manager_finders[class_] = manager.manager_getter()
+ self._state_finders[class_] = manager.state_getter()
+ self._dict_finders[class_] = manager.dict_getter()
+ return manager
+
+ def _collect_management_factories_for(self, cls):
+ """Return a collection of factories in play or specified for a
+ hierarchy.
+
+ Traverses the entire inheritance graph of a cls and returns a
+ collection of instrumentation factories for those classes. Factories
+ are extracted from active ClassManagers, if available, otherwise
+ instrumentation_finders is consulted.
+
+ """
+ hierarchy = util.class_hierarchy(cls)
+ factories = set()
+ for member in hierarchy:
+ manager = self.manager_of_class(member)
+ if manager is not None:
+ factories.add(manager.factory)
+ else:
+ for finder in instrumentation_finders:
+ factory = finder(member)
+ if factory is not None:
+ break
+ else:
+ factory = None
+ factories.add(factory)
+ factories.discard(None)
+ return factories
+
+ def unregister(self, class_):
+ super(ExtendedInstrumentationRegistry, self).unregister(class_)
+ if class_ in self._manager_finders:
+ del self._manager_finders[class_]
+ del self._state_finders[class_]
+ del self._dict_finders[class_]
+
+ def manager_of_class(self, cls):
+ if cls is None:
+ return None
+ try:
+ finder = self._manager_finders.get(cls, _default_manager_getter)
+ except TypeError:
+ # due to weakref lookup on invalid object
+ return None
+ else:
+ return finder(cls)
+
+ def state_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._state_finders.get(
+ instance.__class__, _default_state_getter
+ )(instance)
+
+ def dict_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._dict_finders.get(
+ instance.__class__, _default_dict_getter
+ )(instance)
+
+
+orm_instrumentation._instrumentation_factory = (
+ _instrumentation_factory
+) = ExtendedInstrumentationRegistry()
+orm_instrumentation.instrumentation_finders = instrumentation_finders
+
+
+class InstrumentationManager(object):
+ """User-defined class instrumentation extension.
+
+ :class:`.InstrumentationManager` can be subclassed in order
+ to change
+ how class instrumentation proceeds. This class exists for
+ the purposes of integration with other object management
+ frameworks which would like to entirely modify the
+ instrumentation methodology of the ORM, and is not intended
+ for regular usage. For interception of class instrumentation
+ events, see :class:`.InstrumentationEvents`.
+
+ The API for this class should be considered as semi-stable,
+ and may change slightly with new releases.
+
+ """
+
+ # r4361 added a mandatory (cls) constructor to this interface.
+ # given that, perhaps class_ should be dropped from all of these
+ # signatures.
+
+ def __init__(self, class_):
+ pass
+
+ def manage(self, class_, manager):
+ setattr(class_, "_default_class_manager", manager)
+
+ def unregister(self, class_, manager):
+ delattr(class_, "_default_class_manager")
+
+ def manager_getter(self, class_):
+ def get(cls):
+ return cls._default_class_manager
+
+ return get
+
+ def instrument_attribute(self, class_, key, inst):
+ pass
+
+ def post_configure_attribute(self, class_, key, inst):
+ pass
+
+ def install_descriptor(self, class_, key, inst):
+ setattr(class_, key, inst)
+
+ def uninstall_descriptor(self, class_, key):
+ delattr(class_, key)
+
+ def install_member(self, class_, key, implementation):
+ setattr(class_, key, implementation)
+
+ def uninstall_member(self, class_, key):
+ delattr(class_, key)
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def get_instance_dict(self, class_, instance):
+ return instance.__dict__
+
+ def initialize_instance_dict(self, class_, instance):
+ pass
+
+ def install_state(self, class_, instance, state):
+ setattr(instance, "_default_state", state)
+
+ def remove_state(self, class_, instance):
+ delattr(instance, "_default_state")
+
+ def state_getter(self, class_):
+ return lambda instance: getattr(instance, "_default_state")
+
+ def dict_getter(self, class_):
+ return lambda inst: self.get_instance_dict(class_, inst)
+
+
+class _ClassInstrumentationAdapter(ClassManager):
+ """Adapts a user-defined InstrumentationManager to a ClassManager."""
+
+ def __init__(self, class_, override):
+ self._adapted = override
+ self._get_state = self._adapted.state_getter(class_)
+ self._get_dict = self._adapted.dict_getter(class_)
+
+ ClassManager.__init__(self, class_)
+
+ def manage(self):
+ self._adapted.manage(self.class_, self)
+
+ def unregister(self):
+ self._adapted.unregister(self.class_, self)
+
+ def manager_getter(self):
+ return self._adapted.manager_getter(self.class_)
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ ClassManager.instrument_attribute(self, key, inst, propagated)
+ if not propagated:
+ self._adapted.instrument_attribute(self.class_, key, inst)
+
+ def post_configure_attribute(self, key):
+ super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
+ self._adapted.post_configure_attribute(self.class_, key, self[key])
+
+ def install_descriptor(self, key, inst):
+ self._adapted.install_descriptor(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ self._adapted.uninstall_descriptor(self.class_, key)
+
+ def install_member(self, key, implementation):
+ self._adapted.install_member(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ self._adapted.uninstall_member(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return self._adapted.instrument_collection_class(
+ self.class_, key, collection_class
+ )
+
+ def initialize_collection(self, key, state, factory):
+ delegate = getattr(self._adapted, "initialize_collection", None)
+ if delegate:
+ return delegate(key, state, factory)
+ else:
+ return ClassManager.initialize_collection(
+ self, key, state, factory
+ )
+
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ self.setup_instance(instance, state)
+ return instance
+
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
+
+ A private convenience method used by the __init__ decorator.
+ """
+ if self.has_state(instance):
+ return False
+ else:
+ return self.setup_instance(instance)
+
+ def setup_instance(self, instance, state=None):
+ self._adapted.initialize_instance_dict(self.class_, instance)
+
+ if state is None:
+ state = self._state_constructor(instance, self)
+
+ # the given instance is assumed to have no state
+ self._adapted.install_state(self.class_, instance, state)
+ return state
+
+ def teardown_instance(self, instance):
+ self._adapted.remove_state(self.class_, instance)
+
+ def has_state(self, instance):
+ try:
+ self._get_state(instance)
+ except orm_exc.NO_STATE:
+ return False
+ else:
+ return True
+
+ def state_getter(self):
+ return self._get_state
+
+ def dict_getter(self):
+ return self._get_dict
+
+
+def _install_instrumented_lookups():
+ """Replace global class/object management functions
+ with ExtendedInstrumentationRegistry implementations, which
+ allow multiple types of class managers to be present,
+ at the cost of performance.
+
+ This function is called only by ExtendedInstrumentationRegistry
+ and unit tests specific to this behavior.
+
+ The _reinstall_default_lookups() function can be called
+ after this one to re-establish the default functions.
+
+ """
+ _install_lookups(
+ dict(
+ instance_state=_instrumentation_factory.state_of,
+ instance_dict=_instrumentation_factory.dict_of,
+ manager_of_class=_instrumentation_factory.manager_of_class,
+ )
+ )
+
+
+def _reinstall_default_lookups():
+ """Restore simplified lookups."""
+ _install_lookups(
+ dict(
+ instance_state=_default_state_getter,
+ instance_dict=_default_dict_getter,
+ manager_of_class=_default_manager_getter,
+ )
+ )
+ _instrumentation_factory._extended = False
+
+
+def _install_lookups(lookups):
+ global instance_state, instance_dict, manager_of_class
+ instance_state = lookups["instance_state"]
+ instance_dict = lookups["instance_dict"]
+ manager_of_class = lookups["manager_of_class"]
+ orm_base.instance_state = (
+ attributes.instance_state
+ ) = orm_instrumentation.instance_state = instance_state
+ orm_base.instance_dict = (
+ attributes.instance_dict
+ ) = orm_instrumentation.instance_dict = instance_dict
+ orm_base.manager_of_class = (
+ attributes.manager_of_class
+ ) = orm_instrumentation.manager_of_class = manager_of_class
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
new file mode 100644
index 0000000..cbec06a
--- /dev/null
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -0,0 +1,958 @@
+# ext/mutable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provide support for tracking of in-place changes to scalar values,
+which are propagated into ORM change events on owning parent objects.
+
+.. _mutable_scalars:
+
+Establishing Mutability on Scalar Column Values
+===============================================
+
+A typical example of a "mutable" structure is a Python dictionary.
+Following the example introduced in :ref:`types_toplevel`, we
+begin with a custom type that marshals Python dictionaries into
+JSON strings before being persisted::
+
+ from sqlalchemy.types import TypeDecorator, VARCHAR
+ import json
+
+ class JSONEncodedDict(TypeDecorator):
+ "Represents an immutable structure as a json-encoded string."
+
+ impl = VARCHAR
+
+ def process_bind_param(self, value, dialect):
+ if value is not None:
+ value = json.dumps(value)
+ return value
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = json.loads(value)
+ return value
+
+The usage of ``json`` is only for the purposes of example. The
+:mod:`sqlalchemy.ext.mutable` extension can be used
+with any type whose target Python type may be mutable, including
+:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc.
+
+When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
+tracks all parents which reference it. Below, we illustrate a simple
+version of the :class:`.MutableDict` dictionary object, which applies
+the :class:`.Mutable` mixin to a plain Python dictionary::
+
+ from sqlalchemy.ext.mutable import Mutable
+
+ class MutableDict(Mutable, dict):
+ @classmethod
+ def coerce(cls, key, value):
+ "Convert plain dictionaries to MutableDict."
+
+ if not isinstance(value, MutableDict):
+ if isinstance(value, dict):
+ return MutableDict(value)
+
+ # this call will raise ValueError
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __setitem__(self, key, value):
+ "Detect dictionary set events and emit change events."
+
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def __delitem__(self, key):
+ "Detect dictionary del events and emit change events."
+
+ dict.__delitem__(self, key)
+ self.changed()
+
+The above dictionary class takes the approach of subclassing the Python
+built-in ``dict`` to produce a dict
+subclass which routes all mutation events through ``__setitem__``. There are
+variants on this approach, such as subclassing ``UserDict.UserDict`` or
+``collections.MutableMapping``; the part that's important to this example is
+that the :meth:`.Mutable.changed` method is called whenever an in-place
+change to the datastructure takes place.
+
+We also redefine the :meth:`.Mutable.coerce` method which will be used to
+convert any values that are not instances of ``MutableDict``, such
+as the plain dictionaries returned by the ``json`` module, into the
+appropriate type. Defining this method is optional; we could just as well
+created our ``JSONEncodedDict`` such that it always returns an instance
+of ``MutableDict``, and additionally ensured that all calling code
+uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
+overridden, any values applied to a parent object which are not instances
+of the mutable type will raise a ``ValueError``.
+
+Our new ``MutableDict`` type offers a class method
+:meth:`~.Mutable.as_mutable` which we can use within column metadata
+to associate with types. This method grabs the given type object or
+class and associates a listener that will detect all future mappings
+of this type, applying event listening instrumentation to the mapped
+attribute. Such as, with classical table metadata::
+
+ from sqlalchemy import Table, Column, Integer
+
+ my_data = Table('my_data', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MutableDict.as_mutable(JSONEncodedDict))
+ )
+
+Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
+(if the type object was not an instance already), which will intercept any
+attributes which are mapped against this type. Below we establish a simple
+mapping against the ``my_data`` table::
+
+ from sqlalchemy import mapper
+
+ class MyDataClass(object):
+ pass
+
+ # associates mutation listeners with MyDataClass.data
+ mapper(MyDataClass, my_data)
+
+The ``MyDataClass.data`` member will now be notified of in place changes
+to its value.
+
+There's no difference in usage when using declarative::
+
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+Any in-place changes to the ``MyDataClass.data`` member
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> m1 = MyDataClass(data={'value1':'foo'})
+ >>> sess.add(m1)
+ >>> sess.commit()
+
+ >>> m1.data['value1'] = 'bar'
+ >>> assert m1 in sess.dirty
+ True
+
+The ``MutableDict`` can be associated with all future instances
+of ``JSONEncodedDict`` in one step, using
+:meth:`~.Mutable.associate_with`. This is similar to
+:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
+of ``MutableDict`` in all mappings unconditionally, without
+the need to declare it individually::
+
+ MutableDict.associate_with(JSONEncodedDict)
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(JSONEncodedDict)
+
+
+Supporting Pickling
+--------------------
+
+The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
+placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
+stores a mapping of parent mapped objects keyed to the attribute name under
+which they are associated with this value. ``WeakKeyDictionary`` objects are
+not picklable, due to the fact that they contain weakrefs and function
+callbacks. In our case, this is a good thing, since if this dictionary were
+picklable, it could lead to an excessively large pickle size for our value
+objects that are pickled by themselves outside of the context of the parent.
+The developer responsibility here is only to provide a ``__getstate__`` method
+that excludes the :meth:`~MutableBase._parents` collection from the pickle
+stream::
+
+ class MyMutableType(Mutable):
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop('_parents', None)
+ return d
+
+With our dictionary example, we need to return the contents of the dict itself
+(and also restore them on __setstate__)::
+
+ class MutableDict(Mutable, dict):
+ # ....
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+In the case that our mutable value object is pickled as it is attached to one
+or more parent objects that are also part of the pickle, the :class:`.Mutable`
+mixin will re-establish the :attr:`.Mutable._parents` collection on each value
+object as the owning parents themselves are unpickled.
+
+Receiving Events
+----------------
+
+The :meth:`.AttributeEvents.modified` event handler may be used to receive
+an event when a mutable scalar emits a change event. This event handler
+is called when the :func:`.attributes.flag_modified` function is called
+from within the mutable extension::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy import event
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+ @event.listens_for(MyDataClass.data, "modified")
+ def modified_json(instance):
+ print("json value modified:", instance.data)
+
+.. _mutable_composites:
+
+Establishing Mutability on Composites
+=====================================
+
+Composites are a special ORM feature which allow a single scalar attribute to
+be assigned an object value which represents information "composed" from one
+or more columns from the underlying mapped table. The usual example is that of
+a geometric "point", and is introduced in :ref:`mapper_composite`.
+
+As is the case with :class:`.Mutable`, the user-defined composite class
+subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
+change events to its parents via the :meth:`.MutableComposite.changed` method.
+In the case of a composite class, the detection is usually via the usage of
+Python descriptors (i.e. ``@property``), or alternatively via the special
+Python method ``__setattr__()``. Below we expand upon the ``Point`` class
+introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
+and to also route attribute set events via ``__setattr__`` to the
+:meth:`.MutableComposite.changed` method::
+
+ from sqlalchemy.ext.mutable import MutableComposite
+
+ class Point(MutableComposite):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __setattr__(self, key, value):
+ "Intercept set events"
+
+ # set the attribute
+ object.__setattr__(self, key, value)
+
+ # alert all parents to the change
+ self.changed()
+
+ def __composite_values__(self):
+ return self.x, self.y
+
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+The :class:`.MutableComposite` class uses a Python metaclass to automatically
+establish listeners for any usage of :func:`_orm.composite` that specifies our
+``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
+listeners are established which will route change events from ``Point``
+objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
+
+ from sqlalchemy.orm import composite, mapper
+ from sqlalchemy import Table, Column
+
+ vertices = Table('vertices', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ class Vertex(object):
+ pass
+
+ mapper(Vertex, vertices, properties={
+ 'start': composite(Point, vertices.c.x1, vertices.c.y1),
+ 'end': composite(Point, vertices.c.x2, vertices.c.y2)
+ })
+
+Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
+ >>> sess.add(v1)
+ >>> sess.commit()
+
+ >>> v1.end.x = 8
+ >>> assert v1 in sess.dirty
+ True
+
+Coercing Mutable Composites
+---------------------------
+
+The :meth:`.MutableBase.coerce` method is also supported on composite types.
+In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
+method is only called for attribute set operations, not load operations.
+Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
+to using a :func:`.validates` validation routine for all attributes which
+make use of the custom composite type::
+
+ class Point(MutableComposite):
+ # other Point methods
+ # ...
+
+ def coerce(cls, key, value):
+ if isinstance(value, tuple):
+ value = Point(*value)
+ elif not isinstance(value, Point):
+ raise ValueError("tuple or Point expected")
+ return value
+
+Supporting Pickling
+--------------------
+
+As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
+class uses a ``weakref.WeakKeyDictionary`` available via the
+:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
+pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
+to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
+Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
+the minimal form of our ``Point`` class::
+
+ class Point(MutableComposite):
+ # ...
+
+ def __getstate__(self):
+ return self.x, self.y
+
+ def __setstate__(self, state):
+ self.x, self.y = state
+
+As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
+pickling process of the parent's object-relational state so that the
+:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
+
+"""
+from collections import defaultdict
+import weakref
+
+from .. import event
+from .. import inspect
+from .. import types
+from ..orm import Mapper
+from ..orm import mapper
+from ..orm.attributes import flag_modified
+from ..sql.base import SchemaEventTarget
+from ..util import memoized_property
+
+
+class MutableBase(object):
+ """Common base class to :class:`.Mutable`
+ and :class:`.MutableComposite`.
+
+ """
+
+ @memoized_property
+ def _parents(self):
+ """Dictionary of parent object's :class:`.InstanceState`->attribute
+ name on the parent.
+
+ This attribute is a so-called "memoized" property. It initializes
+ itself with a new ``weakref.WeakKeyDictionary`` the first time
+ it is accessed, returning the same object upon subsequent access.
+
+ .. versionchanged:: 1.4 the :class:`.InstanceState` is now used
+ as the key in the weak dictionary rather than the instance
+ itself.
+
+ """
+
+ return weakref.WeakKeyDictionary()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Given a value, coerce it into the target type.
+
+ Can be overridden by custom subclasses to coerce incoming
+ data into a particular type.
+
+ By default, raises ``ValueError``.
+
+ This method is called in different scenarios depending on if
+ the parent class is of type :class:`.Mutable` or of type
+ :class:`.MutableComposite`. In the case of the former, it is called
+ for both attribute-set operations as well as during ORM loading
+ operations. For the latter, it is only called during attribute-set
+ operations; the mechanics of the :func:`.composite` construct
+ handle coercion during load operations.
+
+
+ :param key: string name of the ORM-mapped attribute being set.
+ :param value: the incoming value.
+ :return: the method should return the coerced value, or raise
+ ``ValueError`` if the coercion cannot be completed.
+
+ """
+ if value is None:
+ return None
+ msg = "Attribute '%s' does not accept objects of type %s"
+ raise ValueError(msg % (key, type(value)))
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ """Given a descriptor attribute, return a ``set()`` of the attribute
+ keys which indicate a change in the state of this attribute.
+
+ This is normally just ``set([attribute.key])``, but can be overridden
+ to provide for additional keys. E.g. a :class:`.MutableComposite`
+ augments this set with the attribute keys associated with the columns
+ that comprise the composite value.
+
+ This collection is consulted in the case of intercepting the
+ :meth:`.InstanceEvents.refresh` and
+ :meth:`.InstanceEvents.refresh_flush` events, which pass along a list
+ of attribute names that have been refreshed; the list is compared
+ against this set to determine if action needs to be taken.
+
+ .. versionadded:: 1.0.5
+
+ """
+ return {attribute.key}
+
+ @classmethod
+ def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ key = attribute.key
+ if parent_cls is not attribute.class_:
+ return
+
+ # rely on "propagate" here
+ parent_cls = attribute.class_
+
+ listen_keys = cls._get_listen_keys(attribute)
+
+ def load(state, *args):
+ """Listen for objects loaded or refreshed.
+
+ Wrap the target data member's value with
+ ``Mutable``.
+
+ """
+ val = state.dict.get(key, None)
+ if val is not None:
+ if coerce:
+ val = cls.coerce(key, val)
+ state.dict[key] = val
+ val._parents[state] = key
+
+ def load_attrs(state, ctx, attrs):
+ if not attrs or listen_keys.intersection(attrs):
+ load(state)
+
+ def set_(target, value, oldvalue, initiator):
+ """Listen for set/replace events on the target
+ data member.
+
+ Establish a weak reference to the parent object
+ on the incoming value, remove it for the one
+ outgoing.
+
+ """
+ if value is oldvalue:
+ return value
+
+ if not isinstance(value, cls):
+ value = cls.coerce(key, value)
+ if value is not None:
+ value._parents[target] = key
+ if isinstance(oldvalue, cls):
+ oldvalue._parents.pop(inspect(target), None)
+ return value
+
+ def pickle(state, state_dict):
+ val = state.dict.get(key, None)
+ if val is not None:
+ if "ext.mutable.values" not in state_dict:
+ state_dict["ext.mutable.values"] = defaultdict(list)
+ state_dict["ext.mutable.values"][key].append(val)
+
+ def unpickle(state, state_dict):
+ if "ext.mutable.values" in state_dict:
+ collection = state_dict["ext.mutable.values"]
+ if isinstance(collection, list):
+ # legacy format
+ for val in collection:
+ val._parents[state] = key
+ else:
+ for val in state_dict["ext.mutable.values"][key]:
+ val._parents[state] = key
+
+ event.listen(parent_cls, "load", load, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "refresh", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ attribute, "set", set_, raw=True, retval=True, propagate=True
+ )
+ event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "unpickle", unpickle, raw=True, propagate=True
+ )
+
+
+class Mutable(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events to a parent object.
+
+ See the example in :ref:`mutable_scalars` for usage information.
+
+ """
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+ flag_modified(parent.obj(), key)
+
+ @classmethod
+ def associate_with_attribute(cls, attribute):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ cls._listen_on_attribute(attribute, True, attribute.class_)
+
+ @classmethod
+ def associate_with(cls, sqltype):
+ """Associate this wrapper with all future mapped columns
+ of the given type.
+
+ This is a convenience method that calls
+ ``associate_with_attribute`` automatically.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.associate_with` for types that are permanent to an
+ application, not with ad-hoc types else this will cause unbounded
+ growth in memory usage.
+
+ """
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if isinstance(prop.columns[0].type, sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ @classmethod
+ def as_mutable(cls, sqltype):
+ """Associate a SQL type with this mutable Python type.
+
+ This establishes listeners that will detect ORM mappings against
+ the given type, adding mutation event trackers to those mappings.
+
+ The type is returned, unconditionally as an instance, so that
+ :meth:`.as_mutable` can be used inline::
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyMutableType.as_mutable(PickleType))
+ )
+
+ Note that the returned type is always an instance, even if a class
+ is given, and that only columns which are declared specifically with
+ that type instance receive additional instrumentation.
+
+ To associate a particular mutable type with all occurrences of a
+ particular type, use the :meth:`.Mutable.associate_with` classmethod
+ of the particular :class:`.Mutable` subclass to establish a global
+ association.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.as_mutable` for types that are permanent to an application,
+ not with ad-hoc types else this will cause unbounded growth
+ in memory usage.
+
+ """
+ sqltype = types.to_instance(sqltype)
+
+ # a SchemaType will be copied when the Column is copied,
+ # and we'll lose our ability to link that type back to the original.
+ # so track our original type w/ columns
+ if isinstance(sqltype, SchemaEventTarget):
+
+ @event.listens_for(sqltype, "before_parent_attach")
+ def _add_column_memo(sqltyp, parent):
+ parent.info["_ext_mutable_orig_type"] = sqltyp
+
+ schema_event_check = True
+ else:
+ schema_event_check = False
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if (
+ schema_event_check
+ and hasattr(prop.expression, "info")
+ and prop.expression.info.get("_ext_mutable_orig_type")
+ is sqltype
+ ) or (prop.columns[0].type is sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ return sqltype
+
+
+class MutableComposite(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events on a SQLAlchemy "composite" object to its
+ owning parent or parents.
+
+ See the example in :ref:`mutable_composites` for usage information.
+
+ """
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ return {attribute.key}.union(attribute.property._attribute_keys)
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+
+ prop = parent.mapper.get_property(key)
+ for value, attr_name in zip(
+ self.__composite_values__(), prop._attribute_keys
+ ):
+ setattr(parent.obj(), attr_name, value)
+
+
+def _setup_composite_listener():
+ def _listen_for_type(mapper, class_):
+ for prop in mapper.iterate_properties:
+ if (
+ hasattr(prop, "composite_class")
+ and isinstance(prop.composite_class, type)
+ and issubclass(prop.composite_class, MutableComposite)
+ ):
+ prop.composite_class._listen_on_attribute(
+ getattr(class_, prop.key), False, class_
+ )
+
+ if not event.contains(Mapper, "mapper_configured", _listen_for_type):
+ event.listen(Mapper, "mapper_configured", _listen_for_type)
+
+
+_setup_composite_listener()
+
+
+class MutableDict(Mutable, dict):
+ """A dictionary type that implements :class:`.Mutable`.
+
+ The :class:`.MutableDict` object implements a dictionary that will
+ emit change events to the underlying mapping when the contents of
+ the dictionary are altered, including when values are added or removed.
+
+ Note that :class:`.MutableDict` does **not** apply mutable tracking to the
+ *values themselves* inside the dictionary. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ dictionary structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableDict` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. seealso::
+
+ :class:`.MutableList`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __setitem__(self, key, value):
+ """Detect dictionary set events and emit change events."""
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def setdefault(self, key, value):
+ result = dict.setdefault(self, key, value)
+ self.changed()
+ return result
+
+ def __delitem__(self, key):
+ """Detect dictionary del events and emit change events."""
+ dict.__delitem__(self, key)
+ self.changed()
+
+ def update(self, *a, **kw):
+ dict.update(self, *a, **kw)
+ self.changed()
+
+ def pop(self, *arg):
+ result = dict.pop(self, *arg)
+ self.changed()
+ return result
+
+ def popitem(self):
+ result = dict.popitem(self)
+ self.changed()
+ return result
+
+ def clear(self):
+ dict.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Convert plain dictionary to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, dict):
+ return cls(value)
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+
+class MutableList(Mutable, list):
+ """A list type that implements :class:`.Mutable`.
+
+ The :class:`.MutableList` object implements a list that will
+ emit change events to the underlying mapping when the contents of
+ the list are altered, including when values are added or removed.
+
+ Note that :class:`.MutableList` does **not** apply mutable tracking to the
+ *values themselves* inside the list. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableList` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
+
+ # needed for backwards compatibility with
+ # older pickles
+ def __setstate__(self, state):
+ self[:] = state
+
+ def __setitem__(self, index, value):
+ """Detect list set events and emit change events."""
+ list.__setitem__(self, index, value)
+ self.changed()
+
+ def __setslice__(self, start, end, value):
+ """Detect list set events and emit change events."""
+ list.__setslice__(self, start, end, value)
+ self.changed()
+
+ def __delitem__(self, index):
+ """Detect list del events and emit change events."""
+ list.__delitem__(self, index)
+ self.changed()
+
+ def __delslice__(self, start, end):
+ """Detect list del events and emit change events."""
+ list.__delslice__(self, start, end)
+ self.changed()
+
+ def pop(self, *arg):
+ result = list.pop(self, *arg)
+ self.changed()
+ return result
+
+ def append(self, x):
+ list.append(self, x)
+ self.changed()
+
+ def extend(self, x):
+ list.extend(self, x)
+ self.changed()
+
+ def __iadd__(self, x):
+ self.extend(x)
+ return self
+
+ def insert(self, i, x):
+ list.insert(self, i, x)
+ self.changed()
+
+ def remove(self, i):
+ list.remove(self, i)
+ self.changed()
+
+ def clear(self):
+ list.clear(self)
+ self.changed()
+
+ def sort(self, **kw):
+ list.sort(self, **kw)
+ self.changed()
+
+ def reverse(self):
+ list.reverse(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain list to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, list):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+
+class MutableSet(Mutable, set):
+ """A set type that implements :class:`.Mutable`.
+
+ The :class:`.MutableSet` object implements a set that will
+ emit change events to the underlying mapping when the contents of
+ the set are altered, including when values are added or removed.
+
+ Note that :class:`.MutableSet` does **not** apply mutable tracking to the
+ *values themselves* inside the set. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure. To support this use case,
+ build a subclass of :class:`.MutableSet` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableList`
+
+
+ """
+
+ def update(self, *arg):
+ set.update(self, *arg)
+ self.changed()
+
+ def intersection_update(self, *arg):
+ set.intersection_update(self, *arg)
+ self.changed()
+
+ def difference_update(self, *arg):
+ set.difference_update(self, *arg)
+ self.changed()
+
+ def symmetric_difference_update(self, *arg):
+ set.symmetric_difference_update(self, *arg)
+ self.changed()
+
+ def __ior__(self, other):
+ self.update(other)
+ return self
+
+ def __iand__(self, other):
+ self.intersection_update(other)
+ return self
+
+ def __ixor__(self, other):
+ self.symmetric_difference_update(other)
+ return self
+
+ def __isub__(self, other):
+ self.difference_update(other)
+ return self
+
+ def add(self, elem):
+ set.add(self, elem)
+ self.changed()
+
+ def remove(self, elem):
+ set.remove(self, elem)
+ self.changed()
+
+ def discard(self, elem):
+ set.discard(self, elem)
+ self.changed()
+
+ def pop(self, *arg):
+ result = set.pop(self, *arg)
+ self.changed()
+ return result
+
+ def clear(self):
+ set.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain set to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, set):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return set(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/__init__.py
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
new file mode 100644
index 0000000..99be194
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -0,0 +1,299 @@
+# ext/mypy/apply.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import MDEF
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.types import AnyType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import infer
+from . import util
+from .names import NAMED_TYPE_SQLA_MAPPED
+
+
+def apply_mypy_mapped_attr(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ item: Union[NameExpr, StrExpr],
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ if isinstance(item, NameExpr):
+ name = item.name
+ elif isinstance(item, StrExpr):
+ name = item.value
+ else:
+ return None
+
+ for stmt in cls.defs.body:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
+ break
+ else:
+ util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+ return None
+
+ if stmt.type is None:
+ util.fail(
+ api,
+ "Statement linked from _mypy_mapped_attrs has no "
+ "typing information",
+ stmt,
+ )
+ return None
+
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=item.line,
+ column=item.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+ apply_type_to_mapped_statement(
+ api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+ )
+
+
+def re_apply_declarative_assignments(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """For multiple class passes, re-apply our left-hand side types as mypy
+ seems to reset them in place.
+
+ """
+ mapped_attr_lookup = {attr.name: attr for attr in attributes}
+ update_cls_metadata = False
+
+ for stmt in cls.defs.body:
+ # for a re-apply, all of our statements are AssignmentStmt;
+ # @declared_attr calls will have been converted and this
+ # currently seems to be preserved by mypy (but who knows if this
+ # will change).
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
+ ):
+
+ left_node = stmt.lvalues[0].node
+ python_type_for_type = mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type
+
+ left_node_proper_type = get_proper_type(left_node.type)
+
+ # if we have scanned an UnboundType and now there's a more
+ # specific type than UnboundType, call the re-scan so we
+ # can get that set up correctly
+ if (
+ isinstance(python_type_for_type, UnboundType)
+ and not isinstance(left_node_proper_type, UnboundType)
+ and (
+ isinstance(stmt.rvalue, CallExpr)
+ and isinstance(stmt.rvalue.callee, MemberExpr)
+ and isinstance(stmt.rvalue.callee.expr, NameExpr)
+ and stmt.rvalue.callee.expr.node is not None
+ and stmt.rvalue.callee.expr.node.fullname
+ == NAMED_TYPE_SQLA_MAPPED
+ and stmt.rvalue.callee.name == "_empty_constructor"
+ and isinstance(stmt.rvalue.args[0], CallExpr)
+ and isinstance(stmt.rvalue.args[0].callee, RefExpr)
+ )
+ ):
+
+ python_type_for_type = (
+ infer.infer_type_from_right_hand_nameexpr(
+ api,
+ stmt,
+ left_node,
+ left_node_proper_type,
+ stmt.rvalue.args[0].callee,
+ )
+ )
+
+ if python_type_for_type is None or isinstance(
+ python_type_for_type, UnboundType
+ ):
+ continue
+
+ # update the SQLAlchemyAttribute with the better information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
+
+ update_cls_metadata = True
+
+ if python_type_for_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
+ )
+
+ if update_cls_metadata:
+ util.set_mapped_attributes(cls.info, attributes)
+
+
+def apply_type_to_mapped_statement(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ lvalue: NameExpr,
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
+) -> None:
+ """Apply the Mapped[<type>] annotation and right hand object to a
+ declarative assignment statement.
+
+ This converts a Python declarative class statement such as::
+
+ class User(Base):
+ # ...
+
+ attrname = Column(Integer)
+
+ To one that describes the final Python behavior to Mypy::
+
+ class User(Base):
+ # ...
+
+ attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+ """
+ left_node = lvalue.node
+ assert isinstance(left_node, Var)
+
+ if left_hand_explicit_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+ else:
+ lvalue.is_inferred_def = False
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED,
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
+
+ # so to have it skip the right side totally, we can do this:
+ # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # however, if we instead manufacture a new node that uses the old
+ # one, then we can still get type checking for the call itself,
+ # e.g. the Column, relationship() call, etc.
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+ # the original right-hand side is maintained so it gets type checked
+ # internally
+ stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
+
+
+def add_additional_orm_attributes(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Apply __init__, __table__ and other attributes to the mapped class."""
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ is_base = util.get_is_base(info)
+
+ if "__init__" not in info.names and not is_base:
+ mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+ for base in info.mro[1:-1]:
+ if "sqlalchemy" not in info.metadata:
+ continue
+
+ base_cls_attributes = util.get_mapped_attributes(base, api)
+ if base_cls_attributes is None:
+ continue
+
+ for attr in base_cls_attributes:
+ mapped_attr_names.setdefault(attr.name, attr.type)
+
+ arguments = []
+ for name, typ in mapped_attr_names.items():
+ if typ is None:
+ typ = AnyType(TypeOfAny.special_form)
+ arguments.append(
+ Argument(
+ variable=Var(name, typ),
+ type_annotation=typ,
+ initializer=TempNode(typ),
+ kind=ARG_NAMED_OPT,
+ )
+ )
+
+ add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+ if "__table__" not in info.names and util.get_has_table(info):
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+ )
+ if not is_base:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+ )
+
+
+def _apply_placeholder_attr_to_class(
+ api: SemanticAnalyzerPluginInterface,
+ cls: ClassDef,
+ qualified_name: str,
+ attrname: str,
+) -> None:
+ sym = api.lookup_fully_qualified_or_none(qualified_name)
+ if sym:
+ assert isinstance(sym.node, TypeInfo)
+ type_: ProperType = Instance(sym.node, [])
+ else:
+ type_ = AnyType(TypeOfAny.special_form)
+ var = Var(attrname)
+ var._fullname = cls.fullname + "." + attrname
+ var.info = cls.info
+ var.type = type_
+ cls.info.names[attrname] = SymbolTableNode(MDEF, var)
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
new file mode 100644
index 0000000..c33c30e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -0,0 +1,516 @@
+# ext/mypy/decl_class.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
+from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import PlaceholderNode
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import apply
+from . import infer
+from . import names
+from . import util
+
+
+def scan_declarative_assignments_and_apply_types(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ is_mixin_scan: bool = False,
+) -> Optional[List[util.SQLAlchemyAttribute]]:
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ # this can occur during cached passes
+ return None
+ elif cls.fullname.startswith("builtins"):
+ return None
+
+ mapped_attributes: Optional[
+ List[util.SQLAlchemyAttribute]
+ ] = util.get_mapped_attributes(info, api)
+
+ # used by assign.add_additional_orm_attributes among others
+ util.establish_as_sqlalchemy(info)
+
+ if mapped_attributes is not None:
+ # ensure that a class that's mapped is always picked up by
+ # its mapped() decorator or declarative metaclass before
+ # it would be detected as an unmapped mixin class
+
+ if not is_mixin_scan:
+ # mypy can call us more than once. it then *may* have reset the
+ # left hand side of everything, but not the right that we removed,
+ # removing our ability to re-scan. but we have the types
+ # here, so lets re-apply them, or if we have an UnboundType,
+ # we can re-scan
+
+ apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
+
+ return mapped_attributes
+
+ mapped_attributes = []
+
+ if not cls.defs.body:
+ # when we get a mixin class from another file, the body is
+ # empty (!) but the names are in the symbol table. so use that.
+
+ for sym_name, sym in info.names.items():
+ _scan_symbol_table_entry(
+ cls, api, sym_name, sym, mapped_attributes
+ )
+ else:
+ for stmt in util.flatten_typechecking(cls.defs.body):
+ if isinstance(stmt, AssignmentStmt):
+ _scan_declarative_assignment_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ elif isinstance(stmt, Decorator):
+ _scan_declarative_decorator_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ _scan_for_mapped_bases(cls, api)
+
+ if not is_mixin_scan:
+ apply.add_additional_orm_attributes(cls, api, mapped_attributes)
+
+ util.set_mapped_attributes(info, mapped_attributes)
+
+ return mapped_attributes
+
+
+def _scan_symbol_table_entry(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ name: str,
+ value: SymbolTableNode,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a SymbolTableNode that's in the
+ type.names dictionary.
+
+ """
+ value_type = get_proper_type(value.type)
+ if not isinstance(value_type, Instance):
+ return
+
+ left_hand_explicit_type = None
+ type_id = names.type_id_for_named_node(value_type.type)
+ # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
+
+ err = False
+
+ # TODO: this is nearly the same logic as that of
+ # _scan_declarative_decorator_stmt, likely can be merged
+ if type_id in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }:
+ if value_type.args:
+ left_hand_explicit_type = get_proper_type(value_type.args[0])
+ else:
+ err = True
+ elif type_id is names.COLUMN:
+ if not value_type.args:
+ err = True
+ else:
+ typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+ value_type.args[0]
+ )
+ if isinstance(typeengine_arg, Instance):
+ typeengine_arg = typeengine_arg.type
+
+ if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ value_type,
+ )
+
+ if err:
+ msg = (
+ "Can't infer type from attribute {} on class {}. "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(name, cls.name), cls)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ if left_hand_explicit_type is not None:
+ assert value.node is not None
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=value.node.line,
+ column=value.node.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+
+def _scan_declarative_decorator_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: Decorator,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a @declared_attr in a declarative
+ class.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ @declared_attr
+ def updated_at(cls) -> Column[DateTime]:
+ return Column(DateTime)
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ updated_at: Mapped[Optional[datetime.datetime]]
+
+ """
+ for dec in stmt.decorators:
+ if (
+ isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+ and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
+ ):
+ break
+ else:
+ return
+
+ dec_index = cls.defs.body.index(stmt)
+
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if util.name_is_dunder(stmt.name):
+ # for dunder names like __table_args__, __tablename__,
+ # __mapper_args__ etc., rewrite these as simple assignment
+ # statements; otherwise mypy doesn't like if the decorated
+ # function has an annotation like ``cls: Type[Foo]`` because
+ # it isn't @classmethod
+ any_ = AnyType(TypeOfAny.special_form)
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+ new_stmt = AssignmentStmt([left_node], TempNode(any_))
+ new_stmt.type = left_node.node.type
+ cls.defs.body[dec_index] = new_stmt
+ return
+ elif isinstance(stmt.func.type, CallableType):
+ func_type = stmt.func.type.ret_type
+ if isinstance(func_type, UnboundType):
+ type_id = names.type_id_for_unbound_type(func_type, cls, api)
+ else:
+ # this does not seem to occur unless the type argument is
+ # incorrect
+ return
+
+ if (
+ type_id
+ in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }
+ and func_type.args
+ ):
+ left_hand_explicit_type = get_proper_type(func_type.args[0])
+ elif type_id is names.COLUMN and func_type.args:
+ typeengine_arg = func_type.args[0]
+ if isinstance(typeengine_arg, UnboundType):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
+
+ if left_hand_explicit_type is None:
+ # no type on the decorated function. our option here is to
+ # dig into the function body and get the return type, but they
+ # should just have an annotation.
+ msg = (
+ "Can't infer type from @declared_attr on function '{}'; "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(stmt.var.name), stmt)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+
+ # totally feeling around in the dark here as I don't totally understand
+ # the significance of UnboundType. It seems to be something that is
+ # not going to do what's expected when it is applied as the type of
+ # an AssignmentStatement. So do a feeling-around-in-the-dark version
+ # of converting it to the regular Instance/TypeInfo/UnionType structures
+ # we see everywhere else.
+ if isinstance(left_hand_explicit_type, UnboundType):
+ left_hand_explicit_type = get_proper_type(
+ util.unbound_to_instance(api, left_hand_explicit_type)
+ )
+
+ left_node.node.type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+
+ # this will ignore the rvalue entirely
+ # rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(lambda: <function body>)
+ # the function body is maintained so it gets type checked internally
+ rvalue = util.expr_to_mapped_constructor(
+ LambdaExpr(stmt.func.arguments, stmt.func.body)
+ )
+
+ new_stmt = AssignmentStmt([left_node], rvalue)
+ new_stmt.type = left_node.node.type
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=left_node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+ cls.defs.body[dec_index] = new_stmt
+
+
+def _scan_declarative_assignment_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from an assignment statement in a
+ declarative class.
+
+ """
+ lvalue = stmt.lvalues[0]
+ if not isinstance(lvalue, NameExpr):
+ return
+
+ sym = cls.info.names.get(lvalue.name)
+
+ # this establishes that semantic analysis has taken place, which
+ # means the nodes are populated and we are called from an appropriate
+ # hook.
+ assert sym is not None
+ node = sym.node
+
+ if isinstance(node, PlaceholderNode):
+ return
+
+ assert node is lvalue.node
+ assert isinstance(node, Var)
+
+ if node.name == "__abstract__":
+ if api.parse_bool(stmt.rvalue) is True:
+ util.set_is_base(cls.info)
+ return
+ elif node.name == "__tablename__":
+ util.set_has_table(cls.info)
+ elif node.name.startswith("__"):
+ return
+ elif node.name == "_mypy_mapped_attrs":
+ if not isinstance(stmt.rvalue, ListExpr):
+ util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
+ else:
+ for item in stmt.rvalue.items:
+ if isinstance(item, (NameExpr, StrExpr)):
+ apply.apply_mypy_mapped_attr(cls, api, item, attributes)
+
+ left_hand_mapped_type: Optional[Type] = None
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if node.is_inferred or node.type is None:
+ if isinstance(stmt.type, UnboundType):
+ # look for an explicit Mapped[] type annotation on the left
+ # side with nothing on the right
+
+ # print(stmt.type)
+ # Mapped?[Optional?[A?]]
+
+ left_hand_explicit_type = stmt.type
+
+ if stmt.type.name == "Mapped":
+ mapped_sym = api.lookup_qualified("Mapped", cls)
+ if (
+ mapped_sym is not None
+ and mapped_sym.node is not None
+ and names.type_id_for_named_node(mapped_sym.node)
+ is names.MAPPED
+ ):
+ left_hand_explicit_type = get_proper_type(
+ stmt.type.args[0]
+ )
+ left_hand_mapped_type = stmt.type
+
+ # TODO: do we need to convert from unbound for this case?
+ # left_hand_explicit_type = util._unbound_to_instance(
+ # api, left_hand_explicit_type
+ # )
+ else:
+ node_type = get_proper_type(node.type)
+ if (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.MAPPED
+ ):
+ # print(node.type)
+ # sqlalchemy.orm.attributes.Mapped[<python type>]
+ left_hand_explicit_type = get_proper_type(node_type.args[0])
+ left_hand_mapped_type = node_type
+ else:
+ # print(node.type)
+ # <python type>
+ left_hand_explicit_type = node_type
+ left_hand_mapped_type = None
+
+ if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
+ # annotation without assignment and Mapped is present
+ # as type annotation
+ # equivalent to using _infer_type_from_left_hand_type_only.
+
+ python_type_for_type = left_hand_explicit_type
+ elif isinstance(stmt.rvalue, CallExpr) and isinstance(
+ stmt.rvalue.callee, RefExpr
+ ):
+
+ python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
+ api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
+ )
+
+ if python_type_for_type is None:
+ return
+
+ else:
+ return
+
+ assert python_type_for_type is not None
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=python_type_for_type,
+ info=cls.info,
+ )
+ )
+
+ apply.apply_type_to_mapped_statement(
+ api,
+ stmt,
+ lvalue,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+
+
+def _scan_for_mapped_bases(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+) -> None:
+ """Given a class, iterate through its superclass hierarchy to find
+ all other classes that are considered as ORM-significant.
+
+ Locates non-mapped mixins and scans them for mapped attributes to be
+ applied to subclasses.
+
+ """
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ for base_info in info.mro[1:-1]:
+ if base_info.fullname.startswith("builtins"):
+ continue
+
+ # scan each base for mapped attributes. if they are not already
+ # scanned (but have all their type info), that means they are unmapped
+ # mixins
+ scan_declarative_assignments_and_apply_types(
+ base_info.defn, api, is_mixin_scan=True
+ )
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
new file mode 100644
index 0000000..f88a960
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -0,0 +1,556 @@
+# ext/mypy/infer.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Sequence
+
+from mypy.maptype import map_instance_to_supertype
+from mypy.messages import format_type
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import LambdaExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+def infer_type_from_right_hand_nameexpr(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ infer_from_right_side: RefExpr,
+) -> Optional[ProperType]:
+
+ type_id = names.type_id_for_callee(infer_from_right_side)
+
+ if type_id is None:
+ return None
+ elif type_id is names.COLUMN:
+ python_type_for_type = _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.RELATIONSHIP:
+ python_type_for_type = _infer_type_from_relationship(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.COLUMN_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_column_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.SYNONYM_PROPERTY:
+ python_type_for_type = infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif type_id is names.COMPOSITE_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ else:
+ return None
+
+ return python_type_for_type
+
+
+def _infer_type_from_relationship(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a relationship.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses = relationship(Address, uselist=True)
+
+ order: Mapped["Order"] = relationship("Order")
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses: Mapped[List[Address]]
+
+ order: Mapped["Order"]
+
+ """
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type: Optional[ProperType] = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ # type
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+
+ # other cases not covered - an error message directs the user
+ # to set an explicit type annotation
+ #
+ # node.type == str, it's a string
+ # if isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, Var
+ # )
+ # points to a type
+ # isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, TypeAlias
+ # )
+ # string expression
+ # isinstance(target_cls_arg, StrExpr)
+
+ uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
+ stmt.rvalue, "collection_class"
+ )
+ type_is_a_collection = False
+
+ # this can be used to determine Optional for a many-to-one
+ # in the same way nullable=False could be used, if we start supporting
+ # that.
+ # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+ if (
+ uselist_arg is not None
+ and api.parse_bool(uselist_arg) is True
+ and collection_cls_arg is None
+ ):
+ type_is_a_collection = True
+ if python_type_for_type is not None:
+ python_type_for_type = api.named_type(
+ names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
+ )
+ elif (
+ uselist_arg is None or api.parse_bool(uselist_arg) is True
+ ) and collection_cls_arg is not None:
+ type_is_a_collection = True
+ if isinstance(collection_cls_arg, CallExpr):
+ collection_cls_arg = collection_cls_arg.callee
+
+ if isinstance(collection_cls_arg, NameExpr) and isinstance(
+ collection_cls_arg.node, TypeInfo
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+ python_type_for_type = Instance(
+ collection_cls_arg.node, [python_type_for_type]
+ )
+ elif (
+ isinstance(collection_cls_arg, NameExpr)
+ and isinstance(collection_cls_arg.node, FuncDef)
+ and collection_cls_arg.node.type is not None
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+
+ # TODO: handle mypy.types.Overloaded
+ if isinstance(collection_cls_arg.node.type, CallableType):
+ rt = get_proper_type(collection_cls_arg.node.type.ret_type)
+
+ if isinstance(rt, CallableType):
+ callable_ret_type = get_proper_type(rt.ret_type)
+ if isinstance(callable_ret_type, Instance):
+ python_type_for_type = Instance(
+ callable_ret_type.type,
+ [python_type_for_type],
+ )
+ else:
+ util.fail(
+ api,
+ "Expected Python collection type for "
+ "collection_class parameter",
+ stmt.rvalue,
+ )
+ python_type_for_type = None
+ elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
+ if collection_cls_arg is not None:
+ util.fail(
+ api,
+ "Sending uselist=False and collection_class at the same time "
+ "does not make sense",
+ stmt.rvalue,
+ )
+ if python_type_for_type is not None:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+
+ else:
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer scalar or collection for ORM mapped expression "
+ "assigned to attribute '{}' if both 'uselist' and "
+ "'collection_class' arguments are absent from the "
+ "relationship(); please specify a "
+ "type annotation on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ if type_is_a_collection:
+ assert isinstance(left_hand_explicit_type, Instance)
+ assert isinstance(python_type_for_type, Instance)
+ return _infer_collection_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a CompositeProperty."""
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+ else:
+ python_type_for_type = None
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a ColumnProperty.
+
+ This includes mappings against ``column_property()`` as well as the
+ ``deferred()`` function.
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+
+ if stmt.rvalue.args:
+ first_prop_arg = stmt.rvalue.args[0]
+
+ if isinstance(first_prop_arg, CallExpr):
+ type_id = names.type_id_for_callee(first_prop_arg.callee)
+
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ right_hand_expression=first_prop_arg,
+ )
+
+ if isinstance(stmt.rvalue, CallExpr):
+ type_id = names.type_id_for_callee(stmt.rvalue.callee)
+ # this is probably not strictly necessary as we have to use the left
+ # hand type for query expression in any case. any other no-arg
+ # column prop objects would go here also
+ if type_id is names.QUERY_EXPRESSION:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ )
+
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_decl_column(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ right_hand_expression: Optional[CallExpr] = None,
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a Column.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a = Column(Integer)
+
+ b = Column("b", String)
+
+ c: Mapped[int] = Column(Integer)
+
+ d: bool = Column(Boolean)
+
+ Will resolve in MyPy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a : Mapped[int]
+
+ b : Mapped[str]
+
+ c: Mapped[int]
+
+ d: Mapped[bool]
+
+ """
+ assert isinstance(node, Var)
+
+ callee = None
+
+ if right_hand_expression is None:
+ if not isinstance(stmt.rvalue, CallExpr):
+ return None
+
+ right_hand_expression = stmt.rvalue
+
+ for column_arg in right_hand_expression.args[0:2]:
+ if isinstance(column_arg, CallExpr):
+ if isinstance(column_arg.callee, RefExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args: Sequence[Expression] = column_arg.args
+ break
+ elif isinstance(column_arg, (NameExpr, MemberExpr)):
+ if isinstance(column_arg.node, TypeInfo):
+ # x = Column(String)
+ callee = column_arg
+ type_args = ()
+ break
+ else:
+ # x = Column(some_name, String), go to next argument
+ continue
+ elif isinstance(column_arg, (StrExpr,)):
+ # x = Column("name", String), go to next argument
+ continue
+ elif isinstance(column_arg, (LambdaExpr,)):
+ # x = Column("name", String, default=lambda: uuid.uuid4())
+ # go to next argument
+ continue
+ else:
+ assert False
+
+ if callee is None:
+ return None
+
+ if isinstance(callee.node, TypeInfo) and names.mro_has_id(
+ callee.node.mro, names.TYPEENGINE
+ ):
+ python_type_for_type = extract_python_type_from_typeengine(
+ api, callee.node, type_args
+ )
+
+ if left_hand_explicit_type is not None:
+
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+
+ else:
+ return UnionType([python_type_for_type, NoneType()])
+ else:
+ # it's not TypeEngine, it's typically implicitly typed
+ # like ForeignKey. we can't infer from the right side.
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: ProperType,
+ python_type_for_type: ProperType,
+ orig_left_hand_type: Optional[ProperType] = None,
+ orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
+ """Validate type when a left hand annotation is present and we also
+ could infer the right hand side::
+
+ attrname: SomeType = Column(SomeDBType)
+
+ """
+
+ if orig_left_hand_type is None:
+ orig_left_hand_type = left_hand_explicit_type
+ if orig_python_type_for_type is None:
+ orig_python_type_for_type = python_type_for_type
+
+ if not is_subtype(left_hand_explicit_type, python_type_for_type):
+ effective_type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
+ )
+
+ msg = (
+ "Left hand assignment '{}: {}' not compatible "
+ "with ORM mapped expression of type {}"
+ )
+ util.fail(
+ api,
+ msg.format(
+ node.name,
+ format_type(orig_left_hand_type),
+ format_type(effective_type),
+ ),
+ node,
+ )
+
+ return orig_left_hand_type
+
+
+def _infer_collection_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Instance,
+ python_type_for_type: Instance,
+) -> Optional[ProperType]:
+ orig_left_hand_type = left_hand_explicit_type
+ orig_python_type_for_type = python_type_for_type
+
+ if left_hand_explicit_type.args:
+ left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+ python_type_arg = get_proper_type(python_type_for_type.args[0])
+ else:
+ left_hand_arg = left_hand_explicit_type
+ python_type_arg = python_type_for_type
+
+ assert isinstance(left_hand_arg, (Instance, UnionType))
+ assert isinstance(python_type_arg, (Instance, UnionType))
+
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_arg,
+ python_type_arg,
+ orig_left_hand_type=orig_left_hand_type,
+ orig_python_type_for_type=orig_python_type_for_type,
+ )
+
+
+def infer_type_from_left_hand_type_only(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Determine the type based on explicit annotation only.
+
+ if no annotation were present, note that we need one there to know
+ the type.
+
+ """
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer type from ORM mapped expression "
+ "assigned to attribute '{}'; please specify a "
+ "Python type or "
+ "Mapped[<python type>] on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ return api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
+ )
+
+ else:
+ # use type from the left hand side
+ return left_hand_explicit_type
+
+
+def extract_python_type_from_typeengine(
+ api: SemanticAnalyzerPluginInterface,
+ node: TypeInfo,
+ type_args: Sequence[Expression],
+) -> ProperType:
+ if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
+ first_arg = type_args[0]
+ if isinstance(first_arg, RefExpr) and isinstance(
+ first_arg.node, TypeInfo
+ ):
+ for base_ in first_arg.node.mro:
+ if base_.fullname == "enum.Enum":
+ return Instance(first_arg.node, [])
+ # TODO: support other pep-435 types here
+ else:
+ return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
+
+ assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
+ "could not extract Python type from node: %s" % node
+ )
+
+ type_engine_sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.sql.type_api.TypeEngine"
+ )
+
+ assert type_engine_sym is not None and isinstance(
+ type_engine_sym.node, TypeInfo
+ )
+ type_engine = map_instance_to_supertype(
+ Instance(node, []),
+ type_engine_sym.node,
+ )
+ return get_proper_type(type_engine.args[-1])
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
new file mode 100644
index 0000000..8ec15a6
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -0,0 +1,253 @@
+# ext/mypy/names.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Union
+
+from mypy.nodes import ClassDef
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import TypeAlias
+from mypy.nodes import TypeInfo
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import UnboundType
+
+from ... import util
+
+COLUMN: int = util.symbol("COLUMN") # type: ignore
+RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
+REGISTRY: int = util.symbol("REGISTRY") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
+MAPPED: int = util.symbol("MAPPED") # type: ignore
+DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
+MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
+COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
+DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
+MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
+AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
+AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
+QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
+
+# names that must succeed with mypy.api.named_type
+NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
+NAMED_TYPE_BUILTINS_STR = "builtins.str"
+NAMED_TYPE_BUILTINS_LIST = "builtins.list"
+NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
+
+_lookup: Dict[str, Tuple[int, Set[str]]] = {
+ "Column": (
+ COLUMN,
+ {
+ "sqlalchemy.sql.schema.Column",
+ "sqlalchemy.sql.Column",
+ },
+ ),
+ "RelationshipProperty": (
+ RELATIONSHIP,
+ {
+ "sqlalchemy.orm.relationships.RelationshipProperty",
+ "sqlalchemy.orm.RelationshipProperty",
+ },
+ ),
+ "registry": (
+ REGISTRY,
+ {
+ "sqlalchemy.orm.decl_api.registry",
+ "sqlalchemy.orm.registry",
+ },
+ ),
+ "ColumnProperty": (
+ COLUMN_PROPERTY,
+ {
+ "sqlalchemy.orm.properties.ColumnProperty",
+ "sqlalchemy.orm.ColumnProperty",
+ },
+ ),
+ "SynonymProperty": (
+ SYNONYM_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.SynonymProperty",
+ "sqlalchemy.orm.SynonymProperty",
+ },
+ ),
+ "CompositeProperty": (
+ COMPOSITE_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.CompositeProperty",
+ "sqlalchemy.orm.CompositeProperty",
+ },
+ ),
+ "MapperProperty": (
+ MAPPER_PROPERTY,
+ {
+ "sqlalchemy.orm.interfaces.MapperProperty",
+ "sqlalchemy.orm.MapperProperty",
+ },
+ ),
+ "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
+ "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
+ "declarative_base": (
+ DECLARATIVE_BASE,
+ {
+ "sqlalchemy.ext.declarative.declarative_base",
+ "sqlalchemy.orm.declarative_base",
+ "sqlalchemy.orm.decl_api.declarative_base",
+ },
+ ),
+ "DeclarativeMeta": (
+ DECLARATIVE_META,
+ {
+ "sqlalchemy.ext.declarative.DeclarativeMeta",
+ "sqlalchemy.orm.DeclarativeMeta",
+ "sqlalchemy.orm.decl_api.DeclarativeMeta",
+ },
+ ),
+ "mapped": (
+ MAPPED_DECORATOR,
+ {
+ "sqlalchemy.orm.decl_api.registry.mapped",
+ "sqlalchemy.orm.registry.mapped",
+ },
+ ),
+ "as_declarative": (
+ AS_DECLARATIVE,
+ {
+ "sqlalchemy.ext.declarative.as_declarative",
+ "sqlalchemy.orm.decl_api.as_declarative",
+ "sqlalchemy.orm.as_declarative",
+ },
+ ),
+ "as_declarative_base": (
+ AS_DECLARATIVE_BASE,
+ {
+ "sqlalchemy.orm.decl_api.registry.as_declarative_base",
+ "sqlalchemy.orm.registry.as_declarative_base",
+ },
+ ),
+ "declared_attr": (
+ DECLARED_ATTR,
+ {
+ "sqlalchemy.orm.decl_api.declared_attr",
+ "sqlalchemy.orm.declared_attr",
+ },
+ ),
+ "declarative_mixin": (
+ DECLARATIVE_MIXIN,
+ {
+ "sqlalchemy.orm.decl_api.declarative_mixin",
+ "sqlalchemy.orm.declarative_mixin",
+ },
+ ),
+ "query_expression": (
+ QUERY_EXPRESSION,
+ {"sqlalchemy.orm.query_expression"},
+ ),
+}
+
+
+def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+ for mr in info.mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
+ for mr in mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def type_id_for_unbound_type(
+ type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[int]:
+ sym = api.lookup_qualified(type_.name, type_)
+ if sym is not None:
+ if isinstance(sym.node, TypeAlias):
+ target_type = get_proper_type(sym.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_named_node(target_type.type)
+ elif isinstance(sym.node, TypeInfo):
+ return type_id_for_named_node(sym.node)
+
+ return None
+
+
+def type_id_for_callee(callee: Expression) -> Optional[int]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_fullname(target_type.type.fullname)
+ elif isinstance(callee.node, TypeInfo):
+ return type_id_for_named_node(callee)
+ return None
+
+
+def type_id_for_named_node(
+ node: Union[NameExpr, MemberExpr, SymbolNode]
+) -> Optional[int]:
+ type_id, fullnames = _lookup.get(node.name, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif node.fullname in fullnames:
+ return type_id
+ else:
+ return None
+
+
+def type_id_for_fullname(fullname: str) -> Optional[int]:
+ tokens = fullname.split(".")
+ immediate = tokens[-1]
+
+ type_id, fullnames = _lookup.get(immediate, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif fullname in fullnames:
+ return type_id
+ else:
+ return None
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
new file mode 100644
index 0000000..8687012
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -0,0 +1,284 @@
+# ext/mypy/plugin.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Mypy plugin for SQLAlchemy ORM.
+
+"""
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type as TypingType
+from typing import Union
+
+from mypy import nodes
+from mypy.mro import calculate_mro
+from mypy.mro import MroError
+from mypy.nodes import Block
+from mypy.nodes import ClassDef
+from mypy.nodes import GDEF
+from mypy.nodes import MypyFile
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTable
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import AttributeContext
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import Plugin
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import Type
+
+from . import decl_class
+from . import names
+from . import util
+
+
+class SQLAlchemyPlugin(Plugin):
+ def get_dynamic_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[DynamicClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+ return _dynamic_class_hook
+ return None
+
+ def get_customize_class_mro_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ return _fill_in_decorators
+
+ def get_class_decorator_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+
+ sym = self.lookup_fully_qualified(fullname)
+
+ if sym is not None and sym.node is not None:
+ type_id = names.type_id_for_named_node(sym.node)
+ if type_id is names.MAPPED_DECORATOR:
+ return _cls_decorator_hook
+ elif type_id in (
+ names.AS_DECLARATIVE,
+ names.AS_DECLARATIVE_BASE,
+ ):
+ return _base_cls_decorator_hook
+ elif type_id is names.DECLARATIVE_MIXIN:
+ return _declarative_mixin_hook
+
+ return None
+
+ def get_metaclass_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
+ # Set any classes that explicitly have metaclass=DeclarativeMeta
+ # as declarative so the check in `get_base_class_hook()` works
+ return _metaclass_cls_hook
+
+ return None
+
+ def get_base_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ sym = self.lookup_fully_qualified(fullname)
+
+ if (
+ sym
+ and isinstance(sym.node, TypeInfo)
+ and util.has_declarative_base(sym.node)
+ ):
+ return _base_cls_hook
+
+ return None
+
+ def get_attribute_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[AttributeContext], Type]]:
+ if fullname.startswith(
+ "sqlalchemy.orm.attributes.QueryableAttribute."
+ ):
+ return _queryable_getattr_hook
+
+ return None
+
+ def get_additional_deps(
+ self, file: MypyFile
+ ) -> List[Tuple[int, str, int]]:
+ return [
+ (10, "sqlalchemy.orm.attributes", -1),
+ (10, "sqlalchemy.orm.decl_api", -1),
+ ]
+
+
+def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
+ return SQLAlchemyPlugin
+
+
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+ """Generate a declarative Base class when the declarative_base() function
+ is encountered."""
+
+ _add_globals(ctx)
+
+ cls = ClassDef(ctx.name, Block([]))
+ cls.fullname = ctx.api.qualified_name(ctx.name)
+
+ info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+ cls.info = info
+ _set_declarative_metaclass(ctx.api, cls)
+
+ cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+ if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
+ util.set_is_base(cls_arg.node)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls_arg.node.defn, ctx.api, is_mixin_scan=True
+ )
+ info.bases = [Instance(cls_arg.node, [])]
+ else:
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+
+ info.bases = [obj]
+
+ try:
+ calculate_mro(info)
+ except MroError:
+ util.fail(
+ ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+ )
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+ info.bases = [obj]
+ info.fallback_to_any = True
+
+ ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+ util.set_is_base(info)
+
+
+def _fill_in_decorators(ctx: ClassDefContext) -> None:
+ for decorator in ctx.cls.decorators:
+ # set the ".fullname" attribute of a class decorator
+ # that is a MemberExpr. This causes the logic in
+ # semanal.py->apply_class_plugin_hooks to invoke the
+ # get_class_decorator_hook for our "registry.map_class()"
+ # and "registry.as_declarative_base()" methods.
+ # this seems like a bug in mypy that these decorators are otherwise
+ # skipped.
+
+ if (
+ isinstance(decorator, nodes.CallExpr)
+ and isinstance(decorator.callee, nodes.MemberExpr)
+ and decorator.callee.name == "as_declarative_base"
+ ):
+ target = decorator.callee
+ elif (
+ isinstance(decorator, nodes.MemberExpr)
+ and decorator.name == "mapped"
+ ):
+ target = decorator
+ else:
+ continue
+
+ assert isinstance(target.expr, NameExpr)
+ sym = ctx.api.lookup_qualified(
+ target.expr.name, target, suppress_errors=True
+ )
+ if sym and sym.node:
+ sym_type = get_proper_type(sym.type)
+ if isinstance(sym_type, Instance):
+ target.fullname = f"{sym_type.type.fullname}.{target.name}"
+ else:
+ # if the registry is in the same file as where the
+ # decorator is used, it might not have semantic
+ # symbols applied and we can't get a fully qualified
+ # name or an inferred type, so we are actually going to
+ # flag an error in this case that they need to annotate
+ # it. The "registry" is declared just
+ # once (or few times), so they have to just not use
+ # type inference for its assignment in this one case.
+ util.fail(
+ ctx.api,
+ "Class decorator called %s(), but we can't "
+ "tell if it's from an ORM registry. Please "
+ "annotate the registry assignment, e.g. "
+ "my_registry: registry = registry()" % target.name,
+ sym.node,
+ )
+
+
+def _cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ assert isinstance(ctx.reason, nodes.MemberExpr)
+ expr = ctx.reason.expr
+
+ assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
+
+ node_type = get_proper_type(expr.node.type)
+
+ assert (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.REGISTRY
+ )
+
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+
+ cls = ctx.cls
+
+ _set_declarative_metaclass(ctx.api, cls)
+
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ ctx.cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
+ util.set_is_base(ctx.cls.info)
+
+
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+ # how do I....tell it it has no attribute of a certain name?
+ # can't find any Type that seems to match that
+ return ctx.default_attr_type
+
+
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
+ """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
+ for all class defs
+
+ """
+
+ util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
+
+
+def _set_declarative_metaclass(
+ api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
+) -> None:
+ info = target_cls.info
+ sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.orm.decl_api.DeclarativeMeta"
+ )
+ assert sym is not None and isinstance(sym.node, TypeInfo)
+ info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
new file mode 100644
index 0000000..16b365e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -0,0 +1,305 @@
+import re
+from typing import Any
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
+
+from mypy.nodes import ARG_POS
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import CLASSDEF_NO_INFO
+from mypy.nodes import Context
+from mypy.nodes import Expression
+from mypy.nodes import IfStmt
+from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import Statement
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import Type
+from mypy.types import TypeVarType
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
+class SQLAlchemyAttribute:
+ def __init__(
+ self,
+ name: str,
+ line: int,
+ column: int,
+ typ: Optional[Type],
+ info: TypeInfo,
+ ) -> None:
+ self.name = name
+ self.line = line
+ self.column = column
+ self.type = typ
+ self.info = info
+
+ def serialize(self) -> JsonDict:
+ assert self.type
+ return {
+ "name": self.name,
+ "line": self.line,
+ "column": self.column,
+ "type": self.type.serialize(),
+ }
+
+ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+ """Expands type vars in the context of a subtype when an attribute is
+ inherited from a generic super type.
+ """
+ if not isinstance(self.type, TypeVarType):
+ return
+
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
+ @classmethod
+ def deserialize(
+ cls,
+ info: TypeInfo,
+ data: JsonDict,
+ api: SemanticAnalyzerPluginInterface,
+ ) -> "SQLAlchemyAttribute":
+ data = data.copy()
+ typ = deserialize_and_fixup_type(data.pop("type"), api)
+ return cls(typ=typ, info=info, **data)
+
+
+def name_is_dunder(name):
+ return bool(re.match(r"^__.+?__$", name))
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+ info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ if info.mro:
+ for base in info.mro:
+ metadata = _get_info_metadata(base, key)
+ if metadata is not None:
+ return metadata
+ return None
+
+
+def establish_as_sqlalchemy(info: TypeInfo) -> None:
+ info.metadata.setdefault("sqlalchemy", {})
+
+
+def set_is_base(info: TypeInfo) -> None:
+ _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "is_base")
+ return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+ is_base = _get_info_mro_metadata(info, "is_base")
+ return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+ _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "has_table")
+ return is_base is True
+
+
+def get_mapped_attributes(
+ info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+ mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+ info, "mapped_attributes"
+ )
+ if mapped_attributes is None:
+ return None
+
+ attributes: List[SQLAlchemyAttribute] = []
+
+ for data in mapped_attributes:
+ attr = SQLAlchemyAttribute.deserialize(info, data, api)
+ attr.expand_typevar_from_subtype(info)
+ attributes.append(attr)
+
+ return attributes
+
+
+def set_mapped_attributes(
+ info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+ _set_info_metadata(
+ info,
+ "mapped_attributes",
+ [attribute.serialize() for attribute in attributes],
+ )
+
+
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
+ msg = "[SQLAlchemy Mypy plugin] %s" % msg
+ return api.fail(msg, ctx)
+
+
+def add_global(
+ ctx: Union[ClassDefContext, DynamicClassDefContext],
+ module: str,
+ symbol_name: str,
+ asname: str,
+) -> None:
+ module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
+
+ if asname not in module_globals:
+ lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
+ symbol_name
+ ]
+
+ module_globals[asname] = lookup_sym
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+ ...
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+ ...
+
+
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
+ try:
+ arg_idx = callexpr.arg_names.index(name)
+ except ValueError:
+ return None
+
+ kwarg = callexpr.args[arg_idx]
+ if isinstance(
+ kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+ ):
+ return kwarg
+
+ return None
+
+
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+ for stmt in stmts:
+ if (
+ isinstance(stmt, IfStmt)
+ and isinstance(stmt.expr[0], NameExpr)
+ and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+ ):
+ for substmt in stmt.body[0].body:
+ yield substmt
+ else:
+ yield stmt
+
+
+def unbound_to_instance(
+ api: SemanticAnalyzerPluginInterface, typ: Type
+) -> Type:
+ """Take the UnboundType that we seem to get as the ret_type from a FuncDef
+ and convert it into an Instance/TypeInfo kind of structure that seems
+ to work as the left-hand type of an AssignmentStatement.
+
+ """
+
+ if not isinstance(typ, UnboundType):
+ return typ
+
+ # TODO: figure out a more robust way to check this. The node is some
+ # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
+ # but I cant figure out how to get them to match up
+ if typ.name == "Optional":
+ # convert from "Optional?" to the more familiar
+ # UnionType[..., NoneType()]
+ return unbound_to_instance(
+ api,
+ UnionType(
+ [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ + [NoneType()]
+ ),
+ )
+
+ node = api.lookup_qualified(typ.name, typ)
+
+ if (
+ node is not None
+ and isinstance(node, SymbolTableNode)
+ and isinstance(node.node, TypeInfo)
+ ):
+ bound_type = node.node
+
+ return Instance(
+ bound_type,
+ [
+ unbound_to_instance(api, arg)
+ if isinstance(arg, UnboundType)
+ else arg
+ for arg in typ.args
+ ],
+ )
+ else:
+ return typ
+
+
+def info_for_cls(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[TypeInfo]:
+ if cls.info is CLASSDEF_NO_INFO:
+ sym = api.lookup_qualified(cls.name, cls)
+ if sym is None:
+ return None
+ assert sym and isinstance(sym.node, TypeInfo)
+ return sym.node
+
+ return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
new file mode 100644
index 0000000..5a327d1
--- /dev/null
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -0,0 +1,388 @@
+# ext/orderinglist.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""A custom list that manages index/position information for contained
+elements.
+
+:author: Jason Kirtland
+
+``orderinglist`` is a helper for mutable ordered relationships. It will
+intercept list operations performed on a :func:`_orm.relationship`-managed
+collection and
+automatically synchronize changes in list position onto a target scalar
+attribute.
+
+Example: A ``slide`` table, where each row refers to zero or more entries
+in a related ``bullet`` table. The bullets within a slide are
+displayed in order based on the value of the ``position`` column in the
+``bullet`` table. As entries are reordered in memory, the value of the
+``position`` attribute should be updated to reflect the new sort order::
+
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position")
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+The standard relationship mapping will produce a list-like attribute on each
+``Slide`` containing all related ``Bullet`` objects,
+but coping with changes in ordering is not handled automatically.
+When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
+attribute will remain unset until manually assigned. When the ``Bullet``
+is inserted into the middle of the list, the following ``Bullet`` objects
+will also need to be renumbered.
+
+The :class:`.OrderingList` object automates this task, managing the
+``position`` attribute on all ``Bullet`` objects in the collection. It is
+constructed using the :func:`.ordering_list` factory::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+With the above mapping the ``Bullet.position`` attribute is managed::
+
+ s = Slide()
+ s.bullets.append(Bullet())
+ s.bullets.append(Bullet())
+ s.bullets[1].position
+ >>> 1
+ s.bullets.insert(1, Bullet())
+ s.bullets[2].position
+ >>> 2
+
+The :class:`.OrderingList` construct only works with **changes** to a
+collection, and not the initial load from the database, and requires that the
+list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
+:func:`_orm.relationship` against the target ordering attribute, so that the
+ordering is correct when first loaded.
+
+.. warning::
+
+ :class:`.OrderingList` only provides limited functionality when a primary
+ key column or unique column is the target of the sort. Operations
+ that are unsupported or are problematic include:
+
+ * two entries must trade values. This is not supported directly in the
+ case of a primary key or unique constraint because it means at least
+ one row would need to be temporarily removed first, or changed to
+ a third, neutral value while the switch occurs.
+
+ * an entry must be deleted in order to make room for a new entry.
+ SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
+ single flush. In the case of a primary key, it will trade
+ an INSERT/DELETE of the same primary key for an UPDATE statement in order
+ to lessen the impact of this limitation, however this does not take place
+ for a UNIQUE column.
+ A future feature will allow the "DELETE before INSERT" behavior to be
+ possible, alleviating this limitation, though this feature will require
+ explicit configuration at the mapper level for sets of columns that
+ are to be handled in this way.
+
+:func:`.ordering_list` takes the name of the related object's ordering
+attribute as an argument. By default, the zero-based integer index of the
+object's position in the :func:`.ordering_list` is synchronized with the
+ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
+start numbering at 1 or some other integer, provide ``count_from=1``.
+
+
+"""
+from ..orm.collections import collection
+from ..orm.collections import collection_adapter
+
+
+__all__ = ["ordering_list"]
+
+
+def ordering_list(attr, count_from=None, **kw):
+ """Prepares an :class:`OrderingList` factory for use in mapper definitions.
+
+ Returns an object suitable for use as an argument to a Mapper
+ relationship's ``collection_class`` option. e.g.::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ :param attr:
+ Name of the mapped attribute to use for storage and retrieval of
+ ordering information
+
+ :param count_from:
+ Set up an integer-based ordering, starting at ``count_from``. For
+ example, ``ordering_list('pos', count_from=1)`` would create a 1-based
+ list in SQL, storing the value in the 'pos' column. Ignored if
+ ``ordering_func`` is supplied.
+
+ Additional arguments are passed to the :class:`.OrderingList` constructor.
+
+ """
+
+ kw = _unsugar_count_from(count_from=count_from, **kw)
+ return lambda: OrderingList(attr, **kw)
+
+
+# Ordering utility functions
+
+
+def count_from_0(index, collection):
+ """Numbering function: consecutive integers starting at 0."""
+
+ return index
+
+
+def count_from_1(index, collection):
+ """Numbering function: consecutive integers starting at 1."""
+
+ return index + 1
+
+
+def count_from_n_factory(start):
+ """Numbering function: consecutive integers starting at arbitrary start."""
+
+ def f(index, collection):
+ return index + start
+
+ try:
+ f.__name__ = "count_from_%i" % start
+ except TypeError:
+ pass
+ return f
+
+
+def _unsugar_count_from(**kw):
+ """Builds counting functions from keyword arguments.
+
+ Keyword argument filter, prepares a simple ``ordering_func`` from a
+ ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
+ """
+
+ count_from = kw.pop("count_from", None)
+ if kw.get("ordering_func", None) is None and count_from is not None:
+ if count_from == 0:
+ kw["ordering_func"] = count_from_0
+ elif count_from == 1:
+ kw["ordering_func"] = count_from_1
+ else:
+ kw["ordering_func"] = count_from_n_factory(count_from)
+ return kw
+
+
+class OrderingList(list):
+ """A custom list that manages position information for its children.
+
+ The :class:`.OrderingList` object is normally set up using the
+ :func:`.ordering_list` factory function, used in conjunction with
+ the :func:`_orm.relationship` function.
+
+ """
+
+ def __init__(
+ self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ ):
+ """A custom list that manages position information for its children.
+
+ ``OrderingList`` is a ``collection_class`` list implementation that
+ syncs position in a Python list with a position attribute on the
+ mapped objects.
+
+ This implementation relies on the list starting in the proper order,
+ so be **sure** to put an ``order_by`` on your relationship.
+
+ :param ordering_attr:
+ Name of the attribute that stores the object's order in the
+ relationship.
+
+ :param ordering_func: Optional. A function that maps the position in
+ the Python list to a value to store in the
+ ``ordering_attr``. Values returned are usually (but need not be!)
+ integers.
+
+ An ``ordering_func`` is called with two positional parameters: the
+ index of the element in the list, and the list itself.
+
+ If omitted, Python list indexes are used for the attribute values.
+ Two basic pre-built numbering functions are provided in this module:
+ ``count_from_0`` and ``count_from_1``. For more exotic examples
+ like stepped numbering, alphabetical and Fibonacci numbering, see
+ the unit tests.
+
+ :param reorder_on_append:
+ Default False. When appending an object with an existing (non-None)
+ ordering value, that value will be left untouched unless
+ ``reorder_on_append`` is true. This is an optimization to avoid a
+ variety of dangerous unexpected database writes.
+
+ SQLAlchemy will add instances to the list via append() when your
+ object loads. If for some reason the result set from the database
+ skips a step in the ordering (say, row '1' is missing but you get
+ '2', '3', and '4'), reorder_on_append=True would immediately
+ renumber the items to '1', '2', '3'. If you have multiple sessions
+ making changes, any of whom happen to load this collection even in
+ passing, all of the sessions would try to "clean up" the numbering
+ in their commits, possibly causing all but one to fail with a
+ concurrent modification error.
+
+ Recommend leaving this with the default of False, and just call
+ ``reorder()`` if you're doing ``append()`` operations with
+ previously ordered instances or when doing some housekeeping after
+ manual sql operations.
+
+ """
+ self.ordering_attr = ordering_attr
+ if ordering_func is None:
+ ordering_func = count_from_0
+ self.ordering_func = ordering_func
+ self.reorder_on_append = reorder_on_append
+
+ # More complex serialization schemes (multi column, e.g.) are possible by
+ # subclassing and reimplementing these two methods.
+ def _get_order_value(self, entity):
+ return getattr(entity, self.ordering_attr)
+
+ def _set_order_value(self, entity, value):
+ setattr(entity, self.ordering_attr, value)
+
+ def reorder(self):
+ """Synchronize ordering for the entire collection.
+
+ Sweeps through the list and ensures that each object has accurate
+ ordering information set.
+
+ """
+ for index, entity in enumerate(self):
+ self._order_entity(index, entity, True)
+
+ # As of 0.5, _reorder is no longer semi-private
+ _reorder = reorder
+
+ def _order_entity(self, index, entity, reorder=True):
+ have = self._get_order_value(entity)
+
+ # Don't disturb existing ordering if reorder is False
+ if have is not None and not reorder:
+ return
+
+ should_be = self.ordering_func(index, self)
+ if have != should_be:
+ self._set_order_value(entity, should_be)
+
+ def append(self, entity):
+ super(OrderingList, self).append(entity)
+ self._order_entity(len(self) - 1, entity, self.reorder_on_append)
+
+ def _raw_append(self, entity):
+ """Append without any ordering behavior."""
+
+ super(OrderingList, self).append(entity)
+
+ _raw_append = collection.adds(1)(_raw_append)
+
+ def insert(self, index, entity):
+ super(OrderingList, self).insert(index, entity)
+ self._reorder()
+
+ def remove(self, entity):
+ super(OrderingList, self).remove(entity)
+
+ adapter = collection_adapter(self)
+ if adapter and adapter._referenced_by_owner:
+ self._reorder()
+
+ def pop(self, index=-1):
+ entity = super(OrderingList, self).pop(index)
+ self._reorder()
+ return entity
+
+ def __setitem__(self, index, entity):
+ if isinstance(index, slice):
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ stop = index.stop or len(self)
+ if stop < 0:
+ stop += len(self)
+
+ for i in range(start, stop, step):
+ self.__setitem__(i, entity[i])
+ else:
+ self._order_entity(index, entity, True)
+ super(OrderingList, self).__setitem__(index, entity)
+
+ def __delitem__(self, index):
+ super(OrderingList, self).__delitem__(index)
+ self._reorder()
+
+ def __setslice__(self, start, end, values):
+ super(OrderingList, self).__setslice__(start, end, values)
+ self._reorder()
+
+ def __delslice__(self, start, end):
+ super(OrderingList, self).__delslice__(start, end)
+ self._reorder()
+
+ def __reduce__(self):
+ return _reconstitute, (self.__class__, self.__dict__, list(self))
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+def _reconstitute(cls, dict_, items):
+ """Reconstitute an :class:`.OrderingList`.
+
+ This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
+ unpickling :class:`.OrderingList` objects.
+
+ """
+ obj = cls.__new__(cls)
+ obj.__dict__.update(dict_)
+ list.extend(obj, items)
+ return obj
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
new file mode 100644
index 0000000..094b71b
--- /dev/null
+++ b/lib/sqlalchemy/ext/serializer.py
@@ -0,0 +1,177 @@
+# ext/serializer.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
+allowing "contextual" deserialization.
+
+Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
+or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
+etc. which are referenced by the structure are not persisted in serialized
+form, but are instead re-associated with the query structure
+when it is deserialized.
+
+Usage is nearly the same as that of the standard Python pickle module::
+
+ from sqlalchemy.ext.serializer import loads, dumps
+ metadata = MetaData(bind=some_engine)
+ Session = scoped_session(sessionmaker())
+
+ # ... define mappers
+
+ query = Session.query(MyClass).
+ filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
+
+ # pickle the query
+ serialized = dumps(query)
+
+ # unpickle. Pass in metadata + scoped_session
+ query2 = loads(serialized, metadata, Session)
+
+ print query2.all()
+
+Similar restrictions as when using raw pickle apply; mapped classes must be
+themselves be pickleable, meaning they are importable from a module-level
+namespace.
+
+The serializer module is only appropriate for query structures. It is not
+needed for:
+
+* instances of user-defined classes. These contain no references to engines,
+ sessions or expression constructs in the typical case and can be serialized
+ directly.
+
+* Table metadata that is to be loaded entirely from the serialized structure
+ (i.e. is not already declared in the application). Regular
+ pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
+ typically one which was reflected from an existing database at some previous
+ point in time. The serializer module is specifically for the opposite case,
+ where the Table metadata is already present in memory.
+
+"""
+
+import re
+
+from .. import Column
+from .. import Table
+from ..engine import Engine
+from ..orm import class_mapper
+from ..orm.interfaces import MapperProperty
+from ..orm.mapper import Mapper
+from ..orm.session import Session
+from ..util import b64decode
+from ..util import b64encode
+from ..util import byte_buffer
+from ..util import pickle
+from ..util import text_type
+
+
+__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
+
+
+def Serializer(*args, **kw):
+ pickler = pickle.Pickler(*args, **kw)
+
+ def persistent_id(obj):
+ # print "serializing:", repr(obj)
+ if isinstance(obj, Mapper) and not obj.non_primary:
+ id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
+ elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
+ id_ = (
+ "mapperprop:"
+ + b64encode(pickle.dumps(obj.parent.class_))
+ + ":"
+ + obj.key
+ )
+ elif isinstance(obj, Table):
+ if "parententity" in obj._annotations:
+ id_ = "mapper_selectable:" + b64encode(
+ pickle.dumps(obj._annotations["parententity"].class_)
+ )
+ else:
+ id_ = "table:" + text_type(obj.key)
+ elif isinstance(obj, Column) and isinstance(obj.table, Table):
+ id_ = (
+ "column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
+ )
+ elif isinstance(obj, Session):
+ id_ = "session:"
+ elif isinstance(obj, Engine):
+ id_ = "engine:"
+ else:
+ return None
+ return id_
+
+ pickler.persistent_id = persistent_id
+ return pickler
+
+
+our_ids = re.compile(
+ r"(mapperprop|mapper|mapper_selectable|table|column|"
+ r"session|attribute|engine):(.*)"
+)
+
+
+def Deserializer(file, metadata=None, scoped_session=None, engine=None):
+ unpickler = pickle.Unpickler(file)
+
+ def get_engine():
+ if engine:
+ return engine
+ elif scoped_session and scoped_session().bind:
+ return scoped_session().bind
+ elif metadata and metadata.bind:
+ return metadata.bind
+ else:
+ return None
+
+ def persistent_load(id_):
+ m = our_ids.match(text_type(id_))
+ if not m:
+ return None
+ else:
+ type_, args = m.group(1, 2)
+ if type_ == "attribute":
+ key, clsarg = args.split(":")
+ cls = pickle.loads(b64decode(clsarg))
+ return getattr(cls, key)
+ elif type_ == "mapper":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls)
+ elif type_ == "mapper_selectable":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls).__clause_element__()
+ elif type_ == "mapperprop":
+ mapper, keyname = args.split(":")
+ cls = pickle.loads(b64decode(mapper))
+ return class_mapper(cls).attrs[keyname]
+ elif type_ == "table":
+ return metadata.tables[args]
+ elif type_ == "column":
+ table, colname = args.split(":")
+ return metadata.tables[table].c[colname]
+ elif type_ == "session":
+ return scoped_session()
+ elif type_ == "engine":
+ return get_engine()
+ else:
+ raise Exception("Unknown token: %s" % type_)
+
+ unpickler.persistent_load = persistent_load
+ return unpickler
+
+
+def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
+ buf = byte_buffer()
+ pickler = Serializer(buf, protocol)
+ pickler.dump(obj)
+ return buf.getvalue()
+
+
+def loads(data, metadata=None, scoped_session=None, engine=None):
+ buf = byte_buffer(data)
+ unpickler = Deserializer(buf, metadata, scoped_session, engine)
+ return unpickler.load()
diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py
new file mode 100644
index 0000000..a2bed07
--- /dev/null
+++ b/lib/sqlalchemy/future/__init__.py
@@ -0,0 +1,18 @@
+# sql/future/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Future 2.0 API features.
+
+"""
+from .engine import Connection
+from .engine import create_engine
+from .engine import Engine
+from ..sql.selectable import Select
+from ..util.langhelpers import public_factory
+
+
+select = public_factory(Select._create_future_select, ".future.select")
diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py
new file mode 100644
index 0000000..3235529
--- /dev/null
+++ b/lib/sqlalchemy/future/engine.py
@@ -0,0 +1,413 @@
+from .. import util
+from ..engine import Connection as _LegacyConnection
+from ..engine import create_engine as _create_engine
+from ..engine import Engine as _LegacyEngine
+from ..engine.base import OptionEngineMixin
+
+NO_OPTIONS = util.immutabledict()
+
+
+def create_engine(*arg, **kw):
+ """Create a new :class:`_future.Engine` instance.
+
+ Arguments passed to :func:`_future.create_engine` are mostly identical
+ to those passed to the 1.x :func:`_sa.create_engine` function.
+ The difference is that the object returned is the :class:`._future.Engine`
+ which has the 2.0 version of the API.
+
+ """
+
+ kw["_future_engine_class"] = Engine
+ return _create_engine(*arg, **kw)
+
+
+class Connection(_LegacyConnection):
+ """Provides high-level functionality for a wrapped DB-API connection.
+
+ The :class:`_future.Connection` object is procured by calling
+ the :meth:`_future.Engine.connect` method of the :class:`_future.Engine`
+ object, and provides services for execution of SQL statements as well
+ as transaction control.
+
+ **This is the SQLAlchemy 2.0 version** of the :class:`_engine.Connection`
+ class. The API and behavior of this object is largely the same, with the
+ following differences in behavior:
+
+ * The result object returned for results is the
+ :class:`_engine.CursorResult`
+ object, which is a subclass of the :class:`_engine.Result`.
+ This object has a slightly different API and behavior than the
+ :class:`_engine.LegacyCursorResult` returned for 1.x style usage.
+
+ * The object has :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods which commit or roll back
+ the current transaction in progress, if any.
+
+ * The object features "autobegin" behavior, such that any call to
+ :meth:`_future.Connection.execute` will
+ unconditionally start a
+ transaction which can be controlled using the above mentioned
+ :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods.
+
+ * The object does not have any "autocommit" functionality. Any SQL
+ statement or DDL statement will not be followed by any COMMIT until
+ the transaction is explicitly committed, either via the
+ :meth:`_future.Connection.commit` method, or if the connection is
+ being used in a context manager that commits such as the one
+ returned by :meth:`_future.Engine.begin`.
+
+ * The SAVEPOINT method :meth:`_future.Connection.begin_nested` returns
+ a :class:`_engine.NestedTransaction` as was always the case, and the
+ savepoint can be controlled by invoking
+ :meth:`_engine.NestedTransaction.commit` or
+ :meth:`_engine.NestedTransaction.rollback` as was the case before.
+ However, this savepoint "transaction" is not associated with the
+ transaction that is controlled by the connection itself; the overall
+ transaction can be committed or rolled back directly which will not emit
+ any special instructions for the SAVEPOINT (this will typically have the
+ effect that one desires).
+
+ * The :class:`_future.Connection` object does not support "branching",
+ which was a pattern by which a sub "connection" would be used that
+ refers to this connection as a parent.
+
+
+
+ """
+
+ _is_future = True
+
+ def _branch(self):
+ raise NotImplementedError(
+ "sqlalchemy.future.Connection does not support "
+ "'branching' of new connections."
+ )
+
+ def begin(self):
+ """Begin a transaction prior to autobegin occurring.
+
+ The returned object is an instance of :class:`_engine.RootTransaction`.
+ This object represents the "scope" of the transaction,
+ which completes when either the :meth:`_engine.Transaction.rollback`
+ or :meth:`_engine.Transaction.commit` method is called.
+
+ The :meth:`_future.Connection.begin` method in SQLAlchemy 2.0 begins a
+ transaction that normally will be begun in any case when the connection
+ is first used to execute a statement. The reason this method might be
+ used would be to invoke the :meth:`_events.ConnectionEvents.begin`
+ event at a specific time, or to organize code within the scope of a
+ connection checkout in terms of context managed blocks, such as::
+
+ with engine.connect() as conn:
+ with conn.begin():
+ conn.execute(...)
+ conn.execute(...)
+
+ with conn.begin():
+ conn.execute(...)
+ conn.execute(...)
+
+ The above code is not fundamentally any different in its behavior than
+ the following code which does not use
+ :meth:`_future.Connection.begin`; the below style is referred towards
+ as "commit as you go" style::
+
+ with engine.connect() as conn:
+ conn.execute(...)
+ conn.execute(...)
+ conn.commit()
+
+ conn.execute(...)
+ conn.execute(...)
+ conn.commit()
+
+ From a database point of view, the :meth:`_future.Connection.begin`
+ method does not emit any SQL or change the state of the underlying
+ DBAPI connection in any way; the Python DBAPI does not have any
+ concept of explicit transaction begin.
+
+ .. seealso::
+
+ :ref:`tutorial_working_with_transactions` - in the
+ :ref:`unified_tutorial`
+
+ :meth:`_future.Connection.begin_nested` - use a SAVEPOINT
+
+ :meth:`_engine.Connection.begin_twophase` -
+ use a two phase /XID transaction
+
+ :meth:`_future.Engine.begin` - context manager available from
+ :class:`_future.Engine`
+
+ """
+ return super(Connection, self).begin()
+
+ def begin_nested(self):
+ """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction
+ handle.
+
+ The returned object is an instance of
+ :class:`_engine.NestedTransaction`.
+
+ Nested transactions require SAVEPOINT support in the
+ underlying database. Any transaction in the hierarchy may
+ ``commit`` and ``rollback``, however the outermost transaction
+ still controls the overall ``commit`` or ``rollback`` of the
+ transaction of a whole.
+
+ If an outer :class:`.RootTransaction` is not present on this
+ :class:`_future.Connection`, a new one is created using "autobegin".
+ This outer transaction may be completed using "commit-as-you-go" style
+ usage, by calling upon :meth:`_future.Connection.commit` or
+ :meth:`_future.Connection.rollback`.
+
+ .. tip::
+
+ The "autobegin" behavior of :meth:`_future.Connection.begin_nested`
+ is specific to :term:`2.0 style` use; for legacy behaviors, see
+ :meth:`_engine.Connection.begin_nested`.
+
+ The :class:`_engine.NestedTransaction` remains independent of the
+ :class:`_future.Connection` object itself. Calling the
+ :meth:`_future.Connection.commit` or
+ :meth:`_future.Connection.rollback` will always affect the actual
+ containing database transaction itself, and not the SAVEPOINT itself.
+ When a database transaction is committed, any SAVEPOINTs that have been
+ established are cleared and the data changes within their scope is also
+ committed.
+
+ .. seealso::
+
+ :meth:`_future.Connection.begin`
+
+
+ """
+ return super(Connection, self).begin_nested()
+
+ def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ .. note:: The :meth:`_future.Connection.commit` method only acts upon
+ the primary database transaction that is linked to the
+ :class:`_future.Connection` object. It does not operate upon a
+ SAVEPOINT that would have been invoked from the
+ :meth:`_future.Connection.begin_nested` method; for control of a
+ SAVEPOINT, call :meth:`_engine.NestedTransaction.commit` on the
+ :class:`_engine.NestedTransaction` that is returned by the
+ :meth:`_future.Connection.begin_nested` method itself.
+
+
+ """
+ if self._transaction:
+ self._transaction.commit()
+
+ def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ .. note:: The :meth:`_future.Connection.rollback` method only acts
+ upon the primary database transaction that is linked to the
+ :class:`_future.Connection` object. It does not operate upon a
+ SAVEPOINT that would have been invoked from the
+ :meth:`_future.Connection.begin_nested` method; for control of a
+ SAVEPOINT, call :meth:`_engine.NestedTransaction.rollback` on the
+ :class:`_engine.NestedTransaction` that is returned by the
+ :meth:`_future.Connection.begin_nested` method itself.
+
+
+ """
+ if self._transaction:
+ self._transaction.rollback()
+
+ def close(self):
+ """Close this :class:`_future.Connection`.
+
+ This has the effect of also calling :meth:`_future.Connection.rollback`
+ if any transaction is in place.
+
+ """
+ super(Connection, self).close()
+
+ def execute(self, statement, parameters=None, execution_options=None):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.Result`.
+
+ :param statement: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ return self._execute_20(
+ statement, parameters, execution_options or NO_OPTIONS
+ )
+
+ def scalar(self, statement, parameters=None, execution_options=None):
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ return self.execute(statement, parameters, execution_options).scalar()
+
+
+class Engine(_LegacyEngine):
+ """Connects a :class:`_pool.Pool` and
+ :class:`_engine.Dialect` together to provide a
+ source of database connectivity and behavior.
+
+ **This is the SQLAlchemy 2.0 version** of the :class:`~.engine.Engine`.
+
+ An :class:`.future.Engine` object is instantiated publicly using the
+ :func:`~sqlalchemy.future.create_engine` function.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :ref:`connections_toplevel`
+
+ """
+
+ _connection_cls = Connection
+ _is_future = True
+
+ def _not_implemented(self, *arg, **kw):
+ raise NotImplementedError(
+ "This method is not implemented for SQLAlchemy 2.0."
+ )
+
+ transaction = (
+ run_callable
+ ) = (
+ execute
+ ) = (
+ scalar
+ ) = (
+ _execute_clauseelement
+ ) = _execute_compiled = table_names = has_table = _not_implemented
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ # TODO: this is for create_all support etc. not clear if we
+ # want to provide this in 2.0, that is, a way to execute SQL where
+ # they aren't calling "engine.begin()" explicitly, however, DDL
+ # may be a special case for which we want to continue doing it this
+ # way. A big win here is that the full DDL sequence is inside of a
+ # single transaction rather than COMMIT for each statement.
+ with self.begin() as conn:
+ conn._run_ddl_visitor(visitorcallable, element, **kwargs)
+
+ @classmethod
+ def _future_facade(self, legacy_engine):
+ return Engine(
+ legacy_engine.pool,
+ legacy_engine.dialect,
+ legacy_engine.url,
+ logging_name=legacy_engine.logging_name,
+ echo=legacy_engine.echo,
+ hide_parameters=legacy_engine.hide_parameters,
+ execution_options=legacy_engine._execution_options,
+ )
+
+ @util.contextmanager
+ def begin(self):
+ """Return a :class:`_future.Connection` object with a transaction
+ begun.
+
+ Use of this method is similar to that of
+ :meth:`_future.Engine.connect`, typically as a context manager, which
+ will automatically maintain the state of the transaction when the block
+ ends, either by calling :meth:`_future.Connection.commit` when the
+ block succeeds normally, or :meth:`_future.Connection.rollback` when an
+ exception is raised, before propagating the exception outwards::
+
+ with engine.begin() as connection:
+ connection.execute(text("insert into table values ('foo')"))
+
+
+ .. seealso::
+
+ :meth:`_future.Engine.connect`
+
+ :meth:`_future.Connection.begin`
+
+ """
+ with self.connect() as conn:
+ with conn.begin():
+ yield conn
+
+ def connect(self):
+ """Return a new :class:`_future.Connection` object.
+
+ The :class:`_future.Connection` acts as a Python context manager, so
+ the typical use of this method looks like::
+
+ with engine.connect() as connection:
+ connection.execute(text("insert into table values ('foo')"))
+ connection.commit()
+
+ Where above, after the block is completed, the connection is "closed"
+ and its underlying DBAPI resources are returned to the connection pool.
+ This also has the effect of rolling back any transaction that
+ was explicitly begun or was begun via autobegin, and will
+ emit the :meth:`_events.ConnectionEvents.rollback` event if one was
+ started and is still in progress.
+
+ .. seealso::
+
+ :meth:`_future.Engine.begin`
+
+
+ """
+ return super(Engine, self).connect()
+
+
+class OptionEngine(OptionEngineMixin, Engine):
+ pass
+
+
+Engine._option_cls = OptionEngine
diff --git a/lib/sqlalchemy/future/orm/__init__.py b/lib/sqlalchemy/future/orm/__init__.py
new file mode 100644
index 0000000..629631b
--- /dev/null
+++ b/lib/sqlalchemy/future/orm/__init__.py
@@ -0,0 +1,10 @@
+# sql/future/orm/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Future 2.0 API features for Orm.
+
+"""
diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py
new file mode 100644
index 0000000..7f9822d
--- /dev/null
+++ b/lib/sqlalchemy/inspection.py
@@ -0,0 +1,93 @@
+# sqlalchemy/inspect.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The inspection module provides the :func:`_sa.inspect` function,
+which delivers runtime information about a wide variety
+of SQLAlchemy objects, both within the Core as well as the
+ORM.
+
+The :func:`_sa.inspect` function is the entry point to SQLAlchemy's
+public API for viewing the configuration and construction
+of in-memory objects. Depending on the type of object
+passed to :func:`_sa.inspect`, the return value will either be
+a related object which provides a known interface, or in many
+cases it will return the object itself.
+
+The rationale for :func:`_sa.inspect` is twofold. One is that
+it replaces the need to be aware of a large variety of "information
+getting" functions in SQLAlchemy, such as
+:meth:`_reflection.Inspector.from_engine` (deprecated in 1.4),
+:func:`.orm.attributes.instance_state`, :func:`_orm.class_mapper`,
+and others. The other is that the return value of :func:`_sa.inspect`
+is guaranteed to obey a documented API, thus allowing third party
+tools which build on top of SQLAlchemy configurations to be constructed
+in a forwards-compatible way.
+
+"""
+
+from . import exc
+from . import util
+
+
+_registrars = util.defaultdict(list)
+
+
+def inspect(subject, raiseerr=True):
+ """Produce an inspection object for the given target.
+
+ The returned value in some cases may be the
+ same object as the one given, such as if a
+ :class:`_orm.Mapper` object is passed. In other
+ cases, it will be an instance of the registered
+ inspection type for the given object, such as
+ if an :class:`_engine.Engine` is passed, an
+ :class:`_reflection.Inspector` object is returned.
+
+ :param subject: the subject to be inspected.
+ :param raiseerr: When ``True``, if the given subject
+ does not
+ correspond to a known SQLAlchemy inspected type,
+ :class:`sqlalchemy.exc.NoInspectionAvailable`
+ is raised. If ``False``, ``None`` is returned.
+
+ """
+ type_ = type(subject)
+ for cls in type_.__mro__:
+ if cls in _registrars:
+ reg = _registrars[cls]
+ if reg is True:
+ return subject
+ ret = reg(subject)
+ if ret is not None:
+ break
+ else:
+ reg = ret = None
+
+ if raiseerr and (reg is None or ret is None):
+ raise exc.NoInspectionAvailable(
+ "No inspection system is "
+ "available for object of type %s" % type_
+ )
+ return ret
+
+
+def _inspects(*types):
+ def decorate(fn_or_cls):
+ for type_ in types:
+ if type_ in _registrars:
+ raise AssertionError(
+ "Type %s is already " "registered" % type_
+ )
+ _registrars[type_] = fn_or_cls
+ return fn_or_cls
+
+ return decorate
+
+
+def _self_inspects(cls):
+ _inspects(cls)(True)
+ return cls
diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py
new file mode 100644
index 0000000..cc662ec
--- /dev/null
+++ b/lib/sqlalchemy/log.py
@@ -0,0 +1,241 @@
+# sqlalchemy/log.py
+# Copyright (C) 2006-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Logging control and utilities.
+
+Control of logging for SA can be performed from the regular python logging
+module. The regular dotted module namespace is used, starting at
+'sqlalchemy'. For class-level logging, the class name is appended.
+
+The "echo" keyword parameter, available on SQLA :class:`_engine.Engine`
+and :class:`_pool.Pool` objects, corresponds to a logger specific to that
+instance only.
+
+"""
+
+import logging
+import sys
+
+from .util import py311
+from .util import py38
+
+if py38:
+ STACKLEVEL = True
+ # needed as of py3.11.0b1
+ # #8019
+ STACKLEVEL_OFFSET = 2 if py311 else 1
+else:
+ STACKLEVEL = False
+ STACKLEVEL_OFFSET = 0
+
+# set initial level to WARN. This so that
+# log statements don't occur in the absence of explicit
+# logging being enabled for 'sqlalchemy'.
+rootlogger = logging.getLogger("sqlalchemy")
+if rootlogger.level == logging.NOTSET:
+ rootlogger.setLevel(logging.WARN)
+
+
+def _add_default_handler(logger):
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
+ )
+ logger.addHandler(handler)
+
+
+_logged_classes = set()
+
+
+def _qual_logger_name_for_cls(cls):
+ return (
+ getattr(cls, "_sqla_logger_namespace", None)
+ or cls.__module__ + "." + cls.__name__
+ )
+
+
+def class_logger(cls):
+ logger = logging.getLogger(_qual_logger_name_for_cls(cls))
+ cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
+ cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
+ cls.logger = logger
+ _logged_classes.add(cls)
+ return cls
+
+
+class Identified(object):
+ logging_name = None
+
+ def _should_log_debug(self):
+ return self.logger.isEnabledFor(logging.DEBUG)
+
+ def _should_log_info(self):
+ return self.logger.isEnabledFor(logging.INFO)
+
+
+class InstanceLogger(object):
+ """A logger adapter (wrapper) for :class:`.Identified` subclasses.
+
+ This allows multiple instances (e.g. Engine or Pool instances)
+ to share a logger, but have its verbosity controlled on a
+ per-instance basis.
+
+ The basic functionality is to return a logging level
+ which is based on an instance's echo setting.
+
+ Default implementation is:
+
+ 'debug' -> logging.DEBUG
+ True -> logging.INFO
+ False -> Effective level of underlying logger (
+ logging.WARNING by default)
+ None -> same as False
+ """
+
+ # Map echo settings to logger levels
+ _echo_map = {
+ None: logging.NOTSET,
+ False: logging.NOTSET,
+ True: logging.INFO,
+ "debug": logging.DEBUG,
+ }
+
+ def __init__(self, echo, name):
+ self.echo = echo
+ self.logger = logging.getLogger(name)
+
+ # if echo flag is enabled and no handlers,
+ # add a handler to the list
+ if self._echo_map[echo] <= logging.INFO and not self.logger.handlers:
+ _add_default_handler(self.logger)
+
+ #
+ # Boilerplate convenience methods
+ #
+ def debug(self, msg, *args, **kwargs):
+ """Delegate a debug call to the underlying logger."""
+
+ self.log(logging.DEBUG, msg, *args, **kwargs)
+
+ def info(self, msg, *args, **kwargs):
+ """Delegate an info call to the underlying logger."""
+
+ self.log(logging.INFO, msg, *args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ """Delegate a warning call to the underlying logger."""
+
+ self.log(logging.WARNING, msg, *args, **kwargs)
+
+ warn = warning
+
+ def error(self, msg, *args, **kwargs):
+ """
+ Delegate an error call to the underlying logger.
+ """
+ self.log(logging.ERROR, msg, *args, **kwargs)
+
+ def exception(self, msg, *args, **kwargs):
+ """Delegate an exception call to the underlying logger."""
+
+ kwargs["exc_info"] = 1
+ self.log(logging.ERROR, msg, *args, **kwargs)
+
+ def critical(self, msg, *args, **kwargs):
+ """Delegate a critical call to the underlying logger."""
+
+ self.log(logging.CRITICAL, msg, *args, **kwargs)
+
+ def log(self, level, msg, *args, **kwargs):
+ """Delegate a log call to the underlying logger.
+
+ The level here is determined by the echo
+ flag as well as that of the underlying logger, and
+ logger._log() is called directly.
+
+ """
+
+ # inline the logic from isEnabledFor(),
+ # getEffectiveLevel(), to avoid overhead.
+
+ if self.logger.manager.disable >= level:
+ return
+
+ selected_level = self._echo_map[self.echo]
+ if selected_level == logging.NOTSET:
+ selected_level = self.logger.getEffectiveLevel()
+
+ if level >= selected_level:
+ if STACKLEVEL:
+ kwargs["stacklevel"] = (
+ kwargs.get("stacklevel", 1) + STACKLEVEL_OFFSET
+ )
+
+ self.logger._log(level, msg, args, **kwargs)
+
+ def isEnabledFor(self, level):
+ """Is this logger enabled for level 'level'?"""
+
+ if self.logger.manager.disable >= level:
+ return False
+ return level >= self.getEffectiveLevel()
+
+ def getEffectiveLevel(self):
+ """What's the effective level for this logger?"""
+
+ level = self._echo_map[self.echo]
+ if level == logging.NOTSET:
+ level = self.logger.getEffectiveLevel()
+ return level
+
+
+def instance_logger(instance, echoflag=None):
+ """create a logger for an instance that implements :class:`.Identified`."""
+
+ if instance.logging_name:
+ name = "%s.%s" % (
+ _qual_logger_name_for_cls(instance.__class__),
+ instance.logging_name,
+ )
+ else:
+ name = _qual_logger_name_for_cls(instance.__class__)
+
+ instance._echo = echoflag
+
+ if echoflag in (False, None):
+ # if no echo setting or False, return a Logger directly,
+ # avoiding overhead of filtering
+ logger = logging.getLogger(name)
+ else:
+ # if a specified echo flag, return an EchoLogger,
+ # which checks the flag, overrides normal log
+ # levels by calling logger._log()
+ logger = InstanceLogger(echoflag, name)
+
+ instance.logger = logger
+
+
+class echo_property(object):
+ __doc__ = """\
+ When ``True``, enable log output for this element.
+
+ This has the effect of setting the Python logging level for the namespace
+ of this element's class and object reference. A value of boolean ``True``
+ indicates that the loglevel ``logging.INFO`` will be set for the logger,
+ whereas the string value ``debug`` will set the loglevel to
+ ``logging.DEBUG``.
+ """
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self
+ else:
+ return instance._echo
+
+ def __set__(self, instance, value):
+ instance_logger(instance, echoflag=value)
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
new file mode 100644
index 0000000..6e0de05
--- /dev/null
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -0,0 +1,344 @@
+# orm/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Functional constructs for ORM configuration.
+
+See the SQLAlchemy object relational tutorial and mapper configuration
+documentation for an overview of how this module is used.
+
+"""
+
+from . import exc
+from . import mapper as mapperlib
+from . import strategy_options
+from .attributes import AttributeEvent
+from .attributes import InstrumentedAttribute
+from .attributes import Mapped
+from .attributes import QueryableAttribute
+from .context import QueryContext
+from .decl_api import as_declarative
+from .decl_api import declarative_base
+from .decl_api import declarative_mixin
+from .decl_api import DeclarativeMeta
+from .decl_api import declared_attr
+from .decl_api import has_inherited_table
+from .decl_api import registry
+from .decl_api import synonym_for
+from .descriptor_props import CompositeProperty
+from .descriptor_props import SynonymProperty
+from .identity import IdentityMap
+from .instrumentation import ClassManager
+from .interfaces import EXT_CONTINUE
+from .interfaces import EXT_SKIP
+from .interfaces import EXT_STOP
+from .interfaces import InspectionAttr
+from .interfaces import InspectionAttrInfo
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import MapperProperty
+from .interfaces import NOT_EXTENSION
+from .interfaces import ONETOMANY
+from .interfaces import PropComparator
+from .interfaces import UserDefinedOption
+from .loading import merge_frozen_result
+from .loading import merge_result
+from .mapper import class_mapper
+from .mapper import configure_mappers
+from .mapper import Mapper
+from .mapper import reconstructor
+from .mapper import validates
+from .properties import ColumnProperty
+from .query import AliasOption
+from .query import FromStatement
+from .query import Query
+from .relationships import foreign
+from .relationships import RelationshipProperty
+from .relationships import remote
+from .scoping import scoped_session
+from .session import close_all_sessions
+from .session import make_transient
+from .session import make_transient_to_detached
+from .session import object_session
+from .session import ORMExecuteState
+from .session import Session
+from .session import sessionmaker
+from .session import SessionTransaction
+from .state import AttributeState
+from .state import InstanceState
+from .strategy_options import Load
+from .unitofwork import UOWTransaction
+from .util import aliased
+from .util import Bundle
+from .util import CascadeOptions
+from .util import join
+from .util import LoaderCriteriaOption
+from .util import object_mapper
+from .util import outerjoin
+from .util import polymorphic_union
+from .util import was_deleted
+from .util import with_parent
+from .util import with_polymorphic
+from .. import sql as _sql
+from .. import util as _sa_util
+from ..util.langhelpers import public_factory
+
+
+def create_session(bind=None, **kwargs):
+ r"""Create a new :class:`.Session`
+ with no automation enabled by default.
+
+ This function is used primarily for testing. The usual
+ route to :class:`.Session` creation is via its constructor
+ or the :func:`.sessionmaker` function.
+
+ :param bind: optional, a single Connectable to use for all
+ database access in the created
+ :class:`~sqlalchemy.orm.session.Session`.
+
+ :param \*\*kwargs: optional, passed through to the
+ :class:`.Session` constructor.
+
+ :returns: an :class:`~sqlalchemy.orm.session.Session` instance
+
+ The defaults of create_session() are the opposite of that of
+ :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are
+ False, ``autocommit`` is True. In this sense the session acts
+ more like the "classic" SQLAlchemy 0.3 session with these.
+
+ .. deprecated:: 1.4 The "autocommit" parameter will be removed in
+ SQLAlchemy 2.0. :func:`_orm.create_session` will return a
+ :class:`_orm.Session` that does not include "autocommit' behavior
+ in release 2.0.
+
+ Usage::
+
+ >>> from sqlalchemy.orm import create_session
+ >>> session = create_session()
+
+ It is recommended to use :func:`sessionmaker` instead of
+ create_session().
+
+ """
+
+ if kwargs.get("future", False):
+ kwargs.setdefault("autocommit", False)
+ else:
+ kwargs.setdefault("autocommit", True)
+
+ kwargs.setdefault("autoflush", False)
+ kwargs.setdefault("expire_on_commit", False)
+ return Session(bind=bind, **kwargs)
+
+
+with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm")
+
+relationship = public_factory(RelationshipProperty, ".orm.relationship")
+
+
+@_sa_util.deprecated_20("relation", "Please use :func:`.relationship`.")
+def relation(*arg, **kw):
+ """A synonym for :func:`relationship`."""
+
+ return relationship(*arg, **kw)
+
+
+def dynamic_loader(argument, **kw):
+ """Construct a dynamically-loading mapper property.
+
+ This is essentially the same as
+ using the ``lazy='dynamic'`` argument with :func:`relationship`::
+
+ dynamic_loader(SomeClass)
+
+ # is the same as
+
+ relationship(SomeClass, lazy="dynamic")
+
+ See the section :ref:`dynamic_relationship` for more details
+ on dynamic loading.
+
+ """
+ kw["lazy"] = "dynamic"
+ return relationship(argument, **kw)
+
+
+column_property = public_factory(ColumnProperty, ".orm.column_property")
+composite = public_factory(CompositeProperty, ".orm.composite")
+
+
+def backref(name, **kwargs):
+ """When using the :paramref:`_orm.relationship.backref` parameter,
+ provides specific parameters to be used when the new
+ :func:`_orm.relationship` is generated.
+
+ E.g.::
+
+ 'items':relationship(
+ SomeItem, backref=backref('parent', lazy='subquery'))
+
+ The :paramref:`_orm.relationship.backref` parameter is generally
+ considered to be legacy; for modern applications, using
+ explicit :func:`_orm.relationship` constructs linked together using
+ the :paramref:`_orm.relationship.back_populates` parameter should be
+ preferred.
+
+ .. seealso::
+
+ :ref:`relationships_backref` - background on backrefs
+
+ """
+
+ return (name, kwargs)
+
+
+def deferred(*columns, **kw):
+ r"""Indicate a column-based mapped attribute that by default will
+ not load unless accessed.
+
+ :param \*columns: columns to be mapped. This is typically a single
+ :class:`_schema.Column` object,
+ however a collection is supported in order
+ to support multiple columns mapped under the same attribute.
+
+ :param raiseload: boolean, if True, indicates an exception should be raised
+ if the load operation is to take place.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ :param \**kw: additional keyword arguments passed to
+ :class:`.ColumnProperty`.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ """
+ return ColumnProperty(deferred=True, *columns, **kw)
+
+
+def query_expression(default_expr=_sql.null()):
+ """Indicate an attribute that populates from a query-time SQL expression.
+
+ :param default_expr: Optional SQL expression object that will be used in
+ all cases if not assigned later with :func:`_orm.with_expression`.
+ E.g.::
+
+ from sqlalchemy.sql import literal
+
+ class C(Base):
+ #...
+ my_expr = query_expression(literal(1))
+
+ .. versionadded:: 1.3.18
+
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`mapper_querytime_expression`
+
+ """
+ prop = ColumnProperty(default_expr)
+ prop.strategy_key = (("query_expression", True),)
+ return prop
+
+
+mapper = public_factory(Mapper, ".orm.mapper")
+
+synonym = public_factory(SynonymProperty, ".orm.synonym")
+
+
+def clear_mappers():
+ """Remove all mappers from all classes.
+
+ .. versionchanged:: 1.4 This function now locates all
+ :class:`_orm.registry` objects and calls upon the
+ :meth:`_orm.registry.dispose` method of each.
+
+ This function removes all instrumentation from classes and disposes
+ of their associated mappers. Once called, the classes are unmapped
+ and can be later re-mapped with new mappers.
+
+ :func:`.clear_mappers` is *not* for normal use, as there is literally no
+ valid usage for it outside of very specific testing scenarios. Normally,
+ mappers are permanent structural components of user-defined classes, and
+ are never discarded independently of their class. If a mapped class
+ itself is garbage collected, its mapper is automatically disposed of as
+ well. As such, :func:`.clear_mappers` is only for usage in test suites
+ that re-use the same classes with different mappings, which is itself an
+ extremely rare use case - the only such use case is in fact SQLAlchemy's
+ own test suite, and possibly the test suites of other ORM extension
+ libraries which intend to test various combinations of mapper construction
+ upon a fixed set of classes.
+
+ """
+
+ mapperlib._dispose_registries(mapperlib._all_registries(), False)
+
+
+joinedload = strategy_options.joinedload._unbound_fn
+contains_eager = strategy_options.contains_eager._unbound_fn
+defer = strategy_options.defer._unbound_fn
+undefer = strategy_options.undefer._unbound_fn
+undefer_group = strategy_options.undefer_group._unbound_fn
+with_expression = strategy_options.with_expression._unbound_fn
+load_only = strategy_options.load_only._unbound_fn
+lazyload = strategy_options.lazyload._unbound_fn
+subqueryload = strategy_options.subqueryload._unbound_fn
+selectinload = strategy_options.selectinload._unbound_fn
+immediateload = strategy_options.immediateload._unbound_fn
+noload = strategy_options.noload._unbound_fn
+raiseload = strategy_options.raiseload._unbound_fn
+defaultload = strategy_options.defaultload._unbound_fn
+selectin_polymorphic = strategy_options.selectin_polymorphic._unbound_fn
+
+
+@_sa_util.deprecated_20("eagerload", "Please use :func:`_orm.joinedload`.")
+def eagerload(*args, **kwargs):
+ """A synonym for :func:`joinedload()`."""
+ return joinedload(*args, **kwargs)
+
+
+contains_alias = public_factory(AliasOption, ".orm.contains_alias")
+
+if True:
+ from .events import AttributeEvents
+ from .events import MapperEvents
+ from .events import InstanceEvents
+ from .events import InstrumentationEvents
+ from .events import QueryEvents
+ from .events import SessionEvents
+
+
+def __go(lcls):
+ global __all__
+ global AppenderQuery
+ from .. import util as sa_util
+ from . import dynamic
+ from . import events
+ from . import loading
+ import inspect as _inspect
+
+ from .dynamic import AppenderQuery
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ _sa_util.preloaded.import_prefix("sqlalchemy.orm")
+ _sa_util.preloaded.import_prefix("sqlalchemy.ext")
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
new file mode 100644
index 0000000..efa20fb
--- /dev/null
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -0,0 +1,2331 @@
+# orm/attributes.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines instrumentation for class attributes and their interaction
+with instances.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+
+"""
+
+import operator
+
+from . import collections
+from . import exc as orm_exc
+from . import interfaces
+from .base import ATTR_EMPTY
+from .base import ATTR_WAS_SET
+from .base import CALLABLES_OK
+from .base import DEFERRED_HISTORY_LOAD
+from .base import INIT_OK
+from .base import instance_dict
+from .base import instance_state
+from .base import instance_str
+from .base import LOAD_AGAINST_COMMITTED
+from .base import manager_of_class
+from .base import NEVER_SET # noqa
+from .base import NO_AUTOFLUSH
+from .base import NO_CHANGE # noqa
+from .base import NO_RAISE
+from .base import NO_VALUE
+from .base import NON_PERSISTENT_OK # noqa
+from .base import PASSIVE_CLASS_MISMATCH # noqa
+from .base import PASSIVE_NO_FETCH
+from .base import PASSIVE_NO_FETCH_RELATED # noqa
+from .base import PASSIVE_NO_INITIALIZE
+from .base import PASSIVE_NO_RESULT
+from .base import PASSIVE_OFF
+from .base import PASSIVE_ONLY_PERSISTENT
+from .base import PASSIVE_RETURN_NO_VALUE
+from .base import RELATED_OBJECT_OK # noqa
+from .base import SQL_OK # noqa
+from .base import state_str
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql import base as sql_base
+from ..sql import roles
+from ..sql import traversals
+from ..sql import visitors
+
+
+class NoKey(str):
+ pass
+
+
+NO_KEY = NoKey("no name")
+
+
+@inspection._self_inspects
+class QueryableAttribute(
+ interfaces._MappedAttribute,
+ interfaces.InspectionAttr,
+ interfaces.PropComparator,
+ traversals.HasCopyInternals,
+ roles.JoinTargetRole,
+ roles.OnClauseRole,
+ sql_base.Immutable,
+ sql_base.MemoizedHasCacheKey,
+):
+ """Base class for :term:`descriptor` objects that intercept
+ attribute events on behalf of a :class:`.MapperProperty`
+ object. The actual :class:`.MapperProperty` is accessible
+ via the :attr:`.QueryableAttribute.property`
+ attribute.
+
+
+ .. seealso::
+
+ :class:`.InstrumentedAttribute`
+
+ :class:`.MapperProperty`
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ :attr:`_orm.Mapper.attrs`
+ """
+
+ is_attribute = True
+
+ # PropComparator has a __visit_name__ to participate within
+ # traversals. Disambiguate the attribute vs. a comparator.
+ __visit_name__ = "orm_instrumented_attribute"
+
+ def __init__(
+ self,
+ class_,
+ key,
+ parententity,
+ impl=None,
+ comparator=None,
+ of_type=None,
+ extra_criteria=(),
+ ):
+ self.class_ = class_
+ self.key = key
+ self._parententity = parententity
+ self.impl = impl
+ self.comparator = comparator
+ self._of_type = of_type
+ self._extra_criteria = extra_criteria
+
+ manager = manager_of_class(class_)
+ # manager is None in the case of AliasedClass
+ if manager:
+ # propagate existing event listeners from
+ # immediate superclass
+ for base in manager._bases:
+ if key in base:
+ self.dispatch._update(base[key].dispatch)
+ if base[key].dispatch._active_history:
+ self.dispatch._active_history = True
+
+ _cache_key_traversal = [
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ ]
+
+ def __reduce__(self):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ return (
+ _queryable_attribute_unreduce,
+ (
+ self.key,
+ self._parententity.mapper.class_,
+ self._parententity,
+ self._parententity.entity,
+ ),
+ )
+
+ @util.memoized_property
+ def _supports_population(self):
+ return self.impl.supports_population
+
+ @property
+ def _impl_uses_objects(self):
+ return self.impl.uses_objects
+
+ def get_history(self, instance, passive=PASSIVE_OFF):
+ return self.impl.get_history(
+ instance_state(instance), instance_dict(instance), passive
+ )
+
+ @util.memoized_property
+ def info(self):
+ """Return the 'info' dictionary for the underlying SQL element.
+
+ The behavior here is as follows:
+
+ * If the attribute is a column-mapped property, i.e.
+ :class:`.ColumnProperty`, which is mapped directly
+ to a schema-level :class:`_schema.Column` object, this attribute
+ will return the :attr:`.SchemaItem.info` dictionary associated
+ with the core-level :class:`_schema.Column` object.
+
+ * If the attribute is a :class:`.ColumnProperty` but is mapped to
+ any other kind of SQL expression other than a
+ :class:`_schema.Column`,
+ the attribute will refer to the :attr:`.MapperProperty.info`
+ dictionary associated directly with the :class:`.ColumnProperty`,
+ assuming the SQL expression itself does not have its own ``.info``
+ attribute (which should be the case, unless a user-defined SQL
+ construct has defined one).
+
+ * If the attribute refers to any other kind of
+ :class:`.MapperProperty`, including :class:`.RelationshipProperty`,
+ the attribute will refer to the :attr:`.MapperProperty.info`
+ dictionary associated with that :class:`.MapperProperty`.
+
+ * To access the :attr:`.MapperProperty.info` dictionary of the
+ :class:`.MapperProperty` unconditionally, including for a
+ :class:`.ColumnProperty` that's associated directly with a
+ :class:`_schema.Column`, the attribute can be referred to using
+ :attr:`.QueryableAttribute.property` attribute, as
+ ``MyClass.someattribute.property.info``.
+
+ .. seealso::
+
+ :attr:`.SchemaItem.info`
+
+ :attr:`.MapperProperty.info`
+
+ """
+ return self.comparator.info
+
+ @util.memoized_property
+ def parent(self):
+ """Return an inspection instance representing the parent.
+
+ This will be either an instance of :class:`_orm.Mapper`
+ or :class:`.AliasedInsp`, depending upon the nature
+ of the parent entity which this attribute is associated
+ with.
+
+ """
+ return inspection.inspect(self._parententity)
+
+ @util.memoized_property
+ def expression(self):
+ """The SQL expression object represented by this
+ :class:`.QueryableAttribute`.
+
+ This will typically be an instance of a :class:`_sql.ColumnElement`
+ subclass representing a column expression.
+
+ """
+ if self.key is NO_KEY:
+ annotations = {"entity_namespace": self._entity_namespace}
+ else:
+ annotations = {
+ "proxy_key": self.key,
+ "proxy_owner": self._parententity,
+ "entity_namespace": self._entity_namespace,
+ }
+
+ ce = self.comparator.__clause_element__()
+ try:
+ anno = ce._annotate
+ except AttributeError as ae:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'When interpreting attribute "%s" as a SQL expression, '
+ "expected __clause_element__() to return "
+ "a ClauseElement object, got: %r" % (self, ce)
+ ),
+ from_=ae,
+ )
+ else:
+ return anno(annotations)
+
+ @property
+ def _entity_namespace(self):
+ return self._parententity
+
+ @property
+ def _annotations(self):
+ return self.__clause_element__()._annotations
+
+ def __clause_element__(self):
+ return self.expression
+
+ @property
+ def _from_objects(self):
+ return self.expression._from_objects
+
+ def _bulk_update_tuples(self, value):
+ """Return setter tuples for a bulk UPDATE."""
+
+ return self.comparator._bulk_update_tuples(value)
+
+ def adapt_to_entity(self, adapt_to_entity):
+ assert not self._of_type
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ impl=self.impl,
+ comparator=self.comparator.adapt_to_entity(adapt_to_entity),
+ parententity=adapt_to_entity,
+ )
+
+ def of_type(self, entity):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.of_type(entity),
+ of_type=inspection.inspect(entity),
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.and_(*other),
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def _clone(self, **kw):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def label(self, name):
+ return self.__clause_element__().label(name)
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.comparator, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
+ def hasparent(self, state, optimistic=False):
+ return self.impl.hasparent(state, optimistic=optimistic) is not False
+
+ def __getattr__(self, key):
+ try:
+ return getattr(self.comparator, key)
+ except AttributeError as err:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object associated with %s "
+ "has an attribute %r"
+ % (
+ type(self).__name__,
+ type(self.comparator).__name__,
+ self,
+ key,
+ )
+ ),
+ replace_context=err,
+ )
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ @util.memoized_property
+ def property(self):
+ """Return the :class:`.MapperProperty` associated with this
+ :class:`.QueryableAttribute`.
+
+
+ Return values here will commonly be instances of
+ :class:`.ColumnProperty` or :class:`.RelationshipProperty`.
+
+
+ """
+ return self.comparator.property
+
+
+def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ if parententity.is_aliased_class:
+ return entity._get_from_serialized(key, mapped_class, parententity)
+ else:
+ return getattr(entity, key)
+
+
+if util.py3k:
+ from typing import TypeVar, Generic
+
+ _T = TypeVar("_T")
+ _Generic_T = Generic[_T]
+else:
+ _Generic_T = type("_Generic_T", (), {})
+
+
+class Mapped(QueryableAttribute, _Generic_T):
+ """Represent an ORM mapped :term:`descriptor` attribute for typing
+ purposes.
+
+ This class represents the complete descriptor interface for any class
+ attribute that will have been :term:`instrumented` by the ORM
+ :class:`_orm.Mapper` class. When used with typing stubs, it is the final
+ type that would be used by a type checker such as mypy to provide the full
+ behavioral contract for the attribute.
+
+ .. tip::
+
+ The :class:`_orm.Mapped` class represents attributes that are handled
+ directly by the :class:`_orm.Mapper` class. It does not include other
+ Python descriptor classes that are provided as extensions, including
+ :ref:`hybrids_toplevel` and the :ref:`associationproxy_toplevel`.
+ While these systems still make use of ORM-specific superclasses
+ and structures, they are not :term:`instrumented` by the
+ :class:`_orm.Mapper` and instead provide their own functionality
+ when they are accessed on a class.
+
+ When using the :ref:`SQLAlchemy Mypy plugin <mypy_toplevel>`, the
+ :class:`_orm.Mapped` construct is used in typing annotations to indicate to
+ the plugin those attributes that are expected to be mapped; the plugin also
+ applies :class:`_orm.Mapped` as an annotation automatically when it scans
+ through declarative mappings in :ref:`orm_declarative_table` style. For
+ more indirect mapping styles such as
+ :ref:`imperative table <orm_imperative_table_configuration>` it is
+ typically applied explicitly to class level attributes that expect
+ to be mapped based on a given :class:`_schema.Table` configuration.
+
+ :class:`_orm.Mapped` is defined in the
+ `sqlalchemy2-stubs <https://pypi.org/project/sqlalchemy2-stubs>`_ project
+ as a :pep:`484` generic class which may subscribe to any arbitrary Python
+ type, which represents the Python type handled by the attribute::
+
+ class MyMappedClass(Base):
+ __table_ = Table(
+ "some_table", Base.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("created_at", DateTime)
+ )
+
+ id : Mapped[int]
+ data: Mapped[str]
+ created_at: Mapped[datetime]
+
+ For complete background on how to use :class:`_orm.Mapped` with
+ pep-484 tools like Mypy, see the link below for background on SQLAlchemy's
+ Mypy plugin.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`mypy_toplevel` - complete background on Mypy integration
+
+ """
+
+ def __get__(self, instance, owner):
+ raise NotImplementedError()
+
+ def __set__(self, instance, value):
+ raise NotImplementedError()
+
+ def __delete__(self, instance):
+ raise NotImplementedError()
+
+
+class InstrumentedAttribute(Mapped):
+ """Class bound instrumented attribute which adds basic
+ :term:`descriptor` methods.
+
+ See :class:`.QueryableAttribute` for a description of most features.
+
+
+ """
+
+ inherit_cache = True
+
+ def __set__(self, instance, value):
+ self.impl.set(
+ instance_state(instance), instance_dict(instance), value, None
+ )
+
+ def __delete__(self, instance):
+ self.impl.delete(instance_state(instance), instance_dict(instance))
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self
+
+ dict_ = instance_dict(instance)
+ if self._supports_population and self.key in dict_:
+ return dict_[self.key]
+ else:
+ try:
+ state = instance_state(instance)
+ except AttributeError as err:
+ util.raise_(
+ orm_exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ return self.impl.get(state, dict_)
+
+
+HasEntityNamespace = util.namedtuple(
+ "HasEntityNamespace", ["entity_namespace"]
+)
+HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+
+
+def create_proxied_attribute(descriptor):
+ """Create an QueryableAttribute / user descriptor hybrid.
+
+ Returns a new QueryableAttribute type that delegates descriptor
+ behavior and getattr() to the given descriptor.
+ """
+
+ # TODO: can move this to descriptor_props if the need for this
+ # function is removed from ext/hybrid.py
+
+ class Proxy(QueryableAttribute):
+ """Presents the :class:`.QueryableAttribute` interface as a
+ proxy on top of a Python descriptor / :class:`.PropComparator`
+ combination.
+
+ """
+
+ _extra_criteria = ()
+
+ def __init__(
+ self,
+ class_,
+ key,
+ descriptor,
+ comparator,
+ adapt_to_entity=None,
+ doc=None,
+ original_property=None,
+ ):
+ self.class_ = class_
+ self.key = key
+ self.descriptor = descriptor
+ self.original_property = original_property
+ self._comparator = comparator
+ self._adapt_to_entity = adapt_to_entity
+ self.__doc__ = doc
+
+ _is_internal_proxy = True
+
+ _cache_key_traversal = [
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ]
+
+ @property
+ def _impl_uses_objects(self):
+ return (
+ self.original_property is not None
+ and getattr(self.class_, self.key).impl.uses_objects
+ )
+
+ @property
+ def _parententity(self):
+ return inspection.inspect(self.class_, raiseerr=False)
+
+ @property
+ def _entity_namespace(self):
+ if hasattr(self._comparator, "_parententity"):
+ return self._comparator._parententity
+ else:
+ # used by hybrid attributes which try to remain
+ # agnostic of any ORM concepts like mappers
+ return HasEntityNamespace(self.class_)
+
+ @property
+ def property(self):
+ return self.comparator.property
+
+ @util.memoized_property
+ def comparator(self):
+ if callable(self._comparator):
+ self._comparator = self._comparator()
+ if self._adapt_to_entity:
+ self._comparator = self._comparator.adapt_to_entity(
+ self._adapt_to_entity
+ )
+ return self._comparator
+
+ def adapt_to_entity(self, adapt_to_entity):
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ self.descriptor,
+ self._comparator,
+ adapt_to_entity,
+ )
+
+ def _clone(self, **kw):
+ return self.__class__(
+ self.class_,
+ self.key,
+ self.descriptor,
+ self._comparator,
+ adapt_to_entity=self._adapt_to_entity,
+ original_property=self.original_property,
+ )
+
+ def __get__(self, instance, owner):
+ retval = self.descriptor.__get__(instance, owner)
+ # detect if this is a plain Python @property, which just returns
+ # itself for class level access. If so, then return us.
+ # Otherwise, return the object returned by the descriptor.
+ if retval is self.descriptor and instance is None:
+ return self
+ else:
+ return retval
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ def __getattr__(self, attribute):
+ """Delegate __getattr__ to the original descriptor and/or
+ comparator."""
+ try:
+ return getattr(descriptor, attribute)
+ except AttributeError as err:
+ if attribute == "comparator":
+ util.raise_(
+ AttributeError("comparator"), replace_context=err
+ )
+ try:
+ # comparator itself might be unreachable
+ comparator = self.comparator
+ except AttributeError as err2:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor unconfigured comparator "
+ "object associated with %s has an attribute %r"
+ % (type(descriptor).__name__, self, attribute)
+ ),
+ replace_context=err2,
+ )
+ else:
+ try:
+ return getattr(comparator, attribute)
+ except AttributeError as err3:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object "
+ "associated with %s has an attribute %r"
+ % (
+ type(descriptor).__name__,
+ type(comparator).__name__,
+ self,
+ attribute,
+ )
+ ),
+ replace_context=err3,
+ )
+
+ Proxy.__name__ = type(descriptor).__name__ + "Proxy"
+
+ util.monkeypatch_proxied_specials(
+ Proxy, type(descriptor), name="descriptor", from_instance=descriptor
+ )
+ return Proxy
+
+
+OP_REMOVE = util.symbol("REMOVE")
+OP_APPEND = util.symbol("APPEND")
+OP_REPLACE = util.symbol("REPLACE")
+OP_BULK_REPLACE = util.symbol("BULK_REPLACE")
+OP_MODIFIED = util.symbol("MODIFIED")
+
+
+class AttributeEvent(object):
+ """A token propagated throughout the course of a chain of attribute
+ events.
+
+ Serves as an indicator of the source of the event and also provides
+ a means of controlling propagation across a chain of attribute
+ operations.
+
+ The :class:`.Event` object is sent as the ``initiator`` argument
+ when dealing with events such as :meth:`.AttributeEvents.append`,
+ :meth:`.AttributeEvents.set`,
+ and :meth:`.AttributeEvents.remove`.
+
+ The :class:`.Event` object is currently interpreted by the backref
+ event handlers, and is used to control the propagation of operations
+ across two mutually-dependent attributes.
+
+ .. versionadded:: 0.9.0
+
+ :attribute impl: The :class:`.AttributeImpl` which is the current event
+ initiator.
+
+ :attribute op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE`,
+ :attr:`.OP_REPLACE`, or :attr:`.OP_BULK_REPLACE`, indicating the
+ source operation.
+
+ """
+
+ __slots__ = "impl", "op", "parent_token"
+
+ def __init__(self, attribute_impl, op):
+ self.impl = attribute_impl
+ self.op = op
+ self.parent_token = self.impl.parent_token
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, AttributeEvent)
+ and other.impl is self.impl
+ and other.op == self.op
+ )
+
+ @property
+ def key(self):
+ return self.impl.key
+
+ def hasparent(self, state):
+ return self.impl.hasparent(state)
+
+
+Event = AttributeEvent
+
+
+class AttributeImpl(object):
+ """internal implementation for instrumented attributes."""
+
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ trackparent=False,
+ compare_function=None,
+ active_history=False,
+ parent_token=None,
+ load_on_unexpire=True,
+ send_modified_events=True,
+ accepts_scalar_loader=None,
+ **kwargs
+ ):
+ r"""Construct an AttributeImpl.
+
+ :param \class_: associated class
+
+ :param key: string name of the attribute
+
+ :param \callable_:
+ optional function which generates a callable based on a parent
+ instance, which produces the "default" values for a scalar or
+ collection attribute when it's first accessed, if not present
+ already.
+
+ :param trackparent:
+ if True, attempt to track if an instance has a parent attached
+ to it via this attribute.
+
+ :param compare_function:
+ a function that compares two values which are normally
+ assignable to this attribute.
+
+ :param active_history:
+ indicates that get_history() should always return the "old" value,
+ even if it means executing a lazy callable upon attribute change.
+
+ :param parent_token:
+ Usually references the MapperProperty, used as a key for
+ the hasparent() function to identify an "owning" attribute.
+ Allows multiple AttributeImpls to all match a single
+ owner attribute.
+
+ :param load_on_unexpire:
+ if False, don't include this attribute in a load-on-expired
+ operation, i.e. the "expired_attribute_loader" process.
+ The attribute can still be in the "expired" list and be
+ considered to be "expired". Previously, this flag was called
+ "expire_missing" and is only used by a deferred column
+ attribute.
+
+ :param send_modified_events:
+ if False, the InstanceState._modified_event method will have no
+ effect; this means the attribute will never show up as changed in a
+ history entry.
+
+ """
+ self.class_ = class_
+ self.key = key
+ self.callable_ = callable_
+ self.dispatch = dispatch
+ self.trackparent = trackparent
+ self.parent_token = parent_token or self
+ self.send_modified_events = send_modified_events
+ if compare_function is None:
+ self.is_equal = operator.eq
+ else:
+ self.is_equal = compare_function
+
+ if accepts_scalar_loader is not None:
+ self.accepts_scalar_loader = accepts_scalar_loader
+ else:
+ self.accepts_scalar_loader = self.default_accepts_scalar_loader
+
+ _deferred_history = kwargs.pop("_deferred_history", False)
+ self._deferred_history = _deferred_history
+
+ if active_history:
+ self.dispatch._active_history = True
+
+ self.load_on_unexpire = load_on_unexpire
+ self._modified_token = Event(self, OP_MODIFIED)
+
+ __slots__ = (
+ "class_",
+ "key",
+ "callable_",
+ "dispatch",
+ "trackparent",
+ "parent_token",
+ "send_modified_events",
+ "is_equal",
+ "load_on_unexpire",
+ "_modified_token",
+ "accepts_scalar_loader",
+ "_deferred_history",
+ )
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ def _get_active_history(self):
+ """Backwards compat for impl.active_history"""
+
+ return self.dispatch._active_history
+
+ def _set_active_history(self, value):
+ self.dispatch._active_history = value
+
+ active_history = property(_get_active_history, _set_active_history)
+
+ def hasparent(self, state, optimistic=False):
+ """Return the boolean value of a `hasparent` flag attached to
+ the given state.
+
+ The `optimistic` flag determines what the default return value
+ should be if no `hasparent` flag can be located.
+
+ As this function is used to determine if an instance is an
+ *orphan*, instances that were loaded from storage should be
+ assumed to not be orphans, until a True/False value for this
+ flag is set.
+
+ An instance attribute that is loaded by a callable function
+ will also not have a `hasparent` flag.
+
+ """
+ msg = "This AttributeImpl is not configured to track parents."
+ assert self.trackparent, msg
+
+ return (
+ state.parents.get(id(self.parent_token), optimistic) is not False
+ )
+
+ def sethasparent(self, state, parent_state, value):
+ """Set a boolean flag on the given item corresponding to
+ whether or not it is attached to a parent object via the
+ attribute represented by this ``InstrumentedAttribute``.
+
+ """
+ msg = "This AttributeImpl is not configured to track parents."
+ assert self.trackparent, msg
+
+ id_ = id(self.parent_token)
+ if value:
+ state.parents[id_] = parent_state
+ else:
+ if id_ in state.parents:
+ last_parent = state.parents[id_]
+
+ if (
+ last_parent is not False
+ and last_parent.key != parent_state.key
+ ):
+
+ if last_parent.obj() is None:
+ raise orm_exc.StaleDataError(
+ "Removing state %s from parent "
+ "state %s along attribute '%s', "
+ "but the parent record "
+ "has gone stale, can't be sure this "
+ "is the most recent parent."
+ % (
+ state_str(state),
+ state_str(parent_state),
+ self.key,
+ )
+ )
+
+ return
+
+ state.parents[id_] = False
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ raise NotImplementedError()
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ """Return a list of tuples of (state, obj)
+ for all objects in this attribute's current state
+ + history.
+
+ Only applies to object-based attributes.
+
+ This is an inlining of existing functionality
+ which roughly corresponds to:
+
+ get_state_history(
+ state,
+ key,
+ passive=PASSIVE_NO_INITIALIZE).sum()
+
+ """
+ raise NotImplementedError()
+
+ def _default_value(self, state, dict_):
+ """Produce an empty value for an uninitialized scalar attribute."""
+
+ assert self.key not in dict_, (
+ "_default_value should only be invoked for an "
+ "uninitialized or expired attribute"
+ )
+
+ value = None
+ for fn in self.dispatch.init_scalar:
+ ret = fn(state, value, dict_)
+ if ret is not ATTR_EMPTY:
+ value = ret
+
+ return value
+
+ def get(self, state, dict_, passive=PASSIVE_OFF):
+ """Retrieve a value from the given object.
+ If a callable is assembled on this object's attribute, and
+ passive is False, the callable will be executed and the
+ resulting value will be set as the new value for this attribute.
+ """
+ if self.key in dict_:
+ return dict_[self.key]
+ else:
+ # if history present, don't load
+ key = self.key
+ if (
+ key not in state.committed_state
+ or state.committed_state[key] is NO_VALUE
+ ):
+ if not passive & CALLABLES_OK:
+ return PASSIVE_NO_RESULT
+
+ value = self._fire_loader_callables(state, key, passive)
+
+ if value is PASSIVE_NO_RESULT or value is NO_VALUE:
+ return value
+ elif value is ATTR_WAS_SET:
+ try:
+ return dict_[key]
+ except KeyError as err:
+ # TODO: no test coverage here.
+ util.raise_(
+ KeyError(
+ "Deferred loader for attribute "
+ "%r failed to populate "
+ "correctly" % key
+ ),
+ replace_context=err,
+ )
+ elif value is not ATTR_EMPTY:
+ return self.set_committed_value(state, dict_, value)
+
+ if not passive & INIT_OK:
+ return NO_VALUE
+ else:
+ return self._default_value(state, dict_)
+
+ def _fire_loader_callables(self, state, key, passive):
+ if (
+ self.accepts_scalar_loader
+ and self.load_on_unexpire
+ and key in state.expired_attributes
+ ):
+ return state._load_expired(state, passive)
+ elif key in state.callables:
+ callable_ = state.callables[key]
+ return callable_(state, passive)
+ elif self.callable_:
+ return self.callable_(state, passive)
+ else:
+ return ATTR_EMPTY
+
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(state, dict_, value, initiator, passive=passive)
+
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(
+ state, dict_, None, initiator, passive=passive, check_old=value
+ )
+
+ def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(
+ state,
+ dict_,
+ None,
+ initiator,
+ passive=passive,
+ check_old=value,
+ pop=True,
+ )
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ raise NotImplementedError()
+
+ def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
+ """return the unchanged value of this attribute"""
+
+ if self.key in state.committed_state:
+ value = state.committed_state[self.key]
+ if value is NO_VALUE:
+ return None
+ else:
+ return value
+ else:
+ return self.get(state, dict_, passive=passive)
+
+ def set_committed_value(self, state, dict_, value):
+ """set an attribute value on the given instance and 'commit' it."""
+
+ dict_[self.key] = value
+ state._commit(dict_, [self.key])
+ return value
+
+
+class ScalarAttributeImpl(AttributeImpl):
+ """represents a scalar value-holding InstrumentedAttribute."""
+
+ default_accepts_scalar_loader = True
+ uses_objects = False
+ supports_population = True
+ collection = False
+ dynamic = False
+
+ __slots__ = "_replace_token", "_append_token", "_remove_token"
+
+ def __init__(self, *arg, **kw):
+ super(ScalarAttributeImpl, self).__init__(*arg, **kw)
+ self._replace_token = self._append_token = Event(self, OP_REPLACE)
+ self._remove_token = Event(self, OP_REMOVE)
+
+ def delete(self, state, dict_):
+ if self.dispatch._active_history:
+ old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
+ else:
+ old = dict_.get(self.key, NO_VALUE)
+
+ if self.dispatch.remove:
+ self.fire_remove_event(state, dict_, old, self._remove_token)
+ state._modified_event(dict_, self, old)
+
+ existing = dict_.pop(self.key, NO_VALUE)
+ if (
+ existing is NO_VALUE
+ and old is NO_VALUE
+ and not state.expired
+ and self.key not in state.expired_attributes
+ ):
+ raise AttributeError("%s object does not have a value" % self)
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key in dict_:
+ return History.from_scalar_attribute(self, state, dict_[self.key])
+ elif self.key in state.committed_state:
+ return History.from_scalar_attribute(self, state, NO_VALUE)
+ else:
+ if passive & INIT_OK:
+ passive ^= INIT_OK
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+ else:
+ return History.from_scalar_attribute(self, state, current)
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ if self.dispatch._active_history:
+ old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
+ else:
+ old = dict_.get(self.key, NO_VALUE)
+
+ if self.dispatch.set:
+ value = self.fire_replace_event(
+ state, dict_, value, old, initiator
+ )
+ state._modified_event(dict_, self, old)
+ dict_[self.key] = value
+
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
+ for fn in self.dispatch.set:
+ value = fn(
+ state, value, previous, initiator or self._replace_token
+ )
+ return value
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ @property
+ def type(self):
+ self.property.columns[0].type
+
+
+class ScalarObjectAttributeImpl(ScalarAttributeImpl):
+ """represents a scalar-holding InstrumentedAttribute,
+ where the target object is also instrumented.
+
+ Adds events to delete/set operations.
+
+ """
+
+ default_accepts_scalar_loader = False
+ uses_objects = True
+ supports_population = True
+ collection = False
+
+ __slots__ = ()
+
+ def delete(self, state, dict_):
+ if self.dispatch._active_history:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
+ else:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
+
+ self.fire_remove_event(state, dict_, old, self._remove_token)
+
+ existing = dict_.pop(self.key, NO_VALUE)
+
+ # if the attribute is expired, we currently have no way to tell
+ # that an object-attribute was expired vs. not loaded. So
+ # for this test, we look to see if the object has a DB identity.
+ if (
+ existing is NO_VALUE
+ and old is not PASSIVE_NO_RESULT
+ and state.key is None
+ ):
+ raise AttributeError("%s object does not have a value" % self)
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key in dict_:
+ current = dict_[self.key]
+ else:
+ if passive & INIT_OK:
+ passive ^= INIT_OK
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+
+ if not self._deferred_history:
+ return History.from_object_attribute(self, state, current)
+ else:
+ original = state.committed_state.get(self.key, _NO_HISTORY)
+ if original is PASSIVE_NO_RESULT:
+
+ loader_passive = passive | (
+ PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE
+ | DEFERRED_HISTORY_LOAD
+ )
+ original = self._fire_loader_callables(
+ state, self.key, loader_passive
+ )
+ return History.from_object_attribute(
+ self, state, current, original=original
+ )
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ if self.key in dict_:
+ current = dict_[self.key]
+ elif passive & CALLABLES_OK:
+ current = self.get(state, dict_, passive=passive)
+ else:
+ return []
+
+ # can't use __hash__(), can't use __eq__() here
+ if (
+ current is not None
+ and current is not PASSIVE_NO_RESULT
+ and current is not NO_VALUE
+ ):
+ ret = [(instance_state(current), current)]
+ else:
+ ret = [(None, None)]
+
+ if self.key in state.committed_state:
+ original = state.committed_state[self.key]
+ if (
+ original is not None
+ and original is not PASSIVE_NO_RESULT
+ and original is not NO_VALUE
+ and original is not current
+ ):
+
+ ret.append((instance_state(original), original))
+ return ret
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ """Set a value on the given InstanceState."""
+
+ if self.dispatch._active_history:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
+ else:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
+
+ if (
+ check_old is not None
+ and old is not PASSIVE_NO_RESULT
+ and check_old is not old
+ ):
+ if pop:
+ return
+ else:
+ raise ValueError(
+ "Object %s not associated with %s on attribute '%s'"
+ % (instance_str(check_old), state_str(state), self.key)
+ )
+
+ value = self.fire_replace_event(state, dict_, value, old, initiator)
+ dict_[self.key] = value
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ if self.trackparent and value not in (
+ None,
+ PASSIVE_NO_RESULT,
+ NO_VALUE,
+ ):
+ self.sethasparent(instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ state._modified_event(dict_, self, value)
+
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
+ if self.trackparent:
+ if previous is not value and previous not in (
+ None,
+ PASSIVE_NO_RESULT,
+ NO_VALUE,
+ ):
+ self.sethasparent(instance_state(previous), state, False)
+
+ for fn in self.dispatch.set:
+ value = fn(
+ state, value, previous, initiator or self._replace_token
+ )
+
+ state._modified_event(dict_, self, previous)
+
+ if self.trackparent:
+ if value is not None:
+ self.sethasparent(instance_state(value), state, True)
+
+ return value
+
+
+class CollectionAttributeImpl(AttributeImpl):
+ """A collection-holding attribute that instruments changes in membership.
+
+ Only handles collections of instrumented objects.
+
+ InstrumentedCollectionAttribute holds an arbitrary, user-specified
+ container object (defaulting to a list) and brokers access to the
+ CollectionAdapter, a "view" onto that object that presents consistent bag
+ semantics to the orm layer independent of the user data implementation.
+
+ """
+
+ default_accepts_scalar_loader = False
+ uses_objects = True
+ supports_population = True
+ collection = True
+ dynamic = False
+
+ __slots__ = (
+ "copy",
+ "collection_factory",
+ "_append_token",
+ "_remove_token",
+ "_bulk_replace_token",
+ "_duck_typed_as",
+ )
+
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ typecallable=None,
+ trackparent=False,
+ copy_function=None,
+ compare_function=None,
+ **kwargs
+ ):
+ super(CollectionAttributeImpl, self).__init__(
+ class_,
+ key,
+ callable_,
+ dispatch,
+ trackparent=trackparent,
+ compare_function=compare_function,
+ **kwargs
+ )
+
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
+ self.collection_factory = typecallable
+ self._append_token = Event(self, OP_APPEND)
+ self._remove_token = Event(self, OP_REMOVE)
+ self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
+ self._duck_typed_as = util.duck_type_collection(
+ self.collection_factory()
+ )
+
+ if getattr(self.collection_factory, "_sa_linker", None):
+
+ @event.listens_for(self, "init_collection")
+ def link(target, collection, collection_adapter):
+ collection._sa_linker(collection_adapter)
+
+ @event.listens_for(self, "dispose_collection")
+ def unlink(target, collection, collection_adapter):
+ collection._sa_linker(None)
+
+ def __copy(self, item):
+ return [y for y in collections.collection_adapter(item)]
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+ else:
+ return History.from_collection(self, state, current)
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ # NOTE: passive is ignored here at the moment
+
+ if self.key not in dict_:
+ return []
+
+ current = dict_[self.key]
+ current = getattr(current, "_sa_adapter")
+
+ if self.key in state.committed_state:
+ original = state.committed_state[self.key]
+ if original is not NO_VALUE:
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
+
+ current_set = dict(current_states)
+ original_set = dict(original_states)
+
+ return (
+ [
+ (s, o)
+ for s, o in current_states
+ if s not in original_set
+ ]
+ + [(s, o) for s, o in current_states if s in original_set]
+ + [
+ (s, o)
+ for s, o in original_states
+ if s not in current_set
+ ]
+ )
+
+ return [(instance_state(o), o) for o in current]
+
+ def fire_append_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.append:
+ value = fn(state, value, initiator or self._append_token)
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(instance_state(value), state, True)
+
+ return value
+
+ def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.append_wo_mutation:
+ value = fn(state, value, initiator or self._append_token)
+
+ return value
+
+ def fire_pre_remove_event(self, state, dict_, initiator):
+ """A special event used for pop() operations.
+
+ The "remove" event needs to have the item to be removed passed to
+ it, which in the case of pop from a set, we don't have a way to access
+ the item before the operation. the event is used for all pop()
+ operations (even though set.pop is the one where it is really needed).
+
+ """
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ if self.trackparent and value is not None:
+ self.sethasparent(instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ def delete(self, state, dict_):
+ if self.key not in dict_:
+ return
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ collection = self.get_collection(state, state.dict)
+ collection.clear_with_event()
+
+ # key is always present because we checked above. e.g.
+ # del is a no-op if collection not present.
+ del dict_[self.key]
+
+ def _default_value(self, state, dict_):
+ """Produce an empty collection for an un-initialized attribute"""
+
+ assert self.key not in dict_, (
+ "_default_value should only be invoked for an "
+ "uninitialized or expired attribute"
+ )
+
+ if self.key in state._empty_collections:
+ return state._empty_collections[self.key]
+
+ adapter, user_data = self._initialize_collection(state)
+ adapter._set_empty(user_data)
+ return user_data
+
+ def _initialize_collection(self, state):
+
+ adapter, collection = state.manager.initialize_collection(
+ self.key, state, self.collection_factory
+ )
+
+ self.dispatch.init_collection(state, collection, adapter)
+
+ return adapter, collection
+
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ collection = self.get_collection(state, dict_, passive=passive)
+ if collection is PASSIVE_NO_RESULT:
+ value = self.fire_append_event(state, dict_, value, initiator)
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
+ state._get_pending_mutation(self.key).append(value)
+ else:
+ collection.append_with_event(value, initiator)
+
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ collection = self.get_collection(state, state.dict, passive=passive)
+ if collection is PASSIVE_NO_RESULT:
+ self.fire_remove_event(state, dict_, value, initiator)
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
+ state._get_pending_mutation(self.key).remove(value)
+ else:
+ collection.remove_with_event(value, initiator)
+
+ def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ try:
+ # TODO: better solution here would be to add
+ # a "popper" role to collections.py to complement
+ # "remover".
+ self.remove(state, dict_, value, initiator, passive=passive)
+ except (ValueError, KeyError, IndexError):
+ pass
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ _adapt=True,
+ ):
+ iterable = orig_iterable = value
+
+ # pulling a new collection first so that an adaptation exception does
+ # not trigger a lazy load of the old collection.
+ new_collection, user_data = self._initialize_collection(state)
+ if _adapt:
+ if new_collection._converter is not None:
+ iterable = new_collection._converter(iterable)
+ else:
+ setting_type = util.duck_type_collection(iterable)
+ receiving_type = self._duck_typed_as
+
+ if setting_type is not receiving_type:
+ given = (
+ iterable is None
+ and "None"
+ or iterable.__class__.__name__
+ )
+ wanted = self._duck_typed_as.__name__
+ raise TypeError(
+ "Incompatible collection type: %s is not %s-like"
+ % (given, wanted)
+ )
+
+ # If the object is an adapted collection, return the (iterable)
+ # adapter.
+ if hasattr(iterable, "_sa_iterator"):
+ iterable = iterable._sa_iterator()
+ elif setting_type is dict:
+ if util.py3k:
+ iterable = iterable.values()
+ else:
+ iterable = getattr(
+ iterable, "itervalues", iterable.values
+ )()
+ else:
+ iterable = iter(iterable)
+ new_values = list(iterable)
+
+ evt = self._bulk_replace_token
+
+ self.dispatch.bulk_replace(state, new_values, evt)
+
+ old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT)
+ if old is PASSIVE_NO_RESULT:
+ old = self._default_value(state, dict_)
+ elif old is orig_iterable:
+ # ignore re-assignment of the current collection, as happens
+ # implicitly with in-place operators (foo.collection |= other)
+ return
+
+ # place a copy of "old" in state.committed_state
+ state._modified_event(dict_, self, old, True)
+
+ old_collection = old._sa_adapter
+
+ dict_[self.key] = user_data
+
+ collections.bulk_replace(
+ new_values, old_collection, new_collection, initiator=evt
+ )
+
+ self._dispose_previous_collection(state, old, old_collection, True)
+
+ def _dispose_previous_collection(
+ self, state, collection, adapter, fire_event
+ ):
+ del collection._sa_adapter
+
+ # discarding old collection make sure it is not referenced in empty
+ # collections.
+ state._empty_collections.pop(self.key, None)
+ if fire_event:
+ self.dispatch.dispose_collection(state, collection, adapter)
+
+ def _invalidate_collection(self, collection):
+ adapter = getattr(collection, "_sa_adapter")
+ adapter.invalidated = True
+
+ def set_committed_value(self, state, dict_, value):
+ """Set an attribute value on the given instance and 'commit' it."""
+
+ collection, user_data = self._initialize_collection(state)
+
+ if value:
+ collection.append_multiple_without_event(value)
+
+ state.dict[self.key] = user_data
+
+ state._commit(dict_, [self.key])
+
+ if self.key in state._pending_mutations:
+ # pending items exist. issue a modified event,
+ # add/remove new items.
+ state._modified_event(dict_, self, user_data, True)
+
+ pending = state._pending_mutations.pop(self.key)
+ added = pending.added_items
+ removed = pending.deleted_items
+ for item in added:
+ collection.append_without_event(item)
+ for item in removed:
+ collection.remove_without_event(item)
+
+ return user_data
+
+ def get_collection(
+ self, state, dict_, user_data=None, passive=PASSIVE_OFF
+ ):
+ """Retrieve the CollectionAdapter associated with the given state.
+
+ if user_data is None, retrieves it from the state using normal
+ "get()" rules, which will fire lazy callables or return the "empty"
+ collection value.
+
+ """
+ if user_data is None:
+ user_data = self.get(state, dict_, passive=passive)
+ if user_data is PASSIVE_NO_RESULT:
+ return user_data
+
+ return user_data._sa_adapter
+
+
+def backref_listeners(attribute, key, uselist):
+ """Apply listeners to synchronize a two-way relationship."""
+
+ # use easily recognizable names for stack traces.
+
+ # in the sections marked "tokens to test for a recursive loop",
+ # this is somewhat brittle and very performance-sensitive logic
+ # that is specific to how we might arrive at each event. a marker
+ # that can target us directly to arguments being invoked against
+ # the impl might be simpler, but could interfere with other systems.
+
+ parent_token = attribute.impl.parent_token
+ parent_impl = attribute.impl
+
+ def _acceptable_key_err(child_state, initiator, child_impl):
+ raise ValueError(
+ "Bidirectional attribute conflict detected: "
+ 'Passing object %s to attribute "%s" '
+ 'triggers a modify event on attribute "%s" '
+ 'via the backref "%s".'
+ % (
+ state_str(child_state),
+ initiator.parent_token,
+ child_impl.parent_token,
+ attribute.impl.parent_token,
+ )
+ )
+
+ def emit_backref_from_scalar_set_event(state, child, oldchild, initiator):
+ if oldchild is child:
+ return child
+ if (
+ oldchild is not None
+ and oldchild is not PASSIVE_NO_RESULT
+ and oldchild is not NO_VALUE
+ ):
+ # With lazy=None, there's no guarantee that the full collection is
+ # present when updating via a backref.
+ old_state, old_dict = (
+ instance_state(oldchild),
+ instance_dict(oldchild),
+ )
+ impl = old_state.manager[key].impl
+
+ # tokens to test for a recursive loop.
+ if not impl.collection and not impl.dynamic:
+ check_recursive_token = impl._replace_token
+ else:
+ check_recursive_token = impl._remove_token
+
+ if initiator is not check_recursive_token:
+ impl.pop(
+ old_state,
+ old_dict,
+ state.obj(),
+ parent_impl._append_token,
+ passive=PASSIVE_NO_FETCH,
+ )
+
+ if child is not None:
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
+ child_impl = child_state.manager[key].impl
+
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
+ _acceptable_key_err(state, initiator, child_impl)
+
+ # tokens to test for a recursive loop.
+ check_append_token = child_impl._append_token
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
+
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
+ child_impl.append(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+ return child
+
+ def emit_backref_from_collection_append_event(state, child, initiator):
+ if child is None:
+ return
+
+ child_state, child_dict = instance_state(child), instance_dict(child)
+ child_impl = child_state.manager[key].impl
+
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
+ _acceptable_key_err(state, initiator, child_impl)
+
+ # tokens to test for a recursive loop.
+ check_append_token = child_impl._append_token
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token if child_impl.collection else None
+ )
+
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
+ child_impl.append(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+ return child
+
+ def emit_backref_from_collection_remove_event(state, child, initiator):
+ if (
+ child is not None
+ and child is not PASSIVE_NO_RESULT
+ and child is not NO_VALUE
+ ):
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
+ child_impl = child_state.manager[key].impl
+
+ # tokens to test for a recursive loop.
+ if not child_impl.collection and not child_impl.dynamic:
+ check_remove_token = child_impl._remove_token
+ check_replace_token = child_impl._replace_token
+ check_for_dupes_on_remove = uselist and not parent_impl.dynamic
+ else:
+ check_remove_token = child_impl._remove_token
+ check_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
+ check_for_dupes_on_remove = False
+
+ if (
+ initiator is not check_remove_token
+ and initiator is not check_replace_token
+ ):
+
+ if not check_for_dupes_on_remove or not util.has_dupes(
+ # when this event is called, the item is usually
+ # present in the list, except for a pop() operation.
+ state.dict[parent_impl.key],
+ child,
+ ):
+ child_impl.pop(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+
+ if uselist:
+ event.listen(
+ attribute,
+ "append",
+ emit_backref_from_collection_append_event,
+ retval=True,
+ raw=True,
+ )
+ else:
+ event.listen(
+ attribute,
+ "set",
+ emit_backref_from_scalar_set_event,
+ retval=True,
+ raw=True,
+ )
+ # TODO: need coverage in test/orm/ of remove event
+ event.listen(
+ attribute,
+ "remove",
+ emit_backref_from_collection_remove_event,
+ retval=True,
+ raw=True,
+ )
+
+
+_NO_HISTORY = util.symbol("NO_HISTORY")
+_NO_STATE_SYMBOLS = frozenset([id(PASSIVE_NO_RESULT), id(NO_VALUE)])
+
+
+class History(util.namedtuple("History", ["added", "unchanged", "deleted"])):
+ """A 3-tuple of added, unchanged and deleted values,
+ representing the changes which have occurred on an instrumented
+ attribute.
+
+ The easiest way to get a :class:`.History` object for a particular
+ attribute on an object is to use the :func:`_sa.inspect` function::
+
+ from sqlalchemy import inspect
+
+ hist = inspect(myobject).attrs.myattribute.history
+
+ Each tuple member is an iterable sequence:
+
+ * ``added`` - the collection of items added to the attribute (the first
+ tuple element).
+
+ * ``unchanged`` - the collection of items that have not changed on the
+ attribute (the second tuple element).
+
+ * ``deleted`` - the collection of items that have been removed from the
+ attribute (the third tuple element).
+
+ """
+
+ def __bool__(self):
+ return self != HISTORY_BLANK
+
+ __nonzero__ = __bool__
+
+ def empty(self):
+ """Return True if this :class:`.History` has no changes
+ and no existing, unchanged state.
+
+ """
+
+ return not bool((self.added or self.deleted) or self.unchanged)
+
+ def sum(self):
+ """Return a collection of added + unchanged + deleted."""
+
+ return (
+ (self.added or []) + (self.unchanged or []) + (self.deleted or [])
+ )
+
+ def non_deleted(self):
+ """Return a collection of added + unchanged."""
+
+ return (self.added or []) + (self.unchanged or [])
+
+ def non_added(self):
+ """Return a collection of unchanged + deleted."""
+
+ return (self.unchanged or []) + (self.deleted or [])
+
+ def has_changes(self):
+ """Return True if this :class:`.History` has changes."""
+
+ return bool(self.added or self.deleted)
+
+ def as_state(self):
+ return History(
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.added
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.unchanged
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.deleted
+ ],
+ )
+
+ @classmethod
+ def from_scalar_attribute(cls, attribute, state, current):
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+
+ if original is _NO_HISTORY:
+ if current is NO_VALUE:
+ return cls((), (), ())
+ else:
+ return cls((), [current], ())
+ # don't let ClauseElement expressions here trip things up
+ elif (
+ current is not NO_VALUE
+ and attribute.is_equal(current, original) is True
+ ):
+ return cls((), [current], ())
+ else:
+ # current convention on native scalars is to not
+ # include information
+ # about missing previous value in "deleted", but
+ # we do include None, which helps in some primary
+ # key situations
+ if id(original) in _NO_STATE_SYMBOLS:
+ deleted = ()
+ # indicate a "del" operation occurred when we don't have
+ # the previous value as: ([None], (), ())
+ if id(current) in _NO_STATE_SYMBOLS:
+ current = None
+ else:
+ deleted = [original]
+ if current is NO_VALUE:
+ return cls((), (), deleted)
+ else:
+ return cls([current], (), deleted)
+
+ @classmethod
+ def from_object_attribute(
+ cls, attribute, state, current, original=_NO_HISTORY
+ ):
+ if original is _NO_HISTORY:
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+
+ if original is _NO_HISTORY:
+ if current is NO_VALUE:
+ return cls((), (), ())
+ else:
+ return cls((), [current], ())
+ elif current is original and current is not NO_VALUE:
+ return cls((), [current], ())
+ else:
+ # current convention on related objects is to not
+ # include information
+ # about missing previous value in "deleted", and
+ # to also not include None - the dependency.py rules
+ # ignore the None in any case.
+ if id(original) in _NO_STATE_SYMBOLS or original is None:
+ deleted = ()
+ # indicate a "del" operation occurred when we don't have
+ # the previous value as: ([None], (), ())
+ if id(current) in _NO_STATE_SYMBOLS:
+ current = None
+ else:
+ deleted = [original]
+ if current is NO_VALUE:
+ return cls((), (), deleted)
+ else:
+ return cls([current], (), deleted)
+
+ @classmethod
+ def from_collection(cls, attribute, state, current):
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+ if current is NO_VALUE:
+ return cls((), (), ())
+
+ current = getattr(current, "_sa_adapter")
+ if original is NO_VALUE:
+ return cls(list(current), (), ())
+ elif original is _NO_HISTORY:
+ return cls((), list(current), ())
+ else:
+
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
+
+ current_set = dict(current_states)
+ original_set = dict(original_states)
+
+ return cls(
+ [o for s, o in current_states if s not in original_set],
+ [o for s, o in current_states if s in original_set],
+ [o for s, o in original_states if s not in current_set],
+ )
+
+
+HISTORY_BLANK = History(None, None, None)
+
+
+def get_history(obj, key, passive=PASSIVE_OFF):
+ """Return a :class:`.History` record for the given object
+ and attribute key.
+
+ This is the **pre-flush** history for a given attribute, which is
+ reset each time the :class:`.Session` flushes changes to the
+ current database transaction.
+
+ .. note::
+
+ Prefer to use the :attr:`.AttributeState.history` and
+ :meth:`.AttributeState.load_history` accessors to retrieve the
+ :class:`.History` for instance attributes.
+
+
+ :param obj: an object whose class is instrumented by the
+ attributes package.
+
+ :param key: string attribute name.
+
+ :param passive: indicates loading behavior for the attribute
+ if the value is not already present. This is a
+ bitflag attribute, which defaults to the symbol
+ :attr:`.PASSIVE_OFF` indicating all necessary SQL
+ should be emitted.
+
+ .. seealso::
+
+ :attr:`.AttributeState.history`
+
+ :meth:`.AttributeState.load_history` - retrieve history
+ using loader callables if the value is not locally present.
+
+ """
+
+ return get_state_history(instance_state(obj), key, passive)
+
+
+def get_state_history(state, key, passive=PASSIVE_OFF):
+ return state.get_history(key, passive)
+
+
+def has_parent(cls, obj, key, optimistic=False):
+ """TODO"""
+ manager = manager_of_class(cls)
+ state = instance_state(obj)
+ return manager.has_parent(state, key, optimistic)
+
+
+def register_attribute(class_, key, **kw):
+ comparator = kw.pop("comparator", None)
+ parententity = kw.pop("parententity", None)
+ doc = kw.pop("doc", None)
+ desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
+ register_attribute_impl(class_, key, **kw)
+ return desc
+
+
+def register_attribute_impl(
+ class_,
+ key,
+ uselist=False,
+ callable_=None,
+ useobject=False,
+ impl_class=None,
+ backref=None,
+ **kw
+):
+
+ manager = manager_of_class(class_)
+ if uselist:
+ factory = kw.pop("typecallable", None)
+ typecallable = manager.instrument_collection_class(
+ key, factory or list
+ )
+ else:
+ typecallable = kw.pop("typecallable", None)
+
+ dispatch = manager[key].dispatch
+
+ if impl_class:
+ impl = impl_class(class_, key, typecallable, dispatch, **kw)
+ elif uselist:
+ impl = CollectionAttributeImpl(
+ class_, key, callable_, dispatch, typecallable=typecallable, **kw
+ )
+ elif useobject:
+ impl = ScalarObjectAttributeImpl(
+ class_, key, callable_, dispatch, **kw
+ )
+ else:
+ impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw)
+
+ manager[key].impl = impl
+
+ if backref:
+ backref_listeners(manager[key], backref, uselist)
+
+ manager.post_configure_attribute(key)
+ return manager[key]
+
+
+def register_descriptor(
+ class_, key, comparator=None, parententity=None, doc=None
+):
+ manager = manager_of_class(class_)
+
+ descriptor = InstrumentedAttribute(
+ class_, key, comparator=comparator, parententity=parententity
+ )
+
+ descriptor.__doc__ = doc
+
+ manager.instrument_attribute(key, descriptor)
+ return descriptor
+
+
+def unregister_attribute(class_, key):
+ manager_of_class(class_).uninstrument_attribute(key)
+
+
+def init_collection(obj, key):
+ """Initialize a collection attribute and return the collection adapter.
+
+ This function is used to provide direct access to collection internals
+ for a previously unloaded attribute. e.g.::
+
+ collection_adapter = init_collection(someobject, 'elements')
+ for elem in values:
+ collection_adapter.append_without_event(elem)
+
+ For an easier way to do the above, see
+ :func:`~sqlalchemy.orm.attributes.set_committed_value`.
+
+ :param obj: a mapped object
+
+ :param key: string attribute name where the collection is located.
+
+ """
+ state = instance_state(obj)
+ dict_ = state.dict
+ return init_state_collection(state, dict_, key)
+
+
+def init_state_collection(state, dict_, key):
+ """Initialize a collection attribute and return the collection adapter.
+
+ Discards any existing collection which may be there.
+
+ """
+ attr = state.manager[key].impl
+
+ old = dict_.pop(key, None) # discard old collection
+ if old is not None:
+ old_collection = old._sa_adapter
+ attr._dispose_previous_collection(state, old, old_collection, False)
+
+ user_data = attr._default_value(state, dict_)
+ adapter = attr.get_collection(state, dict_, user_data)
+ adapter._reset_empty()
+
+ return adapter
+
+
+def set_committed_value(instance, key, value):
+ """Set the value of an attribute with no history events.
+
+ Cancels any previous history present. The value should be
+ a scalar value for scalar-holding attributes, or
+ an iterable for any collection-holding attribute.
+
+ This is the same underlying method used when a lazy loader
+ fires off and loads additional data from the database.
+ In particular, this method can be used by application code
+ which has loaded additional attributes or collections through
+ separate queries, which can then be attached to an instance
+ as though it were part of its original loaded state.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.set_committed_value(state, dict_, value)
+
+
+def set_attribute(instance, key, value, initiator=None):
+ """Set the value of an attribute, firing history events.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to establish attribute state as understood
+ by SQLAlchemy.
+
+ :param instance: the object that will be modified
+
+ :param key: string name of the attribute
+
+ :param value: value to assign
+
+ :param initiator: an instance of :class:`.Event` that would have
+ been propagated from a previous event listener. This argument
+ is used when the :func:`.set_attribute` function is being used within
+ an existing event listening function where an :class:`.Event` object
+ is being supplied; the object may be used to track the origin of the
+ chain of events.
+
+ .. versionadded:: 1.2.3
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.set(state, dict_, value, initiator)
+
+
+def get_attribute(instance, key):
+ """Get the value of an attribute, firing any callables required.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to make usage of attribute state as understood
+ by SQLAlchemy.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ return state.manager[key].impl.get(state, dict_)
+
+
+def del_attribute(instance, key):
+ """Delete the value of an attribute, firing history events.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to establish attribute state as understood
+ by SQLAlchemy.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.delete(state, dict_)
+
+
+def flag_modified(instance, key):
+ """Mark an attribute on an instance as 'modified'.
+
+ This sets the 'modified' flag on the instance and
+ establishes an unconditional change event for the given attribute.
+ The attribute must have a value present, else an
+ :class:`.InvalidRequestError` is raised.
+
+ To mark an object "dirty" without referring to any specific attribute
+ so that it is considered within a flush, use the
+ :func:`.attributes.flag_dirty` call.
+
+ .. seealso::
+
+ :func:`.attributes.flag_dirty`
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ impl = state.manager[key].impl
+ impl.dispatch.modified(state, impl._modified_token)
+ state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
+
+
+def flag_dirty(instance):
+ """Mark an instance as 'dirty' without any specific attribute mentioned.
+
+ This is a special operation that will allow the object to travel through
+ the flush process for interception by events such as
+ :meth:`.SessionEvents.before_flush`. Note that no SQL will be emitted in
+ the flush process for an object that has no changes, even if marked dirty
+ via this method. However, a :meth:`.SessionEvents.before_flush` handler
+ will be able to see the object in the :attr:`.Session.dirty` collection and
+ may establish changes on it, which will then be included in the SQL
+ emitted.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :func:`.attributes.flag_modified`
+
+ """
+
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state._modified_event(dict_, None, NO_VALUE, is_userland=True)
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
new file mode 100644
index 0000000..8e94d7b
--- /dev/null
+++ b/lib/sqlalchemy/orm/base.py
@@ -0,0 +1,572 @@
+# orm/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Constants and rudimental functions used throughout the ORM.
+
+"""
+
+import operator
+
+from . import exc
+from .. import exc as sa_exc
+from .. import inspection
+from .. import util
+
+
+PASSIVE_NO_RESULT = util.symbol(
+ "PASSIVE_NO_RESULT",
+ """Symbol returned by a loader callable or other attribute/history
+ retrieval operation when a value could not be determined, based
+ on loader callable flags.
+ """,
+)
+
+PASSIVE_CLASS_MISMATCH = util.symbol(
+ "PASSIVE_CLASS_MISMATCH",
+ """Symbol indicating that an object is locally present for a given
+ primary key identity but it is not of the requested class. The
+ return value is therefore None and no SQL should be emitted.""",
+)
+
+ATTR_WAS_SET = util.symbol(
+ "ATTR_WAS_SET",
+ """Symbol returned by a loader callable to indicate the
+ retrieved value, or values, were assigned to their attributes
+ on the target object.
+ """,
+)
+
+ATTR_EMPTY = util.symbol(
+ "ATTR_EMPTY",
+ """Symbol used internally to indicate an attribute had no callable.""",
+)
+
+NO_VALUE = util.symbol(
+ "NO_VALUE",
+ """Symbol which may be placed as the 'previous' value of an attribute,
+ indicating no value was loaded for an attribute when it was modified,
+ and flags indicated we were not to load it.
+ """,
+)
+NEVER_SET = NO_VALUE
+"""
+Synonymous with NO_VALUE
+
+.. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE
+"""
+
+NO_CHANGE = util.symbol(
+ "NO_CHANGE",
+ """No callables or SQL should be emitted on attribute access
+ and no state should change
+ """,
+ canonical=0,
+)
+
+CALLABLES_OK = util.symbol(
+ "CALLABLES_OK",
+ """Loader callables can be fired off if a value
+ is not present.
+ """,
+ canonical=1,
+)
+
+SQL_OK = util.symbol(
+ "SQL_OK",
+ """Loader callables can emit SQL at least on scalar value attributes.""",
+ canonical=2,
+)
+
+RELATED_OBJECT_OK = util.symbol(
+ "RELATED_OBJECT_OK",
+ """Callables can use SQL to load related objects as well
+ as scalar value attributes.
+ """,
+ canonical=4,
+)
+
+INIT_OK = util.symbol(
+ "INIT_OK",
+ """Attributes should be initialized with a blank
+ value (None or an empty collection) upon get, if no other
+ value can be obtained.
+ """,
+ canonical=8,
+)
+
+NON_PERSISTENT_OK = util.symbol(
+ "NON_PERSISTENT_OK",
+ """Callables can be emitted if the parent is not persistent.""",
+ canonical=16,
+)
+
+LOAD_AGAINST_COMMITTED = util.symbol(
+ "LOAD_AGAINST_COMMITTED",
+ """Callables should use committed values as primary/foreign keys during a
+ load.
+ """,
+ canonical=32,
+)
+
+NO_AUTOFLUSH = util.symbol(
+ "NO_AUTOFLUSH",
+ """Loader callables should disable autoflush.""",
+ canonical=64,
+)
+
+NO_RAISE = util.symbol(
+ "NO_RAISE",
+ """Loader callables should not raise any assertions""",
+ canonical=128,
+)
+
+DEFERRED_HISTORY_LOAD = util.symbol(
+ "DEFERRED_HISTORY_LOAD",
+ """indicates special load of the previous value of an attribute""",
+ canonical=256,
+)
+
+# pre-packaged sets of flags used as inputs
+PASSIVE_OFF = util.symbol(
+ "PASSIVE_OFF",
+ "Callables can be emitted in all cases.",
+ canonical=(
+ RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK
+ ),
+)
+PASSIVE_RETURN_NO_VALUE = util.symbol(
+ "PASSIVE_RETURN_NO_VALUE",
+ """PASSIVE_OFF ^ INIT_OK""",
+ canonical=PASSIVE_OFF ^ INIT_OK,
+)
+PASSIVE_NO_INITIALIZE = util.symbol(
+ "PASSIVE_NO_INITIALIZE",
+ "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK",
+ canonical=PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK,
+)
+PASSIVE_NO_FETCH = util.symbol(
+ "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK
+)
+PASSIVE_NO_FETCH_RELATED = util.symbol(
+ "PASSIVE_NO_FETCH_RELATED",
+ "PASSIVE_OFF ^ RELATED_OBJECT_OK",
+ canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK,
+)
+PASSIVE_ONLY_PERSISTENT = util.symbol(
+ "PASSIVE_ONLY_PERSISTENT",
+ "PASSIVE_OFF ^ NON_PERSISTENT_OK",
+ canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK,
+)
+
+DEFAULT_MANAGER_ATTR = "_sa_class_manager"
+DEFAULT_STATE_ATTR = "_sa_instance_state"
+
+EXT_CONTINUE = util.symbol("EXT_CONTINUE")
+EXT_STOP = util.symbol("EXT_STOP")
+EXT_SKIP = util.symbol("EXT_SKIP")
+
+ONETOMANY = util.symbol(
+ "ONETOMANY",
+ """Indicates the one-to-many direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+MANYTOONE = util.symbol(
+ "MANYTOONE",
+ """Indicates the many-to-one direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+MANYTOMANY = util.symbol(
+ "MANYTOMANY",
+ """Indicates the many-to-many direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+NOT_EXTENSION = util.symbol(
+ "NOT_EXTENSION",
+ """Symbol indicating an :class:`InspectionAttr` that's
+ not part of sqlalchemy.ext.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ """,
+)
+
+_never_set = frozenset([NEVER_SET])
+
+_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT])
+
+_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED")
+
+_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
+
+_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
+
+
+def _assertions(*assertions):
+ @util.decorator
+ def generate(fn, *args, **kw):
+ self = args[0]
+ for assertion in assertions:
+ assertion(self, fn.__name__)
+ fn(self, *args[1:], **kw)
+
+ return generate
+
+
+# these can be replaced by sqlalchemy.ext.instrumentation
+# if augmented class instrumentation is enabled.
+def manager_of_class(cls):
+ return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+
+
+instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
+
+instance_dict = operator.attrgetter("__dict__")
+
+
+def instance_str(instance):
+ """Return a string describing an instance."""
+
+ return state_str(instance_state(instance))
+
+
+def state_str(state):
+ """Return a string describing an instance via its InstanceState."""
+
+ if state is None:
+ return "None"
+ else:
+ return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj()))
+
+
+def state_class_str(state):
+ """Return a string describing an instance's class via its
+ InstanceState.
+ """
+
+ if state is None:
+ return "None"
+ else:
+ return "<%s>" % (state.class_.__name__,)
+
+
+def attribute_str(instance, attribute):
+ return instance_str(instance) + "." + attribute
+
+
+def state_attribute_str(state, attribute):
+ return state_str(state) + "." + attribute
+
+
+def object_mapper(instance):
+ """Given an object, return the primary Mapper associated with the object
+ instance.
+
+ Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
+ if no mapping is configured.
+
+ This function is available via the inspection system as::
+
+ inspect(instance).mapper
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
+ not part of a mapping.
+
+ """
+ return object_state(instance).mapper
+
+
+def object_state(instance):
+ """Given an object, return the :class:`.InstanceState`
+ associated with the object.
+
+ Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
+ if no mapping is configured.
+
+ Equivalent functionality is available via the :func:`_sa.inspect`
+ function as::
+
+ inspect(instance)
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
+ not part of a mapping.
+
+ """
+ state = _inspect_mapped_object(instance)
+ if state is None:
+ raise exc.UnmappedInstanceError(instance)
+ else:
+ return state
+
+
+@inspection._inspects(object)
+def _inspect_mapped_object(instance):
+ try:
+ return instance_state(instance)
+ except (exc.UnmappedClassError,) + exc.NO_STATE:
+ return None
+
+
+def _class_to_mapper(class_or_mapper):
+ insp = inspection.inspect(class_or_mapper, False)
+ if insp is not None:
+ return insp.mapper
+ else:
+ raise exc.UnmappedClassError(class_or_mapper)
+
+
+def _mapper_or_none(entity):
+ """Return the :class:`_orm.Mapper` for the given class or None if the
+ class is not mapped.
+ """
+
+ insp = inspection.inspect(entity, False)
+ if insp is not None:
+ return insp.mapper
+ else:
+ return None
+
+
+def _is_mapped_class(entity):
+ """Return True if the given object is a mapped class,
+ :class:`_orm.Mapper`, or :class:`.AliasedClass`.
+ """
+
+ insp = inspection.inspect(entity, False)
+ return (
+ insp is not None
+ and not insp.is_clause_element
+ and (insp.is_mapper or insp.is_aliased_class)
+ )
+
+
+def _orm_columns(entity):
+ insp = inspection.inspect(entity, False)
+ if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
+ return [c for c in insp.selectable.c]
+ else:
+ return [entity]
+
+
+def _is_aliased_class(entity):
+ insp = inspection.inspect(entity, False)
+ return insp is not None and getattr(insp, "is_aliased_class", False)
+
+
+def _entity_descriptor(entity, key):
+ """Return a class attribute given an entity and string name.
+
+ May return :class:`.InstrumentedAttribute` or user-defined
+ attribute.
+
+ """
+ insp = inspection.inspect(entity)
+ if insp.is_selectable:
+ description = entity
+ entity = insp.c
+ elif insp.is_aliased_class:
+ entity = insp.entity
+ description = entity
+ elif hasattr(insp, "mapper"):
+ description = entity = insp.mapper.class_
+ else:
+ description = entity
+
+ try:
+ return getattr(entity, key)
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Entity '%s' has no property '%s'" % (description, key)
+ ),
+ replace_context=err,
+ )
+
+
+_state_mapper = util.dottedgetter("manager.mapper")
+
+
+@inspection._inspects(type)
+def _inspect_mapped_class(class_, configure=False):
+ try:
+ class_manager = manager_of_class(class_)
+ if not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+ return None
+ else:
+ if configure:
+ mapper._check_configure()
+ return mapper
+
+
+def class_mapper(class_, configure=True):
+ """Given a class, return the primary :class:`_orm.Mapper` associated
+ with the key.
+
+ Raises :exc:`.UnmappedClassError` if no mapping is configured
+ on the given class, or :exc:`.ArgumentError` if a non-class
+ object is passed.
+
+ Equivalent functionality is available via the :func:`_sa.inspect`
+ function as::
+
+ inspect(some_mapped_class)
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped.
+
+ """
+ mapper = _inspect_mapped_class(class_, configure=configure)
+ if mapper is None:
+ if not isinstance(class_, type):
+ raise sa_exc.ArgumentError(
+ "Class object expected, got '%r'." % (class_,)
+ )
+ raise exc.UnmappedClassError(class_)
+ else:
+ return mapper
+
+
+class InspectionAttr(object):
+ """A base class applied to all ORM objects that can be returned
+ by the :func:`_sa.inspect` function.
+
+ The attributes defined here allow the usage of simple boolean
+ checks to test basic facts about the object returned.
+
+ While the boolean checks here are basically the same as using
+ the Python isinstance() function, the flags here can be used without
+ the need to import all of these classes, and also such that
+ the SQLAlchemy class system can change while leaving the flags
+ here intact for forwards-compatibility.
+
+ """
+
+ __slots__ = ()
+
+ is_selectable = False
+ """Return True if this object is an instance of
+ :class:`_expression.Selectable`."""
+
+ is_aliased_class = False
+ """True if this object is an instance of :class:`.AliasedClass`."""
+
+ is_instance = False
+ """True if this object is an instance of :class:`.InstanceState`."""
+
+ is_mapper = False
+ """True if this object is an instance of :class:`_orm.Mapper`."""
+
+ is_bundle = False
+ """True if this object is an instance of :class:`.Bundle`."""
+
+ is_property = False
+ """True if this object is an instance of :class:`.MapperProperty`."""
+
+ is_attribute = False
+ """True if this object is a Python :term:`descriptor`.
+
+ This can refer to one of many types. Usually a
+ :class:`.QueryableAttribute` which handles attributes events on behalf
+ of a :class:`.MapperProperty`. But can also be an extension type
+ such as :class:`.AssociationProxy` or :class:`.hybrid_property`.
+ The :attr:`.InspectionAttr.extension_type` will refer to a constant
+ identifying the specific subtype.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ """
+
+ _is_internal_proxy = False
+ """True if this object is an internal proxy object.
+
+ .. versionadded:: 1.2.12
+
+ """
+
+ is_clause_element = False
+ """True if this object is an instance of
+ :class:`_expression.ClauseElement`."""
+
+ extension_type = NOT_EXTENSION
+ """The extension type, if any.
+ Defaults to :data:`.interfaces.NOT_EXTENSION`
+
+ .. seealso::
+
+ :data:`.HYBRID_METHOD`
+
+ :data:`.HYBRID_PROPERTY`
+
+ :data:`.ASSOCIATION_PROXY`
+
+ """
+
+
+class InspectionAttrInfo(InspectionAttr):
+ """Adds the ``.info`` attribute to :class:`.InspectionAttr`.
+
+ The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo`
+ is that the former is compatible as a mixin for classes that specify
+ ``__slots__``; this is essentially an implementation artifact.
+
+ """
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.InspectionAttr`.
+
+ The dictionary is generated when first accessed. Alternatively,
+ it can be specified as a constructor argument to the
+ :func:`.column_property`, :func:`_orm.relationship`, or
+ :func:`.composite`
+ functions.
+
+ .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
+ available on extension types via the
+ :attr:`.InspectionAttrInfo.info` attribute, so that it can apply
+ to a wider variety of ORM and extension constructs.
+
+ .. seealso::
+
+ :attr:`.QueryableAttribute.info`
+
+ :attr:`.SchemaItem.info`
+
+ """
+ return {}
+
+
+class _MappedAttribute(object):
+ """Mixin for attributes which should be replaced by mapper-assigned
+ attributes.
+
+ """
+
+ __slots__ = ()
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
new file mode 100644
index 0000000..2c21498
--- /dev/null
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -0,0 +1,441 @@
+# ext/declarative/clsregistry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Routines to handle the string class registry used by declarative.
+
+This system allows specification of classes and expressions used in
+:func:`_orm.relationship` using strings.
+
+"""
+import weakref
+
+from . import attributes
+from . import interfaces
+from .descriptor_props import SynonymProperty
+from .properties import ColumnProperty
+from .util import class_mapper
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql.schema import _get_table_key
+
+# strong references to registries which we place in
+# the _decl_class_registry, which is usually weak referencing.
+# the internal registries here link to classes with weakrefs and remove
+# themselves when all references to contained classes are removed.
+_registries = set()
+
+
+def add_class(classname, cls, decl_class_registry):
+ """Add a class to the _decl_class_registry associated with the
+ given declarative class.
+
+ """
+ if classname in decl_class_registry:
+ # class already exists.
+ existing = decl_class_registry[classname]
+ if not isinstance(existing, _MultipleClassMarker):
+ existing = decl_class_registry[classname] = _MultipleClassMarker(
+ [cls, existing]
+ )
+ else:
+ decl_class_registry[classname] = cls
+
+ try:
+ root_module = decl_class_registry["_sa_module_registry"]
+ except KeyError:
+ decl_class_registry[
+ "_sa_module_registry"
+ ] = root_module = _ModuleMarker("_sa_module_registry", None)
+
+ tokens = cls.__module__.split(".")
+
+ # build up a tree like this:
+ # modulename: myapp.snacks.nuts
+ #
+ # myapp->snack->nuts->(classes)
+ # snack->nuts->(classes)
+ # nuts->(classes)
+ #
+ # this allows partial token paths to be used.
+ while tokens:
+ token = tokens.pop(0)
+ module = root_module.get_module(token)
+ for token in tokens:
+ module = module.get_module(token)
+ module.add_class(classname, cls)
+
+
+def remove_class(classname, cls, decl_class_registry):
+ if classname in decl_class_registry:
+ existing = decl_class_registry[classname]
+ if isinstance(existing, _MultipleClassMarker):
+ existing.remove_item(cls)
+ else:
+ del decl_class_registry[classname]
+
+ try:
+ root_module = decl_class_registry["_sa_module_registry"]
+ except KeyError:
+ return
+
+ tokens = cls.__module__.split(".")
+
+ while tokens:
+ token = tokens.pop(0)
+ module = root_module.get_module(token)
+ for token in tokens:
+ module = module.get_module(token)
+ module.remove_class(classname, cls)
+
+
+def _key_is_empty(key, decl_class_registry, test):
+ """test if a key is empty of a certain object.
+
+ used for unit tests against the registry to see if garbage collection
+ is working.
+
+ "test" is a callable that will be passed an object should return True
+ if the given object is the one we were looking for.
+
+ We can't pass the actual object itself b.c. this is for testing garbage
+ collection; the caller will have to have removed references to the
+ object itself.
+
+ """
+ if key not in decl_class_registry:
+ return True
+
+ thing = decl_class_registry[key]
+ if isinstance(thing, _MultipleClassMarker):
+ for sub_thing in thing.contents:
+ if test(sub_thing):
+ return False
+ else:
+ return not test(thing)
+
+
+class _MultipleClassMarker(object):
+ """refers to multiple classes of the same name
+ within _decl_class_registry.
+
+ """
+
+ __slots__ = "on_remove", "contents", "__weakref__"
+
+ def __init__(self, classes, on_remove=None):
+ self.on_remove = on_remove
+ self.contents = set(
+ [weakref.ref(item, self._remove_item) for item in classes]
+ )
+ _registries.add(self)
+
+ def remove_item(self, cls):
+ self._remove_item(weakref.ref(cls))
+
+ def __iter__(self):
+ return (ref() for ref in self.contents)
+
+ def attempt_get(self, path, key):
+ if len(self.contents) > 1:
+ raise exc.InvalidRequestError(
+ 'Multiple classes found for path "%s" '
+ "in the registry of this declarative "
+ "base. Please use a fully module-qualified path."
+ % (".".join(path + [key]))
+ )
+ else:
+ ref = list(self.contents)[0]
+ cls = ref()
+ if cls is None:
+ raise NameError(key)
+ return cls
+
+ def _remove_item(self, ref):
+ self.contents.discard(ref)
+ if not self.contents:
+ _registries.discard(self)
+ if self.on_remove:
+ self.on_remove()
+
+ def add_item(self, item):
+ # protect against class registration race condition against
+ # asynchronous garbage collection calling _remove_item,
+ # [ticket:3208]
+ modules = set(
+ [
+ cls.__module__
+ for cls in [ref() for ref in self.contents]
+ if cls is not None
+ ]
+ )
+ if item.__module__ in modules:
+ util.warn(
+ "This declarative base already contains a class with the "
+ "same class name and module name as %s.%s, and will "
+ "be replaced in the string-lookup table."
+ % (item.__module__, item.__name__)
+ )
+ self.contents.add(weakref.ref(item, self._remove_item))
+
+
+class _ModuleMarker(object):
+ """Refers to a module name within
+ _decl_class_registry.
+
+ """
+
+ __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
+
+ def __init__(self, name, parent):
+ self.parent = parent
+ self.name = name
+ self.contents = {}
+ self.mod_ns = _ModNS(self)
+ if self.parent:
+ self.path = self.parent.path + [self.name]
+ else:
+ self.path = []
+ _registries.add(self)
+
+ def __contains__(self, name):
+ return name in self.contents
+
+ def __getitem__(self, name):
+ return self.contents[name]
+
+ def _remove_item(self, name):
+ self.contents.pop(name, None)
+ if not self.contents and self.parent is not None:
+ self.parent._remove_item(self.name)
+ _registries.discard(self)
+
+ def resolve_attr(self, key):
+ return getattr(self.mod_ns, key)
+
+ def get_module(self, name):
+ if name not in self.contents:
+ marker = _ModuleMarker(name, self)
+ self.contents[name] = marker
+ else:
+ marker = self.contents[name]
+ return marker
+
+ def add_class(self, name, cls):
+ if name in self.contents:
+ existing = self.contents[name]
+ existing.add_item(cls)
+ else:
+ existing = self.contents[name] = _MultipleClassMarker(
+ [cls], on_remove=lambda: self._remove_item(name)
+ )
+
+ def remove_class(self, name, cls):
+ if name in self.contents:
+ existing = self.contents[name]
+ existing.remove_item(cls)
+
+
+class _ModNS(object):
+ __slots__ = ("__parent",)
+
+ def __init__(self, parent):
+ self.__parent = parent
+
+ def __getattr__(self, key):
+ try:
+ value = self.__parent.contents[key]
+ except KeyError:
+ pass
+ else:
+ if value is not None:
+ if isinstance(value, _ModuleMarker):
+ return value.mod_ns
+ else:
+ assert isinstance(value, _MultipleClassMarker)
+ return value.attempt_get(self.__parent.path, key)
+ raise NameError(
+ "Module %r has no mapped classes "
+ "registered under the name %r" % (self.__parent.name, key)
+ )
+
+
+class _GetColumns(object):
+ __slots__ = ("cls",)
+
+ def __init__(self, cls):
+ self.cls = cls
+
+ def __getattr__(self, key):
+ mp = class_mapper(self.cls, configure=False)
+ if mp:
+ if key not in mp.all_orm_descriptors:
+ raise AttributeError(
+ "Class %r does not have a mapped column named %r"
+ % (self.cls, key)
+ )
+
+ desc = mp.all_orm_descriptors[key]
+ if desc.extension_type is interfaces.NOT_EXTENSION:
+ prop = desc.property
+ if isinstance(prop, SynonymProperty):
+ key = prop.name
+ elif not isinstance(prop, ColumnProperty):
+ raise exc.InvalidRequestError(
+ "Property %r is not an instance of"
+ " ColumnProperty (i.e. does not correspond"
+ " directly to a Column)." % key
+ )
+ return getattr(self.cls, key)
+
+
+inspection._inspects(_GetColumns)(
+ lambda target: inspection.inspect(target.cls)
+)
+
+
+class _GetTable(object):
+ __slots__ = "key", "metadata"
+
+ def __init__(self, key, metadata):
+ self.key = key
+ self.metadata = metadata
+
+ def __getattr__(self, key):
+ return self.metadata.tables[_get_table_key(key, self.key)]
+
+
+def _determine_container(key, value):
+ if isinstance(value, _MultipleClassMarker):
+ value = value.attempt_get([], key)
+ return _GetColumns(value)
+
+
+class _class_resolver(object):
+ __slots__ = (
+ "cls",
+ "prop",
+ "arg",
+ "fallback",
+ "_dict",
+ "_resolvers",
+ "favor_tables",
+ )
+
+ def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+ self.cls = cls
+ self.prop = prop
+ self.arg = arg
+ self.fallback = fallback
+ self._dict = util.PopulateDict(self._access_cls)
+ self._resolvers = ()
+ self.favor_tables = favor_tables
+
+ def _access_cls(self, key):
+ cls = self.cls
+
+ manager = attributes.manager_of_class(cls)
+ decl_base = manager.registry
+ decl_class_registry = decl_base._class_registry
+ metadata = decl_base.metadata
+
+ if self.favor_tables:
+ if key in metadata.tables:
+ return metadata.tables[key]
+ elif key in metadata._schemas:
+ return _GetTable(key, cls.metadata)
+
+ if key in decl_class_registry:
+ return _determine_container(key, decl_class_registry[key])
+
+ if not self.favor_tables:
+ if key in metadata.tables:
+ return metadata.tables[key]
+ elif key in metadata._schemas:
+ return _GetTable(key, cls.metadata)
+
+ if (
+ "_sa_module_registry" in decl_class_registry
+ and key in decl_class_registry["_sa_module_registry"]
+ ):
+ registry = decl_class_registry["_sa_module_registry"]
+ return registry.resolve_attr(key)
+ elif self._resolvers:
+ for resolv in self._resolvers:
+ value = resolv(key)
+ if value is not None:
+ return value
+
+ return self.fallback[key]
+
+ def _raise_for_name(self, name, err):
+ util.raise_(
+ exc.InvalidRequestError(
+ "When initializing mapper %s, expression %r failed to "
+ "locate a name (%r). If this is a class name, consider "
+ "adding this relationship() to the %r class after "
+ "both dependent classes have been defined."
+ % (self.prop.parent, self.arg, name, self.cls)
+ ),
+ from_=err,
+ )
+
+ def _resolve_name(self):
+ name = self.arg
+ d = self._dict
+ rval = None
+ try:
+ for token in name.split("."):
+ if rval is None:
+ rval = d[token]
+ else:
+ rval = getattr(rval, token)
+ except KeyError as err:
+ self._raise_for_name(name, err)
+ except NameError as n:
+ self._raise_for_name(n.args[0], n)
+ else:
+ if isinstance(rval, _GetColumns):
+ return rval.cls
+ else:
+ return rval
+
+ def __call__(self):
+ try:
+ x = eval(self.arg, globals(), self._dict)
+
+ if isinstance(x, _GetColumns):
+ return x.cls
+ else:
+ return x
+ except NameError as n:
+ self._raise_for_name(n.args[0], n)
+
+
+_fallback_dict = None
+
+
+def _resolver(cls, prop):
+
+ global _fallback_dict
+
+ if _fallback_dict is None:
+ import sqlalchemy
+ from sqlalchemy.orm import foreign, remote
+
+ _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
+ {"foreign": foreign, "remote": remote}
+ )
+
+ def resolve_arg(arg, favor_tables=False):
+ return _class_resolver(
+ cls, prop, _fallback_dict, arg, favor_tables=favor_tables
+ )
+
+ def resolve_name(arg):
+ return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
+
+ return resolve_name, resolve_arg
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
new file mode 100644
index 0000000..a189f02
--- /dev/null
+++ b/lib/sqlalchemy/orm/collections.py
@@ -0,0 +1,1706 @@
+# orm/collections.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Support for collections of mapped entities.
+
+The collections package supplies the machinery used to inform the ORM of
+collection membership changes. An instrumentation via decoration approach is
+used, allowing arbitrary types (including built-ins) to be used as entity
+collections without requiring inheritance from a base class.
+
+Instrumentation decoration relays membership change events to the
+:class:`.CollectionAttributeImpl` that is currently managing the collection.
+The decorators observe function call arguments and return values, tracking
+entities entering or leaving the collection. Two decorator approaches are
+provided. One is a bundle of generic decorators that map function arguments
+and return values to events::
+
+ from sqlalchemy.orm.collections import collection
+ class MyClass(object):
+ # ...
+
+ @collection.adds(1)
+ def store(self, item):
+ self.data.append(item)
+
+ @collection.removes_return()
+ def pop(self):
+ return self.data.pop()
+
+
+The second approach is a bundle of targeted decorators that wrap appropriate
+append and remove notifiers around the mutation methods present in the
+standard Python ``list``, ``set`` and ``dict`` interfaces. These could be
+specified in terms of generic decorator recipes, but are instead hand-tooled
+for increased efficiency. The targeted decorators occasionally implement
+adapter-like behavior, such as mapping bulk-set methods (``extend``,
+``update``, ``__setslice__``, etc.) into the series of atomic mutation events
+that the ORM requires.
+
+The targeted decorators are used internally for automatic instrumentation of
+entity collection classes. Every collection class goes through a
+transformation process roughly like so:
+
+1. If the class is a built-in, substitute a trivial sub-class
+2. Is this class already instrumented?
+3. Add in generic decorators
+4. Sniff out the collection interface through duck-typing
+5. Add targeted decoration to any undecorated interface method
+
+This process modifies the class at runtime, decorating methods and adding some
+bookkeeping properties. This isn't possible (or desirable) for built-in
+classes like ``list``, so trivial sub-classes are substituted to hold
+decoration::
+
+ class InstrumentedList(list):
+ pass
+
+Collection classes can be specified in ``relationship(collection_class=)`` as
+types or a function that returns an instance. Collection classes are
+inspected and instrumented during the mapper compilation phase. The
+collection_class callable will be executed once to produce a specimen
+instance, and the type of that specimen will be instrumented. Functions that
+return built-in types like ``lists`` will be adapted to produce instrumented
+instances.
+
+When extending a known type like ``list``, additional decorations are not
+generally not needed. Odds are, the extension method will delegate to a
+method that's already instrumented. For example::
+
+ class QueueIsh(list):
+ def push(self, item):
+ self.append(item)
+ def shift(self):
+ return self.pop(0)
+
+There's no need to decorate these methods. ``append`` and ``pop`` are already
+instrumented as part of the ``list`` interface. Decorating them would fire
+duplicate events, which should be avoided.
+
+The targeted decoration tries not to rely on other methods in the underlying
+collection class, but some are unavoidable. Many depend on 'read' methods
+being present to properly instrument a 'write', for example, ``__setitem__``
+needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also
+reimplemented in terms of atomic appends and removes, so the ``extend``
+decoration will actually perform many ``append`` operations and not call the
+underlying method at all.
+
+Tight control over bulk operation and the firing of events is also possible by
+implementing the instrumentation internally in your methods. The basic
+instrumentation package works under the general assumption that collection
+mutation will not raise unusual exceptions. If you want to closely
+orchestrate append and remove events with exception management, internal
+instrumentation may be the answer. Within your method,
+``collection_adapter(self)`` will retrieve an object that you can use for
+explicit control over triggering append and remove events.
+
+The owning object and :class:`.CollectionAttributeImpl` are also reachable
+through the adapter, allowing for some very sophisticated behavior.
+
+"""
+
+import operator
+import weakref
+
+from sqlalchemy.util.compat import inspect_getfullargspec
+from . import base
+from .. import exc as sa_exc
+from .. import util
+from ..sql import coercions
+from ..sql import expression
+from ..sql import roles
+
+__all__ = [
+ "collection",
+ "collection_adapter",
+ "mapped_collection",
+ "column_mapped_collection",
+ "attribute_mapped_collection",
+]
+
+__instrumentation_mutex = util.threading.Lock()
+
+
+class _PlainColumnGetter(object):
+ """Plain column getter, stores collection of Column objects
+ directly.
+
+ Serializes to a :class:`._SerializableColumnGetterV2`
+ which has more expensive __call__() performance
+ and some rare caveats.
+
+ """
+
+ def __init__(self, cols):
+ self.cols = cols
+ self.composite = len(cols) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetterV2._reduce_from_cols(self.cols)
+
+ def _cols(self, mapper):
+ return self.cols
+
+ def __call__(self, value):
+ state = base.instance_state(value)
+ m = base._state_mapper(state)
+
+ key = [
+ m._get_state_attr_by_column(state, state.dict, col)
+ for col in self._cols(m)
+ ]
+
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
+
+
+class _SerializableColumnGetter(object):
+ """Column-based getter used in version 0.7.6 only.
+
+ Remains here for pickle compatibility with 0.7.6.
+
+ """
+
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetter, (self.colkeys,)
+
+ def __call__(self, value):
+ state = base.instance_state(value)
+ m = base._state_mapper(state)
+ key = [
+ m._get_state_attr_by_column(
+ state, state.dict, m.mapped_table.columns[k]
+ )
+ for k in self.colkeys
+ ]
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
+
+
+class _SerializableColumnGetterV2(_PlainColumnGetter):
+ """Updated serializable getter which deals with
+ multi-table mapped classes.
+
+ Two extremely unusual cases are not supported.
+ Mappings which have tables across multiple metadata
+ objects, or which are mapped to non-Table selectables
+ linked across inheriting mappers may fail to function
+ here.
+
+ """
+
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return self.__class__, (self.colkeys,)
+
+ @classmethod
+ def _reduce_from_cols(cls, cols):
+ def _table_key(c):
+ if not isinstance(c.table, expression.TableClause):
+ return None
+ else:
+ return c.table.key
+
+ colkeys = [(c.key, _table_key(c)) for c in cols]
+ return _SerializableColumnGetterV2, (colkeys,)
+
+ def _cols(self, mapper):
+ cols = []
+ metadata = getattr(mapper.local_table, "metadata", None)
+ for (ckey, tkey) in self.colkeys:
+ if tkey is None or metadata is None or tkey not in metadata:
+ cols.append(mapper.local_table.c[ckey])
+ else:
+ cols.append(metadata.tables[tkey].c[ckey])
+ return cols
+
+
+def column_mapped_collection(mapping_spec):
+ """A dictionary-based collection type with column-based keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying function
+ generated from mapping_spec, which may be a Column or a sequence
+ of Columns.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+
+ """
+ cols = [
+ coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec")
+ for q in util.to_list(mapping_spec)
+ ]
+ keyfunc = _PlainColumnGetter(cols)
+ return lambda: MappedCollection(keyfunc)
+
+
+class _SerializableAttrGetter(object):
+ def __init__(self, name):
+ self.name = name
+ self.getter = operator.attrgetter(name)
+
+ def __call__(self, target):
+ return self.getter(target)
+
+ def __reduce__(self):
+ return _SerializableAttrGetter, (self.name,)
+
+
+def attribute_mapped_collection(attr_name):
+ """A dictionary-based collection type with attribute-based keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying based on the
+ 'attr_name' attribute of entities in the collection, where ``attr_name``
+ is the string name of the attribute.
+
+ .. warning:: the key value must be assigned to its final value
+ **before** it is accessed by the attribute mapped collection.
+ Additionally, changes to the key attribute are **not tracked**
+ automatically, which means the key in the dictionary is not
+ automatically synchronized with the key value on the target object
+ itself. See the section :ref:`key_collections_mutations`
+ for an example.
+
+ """
+ getter = _SerializableAttrGetter(attr_name)
+ return lambda: MappedCollection(getter)
+
+
+def mapped_collection(keyfunc):
+ """A dictionary-based collection type with arbitrary keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying function
+ generated from keyfunc, a callable that takes an entity and returns a
+ key value.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+
+ """
+ return lambda: MappedCollection(keyfunc)
+
+
+class collection(object):
+ """Decorators for entity collection classes.
+
+ The decorators fall into two groups: annotations and interception recipes.
+
+ The annotating decorators (appender, remover, iterator, converter,
+ internally_instrumented) indicate the method's purpose and take no
+ arguments. They are not written with parens::
+
+ @collection.appender
+ def append(self, append): ...
+
+ The recipe decorators all require parens, even those that take no
+ arguments::
+
+ @collection.adds('entity')
+ def insert(self, position, entity): ...
+
+ @collection.removes_return()
+ def popitem(self): ...
+
+ """
+
+ # Bundled as a class solely for ease of use: packaging, doc strings,
+ # importability.
+
+ @staticmethod
+ def appender(fn):
+ """Tag the method as the collection appender.
+
+ The appender method is called with one positional argument: the value
+ to append. The method will be automatically decorated with 'adds(1)'
+ if not already decorated::
+
+ @collection.appender
+ def add(self, append): ...
+
+ # or, equivalently
+ @collection.appender
+ @collection.adds(1)
+ def add(self, append): ...
+
+ # for mapping type, an 'append' may kick out a previous value
+ # that occupies that slot. consider d['a'] = 'foo'- any previous
+ # value in d['a'] is discarded.
+ @collection.appender
+ @collection.replaces(1)
+ def add(self, entity):
+ key = some_key_func(entity)
+ previous = None
+ if key in self:
+ previous = self[key]
+ self[key] = entity
+ return previous
+
+ If the value to append is not allowed in the collection, you may
+ raise an exception. Something to remember is that the appender
+ will be called for each object mapped by a database query. If the
+ database contains rows that violate your collection semantics, you
+ will need to get creative to fix the problem, as access via the
+ collection will not work.
+
+ If the appender method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+
+ """
+ fn._sa_instrument_role = "appender"
+ return fn
+
+ @staticmethod
+ def remover(fn):
+ """Tag the method as the collection remover.
+
+ The remover method is called with one positional argument: the value
+ to remove. The method will be automatically decorated with
+ :meth:`removes_return` if not already decorated::
+
+ @collection.remover
+ def zap(self, entity): ...
+
+ # or, equivalently
+ @collection.remover
+ @collection.removes_return()
+ def zap(self, ): ...
+
+ If the value to remove is not present in the collection, you may
+ raise an exception or return None to ignore the error.
+
+ If the remove method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+
+ """
+ fn._sa_instrument_role = "remover"
+ return fn
+
+ @staticmethod
+ def iterator(fn):
+ """Tag the method as the collection remover.
+
+ The iterator method is called with no arguments. It is expected to
+ return an iterator over all collection members::
+
+ @collection.iterator
+ def __iter__(self): ...
+
+ """
+ fn._sa_instrument_role = "iterator"
+ return fn
+
+ @staticmethod
+ def internally_instrumented(fn):
+ """Tag the method as instrumented.
+
+ This tag will prevent any decoration from being applied to the
+ method. Use this if you are orchestrating your own calls to
+ :func:`.collection_adapter` in one of the basic SQLAlchemy
+ interface methods, or to prevent an automatic ABC method
+ decoration from wrapping your implementation::
+
+ # normally an 'extend' method on a list-like class would be
+ # automatically intercepted and re-implemented in terms of
+ # SQLAlchemy events and append(). your implementation will
+ # never be called, unless:
+ @collection.internally_instrumented
+ def extend(self, items): ...
+
+ """
+ fn._sa_instrumented = True
+ return fn
+
+ @staticmethod
+ @util.deprecated(
+ "1.3",
+ "The :meth:`.collection.converter` handler is deprecated and will "
+ "be removed in a future release. Please refer to the "
+ ":class:`.AttributeEvents.bulk_replace` listener interface in "
+ "conjunction with the :func:`.event.listen` function.",
+ )
+ def converter(fn):
+ """Tag the method as the collection converter.
+
+ This optional method will be called when a collection is being
+ replaced entirely, as in::
+
+ myobj.acollection = [newvalue1, newvalue2]
+
+ The converter method will receive the object being assigned and should
+ return an iterable of values suitable for use by the ``appender``
+ method. A converter must not assign values or mutate the collection,
+ its sole job is to adapt the value the user provides into an iterable
+ of values for the ORM's use.
+
+ The default converter implementation will use duck-typing to do the
+ conversion. A dict-like collection will be convert into an iterable
+ of dictionary values, and other types will simply be iterated::
+
+ @collection.converter
+ def convert(self, other): ...
+
+ If the duck-typing of the object does not match the type of this
+ collection, a TypeError is raised.
+
+ Supply an implementation of this method if you want to expand the
+ range of possible types that can be assigned in bulk or perform
+ validation on the values about to be assigned.
+
+ """
+ fn._sa_instrument_role = "converter"
+ return fn
+
+ @staticmethod
+ def adds(arg):
+ """Mark the method as adding an entity to the collection.
+
+ Adds "add to collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value. Arguments can be specified positionally (i.e. integer) or by
+ name::
+
+ @collection.adds(1)
+ def push(self, item): ...
+
+ @collection.adds('entity')
+ def do_stuff(self, thing, entity=None): ...
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_append_event", arg)
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def replaces(arg):
+ """Mark the method as replacing an entity in the collection.
+
+ Adds "add to collection" and "remove from collection" handling to
+ the method. The decorator argument indicates which method argument
+ holds the SQLAlchemy-relevant value to be added, and return value, if
+ any will be considered the value to remove.
+
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.replaces(2)
+ def __setitem__(self, index, item): ...
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_append_event", arg)
+ fn._sa_instrument_after = "fire_remove_event"
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def removes(arg):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value to be removed. Arguments can be specified positionally (i.e.
+ integer) or by name::
+
+ @collection.removes(1)
+ def zap(self, item): ...
+
+ For methods where the value to remove is not known at call-time, use
+ collection.removes_return.
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_remove_event", arg)
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def removes_return():
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The return
+ value of the method, if any, is considered the value to remove. The
+ method arguments are not inspected::
+
+ @collection.removes_return()
+ def pop(self): ...
+
+ For methods where the value to remove is known at call-time, use
+ collection.remove.
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_after = "fire_remove_event"
+ return fn
+
+ return decorator
+
+
+collection_adapter = operator.attrgetter("_sa_adapter")
+"""Fetch the :class:`.CollectionAdapter` for a collection."""
+
+
+class CollectionAdapter(object):
+ """Bridges between the ORM and arbitrary Python collections.
+
+ Proxies base-level collection operations (append, remove, iterate)
+ to the underlying Python collection, and emits add/remove events for
+ entities entering or leaving the collection.
+
+ The ORM uses :class:`.CollectionAdapter` exclusively for interaction with
+ entity collections.
+
+
+ """
+
+ __slots__ = (
+ "attr",
+ "_key",
+ "_data",
+ "owner_state",
+ "_converter",
+ "invalidated",
+ "empty",
+ )
+
+ def __init__(self, attr, owner_state, data):
+ self.attr = attr
+ self._key = attr.key
+ self._data = weakref.ref(data)
+ self.owner_state = owner_state
+ data._sa_adapter = self
+ self._converter = data._sa_converter
+ self.invalidated = False
+ self.empty = False
+
+ def _warn_invalidated(self):
+ util.warn("This collection has been invalidated.")
+
+ @property
+ def data(self):
+ "The entity collection being adapted."
+ return self._data()
+
+ @property
+ def _referenced_by_owner(self):
+ """return True if the owner state still refers to this collection.
+
+ This will return False within a bulk replace operation,
+ where this collection is the one being replaced.
+
+ """
+ return self.owner_state.dict[self._key] is self._data()
+
+ def bulk_appender(self):
+ return self._data()._sa_appender
+
+ def append_with_event(self, item, initiator=None):
+ """Add an entity to the collection, firing mutation events."""
+
+ self._data()._sa_appender(item, _sa_initiator=initiator)
+
+ def _set_empty(self, user_data):
+ assert (
+ not self.empty
+ ), "This collection adapter is already in the 'empty' state"
+ self.empty = True
+ self.owner_state._empty_collections[self._key] = user_data
+
+ def _reset_empty(self):
+ assert (
+ self.empty
+ ), "This collection adapter is not in the 'empty' state"
+ self.empty = False
+ self.owner_state.dict[
+ self._key
+ ] = self.owner_state._empty_collections.pop(self._key)
+
+ def _refuse_empty(self):
+ raise sa_exc.InvalidRequestError(
+ "This is a special 'empty' collection which cannot accommodate "
+ "internal mutation operations"
+ )
+
+ def append_without_event(self, item):
+ """Add or restore an entity to the collection, firing no events."""
+
+ if self.empty:
+ self._refuse_empty()
+ self._data()._sa_appender(item, _sa_initiator=False)
+
+ def append_multiple_without_event(self, items):
+ """Add or restore an entity to the collection, firing no events."""
+ if self.empty:
+ self._refuse_empty()
+ appender = self._data()._sa_appender
+ for item in items:
+ appender(item, _sa_initiator=False)
+
+ def bulk_remover(self):
+ return self._data()._sa_remover
+
+ def remove_with_event(self, item, initiator=None):
+ """Remove an entity from the collection, firing mutation events."""
+ self._data()._sa_remover(item, _sa_initiator=initiator)
+
+ def remove_without_event(self, item):
+ """Remove an entity from the collection, firing no events."""
+ if self.empty:
+ self._refuse_empty()
+ self._data()._sa_remover(item, _sa_initiator=False)
+
+ def clear_with_event(self, initiator=None):
+ """Empty the collection, firing a mutation event for each entity."""
+
+ if self.empty:
+ self._refuse_empty()
+ remover = self._data()._sa_remover
+ for item in list(self):
+ remover(item, _sa_initiator=initiator)
+
+ def clear_without_event(self):
+ """Empty the collection, firing no events."""
+
+ if self.empty:
+ self._refuse_empty()
+ remover = self._data()._sa_remover
+ for item in list(self):
+ remover(item, _sa_initiator=False)
+
+ def __iter__(self):
+ """Iterate over entities in the collection."""
+
+ return iter(self._data()._sa_iterator())
+
+ def __len__(self):
+ """Count entities in the collection."""
+ return len(list(self._data()._sa_iterator()))
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+ def fire_append_wo_mutation_event(self, item, initiator=None):
+ """Notify that a entity is entering the collection but is already
+ present.
+
+
+ Initiator is a token owned by the InstrumentedAttribute that
+ initiated the membership mutation, and should be left as None
+ unless you are passing along an initiator value from a chained
+ operation.
+
+ .. versionadded:: 1.4.15
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ return self.attr.fire_append_wo_mutation_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+ else:
+ return item
+
+ def fire_append_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is a token owned by the InstrumentedAttribute that
+ initiated the membership mutation, and should be left as None
+ unless you are passing along an initiator value from a chained
+ operation.
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ return self.attr.fire_append_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+ else:
+ return item
+
+ def fire_remove_event(self, item, initiator=None):
+ """Notify that a entity has been removed from the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ self.attr.fire_remove_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+
+ def fire_pre_remove_event(self, initiator=None):
+ """Notify that an entity is about to be removed from the collection.
+
+ Only called if the entity cannot be removed after calling
+ fire_remove_event().
+
+ """
+ if self.invalidated:
+ self._warn_invalidated()
+ self.attr.fire_pre_remove_event(
+ self.owner_state, self.owner_state.dict, initiator=initiator
+ )
+
+ def __getstate__(self):
+ return {
+ "key": self._key,
+ "owner_state": self.owner_state,
+ "owner_cls": self.owner_state.class_,
+ "data": self.data,
+ "invalidated": self.invalidated,
+ "empty": self.empty,
+ }
+
+ def __setstate__(self, d):
+ self._key = d["key"]
+ self.owner_state = d["owner_state"]
+ self._data = weakref.ref(d["data"])
+ self._converter = d["data"]._sa_converter
+ d["data"]._sa_adapter = self
+ self.invalidated = d["invalidated"]
+ self.attr = getattr(d["owner_cls"], self._key).impl
+ self.empty = d.get("empty", False)
+
+
+def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
+ """Load a new collection, firing events based on prior like membership.
+
+ Appends instances in ``values`` onto the ``new_adapter``. Events will be
+ fired for any instance not present in the ``existing_adapter``. Any
+ instances in ``existing_adapter`` not present in ``values`` will have
+ remove events fired upon them.
+
+ :param values: An iterable of collection member instances
+
+ :param existing_adapter: A :class:`.CollectionAdapter` of
+ instances to be replaced
+
+ :param new_adapter: An empty :class:`.CollectionAdapter`
+ to load with ``values``
+
+
+ """
+
+ assert isinstance(values, list)
+
+ idset = util.IdentitySet
+ existing_idset = idset(existing_adapter or ())
+ constants = existing_idset.intersection(values or ())
+ additions = idset(values or ()).difference(constants)
+ removals = existing_idset.difference(constants)
+
+ appender = new_adapter.bulk_appender()
+
+ for member in values or ():
+ if member in additions:
+ appender(member, _sa_initiator=initiator)
+ elif member in constants:
+ appender(member, _sa_initiator=False)
+
+ if existing_adapter:
+ for member in removals:
+ existing_adapter.fire_remove_event(member, initiator=initiator)
+
+
+def prepare_instrumentation(factory):
+ """Prepare a callable for future use as a collection class factory.
+
+ Given a collection class factory (either a type or no-arg callable),
+ return another factory that will produce compatible instances when
+ called.
+
+ This function is responsible for converting collection_class=list
+ into the run-time behavior of collection_class=InstrumentedList.
+
+ """
+ # Convert a builtin to 'Instrumented*'
+ if factory in __canned_instrumentation:
+ factory = __canned_instrumentation[factory]
+
+ # Create a specimen
+ cls = type(factory())
+
+ # Did factory callable return a builtin?
+ if cls in __canned_instrumentation:
+ # Wrap it so that it returns our 'Instrumented*'
+ factory = __converting_factory(cls, factory)
+ cls = factory()
+
+ # Instrument the class if needed.
+ if __instrumentation_mutex.acquire():
+ try:
+ if getattr(cls, "_sa_instrumented", None) != id(cls):
+ _instrument_class(cls)
+ finally:
+ __instrumentation_mutex.release()
+
+ return factory
+
+
+def __converting_factory(specimen_cls, original_factory):
+ """Return a wrapper that converts a "canned" collection like
+ set, dict, list into the Instrumented* version.
+
+ """
+
+ instrumented_cls = __canned_instrumentation[specimen_cls]
+
+ def wrapper():
+ collection = original_factory()
+ return instrumented_cls(collection)
+
+ # often flawed but better than nothing
+ wrapper.__name__ = "%sWrapper" % original_factory.__name__
+ wrapper.__doc__ = original_factory.__doc__
+
+ return wrapper
+
+
+def _instrument_class(cls):
+ """Modify methods in a class and install instrumentation."""
+
+ # In the normal call flow, a request for any of the 3 basic collection
+ # types is transformed into one of our trivial subclasses
+ # (e.g. InstrumentedList). Catch anything else that sneaks in here...
+ if cls.__module__ == "__builtin__":
+ raise sa_exc.ArgumentError(
+ "Can not instrument a built-in type. Use a "
+ "subclass, even a trivial one."
+ )
+
+ roles, methods = _locate_roles_and_methods(cls)
+
+ _setup_canned_roles(cls, roles, methods)
+
+ _assert_required_roles(cls, roles, methods)
+
+ _set_collection_attributes(cls, roles, methods)
+
+
+def _locate_roles_and_methods(cls):
+ """search for _sa_instrument_role-decorated methods in
+ method resolution order, assign to roles.
+
+ """
+
+ roles = {}
+ methods = {}
+
+ for supercls in cls.__mro__:
+ for name, method in vars(supercls).items():
+ if not callable(method):
+ continue
+
+ # note role declarations
+ if hasattr(method, "_sa_instrument_role"):
+ role = method._sa_instrument_role
+ assert role in (
+ "appender",
+ "remover",
+ "iterator",
+ "converter",
+ )
+ roles.setdefault(role, name)
+
+ # transfer instrumentation requests from decorated function
+ # to the combined queue
+ before, after = None, None
+ if hasattr(method, "_sa_instrument_before"):
+ op, argument = method._sa_instrument_before
+ assert op in ("fire_append_event", "fire_remove_event")
+ before = op, argument
+ if hasattr(method, "_sa_instrument_after"):
+ op = method._sa_instrument_after
+ assert op in ("fire_append_event", "fire_remove_event")
+ after = op
+ if before:
+ methods[name] = before + (after,)
+ elif after:
+ methods[name] = None, None, after
+ return roles, methods
+
+
+def _setup_canned_roles(cls, roles, methods):
+ """see if this class has "canned" roles based on a known
+ collection type (dict, set, list). Apply those roles
+ as needed to the "roles" dictionary, and also
+ prepare "decorator" methods
+
+ """
+ collection_type = util.duck_type_collection(cls)
+ if collection_type in __interfaces:
+ canned_roles, decorators = __interfaces[collection_type]
+ for role, name in canned_roles.items():
+ roles.setdefault(role, name)
+
+ # apply ABC auto-decoration to methods that need it
+ for method, decorator in decorators.items():
+ fn = getattr(cls, method, None)
+ if (
+ fn
+ and method not in methods
+ and not hasattr(fn, "_sa_instrumented")
+ ):
+ setattr(cls, method, decorator(fn))
+
+
+def _assert_required_roles(cls, roles, methods):
+ """ensure all roles are present, and apply implicit instrumentation if
+ needed
+
+ """
+ if "appender" not in roles or not hasattr(cls, roles["appender"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect an appender method to be "
+ "a collection class" % cls.__name__
+ )
+ elif roles["appender"] not in methods and not hasattr(
+ getattr(cls, roles["appender"]), "_sa_instrumented"
+ ):
+ methods[roles["appender"]] = ("fire_append_event", 1, None)
+
+ if "remover" not in roles or not hasattr(cls, roles["remover"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect a remover method to be "
+ "a collection class" % cls.__name__
+ )
+ elif roles["remover"] not in methods and not hasattr(
+ getattr(cls, roles["remover"]), "_sa_instrumented"
+ ):
+ methods[roles["remover"]] = ("fire_remove_event", 1, None)
+
+ if "iterator" not in roles or not hasattr(cls, roles["iterator"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect an iterator method to be "
+ "a collection class" % cls.__name__
+ )
+
+
+def _set_collection_attributes(cls, roles, methods):
+ """apply ad-hoc instrumentation from decorators, class-level defaults
+ and implicit role declarations
+
+ """
+ for method_name, (before, argument, after) in methods.items():
+ setattr(
+ cls,
+ method_name,
+ _instrument_membership_mutator(
+ getattr(cls, method_name), before, argument, after
+ ),
+ )
+ # intern the role map
+ for role, method_name in roles.items():
+ setattr(cls, "_sa_%s" % role, getattr(cls, method_name))
+
+ cls._sa_adapter = None
+
+ if not hasattr(cls, "_sa_converter"):
+ cls._sa_converter = None
+ cls._sa_instrumented = id(cls)
+
+
+def _instrument_membership_mutator(method, before, argument, after):
+ """Route method args and/or return value through the collection
+ adapter."""
+ # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
+ if before:
+ fn_args = list(
+ util.flatten_iterator(inspect_getfullargspec(method)[0])
+ )
+ if isinstance(argument, int):
+ pos_arg = argument
+ named_arg = len(fn_args) > argument and fn_args[argument] or None
+ else:
+ if argument in fn_args:
+ pos_arg = fn_args.index(argument)
+ else:
+ pos_arg = None
+ named_arg = argument
+ del fn_args
+
+ def wrapper(*args, **kw):
+ if before:
+ if pos_arg is None:
+ if named_arg not in kw:
+ raise sa_exc.ArgumentError(
+ "Missing argument %s" % argument
+ )
+ value = kw[named_arg]
+ else:
+ if len(args) > pos_arg:
+ value = args[pos_arg]
+ elif named_arg in kw:
+ value = kw[named_arg]
+ else:
+ raise sa_exc.ArgumentError(
+ "Missing argument %s" % argument
+ )
+
+ initiator = kw.pop("_sa_initiator", None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = args[0]._sa_adapter
+
+ if before and executor:
+ getattr(executor, before)(value, initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+
+ wrapper._sa_instrumented = True
+ if hasattr(method, "_sa_instrument_role"):
+ wrapper._sa_instrument_role = method._sa_instrument_role
+ wrapper.__name__ = method.__name__
+ wrapper.__doc__ = method.__doc__
+ return wrapper
+
+
+def __set_wo_mutation(collection, item, _sa_initiator=None):
+ """Run set wo mutation events.
+
+ The collection is not mutated.
+
+ """
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_append_wo_mutation_event(item, _sa_initiator)
+
+
+def __set(collection, item, _sa_initiator=None):
+ """Run set events.
+
+ This event always occurs before the collection is actually mutated.
+
+ """
+
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ item = executor.fire_append_event(item, _sa_initiator)
+ return item
+
+
+def __del(collection, item, _sa_initiator=None):
+ """Run del events.
+
+ This event occurs before the collection is actually mutated, *except*
+ in the case of a pop operation, in which case it occurs afterwards.
+ For pop operations, the __before_pop hook is called before the
+ operation occurs.
+
+ """
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_remove_event(item, _sa_initiator)
+
+
+def __before_pop(collection, _sa_initiator=None):
+ """An event which occurs on a before a pop() operation occurs."""
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_pre_remove_event(_sa_initiator)
+
+
+def _list_decorators():
+ """Tailored instrumentation wrappers for any list-like class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(list, fn.__name__).__doc__
+
+ def append(fn):
+ def append(self, item, _sa_initiator=None):
+ item = __set(self, item, _sa_initiator)
+ fn(self, item)
+
+ _tidy(append)
+ return append
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__eq__
+ fn(self, value)
+
+ _tidy(remove)
+ return remove
+
+ def insert(fn):
+ def insert(self, index, value):
+ value = __set(self, value)
+ fn(self, index, value)
+
+ _tidy(insert)
+ return insert
+
+ def __setitem__(fn):
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ existing = self[index]
+ if existing is not None:
+ __del(self, existing)
+ value = __set(self, value)
+ fn(self, index, value)
+ else:
+ # slice assignment requires __delitem__, insert, __len__
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ if index.stop is not None:
+ stop = index.stop
+ else:
+ stop = len(self)
+ if stop < 0:
+ stop += len(self)
+
+ if step == 1:
+ if value is self:
+ return
+ for i in range(start, stop, step):
+ if len(self) > start:
+ del self[start]
+
+ for i, item in enumerate(value):
+ self.insert(i + start, item)
+ else:
+ rng = list(range(start, stop, step))
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s"
+ % (len(value), len(rng))
+ )
+ for i, item in zip(rng, value):
+ self.__setitem__(i, item)
+
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, index):
+ if not isinstance(index, slice):
+ item = self[index]
+ __del(self, item)
+ fn(self, index)
+ else:
+ # slice deletion requires __getslice__ and a slice-groking
+ # __getitem__ for stepped deletion
+ # note: not breaking this into atomic dels
+ for item in self[index]:
+ __del(self, item)
+ fn(self, index)
+
+ _tidy(__delitem__)
+ return __delitem__
+
+ if util.py2k:
+
+ def __setslice__(fn):
+ def __setslice__(self, start, end, values):
+ for value in self[start:end]:
+ __del(self, value)
+ values = [__set(self, value) for value in values]
+ fn(self, start, end, values)
+
+ _tidy(__setslice__)
+ return __setslice__
+
+ def __delslice__(fn):
+ def __delslice__(self, start, end):
+ for value in self[start:end]:
+ __del(self, value)
+ fn(self, start, end)
+
+ _tidy(__delslice__)
+ return __delslice__
+
+ def extend(fn):
+ def extend(self, iterable):
+ for value in list(iterable):
+ self.append(value)
+
+ _tidy(extend)
+ return extend
+
+ def __iadd__(fn):
+ def __iadd__(self, iterable):
+ # list.__iadd__ takes any iterable and seems to let TypeError
+ # raise as-is instead of returning NotImplemented
+ for value in list(iterable):
+ self.append(value)
+ return self
+
+ _tidy(__iadd__)
+ return __iadd__
+
+ def pop(fn):
+ def pop(self, index=-1):
+ __before_pop(self)
+ item = fn(self, index)
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ if not util.py2k:
+
+ def clear(fn):
+ def clear(self, index=-1):
+ for item in self:
+ __del(self, item)
+ fn(self)
+
+ _tidy(clear)
+ return clear
+
+ # __imul__ : not wrapping this. all members of the collection are already
+ # present, so no need to fire appends... wrapping it with an explicit
+ # decorator is still possible, so events on *= can be had if they're
+ # desired. hard to imagine a use case for __imul__, though.
+
+ l = locals().copy()
+ l.pop("_tidy")
+ return l
+
+
+def _dict_decorators():
+ """Tailored instrumentation wrappers for any dict-like mapping class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(dict, fn.__name__).__doc__
+
+ Unspecified = util.symbol("Unspecified")
+
+ def __setitem__(fn):
+ def __setitem__(self, key, value, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ value = __set(self, value, _sa_initiator)
+ fn(self, key, value)
+
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, key, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ fn(self, key)
+
+ _tidy(__delitem__)
+ return __delitem__
+
+ def clear(fn):
+ def clear(self):
+ for key in self:
+ __del(self, self[key])
+ fn(self)
+
+ _tidy(clear)
+ return clear
+
+ def pop(fn):
+ def pop(self, key, default=Unspecified):
+ __before_pop(self)
+ _to_del = key in self
+ if default is Unspecified:
+ item = fn(self, key)
+ else:
+ item = fn(self, key, default)
+ if _to_del:
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ def popitem(fn):
+ def popitem(self):
+ __before_pop(self)
+ item = fn(self)
+ __del(self, item[1])
+ return item
+
+ _tidy(popitem)
+ return popitem
+
+ def setdefault(fn):
+ def setdefault(self, key, default=None):
+ if key not in self:
+ self.__setitem__(key, default)
+ return default
+ else:
+ value = self.__getitem__(key)
+ if value is default:
+ __set_wo_mutation(self, value, None)
+
+ return value
+
+ _tidy(setdefault)
+ return setdefault
+
+ def update(fn):
+ def update(self, __other=Unspecified, **kw):
+ if __other is not Unspecified:
+ if hasattr(__other, "keys"):
+ for key in list(__other):
+ if key not in self or self[key] is not __other[key]:
+ self[key] = __other[key]
+ else:
+ __set_wo_mutation(self, __other[key], None)
+ else:
+ for key, value in __other:
+ if key not in self or self[key] is not value:
+ self[key] = value
+ else:
+ __set_wo_mutation(self, value, None)
+ for key in kw:
+ if key not in self or self[key] is not kw[key]:
+ self[key] = kw[key]
+ else:
+ __set_wo_mutation(self, kw[key], None)
+
+ _tidy(update)
+ return update
+
+ l = locals().copy()
+ l.pop("_tidy")
+ l.pop("Unspecified")
+ return l
+
+
+_set_binop_bases = (set, frozenset)
+
+
+def _set_binops_check_strict(self, obj):
+ """Allow only set, frozenset and self.__class__-derived
+ objects in binops."""
+ return isinstance(obj, _set_binop_bases + (self.__class__,))
+
+
+def _set_binops_check_loose(self, obj):
+ """Allow anything set-like to participate in set binops."""
+ return (
+ isinstance(obj, _set_binop_bases + (self.__class__,))
+ or util.duck_type_collection(obj) == set
+ )
+
+
+def _set_decorators():
+ """Tailored instrumentation wrappers for any set-like class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(set, fn.__name__).__doc__
+
+ Unspecified = util.symbol("Unspecified")
+
+ def add(fn):
+ def add(self, value, _sa_initiator=None):
+ if value not in self:
+ value = __set(self, value, _sa_initiator)
+ else:
+ __set_wo_mutation(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(add)
+ return add
+
+ def discard(fn):
+ def discard(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
+ if value in self:
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(discard)
+ return discard
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
+ if value in self:
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(remove)
+ return remove
+
+ def pop(fn):
+ def pop(self):
+ __before_pop(self)
+ item = fn(self)
+ # for set in particular, we have no way to access the item
+ # that will be popped before pop is called.
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ def clear(fn):
+ def clear(self):
+ for item in list(self):
+ self.remove(item)
+
+ _tidy(clear)
+ return clear
+
+ def update(fn):
+ def update(self, value):
+ for item in value:
+ self.add(item)
+
+ _tidy(update)
+ return update
+
+ def __ior__(fn):
+ def __ior__(self, value):
+ if not _set_binops_check_strict(self, value):
+ return NotImplemented
+ for item in value:
+ self.add(item)
+ return self
+
+ _tidy(__ior__)
+ return __ior__
+
+ def difference_update(fn):
+ def difference_update(self, value):
+ for item in value:
+ self.discard(item)
+
+ _tidy(difference_update)
+ return difference_update
+
+ def __isub__(fn):
+ def __isub__(self, value):
+ if not _set_binops_check_strict(self, value):
+ return NotImplemented
+ for item in value:
+ self.discard(item)
+ return self
+
+ _tidy(__isub__)
+ return __isub__
+
+ def intersection_update(fn):
+ def intersection_update(self, other):
+ want, have = self.intersection(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+
+ _tidy(intersection_update)
+ return intersection_update
+
+ def __iand__(fn):
+ def __iand__(self, other):
+ if not _set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.intersection(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+
+ _tidy(__iand__)
+ return __iand__
+
+ def symmetric_difference_update(fn):
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+
+ _tidy(symmetric_difference_update)
+ return symmetric_difference_update
+
+ def __ixor__(fn):
+ def __ixor__(self, other):
+ if not _set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.symmetric_difference(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+
+ _tidy(__ixor__)
+ return __ixor__
+
+ l = locals().copy()
+ l.pop("_tidy")
+ l.pop("Unspecified")
+ return l
+
+
+class InstrumentedList(list):
+ """An instrumented version of the built-in list."""
+
+
+class InstrumentedSet(set):
+ """An instrumented version of the built-in set."""
+
+
+class InstrumentedDict(dict):
+ """An instrumented version of the built-in dict."""
+
+
+__canned_instrumentation = {
+ list: InstrumentedList,
+ set: InstrumentedSet,
+ dict: InstrumentedDict,
+}
+
+__interfaces = {
+ list: (
+ {"appender": "append", "remover": "remove", "iterator": "__iter__"},
+ _list_decorators(),
+ ),
+ set: (
+ {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+ _set_decorators(),
+ ),
+ # decorators are required for dicts and object collections.
+ dict: ({"iterator": "values"}, _dict_decorators())
+ if util.py3k
+ else ({"iterator": "itervalues"}, _dict_decorators()),
+}
+
+
+class MappedCollection(dict):
+ """A basic dictionary-based collection class.
+
+ Extends dict with the minimal bag semantics that collection
+ classes require. ``set`` and ``remove`` are implemented in terms
+ of a keying function: any callable that takes an object and
+ returns an object for use as a dictionary key.
+
+ """
+
+ def __init__(self, keyfunc):
+ """Create a new collection with keying provided by keyfunc.
+
+ keyfunc may be any callable that takes an object and returns an object
+ for use as a dictionary key.
+
+ The keyfunc will be called every time the ORM needs to add a member by
+ value-only (such as when loading instances from the database) or
+ remove a member. The usual cautions about dictionary keying apply-
+ ``keyfunc(object)`` should return the same output for the life of the
+ collection. Keying based on mutable properties can result in
+ unreachable instances "lost" in the collection.
+
+ """
+ self.keyfunc = keyfunc
+
+ @collection.appender
+ @collection.internally_instrumented
+ def set(self, value, _sa_initiator=None):
+ """Add an item by value, consulting the keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ self.__setitem__(key, value, _sa_initiator)
+
+ @collection.remover
+ @collection.internally_instrumented
+ def remove(self, value, _sa_initiator=None):
+ """Remove an item by value, consulting the keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ # Let self[key] raise if key is not in this collection
+ # testlib.pragma exempt:__ne__
+ if self[key] != value:
+ raise sa_exc.InvalidRequestError(
+ "Can not remove '%s': collection holds '%s' for key '%s'. "
+ "Possible cause: is the MappedCollection key function "
+ "based on mutable properties or properties that only obtain "
+ "values after flush?" % (value, self[key], key)
+ )
+ self.__delitem__(key, _sa_initiator)
+
+
+# ensure instrumentation is associated with
+# these built-in classes; if a user-defined class
+# subclasses these and uses @internally_instrumented,
+# the superclass is otherwise not instrumented.
+# see [ticket:2406].
+_instrument_class(MappedCollection)
+_instrument_class(InstrumentedList)
+_instrument_class(InstrumentedSet)
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
new file mode 100644
index 0000000..9d4f652
--- /dev/null
+++ b/lib/sqlalchemy/orm/context.py
@@ -0,0 +1,3136 @@
+# orm/context.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import itertools
+
+from . import attributes
+from . import interfaces
+from . import loading
+from .base import _is_aliased_class
+from .interfaces import ORMColumnsClauseRole
+from .path_registry import PathRegistry
+from .util import _entity_corresponds_to
+from .util import _ORMJoin
+from .util import aliased
+from .util import Bundle
+from .util import ORMAdapter
+from .. import exc as sa_exc
+from .. import future
+from .. import inspect
+from .. import sql
+from .. import util
+from ..sql import ClauseElement
+from ..sql import coercions
+from ..sql import expression
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.base import _entity_namespace_key
+from ..sql.base import _select_iterables
+from ..sql.base import CacheableOptions
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
+from ..sql.selectable import LABEL_STYLE_NONE
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectState
+from ..sql.visitors import ExtendedInternalTraversal
+from ..sql.visitors import InternalTraversal
+
+_path_registry = PathRegistry.root
+
+_EMPTY_DICT = util.immutabledict()
+
+
+LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+
+
+class QueryContext(object):
+ __slots__ = (
+ "compile_state",
+ "query",
+ "params",
+ "load_options",
+ "bind_arguments",
+ "execution_options",
+ "session",
+ "autoflush",
+ "populate_existing",
+ "invoke_all_eagers",
+ "version_check",
+ "refresh_state",
+ "create_eager_joins",
+ "propagated_loader_options",
+ "attributes",
+ "runid",
+ "partials",
+ "post_load_paths",
+ "identity_token",
+ "yield_per",
+ "loaders_require_buffering",
+ "loaders_require_uniquing",
+ )
+
+ class default_load_options(Options):
+ _only_return_tuples = False
+ _populate_existing = False
+ _version_check = False
+ _invoke_all_eagers = True
+ _autoflush = True
+ _refresh_identity_token = None
+ _yield_per = None
+ _refresh_state = None
+ _lazy_loaded_from = None
+ _legacy_uniquing = False
+
+ def __init__(
+ self,
+ compile_state,
+ statement,
+ params,
+ session,
+ load_options,
+ execution_options=None,
+ bind_arguments=None,
+ ):
+ self.load_options = load_options
+ self.execution_options = execution_options or _EMPTY_DICT
+ self.bind_arguments = bind_arguments or _EMPTY_DICT
+ self.compile_state = compile_state
+ self.query = statement
+ self.session = session
+ self.loaders_require_buffering = False
+ self.loaders_require_uniquing = False
+ self.params = params
+
+ self.propagated_loader_options = {
+ # issue 7447.
+ # propagated loader options will be present on loaded InstanceState
+ # objects under state.load_options and are typically used by
+ # LazyLoader to apply options to the SELECT statement it emits.
+ # For compile state options (i.e. loader strategy options), these
+ # need to line up with the ".load_path" attribute which in
+ # loader.py is pulled from context.compile_state.current_path.
+ # so, this means these options have to be the ones from the
+ # *cached* statement that's travelling with compile_state, not the
+ # *current* statement which won't match up for an ad-hoc
+ # AliasedClass
+ cached_o
+ for cached_o in compile_state.select_statement._with_options
+ if cached_o.propagate_to_loaders and cached_o._is_compile_state
+ } | {
+ # for user defined loader options that are not "compile state",
+ # those just need to be present as they are
+ uncached_o
+ for uncached_o in statement._with_options
+ if uncached_o.propagate_to_loaders
+ and not uncached_o._is_compile_state
+ }
+
+ self.attributes = dict(compile_state.attributes)
+
+ self.autoflush = load_options._autoflush
+ self.populate_existing = load_options._populate_existing
+ self.invoke_all_eagers = load_options._invoke_all_eagers
+ self.version_check = load_options._version_check
+ self.refresh_state = load_options._refresh_state
+ self.yield_per = load_options._yield_per
+ self.identity_token = load_options._refresh_identity_token
+
+ if self.yield_per and compile_state._no_yield_pers:
+ raise sa_exc.InvalidRequestError(
+ "The yield_per Query option is currently not "
+ "compatible with %s eager loading. Please "
+ "specify lazyload('*') or query.enable_eagerloads(False) in "
+ "order to "
+ "proceed with query.yield_per()."
+ % ", ".join(compile_state._no_yield_pers)
+ )
+
+
+_orm_load_exec_options = util.immutabledict(
+ {"_result_disable_adapt_to_context": True, "future_result": True}
+)
+
+
+class ORMCompileState(CompileState):
+ # note this is a dictionary, but the
+ # default_compile_options._with_polymorphic_adapt_map is a tuple
+ _with_polymorphic_adapt_map = _EMPTY_DICT
+
+ class default_compile_options(CacheableOptions):
+ _cache_key_traversal = [
+ ("_use_legacy_query_style", InternalTraversal.dp_boolean),
+ ("_for_statement", InternalTraversal.dp_boolean),
+ ("_bake_ok", InternalTraversal.dp_boolean),
+ (
+ "_with_polymorphic_adapt_map",
+ ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ ("_current_path", InternalTraversal.dp_has_cache_key),
+ ("_enable_single_crit", InternalTraversal.dp_boolean),
+ ("_enable_eagerloads", InternalTraversal.dp_boolean),
+ ("_orm_only_from_obj_alias", InternalTraversal.dp_boolean),
+ ("_only_load_props", InternalTraversal.dp_plain_obj),
+ ("_set_base_alias", InternalTraversal.dp_boolean),
+ ("_for_refresh_state", InternalTraversal.dp_boolean),
+ ("_render_for_subquery", InternalTraversal.dp_boolean),
+ ("_is_star", InternalTraversal.dp_boolean),
+ ]
+
+ # set to True by default from Query._statement_20(), to indicate
+ # the rendered query should look like a legacy ORM query. right
+ # now this basically indicates we should use tablename_columnname
+ # style labels. Generally indicates the statement originated
+ # from a Query object.
+ _use_legacy_query_style = False
+
+ # set *only* when we are coming from the Query.statement
+ # accessor, or a Query-level equivalent such as
+ # query.subquery(). this supersedes "toplevel".
+ _for_statement = False
+
+ _bake_ok = True
+ _with_polymorphic_adapt_map = ()
+ _current_path = _path_registry
+ _enable_single_crit = True
+ _enable_eagerloads = True
+ _orm_only_from_obj_alias = True
+ _only_load_props = None
+ _set_base_alias = False
+ _for_refresh_state = False
+ _render_for_subquery = False
+ _is_star = False
+
+ current_path = _path_registry
+
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def _append_dedupe_col_collection(self, obj, col_collection):
+ dedupe = self.dedupe_columns
+ if obj not in dedupe:
+ dedupe.add(obj)
+ col_collection.append(obj)
+
+ @classmethod
+ def _column_naming_convention(cls, label_style, legacy):
+
+ if legacy:
+
+ def name(col, col_name=None):
+ if col_name:
+ return col_name
+ else:
+ return getattr(col, "key")
+
+ return name
+ else:
+ return SelectState._column_naming_convention(label_style)
+
+ @classmethod
+ def create_for_statement(cls, statement_container, compiler, **kw):
+ """Create a context for a statement given a :class:`.Compiler`.
+
+ This method is always invoked in the context of SQLCompiler.process().
+
+ For a Select object, this would be invoked from
+ SQLCompiler.visit_select(). For the special FromStatement object used
+ by Query to indicate "Query.from_statement()", this is called by
+ FromStatement._compiler_dispatch() that would be called by
+ SQLCompiler.process().
+
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def get_column_descriptions(cls, statement):
+ return _column_descriptions(statement)
+
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ if is_reentrant_invoke:
+ return statement, execution_options
+
+ (
+ load_options,
+ execution_options,
+ ) = QueryContext.default_load_options.from_execution_options(
+ "_sa_orm_load_options",
+ {"populate_existing", "autoflush", "yield_per"},
+ execution_options,
+ statement._execution_options,
+ )
+
+ # default execution options for ORM results:
+ # 1. _result_disable_adapt_to_context=True
+ # this will disable the ResultSetMetadata._adapt_to_context()
+ # step which we don't need, as we have result processors cached
+ # against the original SELECT statement before caching.
+ # 2. future_result=True. The ORM should **never** resolve columns
+ # in a result set based on names, only on Column objects that
+ # are correctly adapted to the context. W the legacy result
+ # it will still attempt name-based resolution and also emit a
+ # warning.
+ if not execution_options:
+ execution_options = _orm_load_exec_options
+ else:
+ execution_options = execution_options.union(_orm_load_exec_options)
+
+ if load_options._yield_per:
+ execution_options = execution_options.union(
+ {"yield_per": load_options._yield_per}
+ )
+
+ bind_arguments["clause"] = statement
+
+ # new in 1.4 - the coercions system is leveraged to allow the
+ # "subject" mapper of a statement be propagated to the top
+ # as the statement is built. "subject" mapper is the generally
+ # standard object used as an identifier for multi-database schemes.
+
+ # we are here based on the fact that _propagate_attrs contains
+ # "compile_state_plugin": "orm". The "plugin_subject"
+ # needs to be present as well.
+
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ if plugin_subject:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ if load_options._autoflush:
+ session._autoflush()
+
+ return statement, execution_options
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ # cover edge case where ORM entities used in legacy select
+ # were passed to session.execute:
+ # session.execute(legacy_select([User.id, User.name]))
+ # see test_query->test_legacy_tuple_old_select
+
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ if compile_state.compile_options._is_star:
+ return result
+
+ querycontext = QueryContext(
+ compile_state,
+ statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+
+ @property
+ def _lead_mapper_entities(self):
+ """return all _MapperEntity objects in the lead entities collection.
+
+ Does **not** include entities that have been replaced by
+ with_entities(), with_only_columns()
+
+ """
+ return [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]
+
+ def _create_with_polymorphic_adapter(self, ext_info, selectable):
+ if (
+ not ext_info.is_aliased_class
+ and ext_info.mapper.persist_selectable
+ not in self._polymorphic_adapters
+ ):
+ for mp in ext_info.mapper.iterate_to_root():
+ self._mapper_loads_polymorphically_with(
+ mp,
+ sql_util.ColumnAdapter(selectable, mp._equivalent_columns),
+ )
+
+ def _mapper_loads_polymorphically_with(self, mapper, adapter):
+ for m2 in mapper._with_polymorphic_mappers or [mapper]:
+ self._polymorphic_adapters[m2] = adapter
+ for m in m2.iterate_to_root(): # TODO: redundant ?
+ self._polymorphic_adapters[m.local_table] = adapter
+
+ @classmethod
+ def _create_entities_collection(cls, query, legacy):
+ raise NotImplementedError(
+ "this method only works for ORMSelectCompileState"
+ )
+
+
+@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
+class ORMFromStatementCompileState(ORMCompileState):
+ _aliased_generations = util.immutabledict()
+ _from_obj_alias = None
+ _has_mapper_entities = False
+
+ _has_orm_entities = False
+ multi_row_eager_loaders = False
+ compound_eager_adapter = None
+
+ extra_criteria_entities = _EMPTY_DICT
+ eager_joins = _EMPTY_DICT
+
+ @classmethod
+ def create_for_statement(cls, statement_container, compiler, **kw):
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+
+ self = cls.__new__(cls)
+ self._primary_entity = None
+
+ self.use_legacy_query_style = (
+ statement_container._compile_options._use_legacy_query_style
+ )
+ self.statement_container = self.select_statement = statement_container
+ self.requested_statement = statement = statement_container.element
+
+ if statement.is_dml:
+ self.dml_table = statement.table
+
+ self._entities = []
+ self._polymorphic_adapters = {}
+ self._no_yield_pers = set()
+
+ self.compile_options = statement_container._compile_options
+
+ if (
+ self.use_legacy_query_style
+ and isinstance(statement, expression.SelectBase)
+ and not statement._is_textual
+ and not statement.is_dml
+ and statement._label_style is LABEL_STYLE_NONE
+ ):
+ self.statement = statement.set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ else:
+ self.statement = statement
+
+ self._label_convention = self._column_naming_convention(
+ statement._label_style
+ if not statement._is_textual and not statement.is_dml
+ else LABEL_STYLE_NONE,
+ self.use_legacy_query_style,
+ )
+
+ _QueryEntity.to_compile_state(
+ self,
+ statement_container._raw_columns,
+ self._entities,
+ is_current_entities=True,
+ )
+
+ self.current_path = statement_container._compile_options._current_path
+
+ if toplevel and statement_container._with_options:
+ self.attributes = {"_unbound_load_dedupes": set()}
+ self.global_attributes = compiler._global_attributes
+
+ for opt in statement_container._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state(self)
+
+ else:
+ self.attributes = {}
+ self.global_attributes = compiler._global_attributes
+
+ if statement_container._with_context_options:
+ for fn, key in statement_container._with_context_options:
+ fn(self)
+
+ self.primary_columns = []
+ self.secondary_columns = []
+ self.dedupe_columns = set()
+ self.create_eager_joins = []
+ self._fallback_from_clauses = []
+
+ self.order_by = None
+
+ if isinstance(
+ self.statement, (expression.TextClause, expression.UpdateBase)
+ ):
+
+ self.extra_criteria_entities = {}
+
+ # setup for all entities. Currently, this is not useful
+ # for eager loaders, as the eager loaders that work are able
+ # to do their work entirely in row_processor.
+ for entity in self._entities:
+ entity.setup_compile_state(self)
+
+ # we did the setup just to get primary columns.
+ self.statement = _AdHocColumnsStatement(
+ self.statement, self.primary_columns
+ )
+ else:
+ # allow TextualSelect with implicit columns as well
+ # as select() with ad-hoc columns, see test_query::TextTest
+ self._from_obj_alias = sql.util.ColumnAdapter(
+ self.statement, adapt_on_names=True
+ )
+ # set up for eager loaders, however if we fix subqueryload
+ # it should not need to do this here. the model of eager loaders
+ # that can work entirely in row_processor might be interesting
+ # here though subqueryloader has a lot of upfront work to do
+ # see test/orm/test_query.py -> test_related_eagerload_against_text
+ # for where this part makes a difference. would rather have
+ # subqueryload figure out what it needs more intelligently.
+ # for entity in self._entities:
+ # entity.setup_compile_state(self)
+
+ return self
+
+ def _adapt_col_list(self, cols, current_adapter):
+ return cols
+
+ def _get_current_adapter(self):
+ return None
+
+
+class _AdHocColumnsStatement(ClauseElement):
+ """internal object created to somewhat act like a SELECT when we
+ are selecting columns from a DML RETURNING.
+
+
+ """
+
+ __visit_name__ = None
+
+ def __init__(self, text, columns):
+ self.element = text
+ self.column_args = [
+ coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+ ]
+
+ def _generate_cache_key(self):
+ raise NotImplementedError()
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ raise NotImplementedError()
+
+ def _compiler_dispatch(
+ self, compiler, compound_index=None, asfrom=False, **kw
+ ):
+ """provide a fixed _compiler_dispatch method."""
+
+ toplevel = not compiler.stack
+ entry = (
+ compiler._default_stack_entry if toplevel else compiler.stack[-1]
+ )
+
+ populate_result_map = (
+ toplevel
+ # these two might not be needed
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ if populate_result_map:
+ compiler._ordered_columns = (
+ compiler._textual_ordered_columns
+ ) = False
+
+ # enable looser result column matching. this is shown to be
+ # needed by test_query.py::TextTest
+ compiler._loose_column_name_matching = True
+
+ for c in self.column_args:
+ compiler.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=compiler._add_to_result_map,
+ )
+ return compiler.process(self.element, **kw)
+
+
+@sql.base.CompileState.plugin_for("orm", "select")
+class ORMSelectCompileState(ORMCompileState, SelectState):
+ _joinpath = _joinpoint = _EMPTY_DICT
+
+ _memoized_entities = _EMPTY_DICT
+
+ _from_obj_alias = None
+ _has_mapper_entities = False
+
+ _has_orm_entities = False
+ multi_row_eager_loaders = False
+ compound_eager_adapter = None
+
+ correlate = None
+ correlate_except = None
+ _where_criteria = ()
+ _having_criteria = ()
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ """compiler hook, we arrive here from compiler.visit_select() only."""
+
+ self = cls.__new__(cls)
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ self.global_attributes = compiler._global_attributes
+ else:
+ toplevel = True
+ self.global_attributes = {}
+
+ select_statement = statement
+
+ # if we are a select() that was never a legacy Query, we won't
+ # have ORM level compile options.
+ statement._compile_options = cls.default_compile_options.safe_merge(
+ statement._compile_options
+ )
+
+ if select_statement._execution_options:
+ # execution options should not impact the compilation of a
+ # query, and at the moment subqueryloader is putting some things
+ # in here that we explicitly don't want stuck in a cache.
+ self.select_statement = select_statement._clone()
+ self.select_statement._execution_options = util.immutabledict()
+ else:
+ self.select_statement = select_statement
+
+ # indicates this select() came from Query.statement
+ self.for_statement = select_statement._compile_options._for_statement
+
+ # generally if we are from Query or directly from a select()
+ self.use_legacy_query_style = (
+ select_statement._compile_options._use_legacy_query_style
+ )
+
+ self._entities = []
+ self._primary_entity = None
+ self._aliased_generations = {}
+ self._polymorphic_adapters = {}
+ self._no_yield_pers = set()
+
+ # legacy: only for query.with_polymorphic()
+ if select_statement._compile_options._with_polymorphic_adapt_map:
+ self._with_polymorphic_adapt_map = dict(
+ select_statement._compile_options._with_polymorphic_adapt_map
+ )
+ self._setup_with_polymorphics()
+
+ self.compile_options = select_statement._compile_options
+
+ if not toplevel:
+ # for subqueries, turn off eagerloads and set
+ # "render_for_subquery".
+ self.compile_options += {
+ "_enable_eagerloads": False,
+ "_render_for_subquery": True,
+ }
+
+ # determine label style. we can make different decisions here.
+ # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY
+ # rather than LABEL_STYLE_NONE, and if we can use disambiguate style
+ # for new style ORM selects too.
+ if (
+ self.use_legacy_query_style
+ and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM
+ ):
+ if not self.for_statement:
+ self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+ else:
+ self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY
+ else:
+ self.label_style = self.select_statement._label_style
+
+ if select_statement._memoized_select_entities:
+ self._memoized_entities = {
+ memoized_entities: _QueryEntity.to_compile_state(
+ self,
+ memoized_entities._raw_columns,
+ [],
+ is_current_entities=False,
+ )
+ for memoized_entities in (
+ select_statement._memoized_select_entities
+ )
+ }
+
+ # label_convention is stateful and will yield deduping keys if it
+ # sees the same key twice. therefore it's important that it is not
+ # invoked for the above "memoized" entities that aren't actually
+ # in the columns clause
+ self._label_convention = self._column_naming_convention(
+ statement._label_style, self.use_legacy_query_style
+ )
+
+ _QueryEntity.to_compile_state(
+ self,
+ select_statement._raw_columns,
+ self._entities,
+ is_current_entities=True,
+ )
+
+ self.current_path = select_statement._compile_options._current_path
+
+ self.eager_order_by = ()
+
+ if toplevel and (
+ select_statement._with_options
+ or select_statement._memoized_select_entities
+ ):
+ self.attributes = {"_unbound_load_dedupes": set()}
+
+ for (
+ memoized_entities
+ ) in select_statement._memoized_select_entities:
+ for opt in memoized_entities._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state_replaced_entities(
+ self,
+ [
+ ent
+ for ent in self._memoized_entities[
+ memoized_entities
+ ]
+ if isinstance(ent, _MapperEntity)
+ ],
+ )
+
+ for opt in self.select_statement._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state(self)
+ else:
+ self.attributes = {}
+
+ if select_statement._with_context_options:
+ for fn, key in select_statement._with_context_options:
+ fn(self)
+
+ self.primary_columns = []
+ self.secondary_columns = []
+ self.dedupe_columns = set()
+ self.eager_joins = {}
+ self.extra_criteria_entities = {}
+ self.create_eager_joins = []
+ self._fallback_from_clauses = []
+
+ # normalize the FROM clauses early by themselves, as this makes
+ # it an easier job when we need to assemble a JOIN onto these,
+ # for select.join() as well as joinedload(). As of 1.4 there are now
+ # potentially more complex sets of FROM objects here as the use
+ # of lambda statements for lazyload, load_on_pk etc. uses more
+ # cloning of the select() construct. See #6495
+ self.from_clauses = self._normalize_froms(
+ info.selectable for info in select_statement._from_obj
+ )
+
+ # this is a fairly arbitrary break into a second method,
+ # so it might be nicer to break up create_for_statement()
+ # and _setup_for_generate into three or four logical sections
+ self._setup_for_generate()
+
+ SelectState.__init__(self, self.statement, compiler, **kw)
+
+ return self
+
+ def _setup_for_generate(self):
+ query = self.select_statement
+
+ self.statement = None
+ self._join_entities = ()
+
+ if self.compile_options._set_base_alias:
+ self._set_select_from_alias()
+
+ for memoized_entities in query._memoized_select_entities:
+ if memoized_entities._setup_joins:
+ self._join(
+ memoized_entities._setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+ if memoized_entities._legacy_setup_joins:
+ self._legacy_join(
+ memoized_entities._legacy_setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+
+ if query._setup_joins:
+ self._join(query._setup_joins, self._entities)
+
+ if query._legacy_setup_joins:
+ self._legacy_join(query._legacy_setup_joins, self._entities)
+
+ current_adapter = self._get_current_adapter()
+
+ if query._where_criteria:
+ self._where_criteria = query._where_criteria
+
+ if current_adapter:
+ self._where_criteria = tuple(
+ current_adapter(crit, True)
+ for crit in self._where_criteria
+ )
+
+ # TODO: some complexity with order_by here was due to mapper.order_by.
+ # now that this is removed we can hopefully make order_by /
+ # group_by act identically to how they are in Core select.
+ self.order_by = (
+ self._adapt_col_list(query._order_by_clauses, current_adapter)
+ if current_adapter and query._order_by_clauses not in (None, False)
+ else query._order_by_clauses
+ )
+
+ if query._having_criteria:
+ self._having_criteria = tuple(
+ current_adapter(crit, True) if current_adapter else crit
+ for crit in query._having_criteria
+ )
+
+ self.group_by = (
+ self._adapt_col_list(
+ util.flatten_iterator(query._group_by_clauses), current_adapter
+ )
+ if current_adapter and query._group_by_clauses not in (None, False)
+ else query._group_by_clauses or None
+ )
+
+ if self.eager_order_by:
+ adapter = self.from_clauses[0]._target_adapter
+ self.eager_order_by = adapter.copy_and_process(self.eager_order_by)
+
+ if query._distinct_on:
+ self.distinct_on = self._adapt_col_list(
+ query._distinct_on, current_adapter
+ )
+ else:
+ self.distinct_on = ()
+
+ self.distinct = query._distinct
+
+ if query._correlate:
+ # ORM mapped entities that are mapped to joins can be passed
+ # to .correlate, so here they are broken into their component
+ # tables.
+ self.correlate = tuple(
+ util.flatten_iterator(
+ sql_util.surface_selectables(s) if s is not None else None
+ for s in query._correlate
+ )
+ )
+ elif query._correlate_except is not None:
+ self.correlate_except = tuple(
+ util.flatten_iterator(
+ sql_util.surface_selectables(s) if s is not None else None
+ for s in query._correlate_except
+ )
+ )
+ elif not query._auto_correlate:
+ self.correlate = (None,)
+
+ # PART II
+
+ self._for_update_arg = query._for_update_arg
+
+ if self.compile_options._is_star and (len(self._entities) != 1):
+ raise sa_exc.CompileError(
+ "Can't generate ORM query that includes multiple expressions "
+ "at the same time as '*'; query for '*' alone if present"
+ )
+ for entity in self._entities:
+ entity.setup_compile_state(self)
+
+ for rec in self.create_eager_joins:
+ strategy = rec[0]
+ strategy(self, *rec[1:])
+
+ # else "load from discrete FROMs" mode,
+ # i.e. when each _MappedEntity has its own FROM
+
+ if self.compile_options._enable_single_crit:
+ self._adjust_for_extra_criteria()
+
+ if not self.primary_columns:
+ if self.compile_options._only_load_props:
+ raise sa_exc.InvalidRequestError(
+ "No column-based properties specified for "
+ "refresh operation. Use session.expire() "
+ "to reload collections and related items."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Query contains no columns with which to SELECT from."
+ )
+
+ if not self.from_clauses:
+ self.from_clauses = list(self._fallback_from_clauses)
+
+ if self.order_by is False:
+ self.order_by = None
+
+ if self.multi_row_eager_loaders and self._should_nest_selectable:
+ self.statement = self._compound_eager_statement()
+ else:
+ self.statement = self._simple_statement()
+
+ if self.for_statement:
+ ezero = self._mapper_zero()
+ if ezero is not None:
+ # TODO: this goes away once we get rid of the deep entity
+ # thing
+ self.statement = self.statement._annotate(
+ {"deepentity": ezero}
+ )
+
+ @classmethod
+ def _create_entities_collection(cls, query, legacy):
+ """Creates a partial ORMSelectCompileState that includes
+ the full collection of _MapperEntity and other _QueryEntity objects.
+
+ Supports a few remaining use cases that are pre-compilation
+ but still need to gather some of the column / adaption information.
+
+ """
+ self = cls.__new__(cls)
+
+ self._entities = []
+ self._primary_entity = None
+ self._aliased_generations = {}
+ self._polymorphic_adapters = {}
+
+ compile_options = cls.default_compile_options.safe_merge(
+ query._compile_options
+ )
+ # legacy: only for query.with_polymorphic()
+ if compile_options._with_polymorphic_adapt_map:
+ self._with_polymorphic_adapt_map = dict(
+ compile_options._with_polymorphic_adapt_map
+ )
+ self._setup_with_polymorphics()
+
+ self._label_convention = self._column_naming_convention(
+ query._label_style, legacy
+ )
+
+ # entities will also set up polymorphic adapters for mappers
+ # that have with_polymorphic configured
+ _QueryEntity.to_compile_state(
+ self, query._raw_columns, self._entities, is_current_entities=True
+ )
+ return self
+
+ @classmethod
+ def determine_last_joined_entity(cls, statement):
+ setup_joins = statement._setup_joins
+
+ if not setup_joins:
+ return None
+
+ (target, onclause, from_, flags) = setup_joins[-1]
+
+ if isinstance(target, interfaces.PropComparator):
+ return target.entity
+ else:
+ return target
+
+ @classmethod
+ def all_selected_columns(cls, statement):
+ for element in statement._raw_columns:
+ if (
+ element.is_selectable
+ and "entity_namespace" in element._annotations
+ ):
+ ens = element._annotations["entity_namespace"]
+ if not ens.is_mapper and not ens.is_aliased_class:
+ for elem in _select_iterables([element]):
+ yield elem
+ else:
+ for elem in _select_iterables(ens._all_column_expressions):
+ yield elem
+ else:
+ for elem in _select_iterables([element]):
+ yield elem
+
+ @classmethod
+ def get_columns_clause_froms(cls, statement):
+ return cls._normalize_froms(
+ itertools.chain.from_iterable(
+ element._from_objects
+ if "parententity" not in element._annotations
+ else [
+ element._annotations["parententity"].__clause_element__()
+ ]
+ for element in statement._raw_columns
+ )
+ )
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm.query")
+ def from_statement(cls, statement, from_statement):
+ query = util.preloaded.orm_query
+
+ from_statement = coercions.expect(
+ roles.ReturnsRowsRole,
+ from_statement,
+ apply_propagate_attrs=statement,
+ )
+
+ stmt = query.FromStatement(statement._raw_columns, from_statement)
+
+ stmt.__dict__.update(
+ _with_options=statement._with_options,
+ _with_context_options=statement._with_context_options,
+ _execution_options=statement._execution_options,
+ _propagate_attrs=statement._propagate_attrs,
+ )
+ return stmt
+
+ def _setup_with_polymorphics(self):
+ # legacy: only for query.with_polymorphic()
+ for ext_info, wp in self._with_polymorphic_adapt_map.items():
+ self._mapper_loads_polymorphically_with(ext_info, wp._adapter)
+
+ def _set_select_from_alias(self):
+
+ query = self.select_statement # query
+
+ assert self.compile_options._set_base_alias
+ assert len(query._from_obj) == 1
+
+ adapter = self._get_select_from_alias_from_obj(query._from_obj[0])
+ if adapter:
+ self.compile_options += {"_enable_single_crit": False}
+ self._from_obj_alias = adapter
+
+ def _get_select_from_alias_from_obj(self, from_obj):
+ info = from_obj
+
+ if "parententity" in info._annotations:
+ info = info._annotations["parententity"]
+
+ if hasattr(info, "mapper"):
+ if not info.is_aliased_class:
+ raise sa_exc.ArgumentError(
+ "A selectable (FromClause) instance is "
+ "expected when the base alias is being set."
+ )
+ else:
+ return info._adapter
+
+ elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows):
+ equivs = self._all_equivs()
+ return sql_util.ColumnAdapter(info, equivs)
+ else:
+ return None
+
+ def _mapper_zero(self):
+ """return the Mapper associated with the first QueryEntity."""
+ return self._entities[0].mapper
+
+ def _entity_zero(self):
+ """Return the 'entity' (mapper or AliasedClass) associated
+ with the first QueryEntity, or alternatively the 'select from'
+ entity if specified."""
+
+ for ent in self.from_clauses:
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ for qent in self._entities:
+ if qent.entity_zero:
+ return qent.entity_zero
+
+ return None
+
+ def _only_full_mapper_zero(self, methname):
+ if self._entities != [self._primary_entity]:
+ raise sa_exc.InvalidRequestError(
+ "%s() can only be used against "
+ "a single mapped class." % methname
+ )
+ return self._primary_entity.entity_zero
+
+ def _only_entity_zero(self, rationale=None):
+ if len(self._entities) > 1:
+ raise sa_exc.InvalidRequestError(
+ rationale
+ or "This operation requires a Query "
+ "against a single mapper."
+ )
+ return self._entity_zero()
+
+ def _all_equivs(self):
+ equivs = {}
+
+ for memoized_entities in self._memoized_entities.values():
+ for ent in [
+ ent
+ for ent in memoized_entities
+ if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+
+ for ent in [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+ return equivs
+
+ def _compound_eager_statement(self):
+ # for eager joins present and LIMIT/OFFSET/DISTINCT,
+ # wrap the query inside a select,
+ # then append eager joins onto that
+
+ if self.order_by:
+ # the default coercion for ORDER BY is now the OrderByRole,
+ # which adds an additional post coercion to ByOfRole in that
+ # elements are converted into label references. For the
+ # eager load / subquery wrapping case, we need to un-coerce
+ # the original expressions outside of the label references
+ # in order to have them render.
+ unwrapped_order_by = [
+ elem.element
+ if isinstance(elem, sql.elements._label_reference)
+ else elem
+ for elem in self.order_by
+ ]
+
+ order_by_col_expr = sql_util.expand_column_list_from_order_by(
+ self.primary_columns, unwrapped_order_by
+ )
+ else:
+ order_by_col_expr = []
+ unwrapped_order_by = None
+
+ # put FOR UPDATE on the inner query, where MySQL will honor it,
+ # as well as if it has an OF so PostgreSQL can use it.
+ inner = self._select_statement(
+ self.primary_columns
+ + [c for c in order_by_col_expr if c not in self.dedupe_columns],
+ self.from_clauses,
+ self._where_criteria,
+ self._having_criteria,
+ self.label_style,
+ self.order_by,
+ for_update=self._for_update_arg,
+ hints=self.select_statement._hints,
+ statement_hints=self.select_statement._statement_hints,
+ correlate=self.correlate,
+ correlate_except=self.correlate_except,
+ **self._select_args
+ )
+
+ inner = inner.alias()
+
+ equivs = self._all_equivs()
+
+ self.compound_eager_adapter = sql_util.ColumnAdapter(inner, equivs)
+
+ statement = future.select(
+ *([inner] + self.secondary_columns) # use_labels=self.labels
+ )
+ statement._label_style = self.label_style
+
+ # Oracle however does not allow FOR UPDATE on the subquery,
+ # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL
+ # we expect that all elements of the row are locked, so also put it
+ # on the outside (except in the case of PG when OF is used)
+ if (
+ self._for_update_arg is not None
+ and self._for_update_arg.of is None
+ ):
+ statement._for_update_arg = self._for_update_arg
+
+ from_clause = inner
+ for eager_join in self.eager_joins.values():
+ # EagerLoader places a 'stop_on' attribute on the join,
+ # giving us a marker as to where the "splice point" of
+ # the join should be
+ from_clause = sql_util.splice_joins(
+ from_clause, eager_join, eager_join.stop_on
+ )
+
+ statement.select_from.non_generative(statement, from_clause)
+
+ if unwrapped_order_by:
+ statement.order_by.non_generative(
+ statement,
+ *self.compound_eager_adapter.copy_and_process(
+ unwrapped_order_by
+ )
+ )
+
+ statement.order_by.non_generative(statement, *self.eager_order_by)
+ return statement
+
+ def _simple_statement(self):
+
+ if (
+ self.compile_options._use_legacy_query_style
+ and (self.distinct and not self.distinct_on)
+ and self.order_by
+ ):
+ to_add = sql_util.expand_column_list_from_order_by(
+ self.primary_columns, self.order_by
+ )
+ if to_add:
+ util.warn_deprecated_20(
+ "ORDER BY columns added implicitly due to "
+ "DISTINCT is deprecated and will be removed in "
+ "SQLAlchemy 2.0. SELECT statements with DISTINCT "
+ "should be written to explicitly include the appropriate "
+ "columns in the columns clause"
+ )
+ self.primary_columns += to_add
+
+ statement = self._select_statement(
+ self.primary_columns + self.secondary_columns,
+ tuple(self.from_clauses) + tuple(self.eager_joins.values()),
+ self._where_criteria,
+ self._having_criteria,
+ self.label_style,
+ self.order_by,
+ for_update=self._for_update_arg,
+ hints=self.select_statement._hints,
+ statement_hints=self.select_statement._statement_hints,
+ correlate=self.correlate,
+ correlate_except=self.correlate_except,
+ **self._select_args
+ )
+
+ if self.eager_order_by:
+ statement.order_by.non_generative(statement, *self.eager_order_by)
+ return statement
+
+ def _select_statement(
+ self,
+ raw_columns,
+ from_obj,
+ where_criteria,
+ having_criteria,
+ label_style,
+ order_by,
+ for_update,
+ hints,
+ statement_hints,
+ correlate,
+ correlate_except,
+ limit_clause,
+ offset_clause,
+ fetch_clause,
+ fetch_clause_options,
+ distinct,
+ distinct_on,
+ prefixes,
+ suffixes,
+ group_by,
+ ):
+
+ Select = future.Select
+ statement = Select._create_raw_select(
+ _raw_columns=raw_columns,
+ _from_obj=from_obj,
+ _label_style=label_style,
+ )
+
+ if where_criteria:
+ statement._where_criteria = where_criteria
+ if having_criteria:
+ statement._having_criteria = having_criteria
+
+ if order_by:
+ statement._order_by_clauses += tuple(order_by)
+
+ if distinct_on:
+ statement.distinct.non_generative(statement, *distinct_on)
+ elif distinct:
+ statement.distinct.non_generative(statement)
+
+ if group_by:
+ statement._group_by_clauses += tuple(group_by)
+
+ statement._limit_clause = limit_clause
+ statement._offset_clause = offset_clause
+ statement._fetch_clause = fetch_clause
+ statement._fetch_clause_options = fetch_clause_options
+
+ if prefixes:
+ statement._prefixes = prefixes
+
+ if suffixes:
+ statement._suffixes = suffixes
+
+ statement._for_update_arg = for_update
+
+ if hints:
+ statement._hints = hints
+ if statement_hints:
+ statement._statement_hints = statement_hints
+
+ if correlate:
+ statement.correlate.non_generative(statement, *correlate)
+
+ if correlate_except is not None:
+ statement.correlate_except.non_generative(
+ statement, *correlate_except
+ )
+
+ return statement
+
+ def _adapt_polymorphic_element(self, element):
+ if "parententity" in element._annotations:
+ search = element._annotations["parententity"]
+ alias = self._polymorphic_adapters.get(search, None)
+ if alias:
+ return alias.adapt_clause(element)
+
+ if isinstance(element, expression.FromClause):
+ search = element
+ elif hasattr(element, "table"):
+ search = element.table
+ else:
+ return None
+
+ alias = self._polymorphic_adapters.get(search, None)
+ if alias:
+ return alias.adapt_clause(element)
+
+ def _adapt_aliased_generation(self, element):
+ # this is crazy logic that I look forward to blowing away
+ # when aliased=True is gone :)
+ if "aliased_generation" in element._annotations:
+ for adapter in self._aliased_generations.get(
+ element._annotations["aliased_generation"], ()
+ ):
+ replaced_elem = adapter.replace(element)
+ if replaced_elem is not None:
+ return replaced_elem
+
+ return None
+
+ def _adapt_col_list(self, cols, current_adapter):
+ if current_adapter:
+ return [current_adapter(o, True) for o in cols]
+ else:
+ return cols
+
+ def _get_current_adapter(self):
+
+ adapters = []
+
+ if self._from_obj_alias:
+ # used for legacy going forward for query set_ops, e.g.
+ # union(), union_all(), etc.
+ # 1.4 and previously, also used for from_self(),
+ # select_entity_from()
+ #
+ # for the "from obj" alias, apply extra rule to the
+ # 'ORM only' check, if this query were generated from a
+ # subquery of itself, i.e. _from_selectable(), apply adaption
+ # to all SQL constructs.
+ adapters.append(
+ (
+ False
+ if self.compile_options._orm_only_from_obj_alias
+ else True,
+ self._from_obj_alias.replace,
+ )
+ )
+
+ # vvvvvvvvvvvvvvv legacy vvvvvvvvvvvvvvvvvv
+ # this can totally go away when we remove join(..., aliased=True)
+ if self._aliased_generations:
+ adapters.append((False, self._adapt_aliased_generation))
+ # ^^^^^^^^^^^^^ legacy ^^^^^^^^^^^^^^^^^^^^^
+
+ # this was *hopefully* the only adapter we were going to need
+ # going forward...however, we unfortunately need _from_obj_alias
+ # for query.union(), which we can't drop
+ if self._polymorphic_adapters:
+ adapters.append((False, self._adapt_polymorphic_element))
+
+ if not adapters:
+ return None
+
+ def _adapt_clause(clause, as_filter):
+ # do we adapt all expression elements or only those
+ # tagged as 'ORM' constructs ?
+
+ def replace(elem):
+ is_orm_adapt = (
+ "_orm_adapt" in elem._annotations
+ or "parententity" in elem._annotations
+ )
+ for always_adapt, adapter in adapters:
+ if is_orm_adapt or always_adapt:
+ e = adapter(elem)
+ if e is not None:
+ return e
+
+ return visitors.replacement_traverse(clause, {}, replace)
+
+ return _adapt_clause
+
+ def _join(self, args, entities_collection):
+ for (right, onclause, from_, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+ # maybe?
+ self._reset_joinpoint()
+
+ right = inspect(right)
+ if onclause is not None:
+ onclause = inspect(onclause)
+
+ if onclause is None and isinstance(
+ right, interfaces.PropComparator
+ ):
+ # determine onclause/right_entity. still need to think
+ # about how to best organize this since we are getting:
+ #
+ #
+ # q.join(Entity, Parent.property)
+ # q.join(Parent.property)
+ # q.join(Parent.property.of_type(Entity))
+ # q.join(some_table)
+ # q.join(some_table, some_parent.c.id==some_table.c.parent_id)
+ #
+ # is this still too many choices? how do we handle this
+ # when sometimes "right" is implied and sometimes not?
+ #
+ onclause = right
+ right = None
+ elif "parententity" in right._annotations:
+ right = right._annotations["parententity"]
+
+ if onclause is None:
+ if not right.is_selectable and not hasattr(right, "mapper"):
+ raise sa_exc.ArgumentError(
+ "Expected mapped entity or "
+ "selectable/table as join target"
+ )
+
+ of_type = None
+
+ if isinstance(onclause, interfaces.PropComparator):
+ # descriptor/property given (or determined); this tells us
+ # explicitly what the expected "left" side of the join is.
+
+ of_type = getattr(onclause, "_of_type", None)
+
+ if right is None:
+ if of_type:
+ right = of_type
+ else:
+ right = onclause.property
+
+ try:
+ right = right.entity
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Join target %s does not refer to a "
+ "mapped entity" % right
+ ),
+ replace_context=err,
+ )
+
+ left = onclause._parententity
+
+ alias = self._polymorphic_adapters.get(left, None)
+
+ # could be None or could be ColumnAdapter also
+ if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
+ left = alias.aliased_class
+ onclause = getattr(left, onclause.key)
+
+ prop = onclause.property
+ if not isinstance(onclause, attributes.QueryableAttribute):
+ onclause = prop
+
+ # TODO: this is where "check for path already present"
+ # would occur. see if this still applies?
+
+ if from_ is not None:
+ if (
+ from_ is not left
+ and from_._annotations.get("parententity", None)
+ is not left
+ ):
+ raise sa_exc.InvalidRequestError(
+ "explicit from clause %s does not match left side "
+ "of relationship attribute %s"
+ % (
+ from_._annotations.get("parententity", from_),
+ onclause,
+ )
+ )
+ elif from_ is not None:
+ prop = None
+ left = from_
+ else:
+ # no descriptor/property given; we will need to figure out
+ # what the effective "left" side is
+ prop = left = None
+
+ # figure out the final "left" and "right" sides and create an
+ # ORMJoin to add to our _from_obj tuple
+ self._join_left_to_right(
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ False,
+ False,
+ isouter,
+ full,
+ )
+
+ def _legacy_join(self, args, entities_collection):
+ """consumes arguments from join() or outerjoin(), places them into a
+ consistent format with which to form the actual JOIN constructs.
+
+ """
+ for (right, onclause, left, flags) in args:
+
+ outerjoin = flags["isouter"]
+ create_aliases = flags["aliased"]
+ from_joinpoint = flags["from_joinpoint"]
+ full = flags["full"]
+ aliased_generation = flags["aliased_generation"]
+
+ # do a quick inspect to accommodate for a lambda
+ if right is not None and not isinstance(right, util.string_types):
+ right = inspect(right)
+ if onclause is not None and not isinstance(
+ onclause, util.string_types
+ ):
+ onclause = inspect(onclause)
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvv
+ if not from_joinpoint:
+ self._reset_joinpoint()
+ else:
+ prev_aliased_generation = self._joinpoint.get(
+ "aliased_generation", None
+ )
+ if not aliased_generation:
+ aliased_generation = prev_aliased_generation
+ elif prev_aliased_generation:
+ self._aliased_generations[
+ aliased_generation
+ ] = self._aliased_generations.get(
+ prev_aliased_generation, ()
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if (
+ isinstance(
+ right, (interfaces.PropComparator, util.string_types)
+ )
+ and onclause is None
+ ):
+ onclause = right
+ right = None
+ elif "parententity" in right._annotations:
+ right = right._annotations["parententity"]
+
+ if onclause is None:
+ if not right.is_selectable and not hasattr(right, "mapper"):
+ raise sa_exc.ArgumentError(
+ "Expected mapped entity or "
+ "selectable/table as join target"
+ )
+
+ if isinstance(onclause, interfaces.PropComparator):
+ of_type = getattr(onclause, "_of_type", None)
+ else:
+ of_type = None
+
+ if isinstance(onclause, util.string_types):
+ # string given, e.g. query(Foo).join("bar").
+ # we look to the left entity or what we last joined
+ # towards
+ onclause = _entity_namespace_key(
+ inspect(self._joinpoint_zero()), onclause
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
+ # check for q.join(Class.propname, from_joinpoint=True)
+ # and Class corresponds at the mapper level to the current
+ # joinpoint. this match intentionally looks for a non-aliased
+ # class-bound descriptor as the onclause and if it matches the
+ # current joinpoint at the mapper level, it's used. This
+ # is a very old use case that is intended to make it easier
+ # to work with the aliased=True flag, which is also something
+ # that probably shouldn't exist on join() due to its high
+ # complexity/usefulness ratio
+ elif from_joinpoint and isinstance(
+ onclause, interfaces.PropComparator
+ ):
+ jp0 = self._joinpoint_zero()
+ info = inspect(jp0)
+
+ if getattr(info, "mapper", None) is onclause._parententity:
+ onclause = _entity_namespace_key(info, onclause.key)
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if isinstance(onclause, interfaces.PropComparator):
+ # descriptor/property given (or determined); this tells us
+ # explicitly what the expected "left" side of the join is.
+ if right is None:
+ if of_type:
+ right = of_type
+ else:
+ right = onclause.property
+
+ try:
+ right = right.entity
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Join target %s does not refer to a "
+ "mapped entity" % right
+ ),
+ replace_context=err,
+ )
+
+ left = onclause._parententity
+
+ alias = self._polymorphic_adapters.get(left, None)
+
+ # could be None or could be ColumnAdapter also
+ if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
+ left = alias.aliased_class
+ onclause = getattr(left, onclause.key)
+
+ prop = onclause.property
+ if not isinstance(onclause, attributes.QueryableAttribute):
+ onclause = prop
+
+ if not create_aliases:
+ # check for this path already present.
+ # don't render in that case.
+ edge = (left, right, prop.key)
+ if edge in self._joinpoint:
+ # The child's prev reference might be stale --
+ # it could point to a parent older than the
+ # current joinpoint. If this is the case,
+ # then we need to update it and then fix the
+ # tree's spine with _update_joinpoint. Copy
+ # and then mutate the child, which might be
+ # shared by a different query object.
+ jp = self._joinpoint[edge].copy()
+ jp["prev"] = (edge, self._joinpoint)
+ self._update_joinpoint(jp)
+
+ continue
+
+ else:
+ # no descriptor/property given; we will need to figure out
+ # what the effective "left" side is
+ prop = left = None
+
+ # figure out the final "left" and "right" sides and create an
+ # ORMJoin to add to our _from_obj tuple
+ self._join_left_to_right(
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ create_aliases,
+ aliased_generation,
+ outerjoin,
+ full,
+ )
+
+ def _joinpoint_zero(self):
+ return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
+
+ def _join_left_to_right(
+ self,
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ create_aliases,
+ aliased_generation,
+ outerjoin,
+ full,
+ ):
+ """given raw "left", "right", "onclause" parameters consumed from
+ a particular key within _join(), add a real ORMJoin object to
+ our _from_obj list (or augment an existing one)
+
+ """
+
+ if left is None:
+ # left not given (e.g. no relationship object/name specified)
+ # figure out the best "left" side based on our existing froms /
+ # entities
+ assert prop is None
+ (
+ left,
+ replace_from_obj_index,
+ use_entity_index,
+ ) = self._join_determine_implicit_left_side(
+ entities_collection, left, right, onclause
+ )
+ else:
+ # left is given via a relationship/name, or as explicit left side.
+ # Determine where in our
+ # "froms" list it should be spliced/appended as well as what
+ # existing entity it corresponds to.
+ (
+ replace_from_obj_index,
+ use_entity_index,
+ ) = self._join_place_explicit_left_side(entities_collection, left)
+
+ if left is right and not create_aliases:
+ raise sa_exc.InvalidRequestError(
+ "Can't construct a join from %s to %s, they "
+ "are the same entity" % (left, right)
+ )
+
+ # the right side as given often needs to be adapted. additionally
+ # a lot of things can be wrong with it. handle all that and
+ # get back the new effective "right" side
+ r_info, right, onclause = self._join_check_and_adapt_right_side(
+ left, right, onclause, prop, create_aliases, aliased_generation
+ )
+
+ if not r_info.is_selectable:
+ extra_criteria = self._get_extra_criteria(r_info)
+ else:
+ extra_criteria = ()
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + [
+ _ORMJoin(
+ left_clause,
+ right,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
+ )
+ ]
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+ # add a new element to the self._from_obj list
+ if use_entity_index is not None:
+ # make use of _MapperEntity selectable, which is usually
+ # entity_zero.selectable, but if with_polymorphic() were used
+ # might be distinct
+ assert isinstance(
+ entities_collection[use_entity_index], _MapperEntity
+ )
+ left_clause = entities_collection[use_entity_index].selectable
+ else:
+ left_clause = left
+
+ self.from_clauses = self.from_clauses + [
+ _ORMJoin(
+ left_clause,
+ r_info,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
+ )
+ ]
+
+ def _join_determine_implicit_left_side(
+ self, entities_collection, left, right, onclause
+ ):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ # when we are here, it means join() was called without an ORM-
+ # specific way of telling us what the "left" side is, e.g.:
+ #
+ # join(RightEntity)
+ #
+ # or
+ #
+ # join(RightEntity, RightEntity.foo == LeftEntity.bar)
+ #
+
+ r_info = inspect(right)
+
+ replace_from_obj_index = use_entity_index = None
+
+ if self.from_clauses:
+ # we have a list of FROMs already. So by definition this
+ # join has to connect to one of those FROMs.
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ self.from_clauses, r_info.selectable, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = self.from_clauses[replace_from_obj_index]
+ elif len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity." % (right,)
+ )
+
+ elif entities_collection:
+ # we have no explicit FROMs, so the implicit left has to
+ # come from our list of entities.
+
+ potential = {}
+ for entity_index, ent in enumerate(entities_collection):
+ entity = ent.entity_zero_or_selectable
+ if entity is None:
+ continue
+ ent_info = inspect(entity)
+ if ent_info is r_info: # left and right are the same, skip
+ continue
+
+ # by using a dictionary with the selectables as keys this
+ # de-duplicates those selectables as occurs when the query is
+ # against a series of columns from the same selectable
+ if isinstance(ent, _MapperEntity):
+ potential[ent.selectable] = (entity_index, entity)
+ else:
+ potential[ent_info.selectable] = (None, entity)
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, r_info.selectable, onclause
+ )
+
+ if len(indexes) == 1:
+ use_entity_index, left = potential[all_clauses[indexes[0]]]
+ elif len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity." % (right,)
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "No entities to join from; please use "
+ "select_from() to establish the left "
+ "entity/selectable of this join"
+ )
+
+ return left, replace_from_obj_index, use_entity_index
+
+ def _join_place_explicit_left_side(self, entities_collection, left):
+ """When join conditions express a left side explicitly, determine
+ where in our existing list of FROM clauses we should join towards,
+ or if we need to make a new join, and if so is it from one of our
+ existing entities.
+
+ """
+
+ # when we are here, it means join() was called with an indicator
+ # as to an exact left side, which means a path to a
+ # RelationshipProperty was given, e.g.:
+ #
+ # join(RightEntity, LeftEntity.right)
+ #
+ # or
+ #
+ # join(LeftEntity.right)
+ #
+ # as well as string forms:
+ #
+ # join(RightEntity, "right")
+ #
+ # etc.
+ #
+
+ replace_from_obj_index = use_entity_index = None
+
+ l_info = inspect(left)
+ if self.from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, l_info.selectable
+ )
+
+ if len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ # no from element present, so we will have to add to the
+ # self._from_obj tuple. Determine if this left side matches up
+ # with existing mapper entities, in which case we want to apply the
+ # aliasing / adaptation rules present on that entity if any
+ if (
+ replace_from_obj_index is None
+ and entities_collection
+ and hasattr(l_info, "mapper")
+ ):
+ for idx, ent in enumerate(entities_collection):
+ # TODO: should we be checking for multiple mapper entities
+ # matching?
+ if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
+ use_entity_index = idx
+ break
+
+ return replace_from_obj_index, use_entity_index
+
+ def _join_check_and_adapt_right_side(
+ self, left, right, onclause, prop, create_aliases, aliased_generation
+ ):
+ """transform the "right" side of the join as well as the onclause
+ according to polymorphic mapping translations, aliasing on the query
+ or on the join, special cases where the right and left side have
+ overlapping tables.
+
+ """
+
+ l_info = inspect(left)
+ r_info = inspect(right)
+
+ overlap = False
+ if not create_aliases:
+ right_mapper = getattr(r_info, "mapper", None)
+ # if the target is a joined inheritance mapping,
+ # be more liberal about auto-aliasing.
+ if right_mapper and (
+ right_mapper.with_polymorphic
+ or isinstance(right_mapper.persist_selectable, expression.Join)
+ ):
+ for from_obj in self.from_clauses or [l_info.selectable]:
+ if sql_util.selectables_overlap(
+ l_info.selectable, from_obj
+ ) and sql_util.selectables_overlap(
+ from_obj, r_info.selectable
+ ):
+ overlap = True
+ break
+
+ if (
+ overlap or not create_aliases
+ ) and l_info.selectable is r_info.selectable:
+ raise sa_exc.InvalidRequestError(
+ "Can't join table/selectable '%s' to itself"
+ % l_info.selectable
+ )
+
+ right_mapper, right_selectable, right_is_aliased = (
+ getattr(r_info, "mapper", None),
+ r_info.selectable,
+ getattr(r_info, "is_aliased_class", False),
+ )
+
+ if (
+ right_mapper
+ and prop
+ and not right_mapper.common_parent(prop.mapper)
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Join target %s does not correspond to "
+ "the right side of join condition %s" % (right, onclause)
+ )
+
+ # _join_entities is used as a hint for single-table inheritance
+ # purposes at the moment
+ if hasattr(r_info, "mapper"):
+ self._join_entities += (r_info,)
+
+ need_adapter = False
+
+ # test for joining to an unmapped selectable as the target
+ if r_info.is_clause_element:
+
+ if prop:
+ right_mapper = prop.mapper
+
+ if right_selectable._is_lateral:
+ # orm_only is disabled to suit the case where we have to
+ # adapt an explicit correlate(Entity) - the select() loses
+ # the ORM-ness in this case right now, ideally it would not
+ current_adapter = self._get_current_adapter()
+ if current_adapter is not None:
+ # TODO: we had orm_only=False here before, removing
+ # it didn't break things. if we identify the rationale,
+ # may need to apply "_orm_only" annotation here.
+ right = current_adapter(right, True)
+
+ elif prop:
+ # joining to selectable with a mapper property given
+ # as the ON clause
+
+ if not right_selectable.is_derived_from(
+ right_mapper.persist_selectable
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Selectable '%s' is not derived from '%s'"
+ % (
+ right_selectable.description,
+ right_mapper.persist_selectable.description,
+ )
+ )
+
+ # if the destination selectable is a plain select(),
+ # turn it into an alias().
+ if isinstance(right_selectable, expression.SelectBase):
+ right_selectable = coercions.expect(
+ roles.FromClauseRole, right_selectable
+ )
+ need_adapter = True
+
+ # make the right hand side target into an ORM entity
+ right = aliased(right_mapper, right_selectable)
+
+ util.warn_deprecated(
+ "An alias is being generated automatically against "
+ "joined entity %s for raw clauseelement, which is "
+ "deprecated and will be removed in a later release. "
+ "Use the aliased() "
+ "construct explicitly, see the linked example."
+ % right_mapper,
+ "1.4",
+ code="xaj1",
+ )
+
+ elif create_aliases:
+ # it *could* work, but it doesn't right now and I'd rather
+ # get rid of aliased=True completely
+ raise sa_exc.InvalidRequestError(
+ "The aliased=True parameter on query.join() only works "
+ "with an ORM entity, not a plain selectable, as the "
+ "target."
+ )
+
+ # test for overlap:
+ # orm/inheritance/relationships.py
+ # SelfReferentialM2MTest
+ aliased_entity = right_mapper and not right_is_aliased and overlap
+
+ if not need_adapter and (create_aliases or aliased_entity):
+ # there are a few places in the ORM that automatic aliasing
+ # is still desirable, and can't be automatic with a Core
+ # only approach. For illustrations of "overlaps" see
+ # test/orm/inheritance/test_relationships.py. There are also
+ # general overlap cases with many-to-many tables where automatic
+ # aliasing is desirable.
+ right = aliased(right, flat=True)
+ need_adapter = True
+
+ if not create_aliases:
+ util.warn(
+ "An alias is being generated automatically against "
+ "joined entity %s due to overlapping tables. This is a "
+ "legacy pattern which may be "
+ "deprecated in a later release. Use the "
+ "aliased(<entity>, flat=True) "
+ "construct explicitly, see the linked example."
+ % right_mapper,
+ code="xaj2",
+ )
+
+ if need_adapter:
+ assert right_mapper
+
+ adapter = ORMAdapter(
+ right, equivalents=right_mapper._equivalent_columns
+ )
+
+ # if an alias() on the right side was generated,
+ # which is intended to wrap a the right side in a subquery,
+ # ensure that columns retrieved from this target in the result
+ # set are also adapted.
+ if not create_aliases:
+ self._mapper_loads_polymorphically_with(right_mapper, adapter)
+ elif aliased_generation:
+ adapter._debug = True
+ self._aliased_generations[aliased_generation] = (
+ adapter,
+ ) + self._aliased_generations.get(aliased_generation, ())
+ elif (
+ not r_info.is_clause_element
+ and not right_is_aliased
+ and right_mapper.with_polymorphic
+ and isinstance(
+ right_mapper._with_polymorphic_selectable,
+ expression.AliasedReturnsRows,
+ )
+ ):
+ # for the case where the target mapper has a with_polymorphic
+ # set up, ensure an adapter is set up for criteria that works
+ # against this mapper. Previously, this logic used to
+ # use the "create_aliases or aliased_entity" case to generate
+ # an aliased() object, but this creates an alias that isn't
+ # strictly necessary.
+ # see test/orm/test_core_compilation.py
+ # ::RelNaturalAliasedJoinsTest::test_straight
+ # and similar
+ self._mapper_loads_polymorphically_with(
+ right_mapper,
+ sql_util.ColumnAdapter(
+ right_mapper.selectable,
+ right_mapper._equivalent_columns,
+ ),
+ )
+ # if the onclause is a ClauseElement, adapt it with any
+ # adapters that are in place right now
+ if isinstance(onclause, expression.ClauseElement):
+ current_adapter = self._get_current_adapter()
+ if current_adapter:
+ onclause = current_adapter(onclause, True)
+
+ # if joining on a MapperProperty path,
+ # track the path to prevent redundant joins
+ if not create_aliases and prop:
+ self._update_joinpoint(
+ {
+ "_joinpoint_entity": right,
+ "prev": ((left, right, prop.key), self._joinpoint),
+ "aliased_generation": aliased_generation,
+ }
+ )
+ else:
+ self._joinpoint = {
+ "_joinpoint_entity": right,
+ "aliased_generation": aliased_generation,
+ }
+
+ return inspect(right), right, onclause
+
+ def _update_joinpoint(self, jp):
+ self._joinpoint = jp
+ # copy backwards to the root of the _joinpath
+ # dict, so that no existing dict in the path is mutated
+ while "prev" in jp:
+ f, prev = jp["prev"]
+ prev = dict(prev)
+ prev[f] = jp.copy()
+ jp["prev"] = (f, prev)
+ jp = prev
+ self._joinpath = jp
+
+ def _reset_joinpoint(self):
+ self._joinpoint = self._joinpath
+
+ @property
+ def _select_args(self):
+ return {
+ "limit_clause": self.select_statement._limit_clause,
+ "offset_clause": self.select_statement._offset_clause,
+ "distinct": self.distinct,
+ "distinct_on": self.distinct_on,
+ "prefixes": self.select_statement._prefixes,
+ "suffixes": self.select_statement._suffixes,
+ "group_by": self.group_by or None,
+ "fetch_clause": self.select_statement._fetch_clause,
+ "fetch_clause_options": (
+ self.select_statement._fetch_clause_options
+ ),
+ }
+
+ @property
+ def _should_nest_selectable(self):
+ kwargs = self._select_args
+ return (
+ kwargs.get("limit_clause") is not None
+ or kwargs.get("offset_clause") is not None
+ or kwargs.get("distinct", False)
+ or kwargs.get("distinct_on", ())
+ or kwargs.get("group_by", False)
+ )
+
+ def _get_extra_criteria(self, ext_info):
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in self.global_attributes:
+ return tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in self.global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if (ae.include_aliases or ae.entity is ext_info)
+ and ae._should_include(self)
+ )
+ else:
+ return ()
+
+ def _adjust_for_extra_criteria(self):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in
+ the columns clause of this query, as well as the "select from entity",
+ add criterion to the WHERE
+ clause of the given QueryContext such that only the appropriate
+ subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ associated with the global context.
+
+ """
+
+ for fromclause in self.from_clauses:
+ ext_info = fromclause._annotations.get("parententity", None)
+ if (
+ ext_info
+ and (
+ ext_info.mapper._single_table_criterion is not None
+ or ("additional_entity_criteria", ext_info.mapper)
+ in self.global_attributes
+ )
+ and ext_info not in self.extra_criteria_entities
+ ):
+
+ self.extra_criteria_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
+ search = set(self.extra_criteria_entities.values())
+
+ for (ext_info, adapter) in search:
+ if ext_info in self._join_entities:
+ continue
+
+ single_crit = ext_info.mapper._single_table_criterion
+
+ if self.compile_options._for_refresh_state:
+ additional_entity_criteria = []
+ else:
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
+
+ if single_crit is not None:
+ additional_entity_criteria += (single_crit,)
+
+ current_adapter = self._get_current_adapter()
+ for crit in additional_entity_criteria:
+ if adapter:
+ crit = adapter.traverse(crit)
+
+ if current_adapter:
+ crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+ crit = current_adapter(crit, False)
+ self._where_criteria += (crit,)
+
+
+def _column_descriptions(
+ query_or_select_stmt, compile_state=None, legacy=False
+):
+ if compile_state is None:
+ compile_state = ORMSelectCompileState._create_entities_collection(
+ query_or_select_stmt, legacy=legacy
+ )
+ ctx = compile_state
+ return [
+ {
+ "name": ent._label_name,
+ "type": ent.type,
+ "aliased": getattr(insp_ent, "is_aliased_class", False),
+ "expr": ent.expr,
+ "entity": getattr(insp_ent, "entity", None)
+ if ent.entity_zero is not None and not insp_ent.is_clause_element
+ else None,
+ }
+ for ent, insp_ent in [
+ (
+ _ent,
+ (
+ inspect(_ent.entity_zero)
+ if _ent.entity_zero is not None
+ else None
+ ),
+ )
+ for _ent in ctx._entities
+ ]
+ ]
+
+
+def _legacy_filter_by_entity_zero(query_or_augmented_select):
+ self = query_or_augmented_select
+ if self._legacy_setup_joins:
+ _last_joined_entity = self._last_joined_entity
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ if self._from_obj and "parententity" in self._from_obj[0]._annotations:
+ return self._from_obj[0]._annotations["parententity"]
+
+ return _entity_from_pre_ent_zero(self)
+
+
+def _entity_from_pre_ent_zero(query_or_augmented_select):
+ self = query_or_augmented_select
+ if not self._raw_columns:
+ return None
+
+ ent = self._raw_columns[0]
+
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ elif isinstance(ent, ORMColumnsClauseRole):
+ return ent.entity
+ elif "bundle" in ent._annotations:
+ return ent._annotations["bundle"]
+ else:
+ return ent
+
+
+def _legacy_determine_last_joined_entity(setup_joins, entity_zero):
+ """given the legacy_setup_joins collection at a point in time,
+ figure out what the "filter by entity" would be in terms
+ of those joins.
+
+ in 2.0 this logic should hopefully be much simpler as there will
+ be far fewer ways to specify joins with the ORM
+
+ """
+
+ if not setup_joins:
+ return entity_zero
+
+ # CAN BE REMOVED IN 2.0:
+ # 1. from_joinpoint
+ # 2. aliased_generation
+ # 3. aliased
+ # 4. any treating of prop as str
+ # 5. tuple madness
+ # 6. won't need recursive call anymore without #4
+ # 7. therefore can pass in just the last setup_joins record,
+ # don't need entity_zero
+
+ (right, onclause, left_, flags) = setup_joins[-1]
+
+ from_joinpoint = flags["from_joinpoint"]
+
+ if onclause is None and isinstance(
+ right, (str, interfaces.PropComparator)
+ ):
+ onclause = right
+ right = None
+
+ if right is not None and "parententity" in right._annotations:
+ right = right._annotations["parententity"].entity
+
+ if right is not None:
+ last_entity = right
+ insp = inspect(last_entity)
+ if insp.is_clause_element or insp.is_aliased_class or insp.is_mapper:
+ return insp
+
+ last_entity = onclause
+ if isinstance(last_entity, interfaces.PropComparator):
+ return last_entity.entity
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if isinstance(onclause, str):
+ if from_joinpoint:
+ prev = _legacy_determine_last_joined_entity(
+ setup_joins[0:-1], entity_zero
+ )
+ else:
+ prev = entity_zero
+
+ if prev is None:
+ return None
+
+ prev = inspect(prev)
+ attr = getattr(prev.entity, onclause, None)
+ if attr is not None:
+ return attr.property.entity
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ return None
+
+
+class _QueryEntity(object):
+ """represent an entity column returned within a Query result."""
+
+ __slots__ = ()
+
+ _non_hashable_value = False
+ _null_column_type = False
+ use_id_for_hash = False
+
+ @classmethod
+ def to_compile_state(
+ cls, compile_state, entities, entities_collection, is_current_entities
+ ):
+
+ for idx, entity in enumerate(entities):
+ if entity._is_lambda_element:
+ if entity._is_sequence:
+ cls.to_compile_state(
+ compile_state,
+ entity._resolved,
+ entities_collection,
+ is_current_entities,
+ )
+ continue
+ else:
+ entity = entity._resolved
+
+ if entity.is_clause_element:
+ if entity.is_selectable:
+ if "parententity" in entity._annotations:
+ _MapperEntity(
+ compile_state,
+ entity,
+ entities_collection,
+ is_current_entities,
+ )
+ else:
+ _ColumnEntity._for_columns(
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ else:
+ if entity._annotations.get("bundle", False):
+ _BundleEntity(
+ compile_state,
+ entity,
+ entities_collection,
+ is_current_entities,
+ )
+ elif entity._is_clause_list:
+ # this is legacy only - test_composites.py
+ # test_query_cols_legacy
+ _ColumnEntity._for_columns(
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ else:
+ _ColumnEntity._for_columns(
+ compile_state,
+ [entity],
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ elif entity.is_bundle:
+ _BundleEntity(compile_state, entity, entities_collection)
+
+ return entities_collection
+
+
+class _MapperEntity(_QueryEntity):
+ """mapper/class/AliasedClass entity"""
+
+ __slots__ = (
+ "expr",
+ "mapper",
+ "entity_zero",
+ "is_aliased_class",
+ "path",
+ "_extra_entities",
+ "_label_name",
+ "_with_polymorphic_mappers",
+ "selectable",
+ "_polymorphic_discriminator",
+ )
+
+ def __init__(
+ self, compile_state, entity, entities_collection, is_current_entities
+ ):
+ entities_collection.append(self)
+ if is_current_entities:
+ if compile_state._primary_entity is None:
+ compile_state._primary_entity = self
+ compile_state._has_mapper_entities = True
+ compile_state._has_orm_entities = True
+
+ entity = entity._annotations["parententity"]
+ entity._post_inspect
+ ext_info = self.entity_zero = entity
+ entity = ext_info.entity
+
+ self.expr = entity
+ self.mapper = mapper = ext_info.mapper
+
+ self._extra_entities = (self.expr,)
+
+ if ext_info.is_aliased_class:
+ self._label_name = ext_info.name
+ else:
+ self._label_name = mapper.class_.__name__
+
+ self.is_aliased_class = ext_info.is_aliased_class
+ self.path = ext_info._path_registry
+
+ if ext_info in compile_state._with_polymorphic_adapt_map:
+ # this codepath occurs only if query.with_polymorphic() were
+ # used
+
+ wp = inspect(compile_state._with_polymorphic_adapt_map[ext_info])
+
+ if self.is_aliased_class:
+ # TODO: invalidrequest ?
+ raise NotImplementedError(
+ "Can't use with_polymorphic() against an Aliased object"
+ )
+
+ mappers, from_obj = mapper._with_polymorphic_args(
+ wp.with_polymorphic_mappers, wp.selectable
+ )
+
+ self._with_polymorphic_mappers = mappers
+ self.selectable = from_obj
+ self._polymorphic_discriminator = wp.polymorphic_on
+
+ else:
+ self.selectable = ext_info.selectable
+ self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers
+ self._polymorphic_discriminator = ext_info.polymorphic_on
+
+ if (
+ mapper.with_polymorphic
+ # controversy - only if inheriting mapper is also
+ # polymorphic?
+ # or (mapper.inherits and mapper.inherits.with_polymorphic)
+ or mapper.inherits
+ or mapper._requires_row_aliasing
+ ):
+ compile_state._create_with_polymorphic_adapter(
+ ext_info, self.selectable
+ )
+
+ supports_single_entity = True
+
+ _non_hashable_value = True
+ use_id_for_hash = True
+
+ @property
+ def type(self):
+ return self.mapper.class_
+
+ @property
+ def entity_zero_or_selectable(self):
+ return self.entity_zero
+
+ def corresponds_to(self, entity):
+ return _entity_corresponds_to(self.entity_zero, entity)
+
+ def _get_entity_clauses(self, compile_state):
+
+ adapter = None
+
+ if not self.is_aliased_class:
+ if compile_state._polymorphic_adapters:
+ adapter = compile_state._polymorphic_adapters.get(
+ self.mapper, None
+ )
+ else:
+ adapter = self.entity_zero._adapter
+
+ if adapter:
+ if compile_state._from_obj_alias:
+ ret = adapter.wrap(compile_state._from_obj_alias)
+ else:
+ ret = adapter
+ else:
+ ret = compile_state._from_obj_alias
+
+ return ret
+
+ def row_processor(self, context, result):
+ compile_state = context.compile_state
+ adapter = self._get_entity_clauses(compile_state)
+
+ if compile_state.compound_eager_adapter and adapter:
+ adapter = adapter.wrap(compile_state.compound_eager_adapter)
+ elif not adapter:
+ adapter = compile_state.compound_eager_adapter
+
+ if compile_state._primary_entity is self:
+ only_load_props = compile_state.compile_options._only_load_props
+ refresh_state = context.refresh_state
+ else:
+ only_load_props = refresh_state = None
+
+ _instance = loading._instance_processor(
+ self,
+ self.mapper,
+ context,
+ result,
+ self.path,
+ adapter,
+ only_load_props=only_load_props,
+ refresh_state=refresh_state,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+
+ return _instance, self._label_name, self._extra_entities
+
+ def setup_compile_state(self, compile_state):
+
+ adapter = self._get_entity_clauses(compile_state)
+
+ single_table_crit = self.mapper._single_table_criterion
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+ ext_info = self.entity_zero
+ compile_state.extra_criteria_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+
+ compile_state._fallback_from_clauses.append(self.selectable)
+
+
+class _BundleEntity(_QueryEntity):
+
+ _extra_entities = ()
+
+ __slots__ = (
+ "bundle",
+ "expr",
+ "type",
+ "_label_name",
+ "_entities",
+ "supports_single_entity",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ setup_entities=True,
+ parent_bundle=None,
+ ):
+ compile_state._has_orm_entities = True
+
+ expr = expr._annotations["bundle"]
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ if isinstance(
+ expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ bundle = expr.__clause_element__()
+ else:
+ bundle = expr
+
+ self.bundle = self.expr = bundle
+ self.type = type(bundle)
+ self._label_name = bundle.name
+ self._entities = []
+
+ if setup_entities:
+ for expr in bundle.exprs:
+ if "bundle" in expr._annotations:
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ parent_bundle=self,
+ )
+ elif isinstance(expr, Bundle):
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ parent_bundle=self,
+ )
+ else:
+ _ORMColumnEntity._for_columns(
+ compile_state,
+ [expr],
+ entities_collection,
+ None,
+ is_current_entities,
+ parent_bundle=self,
+ )
+
+ self.supports_single_entity = self.bundle.single_entity
+ if (
+ self.supports_single_entity
+ and not compile_state.compile_options._use_legacy_query_style
+ ):
+ util.warn_deprecated_20(
+ "The Bundle.single_entity flag has no effect when "
+ "using 2.0 style execution."
+ )
+
+ @property
+ def mapper(self):
+ ezero = self.entity_zero
+ if ezero is not None:
+ return ezero.mapper
+ else:
+ return None
+
+ @property
+ def entity_zero(self):
+ for ent in self._entities:
+ ezero = ent.entity_zero
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
+ def corresponds_to(self, entity):
+ # TODO: we might be able to implement this but for now
+ # we are working around it
+ return False
+
+ @property
+ def entity_zero_or_selectable(self):
+ for ent in self._entities:
+ ezero = ent.entity_zero_or_selectable
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
+ def setup_compile_state(self, compile_state):
+ for ent in self._entities:
+ ent.setup_compile_state(compile_state)
+
+ def row_processor(self, context, result):
+ procs, labels, extra = zip(
+ *[ent.row_processor(context, result) for ent in self._entities]
+ )
+
+ proc = self.bundle.create_row_processor(context.query, procs, labels)
+
+ return proc, self._label_name, self._extra_entities
+
+
+class _ColumnEntity(_QueryEntity):
+ __slots__ = (
+ "_fetch_column",
+ "_row_processor",
+ "raw_column_index",
+ "translate_raw_column",
+ )
+
+ @classmethod
+ def _for_columns(
+ cls,
+ compile_state,
+ columns,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ for column in columns:
+ annotations = column._annotations
+ if "parententity" in annotations:
+ _entity = annotations["parententity"]
+ else:
+ _entity = sql_util.extract_first_column_annotation(
+ column, "parententity"
+ )
+
+ if _entity:
+ if "identity_token" in column._annotations:
+ _IdentityTokenEntity(
+ compile_state,
+ column,
+ entities_collection,
+ _entity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+ else:
+ _ORMColumnEntity(
+ compile_state,
+ column,
+ entities_collection,
+ _entity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+ else:
+ _RawColumnEntity(
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+
+ @property
+ def type(self):
+ return self.column.type
+
+ @property
+ def _non_hashable_value(self):
+ return not self.column.type.hashable
+
+ @property
+ def _null_column_type(self):
+ return self.column.type._isnull
+
+ def row_processor(self, context, result):
+ compile_state = context.compile_state
+
+ # the resulting callable is entirely cacheable so just return
+ # it if we already made one
+ if self._row_processor is not None:
+ getter, label_name, extra_entities = self._row_processor
+ if self.translate_raw_column:
+ extra_entities += (
+ result.context.invoked_statement._raw_columns[
+ self.raw_column_index
+ ],
+ )
+
+ return getter, label_name, extra_entities
+
+ # retrieve the column that would have been set up in
+ # setup_compile_state, to avoid doing redundant work
+ if self._fetch_column is not None:
+ column = self._fetch_column
+ else:
+ # fetch_column will be None when we are doing a from_statement
+ # and setup_compile_state may not have been called.
+ column = self.column
+
+ # previously, the RawColumnEntity didn't look for from_obj_alias
+ # however I can't think of a case where we would be here and
+ # we'd want to ignore it if this is the from_statement use case.
+ # it's not really a use case to have raw columns + from_statement
+ if compile_state._from_obj_alias:
+ column = compile_state._from_obj_alias.columns[column]
+
+ if column._annotations:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ if compile_state.compound_eager_adapter:
+ column = compile_state.compound_eager_adapter.columns[column]
+
+ getter = result._getter(column)
+
+ ret = getter, self._label_name, self._extra_entities
+ self._row_processor = ret
+
+ if self.translate_raw_column:
+ extra_entities = self._extra_entities + (
+ result.context.invoked_statement._raw_columns[
+ self.raw_column_index
+ ],
+ )
+ return getter, self._label_name, extra_entities
+ else:
+ return ret
+
+
+class _RawColumnEntity(_ColumnEntity):
+ entity_zero = None
+ mapper = None
+ supports_single_entity = False
+
+ __slots__ = (
+ "expr",
+ "column",
+ "_label_name",
+ "entity_zero_or_selectable",
+ "_extra_entities",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ self.expr = column
+ self.raw_column_index = raw_column_index
+ self.translate_raw_column = raw_column_index is not None
+
+ if column._is_star:
+ compile_state.compile_options += {"_is_star": True}
+
+ if not is_current_entities or column._is_text_clause:
+ self._label_name = None
+ else:
+ self._label_name = compile_state._label_convention(column)
+
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ self.column = column
+ self.entity_zero_or_selectable = (
+ self.column._from_objects[0] if self.column._from_objects else None
+ )
+ self._extra_entities = (self.expr, self.column)
+ self._fetch_column = self._row_processor = None
+
+ def corresponds_to(self, entity):
+ return False
+
+ def setup_compile_state(self, compile_state):
+ current_adapter = compile_state._get_current_adapter()
+ if current_adapter:
+ column = current_adapter(self.column, False)
+ else:
+ column = self.column
+
+ if column._annotations:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+ self._fetch_column = column
+
+
+class _ORMColumnEntity(_ColumnEntity):
+ """Column/expression based entity."""
+
+ supports_single_entity = False
+
+ __slots__ = (
+ "expr",
+ "mapper",
+ "column",
+ "_label_name",
+ "entity_zero_or_selectable",
+ "entity_zero",
+ "_extra_entities",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ parententity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ annotations = column._annotations
+
+ _entity = parententity
+
+ # an AliasedClass won't have proxy_key in the annotations for
+ # a column if it was acquired using the class' adapter directly,
+ # such as using AliasedInsp._adapt_element(). this occurs
+ # within internal loaders.
+
+ orm_key = annotations.get("proxy_key", None)
+ proxy_owner = annotations.get("proxy_owner", _entity)
+ if orm_key:
+ self.expr = getattr(proxy_owner.entity, orm_key)
+ self.translate_raw_column = False
+ else:
+ # if orm_key is not present, that means this is an ad-hoc
+ # SQL ColumnElement, like a CASE() or other expression.
+ # include this column position from the invoked statement
+ # in the ORM-level ResultSetMetaData on each execute, so that
+ # it can be targeted by identity after caching
+ self.expr = column
+ self.translate_raw_column = raw_column_index is not None
+
+ self.raw_column_index = raw_column_index
+
+ if is_current_entities:
+ self._label_name = compile_state._label_convention(
+ column, col_name=orm_key
+ )
+ else:
+ self._label_name = None
+
+ _entity._post_inspect
+ self.entity_zero = self.entity_zero_or_selectable = ezero = _entity
+ self.mapper = mapper = _entity.mapper
+
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ compile_state._has_orm_entities = True
+
+ self.column = column
+
+ self._fetch_column = self._row_processor = None
+
+ self._extra_entities = (self.expr, self.column)
+
+ if (
+ mapper.with_polymorphic
+ or mapper.inherits
+ or mapper._requires_row_aliasing
+ ):
+ compile_state._create_with_polymorphic_adapter(
+ ezero, ezero.selectable
+ )
+
+ def corresponds_to(self, entity):
+ if _is_aliased_class(entity):
+ # TODO: polymorphic subclasses ?
+ return entity is self.entity_zero
+ else:
+ return not _is_aliased_class(
+ self.entity_zero
+ ) and entity.common_parent(self.entity_zero)
+
+ def setup_compile_state(self, compile_state):
+ current_adapter = compile_state._get_current_adapter()
+ if current_adapter:
+ column = current_adapter(self.column, False)
+ else:
+ column = self.column
+
+ ezero = self.entity_zero
+
+ single_table_crit = self.mapper._single_table_criterion
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+
+ compile_state.extra_criteria_entities[ezero] = (
+ ezero,
+ ezero._adapter if ezero.is_aliased_class else None,
+ )
+
+ if column._annotations and not column._expression_label:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ # use entity_zero as the from if we have it. this is necessary
+ # for polymorphic scenarios where our FROM is based on ORM entity,
+ # not the FROM of the column. but also, don't use it if our column
+ # doesn't actually have any FROMs that line up, such as when its
+ # a scalar subquery.
+ if set(self.column._from_objects).intersection(
+ ezero.selectable._from_objects
+ ):
+ compile_state._fallback_from_clauses.append(ezero.selectable)
+
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+ self._fetch_column = column
+
+
+class _IdentityTokenEntity(_ORMColumnEntity):
+ translate_raw_column = False
+
+ def setup_compile_state(self, compile_state):
+ pass
+
+ def row_processor(self, context, result):
+ def getter(row):
+ return context.load_options._refresh_identity_token
+
+ return getter, self._label_name, self._extra_entities
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
new file mode 100644
index 0000000..16f91c6
--- /dev/null
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -0,0 +1,1062 @@
+# ext/declarative/api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Public API functions and helpers for declarative."""
+from __future__ import absolute_import
+
+import itertools
+import re
+import weakref
+
+from . import attributes
+from . import clsregistry
+from . import exc as orm_exc
+from . import instrumentation
+from . import interfaces
+from . import mapper as mapperlib
+from .base import _inspect_mapped_class
+from .decl_base import _add_attribute
+from .decl_base import _as_declarative
+from .decl_base import _declarative_constructor
+from .decl_base import _DeferredMapperConfig
+from .decl_base import _del_attribute
+from .decl_base import _mapper
+from .descriptor_props import SynonymProperty as _orm_synonym
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql.schema import MetaData
+from ..util import hybridmethod
+from ..util import hybridproperty
+
+
+def has_inherited_table(cls):
+ """Given a class, return True if any of the classes it inherits from has a
+ mapped table, otherwise return False.
+
+ This is used in declarative mixins to build attributes that behave
+ differently for the base class vs. a subclass in an inheritance
+ hierarchy.
+
+ .. seealso::
+
+ :ref:`decl_mixin_inheritance`
+
+ """
+ for class_ in cls.__mro__[1:]:
+ if getattr(class_, "__table__", None) is not None:
+ return True
+ return False
+
+
+class DeclarativeMeta(type):
+ def __init__(cls, classname, bases, dict_, **kw):
+ # use cls.__dict__, which can be modified by an
+ # __init_subclass__() method (#7900)
+ dict_ = cls.__dict__
+
+ # early-consume registry from the initial declarative base,
+ # assign privately to not conflict with subclass attributes named
+ # "registry"
+ reg = getattr(cls, "_sa_registry", None)
+ if reg is None:
+ reg = dict_.get("registry", None)
+ if not isinstance(reg, registry):
+ raise exc.InvalidRequestError(
+ "Declarative base class has no 'registry' attribute, "
+ "or registry is not a sqlalchemy.orm.registry() object"
+ )
+ else:
+ cls._sa_registry = reg
+
+ if not cls.__dict__.get("__abstract__", False):
+ _as_declarative(reg, cls, dict_)
+ type.__init__(cls, classname, bases, dict_)
+
+ def __setattr__(cls, key, value):
+ _add_attribute(cls, key, value)
+
+ def __delattr__(cls, key):
+ _del_attribute(cls, key)
+
+
+def synonym_for(name, map_column=False):
+ """Decorator that produces an :func:`_orm.synonym`
+ attribute in conjunction with a Python descriptor.
+
+ The function being decorated is passed to :func:`_orm.synonym` as the
+ :paramref:`.orm.synonym.descriptor` parameter::
+
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+
+ id = Column(Integer, primary_key=True)
+ _job_status = Column("job_status", String(50))
+
+ @synonym_for("job_status")
+ @property
+ def job_status(self):
+ return "Status: %s" % self._job_status
+
+ The :ref:`hybrid properties <mapper_hybrids>` feature of SQLAlchemy
+ is typically preferred instead of synonyms, which is a more legacy
+ feature.
+
+ .. seealso::
+
+ :ref:`synonyms` - Overview of synonyms
+
+ :func:`_orm.synonym` - the mapper-level function
+
+ :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an
+ updated approach to augmenting attribute behavior more flexibly than
+ can be achieved with synonyms.
+
+ """
+
+ def decorate(fn):
+ return _orm_synonym(name, map_column=map_column, descriptor=fn)
+
+ return decorate
+
+
+class declared_attr(interfaces._MappedAttribute, property):
+ """Mark a class-level method as representing the definition of
+ a mapped property or special declarative member name.
+
+ :class:`_orm.declared_attr` is typically applied as a decorator to a class
+ level method, turning the attribute into a scalar-like property that can be
+ invoked from the uninstantiated class. The Declarative mapping process
+ looks for these :class:`_orm.declared_attr` callables as it scans classes,
+ and assumes any attribute marked with :class:`_orm.declared_attr` will be a
+ callable that will produce an object specific to the Declarative mapping or
+ table configuration.
+
+ :class:`_orm.declared_attr` is usually applicable to mixins, to define
+ relationships that are to be applied to different implementors of the
+ class. It is also used to define :class:`_schema.Column` objects that
+ include the :class:`_schema.ForeignKey` construct, as these cannot be
+ easily reused across different mappings. The example below illustrates
+ both::
+
+ class ProvidesUser(object):
+ "A mixin that adds a 'user' relationship to classes."
+
+ @declared_attr
+ def user_id(self):
+ return Column(ForeignKey("user_account.id"))
+
+ @declared_attr
+ def user(self):
+ return relationship("User")
+
+ :class:`_orm.declared_attr` can also be applied to mapped classes, such as
+ to provide a "polymorphic" scheme for inheritance::
+
+ class Employee(Base):
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50), nullable=False)
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ @declared_attr
+ def __mapper_args__(cls):
+ if cls.__name__ == 'Employee':
+ return {
+ "polymorphic_on":cls.type,
+ "polymorphic_identity":"Employee"
+ }
+ else:
+ return {"polymorphic_identity":cls.__name__}
+
+ To use :class:`_orm.declared_attr` inside of a Python dataclass
+ as discussed at :ref:`orm_declarative_dataclasses_declarative_table`,
+ it may be placed directly inside the field metadata using a lambda::
+
+ @dataclass
+ class AddressMixin:
+ __sa_dataclass_metadata_key__ = "sa"
+
+ user_id: int = field(
+ init=False, metadata={"sa": declared_attr(lambda: Column(ForeignKey("user.id")))}
+ )
+ user: User = field(
+ init=False, metadata={"sa": declared_attr(lambda: relationship(User))}
+ )
+
+ :class:`_orm.declared_attr` also may be omitted from this form using a
+ lambda directly, as in::
+
+ user: User = field(
+ init=False, metadata={"sa": lambda: relationship(User)}
+ )
+
+ .. seealso::
+
+ :ref:`orm_mixins_toplevel` - illustrates how to use Declarative Mixins
+ which is the primary use case for :class:`_orm.declared_attr`
+
+ :ref:`orm_declarative_dataclasses_mixin` - illustrates special forms
+ for use with Python dataclasses
+
+ """ # noqa: E501
+
+ def __init__(self, fget, cascading=False):
+ super(declared_attr, self).__init__(fget)
+ self.__doc__ = fget.__doc__
+ self._cascading = cascading
+
+ def __get__(desc, self, cls):
+ # the declared_attr needs to make use of a cache that exists
+ # for the span of the declarative scan_attributes() phase.
+ # to achieve this we look at the class manager that's configured.
+ manager = attributes.manager_of_class(cls)
+ if manager is None:
+ if not re.match(r"^__.+__$", desc.fget.__name__):
+ # if there is no manager at all, then this class hasn't been
+ # run through declarative or mapper() at all, emit a warning.
+ util.warn(
+ "Unmanaged access of declarative attribute %s from "
+ "non-mapped class %s" % (desc.fget.__name__, cls.__name__)
+ )
+ return desc.fget(cls)
+ elif manager.is_mapped:
+ # the class is mapped, which means we're outside of the declarative
+ # scan setup, just run the function.
+ return desc.fget(cls)
+
+ # here, we are inside of the declarative scan. use the registry
+ # that is tracking the values of these attributes.
+ declarative_scan = manager.declarative_scan()
+ assert declarative_scan is not None
+ reg = declarative_scan.declared_attr_reg
+
+ if desc in reg:
+ return reg[desc]
+ else:
+ reg[desc] = obj = desc.fget(cls)
+ return obj
+
+ @hybridmethod
+ def _stateful(cls, **kw):
+ return _stateful_declared_attr(**kw)
+
+ @hybridproperty
+ def cascading(cls):
+ """Mark a :class:`.declared_attr` as cascading.
+
+ This is a special-use modifier which indicates that a column
+ or MapperProperty-based declared attribute should be configured
+ distinctly per mapped subclass, within a mapped-inheritance scenario.
+
+ .. warning::
+
+ The :attr:`.declared_attr.cascading` modifier has several
+ limitations:
+
+ * The flag **only** applies to the use of :class:`.declared_attr`
+ on declarative mixin classes and ``__abstract__`` classes; it
+ currently has no effect when used on a mapped class directly.
+
+ * The flag **only** applies to normally-named attributes, e.g.
+ not any special underscore attributes such as ``__tablename__``.
+ On these attributes it has **no** effect.
+
+ * The flag currently **does not allow further overrides** down
+ the class hierarchy; if a subclass tries to override the
+ attribute, a warning is emitted and the overridden attribute
+ is skipped. This is a limitation that it is hoped will be
+ resolved at some point.
+
+ Below, both MyClass as well as MySubClass will have a distinct
+ ``id`` Column object established::
+
+ class HasIdMixin(object):
+ @declared_attr.cascading
+ def id(cls):
+ if has_inherited_table(cls):
+ return Column(
+ ForeignKey('myclass.id'), primary_key=True
+ )
+ else:
+ return Column(Integer, primary_key=True)
+
+ class MyClass(HasIdMixin, Base):
+ __tablename__ = 'myclass'
+ # ...
+
+ class MySubClass(MyClass):
+ ""
+ # ...
+
+ The behavior of the above configuration is that ``MySubClass``
+ will refer to both its own ``id`` column as well as that of
+ ``MyClass`` underneath the attribute named ``some_id``.
+
+ .. seealso::
+
+ :ref:`declarative_inheritance`
+
+ :ref:`mixin_inheritance_columns`
+
+
+ """
+ return cls._stateful(cascading=True)
+
+
+class _stateful_declared_attr(declared_attr):
+ def __init__(self, **kw):
+ self.kw = kw
+
+ def _stateful(self, **kw):
+ new_kw = self.kw.copy()
+ new_kw.update(kw)
+ return _stateful_declared_attr(**new_kw)
+
+ def __call__(self, fn):
+ return declared_attr(fn, **self.kw)
+
+
+def declarative_mixin(cls):
+ """Mark a class as providing the feature of "declarative mixin".
+
+ E.g.::
+
+ from sqlalchemy.orm import declared_attr
+ from sqlalchemy.orm import declarative_mixin
+
+ @declarative_mixin
+ class MyMixin:
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ __table_args__ = {'mysql_engine': 'InnoDB'}
+ __mapper_args__= {'always_refresh': True}
+
+ id = Column(Integer, primary_key=True)
+
+ class MyModel(MyMixin, Base):
+ name = Column(String(1000))
+
+ The :func:`_orm.declarative_mixin` decorator currently does not modify
+ the given class in any way; it's current purpose is strictly to assist
+ the :ref:`Mypy plugin <mypy_toplevel>` in being able to identify
+ SQLAlchemy declarative mixin classes when no other context is present.
+
+ .. versionadded:: 1.4.6
+
+ .. seealso::
+
+ :ref:`orm_mixins_toplevel`
+
+ :ref:`mypy_declarative_mixins` - in the
+ :ref:`Mypy plugin documentation <mypy_toplevel>`
+
+ """ # noqa: E501
+
+ return cls
+
+
+def declarative_base(
+ bind=None,
+ metadata=None,
+ mapper=None,
+ cls=object,
+ name="Base",
+ constructor=_declarative_constructor,
+ class_registry=None,
+ metaclass=DeclarativeMeta,
+):
+ r"""Construct a base class for declarative class definitions.
+
+ The new base class will be given a metaclass that produces
+ appropriate :class:`~sqlalchemy.schema.Table` objects and makes
+ the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
+ information provided declaratively in the class and any subclasses
+ of the class.
+
+ The :func:`_orm.declarative_base` function is a shorthand version
+ of using the :meth:`_orm.registry.generate_base`
+ method. That is, the following::
+
+ from sqlalchemy.orm import declarative_base
+
+ Base = declarative_base()
+
+ Is equivalent to::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+ Base = mapper_registry.generate_base()
+
+ See the docstring for :class:`_orm.registry`
+ and :meth:`_orm.registry.generate_base`
+ for more details.
+
+ .. versionchanged:: 1.4 The :func:`_orm.declarative_base`
+ function is now a specialization of the more generic
+ :class:`_orm.registry` class. The function also moves to the
+ ``sqlalchemy.orm`` package from the ``declarative.ext`` package.
+
+
+ :param bind: An optional
+ :class:`~sqlalchemy.engine.Connectable`, will be assigned
+ the ``bind`` attribute on the :class:`~sqlalchemy.schema.MetaData`
+ instance.
+
+ .. deprecated:: 1.4 The "bind" argument to declarative_base is
+ deprecated and will be removed in SQLAlchemy 2.0.
+
+ :param metadata:
+ An optional :class:`~sqlalchemy.schema.MetaData` instance. All
+ :class:`~sqlalchemy.schema.Table` objects implicitly declared by
+ subclasses of the base will share this MetaData. A MetaData instance
+ will be created if none is provided. The
+ :class:`~sqlalchemy.schema.MetaData` instance will be available via the
+ ``metadata`` attribute of the generated declarative base class.
+
+ :param mapper:
+ An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will
+ be used to map subclasses to their Tables.
+
+ :param cls:
+ Defaults to :class:`object`. A type to use as the base for the generated
+ declarative base class. May be a class or tuple of classes.
+
+ :param name:
+ Defaults to ``Base``. The display name for the generated
+ class. Customizing this is not required, but can improve clarity in
+ tracebacks and debugging.
+
+ :param constructor:
+ Specify the implementation for the ``__init__`` function on a mapped
+ class that has no ``__init__`` of its own. Defaults to an
+ implementation that assigns \**kwargs for declared
+ fields and relationships to an instance. If ``None`` is supplied,
+ no __init__ will be provided and construction will fall back to
+ cls.__init__ by way of the normal Python semantics.
+
+ :param class_registry: optional dictionary that will serve as the
+ registry of class names-> mapped classes when string names
+ are used to identify classes inside of :func:`_orm.relationship`
+ and others. Allows two or more declarative base classes
+ to share the same registry of class names for simplified
+ inter-base relationships.
+
+ :param metaclass:
+ Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
+ compatible callable to use as the meta type of the generated
+ declarative base class.
+
+ .. seealso::
+
+ :class:`_orm.registry`
+
+ """
+
+ if bind is not None:
+ # util.deprecated_params does not work
+ util.warn_deprecated_20(
+ "The ``bind`` argument to declarative_base is "
+ "deprecated and will be removed in SQLAlchemy 2.0.",
+ )
+
+ return registry(
+ _bind=bind,
+ metadata=metadata,
+ class_registry=class_registry,
+ constructor=constructor,
+ ).generate_base(
+ mapper=mapper,
+ cls=cls,
+ name=name,
+ metaclass=metaclass,
+ )
+
+
+class registry(object):
+ """Generalized registry for mapping classes.
+
+ The :class:`_orm.registry` serves as the basis for maintaining a collection
+ of mappings, and provides configurational hooks used to map classes.
+
+ The three general kinds of mappings supported are Declarative Base,
+ Declarative Decorator, and Imperative Mapping. All of these mapping
+ styles may be used interchangeably:
+
+ * :meth:`_orm.registry.generate_base` returns a new declarative base
+ class, and is the underlying implementation of the
+ :func:`_orm.declarative_base` function.
+
+ * :meth:`_orm.registry.mapped` provides a class decorator that will
+ apply declarative mapping to a class without the use of a declarative
+ base class.
+
+ * :meth:`_orm.registry.map_imperatively` will produce a
+ :class:`_orm.Mapper` for a class without scanning the class for
+ declarative class attributes. This method suits the use case historically
+ provided by the
+ :func:`_orm.mapper` classical mapping function.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`orm_mapping_classes_toplevel` - overview of class mapping
+ styles.
+
+ """
+
+ def __init__(
+ self,
+ metadata=None,
+ class_registry=None,
+ constructor=_declarative_constructor,
+ _bind=None,
+ ):
+ r"""Construct a new :class:`_orm.registry`
+
+ :param metadata:
+ An optional :class:`_schema.MetaData` instance. All
+ :class:`_schema.Table` objects generated using declarative
+ table mapping will make use of this :class:`_schema.MetaData`
+ collection. If this argument is left at its default of ``None``,
+ a blank :class:`_schema.MetaData` collection is created.
+
+ :param constructor:
+ Specify the implementation for the ``__init__`` function on a mapped
+ class that has no ``__init__`` of its own. Defaults to an
+ implementation that assigns \**kwargs for declared
+ fields and relationships to an instance. If ``None`` is supplied,
+ no __init__ will be provided and construction will fall back to
+ cls.__init__ by way of the normal Python semantics.
+
+ :param class_registry: optional dictionary that will serve as the
+ registry of class names-> mapped classes when string names
+ are used to identify classes inside of :func:`_orm.relationship`
+ and others. Allows two or more declarative base classes
+ to share the same registry of class names for simplified
+ inter-base relationships.
+
+ """
+ lcl_metadata = metadata or MetaData()
+ if _bind:
+ lcl_metadata.bind = _bind
+
+ if class_registry is None:
+ class_registry = weakref.WeakValueDictionary()
+
+ self._class_registry = class_registry
+ self._managers = weakref.WeakKeyDictionary()
+ self._non_primary_mappers = weakref.WeakKeyDictionary()
+ self.metadata = lcl_metadata
+ self.constructor = constructor
+
+ self._dependents = set()
+ self._dependencies = set()
+
+ self._new_mappers = False
+
+ with mapperlib._CONFIGURE_MUTEX:
+ mapperlib._mapper_registries[self] = True
+
+ @property
+ def mappers(self):
+ """read only collection of all :class:`_orm.Mapper` objects."""
+
+ return frozenset(manager.mapper for manager in self._managers).union(
+ self._non_primary_mappers
+ )
+
+ def _set_depends_on(self, registry):
+ if registry is self:
+ return
+ registry._dependents.add(self)
+ self._dependencies.add(registry)
+
+ def _flag_new_mapper(self, mapper):
+ mapper._ready_for_configure = True
+ if self._new_mappers:
+ return
+
+ for reg in self._recurse_with_dependents({self}):
+ reg._new_mappers = True
+
+ @classmethod
+ def _recurse_with_dependents(cls, registries):
+ todo = registries
+ done = set()
+ while todo:
+ reg = todo.pop()
+ done.add(reg)
+
+ # if yielding would remove dependents, make sure we have
+ # them before
+ todo.update(reg._dependents.difference(done))
+ yield reg
+
+ # if yielding would add dependents, make sure we have them
+ # after
+ todo.update(reg._dependents.difference(done))
+
+ @classmethod
+ def _recurse_with_dependencies(cls, registries):
+ todo = registries
+ done = set()
+ while todo:
+ reg = todo.pop()
+ done.add(reg)
+
+ # if yielding would remove dependencies, make sure we have
+ # them before
+ todo.update(reg._dependencies.difference(done))
+
+ yield reg
+
+ # if yielding would remove dependencies, make sure we have
+ # them before
+ todo.update(reg._dependencies.difference(done))
+
+ def _mappers_to_configure(self):
+ return itertools.chain(
+ (
+ manager.mapper
+ for manager in list(self._managers)
+ if manager.is_mapped
+ and not manager.mapper.configured
+ and manager.mapper._ready_for_configure
+ ),
+ (
+ npm
+ for npm in list(self._non_primary_mappers)
+ if not npm.configured and npm._ready_for_configure
+ ),
+ )
+
+ def _add_non_primary_mapper(self, np_mapper):
+ self._non_primary_mappers[np_mapper] = True
+
+ def _dispose_cls(self, cls):
+ clsregistry.remove_class(cls.__name__, cls, self._class_registry)
+
+ def _add_manager(self, manager):
+ self._managers[manager] = True
+ if manager.registry is not None and manager.is_mapped:
+ raise exc.ArgumentError(
+ "Class '%s' already has a primary mapper defined. "
+ % manager.class_
+ )
+ manager.registry = self
+
+ def configure(self, cascade=False):
+ """Configure all as-yet unconfigured mappers in this
+ :class:`_orm.registry`.
+
+ The configure step is used to reconcile and initialize the
+ :func:`_orm.relationship` linkages between mapped classes, as well as
+ to invoke configuration events such as the
+ :meth:`_orm.MapperEvents.before_configured` and
+ :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM
+ extensions or user-defined extension hooks.
+
+ If one or more mappers in this registry contain
+ :func:`_orm.relationship` constructs that refer to mapped classes in
+ other registries, this registry is said to be *dependent* on those
+ registries. In order to configure those dependent registries
+ automatically, the :paramref:`_orm.registry.configure.cascade` flag
+ should be set to ``True``. Otherwise, if they are not configured, an
+ exception will be raised. The rationale behind this behavior is to
+ allow an application to programmatically invoke configuration of
+ registries while controlling whether or not the process implicitly
+ reaches other registries.
+
+ As an alternative to invoking :meth:`_orm.registry.configure`, the ORM
+ function :func:`_orm.configure_mappers` function may be used to ensure
+ configuration is complete for all :class:`_orm.registry` objects in
+ memory. This is generally simpler to use and also predates the usage of
+ :class:`_orm.registry` objects overall. However, this function will
+ impact all mappings throughout the running Python process and may be
+ more memory/time consuming for an application that has many registries
+ in use for different purposes that may not be needed immediately.
+
+ .. seealso::
+
+ :func:`_orm.configure_mappers`
+
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ mapperlib._configure_registries({self}, cascade=cascade)
+
+ def dispose(self, cascade=False):
+ """Dispose of all mappers in this :class:`_orm.registry`.
+
+ After invocation, all the classes that were mapped within this registry
+ will no longer have class instrumentation associated with them. This
+ method is the per-:class:`_orm.registry` analogue to the
+ application-wide :func:`_orm.clear_mappers` function.
+
+ If this registry contains mappers that are dependencies of other
+ registries, typically via :func:`_orm.relationship` links, then those
+ registries must be disposed as well. When such registries exist in
+ relation to this one, their :meth:`_orm.registry.dispose` method will
+ also be called, if the :paramref:`_orm.registry.dispose.cascade` flag
+ is set to ``True``; otherwise, an error is raised if those registries
+ were not already disposed.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :func:`_orm.clear_mappers`
+
+ """
+
+ mapperlib._dispose_registries({self}, cascade=cascade)
+
+ def _dispose_manager_and_mapper(self, manager):
+ if "mapper" in manager.__dict__:
+ mapper = manager.mapper
+
+ mapper._set_dispose_flags()
+
+ class_ = manager.class_
+ self._dispose_cls(class_)
+ instrumentation._instrumentation_factory.unregister(class_)
+
+ def generate_base(
+ self,
+ mapper=None,
+ cls=object,
+ name="Base",
+ metaclass=DeclarativeMeta,
+ ):
+ """Generate a declarative base class.
+
+ Classes that inherit from the returned class object will be
+ automatically mapped using declarative mapping.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ Base = mapper_registry.generate_base()
+
+ class MyClass(Base):
+ __tablename__ = "my_table"
+ id = Column(Integer, primary_key=True)
+
+ The above dynamically generated class is equivalent to the
+ non-dynamic example below::
+
+ from sqlalchemy.orm import registry
+ from sqlalchemy.orm.decl_api import DeclarativeMeta
+
+ mapper_registry = registry()
+
+ class Base(metaclass=DeclarativeMeta):
+ __abstract__ = True
+ registry = mapper_registry
+ metadata = mapper_registry.metadata
+
+ __init__ = mapper_registry.constructor
+
+ The :meth:`_orm.registry.generate_base` method provides the
+ implementation for the :func:`_orm.declarative_base` function, which
+ creates the :class:`_orm.registry` and base class all at once.
+
+ See the section :ref:`orm_declarative_mapping` for background and
+ examples.
+
+ :param mapper:
+ An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`.
+ This function is used to generate new :class:`_orm.Mapper` objects.
+
+ :param cls:
+ Defaults to :class:`object`. A type to use as the base for the
+ generated declarative base class. May be a class or tuple of classes.
+
+ :param name:
+ Defaults to ``Base``. The display name for the generated
+ class. Customizing this is not required, but can improve clarity in
+ tracebacks and debugging.
+
+ :param metaclass:
+ Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
+ compatible callable to use as the meta type of the generated
+ declarative base class.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :func:`_orm.declarative_base`
+
+ """
+ metadata = self.metadata
+
+ bases = not isinstance(cls, tuple) and (cls,) or cls
+
+ class_dict = dict(registry=self, metadata=metadata)
+ if isinstance(cls, type):
+ class_dict["__doc__"] = cls.__doc__
+
+ if self.constructor:
+ class_dict["__init__"] = self.constructor
+
+ class_dict["__abstract__"] = True
+ if mapper:
+ class_dict["__mapper_cls__"] = mapper
+
+ if hasattr(cls, "__class_getitem__"):
+
+ def __class_getitem__(cls, key):
+ # allow generic classes in py3.9+
+ return cls
+
+ class_dict["__class_getitem__"] = __class_getitem__
+
+ return metaclass(name, bases, class_dict)
+
+ def mapped(self, cls):
+ """Class decorator that will apply the Declarative mapping process
+ to a given class.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ @mapper_registry.mapped
+ class Foo:
+ __tablename__ = 'some_table'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ See the section :ref:`orm_declarative_mapping` for complete
+ details and examples.
+
+ :param cls: class to be mapped.
+
+ :return: the class that was passed.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :meth:`_orm.registry.generate_base` - generates a base class
+ that will apply Declarative mapping to subclasses automatically
+ using a Python metaclass.
+
+ """
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ def as_declarative_base(self, **kw):
+ """
+ Class decorator which will invoke
+ :meth:`_orm.registry.generate_base`
+ for a given base class.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ @mapper_registry.as_declarative_base()
+ class Base(object):
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+ id = Column(Integer, primary_key=True)
+
+ class MyMappedClass(Base):
+ # ...
+
+ All keyword arguments passed to
+ :meth:`_orm.registry.as_declarative_base` are passed
+ along to :meth:`_orm.registry.generate_base`.
+
+ """
+
+ def decorate(cls):
+ kw["cls"] = cls
+ kw["name"] = cls.__name__
+ return self.generate_base(**kw)
+
+ return decorate
+
+ def map_declaratively(self, cls):
+ """Map a class declaratively.
+
+ In this form of mapping, the class is scanned for mapping information,
+ including for columns to be associated with a table, and/or an
+ actual table object.
+
+ Returns the :class:`_orm.Mapper` object.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ class Foo:
+ __tablename__ = 'some_table'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ mapper = mapper_registry.map_declaratively(Foo)
+
+ This function is more conveniently invoked indirectly via either the
+ :meth:`_orm.registry.mapped` class decorator or by subclassing a
+ declarative metaclass generated from
+ :meth:`_orm.registry.generate_base`.
+
+ See the section :ref:`orm_declarative_mapping` for complete
+ details and examples.
+
+ :param cls: class to be mapped.
+
+ :return: a :class:`_orm.Mapper` object.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :meth:`_orm.registry.mapped` - more common decorator interface
+ to this function.
+
+ :meth:`_orm.registry.map_imperatively`
+
+ """
+ return _as_declarative(self, cls, cls.__dict__)
+
+ def map_imperatively(self, class_, local_table=None, **kw):
+ r"""Map a class imperatively.
+
+ In this form of mapping, the class is not scanned for any mapping
+ information. Instead, all mapping constructs are passed as
+ arguments.
+
+ This method is intended to be fully equivalent to the classic
+ SQLAlchemy :func:`_orm.mapper` function, except that it's in terms of
+ a particular registry.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ my_table = Table(
+ "my_table",
+ mapper_registry.metadata,
+ Column('id', Integer, primary_key=True)
+ )
+
+ class MyClass:
+ pass
+
+ mapper_registry.map_imperatively(MyClass, my_table)
+
+ See the section :ref:`orm_imperative_mapping` for complete background
+ and usage examples.
+
+ :param class\_: The class to be mapped. Corresponds to the
+ :paramref:`_orm.mapper.class_` parameter.
+
+ :param local_table: the :class:`_schema.Table` or other
+ :class:`_sql.FromClause` object that is the subject of the mapping.
+ Corresponds to the
+ :paramref:`_orm.mapper.local_table` parameter.
+
+ :param \**kw: all other keyword arguments are passed to the
+ :func:`_orm.mapper` function directly.
+
+ .. seealso::
+
+ :ref:`orm_imperative_mapping`
+
+ :ref:`orm_declarative_mapping`
+
+ """
+ return _mapper(self, class_, local_table, kw)
+
+
+mapperlib._legacy_registry = registry()
+
+
+@util.deprecated_params(
+ bind=(
+ "2.0",
+ "The ``bind`` argument to as_declarative is "
+ "deprecated and will be removed in SQLAlchemy 2.0.",
+ )
+)
+def as_declarative(**kw):
+ """
+ Class decorator which will adapt a given class into a
+ :func:`_orm.declarative_base`.
+
+ This function makes use of the :meth:`_orm.registry.as_declarative_base`
+ method, by first creating a :class:`_orm.registry` automatically
+ and then invoking the decorator.
+
+ E.g.::
+
+ from sqlalchemy.orm import as_declarative
+
+ @as_declarative()
+ class Base(object):
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+ id = Column(Integer, primary_key=True)
+
+ class MyMappedClass(Base):
+ # ...
+
+ .. seealso::
+
+ :meth:`_orm.registry.as_declarative_base`
+
+ """
+ bind, metadata, class_registry = (
+ kw.pop("bind", None),
+ kw.pop("metadata", None),
+ kw.pop("class_registry", None),
+ )
+
+ return registry(
+ _bind=bind, metadata=metadata, class_registry=class_registry
+ ).as_declarative_base(**kw)
+
+
+@inspection._inspects(DeclarativeMeta)
+def _inspect_decl_meta(cls):
+ mp = _inspect_mapped_class(cls)
+ if mp is None:
+ if _DeferredMapperConfig.has_cls(cls):
+ _DeferredMapperConfig.raise_unmapped_for_cls(cls)
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s has a deferred mapping on it. It is not yet "
+ "usable as a mapped class." % orm_exc._safe_cls_name(cls),
+ )
+ return mp
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
new file mode 100644
index 0000000..6e1c797
--- /dev/null
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -0,0 +1,1210 @@
+# ext/declarative/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Internal implementation for declarative."""
+from __future__ import absolute_import
+
+import collections
+import weakref
+
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import instrumentation
+from . import clsregistry
+from . import exc as orm_exc
+from . import mapper as mapperlib
+from .attributes import InstrumentedAttribute
+from .attributes import QueryableAttribute
+from .base import _is_mapped_class
+from .base import InspectionAttr
+from .descriptor_props import CompositeProperty
+from .descriptor_props import SynonymProperty
+from .interfaces import MapperProperty
+from .mapper import Mapper as mapper
+from .properties import ColumnProperty
+from .util import class_mapper
+from .. import event
+from .. import exc
+from .. import util
+from ..sql import expression
+from ..sql.schema import Column
+from ..sql.schema import Table
+from ..util import topological
+
+
+def _declared_mapping_info(cls):
+ # deferred mapping
+ if _DeferredMapperConfig.has_cls(cls):
+ return _DeferredMapperConfig.config_for_cls(cls)
+ # regular mapping
+ elif _is_mapped_class(cls):
+ return class_mapper(cls, configure=False)
+ else:
+ return None
+
+
+def _resolve_for_abstract_or_classical(cls):
+ if cls is object:
+ return None
+
+ if cls.__dict__.get("__abstract__", False):
+ for sup in cls.__bases__:
+ sup = _resolve_for_abstract_or_classical(sup)
+ if sup is not None:
+ return sup
+ else:
+ return None
+ else:
+ clsmanager = _dive_for_cls_manager(cls)
+
+ if clsmanager:
+ return clsmanager.class_
+ else:
+ return cls
+
+
+def _get_immediate_cls_attr(cls, attrname, strict=False):
+ """return an attribute of the class that is either present directly
+ on the class, e.g. not on a superclass, or is from a superclass but
+ this superclass is a non-mapped mixin, that is, not a descendant of
+ the declarative base and is also not classically mapped.
+
+ This is used to detect attributes that indicate something about
+ a mapped class independently from any mapped classes that it may
+ inherit from.
+
+ """
+
+ # the rules are different for this name than others,
+ # make sure we've moved it out. transitional
+ assert attrname != "__abstract__"
+
+ if not issubclass(cls, object):
+ return None
+
+ if attrname in cls.__dict__:
+ return getattr(cls, attrname)
+
+ for base in cls.__mro__[1:]:
+ _is_classicial_inherits = _dive_for_cls_manager(base)
+
+ if attrname in base.__dict__ and (
+ base is cls
+ or (
+ (base in cls.__bases__ if strict else True)
+ and not _is_classicial_inherits
+ )
+ ):
+ return getattr(base, attrname)
+ else:
+ return None
+
+
+def _dive_for_cls_manager(cls):
+ # because the class manager registration is pluggable,
+ # we need to do the search for every class in the hierarchy,
+ # rather than just a simple "cls._sa_class_manager"
+
+ # python 2 old style class
+ if not hasattr(cls, "__mro__"):
+ return None
+
+ for base in cls.__mro__:
+ manager = attributes.manager_of_class(base)
+ if manager:
+ return manager
+ return None
+
+
+def _as_declarative(registry, cls, dict_):
+
+ # declarative scans the class for attributes. no table or mapper
+ # args passed separately.
+
+ return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
+
+
+def _mapper(registry, cls, table, mapper_kw):
+ _ImperativeMapperConfig(registry, cls, table, mapper_kw)
+ return cls.__mapper__
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _is_declarative_props(obj):
+ declared_attr = util.preloaded.orm_decl_api.declared_attr
+
+ return isinstance(obj, (declared_attr, util.classproperty))
+
+
+def _check_declared_props_nocascade(obj, name, cls):
+ if _is_declarative_props(obj):
+ if getattr(obj, "_cascading", False):
+ util.warn(
+ "@declared_attr.cascading is not supported on the %s "
+ "attribute on class %s. This attribute invokes for "
+ "subclasses in any case." % (name, cls)
+ )
+ return True
+ else:
+ return False
+
+
+class _MapperConfig(object):
+ __slots__ = (
+ "cls",
+ "classname",
+ "properties",
+ "declared_attr_reg",
+ "__weakref__",
+ )
+
+ @classmethod
+ def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+ manager = attributes.manager_of_class(cls)
+ if manager and manager.class_ is cls_:
+ raise exc.InvalidRequestError(
+ "Class %r already has been " "instrumented declaratively" % cls
+ )
+
+ if cls_.__dict__.get("__abstract__", False):
+ return
+
+ defer_map = _get_immediate_cls_attr(
+ cls_, "_sa_decl_prepare_nocascade", strict=True
+ ) or hasattr(cls_, "_sa_decl_prepare")
+
+ if defer_map:
+ cfg_cls = _DeferredMapperConfig
+ else:
+ cfg_cls = _ClassScanMapperConfig
+
+ return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+
+ def __init__(self, registry, cls_, mapper_kw):
+ self.cls = util.assert_arg_type(cls_, type, "cls_")
+ self.classname = cls_.__name__
+ self.properties = util.OrderedDict()
+ self.declared_attr_reg = {}
+
+ if not mapper_kw.get("non_primary", False):
+ instrumentation.register_class(
+ self.cls,
+ finalize=False,
+ registry=registry,
+ declarative_scan=self,
+ init_method=registry.constructor,
+ )
+ else:
+ manager = attributes.manager_of_class(self.cls)
+ if not manager or not manager.is_mapped:
+ raise exc.InvalidRequestError(
+ "Class %s has no primary mapper configured. Configure "
+ "a primary mapper first before setting up a non primary "
+ "Mapper." % self.cls
+ )
+
+ def set_cls_attribute(self, attrname, value):
+
+ manager = instrumentation.manager_of_class(self.cls)
+ manager.install_member(attrname, value)
+ return value
+
+ def _early_mapping(self, mapper_kw):
+ self.map(mapper_kw)
+
+
+class _ImperativeMapperConfig(_MapperConfig):
+ __slots__ = ("dict_", "local_table", "inherits")
+
+ def __init__(
+ self,
+ registry,
+ cls_,
+ table,
+ mapper_kw,
+ ):
+ super(_ImperativeMapperConfig, self).__init__(
+ registry, cls_, mapper_kw
+ )
+
+ self.dict_ = {}
+ self.local_table = self.set_cls_attribute("__table__", table)
+
+ with mapperlib._CONFIGURE_MUTEX:
+ if not mapper_kw.get("non_primary", False):
+ clsregistry.add_class(
+ self.classname, self.cls, registry._class_registry
+ )
+
+ self._setup_inheritance(mapper_kw)
+
+ self._early_mapping(mapper_kw)
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ mapper_cls = mapper
+
+ return self.set_cls_attribute(
+ "__mapper__",
+ mapper_cls(self.cls, self.local_table, **mapper_kw),
+ )
+
+ def _setup_inheritance(self, mapper_kw):
+ cls = self.cls
+
+ inherits = mapper_kw.get("inherits", None)
+
+ if inherits is None:
+ # since we search for classical mappings now, search for
+ # multiple mapped bases as well and raise an error.
+ inherits_search = []
+ for c in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(c)
+ if c is None:
+ continue
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
+ inherits_search.append(c)
+
+ if inherits_search:
+ if len(inherits_search) > 1:
+ raise exc.InvalidRequestError(
+ "Class %s has multiple mapped bases: %r"
+ % (cls, inherits_search)
+ )
+ inherits = inherits_search[0]
+ elif isinstance(inherits, mapper):
+ inherits = inherits.class_
+
+ self.inherits = inherits
+
+
+class _ClassScanMapperConfig(_MapperConfig):
+ __slots__ = (
+ "dict_",
+ "local_table",
+ "persist_selectable",
+ "declared_columns",
+ "column_copies",
+ "table_args",
+ "tablename",
+ "mapper_args",
+ "mapper_args_fn",
+ "inherits",
+ )
+
+ def __init__(
+ self,
+ registry,
+ cls_,
+ dict_,
+ table,
+ mapper_kw,
+ ):
+
+ # grab class dict before the instrumentation manager has been added.
+ # reduces cycles
+ self.dict_ = dict(dict_) if dict_ else {}
+
+ super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
+
+ self.persist_selectable = None
+ self.declared_columns = set()
+ self.column_copies = {}
+ self._setup_declared_events()
+
+ self._scan_attributes()
+
+ with mapperlib._CONFIGURE_MUTEX:
+ clsregistry.add_class(
+ self.classname, self.cls, registry._class_registry
+ )
+
+ self._extract_mappable_attributes()
+
+ self._extract_declared_columns()
+
+ self._setup_table(table)
+
+ self._setup_inheritance(mapper_kw)
+
+ self._early_mapping(mapper_kw)
+
+ def _setup_declared_events(self):
+ if _get_immediate_cls_attr(self.cls, "__declare_last__"):
+
+ @event.listens_for(mapper, "after_configured")
+ def after_configured():
+ self.cls.__declare_last__()
+
+ if _get_immediate_cls_attr(self.cls, "__declare_first__"):
+
+ @event.listens_for(mapper, "before_configured")
+ def before_configured():
+ self.cls.__declare_first__()
+
+ def _cls_attr_override_checker(self, cls):
+ """Produce a function that checks if a class has overridden an
+ attribute, taking SQLAlchemy-enabled dataclass fields into account.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def attribute_is_overridden(key, obj):
+ return getattr(cls, key) is not obj
+
+ else:
+
+ all_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+ local_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.local_dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+
+ absent = object()
+
+ def attribute_is_overridden(key, obj):
+ if _is_declarative_props(obj):
+ obj = obj.fget
+
+ # this function likely has some failure modes still if
+ # someone is doing a deep mixing of the same attribute
+ # name as plain Python attribute vs. dataclass field.
+
+ ret = local_datacls_fields.get(key, absent)
+ if _is_declarative_props(ret):
+ ret = ret.fget
+
+ if ret is obj:
+ return False
+ elif ret is not absent:
+ return True
+
+ all_field = all_datacls_fields.get(key, absent)
+
+ ret = getattr(cls, key, obj)
+
+ if ret is obj:
+ return False
+
+ # for dataclasses, this could be the
+ # 'default' of the field. so filter more specifically
+ # for an already-mapped InstrumentedAttribute
+ if ret is not absent and isinstance(
+ ret, InstrumentedAttribute
+ ):
+ return True
+
+ if all_field is obj:
+ return False
+ elif all_field is not absent:
+ return True
+
+ # can't find another attribute
+ return False
+
+ return attribute_is_overridden
+
+ def _cls_attr_resolver(self, cls):
+ """produce a function to iterate the "attributes" of a class,
+ adjusting for SQLAlchemy fields embedded in dataclass fields.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def local_attributes_for_class():
+ for name, obj in vars(cls).items():
+ yield name, obj, False
+
+ else:
+ field_names = set()
+
+ def local_attributes_for_class():
+ for field in util.local_dataclass_fields(cls):
+ if sa_dataclass_metadata_key in field.metadata:
+ field_names.add(field.name)
+ yield field.name, _as_dc_declaredattr(
+ field.metadata, sa_dataclass_metadata_key
+ ), True
+ for name, obj in vars(cls).items():
+ if name not in field_names:
+ yield name, obj, False
+
+ return local_attributes_for_class
+
+ def _scan_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
+ column_copies = self.column_copies
+ mapper_args_fn = None
+ table_args = inherited_table_args = None
+ tablename = None
+
+ attribute_is_overridden = self._cls_attr_override_checker(self.cls)
+
+ bases = []
+
+ for base in cls.__mro__:
+ # collect bases and make sure standalone columns are copied
+ # to be the column they will ultimately be on the class,
+ # so that declared_attr functions use the right columns.
+ # need to do this all the way up the hierarchy first
+ # (see #8190)
+
+ class_mapped = (
+ base is not cls
+ and _declared_mapping_info(base) is not None
+ and not _get_immediate_cls_attr(
+ base, "_sa_decl_prepare_nocascade", strict=True
+ )
+ )
+
+ local_attributes_for_class = self._cls_attr_resolver(base)
+
+ if not class_mapped and base is not cls:
+ locally_collected_columns = self._produce_column_copies(
+ local_attributes_for_class,
+ attribute_is_overridden,
+ )
+ else:
+ locally_collected_columns = {}
+
+ bases.append(
+ (
+ base,
+ class_mapped,
+ local_attributes_for_class,
+ locally_collected_columns,
+ )
+ )
+
+ for (
+ base,
+ class_mapped,
+ local_attributes_for_class,
+ locally_collected_columns,
+ ) in bases:
+
+ # this transfer can also take place as we scan each name
+ # for finer-grained control of how collected_attributes is
+ # populated, as this is what impacts column ordering.
+ # however it's simpler to get it out of the way here.
+ dict_.update(locally_collected_columns)
+
+ for name, obj, is_dataclass in local_attributes_for_class():
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (not class_mapped or check_decl):
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def mapper_args_fn():
+ return dict(cls.__mapper_args__)
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ elif class_mapped:
+ if _is_declarative_props(obj):
+ util.warn(
+ "Regular (i.e. not __special__) "
+ "attribute '%s.%s' uses @declared_attr, "
+ "but owning class %s is mapped - "
+ "not applying to subclass %s."
+ % (base.__name__, name, base, cls)
+ )
+ continue
+ elif base is not cls:
+ # we're a mixin, abstract base, or something that is
+ # acting like that for now.
+ if isinstance(obj, Column):
+ # already copied columns to the mapped class.
+ continue
+ elif isinstance(obj, MapperProperty):
+ raise exc.InvalidRequestError(
+ "Mapper properties (i.e. deferred,"
+ "column_property(), relationship(), etc.) must "
+ "be declared as @declared_attr callables "
+ "on declarative mixin classes. For dataclass "
+ "field() objects, use a lambda:"
+ )
+ elif _is_declarative_props(obj):
+ if obj._cascading:
+ if name in dict_:
+ # unfortunately, while we can use the user-
+ # defined attribute here to allow a clean
+ # override, if there's another
+ # subclass below then it still tries to use
+ # this. not sure if there is enough
+ # information here to add this as a feature
+ # later on.
+ util.warn(
+ "Attribute '%s' on class %s cannot be "
+ "processed due to "
+ "@declared_attr.cascading; "
+ "skipping" % (name, cls)
+ )
+ dict_[name] = column_copies[
+ obj
+ ] = ret = obj.__get__(obj, cls)
+ setattr(cls, name, ret)
+ else:
+ if is_dataclass:
+ # access attribute using normal class access
+ # first, to see if it's been mapped on a
+ # superclass. note if the dataclasses.field()
+ # has "default", this value can be anything.
+ ret = getattr(cls, name, None)
+
+ # so, if it's anything that's not ORM
+ # mapped, assume we should invoke the
+ # declared_attr
+ if not isinstance(ret, InspectionAttr):
+ ret = obj.fget()
+ else:
+ # access attribute using normal class access.
+ # if the declared attr already took place
+ # on a superclass that is mapped, then
+ # this is no longer a declared_attr, it will
+ # be the InstrumentedAttribute
+ ret = getattr(cls, name)
+
+ # correct for proxies created from hybrid_property
+ # or similar. note there is no known case that
+ # produces nested proxies, so we are only
+ # looking one level deep right now.
+ if (
+ isinstance(ret, InspectionAttr)
+ and ret._is_internal_proxy
+ and not isinstance(
+ ret.original_property, MapperProperty
+ )
+ ):
+ ret = ret.descriptor
+
+ dict_[name] = column_copies[obj] = ret
+ if (
+ isinstance(ret, (Column, MapperProperty))
+ and ret.doc is None
+ ):
+ ret.doc = obj.__doc__
+ # here, the attribute is some other kind of property that
+ # we assume is not part of the declarative mapping.
+ # however, check for some more common mistakes
+ else:
+ self._warn_for_decl_attributes(base, name, obj)
+ elif is_dataclass and (
+ name not in dict_ or dict_[name] is not obj
+ ):
+ # here, we are definitely looking at the target class
+ # and not a superclass. this is currently a
+ # dataclass-only path. if the name is only
+ # a dataclass field and isn't in local cls.__dict__,
+ # put the object there.
+ # assert that the dataclass-enabled resolver agrees
+ # with what we are seeing
+
+ assert not attribute_is_overridden(name, obj)
+
+ if _is_declarative_props(obj):
+ obj = obj.fget()
+
+ dict_[name] = obj
+
+ if inherited_table_args and not tablename:
+ table_args = None
+
+ self.table_args = table_args
+ self.tablename = tablename
+ self.mapper_args_fn = mapper_args_fn
+
+ def _warn_for_decl_attributes(self, cls, key, c):
+ if isinstance(c, expression.ColumnClause):
+ util.warn(
+ "Attribute '%s' on class %s appears to be a non-schema "
+ "'sqlalchemy.sql.column()' "
+ "object; this won't be part of the declarative mapping"
+ % (key, cls)
+ )
+
+ def _produce_column_copies(
+ self, attributes_for_class, attribute_is_overridden
+ ):
+ cls = self.cls
+ dict_ = self.dict_
+ locally_collected_attributes = {}
+ column_copies = self.column_copies
+ # copy mixin columns to the mapped class
+
+ for name, obj, is_dataclass in attributes_for_class():
+ if isinstance(obj, Column):
+ if attribute_is_overridden(name, obj):
+ # if column has been overridden
+ # (like by the InstrumentedAttribute of the
+ # superclass), skip
+ continue
+ elif obj.foreign_keys:
+ raise exc.InvalidRequestError(
+ "Columns with foreign keys to other columns "
+ "must be declared as @declared_attr callables "
+ "on declarative mixin classes. For dataclass "
+ "field() objects, use a lambda:."
+ )
+ elif name not in dict_ and not (
+ "__table__" in dict_
+ and (obj.name or name) in dict_["__table__"].c
+ ):
+ column_copies[obj] = copy_ = obj._copy()
+ copy_._creation_order = obj._creation_order
+ setattr(cls, name, copy_)
+ locally_collected_attributes[name] = copy_
+ return locally_collected_attributes
+
+ def _extract_mappable_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
+
+ our_stuff = self.properties
+
+ late_mapped = _get_immediate_cls_attr(
+ cls, "_sa_decl_prepare_nocascade", strict=True
+ )
+
+ for k in list(dict_):
+
+ if k in ("__table__", "__tablename__", "__mapper_args__"):
+ continue
+
+ value = dict_[k]
+ if _is_declarative_props(value):
+ if value._cascading:
+ util.warn(
+ "Use of @declared_attr.cascading only applies to "
+ "Declarative 'mixin' and 'abstract' classes. "
+ "Currently, this flag is ignored on mapped class "
+ "%s" % self.cls
+ )
+
+ value = getattr(cls, k)
+
+ elif (
+ isinstance(value, QueryableAttribute)
+ and value.class_ is not cls
+ and value.key != k
+ ):
+ # detect a QueryableAttribute that's already mapped being
+ # assigned elsewhere in userland, turn into a synonym()
+ value = SynonymProperty(value.key)
+ setattr(cls, k, value)
+
+ if (
+ isinstance(value, tuple)
+ and len(value) == 1
+ and isinstance(value[0], (Column, MapperProperty))
+ ):
+ util.warn(
+ "Ignoring declarative-like tuple value of attribute "
+ "'%s': possibly a copy-and-paste error with a comma "
+ "accidentally placed at the end of the line?" % k
+ )
+ continue
+ elif not isinstance(value, (Column, MapperProperty)):
+ # using @declared_attr for some object that
+ # isn't Column/MapperProperty; remove from the dict_
+ # and place the evaluated value onto the class.
+ if not k.startswith("__"):
+ dict_.pop(k)
+ self._warn_for_decl_attributes(cls, k, value)
+ if not late_mapped:
+ setattr(cls, k, value)
+ continue
+ # we expect to see the name 'metadata' in some valid cases;
+ # however at this point we see it's assigned to something trying
+ # to be mapped, so raise for that.
+ elif k == "metadata":
+ raise exc.InvalidRequestError(
+ "Attribute name 'metadata' is reserved "
+ "for the MetaData instance when using a "
+ "declarative base class."
+ )
+ our_stuff[k] = value
+
+ def _extract_declared_columns(self):
+ our_stuff = self.properties
+
+ # set up attributes in the order they were created
+ util.sort_dictionary(
+ our_stuff, key=lambda key: our_stuff[key]._creation_order
+ )
+
+ # extract columns from the class dict
+ declared_columns = self.declared_columns
+ name_to_prop_key = collections.defaultdict(set)
+ for key, c in list(our_stuff.items()):
+ if isinstance(c, (ColumnProperty, CompositeProperty)):
+ for col in c.columns:
+ if isinstance(col, Column) and col.table is None:
+ _undefer_column_name(key, col)
+ if not isinstance(c, CompositeProperty):
+ name_to_prop_key[col.name].add(key)
+ declared_columns.add(col)
+ elif isinstance(c, Column):
+ _undefer_column_name(key, c)
+ name_to_prop_key[c.name].add(key)
+ declared_columns.add(c)
+ # if the column is the same name as the key,
+ # remove it from the explicit properties dict.
+ # the normal rules for assigning column-based properties
+ # will take over, including precedence of columns
+ # in multi-column ColumnProperties.
+ if key == c.key:
+ del our_stuff[key]
+
+ for name, keys in name_to_prop_key.items():
+ if len(keys) > 1:
+ util.warn(
+ "On class %r, Column object %r named "
+ "directly multiple times, "
+ "only one will be used: %s. "
+ "Consider using orm.synonym instead"
+ % (self.classname, name, (", ".join(sorted(keys))))
+ )
+
+ def _setup_table(self, table=None):
+ cls = self.cls
+ tablename = self.tablename
+ table_args = self.table_args
+ dict_ = self.dict_
+ declared_columns = self.declared_columns
+
+ manager = attributes.manager_of_class(cls)
+
+ declared_columns = self.declared_columns = sorted(
+ declared_columns, key=lambda c: c._creation_order
+ )
+
+ if "__table__" not in dict_ and table is None:
+ if hasattr(cls, "__table_cls__"):
+ table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ else:
+ table_cls = Table
+
+ if tablename is not None:
+
+ args, table_kw = (), {}
+ if table_args:
+ if isinstance(table_args, dict):
+ table_kw = table_args
+ elif isinstance(table_args, tuple):
+ if isinstance(table_args[-1], dict):
+ args, table_kw = table_args[0:-1], table_args[-1]
+ else:
+ args = table_args
+
+ autoload_with = dict_.get("__autoload_with__")
+ if autoload_with:
+ table_kw["autoload_with"] = autoload_with
+
+ autoload = dict_.get("__autoload__")
+ if autoload:
+ table_kw["autoload"] = True
+
+ table = self.set_cls_attribute(
+ "__table__",
+ table_cls(
+ tablename,
+ self._metadata_for_cls(manager),
+ *(tuple(declared_columns) + tuple(args)),
+ **table_kw
+ ),
+ )
+ else:
+ if table is None:
+ table = cls.__table__
+ if declared_columns:
+ for c in declared_columns:
+ if not table.c.contains_column(c):
+ raise exc.ArgumentError(
+ "Can't add additional column %r when "
+ "specifying __table__" % c.key
+ )
+ self.local_table = table
+
+ def _metadata_for_cls(self, manager):
+ if hasattr(self.cls, "metadata"):
+ return self.cls.metadata
+ else:
+ return manager.registry.metadata
+
+ def _setup_inheritance(self, mapper_kw):
+ table = self.local_table
+ cls = self.cls
+ table_args = self.table_args
+ declared_columns = self.declared_columns
+
+ inherits = mapper_kw.get("inherits", None)
+
+ if inherits is None:
+ # since we search for classical mappings now, search for
+ # multiple mapped bases as well and raise an error.
+ inherits_search = []
+ for c in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(c)
+ if c is None:
+ continue
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
+ if c not in inherits_search:
+ inherits_search.append(c)
+
+ if inherits_search:
+ if len(inherits_search) > 1:
+ raise exc.InvalidRequestError(
+ "Class %s has multiple mapped bases: %r"
+ % (cls, inherits_search)
+ )
+ inherits = inherits_search[0]
+ elif isinstance(inherits, mapper):
+ inherits = inherits.class_
+
+ self.inherits = inherits
+
+ if (
+ table is None
+ and self.inherits is None
+ and not _get_immediate_cls_attr(cls, "__no_table__")
+ ):
+
+ raise exc.InvalidRequestError(
+ "Class %r does not have a __table__ or __tablename__ "
+ "specified and does not inherit from an existing "
+ "table-mapped class." % cls
+ )
+ elif self.inherits:
+ inherited_mapper = _declared_mapping_info(self.inherits)
+ inherited_table = inherited_mapper.local_table
+ inherited_persist_selectable = inherited_mapper.persist_selectable
+
+ if table is None:
+ # single table inheritance.
+ # ensure no table args
+ if table_args:
+ raise exc.ArgumentError(
+ "Can't place __table_args__ on an inherited class "
+ "with no table."
+ )
+ # add any columns declared here to the inherited table.
+ for c in declared_columns:
+ if c.name in inherited_table.c:
+ if inherited_table.c[c.name] is c:
+ continue
+ raise exc.ArgumentError(
+ "Column '%s' on class %s conflicts with "
+ "existing column '%s'"
+ % (c, cls, inherited_table.c[c.name])
+ )
+ if c.primary_key:
+ raise exc.ArgumentError(
+ "Can't place primary key columns on an inherited "
+ "class with no table."
+ )
+ inherited_table.append_column(c)
+ if (
+ inherited_persist_selectable is not None
+ and inherited_persist_selectable is not inherited_table
+ ):
+ inherited_persist_selectable._refresh_for_new_column(c)
+
+ def _prepare_mapper_arguments(self, mapper_kw):
+ properties = self.properties
+
+ if self.mapper_args_fn:
+ mapper_args = self.mapper_args_fn()
+ else:
+ mapper_args = {}
+
+ if mapper_kw:
+ mapper_args.update(mapper_kw)
+
+ if "properties" in mapper_args:
+ properties = dict(properties)
+ properties.update(mapper_args["properties"])
+
+ # make sure that column copies are used rather
+ # than the original columns from any mixins
+ for k in ("version_id_col", "polymorphic_on"):
+ if k in mapper_args:
+ v = mapper_args[k]
+ mapper_args[k] = self.column_copies.get(v, v)
+
+ if "inherits" in mapper_args:
+ inherits_arg = mapper_args["inherits"]
+ if isinstance(inherits_arg, mapper):
+ inherits_arg = inherits_arg.class_
+
+ if inherits_arg is not self.inherits:
+ raise exc.InvalidRequestError(
+ "mapper inherits argument given for non-inheriting "
+ "class %s" % (mapper_args["inherits"])
+ )
+
+ if self.inherits:
+ mapper_args["inherits"] = self.inherits
+
+ if self.inherits and not mapper_args.get("concrete", False):
+ # single or joined inheritance
+ # exclude any cols on the inherited table which are
+ # not mapped on the parent class, to avoid
+ # mapping columns specific to sibling/nephew classes
+ inherited_mapper = _declared_mapping_info(self.inherits)
+ inherited_table = inherited_mapper.local_table
+
+ if "exclude_properties" not in mapper_args:
+ mapper_args["exclude_properties"] = exclude_properties = set(
+ [
+ c.key
+ for c in inherited_table.c
+ if c not in inherited_mapper._columntoproperty
+ ]
+ ).union(inherited_mapper.exclude_properties or ())
+ exclude_properties.difference_update(
+ [c.key for c in self.declared_columns]
+ )
+
+ # look through columns in the current mapper that
+ # are keyed to a propname different than the colname
+ # (if names were the same, we'd have popped it out above,
+ # in which case the mapper makes this combination).
+ # See if the superclass has a similar column property.
+ # If so, join them together.
+ for k, col in list(properties.items()):
+ if not isinstance(col, expression.ColumnElement):
+ continue
+ if k in inherited_mapper._props:
+ p = inherited_mapper._props[k]
+ if isinstance(p, ColumnProperty):
+ # note here we place the subclass column
+ # first. See [ticket:1892] for background.
+ properties[k] = [col] + p.columns
+ result_mapper_args = mapper_args.copy()
+ result_mapper_args["properties"] = properties
+ self.mapper_args = result_mapper_args
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ self._prepare_mapper_arguments(mapper_kw)
+ if hasattr(self.cls, "__mapper_cls__"):
+ mapper_cls = util.unbound_method_to_callable(
+ self.cls.__mapper_cls__
+ )
+ else:
+ mapper_cls = mapper
+
+ return self.set_cls_attribute(
+ "__mapper__",
+ mapper_cls(self.cls, self.local_table, **self.mapper_args),
+ )
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+ # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
+ # we can't write it because field.metadata is immutable :( so we have
+ # to go through extra trouble to compare these
+ decl_api = util.preloaded.orm_decl_api
+ obj = field_metadata[sa_dataclass_metadata_key]
+ if callable(obj) and not isinstance(obj, decl_api.declared_attr):
+ return decl_api.declared_attr(obj)
+ else:
+ return obj
+
+
+class _DeferredMapperConfig(_ClassScanMapperConfig):
+ _configs = util.OrderedDict()
+
+ def _early_mapping(self, mapper_kw):
+ pass
+
+ @property
+ def cls(self):
+ return self._cls()
+
+ @cls.setter
+ def cls(self, class_):
+ self._cls = weakref.ref(class_, self._remove_config_cls)
+ self._configs[self._cls] = self
+
+ @classmethod
+ def _remove_config_cls(cls, ref):
+ cls._configs.pop(ref, None)
+
+ @classmethod
+ def has_cls(cls, class_):
+ # 2.6 fails on weakref if class_ is an old style class
+ return isinstance(class_, type) and weakref.ref(class_) in cls._configs
+
+ @classmethod
+ def raise_unmapped_for_cls(cls, class_):
+ if hasattr(class_, "_sa_raise_deferred_config"):
+ class_._sa_raise_deferred_config()
+
+ raise orm_exc.UnmappedClassError(
+ class_,
+ msg="Class %s has a deferred mapping on it. It is not yet "
+ "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+ )
+
+ @classmethod
+ def config_for_cls(cls, class_):
+ return cls._configs[weakref.ref(class_)]
+
+ @classmethod
+ def classes_for_base(cls, base_cls, sort=True):
+ classes_for_base = [
+ m
+ for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
+ if cls_ is not None and issubclass(cls_, base_cls)
+ ]
+
+ if not sort:
+ return classes_for_base
+
+ all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
+
+ tuples = []
+ for m_cls in all_m_by_cls:
+ tuples.extend(
+ (all_m_by_cls[base_cls], all_m_by_cls[m_cls])
+ for base_cls in m_cls.__bases__
+ if base_cls in all_m_by_cls
+ )
+ return list(topological.sort(tuples, classes_for_base))
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ self._configs.pop(self._cls, None)
+ return super(_DeferredMapperConfig, self).map(mapper_kw)
+
+
+def _add_attribute(cls, key, value):
+ """add an attribute to an existing declarative class.
+
+ This runs through the logic to determine MapperProperty,
+ adds it to the Mapper, adds a column to the mapped Table, etc.
+
+ """
+
+ if "__mapper__" in cls.__dict__:
+ if isinstance(value, Column):
+ _undefer_column_name(key, value)
+ cls.__table__.append_column(value, replace_existing=True)
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, ColumnProperty):
+ for col in value.columns:
+ if isinstance(col, Column) and col.table is None:
+ _undefer_column_name(key, col)
+ cls.__table__.append_column(col, replace_existing=True)
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, MapperProperty):
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, QueryableAttribute) and value.key != key:
+ # detect a QueryableAttribute that's already mapped being
+ # assigned elsewhere in userland, turn into a synonym()
+ value = SynonymProperty(value.key)
+ cls.__mapper__.add_property(key, value)
+ else:
+ type.__setattr__(cls, key, value)
+ cls.__mapper__._expire_memoizations()
+ else:
+ type.__setattr__(cls, key, value)
+
+
+def _del_attribute(cls, key):
+
+ if (
+ "__mapper__" in cls.__dict__
+ and key in cls.__dict__
+ and not cls.__mapper__._dispose_called
+ ):
+ value = cls.__dict__[key]
+ if isinstance(
+ value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ ):
+ raise NotImplementedError(
+ "Can't un-map individual mapped attributes on a mapped class."
+ )
+ else:
+ type.__delattr__(cls, key)
+ cls.__mapper__._expire_memoizations()
+ else:
+ type.__delattr__(cls, key)
+
+
+def _declarative_constructor(self, **kwargs):
+ """A simple constructor that allows initialization from kwargs.
+
+ Sets attributes on the constructed instance using the names and
+ values in ``kwargs``.
+
+ Only keys that are present as
+ attributes of the instance's class are allowed. These could be,
+ for example, any mapped columns or relationships.
+ """
+ cls_ = type(self)
+ for k in kwargs:
+ if not hasattr(cls_, k):
+ raise TypeError(
+ "%r is an invalid keyword argument for %s" % (k, cls_.__name__)
+ )
+ setattr(self, k, kwargs[k])
+
+
+_declarative_constructor.__name__ = "__init__"
+
+
+def _undefer_column_name(key, column):
+ if column.key is None:
+ column.key = key
+ if column.name is None:
+ column.name = key
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
new file mode 100644
index 0000000..1b5be9a
--- /dev/null
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -0,0 +1,1290 @@
+# orm/dependency.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Relationship dependencies.
+
+"""
+
+from . import attributes
+from . import exc
+from . import sync
+from . import unitofwork
+from . import util as mapperutil
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import ONETOMANY
+from .. import exc as sa_exc
+from .. import sql
+from .. import util
+
+
+class DependencyProcessor(object):
+ def __init__(self, prop):
+ self.prop = prop
+ self.cascade = prop.cascade
+ self.mapper = prop.mapper
+ self.parent = prop.parent
+ self.secondary = prop.secondary
+ self.direction = prop.direction
+ self.post_update = prop.post_update
+ self.passive_deletes = prop.passive_deletes
+ self.passive_updates = prop.passive_updates
+ self.enable_typechecks = prop.enable_typechecks
+ if self.passive_deletes:
+ self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ self._passive_delete_flag = attributes.PASSIVE_OFF
+ if self.passive_updates:
+ self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ self._passive_update_flag = attributes.PASSIVE_OFF
+
+ self.sort_key = "%s_%s" % (self.parent._sort_key, prop.key)
+ self.key = prop.key
+ if not self.prop.synchronize_pairs:
+ raise sa_exc.ArgumentError(
+ "Can't build a DependencyProcessor for relationship %s. "
+ "No target attributes to populate between parent and "
+ "child are present" % self.prop
+ )
+
+ @classmethod
+ def from_relationship(cls, prop):
+ return _direction_to_processor[prop.direction](prop)
+
+ def hasparent(self, state):
+ """return True if the given object instance has a parent,
+ according to the ``InstrumentedAttribute`` handled by this
+ ``DependencyProcessor``.
+
+ """
+ return self.parent.class_manager.get_impl(self.key).hasparent(state)
+
+ def per_property_preprocessors(self, uow):
+ """establish actions and dependencies related to a flush.
+
+ These actions will operate on all relevant states in
+ the aggregate.
+
+ """
+ uow.register_preprocessor(self, True)
+
+ def per_property_flush_actions(self, uow):
+ after_save = unitofwork.ProcessAll(uow, self, False, True)
+ before_delete = unitofwork.ProcessAll(uow, self, True, True)
+
+ parent_saves = unitofwork.SaveUpdateAll(
+ uow, self.parent.primary_base_mapper
+ )
+ child_saves = unitofwork.SaveUpdateAll(
+ uow, self.mapper.primary_base_mapper
+ )
+
+ parent_deletes = unitofwork.DeleteAll(
+ uow, self.parent.primary_base_mapper
+ )
+ child_deletes = unitofwork.DeleteAll(
+ uow, self.mapper.primary_base_mapper
+ )
+
+ self.per_property_dependencies(
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ )
+
+ def per_state_flush_actions(self, uow, states, isdelete):
+ """establish actions and dependencies related to a flush.
+
+ These actions will operate on all relevant states
+ individually. This occurs only if there are cycles
+ in the 'aggregated' version of events.
+
+ """
+
+ child_base_mapper = self.mapper.primary_base_mapper
+ child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper)
+ child_deletes = unitofwork.DeleteAll(uow, child_base_mapper)
+
+ # locate and disable the aggregate processors
+ # for this dependency
+
+ if isdelete:
+ before_delete = unitofwork.ProcessAll(uow, self, True, True)
+ before_delete.disabled = True
+ else:
+ after_save = unitofwork.ProcessAll(uow, self, False, True)
+ after_save.disabled = True
+
+ # check if the "child" side is part of the cycle
+
+ if child_saves not in uow.cycles:
+ # based on the current dependencies we use, the saves/
+ # deletes should always be in the 'cycles' collection
+ # together. if this changes, we will have to break up
+ # this method a bit more.
+ assert child_deletes not in uow.cycles
+
+ # child side is not part of the cycle, so we will link per-state
+ # actions to the aggregate "saves", "deletes" actions
+ child_actions = [(child_saves, False), (child_deletes, True)]
+ child_in_cycles = False
+ else:
+ child_in_cycles = True
+
+ # check if the "parent" side is part of the cycle
+ if not isdelete:
+ parent_saves = unitofwork.SaveUpdateAll(
+ uow, self.parent.base_mapper
+ )
+ parent_deletes = before_delete = None
+ if parent_saves in uow.cycles:
+ parent_in_cycles = True
+ else:
+ parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper)
+ parent_saves = after_save = None
+ if parent_deletes in uow.cycles:
+ parent_in_cycles = True
+
+ # now create actions /dependencies for each state.
+
+ for state in states:
+ # detect if there's anything changed or loaded
+ # by a preprocessor on this state/attribute. In the
+ # case of deletes we may try to load missing items here as well.
+ sum_ = state.manager[self.key].impl.get_all_pending(
+ state,
+ state.dict,
+ self._passive_delete_flag
+ if isdelete
+ else attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ if not sum_:
+ continue
+
+ if isdelete:
+ before_delete = unitofwork.ProcessState(uow, self, True, state)
+ if parent_in_cycles:
+ parent_deletes = unitofwork.DeleteState(uow, state)
+ else:
+ after_save = unitofwork.ProcessState(uow, self, False, state)
+ if parent_in_cycles:
+ parent_saves = unitofwork.SaveUpdateState(uow, state)
+
+ if child_in_cycles:
+ child_actions = []
+ for child_state, child in sum_:
+ if child_state not in uow.states:
+ child_action = (None, None)
+ else:
+ (deleted, listonly) = uow.states[child_state]
+ if deleted:
+ child_action = (
+ unitofwork.DeleteState(uow, child_state),
+ True,
+ )
+ else:
+ child_action = (
+ unitofwork.SaveUpdateState(uow, child_state),
+ False,
+ )
+ child_actions.append(child_action)
+
+ # establish dependencies between our possibly per-state
+ # parent action and our possibly per-state child action.
+ for child_action, childisdelete in child_actions:
+ self.per_state_dependencies(
+ uow,
+ parent_saves,
+ parent_deletes,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ return False
+
+ def presort_saves(self, uowcommit, states):
+ return False
+
+ def process_deletes(self, uowcommit, states):
+ pass
+
+ def process_saves(self, uowcommit, states):
+ pass
+
+ def prop_has_changes(self, uowcommit, states, isdelete):
+ if not isdelete or self.passive_deletes:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ elif self.direction is MANYTOONE:
+ # here, we were hoping to optimize having to fetch many-to-one
+ # for history and ignore it, if there's no further cascades
+ # to take place. however there are too many less common conditions
+ # that still take place and tests in test_relationships /
+ # test_cascade etc. will still fail.
+ passive = attributes.PASSIVE_NO_FETCH_RELATED
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ for s in states:
+ # TODO: add a high speed method
+ # to InstanceState which returns: attribute
+ # has a non-None value, or had one
+ history = uowcommit.get_attribute_history(s, self.key, passive)
+ if history and not history.empty():
+ return True
+ else:
+ return (
+ states
+ and not self.prop._is_self_referential
+ and self.mapper in uowcommit.mappers
+ )
+
+ def _verify_canload(self, state):
+ if self.prop.uselist and state is None:
+ raise exc.FlushError(
+ "Can't flush None value found in "
+ "collection %s" % (self.prop,)
+ )
+ elif state is not None and not self.mapper._canload(
+ state, allow_subtypes=not self.enable_typechecks
+ ):
+ if self.mapper._canload(state, allow_subtypes=True):
+ raise exc.FlushError(
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
+ '"%(y)s". Expected an object of type '
+ "%(z)s or a polymorphic subclass of "
+ "this type. If %(x)s is a subclass of "
+ '%(z)s, configure mapper "%(zm)s" to '
+ "load this subtype polymorphically, or "
+ "set enable_typechecks=False to allow "
+ "any subtype to be accepted for flush. "
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ "zm": self.mapper,
+ }
+ )
+ else:
+ raise exc.FlushError(
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
+ '"%(y)s". Expected an object of type '
+ "%(z)s or a polymorphic subclass of "
+ "this type."
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ }
+ )
+
+ def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
+ raise NotImplementedError()
+
+ def _get_reversed_processed_set(self, uow):
+ if not self.prop._reverse_property:
+ return None
+
+ process_key = tuple(
+ sorted([self.key] + [p.key for p in self.prop._reverse_property])
+ )
+ return uow.memo(("reverse_key", process_key), set)
+
+ def _post_update(self, state, uowcommit, related, is_m2o_delete=False):
+ for x in related:
+ if not is_m2o_delete or x is not None:
+ uowcommit.register_post_update(
+ state, [r for l, r in self.prop.synchronize_pairs]
+ )
+ break
+
+ def _pks_changed(self, uowcommit, state):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.prop)
+
+
+class OneToManyDP(DependencyProcessor):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+ if self.post_update:
+ child_post_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, False
+ )
+ child_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, child_post_updates),
+ (before_delete, child_pre_updates),
+ (child_pre_updates, parent_deletes),
+ (child_pre_updates, child_deletes),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (after_save, child_saves),
+ (after_save, child_deletes),
+ (child_saves, parent_deletes),
+ (child_deletes, parent_deletes),
+ (before_delete, child_saves),
+ (before_delete, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+
+ if self.post_update:
+
+ child_post_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, False
+ )
+ child_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, True
+ )
+
+ # TODO: this whole block is not covered
+ # by any tests
+ if not isdelete:
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
+ else:
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
+ elif not isdelete:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (after_save, child_action),
+ (save_parent, child_action),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [(before_delete, child_action), (child_action, delete_parent)]
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ # head object is being deleted, and we manage its list of
+ # child objects the child objects have to have their
+ # foreign key to the parent set to NULL
+ should_null_fks = (
+ not self.cascade.delete and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if child is not None and self.hasparent(child) is False:
+ if self.cascade.delete_orphan:
+ uowcommit.register_object(child, isdelete=True)
+ else:
+ uowcommit.register_object(child)
+
+ if should_null_fks:
+ for child in history.unchanged:
+ if child is not None:
+ uowcommit.register_object(
+ child, operation="delete", prop=self.prop
+ )
+
+ def presort_saves(self, uowcommit, states):
+ children_added = uowcommit.memo(("children_added", self), set)
+
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ pks_changed = self._pks_changed(uowcommit, state)
+
+ if not pks_changed or self.passive_updates:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ history = uowcommit.get_attribute_history(state, self.key, passive)
+ if history:
+ for child in history.added:
+ if child is not None:
+ uowcommit.register_object(
+ child,
+ cancel_delete=True,
+ operation="add",
+ prop=self.prop,
+ )
+
+ children_added.update(history.added)
+
+ for child in history.deleted:
+ if not self.cascade.delete_orphan:
+ if should_null_fks:
+ uowcommit.register_object(
+ child,
+ isdelete=False,
+ operation="delete",
+ prop=self.prop,
+ )
+ elif self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
+
+ if pks_changed:
+ if history:
+ for child in history.unchanged:
+ if child is not None:
+ uowcommit.register_object(
+ child,
+ False,
+ self.passive_updates,
+ operation="pk change",
+ prop=self.prop,
+ )
+
+ def process_deletes(self, uowcommit, states):
+ # head object is being deleted, and we manage its list of
+ # child objects the child objects have to have their foreign
+ # key to the parent set to NULL this phase can be called
+ # safely for any cascade but is unnecessary if delete cascade
+ # is on.
+
+ if self.post_update or not self.passive_deletes == "all":
+ children_added = uowcommit.memo(("children_added", self), set)
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if (
+ child is not None
+ and self.hasparent(child) is False
+ ):
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+ if self.post_update and child:
+ self._post_update(child, uowcommit, [state])
+
+ if self.post_update or not self.cascade.delete:
+ for child in set(history.unchanged).difference(
+ children_added
+ ):
+ if child is not None:
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+ if self.post_update and child:
+ self._post_update(
+ child, uowcommit, [state]
+ )
+
+ # technically, we can even remove each child from the
+ # collection here too. but this would be a somewhat
+ # inconsistent behavior since it wouldn't happen
+ # if the old parent wasn't deleted but child was moved.
+
+ def process_saves(self, uowcommit, states):
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ for child in history.added:
+ self._synchronize(
+ state, child, None, False, uowcommit, False
+ )
+ if child is not None and self.post_update:
+ self._post_update(child, uowcommit, [state])
+
+ for child in history.deleted:
+ if (
+ should_null_fks
+ and not self.cascade.delete_orphan
+ and not self.hasparent(child)
+ ):
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+
+ if self._pks_changed(uowcommit, state):
+ for child in history.unchanged:
+ self._synchronize(
+ state, child, None, False, uowcommit, True
+ )
+
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, pks_changed
+ ):
+ source = state
+ dest = child
+ self._verify_canload(child)
+ if dest is None or (
+ not self.post_update and uowcommit.is_deleted(dest)
+ ):
+ return
+ if clearkeys:
+ sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
+ else:
+ sync.populate(
+ source,
+ self.parent,
+ dest,
+ self.mapper,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates and pks_changed,
+ )
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_modified(
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
+
+
+class ManyToOneDP(DependencyProcessor):
+ def __init__(self, prop):
+ DependencyProcessor.__init__(self, prop)
+ for mapper in self.mapper.self_and_descendants:
+ mapper._dependency_processors.append(DetectKeySwitch(prop))
+
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+
+ if self.post_update:
+ parent_post_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, False
+ )
+ parent_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, parent_post_updates),
+ (after_save, parent_pre_updates),
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, child_deletes),
+ (parent_pre_updates, parent_deletes),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (after_save, parent_saves),
+ (parent_saves, child_deletes),
+ (parent_deletes, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+
+ if self.post_update:
+
+ if not isdelete:
+ parent_post_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, False
+ )
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (after_save, parent_post_updates),
+ (parent_post_updates, child_action),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, parent_post_updates),
+ ]
+ )
+ else:
+ parent_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, delete_parent),
+ (parent_pre_updates, child_action),
+ ]
+ )
+
+ elif not isdelete:
+ if not childisdelete:
+ uow.dependencies.update(
+ [(child_action, after_save), (after_save, save_parent)]
+ )
+ else:
+ uow.dependencies.update([(after_save, save_parent)])
+
+ else:
+ if childisdelete:
+ uow.dependencies.update([(delete_parent, child_action)])
+
+ def presort_deletes(self, uowcommit, states):
+ if self.cascade.delete or self.cascade.delete_orphan:
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ if self.cascade.delete_orphan:
+ todelete = history.sum()
+ else:
+ todelete = history.non_deleted()
+ for child in todelete:
+ if child is None:
+ continue
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ t = self.mapper.cascade_iterator("delete", child)
+ for c, m, st_, dct_ in t:
+ uowcommit.register_object(st_, isdelete=True)
+
+ def presort_saves(self, uowcommit, states):
+ for state in states:
+ uowcommit.register_object(state, operation="add", prop=self.prop)
+ if self.cascade.delete_orphan:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+
+ t = self.mapper.cascade_iterator("delete", child)
+ for c, m, st_, dct_ in t:
+ uowcommit.register_object(st_, isdelete=True)
+
+ def process_deletes(self, uowcommit, states):
+ if (
+ self.post_update
+ and not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ ):
+
+ # post_update means we have to update our
+ # row to not reference the child object
+ # before we can DELETE the row
+ for state in states:
+ self._synchronize(state, None, None, True, uowcommit)
+ if state and self.post_update:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ self._post_update(
+ state, uowcommit, history.sum(), is_m2o_delete=True
+ )
+
+ def process_saves(self, uowcommit, states):
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ if history.added:
+ for child in history.added:
+ self._synchronize(
+ state, child, None, False, uowcommit, "add"
+ )
+ elif history.deleted:
+ self._synchronize(
+ state, None, None, True, uowcommit, "delete"
+ )
+ if self.post_update:
+ self._post_update(state, uowcommit, history.sum())
+
+ def _synchronize(
+ self,
+ state,
+ child,
+ associationrow,
+ clearkeys,
+ uowcommit,
+ operation=None,
+ ):
+ if state is None or (
+ not self.post_update and uowcommit.is_deleted(state)
+ ):
+ return
+
+ if (
+ operation is not None
+ and child is not None
+ and not uowcommit.session._contains_state(child)
+ ):
+ util.warn(
+ "Object of type %s not in session, %s "
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
+ return
+
+ if clearkeys or child is None:
+ sync.clear(state, self.parent, self.prop.synchronize_pairs)
+ else:
+ self._verify_canload(child)
+ sync.populate(
+ child,
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ False,
+ )
+
+
+class DetectKeySwitch(DependencyProcessor):
+ """For many-to-one relationships with no one-to-many backref,
+ searches for parents through the unit of work when a primary
+ key has changed and updates them.
+
+ Theoretically, this approach could be expanded to support transparent
+ deletion of objects referenced via many-to-one as well, although
+ the current attribute system doesn't do enough bookkeeping for this
+ to be efficient.
+
+ """
+
+ def per_property_preprocessors(self, uow):
+ if self.prop._reverse_property:
+ if self.passive_updates:
+ return
+ else:
+ if False in (
+ prop.passive_updates
+ for prop in self.prop._reverse_property
+ ):
+ return
+
+ uow.register_preprocessor(self, False)
+
+ def per_property_flush_actions(self, uow):
+ parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
+ after_save = unitofwork.ProcessAll(uow, self, False, False)
+ uow.dependencies.update([(parent_saves, after_save)])
+
+ def per_state_flush_actions(self, uow, states, isdelete):
+ pass
+
+ def presort_deletes(self, uowcommit, states):
+ pass
+
+ def presort_saves(self, uow, states):
+ if not self.passive_updates:
+ # for non-passive updates, register in the preprocess stage
+ # so that mapper save_obj() gets a hold of changes
+ self._process_key_switches(states, uow)
+
+ def prop_has_changes(self, uow, states, isdelete):
+ if not isdelete and self.passive_updates:
+ d = self._key_switchers(uow, states)
+ return bool(d)
+
+ return False
+
+ def process_deletes(self, uowcommit, states):
+ assert False
+
+ def process_saves(self, uowcommit, states):
+ # for passive updates, register objects in the process stage
+ # so that we avoid ManyToOneDP's registering the object without
+ # the listonly flag in its own preprocess stage (results in UPDATE)
+ # statements being emitted
+ assert self.passive_updates
+ self._process_key_switches(states, uowcommit)
+
+ def _key_switchers(self, uow, states):
+ switched, notswitched = uow.memo(
+ ("pk_switchers", self), lambda: (set(), set())
+ )
+
+ allstates = switched.union(notswitched)
+ for s in states:
+ if s not in allstates:
+ if self._pks_changed(uow, s):
+ switched.add(s)
+ else:
+ notswitched.add(s)
+ return switched
+
+ def _process_key_switches(self, deplist, uowcommit):
+ switchers = self._key_switchers(uowcommit, deplist)
+ if switchers:
+ # if primary key values have actually changed somewhere, perform
+ # a linear search through the UOW in search of a parent.
+ for state in uowcommit.session.identity_map.all_states():
+ if not issubclass(state.class_, self.parent.class_):
+ continue
+ dict_ = state.dict
+ related = state.get_impl(self.key).get(
+ state, dict_, passive=self._passive_update_flag
+ )
+ if (
+ related is not attributes.PASSIVE_NO_RESULT
+ and related is not None
+ ):
+ if self.prop.uselist:
+ if not related:
+ continue
+ related_obj = related[0]
+ else:
+ related_obj = related
+ related_state = attributes.instance_state(related_obj)
+ if related_state in switchers:
+ uowcommit.register_object(
+ state, False, self.passive_updates
+ )
+ sync.populate(
+ related_state,
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates,
+ )
+
+ def _pks_changed(self, uowcommit, state):
+ return bool(state.key) and sync.source_modified(
+ uowcommit, state, self.mapper, self.prop.synchronize_pairs
+ )
+
+
+class ManyToManyDP(DependencyProcessor):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (child_saves, after_save),
+ (after_save, child_deletes),
+ # a rowswitch on the parent from deleted to saved
+ # can make this one occur, as the "save" may remove
+ # an element from the
+ # "deleted" list before we have a chance to
+ # process its child rows
+ (before_delete, parent_saves),
+ (before_delete, parent_deletes),
+ (before_delete, child_deletes),
+ (before_delete, child_saves),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+ if not isdelete:
+ if childisdelete:
+ uow.dependencies.update(
+ [(save_parent, after_save), (after_save, child_action)]
+ )
+ else:
+ uow.dependencies.update(
+ [(save_parent, after_save), (child_action, after_save)]
+ )
+ else:
+ uow.dependencies.update(
+ [(before_delete, child_action), (before_delete, delete_parent)]
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ # TODO: no tests fail if this whole
+ # thing is removed !!!!
+ if not self.passive_deletes:
+ # if no passive deletes, load history on
+ # the collection, so that prop_has_changes()
+ # returns True
+ for state in states:
+ uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+
+ def presort_saves(self, uowcommit, states):
+ if not self.passive_updates:
+ # if no passive updates, load history on
+ # each collection where parent has changed PK,
+ # so that prop_has_changes() returns True
+ for state in states:
+ if self._pks_changed(uowcommit, state):
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_OFF
+ )
+
+ if not self.cascade.delete_orphan:
+ return
+
+ # check for child items removed from the collection
+ # if delete_orphan check is turned on.
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ for child in history.deleted:
+ if self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
+
+ def process_deletes(self, uowcommit, states):
+ secondary_delete = []
+ secondary_insert = []
+ secondary_update = []
+
+ processed = self._get_reversed_processed_set(uowcommit)
+ tmp = set()
+ for state in states:
+ # this history should be cached already, as
+ # we loaded it in preprocess_deletes
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.non_added():
+ if child is None or (
+ processed is not None and (state, child) in processed
+ ):
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
+ continue
+ secondary_delete.append(associationrow)
+
+ tmp.update((c, state) for c in history.non_added())
+
+ if processed is not None:
+ processed.update(tmp)
+
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
+
+ def process_saves(self, uowcommit, states):
+ secondary_delete = []
+ secondary_insert = []
+ secondary_update = []
+
+ processed = self._get_reversed_processed_set(uowcommit)
+ tmp = set()
+
+ for state in states:
+ need_cascade_pks = not self.passive_updates and self._pks_changed(
+ uowcommit, state
+ )
+ if need_cascade_pks:
+ passive = attributes.PASSIVE_OFF
+ else:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ history = uowcommit.get_attribute_history(state, self.key, passive)
+ if history:
+ for child in history.added:
+ if processed is not None and (state, child) in processed:
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state, child, associationrow, False, uowcommit, "add"
+ ):
+ continue
+ secondary_insert.append(associationrow)
+ for child in history.deleted:
+ if processed is not None and (state, child) in processed:
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
+ continue
+ secondary_delete.append(associationrow)
+
+ tmp.update((c, state) for c in history.added + history.deleted)
+
+ if need_cascade_pks:
+
+ for child in history.unchanged:
+ associationrow = {}
+ sync.update(
+ state,
+ self.parent,
+ associationrow,
+ "old_",
+ self.prop.synchronize_pairs,
+ )
+ sync.update(
+ child,
+ self.mapper,
+ associationrow,
+ "old_",
+ self.prop.secondary_synchronize_pairs,
+ )
+
+ secondary_update.append(associationrow)
+
+ if processed is not None:
+ processed.update(tmp)
+
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
+
+ def _run_crud(
+ self, uowcommit, secondary_insert, secondary_update, secondary_delete
+ ):
+ connection = uowcommit.transaction.connection(self.mapper)
+
+ if secondary_delete:
+ associationrow = secondary_delete[0]
+ statement = self.secondary.delete().where(
+ sql.and_(
+ *[
+ c == sql.bindparam(c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
+ result = connection.execute(statement, secondary_delete)
+
+ if (
+ result.supports_sane_multi_rowcount()
+ ) and result.rowcount != len(secondary_delete):
+ raise exc.StaleDataError(
+ "DELETE statement on table '%s' expected to delete "
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_delete),
+ result.rowcount,
+ )
+ )
+
+ if secondary_update:
+ associationrow = secondary_update[0]
+ statement = self.secondary.update().where(
+ sql.and_(
+ *[
+ c == sql.bindparam("old_" + c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
+ result = connection.execute(statement, secondary_update)
+
+ if (
+ result.supports_sane_multi_rowcount()
+ ) and result.rowcount != len(secondary_update):
+ raise exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to update "
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_update),
+ result.rowcount,
+ )
+ )
+
+ if secondary_insert:
+ statement = self.secondary.insert()
+ connection.execute(statement, secondary_insert)
+
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, operation
+ ):
+
+ # this checks for None if uselist=True
+ self._verify_canload(child)
+
+ # but if uselist=False we get here. If child is None,
+ # no association row can be generated, so return.
+ if child is None:
+ return False
+
+ if child is not None and not uowcommit.session._contains_state(child):
+ if not child.deleted:
+ util.warn(
+ "Object of type %s not in session, %s "
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
+ return False
+
+ sync.populate_dict(
+ state, self.parent, associationrow, self.prop.synchronize_pairs
+ )
+ sync.populate_dict(
+ child,
+ self.mapper,
+ associationrow,
+ self.prop.secondary_synchronize_pairs,
+ )
+
+ return True
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_modified(
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
+
+
+_direction_to_processor = {
+ ONETOMANY: OneToManyDP,
+ MANYTOONE: ManyToOneDP,
+ MANYTOMANY: ManyToManyDP,
+}
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
new file mode 100644
index 0000000..3d7f23b
--- /dev/null
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -0,0 +1,745 @@
+# orm/descriptor_props.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Descriptor properties are more "auxiliary" properties
+that exist as configurational elements, but don't participate
+as actively in the load/persist ORM loop.
+
+"""
+
+from . import attributes
+from . import util as orm_util
+from .interfaces import MapperProperty
+from .interfaces import PropComparator
+from .util import _none_set
+from .. import event
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import util
+from ..sql import expression
+from ..sql import operators
+
+
+class DescriptorProperty(MapperProperty):
+ """:class:`.MapperProperty` which proxies access to a
+ user-defined descriptor."""
+
+ doc = None
+
+ uses_objects = False
+ _links_to_entity = False
+
+ def instrument_class(self, mapper):
+ prop = self
+
+ class _ProxyImpl(object):
+ accepts_scalar_loader = False
+ load_on_unexpire = True
+ collection = False
+
+ @property
+ def uses_objects(self):
+ return prop.uses_objects
+
+ def __init__(self, key):
+ self.key = key
+
+ if hasattr(prop, "get_history"):
+
+ def get_history(
+ self, state, dict_, passive=attributes.PASSIVE_OFF
+ ):
+ return prop.get_history(state, dict_, passive)
+
+ if self.descriptor is None:
+ desc = getattr(mapper.class_, self.key, None)
+ if mapper._is_userland_descriptor(self.key, desc):
+ self.descriptor = desc
+
+ if self.descriptor is None:
+
+ def fset(obj, value):
+ setattr(obj, self.name, value)
+
+ def fdel(obj):
+ delattr(obj, self.name)
+
+ def fget(obj):
+ return getattr(obj, self.name)
+
+ self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
+
+ proxy_attr = attributes.create_proxied_attribute(self.descriptor)(
+ self.parent.class_,
+ self.key,
+ self.descriptor,
+ lambda: self._comparator_factory(mapper),
+ doc=self.doc,
+ original_property=self,
+ )
+ proxy_attr.impl = _ProxyImpl(self.key)
+ mapper.class_manager.instrument_attribute(self.key, proxy_attr)
+
+
+class CompositeProperty(DescriptorProperty):
+ """Defines a "composite" mapped attribute, representing a collection
+ of columns as one attribute.
+
+ :class:`.CompositeProperty` is constructed using the :func:`.composite`
+ function.
+
+ .. seealso::
+
+ :ref:`mapper_composite`
+
+ """
+
+ def __init__(self, class_, *attrs, **kwargs):
+ r"""Return a composite column-based property for use with a Mapper.
+
+ See the mapping documentation section :ref:`mapper_composite` for a
+ full usage example.
+
+ The :class:`.MapperProperty` returned by :func:`.composite`
+ is the :class:`.CompositeProperty`.
+
+ :param class\_:
+ The "composite type" class, or any classmethod or callable which
+ will produce a new instance of the composite object given the
+ column values in order.
+
+ :param \*cols:
+ List of Column objects to be mapped.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ scalar attribute should be loaded when replaced, if not
+ already loaded. See the same flag on :func:`.column_property`.
+
+ :param group:
+ A group name for this property when marked as deferred.
+
+ :param deferred:
+ When True, the column property is "deferred", meaning that it does
+ not load immediately, and is instead loaded when the attribute is
+ first accessed on an instance. See also
+ :func:`~sqlalchemy.orm.deferred`.
+
+ :param comparator_factory: a class which extends
+ :class:`.CompositeProperty.Comparator` which provides custom SQL
+ clause generation for comparison operations.
+
+ :param doc:
+ optional string that will be applied as the doc on the
+ class-bound descriptor.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ """
+ super(CompositeProperty, self).__init__()
+
+ self.attrs = attrs
+ self.composite_class = class_
+ self.active_history = kwargs.get("active_history", False)
+ self.deferred = kwargs.get("deferred", False)
+ self.group = kwargs.get("group", None)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ util.set_creation_order(self)
+ self._create_descriptor()
+
+ def instrument_class(self, mapper):
+ super(CompositeProperty, self).instrument_class(mapper)
+ self._setup_event_handlers()
+
+ def do_init(self):
+ """Initialization which occurs after the :class:`.CompositeProperty`
+ has been associated with its parent mapper.
+
+ """
+ self._setup_arguments_on_columns()
+
+ _COMPOSITE_FGET = object()
+
+ def _create_descriptor(self):
+ """Create the Python descriptor that will serve as
+ the access point on instances of the mapped class.
+
+ """
+
+ def fget(instance):
+ dict_ = attributes.instance_dict(instance)
+ state = attributes.instance_state(instance)
+
+ if self.key not in dict_:
+ # key not present. Iterate through related
+ # attributes, retrieve their values. This
+ # ensures they all load.
+ values = [
+ getattr(instance, key) for key in self._attribute_keys
+ ]
+
+ # current expected behavior here is that the composite is
+ # created on access if the object is persistent or if
+ # col attributes have non-None. This would be better
+ # if the composite were created unconditionally,
+ # but that would be a behavioral change.
+ if self.key not in dict_ and (
+ state.key is not None or not _none_set.issuperset(values)
+ ):
+ dict_[self.key] = self.composite_class(*values)
+ state.manager.dispatch.refresh(
+ state, self._COMPOSITE_FGET, [self.key]
+ )
+
+ return dict_.get(self.key, None)
+
+ def fset(instance, value):
+ dict_ = attributes.instance_dict(instance)
+ state = attributes.instance_state(instance)
+ attr = state.manager[self.key]
+ previous = dict_.get(self.key, attributes.NO_VALUE)
+ for fn in attr.dispatch.set:
+ value = fn(state, value, previous, attr.impl)
+ dict_[self.key] = value
+ if value is None:
+ for key in self._attribute_keys:
+ setattr(instance, key, None)
+ else:
+ for key, value in zip(
+ self._attribute_keys, value.__composite_values__()
+ ):
+ setattr(instance, key, value)
+
+ def fdel(instance):
+ state = attributes.instance_state(instance)
+ dict_ = attributes.instance_dict(instance)
+ previous = dict_.pop(self.key, attributes.NO_VALUE)
+ attr = state.manager[self.key]
+ attr.dispatch.remove(state, previous, attr.impl)
+ for key in self._attribute_keys:
+ setattr(instance, key, None)
+
+ self.descriptor = property(fget, fset, fdel)
+
+ @util.memoized_property
+ def _comparable_elements(self):
+ return [getattr(self.parent.class_, prop.key) for prop in self.props]
+
+ @util.memoized_property
+ def props(self):
+ props = []
+ for attr in self.attrs:
+ if isinstance(attr, str):
+ prop = self.parent.get_property(attr, _configure_mappers=False)
+ elif isinstance(attr, schema.Column):
+ prop = self.parent._columntoproperty[attr]
+ elif isinstance(attr, attributes.InstrumentedAttribute):
+ prop = attr.property
+ else:
+ raise sa_exc.ArgumentError(
+ "Composite expects Column objects or mapped "
+ "attributes/attribute names as arguments, got: %r"
+ % (attr,)
+ )
+ props.append(prop)
+ return props
+
+ @property
+ def columns(self):
+ return [a for a in self.attrs if isinstance(a, schema.Column)]
+
+ def _setup_arguments_on_columns(self):
+ """Propagate configuration arguments made on this composite
+ to the target columns, for those that apply.
+
+ """
+ for prop in self.props:
+ prop.active_history = self.active_history
+ if self.deferred:
+ prop.deferred = self.deferred
+ prop.strategy_key = (("deferred", True), ("instrument", True))
+ prop.group = self.group
+
+ def _setup_event_handlers(self):
+ """Establish events that populate/expire the composite attribute."""
+
+ def load_handler(state, context):
+ _load_refresh_handler(state, context, None, is_refresh=False)
+
+ def refresh_handler(state, context, to_load):
+ # note this corresponds to sqlalchemy.ext.mutable load_attrs()
+
+ if not to_load or (
+ {self.key}.union(self._attribute_keys)
+ ).intersection(to_load):
+ _load_refresh_handler(state, context, to_load, is_refresh=True)
+
+ def _load_refresh_handler(state, context, to_load, is_refresh):
+ dict_ = state.dict
+
+ # if context indicates we are coming from the
+ # fget() handler, this already set the value; skip the
+ # handler here. (other handlers like mutablecomposite will still
+ # want to catch it)
+ # there's an insufficiency here in that the fget() handler
+ # really should not be using the refresh event and there should
+ # be some other event that mutablecomposite can subscribe
+ # towards for this.
+
+ if (
+ not is_refresh or context is self._COMPOSITE_FGET
+ ) and self.key in dict_:
+ return
+
+ # if column elements aren't loaded, skip.
+ # __get__() will initiate a load for those
+ # columns
+ for k in self._attribute_keys:
+ if k not in dict_:
+ return
+
+ dict_[self.key] = self.composite_class(
+ *[state.dict[key] for key in self._attribute_keys]
+ )
+
+ def expire_handler(state, keys):
+ if keys is None or set(self._attribute_keys).intersection(keys):
+ state.dict.pop(self.key, None)
+
+ def insert_update_handler(mapper, connection, state):
+ """After an insert or update, some columns may be expired due
+ to server side defaults, or re-populated due to client side
+ defaults. Pop out the composite value here so that it
+ recreates.
+
+ """
+
+ state.dict.pop(self.key, None)
+
+ event.listen(
+ self.parent, "after_insert", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "after_update", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "load", load_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "refresh", refresh_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "expire", expire_handler, raw=True, propagate=True
+ )
+
+ # TODO: need a deserialize hook here
+
+ @util.memoized_property
+ def _attribute_keys(self):
+ return [prop.key for prop in self.props]
+
+ def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ """Provided for userland code that uses attributes.get_history()."""
+
+ added = []
+ deleted = []
+
+ has_history = False
+ for prop in self.props:
+ key = prop.key
+ hist = state.manager[key].impl.get_history(state, dict_)
+ if hist.has_changes():
+ has_history = True
+
+ non_deleted = hist.non_deleted()
+ if non_deleted:
+ added.extend(non_deleted)
+ else:
+ added.append(None)
+ if hist.deleted:
+ deleted.extend(hist.deleted)
+ else:
+ deleted.append(None)
+
+ if has_history:
+ return attributes.History(
+ [self.composite_class(*added)],
+ (),
+ [self.composite_class(*deleted)],
+ )
+ else:
+ return attributes.History((), [self.composite_class(*added)], ())
+
+ def _comparator_factory(self, mapper):
+ return self.comparator_factory(self, mapper)
+
+ class CompositeBundle(orm_util.Bundle):
+ def __init__(self, property_, expr):
+ self.property = property_
+ super(CompositeProperty.CompositeBundle, self).__init__(
+ property_.key, *expr
+ )
+
+ def create_row_processor(self, query, procs, labels):
+ def proc(row):
+ return self.property.composite_class(
+ *[proc(row) for proc in procs]
+ )
+
+ return proc
+
+ class Comparator(PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.CompositeProperty` attributes.
+
+ See the example in :ref:`composite_operations` for an overview
+ of usage , as well as the documentation for :class:`.PropComparator`.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __hash__ = None
+
+ @util.memoized_property
+ def clauses(self):
+ return expression.ClauseList(
+ group=False, *self._comparable_elements
+ )
+
+ def __clause_element__(self):
+ return self.expression
+
+ @util.memoized_property
+ def expression(self):
+ clauses = self.clauses._annotate(
+ {
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ "proxy_key": self.prop.key,
+ }
+ )
+ return CompositeProperty.CompositeBundle(self.prop, clauses)
+
+ def _bulk_update_tuples(self, value):
+ if isinstance(value, sql.elements.BindParameter):
+ value = value.value
+
+ if value is None:
+ values = [None for key in self.prop._attribute_keys]
+ elif isinstance(value, self.prop.composite_class):
+ values = value.__composite_values__()
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't UPDATE composite attribute %s to %r"
+ % (self.prop, value)
+ )
+
+ return zip(self._comparable_elements, values)
+
+ @util.memoized_property
+ def _comparable_elements(self):
+ if self._adapt_to_entity:
+ return [
+ getattr(self._adapt_to_entity.entity, prop.key)
+ for prop in self.prop._comparable_elements
+ ]
+ else:
+ return self.prop._comparable_elements
+
+ def __eq__(self, other):
+ if other is None:
+ values = [None] * len(self.prop._comparable_elements)
+ else:
+ values = other.__composite_values__()
+ comparisons = [
+ a == b for a, b in zip(self.prop._comparable_elements, values)
+ ]
+ if self._adapt_to_entity:
+ comparisons = [self.adapter(x) for x in comparisons]
+ return sql.and_(*comparisons)
+
+ def __ne__(self, other):
+ return sql.not_(self.__eq__(other))
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
+
+class ConcreteInheritedProperty(DescriptorProperty):
+ """A 'do nothing' :class:`.MapperProperty` that disables
+ an attribute on a concrete subclass that is only present
+ on the inherited mapper, not the concrete classes' mapper.
+
+ Cases where this occurs include:
+
+ * When the superclass mapper is mapped against a
+ "polymorphic union", which includes all attributes from
+ all subclasses.
+ * When a relationship() is configured on an inherited mapper,
+ but not on the subclass mapper. Concrete mappers require
+ that relationship() is configured explicitly on each
+ subclass.
+
+ """
+
+ def _comparator_factory(self, mapper):
+ comparator_callable = None
+
+ for m in self.parent.iterate_to_root():
+ p = m._props[self.key]
+ if not isinstance(p, ConcreteInheritedProperty):
+ comparator_callable = p.comparator_factory
+ break
+ return comparator_callable
+
+ def __init__(self):
+ super(ConcreteInheritedProperty, self).__init__()
+
+ def warn():
+ raise AttributeError(
+ "Concrete %s does not implement "
+ "attribute %r at the instance level. Add "
+ "this property explicitly to %s."
+ % (self.parent, self.key, self.parent)
+ )
+
+ class NoninheritedConcreteProp(object):
+ def __set__(s, obj, value):
+ warn()
+
+ def __delete__(s, obj):
+ warn()
+
+ def __get__(s, obj, owner):
+ if obj is None:
+ return self.descriptor
+ warn()
+
+ self.descriptor = NoninheritedConcreteProp()
+
+
+class SynonymProperty(DescriptorProperty):
+ def __init__(
+ self,
+ name,
+ map_column=None,
+ descriptor=None,
+ comparator_factory=None,
+ doc=None,
+ info=None,
+ ):
+ """Denote an attribute name as a synonym to a mapped property,
+ in that the attribute will mirror the value and expression behavior
+ of another attribute.
+
+ e.g.::
+
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+
+ id = Column(Integer, primary_key=True)
+ job_status = Column(String(50))
+
+ status = synonym("job_status")
+
+
+ :param name: the name of the existing mapped property. This
+ can refer to the string name ORM-mapped attribute
+ configured on the class, including column-bound attributes
+ and relationships.
+
+ :param descriptor: a Python :term:`descriptor` that will be used
+ as a getter (and potentially a setter) when this attribute is
+ accessed at the instance level.
+
+ :param map_column: **For classical mappings and mappings against
+ an existing Table object only**. if ``True``, the :func:`.synonym`
+ construct will locate the :class:`_schema.Column`
+ object upon the mapped
+ table that would normally be associated with the attribute name of
+ this synonym, and produce a new :class:`.ColumnProperty` that instead
+ maps this :class:`_schema.Column`
+ to the alternate name given as the "name"
+ argument of the synonym; in this way, the usual step of redefining
+ the mapping of the :class:`_schema.Column`
+ to be under a different name is
+ unnecessary. This is usually intended to be used when a
+ :class:`_schema.Column`
+ is to be replaced with an attribute that also uses a
+ descriptor, that is, in conjunction with the
+ :paramref:`.synonym.descriptor` parameter::
+
+ my_table = Table(
+ "my_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('job_status', String(50))
+ )
+
+ class MyClass(object):
+ @property
+ def _job_status_descriptor(self):
+ return "Status: %s" % self._job_status
+
+
+ mapper(
+ MyClass, my_table, properties={
+ "job_status": synonym(
+ "_job_status", map_column=True,
+ descriptor=MyClass._job_status_descriptor)
+ }
+ )
+
+ Above, the attribute named ``_job_status`` is automatically
+ mapped to the ``job_status`` column::
+
+ >>> j1 = MyClass()
+ >>> j1._job_status = "employed"
+ >>> j1.job_status
+ Status: employed
+
+ When using Declarative, in order to provide a descriptor in
+ conjunction with a synonym, use the
+ :func:`sqlalchemy.ext.declarative.synonym_for` helper. However,
+ note that the :ref:`hybrid properties <mapper_hybrids>` feature
+ should usually be preferred, particularly when redefining attribute
+ behavior.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.InspectionAttr.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param comparator_factory: A subclass of :class:`.PropComparator`
+ that will provide custom comparison behavior at the SQL expression
+ level.
+
+ .. note::
+
+ For the use case of providing an attribute which redefines both
+ Python-level and SQL-expression level behavior of an attribute,
+ please refer to the Hybrid attribute introduced at
+ :ref:`mapper_hybrids` for a more effective technique.
+
+ .. seealso::
+
+ :ref:`synonyms` - Overview of synonyms
+
+ :func:`.synonym_for` - a helper oriented towards Declarative
+
+ :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an
+ updated approach to augmenting attribute behavior more flexibly
+ than can be achieved with synonyms.
+
+ """
+ super(SynonymProperty, self).__init__()
+
+ self.name = name
+ self.map_column = map_column
+ self.descriptor = descriptor
+ self.comparator_factory = comparator_factory
+ self.doc = doc or (descriptor and descriptor.__doc__) or None
+ if info:
+ self.info = info
+
+ util.set_creation_order(self)
+
+ @property
+ def uses_objects(self):
+ return getattr(self.parent.class_, self.name).impl.uses_objects
+
+ # TODO: when initialized, check _proxied_object,
+ # emit a warning if its not a column-based property
+
+ @util.memoized_property
+ def _proxied_object(self):
+ attr = getattr(self.parent.class_, self.name)
+ if not hasattr(attr, "property") or not isinstance(
+ attr.property, MapperProperty
+ ):
+ # attribute is a non-MapperProprerty proxy such as
+ # hybrid or association proxy
+ if isinstance(attr, attributes.QueryableAttribute):
+ return attr.comparator
+ elif isinstance(attr, operators.ColumnOperators):
+ return attr
+
+ raise sa_exc.InvalidRequestError(
+ """synonym() attribute "%s.%s" only supports """
+ """ORM mapped attributes, got %r"""
+ % (self.parent.class_.__name__, self.name, attr)
+ )
+ return attr.property
+
+ def _comparator_factory(self, mapper):
+ prop = self._proxied_object
+
+ if isinstance(prop, MapperProperty):
+ if self.comparator_factory:
+ comp = self.comparator_factory(prop, mapper)
+ else:
+ comp = prop.comparator_factory(prop, mapper)
+ return comp
+ else:
+ return prop
+
+ def get_history(self, *arg, **kw):
+ attr = getattr(self.parent.class_, self.name)
+ return attr.impl.get_history(*arg, **kw)
+
+ @util.preload_module("sqlalchemy.orm.properties")
+ def set_parent(self, parent, init):
+ properties = util.preloaded.orm_properties
+
+ if self.map_column:
+ # implement the 'map_column' option.
+ if self.key not in parent.persist_selectable.c:
+ raise sa_exc.ArgumentError(
+ "Can't compile synonym '%s': no column on table "
+ "'%s' named '%s'"
+ % (
+ self.name,
+ parent.persist_selectable.description,
+ self.key,
+ )
+ )
+ elif (
+ parent.persist_selectable.c[self.key]
+ in parent._columntoproperty
+ and parent._columntoproperty[
+ parent.persist_selectable.c[self.key]
+ ].key
+ == self.name
+ ):
+ raise sa_exc.ArgumentError(
+ "Can't call map_column=True for synonym %r=%r, "
+ "a ColumnProperty already exists keyed to the name "
+ "%r for column %r"
+ % (self.key, self.name, self.name, self.key)
+ )
+ p = properties.ColumnProperty(
+ parent.persist_selectable.c[self.key]
+ )
+ parent._configure_property(self.name, p, init=init, setparent=True)
+ p._mapped_by_synonym = self.key
+
+ self.parent = parent
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
new file mode 100644
index 0000000..ec62560
--- /dev/null
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -0,0 +1,491 @@
+# orm/dynamic.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Dynamic collection API.
+
+Dynamic collections act like Query() objects for read operations and support
+basic add/delete mutation.
+
+"""
+
+from . import attributes
+from . import exc as orm_exc
+from . import interfaces
+from . import object_mapper
+from . import object_session
+from . import relationships
+from . import strategies
+from . import util as orm_util
+from .query import Query
+from .. import exc
+from .. import log
+from .. import util
+from ..engine import result
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="dynamic")
+class DynaLoader(strategies.AbstractRelationshipLoader):
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+ if not self.uselist:
+ raise exc.InvalidRequestError(
+ "On relationship %s, 'dynamic' loaders cannot be used with "
+ "many-to-one/one-to-one relationships and/or "
+ "uselist=False." % self.parent_property
+ )
+ elif self.parent_property.direction not in (
+ interfaces.ONETOMANY,
+ interfaces.MANYTOMANY,
+ ):
+ util.warn(
+ "On relationship %s, 'dynamic' loaders cannot be used with "
+ "many-to-one/one-to-one relationships and/or "
+ "uselist=False. This warning will be an exception in a "
+ "future release." % self.parent_property
+ )
+
+ strategies._register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ impl_class=DynamicAttributeImpl,
+ target_mapper=self.parent_property.mapper,
+ order_by=self.parent_property.order_by,
+ query_class=self.parent_property.query_class,
+ )
+
+
+class DynamicAttributeImpl(attributes.AttributeImpl):
+ uses_objects = True
+ default_accepts_scalar_loader = False
+ supports_population = False
+ collection = False
+ dynamic = True
+ order_by = ()
+
+ def __init__(
+ self,
+ class_,
+ key,
+ typecallable,
+ dispatch,
+ target_mapper,
+ order_by,
+ query_class=None,
+ **kw
+ ):
+ super(DynamicAttributeImpl, self).__init__(
+ class_, key, typecallable, dispatch, **kw
+ )
+ self.target_mapper = target_mapper
+ if order_by:
+ self.order_by = tuple(order_by)
+ if not query_class:
+ self.query_class = AppenderQuery
+ elif AppenderMixin in query_class.mro():
+ self.query_class = query_class
+ else:
+ self.query_class = mixin_user_query(query_class)
+
+ def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ if not passive & attributes.SQL_OK:
+ return self._get_collection_history(
+ state, attributes.PASSIVE_NO_INITIALIZE
+ ).added_items
+ else:
+ return self.query_class(self, state)
+
+ def get_collection(
+ self,
+ state,
+ dict_,
+ user_data=None,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ ):
+ if not passive & attributes.SQL_OK:
+ data = self._get_collection_history(state, passive).added_items
+ else:
+ history = self._get_collection_history(state, passive)
+ data = history.added_plus_unchanged
+ return DynamicCollectionAdapter(data)
+
+ @util.memoized_property
+ def _append_token(self):
+ return attributes.Event(self, attributes.OP_APPEND)
+
+ @util.memoized_property
+ def _remove_token(self):
+ return attributes.Event(self, attributes.OP_REMOVE)
+
+ def fire_append_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
+ if collection_history is None:
+ collection_history = self._modified_event(state, dict_)
+
+ collection_history.add_added(value)
+
+ for fn in self.dispatch.append:
+ value = fn(state, value, initiator or self._append_token)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(attributes.instance_state(value), state, True)
+
+ def fire_remove_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
+ if collection_history is None:
+ collection_history = self._modified_event(state, dict_)
+
+ collection_history.add_removed(value)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(attributes.instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ def _modified_event(self, state, dict_):
+
+ if self.key not in state.committed_state:
+ state.committed_state[self.key] = CollectionHistory(self, state)
+
+ state._modified_event(dict_, self, attributes.NEVER_SET)
+
+ # this is a hack to allow the fixtures.ComparableEntity fixture
+ # to work
+ dict_[self.key] = True
+ return state.committed_state[self.key]
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=attributes.PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ _adapt=True,
+ ):
+ if initiator and initiator.parent_token is self.parent_token:
+ return
+
+ if pop and value is None:
+ return
+
+ iterable = value
+ new_values = list(iterable)
+ if state.has_identity:
+ old_collection = util.IdentitySet(self.get(state, dict_))
+
+ collection_history = self._modified_event(state, dict_)
+ if not state.has_identity:
+ old_collection = collection_history.added_items
+ else:
+ old_collection = old_collection.union(
+ collection_history.added_items
+ )
+
+ idset = util.IdentitySet
+ constants = old_collection.intersection(new_values)
+ additions = idset(new_values).difference(constants)
+ removals = old_collection.difference(constants)
+
+ for member in new_values:
+ if member in additions:
+ self.fire_append_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
+
+ for member in removals:
+ self.fire_remove_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
+
+ def delete(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def set_committed_value(self, state, dict_, value):
+ raise NotImplementedError(
+ "Dynamic attributes don't support " "collection population."
+ )
+
+ def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ c = self._get_collection_history(state, passive)
+ return c.as_history()
+
+ def get_all_pending(
+ self, state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
+ c = self._get_collection_history(state, passive)
+ return [(attributes.instance_state(x), x) for x in c.all_items]
+
+ def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
+ if self.key in state.committed_state:
+ c = state.committed_state[self.key]
+ else:
+ c = CollectionHistory(self, state)
+
+ if state.has_identity and (passive & attributes.INIT_OK):
+ return CollectionHistory(self, state, apply_to=c)
+ else:
+ return c
+
+ def append(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ if initiator is not self:
+ self.fire_append_event(state, dict_, value, initiator)
+
+ def remove(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ if initiator is not self:
+ self.fire_remove_event(state, dict_, value, initiator)
+
+ def pop(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ self.remove(state, dict_, value, initiator, passive=passive)
+
+
+class DynamicCollectionAdapter(object):
+ """simplified CollectionAdapter for internal API consistency"""
+
+ def __init__(self, data):
+ self.data = data
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def _reset_empty(self):
+ pass
+
+ def __len__(self):
+ return len(self.data)
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+
+class AppenderMixin(object):
+ query_class = None
+
+ def __init__(self, attr, state):
+ super(AppenderMixin, self).__init__(attr.target_mapper, None)
+ self.instance = instance = state.obj()
+ self.attr = attr
+
+ mapper = object_mapper(instance)
+ prop = mapper._props[self.attr.key]
+
+ if prop.secondary is not None:
+ # this is a hack right now. The Query only knows how to
+ # make subsequent joins() without a given left-hand side
+ # from self._from_obj[0]. We need to ensure prop.secondary
+ # is in the FROM. So we purposely put the mapper selectable
+ # in _from_obj[0] to ensure a user-defined join() later on
+ # doesn't fail, and secondary is then in _from_obj[1].
+
+ # note also, we are using the official ORM-annotated selectable
+ # from __clause_element__(), see #7868
+ self._from_obj = (prop.mapper.__clause_element__(), prop.secondary)
+
+ self._where_criteria = (
+ prop._with_parent(instance, alias_secondary=False),
+ )
+
+ if self.attr.order_by:
+ self._order_by_clauses = self.attr.order_by
+
+ def session(self):
+ sess = object_session(self.instance)
+ if (
+ sess is not None
+ and self.autoflush
+ and sess.autoflush
+ and self.instance in sess
+ ):
+ sess.flush()
+ if not orm_util.has_identity(self.instance):
+ return None
+ else:
+ return sess
+
+ session = property(session, lambda s, x: None)
+
+ def _iter(self):
+ sess = self.session
+ if sess is None:
+ state = attributes.instance_state(self.instance)
+ if state.detached:
+ util.warn(
+ "Instance %s is detached, dynamic relationship cannot "
+ "return a correct result. This warning will become "
+ "a DetachedInstanceError in a future release."
+ % (orm_util.state_str(state))
+ )
+
+ return result.IteratorResult(
+ result.SimpleResultMetaData([self.attr.class_.__name__]),
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items,
+ _source_supports_scalars=True,
+ ).scalars()
+ else:
+ return self._generate(sess)._iter()
+
+ def __getitem__(self, index):
+ sess = self.session
+ if sess is None:
+ return self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).indexed(index)
+ else:
+ return self._generate(sess).__getitem__(index)
+
+ def count(self):
+ sess = self.session
+ if sess is None:
+ return len(
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items
+ )
+ else:
+ return self._generate(sess).count()
+
+ def _generate(self, sess=None):
+ # note we're returning an entirely new Query class instance
+ # here without any assignment capabilities; the class of this
+ # query is determined by the session.
+ instance = self.instance
+ if sess is None:
+ sess = object_session(instance)
+ if sess is None:
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session, and no "
+ "contextual session is established; lazy load operation "
+ "of attribute '%s' cannot proceed"
+ % (orm_util.instance_str(instance), self.attr.key)
+ )
+
+ if self.query_class:
+ query = self.query_class(self.attr.target_mapper, session=sess)
+ else:
+ query = sess.query(self.attr.target_mapper)
+
+ query._where_criteria = self._where_criteria
+ query._from_obj = self._from_obj
+ query._order_by_clauses = self._order_by_clauses
+
+ return query
+
+ def extend(self, iterator):
+ for item in iterator:
+ self.attr.append(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+ def append(self, item):
+ self.attr.append(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+ def remove(self, item):
+ self.attr.remove(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+
+class AppenderQuery(AppenderMixin, Query):
+ """A dynamic query that supports basic collection storage operations."""
+
+
+def mixin_user_query(cls):
+ """Return a new class with AppenderQuery functionality layered over."""
+ name = "Appender" + cls.__name__
+ return type(name, (AppenderMixin, cls), {"query_class": cls})
+
+
+class CollectionHistory(object):
+ """Overrides AttributeHistory to receive append/remove events directly."""
+
+ def __init__(self, attr, state, apply_to=None):
+ if apply_to:
+ coll = AppenderQuery(attr, state).autoflush(False)
+ self.unchanged_items = util.OrderedIdentitySet(coll)
+ self.added_items = apply_to.added_items
+ self.deleted_items = apply_to.deleted_items
+ self._reconcile_collection = True
+ else:
+ self.deleted_items = util.OrderedIdentitySet()
+ self.added_items = util.OrderedIdentitySet()
+ self.unchanged_items = util.OrderedIdentitySet()
+ self._reconcile_collection = False
+
+ @property
+ def added_plus_unchanged(self):
+ return list(self.added_items.union(self.unchanged_items))
+
+ @property
+ def all_items(self):
+ return list(
+ self.added_items.union(self.unchanged_items).union(
+ self.deleted_items
+ )
+ )
+
+ def as_history(self):
+ if self._reconcile_collection:
+ added = self.added_items.difference(self.unchanged_items)
+ deleted = self.deleted_items.intersection(self.unchanged_items)
+ unchanged = self.unchanged_items.difference(deleted)
+ else:
+ added, unchanged, deleted = (
+ self.added_items,
+ self.unchanged_items,
+ self.deleted_items,
+ )
+ return attributes.History(list(added), list(unchanged), list(deleted))
+
+ def indexed(self, index):
+ return list(self.added_items)[index]
+
+ def add_added(self, value):
+ self.added_items.add(value)
+
+ def add_removed(self, value):
+ if value in self.added_items:
+ self.added_items.remove(value)
+ else:
+ self.deleted_items.add(value)
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
new file mode 100644
index 0000000..dbbfba0
--- /dev/null
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -0,0 +1,241 @@
+# orm/evaluator.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import operator
+
+from .. import inspect
+from .. import util
+from ..sql import and_
+from ..sql import operators
+
+
+class UnevaluatableError(Exception):
+ pass
+
+
+class _NoObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return None
+
+ def reverse_operate(self, *arg, **kw):
+ return None
+
+
+_NO_OBJECT = _NoObject()
+
+_straight_ops = set(
+ getattr(operators, op)
+ for op in (
+ "add",
+ "mul",
+ "sub",
+ "div",
+ "mod",
+ "truediv",
+ "lt",
+ "le",
+ "ne",
+ "gt",
+ "ge",
+ "eq",
+ )
+)
+
+_extended_ops = {
+ operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
+ operators.not_in_op: (
+ lambda a, b: a not in b if a is not _NO_OBJECT else None
+ ),
+}
+
+_notimplemented_ops = set(
+ getattr(operators, op)
+ for op in (
+ "like_op",
+ "not_like_op",
+ "ilike_op",
+ "not_ilike_op",
+ "startswith_op",
+ "between_op",
+ "endswith_op",
+ "concat_op",
+ )
+)
+
+
+class EvaluatorCompiler(object):
+ def __init__(self, target_cls=None):
+ self.target_cls = target_cls
+
+ def process(self, *clauses):
+ if len(clauses) > 1:
+ clause = and_(*clauses)
+ elif clauses:
+ clause = clauses[0]
+
+ meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+ if not meth:
+ raise UnevaluatableError(
+ "Cannot evaluate %s" % type(clause).__name__
+ )
+ return meth(clause)
+
+ def visit_grouping(self, clause):
+ return self.process(clause.element)
+
+ def visit_null(self, clause):
+ return lambda obj: None
+
+ def visit_false(self, clause):
+ return lambda obj: False
+
+ def visit_true(self, clause):
+ return lambda obj: True
+
+ def visit_column(self, clause):
+ if "parentmapper" in clause._annotations:
+ parentmapper = clause._annotations["parentmapper"]
+ if self.target_cls and not issubclass(
+ self.target_cls, parentmapper.class_
+ ):
+ raise UnevaluatableError(
+ "Can't evaluate criteria against alternate class %s"
+ % parentmapper.class_
+ )
+ key = parentmapper._columntoproperty[clause].key
+ else:
+ key = clause.key
+ if (
+ self.target_cls
+ and key in inspect(self.target_cls).column_attrs
+ ):
+ util.warn(
+ "Evaluating non-mapped column expression '%s' onto "
+ "ORM instances; this is a deprecated use case. Please "
+ "make use of the actual mapped columns in ORM-evaluated "
+ "UPDATE / DELETE expressions." % clause
+ )
+ else:
+ raise UnevaluatableError("Cannot evaluate column: %s" % clause)
+
+ get_corresponding_attr = operator.attrgetter(key)
+ return (
+ lambda obj: get_corresponding_attr(obj)
+ if obj is not None
+ else _NO_OBJECT
+ )
+
+ def visit_tuple(self, clause):
+ return self.visit_clauselist(clause)
+
+ def visit_clauselist(self, clause):
+ evaluators = list(map(self.process, clause.clauses))
+ if clause.operator is operators.or_:
+
+ def evaluate(obj):
+ has_null = False
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value:
+ return True
+ has_null = has_null or value is None
+ if has_null:
+ return None
+ return False
+
+ elif clause.operator is operators.and_:
+
+ def evaluate(obj):
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if not value:
+ if value is None or value is _NO_OBJECT:
+ return None
+ return False
+ return True
+
+ elif clause.operator is operators.comma_op:
+
+ def evaluate(obj):
+ values = []
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value is None or value is _NO_OBJECT:
+ return None
+ values.append(value)
+ return tuple(values)
+
+ else:
+ raise UnevaluatableError(
+ "Cannot evaluate clauselist with operator %s" % clause.operator
+ )
+
+ return evaluate
+
+ def visit_binary(self, clause):
+ eval_left, eval_right = list(
+ map(self.process, [clause.left, clause.right])
+ )
+ operator = clause.operator
+ if operator is operators.is_:
+
+ def evaluate(obj):
+ return eval_left(obj) == eval_right(obj)
+
+ elif operator is operators.is_not:
+
+ def evaluate(obj):
+ return eval_left(obj) != eval_right(obj)
+
+ elif operator in _extended_ops:
+
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+
+ return _extended_ops[operator](left_val, right_val)
+
+ elif operator in _straight_ops:
+
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+ return operator(eval_left(obj), eval_right(obj))
+
+ else:
+ raise UnevaluatableError(
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
+ return evaluate
+
+ def visit_unary(self, clause):
+ eval_inner = self.process(clause.element)
+ if clause.operator is operators.inv:
+
+ def evaluate(obj):
+ value = eval_inner(obj)
+ if value is None:
+ return None
+ return not value
+
+ return evaluate
+ raise UnevaluatableError(
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
+
+ def visit_bindparam(self, clause):
+ if clause.callable:
+ val = clause.callable()
+ else:
+ val = clause.value
+ return lambda obj: val
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
new file mode 100644
index 0000000..39659c7
--- /dev/null
+++ b/lib/sqlalchemy/orm/events.py
@@ -0,0 +1,2876 @@
+# orm/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""ORM event interfaces.
+
+"""
+import weakref
+
+from . import instrumentation
+from . import interfaces
+from . import mapperlib
+from .attributes import QueryableAttribute
+from .base import _mapper_or_none
+from .query import Query
+from .scoping import scoped_session
+from .session import Session
+from .session import sessionmaker
+from .. import event
+from .. import exc
+from .. import util
+from ..util.compat import inspect_getfullargspec
+
+
+class InstrumentationEvents(event.Events):
+ """Events related to class instrumentation events.
+
+ The listeners here support being established against
+ any new style class, that is any object that is a subclass
+ of 'type'. Events will then be fired off for events
+ against that class. If the "propagate=True" flag is passed
+ to event.listen(), the event will fire off for subclasses
+ of that class as well.
+
+ The Python ``type`` builtin is also accepted as a target,
+ which when used has the effect of events being emitted
+ for all classes.
+
+ Note the "propagate" flag here is defaulted to ``True``,
+ unlike the other class level events where it defaults
+ to ``False``. This means that new subclasses will also
+ be the subject of these events, when a listener
+ is established on a superclass.
+
+ """
+
+ _target_class_doc = "SomeBaseClass"
+ _dispatch_target = instrumentation.InstrumentationFactory
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ return _InstrumentationEventsHold(target)
+ else:
+ return None
+
+ @classmethod
+ def _listen(cls, event_key, propagate=True, **kw):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ def listen(target_cls, *arg):
+ listen_cls = target()
+
+ # if weakref were collected, however this is not something
+ # that normally happens. it was occurring during test teardown
+ # between mapper/registry/instrumentation_manager, however this
+ # interaction was changed to not rely upon the event system.
+ if listen_cls is None:
+ return None
+
+ if propagate and issubclass(target_cls, listen_cls):
+ return fn(target_cls, *arg)
+ elif not propagate and target_cls is listen_cls:
+ return fn(target_cls, *arg)
+
+ def remove(ref):
+ key = event.registry._EventKey(
+ None,
+ identifier,
+ listen,
+ instrumentation._instrumentation_factory,
+ )
+ getattr(
+ instrumentation._instrumentation_factory.dispatch, identifier
+ ).remove(key)
+
+ target = weakref.ref(target.class_, remove)
+
+ event_key.with_dispatch_target(
+ instrumentation._instrumentation_factory
+ ).with_wrapper(listen).base_listen(**kw)
+
+ @classmethod
+ def _clear(cls):
+ super(InstrumentationEvents, cls)._clear()
+ instrumentation._instrumentation_factory.dispatch._clear()
+
+ def class_instrument(self, cls):
+ """Called after the given class is instrumented.
+
+ To get at the :class:`.ClassManager`, use
+ :func:`.manager_of_class`.
+
+ """
+
+ def class_uninstrument(self, cls):
+ """Called before the given class is uninstrumented.
+
+ To get at the :class:`.ClassManager`, use
+ :func:`.manager_of_class`.
+
+ """
+
+ def attribute_instrument(self, cls, key, inst):
+ """Called when an attribute is instrumented."""
+
+
+class _InstrumentationEventsHold(object):
+ """temporary marker object used to transfer from _accept_with() to
+ _listen() on the InstrumentationEvents class.
+
+ """
+
+ def __init__(self, class_):
+ self.class_ = class_
+
+ dispatch = event.dispatcher(InstrumentationEvents)
+
+
+class InstanceEvents(event.Events):
+ """Define events specific to object lifecycle.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_load_listener(target, context):
+ print("on load!")
+
+ event.listen(SomeClass, 'load', my_load_listener)
+
+ Available targets include:
+
+ * mapped classes
+ * unmapped superclasses of mapped or to-be-mapped classes
+ (using the ``propagate=True`` flag)
+ * :class:`_orm.Mapper` objects
+ * the :class:`_orm.Mapper` class itself and the :func:`.mapper`
+ function indicate listening for all mappers.
+
+ Instance events are closely related to mapper events, but
+ are more specific to the instance and its instrumentation,
+ rather than its system of persistence.
+
+ When using :class:`.InstanceEvents`, several modifiers are
+ available to the :func:`.event.listen` function.
+
+ :param propagate=False: When True, the event listener should
+ be applied to all inheriting classes as well as the
+ class which is the target of this listener.
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions will be the
+ instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param restore_load_context=False: Applies to the
+ :meth:`.InstanceEvents.load` and :meth:`.InstanceEvents.refresh`
+ events. Restores the loader context of the object when the event
+ hook is complete, so that ongoing eager load operations continue
+ to target the object appropriately. A warning is emitted if the
+ object is moved to a new loader context from within one of these
+ events if this flag is not set.
+
+ .. versionadded:: 1.3.14
+
+
+ """
+
+ _target_class_doc = "SomeClass"
+
+ _dispatch_target = instrumentation.ClassManager
+
+ @classmethod
+ def _new_classmanager_instance(cls, class_, classmanager):
+ _InstanceEventsHold.populate(class_, classmanager)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm")
+ def _accept_with(cls, target):
+ orm = util.preloaded.orm
+
+ if isinstance(target, instrumentation.ClassManager):
+ return target
+ elif isinstance(target, mapperlib.Mapper):
+ return target.class_manager
+ elif target is orm.mapper:
+ return instrumentation.ClassManager
+ elif isinstance(target, type):
+ if issubclass(target, mapperlib.Mapper):
+ return instrumentation.ClassManager
+ else:
+ manager = instrumentation.manager_of_class(target)
+ if manager:
+ return manager
+ else:
+ return _InstanceEventsHold(target)
+ return None
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ raw=False,
+ propagate=False,
+ restore_load_context=False,
+ **kw
+ ):
+ target, fn = (event_key.dispatch_target, event_key._listen_fn)
+
+ if not raw or restore_load_context:
+
+ def wrap(state, *arg, **kw):
+ if not raw:
+ target = state.obj()
+ else:
+ target = state
+ if restore_load_context:
+ runid = state.runid
+ try:
+ return fn(target, *arg, **kw)
+ finally:
+ if restore_load_context:
+ state.runid = runid
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(propagate=propagate, **kw)
+
+ if propagate:
+ for mgr in target.subclass_managers(True):
+ event_key.with_dispatch_target(mgr).base_listen(propagate=True)
+
+ @classmethod
+ def _clear(cls):
+ super(InstanceEvents, cls)._clear()
+ _InstanceEventsHold._clear()
+
+ def first_init(self, manager, cls):
+ """Called when the first instance of a particular mapping is called.
+
+ This event is called when the ``__init__`` method of a class
+ is called the first time for that particular class. The event
+ invokes before ``__init__`` actually proceeds as well as before
+ the :meth:`.InstanceEvents.init` event is invoked.
+
+ """
+
+ def init(self, target, args, kwargs):
+ """Receive an instance when its constructor is called.
+
+ This method is only called during a userland construction of
+ an object, in conjunction with the object's constructor, e.g.
+ its ``__init__`` method. It is not called when an object is
+ loaded from the database; see the :meth:`.InstanceEvents.load`
+ event in order to intercept a database load.
+
+ The event is called before the actual ``__init__`` constructor
+ of the object is called. The ``kwargs`` dictionary may be
+ modified in-place in order to affect what is passed to
+ ``__init__``.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param args: positional arguments passed to the ``__init__`` method.
+ This is passed as a tuple and is currently immutable.
+ :param kwargs: keyword arguments passed to the ``__init__`` method.
+ This structure *can* be altered in place.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init_failure`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def init_failure(self, target, args, kwargs):
+ """Receive an instance when its constructor has been called,
+ and raised an exception.
+
+ This method is only called during a userland construction of
+ an object, in conjunction with the object's constructor, e.g.
+ its ``__init__`` method. It is not called when an object is loaded
+ from the database.
+
+ The event is invoked after an exception raised by the ``__init__``
+ method is caught. After the event
+ is invoked, the original exception is re-raised outwards, so that
+ the construction of the object still raises an exception. The
+ actual exception and stack trace raised should be present in
+ ``sys.exc_info()``.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param args: positional arguments that were passed to the ``__init__``
+ method.
+ :param kwargs: keyword arguments that were passed to the ``__init__``
+ method.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def load(self, target, context):
+ """Receive an object instance after it has been created via
+ ``__new__``, and after initial attribute population has
+ occurred.
+
+ This typically occurs when the instance is created based on
+ incoming result rows, and is only called once for that
+ instance's lifetime.
+
+ .. warning::
+
+ During a result-row load, this event is invoked when the
+ first row received for this instance is processed. When using
+ eager loading with collection-oriented attributes, the additional
+ rows that are to be loaded / processed in order to load subsequent
+ collection items have not occurred yet. This has the effect
+ both that collections will not be fully loaded, as well as that
+ if an operation occurs within this event handler that emits
+ another database load operation for the object, the "loading
+ context" for the object can change and interfere with the
+ existing eager loaders still in progress.
+
+ Examples of what can cause the "loading context" to change within
+ the event handler include, but are not necessarily limited to:
+
+ * accessing deferred attributes that weren't part of the row,
+ will trigger an "undefer" operation and refresh the object
+
+ * accessing attributes on a joined-inheritance subclass that
+ weren't part of the row, will trigger a refresh operation.
+
+ As of SQLAlchemy 1.3.14, a warning is emitted when this occurs. The
+ :paramref:`.InstanceEvents.restore_load_context` option may be
+ used on the event to prevent this warning; this will ensure that
+ the existing loading context is maintained for the object after the
+ event is called::
+
+ @event.listens_for(
+ SomeClass, "load", restore_load_context=True)
+ def on_load(instance, context):
+ instance.some_unloaded_attribute
+
+ .. versionchanged:: 1.3.14 Added
+ :paramref:`.InstanceEvents.restore_load_context`
+ and :paramref:`.SessionEvents.restore_load_context` flags which
+ apply to "on load" events, which will ensure that the loading
+ context for an object is restored when the event hook is
+ complete; a warning is emitted if the load context of the object
+ changes without this flag being set.
+
+
+ The :meth:`.InstanceEvents.load` event is also available in a
+ class-method decorator format called :func:`_orm.reconstructor`.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param context: the :class:`.QueryContext` corresponding to the
+ current :class:`_query.Query` in progress. This argument may be
+ ``None`` if the load does not correspond to a :class:`_query.Query`,
+ such as during :meth:`.Session.merge`.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init`
+
+ :meth:`.InstanceEvents.refresh`
+
+ :meth:`.SessionEvents.loaded_as_persistent`
+
+ :ref:`mapping_constructors`
+
+ """
+
+ def refresh(self, target, context, attrs):
+ """Receive an object instance after one or more attributes have
+ been refreshed from a query.
+
+ Contrast this to the :meth:`.InstanceEvents.load` method, which
+ is invoked when the object is first loaded from a query.
+
+ .. note:: This event is invoked within the loader process before
+ eager loaders may have been completed, and the object's state may
+ not be complete. Additionally, invoking row-level refresh
+ operations on the object will place the object into a new loader
+ context, interfering with the existing load context. See the note
+ on :meth:`.InstanceEvents.load` for background on making use of the
+ :paramref:`.InstanceEvents.restore_load_context` parameter, in
+ order to resolve this scenario.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param context: the :class:`.QueryContext` corresponding to the
+ current :class:`_query.Query` in progress.
+ :param attrs: sequence of attribute names which
+ were populated, or None if all column-mapped, non-deferred
+ attributes were populated.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def refresh_flush(self, target, flush_context, attrs):
+ """Receive an object instance after one or more attributes that
+ contain a column-level default or onupdate handler have been refreshed
+ during persistence of the object's state.
+
+ This event is the same as :meth:`.InstanceEvents.refresh` except
+ it is invoked within the unit of work flush process, and includes
+ only non-primary-key columns that have column level default or
+ onupdate handlers, including Python callables as well as server side
+ defaults and triggers which may be fetched via the RETURNING clause.
+
+ .. note::
+
+ While the :meth:`.InstanceEvents.refresh_flush` event is triggered
+ for an object that was INSERTed as well as for an object that was
+ UPDATEd, the event is geared primarily towards the UPDATE process;
+ it is mostly an internal artifact that INSERT actions can also
+ trigger this event, and note that **primary key columns for an
+ INSERTed row are explicitly omitted** from this event. In order to
+ intercept the newly INSERTed state of an object, the
+ :meth:`.SessionEvents.pending_to_persistent` and
+ :meth:`.MapperEvents.after_insert` are better choices.
+
+ .. versionadded:: 1.0.5
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+ :param attrs: sequence of attribute names which
+ were populated.
+
+ .. seealso::
+
+ :ref:`orm_server_defaults`
+
+ :ref:`metadata_defaults_toplevel`
+
+ """
+
+ def expire(self, target, attrs):
+ """Receive an object instance after its attributes or some subset
+ have been expired.
+
+ 'keys' is a list of attribute names. If None, the entire
+ state was expired.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param attrs: sequence of attribute
+ names which were expired, or None if all attributes were
+ expired.
+
+ """
+
+ def pickle(self, target, state_dict):
+ """Receive an object instance when its associated state is
+ being pickled.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param state_dict: the dictionary returned by
+ :class:`.InstanceState.__getstate__`, containing the state
+ to be pickled.
+
+ """
+
+ def unpickle(self, target, state_dict):
+ """Receive an object instance after its associated state has
+ been unpickled.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param state_dict: the dictionary sent to
+ :class:`.InstanceState.__setstate__`, containing the state
+ dictionary which was pickled.
+
+ """
+
+
+class _EventsHold(event.RefCollection):
+ """Hold onto listeners against unmapped, uninstrumented classes.
+
+ Establish _listen() for that class' mapper/instrumentation when
+ those objects are created for that class.
+
+ """
+
+ def __init__(self, class_):
+ self.class_ = class_
+
+ @classmethod
+ def _clear(cls):
+ cls.all_holds.clear()
+
+ class HoldEvents(object):
+ _dispatch_target = None
+
+ @classmethod
+ def _listen(
+ cls, event_key, raw=False, propagate=False, retval=False, **kw
+ ):
+ target = event_key.dispatch_target
+
+ if target.class_ in target.all_holds:
+ collection = target.all_holds[target.class_]
+ else:
+ collection = target.all_holds[target.class_] = {}
+
+ event.registry._stored_in_collection(event_key, target)
+ collection[event_key._key] = (
+ event_key,
+ raw,
+ propagate,
+ retval,
+ kw,
+ )
+
+ if propagate:
+ stack = list(target.class_.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ stack.extend(subclass.__subclasses__())
+ subject = target.resolve(subclass)
+ if subject is not None:
+ # we are already going through __subclasses__()
+ # so leave generic propagate flag False
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval, **kw
+ )
+
+ def remove(self, event_key):
+ target = event_key.dispatch_target
+
+ if isinstance(target, _EventsHold):
+ collection = target.all_holds[target.class_]
+ del collection[event_key._key]
+
+ @classmethod
+ def populate(cls, class_, subject):
+ for subclass in class_.__mro__:
+ if subclass in cls.all_holds:
+ collection = cls.all_holds[subclass]
+ for (
+ event_key,
+ raw,
+ propagate,
+ retval,
+ kw,
+ ) in collection.values():
+ if propagate or subclass is class_:
+ # since we can't be sure in what order different
+ # classes in a hierarchy are triggered with
+ # populate(), we rely upon _EventsHold for all event
+ # assignment, instead of using the generic propagate
+ # flag.
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval, **kw
+ )
+
+
+class _InstanceEventsHold(_EventsHold):
+ all_holds = weakref.WeakKeyDictionary()
+
+ def resolve(self, class_):
+ return instrumentation.manager_of_class(class_)
+
+ class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
+ pass
+
+ dispatch = event.dispatcher(HoldInstanceEvents)
+
+
+class MapperEvents(event.Events):
+ """Define events specific to mappings.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_before_insert_listener(mapper, connection, target):
+ # execute a stored procedure upon INSERT,
+ # apply the value to the row to be inserted
+ target.calculated_value = connection.execute(
+ text("select my_special_function(%d)" % target.special_number)
+ ).scalar()
+
+ # associate the listener function with SomeClass,
+ # to execute during the "before_insert" hook
+ event.listen(
+ SomeClass, 'before_insert', my_before_insert_listener)
+
+ Available targets include:
+
+ * mapped classes
+ * unmapped superclasses of mapped or to-be-mapped classes
+ (using the ``propagate=True`` flag)
+ * :class:`_orm.Mapper` objects
+ * the :class:`_orm.Mapper` class itself and the :func:`.mapper`
+ function indicate listening for all mappers.
+
+ Mapper events provide hooks into critical sections of the
+ mapper, including those related to object instrumentation,
+ object loading, and object persistence. In particular, the
+ persistence methods :meth:`~.MapperEvents.before_insert`,
+ and :meth:`~.MapperEvents.before_update` are popular
+ places to augment the state being persisted - however, these
+ methods operate with several significant restrictions. The
+ user is encouraged to evaluate the
+ :meth:`.SessionEvents.before_flush` and
+ :meth:`.SessionEvents.after_flush` methods as more
+ flexible and user-friendly hooks in which to apply
+ additional database state during a flush.
+
+ When using :class:`.MapperEvents`, several modifiers are
+ available to the :func:`.event.listen` function.
+
+ :param propagate=False: When True, the event listener should
+ be applied to all inheriting mappers and/or the mappers of
+ inheriting classes, as well as any
+ mapper which is the target of this listener.
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions will be the
+ instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param retval=False: when True, the user-defined event function
+ must have a return value, the purpose of which is either to
+ control subsequent event propagation, or to otherwise alter
+ the operation in progress by the mapper. Possible return
+ values are:
+
+ * ``sqlalchemy.orm.interfaces.EXT_CONTINUE`` - continue event
+ processing normally.
+ * ``sqlalchemy.orm.interfaces.EXT_STOP`` - cancel all subsequent
+ event handlers in the chain.
+ * other values - the return value specified by specific listeners.
+
+ """
+
+ _target_class_doc = "SomeClass"
+ _dispatch_target = mapperlib.Mapper
+
+ @classmethod
+ def _new_mapper_instance(cls, class_, mapper):
+ _MapperEventsHold.populate(class_, mapper)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm")
+ def _accept_with(cls, target):
+ orm = util.preloaded.orm
+
+ if target is orm.mapper:
+ return mapperlib.Mapper
+ elif isinstance(target, type):
+ if issubclass(target, mapperlib.Mapper):
+ return target
+ else:
+ mapper = _mapper_or_none(target)
+ if mapper is not None:
+ return mapper
+ else:
+ return _MapperEventsHold(target)
+ else:
+ return target
+
+ @classmethod
+ def _listen(
+ cls, event_key, raw=False, retval=False, propagate=False, **kw
+ ):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ if (
+ identifier in ("before_configured", "after_configured")
+ and target is not mapperlib.Mapper
+ ):
+ util.warn(
+ "'before_configured' and 'after_configured' ORM events "
+ "only invoke with the mapper() function or Mapper class "
+ "as the target."
+ )
+
+ if not raw or not retval:
+ if not raw:
+ meth = getattr(cls, identifier)
+ try:
+ target_index = (
+ inspect_getfullargspec(meth)[0].index("target") - 1
+ )
+ except ValueError:
+ target_index = None
+
+ def wrap(*arg, **kw):
+ if not raw and target_index is not None:
+ arg = list(arg)
+ arg[target_index] = arg[target_index].obj()
+ if not retval:
+ fn(*arg, **kw)
+ return interfaces.EXT_CONTINUE
+ else:
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ if propagate:
+ for mapper in target.self_and_descendants:
+ event_key.with_dispatch_target(mapper).base_listen(
+ propagate=True, **kw
+ )
+ else:
+ event_key.base_listen(**kw)
+
+ @classmethod
+ def _clear(cls):
+ super(MapperEvents, cls)._clear()
+ _MapperEventsHold._clear()
+
+ def instrument_class(self, mapper, class_):
+ r"""Receive a class when the mapper is first constructed,
+ before instrumentation is applied to the mapped class.
+
+ This event is the earliest phase of mapper construction.
+ Most attributes of the mapper are not yet initialized.
+
+ This listener can either be applied to the :class:`_orm.Mapper`
+ class overall, or to any un-mapped class which serves as a base
+ for classes that will be mapped (using the ``propagate=True`` flag)::
+
+ Base = declarative_base()
+
+ @event.listens_for(Base, "instrument_class", propagate=True)
+ def on_new_class(mapper, cls_):
+ " ... "
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param class\_: the mapped class.
+
+ """
+
+ def before_mapper_configured(self, mapper, class_):
+ """Called right before a specific mapper is to be configured.
+
+ This event is intended to allow a specific mapper to be skipped during
+ the configure step, by returning the :attr:`.orm.interfaces.EXT_SKIP`
+ symbol which indicates to the :func:`.configure_mappers` call that this
+ particular mapper (or hierarchy of mappers, if ``propagate=True`` is
+ used) should be skipped in the current configuration run. When one or
+ more mappers are skipped, the he "new mappers" flag will remain set,
+ meaning the :func:`.configure_mappers` function will continue to be
+ called when mappers are used, to continue to try to configure all
+ available mappers.
+
+ In comparison to the other configure-level events,
+ :meth:`.MapperEvents.before_configured`,
+ :meth:`.MapperEvents.after_configured`, and
+ :meth:`.MapperEvents.mapper_configured`, the
+ :meth;`.MapperEvents.before_mapper_configured` event provides for a
+ meaningful return value when it is registered with the ``retval=True``
+ parameter.
+
+ .. versionadded:: 1.3
+
+ e.g.::
+
+ from sqlalchemy.orm import EXT_SKIP
+
+ Base = declarative_base()
+
+ DontConfigureBase = declarative_base()
+
+ @event.listens_for(
+ DontConfigureBase,
+ "before_mapper_configured", retval=True, propagate=True)
+ def dont_configure(mapper, cls):
+ return EXT_SKIP
+
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ """
+
+ def mapper_configured(self, mapper, class_):
+ r"""Called when a specific mapper has completed its own configuration
+ within the scope of the :func:`.configure_mappers` call.
+
+ The :meth:`.MapperEvents.mapper_configured` event is invoked
+ for each mapper that is encountered when the
+ :func:`_orm.configure_mappers` function proceeds through the current
+ list of not-yet-configured mappers.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ When the event is called, the mapper should be in its final
+ state, but **not including backrefs** that may be invoked from
+ other mappers; they might still be pending within the
+ configuration operation. Bidirectional relationships that
+ are instead configured via the
+ :paramref:`.orm.relationship.back_populates` argument
+ *will* be fully available, since this style of relationship does not
+ rely upon other possibly-not-configured mappers to know that they
+ exist.
+
+ For an event that is guaranteed to have **all** mappers ready
+ to go including backrefs that are defined only on other
+ mappings, use the :meth:`.MapperEvents.after_configured`
+ event; this event invokes only after all known mappings have been
+ fully configured.
+
+ The :meth:`.MapperEvents.mapper_configured` event, unlike
+ :meth:`.MapperEvents.before_configured` or
+ :meth:`.MapperEvents.after_configured`,
+ is called for each mapper/class individually, and the mapper is
+ passed to the event itself. It also is called exactly once for
+ a particular mapper. The event is therefore useful for
+ configurational steps that benefit from being invoked just once
+ on a specific mapper basis, which don't require that "backref"
+ configurations are necessarily ready yet.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param class\_: the mapped class.
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ """
+ # TODO: need coverage for this event
+
+ def before_configured(self):
+ """Called before a series of mappers have been configured.
+
+ The :meth:`.MapperEvents.before_configured` event is invoked
+ each time the :func:`_orm.configure_mappers` function is
+ invoked, before the function has done any of its work.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ This event can **only** be applied to the :class:`_orm.Mapper` class
+ or :func:`.mapper` function, and not to individual mappings or
+ mapped classes. It is only invoked for all mappings as a whole::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "before_configured")
+ def go():
+ # ...
+
+ Contrast this event to :meth:`.MapperEvents.after_configured`,
+ which is invoked after the series of mappers has been configured,
+ as well as :meth:`.MapperEvents.before_mapper_configured`
+ and :meth:`.MapperEvents.mapper_configured`, which are both invoked
+ on a per-mapper basis.
+
+ Theoretically this event is called once per
+ application, but is actually called any time new mappers
+ are to be affected by a :func:`_orm.configure_mappers`
+ call. If new mappings are constructed after existing ones have
+ already been used, this event will likely be called again. To ensure
+ that a particular event is only called once and no further, the
+ ``once=True`` argument (new in 0.9.4) can be applied::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "before_configured", once=True)
+ def go():
+ # ...
+
+
+ .. versionadded:: 0.9.3
+
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ """
+
+ def after_configured(self):
+ """Called after a series of mappers have been configured.
+
+ The :meth:`.MapperEvents.after_configured` event is invoked
+ each time the :func:`_orm.configure_mappers` function is
+ invoked, after the function has completed its work.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ Contrast this event to the :meth:`.MapperEvents.mapper_configured`
+ event, which is called on a per-mapper basis while the configuration
+ operation proceeds; unlike that event, when this event is invoked,
+ all cross-configurations (e.g. backrefs) will also have been made
+ available for any mappers that were pending.
+ Also contrast to :meth:`.MapperEvents.before_configured`,
+ which is invoked before the series of mappers has been configured.
+
+ This event can **only** be applied to the :class:`_orm.Mapper` class
+ or :func:`.mapper` function, and not to individual mappings or
+ mapped classes. It is only invoked for all mappings as a whole::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "after_configured")
+ def go():
+ # ...
+
+ Theoretically this event is called once per
+ application, but is actually called any time new mappers
+ have been affected by a :func:`_orm.configure_mappers`
+ call. If new mappings are constructed after existing ones have
+ already been used, this event will likely be called again. To ensure
+ that a particular event is only called once and no further, the
+ ``once=True`` argument (new in 0.9.4) can be applied::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "after_configured", once=True)
+ def go():
+ # ...
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ :meth:`.MapperEvents.before_configured`
+
+ """
+
+ def before_insert(self, mapper, connection, target):
+ """Receive an object instance before an INSERT statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify local, non-object related
+ attributes on the instance before an INSERT occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ The event is often called for a batch of objects of the
+ same class before their INSERT statements are emitted at
+ once in a later step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit INSERT statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_insert(self, mapper, connection, target):
+ """Receive an object instance after an INSERT statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify in-Python-only
+ state on the instance after an INSERT occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ The event is often called for a batch of objects of the
+ same class after their INSERT statements have been
+ emitted at once in a previous step. In the extremely
+ rare case that this is not desirable, the
+ :func:`.mapper` can be configured with ``batch=False``,
+ which will cause batches of instances to be broken up
+ into individual (and more poorly performing)
+ event->persist->event steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit INSERT statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def before_update(self, mapper, connection, target):
+ """Receive an object instance before an UPDATE statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify local, non-object related
+ attributes on the instance before an UPDATE occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ This method is called for all instances that are
+ marked as "dirty", *even those which have no net changes
+ to their column-based attributes*. An object is marked
+ as dirty when any of its column-based attributes have a
+ "set attribute" operation called or when any of its
+ collections are modified. If, at update time, no
+ column-based attributes have any net changes, no UPDATE
+ statement will be issued. This means that an instance
+ being sent to :meth:`~.MapperEvents.before_update` is
+ *not* a guarantee that an UPDATE statement will be
+ issued, although you can affect the outcome here by
+ modifying attributes so that a net change in value does
+ exist.
+
+ To detect if the column-based attributes on the object have net
+ changes, and will therefore generate an UPDATE statement, use
+ ``object_session(instance).is_modified(instance,
+ include_collections=False)``.
+
+ The event is often called for a batch of objects of the
+ same class before their UPDATE statements are emitted at
+ once in a later step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit UPDATE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_update(self, mapper, connection, target):
+ """Receive an object instance after an UPDATE statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify in-Python-only
+ state on the instance after an UPDATE occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ This method is called for all instances that are
+ marked as "dirty", *even those which have no net changes
+ to their column-based attributes*, and for which
+ no UPDATE statement has proceeded. An object is marked
+ as dirty when any of its column-based attributes have a
+ "set attribute" operation called or when any of its
+ collections are modified. If, at update time, no
+ column-based attributes have any net changes, no UPDATE
+ statement will be issued. This means that an instance
+ being sent to :meth:`~.MapperEvents.after_update` is
+ *not* a guarantee that an UPDATE statement has been
+ issued.
+
+ To detect if the column-based attributes on the object have net
+ changes, and therefore resulted in an UPDATE statement, use
+ ``object_session(instance).is_modified(instance,
+ include_collections=False)``.
+
+ The event is often called for a batch of objects of the
+ same class after their UPDATE statements have been emitted at
+ once in a previous step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit UPDATE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def before_delete(self, mapper, connection, target):
+ """Receive an object instance before a DELETE statement
+ is emitted corresponding to that instance.
+
+ This event is used to emit additional SQL statements on
+ the given connection as well as to perform application
+ specific bookkeeping related to a deletion event.
+
+ The event is often called for a batch of objects of the
+ same class before their DELETE statements are emitted at
+ once in a later step.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit DELETE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being deleted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_delete(self, mapper, connection, target):
+ """Receive an object instance after a DELETE statement
+ has been emitted corresponding to that instance.
+
+ This event is used to emit additional SQL statements on
+ the given connection as well as to perform application
+ specific bookkeeping related to a deletion event.
+
+ The event is often called for a batch of objects of the
+ same class after their DELETE statements have been emitted at
+ once in a previous step.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit DELETE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being deleted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+
+class _MapperEventsHold(_EventsHold):
+ all_holds = weakref.WeakKeyDictionary()
+
+ def resolve(self, class_):
+ return _mapper_or_none(class_)
+
+ class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
+ pass
+
+ dispatch = event.dispatcher(HoldMapperEvents)
+
+
+_sessionevents_lifecycle_event_names = set()
+
+
+class SessionEvents(event.Events):
+ """Define events specific to :class:`.Session` lifecycle.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.orm import sessionmaker
+
+ def my_before_commit(session):
+ print("before commit!")
+
+ Session = sessionmaker()
+
+ event.listen(Session, "before_commit", my_before_commit)
+
+ The :func:`~.event.listen` function will accept
+ :class:`.Session` objects as well as the return result
+ of :class:`~.sessionmaker()` and :class:`~.scoped_session()`.
+
+ Additionally, it accepts the :class:`.Session` class which
+ will apply listeners to all :class:`.Session` instances
+ globally.
+
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions that work on individual
+ objects will be the instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+
+ .. versionadded:: 1.3.14
+
+ :param restore_load_context=False: Applies to the
+ :meth:`.SessionEvents.loaded_as_persistent` event. Restores the loader
+ context of the object when the event hook is complete, so that ongoing
+ eager load operations continue to target the object appropriately. A
+ warning is emitted if the object is moved to a new loader context from
+ within this event if this flag is not set.
+
+ .. versionadded:: 1.3.14
+
+ """
+
+ _target_class_doc = "SomeSessionClassOrObject"
+
+ _dispatch_target = Session
+
+ def _lifecycle_event(fn):
+ _sessionevents_lifecycle_event_names.add(fn.__name__)
+ return fn
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, scoped_session):
+
+ target = target.session_factory
+ if not isinstance(target, sessionmaker) and (
+ not isinstance(target, type) or not issubclass(target, Session)
+ ):
+ raise exc.ArgumentError(
+ "Session event listen on a scoped_session "
+ "requires that its creation callable "
+ "is associated with the Session class."
+ )
+
+ if isinstance(target, sessionmaker):
+ return target.class_
+ elif isinstance(target, type):
+ if issubclass(target, scoped_session):
+ return Session
+ elif issubclass(target, Session):
+ return target
+ elif isinstance(target, Session):
+ return target
+ else:
+ # allows alternate SessionEvents-like-classes to be consulted
+ return event.Events._accept_with(target)
+
+ @classmethod
+ def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
+ is_instance_event = (
+ event_key.identifier in _sessionevents_lifecycle_event_names
+ )
+
+ if is_instance_event:
+ if not raw or restore_load_context:
+
+ fn = event_key._listen_fn
+
+ def wrap(session, state, *arg, **kw):
+ if not raw:
+ target = state.obj()
+ if target is None:
+ # existing behavior is that if the object is
+ # garbage collected, no event is emitted
+ return
+ else:
+ target = state
+ if restore_load_context:
+ runid = state.runid
+ try:
+ return fn(session, target, *arg, **kw)
+ finally:
+ if restore_load_context:
+ state.runid = runid
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(**kw)
+
+ def do_orm_execute(self, orm_execute_state):
+ """Intercept statement executions that occur on behalf of an
+ ORM :class:`.Session` object.
+
+ This event is invoked for all top-level SQL statements invoked from the
+ :meth:`_orm.Session.execute` method, as well as related methods such as
+ :meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of
+ SQLAlchemy 1.4, all ORM queries emitted on behalf of a
+ :class:`_orm.Session` will flow through this method, so this event hook
+ provides the single point at which ORM queries of all types may be
+ intercepted before they are invoked, and additionally to replace their
+ execution with a different process.
+
+ .. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook
+ is triggered **for ORM statement executions only**, meaning those
+ invoked via the :meth:`_orm.Session.execute` and similar methods on
+ the :class:`_orm.Session` object. It does **not** trigger for
+ statements that are invoked by SQLAlchemy Core only, i.e. statements
+ invoked directly using :meth:`_engine.Connection.execute` or
+ otherwise originating from an :class:`_engine.Engine` object without
+ any :class:`_orm.Session` involved. To intercept **all** SQL
+ executions regardless of whether the Core or ORM APIs are in use,
+ see the event hooks at
+ :class:`.ConnectionEvents`, such as
+ :meth:`.ConnectionEvents.before_execute` and
+ :meth:`.ConnectionEvents.before_cursor_execute`.
+
+ This event is a ``do_`` event, meaning it has the capability to replace
+ the operation that the :meth:`_orm.Session.execute` method normally
+ performs. The intended use for this includes sharding and
+ result-caching schemes which may seek to invoke the same statement
+ across multiple database connections, returning a result that is
+ merged from each of them, or which don't invoke the statement at all,
+ instead returning data from a cache.
+
+ The hook intends to replace the use of the
+ ``Query._execute_and_instances`` method that could be subclassed prior
+ to SQLAlchemy 1.4.
+
+ :param orm_execute_state: an instance of :class:`.ORMExecuteState`
+ which contains all information about the current execution, as well
+ as helper functions used to derive other commonly required
+ information. See that object for details.
+
+ .. seealso::
+
+ :ref:`session_execute_events` - top level documentation on how
+ to use :meth:`_orm.SessionEvents.do_orm_execute`
+
+ :class:`.ORMExecuteState` - the object passed to the
+ :meth:`_orm.SessionEvents.do_orm_execute` event which contains
+ all information about the statement to be invoked. It also
+ provides an interface to extend the current statement, options,
+ and parameters as well as an option that allows programmatic
+ invocation of the statement at any point.
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :meth:`_orm.SessionEvents.do_orm_execute`
+
+ :ref:`examples_caching` - an example of how to integrate
+ Dogpile caching with the ORM :class:`_orm.Session` making use
+ of the :meth:`_orm.SessionEvents.do_orm_execute` event hook.
+
+ :ref:`examples_sharding` - the Horizontal Sharding example /
+ extension relies upon the
+ :meth:`_orm.SessionEvents.do_orm_execute` event hook to invoke a
+ SQL statement on multiple backends and return a merged result.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ def after_transaction_create(self, session, transaction):
+ """Execute when a new :class:`.SessionTransaction` is created.
+
+ This event differs from :meth:`~.SessionEvents.after_begin`
+ in that it occurs for each :class:`.SessionTransaction`
+ overall, as opposed to when transactions are begun
+ on individual database connections. It is also invoked
+ for nested transactions and subtransactions, and is always
+ matched by a corresponding
+ :meth:`~.SessionEvents.after_transaction_end` event
+ (assuming normal operation of the :class:`.Session`).
+
+ :param session: the target :class:`.Session`.
+ :param transaction: the target :class:`.SessionTransaction`.
+
+ To detect if this is the outermost
+ :class:`.SessionTransaction`, as opposed to a "subtransaction" or a
+ SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute
+ is ``None``::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_create(session, transaction):
+ if transaction.parent is None:
+ # work with top-level transaction
+
+ To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the
+ :attr:`.SessionTransaction.nested` attribute::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_create(session, transaction):
+ if transaction.nested:
+ # work with SAVEPOINT transaction
+
+
+ .. seealso::
+
+ :class:`.SessionTransaction`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_transaction_end(self, session, transaction):
+ """Execute when the span of a :class:`.SessionTransaction` ends.
+
+ This event differs from :meth:`~.SessionEvents.after_commit`
+ in that it corresponds to all :class:`.SessionTransaction`
+ objects in use, including those for nested transactions
+ and subtransactions, and is always matched by a corresponding
+ :meth:`~.SessionEvents.after_transaction_create` event.
+
+ :param session: the target :class:`.Session`.
+ :param transaction: the target :class:`.SessionTransaction`.
+
+ To detect if this is the outermost
+ :class:`.SessionTransaction`, as opposed to a "subtransaction" or a
+ SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute
+ is ``None``::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_end(session, transaction):
+ if transaction.parent is None:
+ # work with top-level transaction
+
+ To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the
+ :attr:`.SessionTransaction.nested` attribute::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_end(session, transaction):
+ if transaction.nested:
+ # work with SAVEPOINT transaction
+
+
+ .. seealso::
+
+ :class:`.SessionTransaction`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ """
+
+ def before_commit(self, session):
+ """Execute before commit is called.
+
+ .. note::
+
+ The :meth:`~.SessionEvents.before_commit` hook is *not* per-flush,
+ that is, the :class:`.Session` can emit SQL to the database
+ many times within the scope of a transaction.
+ For interception of these events, use the
+ :meth:`~.SessionEvents.before_flush`,
+ :meth:`~.SessionEvents.after_flush`, or
+ :meth:`~.SessionEvents.after_flush_postexec`
+ events.
+
+ :param session: The target :class:`.Session`.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_commit`
+
+ :meth:`~.SessionEvents.after_begin`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_commit(self, session):
+ """Execute after a commit has occurred.
+
+ .. note::
+
+ The :meth:`~.SessionEvents.after_commit` hook is *not* per-flush,
+ that is, the :class:`.Session` can emit SQL to the database
+ many times within the scope of a transaction.
+ For interception of these events, use the
+ :meth:`~.SessionEvents.before_flush`,
+ :meth:`~.SessionEvents.after_flush`, or
+ :meth:`~.SessionEvents.after_flush_postexec`
+ events.
+
+ .. note::
+
+ The :class:`.Session` is not in an active transaction
+ when the :meth:`~.SessionEvents.after_commit` event is invoked,
+ and therefore can not emit SQL. To emit SQL corresponding to
+ every transaction, use the :meth:`~.SessionEvents.before_commit`
+ event.
+
+ :param session: The target :class:`.Session`.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_commit`
+
+ :meth:`~.SessionEvents.after_begin`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_rollback(self, session):
+ """Execute after a real DBAPI rollback has occurred.
+
+ Note that this event only fires when the *actual* rollback against
+ the database occurs - it does *not* fire each time the
+ :meth:`.Session.rollback` method is called, if the underlying
+ DBAPI transaction has already been rolled back. In many
+ cases, the :class:`.Session` will not be in
+ an "active" state during this event, as the current
+ transaction is not valid. To acquire a :class:`.Session`
+ which is active after the outermost rollback has proceeded,
+ use the :meth:`.SessionEvents.after_soft_rollback` event, checking the
+ :attr:`.Session.is_active` flag.
+
+ :param session: The target :class:`.Session`.
+
+ """
+
+ def after_soft_rollback(self, session, previous_transaction):
+ """Execute after any rollback has occurred, including "soft"
+ rollbacks that don't actually emit at the DBAPI level.
+
+ This corresponds to both nested and outer rollbacks, i.e.
+ the innermost rollback that calls the DBAPI's
+ rollback() method, as well as the enclosing rollback
+ calls that only pop themselves from the transaction stack.
+
+ The given :class:`.Session` can be used to invoke SQL and
+ :meth:`.Session.query` operations after an outermost rollback
+ by first checking the :attr:`.Session.is_active` flag::
+
+ @event.listens_for(Session, "after_soft_rollback")
+ def do_something(session, previous_transaction):
+ if session.is_active:
+ session.execute("select * from some_table")
+
+ :param session: The target :class:`.Session`.
+ :param previous_transaction: The :class:`.SessionTransaction`
+ transactional marker object which was just closed. The current
+ :class:`.SessionTransaction` for the given :class:`.Session` is
+ available via the :attr:`.Session.transaction` attribute.
+
+ """
+
+ def before_flush(self, session, flush_context, instances):
+ """Execute before flush process has started.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+ :param instances: Usually ``None``, this is the collection of
+ objects which can be passed to the :meth:`.Session.flush` method
+ (note this usage is deprecated).
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_flush`
+
+ :meth:`~.SessionEvents.after_flush_postexec`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_flush(self, session, flush_context):
+ """Execute after flush has completed, but before commit has been
+ called.
+
+ Note that the session's state is still in pre-flush, i.e. 'new',
+ 'dirty', and 'deleted' lists still show pre-flush state as well
+ as the history settings on instance attributes.
+
+ .. warning:: This event runs after the :class:`.Session` has emitted
+ SQL to modify the database, but **before** it has altered its
+ internal state to reflect those changes, including that newly
+ inserted objects are placed into the identity map. ORM operations
+ emitted within this event such as loads of related items
+ may produce new identity map entries that will immediately
+ be replaced, sometimes causing confusing results. SQLAlchemy will
+ emit a warning for this condition as of version 1.3.9.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_flush`
+
+ :meth:`~.SessionEvents.after_flush_postexec`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_flush_postexec(self, session, flush_context):
+ """Execute after flush has completed, and after the post-exec
+ state occurs.
+
+ This will be when the 'new', 'dirty', and 'deleted' lists are in
+ their final state. An actual commit() may or may not have
+ occurred, depending on whether or not the flush started its own
+ transaction or participated in a larger transaction.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_flush`
+
+ :meth:`~.SessionEvents.after_flush`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ :param session: The target :class:`.Session`.
+ :param transaction: The :class:`.SessionTransaction`.
+ :param connection: The :class:`_engine.Connection` object
+ which will be used for SQL statements.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_commit`
+
+ :meth:`~.SessionEvents.after_commit`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ @_lifecycle_event
+ def before_attach(self, session, instance):
+ """Execute before an instance is attached to a session.
+
+ This is called before an add, delete or merge causes
+ the object to be part of the session.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_attach`
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def after_attach(self, session, instance):
+ """Execute after an instance is attached to a session.
+
+ This is called after an add, delete or merge.
+
+ .. note::
+
+ As of 0.8, this event fires off *after* the item
+ has been fully associated with the session, which is
+ different than previous releases. For event
+ handlers that require the object not yet
+ be part of session state (such as handlers which
+ may autoflush while the target object is not
+ yet complete) consider the
+ new :meth:`.before_attach` event.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_attach`
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda update_context: (
+ update_context.session,
+ update_context.query,
+ None,
+ update_context.result,
+ ),
+ )
+ def after_bulk_update(self, update_context):
+ """Execute after an ORM UPDATE against a WHERE expression has been
+ invoked.
+
+ This is called as a result of the :meth:`_query.Query.update` method.
+
+ :param update_context: an "update context" object which contains
+ details about the update, including these attributes:
+
+ * ``session`` - the :class:`.Session` involved
+ * ``query`` -the :class:`_query.Query`
+ object that this update operation
+ was called upon.
+ * ``values`` The "values" dictionary that was passed to
+ :meth:`_query.Query.update`.
+ * ``result`` the :class:`_engine.CursorResult`
+ returned as a result of the
+ bulk UPDATE operation.
+
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_update`
+
+ :meth:`.SessionEvents.after_bulk_delete`
+
+ """
+
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda delete_context: (
+ delete_context.session,
+ delete_context.query,
+ None,
+ delete_context.result,
+ ),
+ )
+ def after_bulk_delete(self, delete_context):
+ """Execute after ORM DELETE against a WHERE expression has been
+ invoked.
+
+ This is called as a result of the :meth:`_query.Query.delete` method.
+
+ :param delete_context: a "delete context" object which contains
+ details about the update, including these attributes:
+
+ * ``session`` - the :class:`.Session` involved
+ * ``query`` -the :class:`_query.Query`
+ object that this update operation
+ was called upon.
+ * ``result`` the :class:`_engine.CursorResult`
+ returned as a result of the
+ bulk DELETE operation.
+
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+ :meth:`.SessionEvents.after_bulk_update`
+
+ """
+
+ @_lifecycle_event
+ def transient_to_pending(self, session, instance):
+ """Intercept the "transient to pending" transition for a specific
+ object.
+
+ This event is a specialization of the
+ :meth:`.SessionEvents.after_attach` event which is only invoked
+ for this specific transition. It is invoked typically during the
+ :meth:`.Session.add` call.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def pending_to_transient(self, session, instance):
+ """Intercept the "pending to transient" transition for a specific
+ object.
+
+ This less common transition occurs when an pending object that has
+ not been flushed is evicted from the session; this can occur
+ when the :meth:`.Session.rollback` method rolls back the transaction,
+ or when the :meth:`.Session.expunge` method is used.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_transient(self, session, instance):
+ """Intercept the "persistent to transient" transition for a specific
+ object.
+
+ This less common transition occurs when an pending object that has
+ has been flushed is evicted from the session; this can occur
+ when the :meth:`.Session.rollback` method rolls back the transaction.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def pending_to_persistent(self, session, instance):
+ """Intercept the "pending to persistent"" transition for a specific
+ object.
+
+ This event is invoked within the flush process, and is
+ similar to scanning the :attr:`.Session.new` collection within
+ the :meth:`.SessionEvents.after_flush` event. However, in this
+ case the object has already been moved to the persistent state
+ when the event is called.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def detached_to_persistent(self, session, instance):
+ """Intercept the "detached to persistent" transition for a specific
+ object.
+
+ This event is a specialization of the
+ :meth:`.SessionEvents.after_attach` event which is only invoked
+ for this specific transition. It is invoked typically during the
+ :meth:`.Session.add` call, as well as during the
+ :meth:`.Session.delete` call if the object was not previously
+ associated with the
+ :class:`.Session` (note that an object marked as "deleted" remains
+ in the "persistent" state until the flush proceeds).
+
+ .. note::
+
+ If the object becomes persistent as part of a call to
+ :meth:`.Session.delete`, the object is **not** yet marked as
+ deleted when this event is called. To detect deleted objects,
+ check the ``deleted`` flag sent to the
+ :meth:`.SessionEvents.persistent_to_detached` to event after the
+ flush proceeds, or check the :attr:`.Session.deleted` collection
+ within the :meth:`.SessionEvents.before_flush` event if deleted
+ objects need to be intercepted before the flush.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def loaded_as_persistent(self, session, instance):
+ """Intercept the "loaded as persistent" transition for a specific
+ object.
+
+ This event is invoked within the ORM loading process, and is invoked
+ very similarly to the :meth:`.InstanceEvents.load` event. However,
+ the event here is linkable to a :class:`.Session` class or instance,
+ rather than to a mapper or class hierarchy, and integrates
+ with the other session lifecycle events smoothly. The object
+ is guaranteed to be present in the session's identity map when
+ this event is called.
+
+ .. note:: This event is invoked within the loader process before
+ eager loaders may have been completed, and the object's state may
+ not be complete. Additionally, invoking row-level refresh
+ operations on the object will place the object into a new loader
+ context, interfering with the existing load context. See the note
+ on :meth:`.InstanceEvents.load` for background on making use of the
+ :paramref:`.SessionEvents.restore_load_context` parameter, which
+ works in the same manner as that of
+ :paramref:`.InstanceEvents.restore_load_context`, in order to
+ resolve this scenario.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_deleted(self, session, instance):
+ """Intercept the "persistent to deleted" transition for a specific
+ object.
+
+ This event is invoked when a persistent object's identity
+ is deleted from the database within a flush, however the object
+ still remains associated with the :class:`.Session` until the
+ transaction completes.
+
+ If the transaction is rolled back, the object moves again
+ to the persistent state, and the
+ :meth:`.SessionEvents.deleted_to_persistent` event is called.
+ If the transaction is committed, the object becomes detached,
+ which will emit the :meth:`.SessionEvents.deleted_to_detached`
+ event.
+
+ Note that while the :meth:`.Session.delete` method is the primary
+ public interface to mark an object as deleted, many objects
+ get deleted due to cascade rules, which are not always determined
+ until flush time. Therefore, there's no way to catch
+ every object that will be deleted until the flush has proceeded.
+ the :meth:`.SessionEvents.persistent_to_deleted` event is therefore
+ invoked at the end of a flush.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def deleted_to_persistent(self, session, instance):
+ """Intercept the "deleted to persistent" transition for a specific
+ object.
+
+ This transition occurs only when an object that's been deleted
+ successfully in a flush is restored due to a call to
+ :meth:`.Session.rollback`. The event is not called under
+ any other circumstances.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def deleted_to_detached(self, session, instance):
+ """Intercept the "deleted to detached" transition for a specific
+ object.
+
+ This event is invoked when a deleted object is evicted
+ from the session. The typical case when this occurs is when
+ the transaction for a :class:`.Session` in which the object
+ was deleted is committed; the object moves from the deleted
+ state to the detached state.
+
+ It is also invoked for objects that were deleted in a flush
+ when the :meth:`.Session.expunge_all` or :meth:`.Session.close`
+ events are called, as well as if the object is individually
+ expunged from its deleted state via :meth:`.Session.expunge`.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_detached(self, session, instance):
+ """Intercept the "persistent to detached" transition for a specific
+ object.
+
+ This event is invoked when a persistent object is evicted
+ from the session. There are many conditions that cause this
+ to happen, including:
+
+ * using a method such as :meth:`.Session.expunge`
+ or :meth:`.Session.close`
+
+ * Calling the :meth:`.Session.rollback` method, when the object
+ was part of an INSERT statement for that session's transaction
+
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ :param deleted: boolean. If True, indicates this object moved
+ to the detached state because it was marked as deleted and flushed.
+
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+
+class AttributeEvents(event.Events):
+ r"""Define events for object attributes.
+
+ These are typically defined on the class-bound descriptor for the
+ target class.
+
+ For example, to register a listener that will receive the
+ :meth:`_orm.AttributeEvents.append` event::
+
+ from sqlalchemy import event
+
+ @event.listens_for(MyClass.collection, 'append', propagate=True)
+ def my_append_listener(target, value, initiator):
+ print("received append event for target: %s" % target)
+
+
+ Listeners have the option to return a possibly modified version of the
+ value, when the :paramref:`.AttributeEvents.retval` flag is passed to
+ :func:`.event.listen` or :func:`.event.listens_for`, such as below,
+ illustrated using the :meth:`_orm.AttributeEvents.set` event::
+
+ def validate_phone(target, value, oldvalue, initiator):
+ "Strip non-numeric characters from a phone number"
+
+ return re.sub(r'\D', '', value)
+
+ # setup listener on UserContact.phone attribute, instructing
+ # it to use the return value
+ listen(UserContact.phone, 'set', validate_phone, retval=True)
+
+ A validation function like the above can also raise an exception
+ such as :exc:`ValueError` to halt the operation.
+
+ The :paramref:`.AttributeEvents.propagate` flag is also important when
+ applying listeners to mapped classes that also have mapped subclasses,
+ as when using mapper inheritance patterns::
+
+
+ @event.listens_for(MySuperClass.attr, 'set', propagate=True)
+ def receive_set(target, value, initiator):
+ print("value set: %s" % target)
+
+ The full list of modifiers available to the :func:`.event.listen`
+ and :func:`.event.listens_for` functions are below.
+
+ :param active_history=False: When True, indicates that the
+ "set" event would like to receive the "old" value being
+ replaced unconditionally, even if this requires firing off
+ database loads. Note that ``active_history`` can also be
+ set directly via :func:`.column_property` and
+ :func:`_orm.relationship`.
+
+ :param propagate=False: When True, the listener function will
+ be established not just for the class attribute given, but
+ for attributes of the same name on all current subclasses
+ of that class, as well as all future subclasses of that
+ class, using an additional listener that listens for
+ instrumentation events.
+ :param raw=False: When True, the "target" argument to the
+ event will be the :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param retval=False: when True, the user-defined event
+ listening must return the "value" argument from the
+ function. This gives the listening function the opportunity
+ to change the value that is ultimately used for a "set"
+ or "append" event.
+
+ """
+
+ _target_class_doc = "SomeClass.some_attribute"
+ _dispatch_target = QueryableAttribute
+
+ @staticmethod
+ def _set_dispatch(cls, dispatch_cls):
+ dispatch = event.Events._set_dispatch(cls, dispatch_cls)
+ dispatch_cls._active_history = False
+ return dispatch
+
+ @classmethod
+ def _accept_with(cls, target):
+ # TODO: coverage
+ if isinstance(target, interfaces.MapperProperty):
+ return getattr(target.parent.class_, target.key)
+ else:
+ return target
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ active_history=False,
+ raw=False,
+ retval=False,
+ propagate=False,
+ ):
+
+ target, fn = event_key.dispatch_target, event_key._listen_fn
+
+ if active_history:
+ target.dispatch._active_history = True
+
+ if not raw or not retval:
+
+ def wrap(target, *arg):
+ if not raw:
+ target = target.obj()
+ if not retval:
+ if arg:
+ value = arg[0]
+ else:
+ value = None
+ fn(target, *arg)
+ return value
+ else:
+ return fn(target, *arg)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(propagate=propagate)
+
+ if propagate:
+ manager = instrumentation.manager_of_class(target.class_)
+
+ for mgr in manager.subclass_managers(True):
+ event_key.with_dispatch_target(mgr[target.key]).base_listen(
+ propagate=True
+ )
+ if active_history:
+ mgr[target.key].dispatch._active_history = True
+
+ def append(self, target, value, initiator):
+ """Receive a collection append event.
+
+ The append event is invoked for each element as it is appended
+ to the collection. This occurs for single-item appends as well
+ as for a "bulk replace" operation.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being appended. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation, as well as be inspected for information
+ about the source of the event.
+ :return: if the event was registered with ``retval=True``,
+ the given value, or a new effective value, should be returned.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :meth:`.AttributeEvents.bulk_replace`
+
+ """
+
+ def append_wo_mutation(self, target, value, initiator):
+ """Receive a collection append event where the collection was not
+ actually mutated.
+
+ This event differs from :meth:`_orm.AttributeEvents.append` in that
+ it is fired off for de-duplicating collections such as sets and
+ dictionaries, when the object already exists in the target collection.
+ The event does not have a return value and the identity of the
+ given object cannot be changed.
+
+ The event is used for cascading objects into a :class:`_orm.Session`
+ when the collection has already been mutated via a backref event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value that would be appended if the object did not
+ already exist in the collection.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation, as well as be inspected for information
+ about the source of the event.
+
+ :return: No return value is defined for this event.
+
+ .. versionadded:: 1.4.15
+
+ """
+
+ def bulk_replace(self, target, values, initiator):
+ """Receive a collection 'bulk replace' event.
+
+ This event is invoked for a sequence of values as they are incoming
+ to a bulk collection set operation, which can be
+ modified in place before the values are treated as ORM objects.
+ This is an "early hook" that runs before the bulk replace routine
+ attempts to reconcile which objects are already present in the
+ collection and which are being removed by the net replace operation.
+
+ It is typical that this method be combined with use of the
+ :meth:`.AttributeEvents.append` event. When using both of these
+ events, note that a bulk replace operation will invoke
+ the :meth:`.AttributeEvents.append` event for all new items,
+ even after :meth:`.AttributeEvents.bulk_replace` has been invoked
+ for the collection as a whole. In order to determine if an
+ :meth:`.AttributeEvents.append` event is part of a bulk replace,
+ use the symbol :attr:`~.attributes.OP_BULK_REPLACE` to test the
+ incoming initiator::
+
+ from sqlalchemy.orm.attributes import OP_BULK_REPLACE
+
+ @event.listens_for(SomeObject.collection, "bulk_replace")
+ def process_collection(target, values, initiator):
+ values[:] = [_make_value(value) for value in values]
+
+ @event.listens_for(SomeObject.collection, "append", retval=True)
+ def process_collection(target, value, initiator):
+ # make sure bulk_replace didn't already do it
+ if initiator is None or initiator.op is not OP_BULK_REPLACE:
+ return _make_value(value)
+ else:
+ return value
+
+ .. versionadded:: 1.2
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: a sequence (e.g. a list) of the values being set. The
+ handler can modify this list in place.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+
+ """
+
+ def remove(self, target, value, initiator):
+ """Receive a collection remove event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being removed.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation.
+
+ .. versionchanged:: 0.9.0 the ``initiator`` argument is now
+ passed as a :class:`.attributes.Event` object, and may be
+ modified by backref handlers within a chain of backref-linked
+ events.
+
+ :return: No return value is defined for this event.
+
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def set(self, target, value, oldvalue, initiator):
+ """Receive a scalar set event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being set. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param oldvalue: the previous value being replaced. This
+ may also be the symbol ``NEVER_SET`` or ``NO_VALUE``.
+ If the listener is registered with ``active_history=True``,
+ the previous value of the attribute will be loaded from
+ the database if the existing value is currently unloaded
+ or expired.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation.
+
+ .. versionchanged:: 0.9.0 the ``initiator`` argument is now
+ passed as a :class:`.attributes.Event` object, and may be
+ modified by backref handlers within a chain of backref-linked
+ events.
+
+ :return: if the event was registered with ``retval=True``,
+ the given value, or a new effective value, should be returned.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def init_scalar(self, target, value, dict_):
+ r"""Receive a scalar "init" event.
+
+ This event is invoked when an uninitialized, unpersisted scalar
+ attribute is accessed, e.g. read::
+
+
+ x = my_object.some_attribute
+
+ The ORM's default behavior when this occurs for an un-initialized
+ attribute is to return the value ``None``; note this differs from
+ Python's usual behavior of raising ``AttributeError``. The
+ event here can be used to customize what value is actually returned,
+ with the assumption that the event listener would be mirroring
+ a default generator that is configured on the Core
+ :class:`_schema.Column`
+ object as well.
+
+ Since a default generator on a :class:`_schema.Column`
+ might also produce
+ a changing value such as a timestamp, the
+ :meth:`.AttributeEvents.init_scalar`
+ event handler can also be used to **set** the newly returned value, so
+ that a Core-level default generation function effectively fires off
+ only once, but at the moment the attribute is accessed on the
+ non-persisted object. Normally, no change to the object's state
+ is made when an uninitialized attribute is accessed (much older
+ SQLAlchemy versions did in fact change the object's state).
+
+ If a default generator on a column returned a particular constant,
+ a handler might be used as follows::
+
+ SOME_CONSTANT = 3.1415926
+
+ class MyClass(Base):
+ # ...
+
+ some_attribute = Column(Numeric, default=SOME_CONSTANT)
+
+ @event.listens_for(
+ MyClass.some_attribute, "init_scalar",
+ retval=True, propagate=True)
+ def _init_some_attribute(target, dict_, value):
+ dict_['some_attribute'] = SOME_CONSTANT
+ return SOME_CONSTANT
+
+ Above, we initialize the attribute ``MyClass.some_attribute`` to the
+ value of ``SOME_CONSTANT``. The above code includes the following
+ features:
+
+ * By setting the value ``SOME_CONSTANT`` in the given ``dict_``,
+ we indicate that this value is to be persisted to the database.
+ This supersedes the use of ``SOME_CONSTANT`` in the default generator
+ for the :class:`_schema.Column`. The ``active_column_defaults.py``
+ example given at :ref:`examples_instrumentation` illustrates using
+ the same approach for a changing default, e.g. a timestamp
+ generator. In this particular example, it is not strictly
+ necessary to do this since ``SOME_CONSTANT`` would be part of the
+ INSERT statement in either case.
+
+ * By establishing the ``retval=True`` flag, the value we return
+ from the function will be returned by the attribute getter.
+ Without this flag, the event is assumed to be a passive observer
+ and the return value of our function is ignored.
+
+ * The ``propagate=True`` flag is significant if the mapped class
+ includes inheriting subclasses, which would also make use of this
+ event listener. Without this flag, an inheriting subclass will
+ not use our event handler.
+
+ In the above example, the attribute set event
+ :meth:`.AttributeEvents.set` as well as the related validation feature
+ provided by :obj:`_orm.validates` is **not** invoked when we apply our
+ value to the given ``dict_``. To have these events to invoke in
+ response to our newly generated value, apply the value to the given
+ object as a normal attribute set operation::
+
+ SOME_CONSTANT = 3.1415926
+
+ @event.listens_for(
+ MyClass.some_attribute, "init_scalar",
+ retval=True, propagate=True)
+ def _init_some_attribute(target, dict_, value):
+ # will also fire off attribute set events
+ target.some_attribute = SOME_CONSTANT
+ return SOME_CONSTANT
+
+ When multiple listeners are set up, the generation of the value
+ is "chained" from one listener to the next by passing the value
+ returned by the previous listener that specifies ``retval=True``
+ as the ``value`` argument of the next listener.
+
+ .. versionadded:: 1.1
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value that is to be returned before this event
+ listener were invoked. This value begins as the value ``None``,
+ however will be the return value of the previous event handler
+ function if multiple listeners are present.
+ :param dict\_: the attribute dictionary of this mapped object.
+ This is normally the ``__dict__`` of the object, but in all cases
+ represents the destination that the attribute system uses to get
+ at the actual value of this attribute. Placing the value in this
+ dictionary has the effect that the value will be used in the
+ INSERT statement generated by the unit of work.
+
+
+ .. seealso::
+
+ :meth:`.AttributeEvents.init_collection` - collection version
+ of this event
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :ref:`examples_instrumentation` - see the
+ ``active_column_defaults.py`` example.
+
+ """
+
+ def init_collection(self, target, collection, collection_adapter):
+ """Receive a 'collection init' event.
+
+ This event is triggered for a collection-based attribute, when
+ the initial "empty collection" is first generated for a blank
+ attribute, as well as for when the collection is replaced with
+ a new one, such as via a set event.
+
+ E.g., given that ``User.addresses`` is a relationship-based
+ collection, the event is triggered here::
+
+ u1 = User()
+ u1.addresses.append(a1) # <- new collection
+
+ and also during replace operations::
+
+ u1.addresses = [a2, a3] # <- new collection
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param collection: the new collection. This will always be generated
+ from what was specified as
+ :paramref:`_orm.relationship.collection_class`, and will always
+ be empty.
+ :param collection_adapter: the :class:`.CollectionAdapter` that will
+ mediate internal access to the collection.
+
+ .. versionadded:: 1.0.0 :meth:`.AttributeEvents.init_collection`
+ and :meth:`.AttributeEvents.dispose_collection` events.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :meth:`.AttributeEvents.init_scalar` - "scalar" version of this
+ event.
+
+ """
+
+ def dispose_collection(self, target, collection, collection_adapter):
+ """Receive a 'collection dispose' event.
+
+ This event is triggered for a collection-based attribute when
+ a collection is replaced, that is::
+
+ u1.addresses.append(a1)
+
+ u1.addresses = [a2, a3] # <- old collection is disposed
+
+ The old collection received will contain its previous contents.
+
+ .. versionchanged:: 1.2 The collection passed to
+ :meth:`.AttributeEvents.dispose_collection` will now have its
+ contents before the dispose intact; previously, the collection
+ would be empty.
+
+ .. versionadded:: 1.0.0 the :meth:`.AttributeEvents.init_collection`
+ and :meth:`.AttributeEvents.dispose_collection` events.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def modified(self, target, initiator):
+ """Receive a 'modified' event.
+
+ This event is triggered when the :func:`.attributes.flag_modified`
+ function is used to trigger a modify event on an attribute without
+ any specific value being set.
+
+ .. versionadded:: 1.2
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+
+class QueryEvents(event.Events):
+ """Represent events within the construction of a :class:`_query.Query`
+ object.
+
+ The :class:`_orm.QueryEvents` hooks are now superseded by the
+ :meth:`_orm.SessionEvents.do_orm_execute` event hook.
+
+ """
+
+ _target_class_doc = "SomeQuery"
+ _dispatch_target = Query
+
+ def before_compile(self, query):
+ """Receive the :class:`_query.Query`
+ object before it is composed into a
+ core :class:`_expression.Select` object.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile` event
+ is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook. In version 1.4,
+ the :meth:`_orm.QueryEvents.before_compile` event is **no longer
+ used** for ORM-level attribute loads, such as loads of deferred
+ or expired attributes as well as relationship loaders. See the
+ new examples in :ref:`examples_session_orm_events` which
+ illustrate new ways of intercepting and modifying ORM queries
+ for the most common purpose of adding arbitrary filter criteria.
+
+
+ This event is intended to allow changes to the query given::
+
+ @event.listens_for(Query, "before_compile", retval=True)
+ def no_deleted(query):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ The event should normally be listened with the ``retval=True``
+ parameter set, so that the modified query may be returned.
+
+ The :meth:`.QueryEvents.before_compile` event by default
+ will disallow "baked" queries from caching a query, if the event
+ hook returns a new :class:`_query.Query` object.
+ This affects both direct
+ use of the baked query extension as well as its operation within
+ lazy loaders and eager loaders for relationships. In order to
+ re-establish the query being cached, apply the event adding the
+ ``bake_ok`` flag::
+
+ @event.listens_for(
+ Query, "before_compile", retval=True, bake_ok=True)
+ def my_event(query):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ When ``bake_ok`` is set to True, the event hook will only be invoked
+ once, and not called for subsequent invocations of a particular query
+ that is being cached.
+
+ .. versionadded:: 1.3.11 - added the "bake_ok" flag to the
+ :meth:`.QueryEvents.before_compile` event and disallowed caching via
+ the "baked" extension from occurring for event handlers that
+ return a new :class:`_query.Query` object if this flag is not set.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_update`
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+ :ref:`baked_with_before_compile`
+
+ """
+
+ def before_compile_update(self, query, update_context):
+ """Allow modifications to the :class:`_query.Query` object within
+ :meth:`_query.Query.update`.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_update`
+ event is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook.
+
+ Like the :meth:`.QueryEvents.before_compile` event, if the event
+ is to be used to alter the :class:`_query.Query` object, it should
+ be configured with ``retval=True``, and the modified
+ :class:`_query.Query` object returned, as in ::
+
+ @event.listens_for(Query, "before_compile_update", retval=True)
+ def no_deleted(query, update_context):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+
+ update_context.values['timestamp'] = datetime.utcnow()
+ return query
+
+ The ``.values`` dictionary of the "update context" object can also
+ be modified in place as illustrated above.
+
+ :param query: a :class:`_query.Query` instance; this is also
+ the ``.query`` attribute of the given "update context"
+ object.
+
+ :param update_context: an "update context" object which is
+ the same kind of object as described in
+ :paramref:`.QueryEvents.after_bulk_update.update_context`.
+ The object has a ``.values`` attribute in an UPDATE context which is
+ the dictionary of parameters passed to :meth:`_query.Query.update`.
+ This
+ dictionary can be modified to alter the VALUES clause of the
+ resulting UPDATE statement.
+
+ .. versionadded:: 1.2.17
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile`
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+
+ """
+
+ def before_compile_delete(self, query, delete_context):
+ """Allow modifications to the :class:`_query.Query` object within
+ :meth:`_query.Query.delete`.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_delete`
+ event is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook.
+
+ Like the :meth:`.QueryEvents.before_compile` event, this event
+ should be configured with ``retval=True``, and the modified
+ :class:`_query.Query` object returned, as in ::
+
+ @event.listens_for(Query, "before_compile_delete", retval=True)
+ def no_deleted(query, delete_context):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ :param query: a :class:`_query.Query` instance; this is also
+ the ``.query`` attribute of the given "delete context"
+ object.
+
+ :param delete_context: a "delete context" object which is
+ the same kind of object as described in
+ :paramref:`.QueryEvents.after_bulk_delete.delete_context`.
+
+ .. versionadded:: 1.2.17
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile`
+
+ :meth:`.QueryEvents.before_compile_update`
+
+
+ """
+
+ @classmethod
+ def _listen(cls, event_key, retval=False, bake_ok=False, **kw):
+ fn = event_key._listen_fn
+
+ if not retval:
+
+ def wrap(*arg, **kw):
+ if not retval:
+ query = arg[0]
+ fn(*arg, **kw)
+ return query
+ else:
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+ else:
+ # don't assume we can apply an attribute to the callable
+ def wrap(*arg, **kw):
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ wrap._bake_ok = bake_ok
+
+ event_key.base_listen(**kw)
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
new file mode 100644
index 0000000..8dd4d90
--- /dev/null
+++ b/lib/sqlalchemy/orm/exc.py
@@ -0,0 +1,204 @@
+# orm/exc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQLAlchemy ORM exceptions."""
+from .. import exc as sa_exc
+from .. import util
+from ..exc import MultipleResultsFound # noqa
+from ..exc import NoResultFound # noqa
+
+
+NO_STATE = (AttributeError, KeyError)
+"""Exception types that may be raised by instrumentation implementations."""
+
+
+class StaleDataError(sa_exc.SQLAlchemyError):
+ """An operation encountered database state that is unaccounted for.
+
+ Conditions which cause this to happen include:
+
+ * A flush may have attempted to update or delete rows
+ and an unexpected number of rows were matched during
+ the UPDATE or DELETE statement. Note that when
+ version_id_col is used, rows in UPDATE or DELETE statements
+ are also matched against the current known version
+ identifier.
+
+ * A mapped object with version_id_col was refreshed,
+ and the version number coming back from the database does
+ not match that of the object itself.
+
+ * A object is detached from its parent object, however
+ the object was previously attached to a different parent
+ identity which was garbage collected, and a decision
+ cannot be made if the new parent was really the most
+ recent "parent".
+
+ """
+
+
+ConcurrentModificationError = StaleDataError
+
+
+class FlushError(sa_exc.SQLAlchemyError):
+ """A invalid condition was detected during flush()."""
+
+
+class UnmappedError(sa_exc.InvalidRequestError):
+ """Base for exceptions that involve expected mappings not present."""
+
+
+class ObjectDereferencedError(sa_exc.SQLAlchemyError):
+ """An operation cannot complete due to an object being garbage
+ collected.
+
+ """
+
+
+class DetachedInstanceError(sa_exc.SQLAlchemyError):
+ """An attempt to access unloaded attributes on a
+ mapped instance that is detached."""
+
+ code = "bhk3"
+
+
+class UnmappedInstanceError(UnmappedError):
+ """An mapping operation was requested for an unknown instance."""
+
+ @util.preload_module("sqlalchemy.orm.base")
+ def __init__(self, obj, msg=None):
+ base = util.preloaded.orm_base
+
+ if not msg:
+ try:
+ base.class_mapper(type(obj))
+ name = _safe_cls_name(type(obj))
+ msg = (
+ "Class %r is mapped, but this instance lacks "
+ "instrumentation. This occurs when the instance "
+ "is created before sqlalchemy.orm.mapper(%s) "
+ "was called." % (name, name)
+ )
+ except UnmappedClassError:
+ msg = _default_unmapped(type(obj))
+ if isinstance(obj, type):
+ msg += (
+ "; was a class (%s) supplied where an instance was "
+ "required?" % _safe_cls_name(obj)
+ )
+ UnmappedError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class UnmappedClassError(UnmappedError):
+ """An mapping operation was requested for an unknown class."""
+
+ def __init__(self, cls, msg=None):
+ if not msg:
+ msg = _default_unmapped(cls)
+ UnmappedError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class ObjectDeletedError(sa_exc.InvalidRequestError):
+ """A refresh operation failed to retrieve the database
+ row corresponding to an object's known primary key identity.
+
+ A refresh operation proceeds when an expired attribute is
+ accessed on an object, or when :meth:`_query.Query.get` is
+ used to retrieve an object which is, upon retrieval, detected
+ as expired. A SELECT is emitted for the target row
+ based on primary key; if no row is returned, this
+ exception is raised.
+
+ The true meaning of this exception is simply that
+ no row exists for the primary key identifier associated
+ with a persistent object. The row may have been
+ deleted, or in some cases the primary key updated
+ to a new value, outside of the ORM's management of the target
+ object.
+
+ """
+
+ @util.preload_module("sqlalchemy.orm.base")
+ def __init__(self, state, msg=None):
+ base = util.preloaded.orm_base
+
+ if not msg:
+ msg = (
+ "Instance '%s' has been deleted, or its "
+ "row is otherwise not present." % base.state_str(state)
+ )
+
+ sa_exc.InvalidRequestError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class UnmappedColumnError(sa_exc.InvalidRequestError):
+ """Mapping operation was requested on an unknown column."""
+
+
+class LoaderStrategyException(sa_exc.InvalidRequestError):
+ """A loader strategy for an attribute does not exist."""
+
+ def __init__(
+ self,
+ applied_to_property_type,
+ requesting_property,
+ applies_to,
+ actual_strategy_type,
+ strategy_key,
+ ):
+ if actual_strategy_type is None:
+ sa_exc.InvalidRequestError.__init__(
+ self,
+ "Can't find strategy %s for %s"
+ % (strategy_key, requesting_property),
+ )
+ else:
+ sa_exc.InvalidRequestError.__init__(
+ self,
+ 'Can\'t apply "%s" strategy to property "%s", '
+ 'which is a "%s"; this loader strategy is intended '
+ 'to be used with a "%s".'
+ % (
+ util.clsname_as_plain_name(actual_strategy_type),
+ requesting_property,
+ util.clsname_as_plain_name(applied_to_property_type),
+ util.clsname_as_plain_name(applies_to),
+ ),
+ )
+
+
+def _safe_cls_name(cls):
+ try:
+ cls_name = ".".join((cls.__module__, cls.__name__))
+ except AttributeError:
+ cls_name = getattr(cls, "__name__", None)
+ if cls_name is None:
+ cls_name = repr(cls)
+ return cls_name
+
+
+@util.preload_module("sqlalchemy.orm.base")
+def _default_unmapped(cls):
+ base = util.preloaded.orm_base
+
+ try:
+ mappers = base.manager_of_class(cls).mappers
+ except (TypeError,) + NO_STATE:
+ mappers = {}
+ name = _safe_cls_name(cls)
+
+ if not mappers:
+ return "Class '%s' is not mapped" % name
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
new file mode 100644
index 0000000..7de8e2c
--- /dev/null
+++ b/lib/sqlalchemy/orm/identity.py
@@ -0,0 +1,254 @@
+# orm/identity.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import weakref
+
+from . import util as orm_util
+from .. import exc as sa_exc
+from .. import util
+
+
+class IdentityMap(object):
+ def __init__(self):
+ self._dict = {}
+ self._modified = set()
+ self._wr = weakref.ref(self)
+
+ def _kill(self):
+ self._add_unpresent = _killed
+
+ def keys(self):
+ return self._dict.keys()
+
+ def replace(self, state):
+ raise NotImplementedError()
+
+ def add(self, state):
+ raise NotImplementedError()
+
+ def _add_unpresent(self, state, key):
+ """optional inlined form of add() which can assume item isn't present
+ in the map"""
+ self.add(state)
+
+ def update(self, dict_):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def clear(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def _manage_incoming_state(self, state):
+ state._instance_dict = self._wr
+
+ if state.modified:
+ self._modified.add(state)
+
+ def _manage_removed_state(self, state):
+ del state._instance_dict
+ if state.modified:
+ self._modified.discard(state)
+
+ def _dirty_states(self):
+ return self._modified
+
+ def check_modified(self):
+ """return True if any InstanceStates present have been marked
+ as 'modified'.
+
+ """
+ return bool(self._modified)
+
+ def has_key(self, key):
+ return key in self
+
+ def popitem(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def pop(self, key, *args):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def setdefault(self, key, default=None):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def __len__(self):
+ return len(self._dict)
+
+ def copy(self):
+ raise NotImplementedError()
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def __delitem__(self, key):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+
+class WeakInstanceDict(IdentityMap):
+ def __getitem__(self, key):
+ state = self._dict[key]
+ o = state.obj()
+ if o is None:
+ raise KeyError(key)
+ return o
+
+ def __contains__(self, key):
+ try:
+ if key in self._dict:
+ state = self._dict[key]
+ o = state.obj()
+ else:
+ return False
+ except KeyError:
+ return False
+ else:
+ return o is not None
+
+ def contains_state(self, state):
+ if state.key in self._dict:
+ try:
+ return self._dict[state.key] is state
+ except KeyError:
+ return False
+ else:
+ return False
+
+ def replace(self, state):
+ if state.key in self._dict:
+ try:
+ existing = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if existing is not state:
+ self._manage_removed_state(existing)
+ else:
+ return None
+ else:
+ existing = None
+
+ self._dict[state.key] = state
+ self._manage_incoming_state(state)
+ return existing
+
+ def add(self, state):
+ key = state.key
+ # inline of self.__contains__
+ if key in self._dict:
+ try:
+ existing_state = self._dict[key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if existing_state is not state:
+ o = existing_state.obj()
+ if o is not None:
+ raise sa_exc.InvalidRequestError(
+ "Can't attach instance "
+ "%s; another instance with key %s is already "
+ "present in this session."
+ % (orm_util.state_str(state), state.key)
+ )
+ else:
+ return False
+ self._dict[key] = state
+ self._manage_incoming_state(state)
+ return True
+
+ def _add_unpresent(self, state, key):
+ # inlined form of add() called by loading.py
+ self._dict[key] = state
+ state._instance_dict = self._wr
+
+ def get(self, key, default=None):
+ if key not in self._dict:
+ return default
+ try:
+ state = self._dict[key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ return default
+ else:
+ o = state.obj()
+ if o is None:
+ return default
+ return o
+
+ def items(self):
+ values = self.all_states()
+ result = []
+ for state in values:
+ value = state.obj()
+ if value is not None:
+ result.append((state.key, value))
+ return result
+
+ def values(self):
+ values = self.all_states()
+ result = []
+ for state in values:
+ value = state.obj()
+ if value is not None:
+ result.append(value)
+
+ return result
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ if util.py2k:
+
+ def iteritems(self):
+ return iter(self.items())
+
+ def itervalues(self):
+ return iter(self.values())
+
+ def all_states(self):
+ if util.py2k:
+ return self._dict.values()
+ else:
+ return list(self._dict.values())
+
+ def _fast_discard(self, state):
+ # used by InstanceState for state being
+ # GC'ed, inlines _managed_removed_state
+ try:
+ st = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if st is state:
+ self._dict.pop(state.key, None)
+
+ def discard(self, state):
+ self.safe_discard(state)
+
+ def safe_discard(self, state):
+ if state.key in self._dict:
+ try:
+ st = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if st is state:
+ self._dict.pop(state.key, None)
+ self._manage_removed_state(state)
+
+
+def _killed(state, key):
+ # external function to avoid creating cycles when assigned to
+ # the IdentityMap
+ raise sa_exc.InvalidRequestError(
+ "Object %s cannot be converted to 'persistent' state, as this "
+ "identity map is no longer valid. Has the owning Session "
+ "been closed?" % orm_util.state_str(state),
+ code="lkrp",
+ )
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
new file mode 100644
index 0000000..a7023a2
--- /dev/null
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -0,0 +1,652 @@
+# orm/instrumentation.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines SQLAlchemy's system of class instrumentation.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+instrumentation.py deals with registration of end-user classes
+for state tracking. It interacts closely with state.py
+and attributes.py which establish per-instance and per-class-attribute
+instrumentation, respectively.
+
+The class instrumentation system can be customized on a per-class
+or global basis using the :mod:`sqlalchemy.ext.instrumentation`
+module, which provides the means to build and specify
+alternate instrumentation forms.
+
+.. versionchanged: 0.8
+ The instrumentation extension system was moved out of the
+ ORM and into the external :mod:`sqlalchemy.ext.instrumentation`
+ package. When that package is imported, it installs
+ itself within sqlalchemy.orm so that its more comprehensive
+ resolution mechanics take effect.
+
+"""
+
+
+import weakref
+
+from . import base
+from . import collections
+from . import exc
+from . import interfaces
+from . import state
+from .. import util
+from ..util import HasMemoized
+
+
+DEL_ATTR = util.symbol("DEL_ATTR")
+
+
+class ClassManager(HasMemoized, dict):
+ """Tracks state information at the class level."""
+
+ MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
+ STATE_ATTR = base.DEFAULT_STATE_ATTR
+
+ _state_setter = staticmethod(util.attrsetter(STATE_ATTR))
+
+ expired_attribute_loader = None
+ "previously known as deferred_scalar_loader"
+
+ init_method = None
+
+ factory = None
+ mapper = None
+ declarative_scan = None
+ registry = None
+
+ @property
+ @util.deprecated(
+ "1.4",
+ message="The ClassManager.deferred_scalar_loader attribute is now "
+ "named expired_attribute_loader",
+ )
+ def deferred_scalar_loader(self):
+ return self.expired_attribute_loader
+
+ @deferred_scalar_loader.setter
+ @util.deprecated(
+ "1.4",
+ message="The ClassManager.deferred_scalar_loader attribute is now "
+ "named expired_attribute_loader",
+ )
+ def deferred_scalar_loader(self, obj):
+ self.expired_attribute_loader = obj
+
+ def __init__(self, class_):
+ self.class_ = class_
+ self.info = {}
+ self.new_init = None
+ self.local_attrs = {}
+ self.originals = {}
+ self._finalized = False
+
+ self._bases = [
+ mgr
+ for mgr in [
+ manager_of_class(base)
+ for base in self.class_.__bases__
+ if isinstance(base, type)
+ ]
+ if mgr is not None
+ ]
+
+ for base_ in self._bases:
+ self.update(base_)
+
+ self.dispatch._events._new_classmanager_instance(class_, self)
+
+ for basecls in class_.__mro__:
+ mgr = manager_of_class(basecls)
+ if mgr is not None:
+ self.dispatch._update(mgr.dispatch)
+
+ self.manage()
+
+ if "__del__" in class_.__dict__:
+ util.warn(
+ "__del__() method on class %s will "
+ "cause unreachable cycles and memory leaks, "
+ "as SQLAlchemy instrumentation often creates "
+ "reference cycles. Please remove this method." % class_
+ )
+
+ def _update_state(
+ self,
+ finalize=False,
+ mapper=None,
+ registry=None,
+ declarative_scan=None,
+ expired_attribute_loader=None,
+ init_method=None,
+ ):
+
+ if mapper:
+ self.mapper = mapper
+ if registry:
+ registry._add_manager(self)
+ if declarative_scan:
+ self.declarative_scan = weakref.ref(declarative_scan)
+ if expired_attribute_loader:
+ self.expired_attribute_loader = expired_attribute_loader
+
+ if init_method:
+ assert not self._finalized, (
+ "class is already instrumented, "
+ "init_method %s can't be applied" % init_method
+ )
+ self.init_method = init_method
+
+ if not self._finalized:
+ self.original_init = (
+ self.init_method
+ if self.init_method is not None
+ and self.class_.__init__ is object.__init__
+ else self.class_.__init__
+ )
+
+ if finalize and not self._finalized:
+ self._finalize()
+
+ def _finalize(self):
+ if self._finalized:
+ return
+ self._finalized = True
+
+ self._instrument_init()
+
+ _instrumentation_factory.dispatch.class_instrument(self.class_)
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return other is self
+
+ @property
+ def is_mapped(self):
+ return "mapper" in self.__dict__
+
+ @HasMemoized.memoized_attribute
+ def _all_key_set(self):
+ return frozenset(self)
+
+ @HasMemoized.memoized_attribute
+ def _collection_impl_keys(self):
+ return frozenset(
+ [attr.key for attr in self.values() if attr.impl.collection]
+ )
+
+ @HasMemoized.memoized_attribute
+ def _scalar_loader_impls(self):
+ return frozenset(
+ [
+ attr.impl
+ for attr in self.values()
+ if attr.impl.accepts_scalar_loader
+ ]
+ )
+
+ @HasMemoized.memoized_attribute
+ def _loader_impls(self):
+ return frozenset([attr.impl for attr in self.values()])
+
+ @util.memoized_property
+ def mapper(self):
+ # raises unless self.mapper has been assigned
+ raise exc.UnmappedClassError(self.class_)
+
+ def _all_sqla_attributes(self, exclude=None):
+ """return an iterator of all classbound attributes that are
+ implement :class:`.InspectionAttr`.
+
+ This includes :class:`.QueryableAttribute` as well as extension
+ types such as :class:`.hybrid_property` and
+ :class:`.AssociationProxy`.
+
+ """
+
+ found = {}
+
+ # constraints:
+ # 1. yield keys in cls.__dict__ order
+ # 2. if a subclass has the same key as a superclass, include that
+ # key as part of the ordering of the superclass, because an
+ # overridden key is usually installed by the mapper which is going
+ # on a different ordering
+ # 3. don't use getattr() as this fires off descriptors
+
+ for supercls in self.class_.__mro__[0:-1]:
+ inherits = supercls.__mro__[1]
+ for key in supercls.__dict__:
+ found.setdefault(key, supercls)
+ if key in inherits.__dict__:
+ continue
+ val = found[key].__dict__[key]
+ if (
+ isinstance(val, interfaces.InspectionAttr)
+ and val.is_attribute
+ ):
+ yield key, val
+
+ def _get_class_attr_mro(self, key, default=None):
+ """return an attribute on the class without tripping it."""
+
+ for supercls in self.class_.__mro__:
+ if key in supercls.__dict__:
+ return supercls.__dict__[key]
+ else:
+ return default
+
+ def _attr_has_impl(self, key):
+ """Return True if the given attribute is fully initialized.
+
+ i.e. has an impl.
+ """
+
+ return key in self and self[key].impl is not None
+
+ def _subclass_manager(self, cls):
+ """Create a new ClassManager for a subclass of this ClassManager's
+ class.
+
+ This is called automatically when attributes are instrumented so that
+ the attributes can be propagated to subclasses against their own
+ class-local manager, without the need for mappers etc. to have already
+ pre-configured managers for the full class hierarchy. Mappers
+ can post-configure the auto-generated ClassManager when needed.
+
+ """
+ return register_class(cls, finalize=False)
+
+ def _instrument_init(self):
+ self.new_init = _generate_init(self.class_, self, self.original_init)
+ self.install_member("__init__", self.new_init)
+
+ @util.memoized_property
+ def _state_constructor(self):
+ self.dispatch.first_init(self, self.class_)
+ return state.InstanceState
+
+ def manage(self):
+ """Mark this instance as the manager for its class."""
+
+ setattr(self.class_, self.MANAGER_ATTR, self)
+
+ @util.hybridmethod
+ def manager_getter(self):
+ return _default_manager_getter
+
+ @util.hybridmethod
+ def state_getter(self):
+ """Return a (instance) -> InstanceState callable.
+
+ "state getter" callables should raise either KeyError or
+ AttributeError if no InstanceState could be found for the
+ instance.
+ """
+
+ return _default_state_getter
+
+ @util.hybridmethod
+ def dict_getter(self):
+ return _default_dict_getter
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ if propagated:
+ if key in self.local_attrs:
+ return # don't override local attr with inherited attr
+ else:
+ self.local_attrs[key] = inst
+ self.install_descriptor(key, inst)
+ self._reset_memoizations()
+ self[key] = inst
+
+ for cls in self.class_.__subclasses__():
+ manager = self._subclass_manager(cls)
+ manager.instrument_attribute(key, inst, True)
+
+ def subclass_managers(self, recursive):
+ for cls in self.class_.__subclasses__():
+ mgr = manager_of_class(cls)
+ if mgr is not None and mgr is not self:
+ yield mgr
+ if recursive:
+ for m in mgr.subclass_managers(True):
+ yield m
+
+ def post_configure_attribute(self, key):
+ _instrumentation_factory.dispatch.attribute_instrument(
+ self.class_, key, self[key]
+ )
+
+ def uninstrument_attribute(self, key, propagated=False):
+ if key not in self:
+ return
+ if propagated:
+ if key in self.local_attrs:
+ return # don't get rid of local attr
+ else:
+ del self.local_attrs[key]
+ self.uninstall_descriptor(key)
+ self._reset_memoizations()
+ del self[key]
+ for cls in self.class_.__subclasses__():
+ manager = manager_of_class(cls)
+ if manager:
+ manager.uninstrument_attribute(key, True)
+
+ def unregister(self):
+ """remove all instrumentation established by this ClassManager."""
+
+ for key in list(self.originals):
+ self.uninstall_member(key)
+
+ self.mapper = self.dispatch = self.new_init = None
+ self.info.clear()
+
+ for key in list(self):
+ if key in self.local_attrs:
+ self.uninstrument_attribute(key)
+
+ if self.MANAGER_ATTR in self.class_.__dict__:
+ delattr(self.class_, self.MANAGER_ATTR)
+
+ def install_descriptor(self, key, inst):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
+ setattr(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ delattr(self.class_, key)
+
+ def install_member(self, key, implementation):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
+ self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
+ setattr(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ original = self.originals.pop(key, None)
+ if original is not DEL_ATTR:
+ setattr(self.class_, key, original)
+ else:
+ delattr(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def initialize_collection(self, key, state, factory):
+ user_data = factory()
+ adapter = collections.CollectionAdapter(
+ self.get_impl(key), state, user_data
+ )
+ return adapter, user_data
+
+ def is_instrumented(self, key, search=False):
+ if search:
+ return key in self
+ else:
+ return key in self.local_attrs
+
+ def get_impl(self, key):
+ return self[key].impl
+
+ @property
+ def attributes(self):
+ return iter(self.values())
+
+ # InstanceState management
+
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ if state is None:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+ return instance
+
+ def setup_instance(self, instance, state=None):
+ if state is None:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+
+ def teardown_instance(self, instance):
+ delattr(instance, self.STATE_ATTR)
+
+ def _serialize(self, state, state_dict):
+ return _SerializeManager(state, state_dict)
+
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
+
+ A private convenience method used by the __init__ decorator.
+
+ """
+ if hasattr(instance, self.STATE_ATTR):
+ return False
+ elif self.class_ is not instance.__class__ and self.is_mapped:
+ # this will create a new ClassManager for the
+ # subclass, without a mapper. This is likely a
+ # user error situation but allow the object
+ # to be constructed, so that it is usable
+ # in a non-ORM context at least.
+ return self._subclass_manager(
+ instance.__class__
+ )._new_state_if_none(instance)
+ else:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+ return state
+
+ def has_state(self, instance):
+ return hasattr(instance, self.STATE_ATTR)
+
+ def has_parent(self, state, key, optimistic=False):
+ """TODO"""
+ return self.get_impl(key).hasparent(state, optimistic=optimistic)
+
+ def __bool__(self):
+ """All ClassManagers are non-zero regardless of attribute state."""
+ return True
+
+ __nonzero__ = __bool__
+
+ def __repr__(self):
+ return "<%s of %r at %x>" % (
+ self.__class__.__name__,
+ self.class_,
+ id(self),
+ )
+
+
+class _SerializeManager(object):
+ """Provide serialization of a :class:`.ClassManager`.
+
+ The :class:`.InstanceState` uses ``__init__()`` on serialize
+ and ``__call__()`` on deserialize.
+
+ """
+
+ def __init__(self, state, d):
+ self.class_ = state.class_
+ manager = state.manager
+ manager.dispatch.pickle(state, d)
+
+ def __call__(self, state, inst, state_dict):
+ state.manager = manager = manager_of_class(self.class_)
+ if manager is None:
+ raise exc.UnmappedInstanceError(
+ inst,
+ "Cannot deserialize object of type %r - "
+ "no mapper() has "
+ "been configured for this class within the current "
+ "Python process!" % self.class_,
+ )
+ elif manager.is_mapped and not manager.mapper.configured:
+ manager.mapper._check_configure()
+
+ # setup _sa_instance_state ahead of time so that
+ # unpickle events can access the object normally.
+ # see [ticket:2362]
+ if inst is not None:
+ manager.setup_instance(inst, state)
+ manager.dispatch.unpickle(state, state_dict)
+
+
+class InstrumentationFactory(object):
+ """Factory for new ClassManager instances."""
+
+ def create_manager_for_cls(self, class_):
+ assert class_ is not None
+ assert manager_of_class(class_) is None
+
+ # give a more complicated subclass
+ # a chance to do what it wants here
+ manager, factory = self._locate_extended_factory(class_)
+
+ if factory is None:
+ factory = ClassManager
+ manager = factory(class_)
+
+ self._check_conflicts(class_, factory)
+
+ manager.factory = factory
+
+ return manager
+
+ def _locate_extended_factory(self, class_):
+ """Overridden by a subclass to do an extended lookup."""
+ return None, None
+
+ def _check_conflicts(self, class_, factory):
+ """Overridden by a subclass to test for conflicting factories."""
+ return
+
+ def unregister(self, class_):
+ manager = manager_of_class(class_)
+ manager.unregister()
+ self.dispatch.class_uninstrument(class_)
+
+
+# this attribute is replaced by sqlalchemy.ext.instrumentation
+# when imported.
+_instrumentation_factory = InstrumentationFactory()
+
+# these attributes are replaced by sqlalchemy.ext.instrumentation
+# when a non-standard InstrumentationManager class is first
+# used to instrument a class.
+instance_state = _default_state_getter = base.instance_state
+
+instance_dict = _default_dict_getter = base.instance_dict
+
+manager_of_class = _default_manager_getter = base.manager_of_class
+
+
+def register_class(
+ class_,
+ finalize=True,
+ mapper=None,
+ registry=None,
+ declarative_scan=None,
+ expired_attribute_loader=None,
+ init_method=None,
+):
+ """Register class instrumentation.
+
+ Returns the existing or newly created class manager.
+
+ """
+
+ manager = manager_of_class(class_)
+ if manager is None:
+ manager = _instrumentation_factory.create_manager_for_cls(class_)
+ manager._update_state(
+ mapper=mapper,
+ registry=registry,
+ declarative_scan=declarative_scan,
+ expired_attribute_loader=expired_attribute_loader,
+ init_method=init_method,
+ finalize=finalize,
+ )
+
+ return manager
+
+
+def unregister_class(class_):
+ """Unregister class instrumentation."""
+
+ _instrumentation_factory.unregister(class_)
+
+
+def is_instrumented(instance, key):
+ """Return True if the given attribute on the given instance is
+ instrumented by the attributes package.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+
+ """
+ return manager_of_class(instance.__class__).is_instrumented(
+ key, search=True
+ )
+
+
+def _generate_init(class_, class_manager, original_init):
+ """Build an __init__ decorator that triggers ClassManager events."""
+
+ # TODO: we should use the ClassManager's notion of the
+ # original '__init__' method, once ClassManager is fixed
+ # to always reference that.
+
+ if original_init is None:
+ original_init = class_.__init__
+
+ # Go through some effort here and don't change the user's __init__
+ # calling signature, including the unlikely case that it has
+ # a return value.
+ # FIXME: need to juggle local names to avoid constructor argument
+ # clashes.
+ func_body = """\
+def __init__(%(apply_pos)s):
+ new_state = class_manager._new_state_if_none(%(self_arg)s)
+ if new_state:
+ return new_state._initialize_instance(%(apply_kw)s)
+ else:
+ return original_init(%(apply_kw)s)
+"""
+ func_vars = util.format_argspec_init(original_init, grouped=False)
+ func_text = func_body % func_vars
+
+ if util.py2k:
+ func = getattr(original_init, "im_func", original_init)
+ func_defaults = getattr(func, "func_defaults", None)
+ else:
+ func_defaults = getattr(original_init, "__defaults__", None)
+ func_kw_defaults = getattr(original_init, "__kwdefaults__", None)
+
+ env = locals().copy()
+ env["__name__"] = __name__
+ exec(func_text, env)
+ __init__ = env["__init__"]
+ __init__.__doc__ = original_init.__doc__
+ __init__._sa_original_init = original_init
+
+ if func_defaults:
+ __init__.__defaults__ = func_defaults
+ if not util.py2k and func_kw_defaults:
+ __init__.__kwdefaults__ = func_kw_defaults
+
+ return __init__
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
new file mode 100644
index 0000000..63295d0
--- /dev/null
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -0,0 +1,978 @@
+# orm/interfaces.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+Contains various base classes used throughout the ORM.
+
+Defines some key base classes prominent within the internals.
+
+This module and the classes within are mostly private, though some attributes
+are exposed when inspecting mappings.
+
+"""
+
+from __future__ import absolute_import
+
+import collections
+
+from . import exc as orm_exc
+from . import path_registry
+from .base import _MappedAttribute # noqa
+from .base import EXT_CONTINUE
+from .base import EXT_SKIP
+from .base import EXT_STOP
+from .base import InspectionAttr # noqa
+from .base import InspectionAttrInfo # noqa
+from .base import MANYTOMANY
+from .base import MANYTOONE
+from .base import NOT_EXTENSION
+from .base import ONETOMANY
+from .. import inspect
+from .. import inspection
+from .. import util
+from ..sql import operators
+from ..sql import roles
+from ..sql import visitors
+from ..sql.base import ExecutableOption
+from ..sql.traversals import HasCacheKey
+
+
+__all__ = (
+ "EXT_CONTINUE",
+ "EXT_STOP",
+ "EXT_SKIP",
+ "ONETOMANY",
+ "MANYTOMANY",
+ "MANYTOONE",
+ "NOT_EXTENSION",
+ "LoaderStrategy",
+ "MapperOption",
+ "LoaderOption",
+ "MapperProperty",
+ "PropComparator",
+ "StrategizedProperty",
+)
+
+
+class ORMStatementRole(roles.StatementRole):
+ _role_name = (
+ "Executable SQL or text() construct, including ORM " "aware objects"
+ )
+
+
+class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+ _role_name = "ORM mapped entity, aliased entity, or Column expression"
+
+
+class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+ _role_name = "ORM mapped or aliased entity"
+
+
+class ORMFromClauseRole(roles.StrictFromClauseRole):
+ _role_name = "ORM mapped entity, aliased entity, or FROM expression"
+
+
+@inspection._self_inspects
+class MapperProperty(
+ HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
+):
+ """Represent a particular class attribute mapped by :class:`_orm.Mapper`.
+
+ The most common occurrences of :class:`.MapperProperty` are the
+ mapped :class:`_schema.Column`, which is represented in a mapping as
+ an instance of :class:`.ColumnProperty`,
+ and a reference to another class produced by :func:`_orm.relationship`,
+ represented in the mapping as an instance of
+ :class:`.RelationshipProperty`.
+
+ """
+
+ __slots__ = (
+ "_configure_started",
+ "_configure_finished",
+ "parent",
+ "key",
+ "info",
+ )
+
+ _cache_key_traversal = [
+ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ]
+
+ cascade = frozenset()
+ """The set of 'cascade' attribute names.
+
+ This collection is checked before the 'cascade_iterator' method is called.
+
+ The collection typically only applies to a RelationshipProperty.
+
+ """
+
+ is_property = True
+ """Part of the InspectionAttr interface; states this object is a
+ mapper property.
+
+ """
+
+ @property
+ def _links_to_entity(self):
+ """True if this MapperProperty refers to a mapped entity.
+
+ Should only be True for RelationshipProperty, False for all others.
+
+ """
+ raise NotImplementedError()
+
+ def _memoized_attr_info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.InspectionAttr`.
+
+ The dictionary is generated when first accessed. Alternatively,
+ it can be specified as a constructor argument to the
+ :func:`.column_property`, :func:`_orm.relationship`, or
+ :func:`.composite`
+ functions.
+
+ .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
+ available on extension types via the
+ :attr:`.InspectionAttrInfo.info` attribute, so that it can apply
+ to a wider variety of ORM and extension constructs.
+
+ .. seealso::
+
+ :attr:`.QueryableAttribute.info`
+
+ :attr:`.SchemaItem.info`
+
+ """
+ return {}
+
+ def setup(self, context, query_entity, path, adapter, **kwargs):
+ """Called by Query for the purposes of constructing a SQL statement.
+
+ Each MapperProperty associated with the target mapper processes the
+ statement referenced by the query context, adding columns and/or
+ criterion as appropriate.
+
+ """
+
+ def create_row_processor(
+ self, context, query_entity, path, mapper, result, adapter, populators
+ ):
+ """Produce row processing functions and append to the given
+ set of populators lists.
+
+ """
+
+ def cascade_iterator(
+ self, type_, state, dict_, visited_states, halt_on=None
+ ):
+ """Iterate through instances related to the given instance for
+ a particular 'cascade', starting with this MapperProperty.
+
+ Return an iterator3-tuples (instance, mapper, state).
+
+ Note that the 'cascade' collection on this MapperProperty is
+ checked first for the given type before cascade_iterator is called.
+
+ This method typically only applies to RelationshipProperty.
+
+ """
+
+ return iter(())
+
+ def set_parent(self, parent, init):
+ """Set the parent mapper that references this MapperProperty.
+
+ This method is overridden by some subclasses to perform extra
+ setup when the mapper is first known.
+
+ """
+ self.parent = parent
+
+ def instrument_class(self, mapper):
+ """Hook called by the Mapper to the property to initiate
+ instrumentation of the class attribute managed by this
+ MapperProperty.
+
+ The MapperProperty here will typically call out to the
+ attributes module to set up an InstrumentedAttribute.
+
+ This step is the first of two steps to set up an InstrumentedAttribute,
+ and is called early in the mapper setup process.
+
+ The second step is typically the init_class_attribute step,
+ called from StrategizedProperty via the post_instrument_class()
+ hook. This step assigns additional state to the InstrumentedAttribute
+ (specifically the "impl") which has been determined after the
+ MapperProperty has determined what kind of persistence
+ management it needs to do (e.g. scalar, object, collection, etc).
+
+ """
+
+ def __init__(self):
+ self._configure_started = False
+ self._configure_finished = False
+
+ def init(self):
+ """Called after all mappers are created to assemble
+ relationships between mappers and perform other post-mapper-creation
+ initialization steps.
+
+
+ """
+ self._configure_started = True
+ self.do_init()
+ self._configure_finished = True
+
+ @property
+ def class_attribute(self):
+ """Return the class-bound descriptor corresponding to this
+ :class:`.MapperProperty`.
+
+ This is basically a ``getattr()`` call::
+
+ return getattr(self.parent.class_, self.key)
+
+ I.e. if this :class:`.MapperProperty` were named ``addresses``,
+ and the class to which it is mapped is ``User``, this sequence
+ is possible::
+
+ >>> from sqlalchemy import inspect
+ >>> mapper = inspect(User)
+ >>> addresses_property = mapper.attrs.addresses
+ >>> addresses_property.class_attribute is User.addresses
+ True
+ >>> User.addresses.property is addresses_property
+ True
+
+
+ """
+
+ return getattr(self.parent.class_, self.key)
+
+ def do_init(self):
+ """Perform subclass-specific initialization post-mapper-creation
+ steps.
+
+ This is a template method called by the ``MapperProperty``
+ object's init() method.
+
+ """
+
+ def post_instrument_class(self, mapper):
+ """Perform instrumentation adjustments that need to occur
+ after init() has completed.
+
+ The given Mapper is the Mapper invoking the operation, which
+ may not be the same Mapper as self.parent in an inheritance
+ scenario; however, Mapper will always at least be a sub-mapper of
+ self.parent.
+
+ This method is typically used by StrategizedProperty, which delegates
+ it to LoaderStrategy.init_class_attribute() to perform final setup
+ on the class-bound InstrumentedAttribute.
+
+ """
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+ """Merge the attribute represented by this ``MapperProperty``
+ from source to destination object.
+
+ """
+
+ def __repr__(self):
+ return "<%s at 0x%x; %s>" % (
+ self.__class__.__name__,
+ id(self),
+ getattr(self, "key", "no key"),
+ )
+
+
+@inspection._self_inspects
+class PropComparator(operators.ColumnOperators):
+ r"""Defines SQL operators for :class:`.MapperProperty` objects.
+
+ SQLAlchemy allows for operators to
+ be redefined at both the Core and ORM level. :class:`.PropComparator`
+ is the base class of operator redefinition for ORM-level operations,
+ including those of :class:`.ColumnProperty`,
+ :class:`.RelationshipProperty`, and :class:`.CompositeProperty`.
+
+ .. note:: With the advent of Hybrid properties introduced in SQLAlchemy
+ 0.7, as well as Core-level operator redefinition in
+ SQLAlchemy 0.8, the use case for user-defined :class:`.PropComparator`
+ instances is extremely rare. See :ref:`hybrids_toplevel` as well
+ as :ref:`types_operators`.
+
+ User-defined subclasses of :class:`.PropComparator` may be created. The
+ built-in Python comparison and math operator methods, such as
+ :meth:`.operators.ColumnOperators.__eq__`,
+ :meth:`.operators.ColumnOperators.__lt__`, and
+ :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide
+ new operator behavior. The custom :class:`.PropComparator` is passed to
+ the :class:`.MapperProperty` instance via the ``comparator_factory``
+ argument. In each case,
+ the appropriate subclass of :class:`.PropComparator` should be used::
+
+ # definition of custom PropComparator subclasses
+
+ from sqlalchemy.orm.properties import \
+ ColumnProperty,\
+ CompositeProperty,\
+ RelationshipProperty
+
+ class MyColumnComparator(ColumnProperty.Comparator):
+ def __eq__(self, other):
+ return self.__clause_element__() == other
+
+ class MyRelationshipComparator(RelationshipProperty.Comparator):
+ def any(self, expression):
+ "define the 'any' operation"
+ # ...
+
+ class MyCompositeComparator(CompositeProperty.Comparator):
+ def __gt__(self, other):
+ "redefine the 'greater than' operation"
+
+ return sql.and_(*[a>b for a, b in
+ zip(self.__clause_element__().clauses,
+ other.__composite_values__())])
+
+
+ # application of custom PropComparator subclasses
+
+ from sqlalchemy.orm import column_property, relationship, composite
+ from sqlalchemy import Column, String
+
+ class SomeMappedClass(Base):
+ some_column = column_property(Column("some_column", String),
+ comparator_factory=MyColumnComparator)
+
+ some_relationship = relationship(SomeOtherClass,
+ comparator_factory=MyRelationshipComparator)
+
+ some_composite = composite(
+ Column("a", String), Column("b", String),
+ comparator_factory=MyCompositeComparator
+ )
+
+ Note that for column-level operator redefinition, it's usually
+ simpler to define the operators at the Core level, using the
+ :attr:`.TypeEngine.comparator_factory` attribute. See
+ :ref:`types_operators` for more detail.
+
+ .. seealso::
+
+ :class:`.ColumnProperty.Comparator`
+
+ :class:`.RelationshipProperty.Comparator`
+
+ :class:`.CompositeProperty.Comparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
+
+ __visit_name__ = "orm_prop_comparator"
+
+ def __init__(
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ ):
+ self.prop = self.property = prop
+ self._parententity = adapt_to_entity or parentmapper
+ self._adapt_to_entity = adapt_to_entity
+
+ def __clause_element__(self):
+ raise NotImplementedError("%r" % self)
+
+ def _bulk_update_tuples(self, value):
+ """Receive a SQL expression that represents a value in the SET
+ clause of an UPDATE statement.
+
+ Return a tuple that can be passed to a :class:`_expression.Update`
+ construct.
+
+ """
+
+ return [(self.__clause_element__(), value)]
+
+ def adapt_to_entity(self, adapt_to_entity):
+ """Return a copy of this PropComparator which will use the given
+ :class:`.AliasedInsp` to produce corresponding expressions.
+ """
+ return self.__class__(self.prop, self._parententity, adapt_to_entity)
+
+ @property
+ def _parentmapper(self):
+ """legacy; this is renamed to _parententity to be
+ compatible with QueryableAttribute."""
+ return inspect(self._parententity).mapper
+
+ @property
+ def _propagate_attrs(self):
+ # this suits the case in coercions where we don't actually
+ # call ``__clause_element__()`` but still need to get
+ # resolved._propagate_attrs. See #6558.
+ return util.immutabledict(
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": self._parentmapper,
+ }
+ )
+
+ @property
+ def adapter(self):
+ """Produce a callable that adapts column expressions
+ to suit an aliased version of this comparator.
+
+ """
+ if self._adapt_to_entity is None:
+ return None
+ else:
+ return self._adapt_to_entity._adapt_element
+
+ @property
+ def info(self):
+ return self.property.info
+
+ @staticmethod
+ def any_op(a, b, **kwargs):
+ return a.any(b, **kwargs)
+
+ @staticmethod
+ def has_op(a, b, **kwargs):
+ return a.has(b, **kwargs)
+
+ @staticmethod
+ def of_type_op(a, class_):
+ return a.of_type(class_)
+
+ def of_type(self, class_):
+ r"""Redefine this object in terms of a polymorphic subclass,
+ :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
+ construct.
+
+ Returns a new PropComparator from which further criterion can be
+ evaluated.
+
+ e.g.::
+
+ query.join(Company.employees.of_type(Engineer)).\
+ filter(Engineer.name=='foo')
+
+ :param \class_: a class or mapper indicating that criterion will be
+ against this specific subclass.
+
+ .. seealso::
+
+ :ref:`queryguide_join_onclause` - in the :ref:`queryguide_toplevel`
+
+ :ref:`inheritance_of_type`
+
+ """
+
+ return self.operate(PropComparator.of_type_op, class_)
+
+ def and_(self, *criteria):
+ """Add additional criteria to the ON clause that's represented by this
+ relationship attribute.
+
+ E.g.::
+
+
+ stmt = select(User).join(
+ User.addresses.and_(Address.email_address != 'foo')
+ )
+
+ stmt = select(User).options(
+ joinedload(User.addresses.and_(Address.email_address != 'foo'))
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`orm_queryguide_join_on_augmented`
+
+ :ref:`loader_option_criteria`
+
+ :func:`.with_loader_criteria`
+
+ """
+ return self.operate(operators.and_, *criteria)
+
+ def any(self, criterion=None, **kwargs):
+ r"""Return true if this collection contains any member that meets the
+ given criterion.
+
+ The usual implementation of ``any()`` is
+ :meth:`.RelationshipProperty.Comparator.any`.
+
+ :param criterion: an optional ClauseElement formulated against the
+ member class' table or attributes.
+
+ :param \**kwargs: key/value pairs corresponding to member class
+ attribute names which will be compared via equality to the
+ corresponding values.
+
+ """
+
+ return self.operate(PropComparator.any_op, criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ r"""Return true if this element references a member which meets the
+ given criterion.
+
+ The usual implementation of ``has()`` is
+ :meth:`.RelationshipProperty.Comparator.has`.
+
+ :param criterion: an optional ClauseElement formulated against the
+ member class' table or attributes.
+
+ :param \**kwargs: key/value pairs corresponding to member class
+ attribute names which will be compared via equality to the
+ corresponding values.
+
+ """
+
+ return self.operate(PropComparator.has_op, criterion, **kwargs)
+
+
+class StrategizedProperty(MapperProperty):
+ """A MapperProperty which uses selectable strategies to affect
+ loading behavior.
+
+ There is a single strategy selected by default. Alternate
+ strategies can be selected at Query time through the usage of
+ ``StrategizedOption`` objects via the Query.options() method.
+
+ The mechanics of StrategizedProperty are used for every Query
+ invocation for every mapped attribute participating in that Query,
+ to determine first how the attribute will be rendered in SQL
+ and secondly how the attribute will retrieve a value from a result
+ row and apply it to a mapped object. The routines here are very
+ performance-critical.
+
+ """
+
+ __slots__ = (
+ "_strategies",
+ "strategy",
+ "_wildcard_token",
+ "_default_path_loader_key",
+ )
+ inherit_cache = True
+ strategy_wildcard_key = None
+
+ def _memoized_attr__wildcard_token(self):
+ return (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN),
+ )
+
+ def _memoized_attr__default_path_loader_key(self):
+ return (
+ "loader",
+ (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN),
+ ),
+ )
+
+ def _get_context_loader(self, context, path):
+ load = None
+
+ search_path = path[self]
+
+ # search among: exact match, "attr.*", "default" strategy
+ # if any.
+ for path_key in (
+ search_path._loader_key,
+ search_path._wildcard_path_loader_key,
+ search_path._default_path_loader_key,
+ ):
+ if path_key in context.attributes:
+ load = context.attributes[path_key]
+ break
+
+ return load
+
+ def _get_strategy(self, key):
+ try:
+ return self._strategies[key]
+ except KeyError:
+ pass
+
+ # run outside to prevent transfer of exception context
+ cls = self._strategy_lookup(self, *key)
+ # this previously was setting self._strategies[cls], that's
+ # a bad idea; should use strategy key at all times because every
+ # strategy has multiple keys at this point
+ self._strategies[key] = strategy = cls(self, key)
+ return strategy
+
+ def setup(self, context, query_entity, path, adapter, **kwargs):
+ loader = self._get_context_loader(context, path)
+ if loader and loader.strategy:
+ strat = self._get_strategy(loader.strategy)
+ else:
+ strat = self.strategy
+ strat.setup_query(
+ context, query_entity, path, loader, adapter, **kwargs
+ )
+
+ def create_row_processor(
+ self, context, query_entity, path, mapper, result, adapter, populators
+ ):
+ loader = self._get_context_loader(context, path)
+ if loader and loader.strategy:
+ strat = self._get_strategy(loader.strategy)
+ else:
+ strat = self.strategy
+ strat.create_row_processor(
+ context,
+ query_entity,
+ path,
+ loader,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ def do_init(self):
+ self._strategies = {}
+ self.strategy = self._get_strategy(self.strategy_key)
+
+ def post_instrument_class(self, mapper):
+ if (
+ not self.parent.non_primary
+ and not mapper.class_manager._attr_has_impl(self.key)
+ ):
+ self.strategy.init_class_attribute(mapper)
+
+ _all_strategies = collections.defaultdict(dict)
+
+ @classmethod
+ def strategy_for(cls, **kw):
+ def decorate(dec_cls):
+ # ensure each subclass of the strategy has its
+ # own _strategy_keys collection
+ if "_strategy_keys" not in dec_cls.__dict__:
+ dec_cls._strategy_keys = []
+ key = tuple(sorted(kw.items()))
+ cls._all_strategies[cls][key] = dec_cls
+ dec_cls._strategy_keys.append(key)
+ return dec_cls
+
+ return decorate
+
+ @classmethod
+ def _strategy_lookup(cls, requesting_property, *key):
+ requesting_property.parent._with_polymorphic_mappers
+
+ for prop_cls in cls.__mro__:
+ if prop_cls in cls._all_strategies:
+ strategies = cls._all_strategies[prop_cls]
+ try:
+ return strategies[key]
+ except KeyError:
+ pass
+
+ for property_type, strats in cls._all_strategies.items():
+ if key in strats:
+ intended_property_type = property_type
+ actual_strategy = strats[key]
+ break
+ else:
+ intended_property_type = None
+ actual_strategy = None
+
+ raise orm_exc.LoaderStrategyException(
+ cls,
+ requesting_property,
+ intended_property_type,
+ actual_strategy,
+ key,
+ )
+
+
+class ORMOption(ExecutableOption):
+ """Base class for option objects that are passed to ORM queries.
+
+ These options may be consumed by :meth:`.Query.options`,
+ :meth:`.Select.options`, or in a more general sense by any
+ :meth:`.Executable.options` method. They are interpreted at
+ statement compile time or execution time in modern use. The
+ deprecated :class:`.MapperOption` is consumed at ORM query construction
+ time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ _is_legacy_option = False
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" SELECT statements that occur for relationship
+ lazy loaders as well as attribute load / refresh operations.
+
+ """
+
+ _is_compile_state = False
+
+ _is_criteria_option = False
+
+ _is_strategy_option = False
+
+
+class CompileStateOption(HasCacheKey, ORMOption):
+ """base for :class:`.ORMOption` classes that affect the compilation of
+ a SQL query and therefore need to be part of the cache key.
+
+ .. note:: :class:`.CompileStateOption` is generally non-public and
+ should not be used as a base class for user-defined options; instead,
+ use :class:`.UserDefinedOption`, which is easier to use as it does not
+ interact with ORM compilation internals or caching.
+
+ :class:`.CompileStateOption` defines an internal attribute
+ ``_is_compile_state=True`` which has the effect of the ORM compilation
+ routines for SELECT and other statements will call upon these options when
+ a SQL string is being compiled. As such, these classes implement
+ :class:`.HasCacheKey` and need to provide robust ``_cache_key_traversal``
+ structures.
+
+ The :class:`.CompileStateOption` class is used to implement the ORM
+ :class:`.LoaderOption` and :class:`.CriteriaOption` classes.
+
+ .. versionadded:: 1.4.28
+
+
+ """
+
+ _is_compile_state = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+
+
+class LoaderOption(CompileStateOption):
+ """Describe a loader modification to an ORM statement at compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+ self.process_compile_state(compile_state)
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+
+class CriteriaOption(CompileStateOption):
+ """Describe a WHERE criteria modification to an ORM statement at
+ compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_criteria_option = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def get_global_criteria(self, attributes):
+ """update additional entity criteria options in the given
+ attributes dictionary.
+
+ """
+
+
+class UserDefinedOption(ORMOption):
+ """Base class for a user-defined option that can be consumed from the
+ :meth:`.SessionEvents.do_orm_execute` event hook.
+
+ """
+
+ _is_legacy_option = False
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" Query objects produced during lazy loads
+ or refresh operations.
+
+ """
+
+ def __init__(self, payload=None):
+ self.payload = payload
+
+
+@util.deprecated_cls(
+ "1.4",
+ "The :class:`.MapperOption class is deprecated and will be removed "
+ "in a future release. For "
+ "modifications to queries on a per-execution basis, use the "
+ ":class:`.UserDefinedOption` class to establish state within a "
+ ":class:`.Query` or other Core statement, then use the "
+ ":meth:`.SessionEvents.before_orm_execute` hook to consume them.",
+ constructor=None,
+)
+class MapperOption(ORMOption):
+ """Describe a modification to a Query"""
+
+ _is_legacy_option = True
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" Query objects produced during lazy loads
+ or refresh operations.
+
+ """
+
+ def process_query(self, query):
+ """Apply a modification to the given :class:`_query.Query`."""
+
+ def process_query_conditionally(self, query):
+ """same as process_query(), except that this option may not
+ apply to the given query.
+
+ This is typically applied during a lazy load or scalar refresh
+ operation to propagate options stated in the original Query to the
+ new Query being used for the load. It occurs for those options that
+ specify propagate_to_loaders=True.
+
+ """
+
+ self.process_query(query)
+
+
+class LoaderStrategy(object):
+ """Describe the loading behavior of a StrategizedProperty object.
+
+ The ``LoaderStrategy`` interacts with the querying process in three
+ ways:
+
+ * it controls the configuration of the ``InstrumentedAttribute``
+ placed on a class to handle the behavior of the attribute. this
+ may involve setting up class-level callable functions to fire
+ off a select operation when the attribute is first accessed
+ (i.e. a lazy load)
+
+ * it processes the ``QueryContext`` at statement construction time,
+ where it can modify the SQL statement that is being produced.
+ For example, simple column attributes will add their represented
+ column to the list of selected columns, a joined eager loader
+ may establish join clauses to add to the statement.
+
+ * It produces "row processor" functions at result fetching time.
+ These "row processor" functions populate a particular attribute
+ on a particular mapped instance.
+
+ """
+
+ __slots__ = (
+ "parent_property",
+ "is_class_level",
+ "parent",
+ "key",
+ "strategy_key",
+ "strategy_opts",
+ )
+
+ def __init__(self, parent, strategy_key):
+ self.parent_property = parent
+ self.is_class_level = False
+ self.parent = self.parent_property.parent
+ self.key = self.parent_property.key
+ self.strategy_key = strategy_key
+ self.strategy_opts = dict(strategy_key)
+
+ def init_class_attribute(self, mapper):
+ pass
+
+ def setup_query(
+ self, compile_state, query_entity, path, loadopt, adapter, **kwargs
+ ):
+ """Establish column and other state for a given QueryContext.
+
+ This method fulfills the contract specified by MapperProperty.setup().
+
+ StrategizedProperty delegates its setup() method
+ directly to this method.
+
+ """
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ """Establish row processing functions for a given QueryContext.
+
+ This method fulfills the contract specified by
+ MapperProperty.create_row_processor().
+
+ StrategizedProperty delegates its create_row_processor() method
+ directly to this method.
+
+ """
+
+ def __str__(self):
+ return str(self.parent_property)
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
new file mode 100644
index 0000000..b5691c0
--- /dev/null
+++ b/lib/sqlalchemy/orm/loading.py
@@ -0,0 +1,1465 @@
+# orm/loading.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to convert database
+rows into object instances and associated state.
+
+the functions here are called primarily by Query, Mapper,
+as well as some of the attribute loading strategies.
+
+"""
+from __future__ import absolute_import
+
+from . import attributes
+from . import exc as orm_exc
+from . import path_registry
+from . import strategy_options
+from .base import _DEFER_FOR_STATE
+from .base import _RAISE_FOR_STATE
+from .base import _SET_DEFERRED_EXPIRED
+from .util import _none_set
+from .util import state_str
+from .. import exc as sa_exc
+from .. import future
+from .. import util
+from ..engine import result_tuple
+from ..engine.result import ChunkedIteratorResult
+from ..engine.result import FrozenResult
+from ..engine.result import SimpleResultMetaData
+from ..sql import util as sql_util
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectState
+
+_new_runid = util.counter()
+
+
+def instances(cursor, context):
+ """Return a :class:`.Result` given an ORM query context.
+
+ :param cursor: a :class:`.CursorResult`, generated by a statement
+ which came from :class:`.ORMCompileState`
+
+ :param context: a :class:`.QueryContext` object
+
+ :return: a :class:`.Result` object representing ORM results
+
+ .. versionchanged:: 1.4 The instances() function now uses
+ :class:`.Result` objects and has an all new interface.
+
+ """
+
+ context.runid = _new_runid()
+ context.post_load_paths = {}
+
+ compile_state = context.compile_state
+ filtered = compile_state._has_mapper_entities
+ single_entity = (
+ not context.load_options._only_return_tuples
+ and len(compile_state._entities) == 1
+ and compile_state._entities[0].supports_single_entity
+ )
+
+ try:
+ (process, labels, extra) = list(
+ zip(
+ *[
+ query_entity.row_processor(context, cursor)
+ for query_entity in context.compile_state._entities
+ ]
+ )
+ )
+
+ if context.yield_per and (
+ context.loaders_require_buffering
+ or context.loaders_require_uniquing
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Can't use yield_per with eager loaders that require uniquing "
+ "or row buffering, e.g. joinedload() against collections "
+ "or subqueryload(). Consider the selectinload() strategy "
+ "for better flexibility in loading objects."
+ )
+
+ except Exception:
+ with util.safe_reraise():
+ cursor.close()
+
+ def _no_unique(entry):
+ raise sa_exc.InvalidRequestError(
+ "Can't use the ORM yield_per feature in conjunction with unique()"
+ )
+
+ def _not_hashable(datatype):
+ def go(obj):
+ raise sa_exc.InvalidRequestError(
+ "Can't apply uniqueness to row tuple containing value of "
+ "type %r; this datatype produces non-hashable values"
+ % datatype
+ )
+
+ return go
+
+ if context.load_options._legacy_uniquing:
+ unique_filters = [
+ _no_unique
+ if context.yield_per
+ else id
+ if (
+ ent.use_id_for_hash
+ or ent._non_hashable_value
+ or ent._null_column_type
+ )
+ else None
+ for ent in context.compile_state._entities
+ ]
+ else:
+ unique_filters = [
+ _no_unique
+ if context.yield_per
+ else _not_hashable(ent.column.type)
+ if (not ent.use_id_for_hash and ent._non_hashable_value)
+ else id
+ if ent.use_id_for_hash
+ else None
+ for ent in context.compile_state._entities
+ ]
+
+ row_metadata = SimpleResultMetaData(
+ labels, extra, _unique_filters=unique_filters
+ )
+
+ def chunks(size):
+ while True:
+ yield_per = size
+
+ context.partials = {}
+
+ if yield_per:
+ fetch = cursor.fetchmany(yield_per)
+
+ if not fetch:
+ break
+ else:
+ fetch = cursor._raw_all_rows()
+
+ if single_entity:
+ proc = process[0]
+ rows = [proc(row) for row in fetch]
+ else:
+ rows = [
+ tuple([proc(row) for proc in process]) for row in fetch
+ ]
+
+ for path, post_load in context.post_load_paths.items():
+ post_load.invoke(context, path)
+
+ yield rows
+
+ if not yield_per:
+ break
+
+ if context.execution_options.get("prebuffer_rows", False):
+ # this is a bit of a hack at the moment.
+ # I would rather have some option in the result to pre-buffer
+ # internally.
+ _prebuffered = list(chunks(None))
+
+ def chunks(size):
+ return iter(_prebuffered)
+
+ result = ChunkedIteratorResult(
+ row_metadata,
+ chunks,
+ source_supports_scalars=single_entity,
+ raw=cursor,
+ dynamic_yield_per=cursor.context._is_server_side,
+ )
+
+ # filtered and single_entity are used to indicate to legacy Query that the
+ # query has ORM entities, so legacy deduping and scalars should be called
+ # on the result.
+ result._attributes = result._attributes.union(
+ dict(filtered=filtered, is_single_entity=single_entity)
+ )
+
+ # multi_row_eager_loaders OTOH is specific to joinedload.
+ if context.compile_state.multi_row_eager_loaders:
+
+ def require_unique(obj):
+ raise sa_exc.InvalidRequestError(
+ "The unique() method must be invoked on this Result, "
+ "as it contains results that include joined eager loads "
+ "against collections"
+ )
+
+ result._unique_filter_state = (None, require_unique)
+
+ if context.yield_per:
+ result.yield_per(context.yield_per)
+
+ return result
+
+
+@util.preload_module("sqlalchemy.orm.context")
+def merge_frozen_result(session, statement, frozen_result, load=True):
+ """Merge a :class:`_engine.FrozenResult` back into a :class:`_orm.Session`,
+ returning a new :class:`_engine.Result` object with :term:`persistent`
+ objects.
+
+ See the section :ref:`do_orm_execute_re_executing` for an example.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing`
+
+ :meth:`_engine.Result.freeze`
+
+ :class:`_engine.FrozenResult`
+
+ """
+ querycontext = util.preloaded.orm_context
+
+ if load:
+ # flush current contents if we expect to load data
+ session._autoflush()
+
+ ctx = querycontext.ORMSelectCompileState._create_entities_collection(
+ statement, legacy=False
+ )
+
+ autoflush = session.autoflush
+ try:
+ session.autoflush = False
+ mapped_entities = [
+ i
+ for i, e in enumerate(ctx._entities)
+ if isinstance(e, querycontext._MapperEntity)
+ ]
+ keys = [ent._label_name for ent in ctx._entities]
+
+ keyed_tuple = result_tuple(
+ keys, [ent._extra_entities for ent in ctx._entities]
+ )
+
+ result = []
+ for newrow in frozen_result.rewrite_rows():
+ for i in mapped_entities:
+ if newrow[i] is not None:
+ newrow[i] = session._merge(
+ attributes.instance_state(newrow[i]),
+ attributes.instance_dict(newrow[i]),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+
+ result.append(keyed_tuple(newrow))
+
+ return frozen_result.with_new_rows(result)
+ finally:
+ session.autoflush = autoflush
+
+
+@util.deprecated_20(
+ ":func:`_orm.merge_result`",
+ alternative="The function as well as the method on :class:`_orm.Query` "
+ "is superseded by the :func:`_orm.merge_frozen_result` function.",
+ becomes_legacy=True,
+)
+@util.preload_module("sqlalchemy.orm.context")
+def merge_result(query, iterator, load=True):
+ """Merge a result into the given :class:`.Query` object's Session.
+
+ See :meth:`_orm.Query.merge_result` for top-level documentation on this
+ function.
+
+ """
+
+ querycontext = util.preloaded.orm_context
+
+ session = query.session
+ if load:
+ # flush current contents if we expect to load data
+ session._autoflush()
+
+ # TODO: need test coverage and documentation for the FrozenResult
+ # use case.
+ if isinstance(iterator, FrozenResult):
+ frozen_result = iterator
+ iterator = iter(frozen_result.data)
+ else:
+ frozen_result = None
+
+ ctx = querycontext.ORMSelectCompileState._create_entities_collection(
+ query, legacy=True
+ )
+
+ autoflush = session.autoflush
+ try:
+ session.autoflush = False
+ single_entity = not frozen_result and len(ctx._entities) == 1
+
+ if single_entity:
+ if isinstance(ctx._entities[0], querycontext._MapperEntity):
+ result = [
+ session._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ for instance in iterator
+ ]
+ else:
+ result = list(iterator)
+ else:
+ mapped_entities = [
+ i
+ for i, e in enumerate(ctx._entities)
+ if isinstance(e, querycontext._MapperEntity)
+ ]
+ result = []
+ keys = [ent._label_name for ent in ctx._entities]
+
+ keyed_tuple = result_tuple(
+ keys, [ent._extra_entities for ent in ctx._entities]
+ )
+
+ for row in iterator:
+ newrow = list(row)
+ for i in mapped_entities:
+ if newrow[i] is not None:
+ newrow[i] = session._merge(
+ attributes.instance_state(newrow[i]),
+ attributes.instance_dict(newrow[i]),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ result.append(keyed_tuple(newrow))
+
+ if frozen_result:
+ return frozen_result.with_data(result)
+ else:
+ return iter(result)
+ finally:
+ session.autoflush = autoflush
+
+
+def get_from_identity(session, mapper, key, passive):
+ """Look up the given key in the given session's identity map,
+ check the object for expired state if found.
+
+ """
+ instance = session.identity_map.get(key)
+ if instance is not None:
+
+ state = attributes.instance_state(instance)
+
+ if mapper.inherits and not state.mapper.isa(mapper):
+ return attributes.PASSIVE_CLASS_MISMATCH
+
+ # expired - ensure it still exists
+ if state.expired:
+ if not passive & attributes.SQL_OK:
+ # TODO: no coverage here
+ return attributes.PASSIVE_NO_RESULT
+ elif not passive & attributes.RELATED_OBJECT_OK:
+ # this mode is used within a flush and the instance's
+ # expired state will be checked soon enough, if necessary.
+ # also used by immediateloader for a mutually-dependent
+ # o2m->m2m load, :ticket:`6301`
+ return instance
+ try:
+ state._load_expired(state, passive)
+ except orm_exc.ObjectDeletedError:
+ session._remove_newly_deleted([state])
+ return None
+ return instance
+ else:
+ return None
+
+
+def load_on_ident(
+ session,
+ statement,
+ key,
+ load_options=None,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ no_autoflush=False,
+ bind_arguments=util.EMPTY_DICT,
+ execution_options=util.EMPTY_DICT,
+):
+ """Load the given identity key from the database."""
+ if key is not None:
+ ident = key[1]
+ identity_token = key[2]
+ else:
+ ident = identity_token = None
+
+ return load_on_pk_identity(
+ session,
+ statement,
+ ident,
+ load_options=load_options,
+ refresh_state=refresh_state,
+ with_for_update=with_for_update,
+ only_load_props=only_load_props,
+ identity_token=identity_token,
+ no_autoflush=no_autoflush,
+ bind_arguments=bind_arguments,
+ execution_options=execution_options,
+ )
+
+
+def load_on_pk_identity(
+ session,
+ statement,
+ primary_key_identity,
+ load_options=None,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ identity_token=None,
+ no_autoflush=False,
+ bind_arguments=util.EMPTY_DICT,
+ execution_options=util.EMPTY_DICT,
+):
+
+ """Load the given primary key identity from the database."""
+
+ query = statement
+ q = query._clone()
+
+ assert not q._is_lambda_element
+
+ # TODO: fix these imports ....
+ from .context import QueryContext, ORMCompileState
+
+ if load_options is None:
+ load_options = QueryContext.default_load_options
+
+ if (
+ statement._compile_options
+ is SelectState.default_select_compile_options
+ ):
+ compile_options = ORMCompileState.default_compile_options
+ else:
+ compile_options = statement._compile_options
+
+ if primary_key_identity is not None:
+ mapper = query._propagate_attrs["plugin_subject"]
+
+ (_get_clause, _get_params) = mapper._get_clause
+
+ # None present in ident - turn those comparisons
+ # into "IS NULL"
+ if None in primary_key_identity:
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+
+ _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones)
+
+ if len(nones) == len(primary_key_identity):
+ util.warn(
+ "fully NULL primary key identity cannot load any "
+ "object. This condition may raise an error in a future "
+ "release."
+ )
+
+ q._where_criteria = (
+ sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}),
+ )
+
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
+ else:
+ params = None
+
+ if with_for_update is not None:
+ version_check = True
+ q._for_update_arg = with_for_update
+ elif query._for_update_arg is not None:
+ version_check = True
+ q._for_update_arg = query._for_update_arg
+ else:
+ version_check = False
+
+ if refresh_state and refresh_state.load_options:
+ compile_options += {"_current_path": refresh_state.load_path.parent}
+ q = q.options(*refresh_state.load_options)
+
+ new_compile_options, load_options = _set_get_options(
+ compile_options,
+ load_options,
+ version_check=version_check,
+ only_load_props=only_load_props,
+ refresh_state=refresh_state,
+ identity_token=identity_token,
+ )
+ q._compile_options = new_compile_options
+ q._order_by = None
+
+ if no_autoflush:
+ load_options += {"_autoflush": False}
+
+ execution_options = util.EMPTY_DICT.merge_with(
+ execution_options, {"_sa_orm_load_options": load_options}
+ )
+ result = (
+ session.execute(
+ q,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ )
+ .unique()
+ .scalars()
+ )
+
+ try:
+ return result.one()
+ except orm_exc.NoResultFound:
+ return None
+
+
+def _set_get_options(
+ compile_opt,
+ load_opt,
+ populate_existing=None,
+ version_check=None,
+ only_load_props=None,
+ refresh_state=None,
+ identity_token=None,
+):
+
+ compile_options = {}
+ load_options = {}
+ if version_check:
+ load_options["_version_check"] = version_check
+ if populate_existing:
+ load_options["_populate_existing"] = populate_existing
+ if refresh_state:
+ load_options["_refresh_state"] = refresh_state
+ compile_options["_for_refresh_state"] = True
+ if only_load_props:
+ compile_options["_only_load_props"] = frozenset(only_load_props)
+ if identity_token:
+ load_options["_refresh_identity_token"] = identity_token
+
+ if load_options:
+ load_opt += load_options
+ if compile_options:
+ compile_opt += compile_options
+
+ return compile_opt, load_opt
+
+
+def _setup_entity_query(
+ compile_state,
+ mapper,
+ query_entity,
+ path,
+ adapter,
+ column_collection,
+ with_polymorphic=None,
+ only_load_props=None,
+ polymorphic_discriminator=None,
+ **kw
+):
+
+ if with_polymorphic:
+ poly_properties = mapper._iterate_polymorphic_properties(
+ with_polymorphic
+ )
+ else:
+ poly_properties = mapper._polymorphic_properties
+
+ quick_populators = {}
+
+ path.set(compile_state.attributes, "memoized_setups", quick_populators)
+
+ # for the lead entities in the path, e.g. not eager loads, and
+ # assuming a user-passed aliased class, e.g. not a from_self() or any
+ # implicit aliasing, don't add columns to the SELECT that aren't
+ # in the thing that's aliased.
+ check_for_adapt = adapter and len(path) == 1 and path[-1].is_aliased_class
+
+ for value in poly_properties:
+ if only_load_props and value.key not in only_load_props:
+ continue
+
+ value.setup(
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ only_load_props=only_load_props,
+ column_collection=column_collection,
+ memoized_populators=quick_populators,
+ check_for_adapt=check_for_adapt,
+ **kw
+ )
+
+ if (
+ polymorphic_discriminator is not None
+ and polymorphic_discriminator is not mapper.polymorphic_on
+ ):
+
+ if adapter:
+ pd = adapter.columns[polymorphic_discriminator]
+ else:
+ pd = polymorphic_discriminator
+ column_collection.append(pd)
+
+
+def _warn_for_runid_changed(state):
+ util.warn(
+ "Loading context for %s has changed within a load/refresh "
+ "handler, suggesting a row refresh operation took place. If this "
+ "event handler is expected to be "
+ "emitting row refresh operations within an existing load or refresh "
+ "operation, set restore_load_context=True when establishing the "
+ "listener to ensure the context remains unchanged when the event "
+ "handler completes." % (state_str(state),)
+ )
+
+
+def _instance_processor(
+ query_entity,
+ mapper,
+ context,
+ result,
+ path,
+ adapter,
+ only_load_props=None,
+ refresh_state=None,
+ polymorphic_discriminator=None,
+ _polymorphic_from=None,
+):
+ """Produce a mapper level row processor callable
+ which processes rows into mapped instances."""
+
+ # note that this method, most of which exists in a closure
+ # called _instance(), resists being broken out, as
+ # attempts to do so tend to add significant function
+ # call overhead. _instance() is the most
+ # performance-critical section in the whole ORM.
+
+ identity_class = mapper._identity_class
+ compile_state = context.compile_state
+
+ # look for "row getter" functions that have been assigned along
+ # with the compile state that were cached from a previous load.
+ # these are operator.itemgetter() objects that each will extract a
+ # particular column from each row.
+
+ getter_key = ("getters", mapper)
+ getters = path.get(compile_state.attributes, getter_key, None)
+
+ if getters is None:
+ # no getters, so go through a list of attributes we are loading for,
+ # and the ones that are column based will have already put information
+ # for us in another collection "memoized_setups", which represents the
+ # output of the LoaderStrategy.setup_query() method. We can just as
+ # easily call LoaderStrategy.create_row_processor for each, but by
+ # getting it all at once from setup_query we save another method call
+ # per attribute.
+ props = mapper._prop_set
+ if only_load_props is not None:
+ props = props.intersection(
+ mapper._props[k] for k in only_load_props
+ )
+
+ quick_populators = path.get(
+ context.attributes, "memoized_setups", _none_set
+ )
+
+ todo = []
+ cached_populators = {
+ "new": [],
+ "quick": [],
+ "deferred": [],
+ "expire": [],
+ "delayed": [],
+ "existing": [],
+ "eager": [],
+ }
+
+ if refresh_state is None:
+ # we can also get the "primary key" tuple getter function
+ pk_cols = mapper.primary_key
+
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+ primary_key_getter = result._tuple_getter(pk_cols)
+ else:
+ primary_key_getter = None
+
+ getters = {
+ "cached_populators": cached_populators,
+ "todo": todo,
+ "primary_key_getter": primary_key_getter,
+ }
+ for prop in props:
+ if prop in quick_populators:
+ # this is an inlined path just for column-based attributes.
+ col = quick_populators[prop]
+ if col is _DEFER_FOR_STATE:
+ cached_populators["new"].append(
+ (prop.key, prop._deferred_column_loader)
+ )
+ elif col is _SET_DEFERRED_EXPIRED:
+ # note that in this path, we are no longer
+ # searching in the result to see if the column might
+ # be present in some unexpected way.
+ cached_populators["expire"].append((prop.key, False))
+ elif col is _RAISE_FOR_STATE:
+ cached_populators["new"].append(
+ (prop.key, prop._raise_column_loader)
+ )
+ else:
+ getter = None
+ if adapter:
+ # this logic had been removed for all 1.4 releases
+ # up until 1.4.18; the adapter here is particularly
+ # the compound eager adapter which isn't accommodated
+ # in the quick_populators right now. The "fallback"
+ # logic below instead took over in many more cases
+ # until issue #6596 was identified.
+
+ # note there is still an issue where this codepath
+ # produces no "getter" for cases where a joined-inh
+ # mapping includes a labeled column property, meaning
+ # KeyError is caught internally and we fall back to
+ # _getter(col), which works anyway. The adapter
+ # here for joined inh without any aliasing might not
+ # be useful. Tests which see this include
+ # test.orm.inheritance.test_basic ->
+ # EagerTargetingTest.test_adapt_stringency
+ # OptimizedLoadTest.test_column_expression_joined
+ # PolymorphicOnNotLocalTest.test_polymorphic_on_column_prop # noqa: E501
+ #
+
+ adapted_col = adapter.columns[col]
+ if adapted_col is not None:
+ getter = result._getter(adapted_col, False)
+ if not getter:
+ getter = result._getter(col, False)
+ if getter:
+ cached_populators["quick"].append((prop.key, getter))
+ else:
+ # fall back to the ColumnProperty itself, which
+ # will iterate through all of its columns
+ # to see if one fits
+ prop.create_row_processor(
+ context,
+ query_entity,
+ path,
+ mapper,
+ result,
+ adapter,
+ cached_populators,
+ )
+ else:
+ # loader strategies like subqueryload, selectinload,
+ # joinedload, basically relationships, these need to interact
+ # with the context each time to work correctly.
+ todo.append(prop)
+
+ path.set(compile_state.attributes, getter_key, getters)
+
+ cached_populators = getters["cached_populators"]
+
+ populators = {key: list(value) for key, value in cached_populators.items()}
+ for prop in getters["todo"]:
+ prop.create_row_processor(
+ context, query_entity, path, mapper, result, adapter, populators
+ )
+
+ propagated_loader_options = context.propagated_loader_options
+ load_path = (
+ context.compile_state.current_path + path
+ if context.compile_state.current_path.path
+ else path
+ )
+
+ session_identity_map = context.session.identity_map
+
+ populate_existing = context.populate_existing or mapper.always_refresh
+ load_evt = bool(mapper.class_manager.dispatch.load)
+ refresh_evt = bool(mapper.class_manager.dispatch.refresh)
+ persistent_evt = bool(context.session.dispatch.loaded_as_persistent)
+ if persistent_evt:
+ loaded_as_persistent = context.session.dispatch.loaded_as_persistent
+ instance_state = attributes.instance_state
+ instance_dict = attributes.instance_dict
+ session_id = context.session.hash_key
+ runid = context.runid
+ identity_token = context.identity_token
+
+ version_check = context.version_check
+ if version_check:
+ version_id_col = mapper.version_id_col
+ if version_id_col is not None:
+ if adapter:
+ version_id_col = adapter.columns[version_id_col]
+ version_id_getter = result._getter(version_id_col)
+ else:
+ version_id_getter = None
+
+ if not refresh_state and _polymorphic_from is not None:
+ key = ("loader", path.path)
+ if key in context.attributes and context.attributes[key].strategy == (
+ ("selectinload_polymorphic", True),
+ ):
+ selectin_load_via = mapper._should_selectin_load(
+ context.attributes[key].local_opts["entities"],
+ _polymorphic_from,
+ )
+ else:
+ selectin_load_via = mapper._should_selectin_load(
+ None, _polymorphic_from
+ )
+
+ if selectin_load_via and selectin_load_via is not _polymorphic_from:
+ # only_load_props goes w/ refresh_state only, and in a refresh
+ # we are a single row query for the exact entity; polymorphic
+ # loading does not apply
+ assert only_load_props is None
+
+ callable_ = _load_subclass_via_in(context, path, selectin_load_via)
+
+ PostLoad.callable_for_path(
+ context,
+ load_path,
+ selectin_load_via.mapper,
+ selectin_load_via,
+ callable_,
+ selectin_load_via,
+ )
+
+ post_load = PostLoad.for_context(context, load_path, only_load_props)
+
+ if refresh_state:
+ refresh_identity_key = refresh_state.key
+ if refresh_identity_key is None:
+ # super-rare condition; a refresh is being called
+ # on a non-instance-key instance; this is meant to only
+ # occur within a flush()
+ refresh_identity_key = mapper._identity_key_from_state(
+ refresh_state
+ )
+ else:
+ refresh_identity_key = None
+
+ primary_key_getter = getters["primary_key_getter"]
+
+ if mapper.allow_partial_pks:
+ is_not_primary_key = _none_set.issuperset
+ else:
+ is_not_primary_key = _none_set.intersection
+
+ def _instance(row):
+
+ # determine the state that we'll be populating
+ if refresh_identity_key:
+ # fixed state that we're refreshing
+ state = refresh_state
+ instance = state.obj()
+ dict_ = instance_dict(instance)
+ isnew = state.runid != runid
+ currentload = True
+ loaded_instance = False
+ else:
+ # look at the row, see if that identity is in the
+ # session, or we have to create a new one
+ identitykey = (
+ identity_class,
+ primary_key_getter(row),
+ identity_token,
+ )
+
+ instance = session_identity_map.get(identitykey)
+
+ if instance is not None:
+ # existing instance
+ state = instance_state(instance)
+ dict_ = instance_dict(instance)
+
+ isnew = state.runid != runid
+ currentload = not isnew
+ loaded_instance = False
+
+ if version_check and version_id_getter and not currentload:
+ _validate_version_id(
+ mapper, state, dict_, row, version_id_getter
+ )
+
+ else:
+ # create a new instance
+
+ # check for non-NULL values in the primary key columns,
+ # else no entity is returned for the row
+ if is_not_primary_key(identitykey[1]):
+ return None
+
+ isnew = True
+ currentload = True
+ loaded_instance = True
+
+ instance = mapper.class_manager.new_instance()
+
+ dict_ = instance_dict(instance)
+ state = instance_state(instance)
+ state.key = identitykey
+ state.identity_token = identity_token
+
+ # attach instance to session.
+ state.session_id = session_id
+ session_identity_map._add_unpresent(state, identitykey)
+
+ effective_populate_existing = populate_existing
+ if refresh_state is state:
+ effective_populate_existing = True
+
+ # populate. this looks at whether this state is new
+ # for this load or was existing, and whether or not this
+ # row is the first row with this identity.
+ if currentload or effective_populate_existing:
+ # full population routines. Objects here are either
+ # just created, or we are doing a populate_existing
+
+ # be conservative about setting load_path when populate_existing
+ # is in effect; want to maintain options from the original
+ # load. see test_expire->test_refresh_maintains_deferred_options
+ if isnew and (
+ propagated_loader_options or not effective_populate_existing
+ ):
+ state.load_options = propagated_loader_options
+ state.load_path = load_path
+
+ _populate_full(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ effective_populate_existing,
+ populators,
+ )
+
+ if isnew:
+ # state.runid should be equal to context.runid / runid
+ # here, however for event checks we are being more conservative
+ # and checking against existing run id
+ # assert state.runid == runid
+
+ existing_runid = state.runid
+
+ if loaded_instance:
+ if load_evt:
+ state.manager.dispatch.load(state, context)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+ if persistent_evt:
+ loaded_as_persistent(context.session, state)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+ elif refresh_evt:
+ state.manager.dispatch.refresh(
+ state, context, only_load_props
+ )
+ if state.runid != runid:
+ _warn_for_runid_changed(state)
+
+ if effective_populate_existing or state.modified:
+ if refresh_state and only_load_props:
+ state._commit(dict_, only_load_props)
+ else:
+ state._commit_all(dict_, session_identity_map)
+
+ if post_load:
+ post_load.add_state(state, True)
+
+ else:
+ # partial population routines, for objects that were already
+ # in the Session, but a row matches them; apply eager loaders
+ # on existing objects, etc.
+ unloaded = state.unloaded
+ isnew = state not in context.partials
+
+ if not isnew or unloaded or populators["eager"]:
+ # state is having a partial set of its attributes
+ # refreshed. Populate those attributes,
+ # and add to the "context.partials" collection.
+
+ to_load = _populate_partial(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ unloaded,
+ populators,
+ )
+
+ if isnew:
+ if refresh_evt:
+ existing_runid = state.runid
+ state.manager.dispatch.refresh(state, context, to_load)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+
+ state._commit(dict_, to_load)
+
+ if post_load and context.invoke_all_eagers:
+ post_load.add_state(state, False)
+
+ return instance
+
+ if mapper.polymorphic_map and not _polymorphic_from and not refresh_state:
+ # if we are doing polymorphic, dispatch to a different _instance()
+ # method specific to the subclass mapper
+ def ensure_no_pk(row):
+ identitykey = (
+ identity_class,
+ primary_key_getter(row),
+ identity_token,
+ )
+ if not is_not_primary_key(identitykey[1]):
+ return identitykey
+ else:
+ return None
+
+ _instance = _decorate_polymorphic_switch(
+ _instance,
+ context,
+ query_entity,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ ensure_no_pk,
+ )
+
+ return _instance
+
+
+def _load_subclass_via_in(context, path, entity):
+ mapper = entity.mapper
+
+ zero_idx = len(mapper.base_mapper.primary_key) == 1
+
+ if entity.is_aliased_class:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+ else:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
+
+ def do_load(context, path, states, load_only, effective_entity):
+ orig_query = context.query
+
+ options = (enable_opt,) + orig_query._with_options + (disable_opt,)
+ q2 = q.options(*options)
+
+ q2._compile_options = context.compile_state.default_compile_options
+ q2._compile_options += {"_current_path": path.parent}
+
+ if context.populate_existing:
+ q2 = q2.execution_options(populate_existing=True)
+
+ context.session.execute(
+ q2,
+ dict(
+ primary_keys=[
+ state.key[1][0] if zero_idx else state.key[1]
+ for state, load_attrs in states
+ ]
+ ),
+ ).unique().scalars().all()
+
+ return do_load
+
+
+def _populate_full(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+):
+ if isnew:
+ # first time we are seeing a row with this identity.
+ state.runid = context.runid
+
+ for key, getter in populators["quick"]:
+ dict_[key] = getter(row)
+ if populate_existing:
+ for key, set_callable in populators["expire"]:
+ dict_.pop(key, None)
+ if set_callable:
+ state.expired_attributes.add(key)
+ else:
+ for key, set_callable in populators["expire"]:
+ if set_callable:
+ state.expired_attributes.add(key)
+
+ for key, populator in populators["new"]:
+ populator(state, dict_, row)
+ for key, populator in populators["delayed"]:
+ populator(state, dict_, row)
+ elif load_path != state.load_path:
+ # new load path, e.g. object is present in more than one
+ # column position in a series of rows
+ state.load_path = load_path
+
+ # if we have data, and the data isn't in the dict, OK, let's put
+ # it in.
+ for key, getter in populators["quick"]:
+ if key not in dict_:
+ dict_[key] = getter(row)
+
+ # otherwise treat like an "already seen" row
+ for key, populator in populators["existing"]:
+ populator(state, dict_, row)
+ # TODO: allow "existing" populator to know this is
+ # a new path for the state:
+ # populator(state, dict_, row, new_path=True)
+
+ else:
+ # have already seen rows with this identity in this same path.
+ for key, populator in populators["existing"]:
+ populator(state, dict_, row)
+
+ # TODO: same path
+ # populator(state, dict_, row, new_path=False)
+
+
+def _populate_partial(
+ context, row, state, dict_, isnew, load_path, unloaded, populators
+):
+
+ if not isnew:
+ to_load = context.partials[state]
+ for key, populator in populators["existing"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ else:
+ to_load = unloaded
+ context.partials[state] = to_load
+
+ for key, getter in populators["quick"]:
+ if key in to_load:
+ dict_[key] = getter(row)
+ for key, set_callable in populators["expire"]:
+ if key in to_load:
+ dict_.pop(key, None)
+ if set_callable:
+ state.expired_attributes.add(key)
+ for key, populator in populators["new"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ for key, populator in populators["delayed"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ for key, populator in populators["eager"]:
+ if key not in unloaded:
+ populator(state, dict_, row)
+
+ return to_load
+
+
+def _validate_version_id(mapper, state, dict_, row, getter):
+
+ if mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ) != getter(row):
+ raise orm_exc.StaleDataError(
+ "Instance '%s' has version id '%s' which "
+ "does not match database-loaded version id '%s'."
+ % (
+ state_str(state),
+ mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ),
+ getter(row),
+ )
+ )
+
+
+def _decorate_polymorphic_switch(
+ instance_fn,
+ context,
+ query_entity,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ ensure_no_pk,
+):
+ if polymorphic_discriminator is not None:
+ polymorphic_on = polymorphic_discriminator
+ else:
+ polymorphic_on = mapper.polymorphic_on
+ if polymorphic_on is None:
+ return instance_fn
+
+ if adapter:
+ polymorphic_on = adapter.columns[polymorphic_on]
+
+ def configure_subclass_mapper(discriminator):
+ try:
+ sub_mapper = mapper.polymorphic_map[discriminator]
+ except KeyError:
+ raise AssertionError(
+ "No such polymorphic_identity %r is defined" % discriminator
+ )
+ else:
+ if sub_mapper is mapper:
+ return None
+ elif not sub_mapper.isa(mapper):
+ return False
+
+ return _instance_processor(
+ query_entity,
+ sub_mapper,
+ context,
+ result,
+ path,
+ adapter,
+ _polymorphic_from=mapper,
+ )
+
+ polymorphic_instances = util.PopulateDict(configure_subclass_mapper)
+
+ getter = result._getter(polymorphic_on)
+
+ def polymorphic_instance(row):
+ discriminator = getter(row)
+ if discriminator is not None:
+ _instance = polymorphic_instances[discriminator]
+ if _instance:
+ return _instance(row)
+ elif _instance is False:
+ identitykey = ensure_no_pk(row)
+
+ if identitykey:
+ raise sa_exc.InvalidRequestError(
+ "Row with identity key %s can't be loaded into an "
+ "object; the polymorphic discriminator column '%s' "
+ "refers to %s, which is not a sub-mapper of "
+ "the requested %s"
+ % (
+ identitykey,
+ polymorphic_on,
+ mapper.polymorphic_map[discriminator],
+ mapper,
+ )
+ )
+ else:
+ return None
+ else:
+ return instance_fn(row)
+ else:
+ identitykey = ensure_no_pk(row)
+
+ if identitykey:
+ raise sa_exc.InvalidRequestError(
+ "Row with identity key %s can't be loaded into an "
+ "object; the polymorphic discriminator column '%s' is "
+ "NULL" % (identitykey, polymorphic_on)
+ )
+ else:
+ return None
+
+ return polymorphic_instance
+
+
+class PostLoad(object):
+ """Track loaders and states for "post load" operations."""
+
+ __slots__ = "loaders", "states", "load_keys"
+
+ def __init__(self):
+ self.loaders = {}
+ self.states = util.OrderedDict()
+ self.load_keys = None
+
+ def add_state(self, state, overwrite):
+ # the states for a polymorphic load here are all shared
+ # within a single PostLoad object among multiple subtypes.
+ # Filtering of callables on a per-subclass basis needs to be done at
+ # the invocation level
+ self.states[state] = overwrite
+
+ def invoke(self, context, path):
+ if not self.states:
+ return
+ path = path_registry.PathRegistry.coerce(path)
+ for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+ states = [
+ (state, overwrite)
+ for state, overwrite in self.states.items()
+ if state.manager.mapper.isa(limit_to_mapper)
+ ]
+ if states:
+ loader(context, path, states, self.load_keys, *arg, **kw)
+ self.states.clear()
+
+ @classmethod
+ def for_context(cls, context, path, only_load_props):
+ pl = context.post_load_paths.get(path.path)
+ if pl is not None and only_load_props:
+ pl.load_keys = only_load_props
+ return pl
+
+ @classmethod
+ def path_exists(self, context, path, key):
+ return (
+ path.path in context.post_load_paths
+ and key in context.post_load_paths[path.path].loaders
+ )
+
+ @classmethod
+ def callable_for_path(
+ cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw
+ ):
+ if path.path in context.post_load_paths:
+ pl = context.post_load_paths[path.path]
+ else:
+ pl = context.post_load_paths[path.path] = PostLoad()
+ pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
+
+
+def load_scalar_attributes(mapper, state, attribute_names, passive):
+ """initiate a column-based attribute refresh operation."""
+
+ # assert mapper is _state_mapper(state)
+ session = state.session
+ if not session:
+ raise orm_exc.DetachedInstanceError(
+ "Instance %s is not bound to a Session; "
+ "attribute refresh operation cannot proceed" % (state_str(state))
+ )
+
+ has_key = bool(state.key)
+
+ result = False
+
+ no_autoflush = (
+ bool(passive & attributes.NO_AUTOFLUSH) or state.session.autocommit
+ )
+
+ # in the case of inheritance, particularly concrete and abstract
+ # concrete inheritance, the class manager might have some keys
+ # of attributes on the superclass that we didn't actually map.
+ # These could be mapped as "concrete, don't load" or could be completely
+ # excluded from the mapping and we know nothing about them. Filter them
+ # here to prevent them from coming through.
+ if attribute_names:
+ attribute_names = attribute_names.intersection(mapper.attrs.keys())
+
+ if mapper.inherits and not mapper.concrete:
+ # because we are using Core to produce a select() that we
+ # pass to the Query, we aren't calling setup() for mapped
+ # attributes; in 1.0 this means deferred attrs won't get loaded
+ # by default
+ statement = mapper._optimized_get_statement(state, attribute_names)
+ if statement is not None:
+ # this was previously aliased(mapper, statement), however,
+ # statement is a select() and Query's coercion now raises for this
+ # since you can't "select" from a "SELECT" statement. only
+ # from_statement() allows this.
+ # note: using from_statement() here means there is an adaption
+ # with adapt_on_names set up. the other option is to make the
+ # aliased() against a subquery which affects the SQL.
+
+ from .query import FromStatement
+
+ stmt = FromStatement(mapper, statement).options(
+ strategy_options.Load(mapper).undefer("*")
+ )
+
+ result = load_on_ident(
+ session,
+ stmt,
+ None,
+ only_load_props=attribute_names,
+ refresh_state=state,
+ no_autoflush=no_autoflush,
+ )
+
+ if result is False:
+ if has_key:
+ identity_key = state.key
+ else:
+ # this codepath is rare - only valid when inside a flush, and the
+ # object is becoming persistent but hasn't yet been assigned
+ # an identity_key.
+ # check here to ensure we have the attrs we need.
+ pk_attrs = [
+ mapper._columntoproperty[col].key for col in mapper.primary_key
+ ]
+ if state.expired_attributes.intersection(pk_attrs):
+ raise sa_exc.InvalidRequestError(
+ "Instance %s cannot be refreshed - it's not "
+ " persistent and does not "
+ "contain a full primary key." % state_str(state)
+ )
+ identity_key = mapper._identity_key_from_state(state)
+
+ if (
+ _none_set.issubset(identity_key) and not mapper.allow_partial_pks
+ ) or _none_set.issuperset(identity_key):
+ util.warn_limited(
+ "Instance %s to be refreshed doesn't "
+ "contain a full primary key - can't be refreshed "
+ "(and shouldn't be expired, either).",
+ state_str(state),
+ )
+ return
+
+ result = load_on_ident(
+ session,
+ future.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ ),
+ identity_key,
+ refresh_state=state,
+ only_load_props=attribute_names,
+ no_autoflush=no_autoflush,
+ )
+
+ # if instance is pending, a refresh operation
+ # may not complete (even if PK attributes are assigned)
+ if has_key and result is None:
+ raise orm_exc.ObjectDeletedError(state)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
new file mode 100644
index 0000000..ed221a9
--- /dev/null
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -0,0 +1,3658 @@
+# orm/mapper.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Logic to map Python classes to and from selectables.
+
+Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central
+configurational unit which associates a class with a database table.
+
+This is a semi-private module; the main configurational API of the ORM is
+available in :class:`~sqlalchemy.orm.`.
+
+"""
+from __future__ import absolute_import
+
+from collections import deque
+from itertools import chain
+import sys
+import weakref
+
+from . import attributes
+from . import exc as orm_exc
+from . import instrumentation
+from . import loading
+from . import properties
+from . import util as orm_util
+from .base import _class_to_mapper
+from .base import _state_mapper
+from .base import class_mapper
+from .base import state_str
+from .interfaces import _MappedAttribute
+from .interfaces import EXT_SKIP
+from .interfaces import InspectionAttr
+from .interfaces import MapperProperty
+from .interfaces import ORMEntityColumnsClauseRole
+from .interfaces import ORMFromClauseRole
+from .interfaces import StrategizedProperty
+from .path_registry import PathRegistry
+from .. import event
+from .. import exc as sa_exc
+from .. import inspection
+from .. import log
+from .. import schema
+from .. import sql
+from .. import util
+from ..sql import base as sql_base
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import HasMemoized
+
+_mapper_registries = weakref.WeakKeyDictionary()
+
+_legacy_registry = None
+
+
+def _all_registries():
+ with _CONFIGURE_MUTEX:
+ return set(_mapper_registries)
+
+
+def _unconfigured_mappers():
+ for reg in _all_registries():
+ for mapper in reg._mappers_to_configure():
+ yield mapper
+
+
+_already_compiling = False
+
+
+# a constant returned by _get_attr_by_column to indicate
+# this mapper is not handling an attribute for a particular
+# column
+NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE")
+
+# lock used to synchronize the "mapper configure" step
+_CONFIGURE_MUTEX = util.threading.RLock()
+
+
+@inspection._self_inspects
+@log.class_logger
+class Mapper(
+ ORMFromClauseRole,
+ ORMEntityColumnsClauseRole,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """Defines an association between a Python class and a database table or
+ other relational structure, so that ORM operations against the class may
+ proceed.
+
+ The :class:`_orm.Mapper` object is instantiated using mapping methods
+ present on the :class:`_orm.registry` object. For information
+ about instantiating new :class:`_orm.Mapper` objects, see
+ :ref:`orm_mapping_classes_toplevel`.
+
+ """
+
+ _dispose_called = False
+ _ready_for_configure = False
+
+ @util.deprecated_params(
+ non_primary=(
+ "1.3",
+ "The :paramref:`.mapper.non_primary` parameter is deprecated, "
+ "and will be removed in a future release. The functionality "
+ "of non primary mappers is now better suited using the "
+ ":class:`.AliasedClass` construct, which can also be used "
+ "as the target of a :func:`_orm.relationship` in 1.3.",
+ ),
+ )
+ def __init__(
+ self,
+ class_,
+ local_table=None,
+ properties=None,
+ primary_key=None,
+ non_primary=False,
+ inherits=None,
+ inherit_condition=None,
+ inherit_foreign_keys=None,
+ always_refresh=False,
+ version_id_col=None,
+ version_id_generator=None,
+ polymorphic_on=None,
+ _polymorphic_map=None,
+ polymorphic_identity=None,
+ concrete=False,
+ with_polymorphic=None,
+ polymorphic_load=None,
+ allow_partial_pks=True,
+ batch=True,
+ column_prefix=None,
+ include_properties=None,
+ exclude_properties=None,
+ passive_updates=True,
+ passive_deletes=False,
+ confirm_deleted_rows=True,
+ eager_defaults=False,
+ legacy_is_orphan=False,
+ _compiled_cache_size=100,
+ ):
+ r"""Direct constructor for a new :class:`_orm.Mapper` object.
+
+ The :func:`_orm.mapper` function is normally invoked through the
+ use of the :class:`_orm.registry` object through either the
+ :ref:`Declarative <orm_declarative_mapping>` or
+ :ref:`Imperative <orm_imperative_mapping>` mapping styles.
+
+ .. versionchanged:: 1.4 The :func:`_orm.mapper` function should not
+ be called directly for classical mapping; for a classical mapping
+ configuration, use the :meth:`_orm.registry.map_imperatively`
+ method. The :func:`_orm.mapper` function may become private in a
+ future release.
+
+ Parameters documented below may be passed to either the
+ :meth:`_orm.registry.map_imperatively` method, or may be passed in the
+ ``__mapper_args__`` declarative class attribute described at
+ :ref:`orm_declarative_mapper_options`.
+
+ :param class\_: The class to be mapped. When using Declarative,
+ this argument is automatically passed as the declared class
+ itself.
+
+ :param local_table: The :class:`_schema.Table` or other selectable
+ to which the class is mapped. May be ``None`` if
+ this mapper inherits from another mapper using single-table
+ inheritance. When using Declarative, this argument is
+ automatically passed by the extension, based on what
+ is configured via the ``__table__`` argument or via the
+ :class:`_schema.Table`
+ produced as a result of the ``__tablename__``
+ and :class:`_schema.Column` arguments present.
+
+ :param always_refresh: If True, all query operations for this mapped
+ class will overwrite all data within object instances that already
+ exist within the session, erasing any in-memory changes with
+ whatever information was loaded from the database. Usage of this
+ flag is highly discouraged; as an alternative, see the method
+ :meth:`_query.Query.populate_existing`.
+
+ :param allow_partial_pks: Defaults to True. Indicates that a
+ composite primary key with some NULL values should be considered as
+ possibly existing within the database. This affects whether a
+ mapper will assign an incoming row to an existing identity, as well
+ as if :meth:`.Session.merge` will check the database first for a
+ particular primary key value. A "partial primary key" can occur if
+ one has mapped to an OUTER JOIN, for example.
+
+ :param batch: Defaults to ``True``, indicating that save operations
+ of multiple entities can be batched together for efficiency.
+ Setting to False indicates
+ that an instance will be fully saved before saving the next
+ instance. This is used in the extremely rare case that a
+ :class:`.MapperEvents` listener requires being called
+ in between individual row persistence operations.
+
+ :param column_prefix: A string which will be prepended
+ to the mapped attribute name when :class:`_schema.Column`
+ objects are automatically assigned as attributes to the
+ mapped class. Does not affect :class:`.Column` objects that
+ are mapped explicitly in the :paramref:`.mapper.properties`
+ dictionary.
+
+ This parameter is typically useful with imperative mappings
+ that keep the :class:`.Table` object separate. Below, assuming
+ the ``user_table`` :class:`.Table` object has columns named
+ ``user_id``, ``user_name``, and ``password``::
+
+ class User(Base):
+ __table__ = user_table
+ __mapper_args__ = {'column_prefix':'_'}
+
+ The above mapping will assign the ``user_id``, ``user_name``, and
+ ``password`` columns to attributes named ``_user_id``,
+ ``_user_name``, and ``_password`` on the mapped ``User`` class.
+
+ The :paramref:`.mapper.column_prefix` parameter is uncommon in
+ modern use. For dealing with reflected tables, a more flexible
+ approach to automating a naming scheme is to intercept the
+ :class:`.Column` objects as they are reflected; see the section
+ :ref:`mapper_automated_reflection_schemes` for notes on this usage
+ pattern.
+
+ :param concrete: If True, indicates this mapper should use concrete
+ table inheritance with its parent mapper.
+
+ See the section :ref:`concrete_inheritance` for an example.
+
+ :param confirm_deleted_rows: defaults to True; when a DELETE occurs
+ of one more rows based on specific primary keys, a warning is
+ emitted when the number of rows matched does not equal the number
+ of rows expected. This parameter may be set to False to handle the
+ case where database ON DELETE CASCADE rules may be deleting some of
+ those rows automatically. The warning may be changed to an
+ exception in a future release.
+
+ .. versionadded:: 0.9.4 - added
+ :paramref:`.mapper.confirm_deleted_rows` as well as conditional
+ matched row checking on delete.
+
+ :param eager_defaults: if True, the ORM will immediately fetch the
+ value of server-generated default values after an INSERT or UPDATE,
+ rather than leaving them as expired to be fetched on next access.
+ This can be used for event schemes where the server-generated values
+ are needed immediately before the flush completes. By default,
+ this scheme will emit an individual ``SELECT`` statement per row
+ inserted or updated, which note can add significant performance
+ overhead. However, if the
+ target database supports :term:`RETURNING`, the default values will
+ be returned inline with the INSERT or UPDATE statement, which can
+ greatly enhance performance for an application that needs frequent
+ access to just-generated server defaults.
+
+ .. seealso::
+
+ :ref:`orm_server_defaults`
+
+ .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now
+ make use of :term:`RETURNING` for backends which support it.
+
+ :param exclude_properties: A list or set of string column names to
+ be excluded from mapping.
+
+ See :ref:`include_exclude_cols` for an example.
+
+ :param include_properties: An inclusive list or set of string column
+ names to map.
+
+ See :ref:`include_exclude_cols` for an example.
+
+ :param inherits: A mapped class or the corresponding
+ :class:`_orm.Mapper`
+ of one indicating a superclass to which this :class:`_orm.Mapper`
+ should *inherit* from. The mapped class here must be a subclass
+ of the other mapper's class. When using Declarative, this argument
+ is passed automatically as a result of the natural class
+ hierarchy of the declared classes.
+
+ .. seealso::
+
+ :ref:`inheritance_toplevel`
+
+ :param inherit_condition: For joined table inheritance, a SQL
+ expression which will
+ define how the two tables are joined; defaults to a natural join
+ between the two tables.
+
+ :param inherit_foreign_keys: When ``inherit_condition`` is used and
+ the columns present are missing a :class:`_schema.ForeignKey`
+ configuration, this parameter can be used to specify which columns
+ are "foreign". In most cases can be left as ``None``.
+
+ :param legacy_is_orphan: Boolean, defaults to ``False``.
+ When ``True``, specifies that "legacy" orphan consideration
+ is to be applied to objects mapped by this mapper, which means
+ that a pending (that is, not persistent) object is auto-expunged
+ from an owning :class:`.Session` only when it is de-associated
+ from *all* parents that specify a ``delete-orphan`` cascade towards
+ this mapper. The new default behavior is that the object is
+ auto-expunged when it is de-associated with *any* of its parents
+ that specify ``delete-orphan`` cascade. This behavior is more
+ consistent with that of a persistent object, and allows behavior to
+ be consistent in more scenarios independently of whether or not an
+ orphan object has been flushed yet or not.
+
+ See the change note and example at :ref:`legacy_is_orphan_addition`
+ for more detail on this change.
+
+ :param non_primary: Specify that this :class:`_orm.Mapper`
+ is in addition
+ to the "primary" mapper, that is, the one used for persistence.
+ The :class:`_orm.Mapper` created here may be used for ad-hoc
+ mapping of the class to an alternate selectable, for loading
+ only.
+
+ .. seealso::
+
+ :ref:`relationship_aliased_class` - the new pattern that removes
+ the need for the :paramref:`_orm.Mapper.non_primary` flag.
+
+ :param passive_deletes: Indicates DELETE behavior of foreign key
+ columns when a joined-table inheritance entity is being deleted.
+ Defaults to ``False`` for a base mapper; for an inheriting mapper,
+ defaults to ``False`` unless the value is set to ``True``
+ on the superclass mapper.
+
+ When ``True``, it is assumed that ON DELETE CASCADE is configured
+ on the foreign key relationships that link this mapper's table
+ to its superclass table, so that when the unit of work attempts
+ to delete the entity, it need only emit a DELETE statement for the
+ superclass table, and not this table.
+
+ When ``False``, a DELETE statement is emitted for this mapper's
+ table individually. If the primary key attributes local to this
+ table are unloaded, then a SELECT must be emitted in order to
+ validate these attributes; note that the primary key columns
+ of a joined-table subclass are not part of the "primary key" of
+ the object as a whole.
+
+ Note that a value of ``True`` is **always** forced onto the
+ subclass mappers; that is, it's not possible for a superclass
+ to specify passive_deletes without this taking effect for
+ all subclass mappers.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`passive_deletes` - description of similar feature as
+ used with :func:`_orm.relationship`
+
+ :paramref:`.mapper.passive_updates` - supporting ON UPDATE
+ CASCADE for joined-table inheritance mappers
+
+ :param passive_updates: Indicates UPDATE behavior of foreign key
+ columns when a primary key column changes on a joined-table
+ inheritance mapping. Defaults to ``True``.
+
+ When True, it is assumed that ON UPDATE CASCADE is configured on
+ the foreign key in the database, and that the database will handle
+ propagation of an UPDATE from a source column to dependent columns
+ on joined-table rows.
+
+ When False, it is assumed that the database does not enforce
+ referential integrity and will not be issuing its own CASCADE
+ operation for an update. The unit of work process will
+ emit an UPDATE statement for the dependent columns during a
+ primary key change.
+
+ .. seealso::
+
+ :ref:`passive_updates` - description of a similar feature as
+ used with :func:`_orm.relationship`
+
+ :paramref:`.mapper.passive_deletes` - supporting ON DELETE
+ CASCADE for joined-table inheritance mappers
+
+ :param polymorphic_load: Specifies "polymorphic loading" behavior
+ for a subclass in an inheritance hierarchy (joined and single
+ table inheritance only). Valid values are:
+
+ * "'inline'" - specifies this class should be part of the
+ "with_polymorphic" mappers, e.g. its columns will be included
+ in a SELECT query against the base.
+
+ * "'selectin'" - specifies that when instances of this class
+ are loaded, an additional SELECT will be emitted to retrieve
+ the columns specific to this subclass. The SELECT uses
+ IN to fetch multiple subclasses at once.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`with_polymorphic_mapper_config`
+
+ :ref:`polymorphic_selectin`
+
+ :param polymorphic_on: Specifies the column, attribute, or
+ SQL expression used to determine the target class for an
+ incoming row, when inheriting classes are present.
+
+ This value is commonly a :class:`_schema.Column` object that's
+ present in the mapped :class:`_schema.Table`::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+
+ __mapper_args__ = {
+ "polymorphic_on":discriminator,
+ "polymorphic_identity":"employee"
+ }
+
+ It may also be specified
+ as a SQL expression, as in this example where we
+ use the :func:`.case` construct to provide a conditional
+ approach::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+
+ __mapper_args__ = {
+ "polymorphic_on":case([
+ (discriminator == "EN", "engineer"),
+ (discriminator == "MA", "manager"),
+ ], else_="employee"),
+ "polymorphic_identity":"employee"
+ }
+
+ It may also refer to any attribute
+ configured with :func:`.column_property`, or to the
+ string name of one::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+ employee_type = column_property(
+ case([
+ (discriminator == "EN", "engineer"),
+ (discriminator == "MA", "manager"),
+ ], else_="employee")
+ )
+
+ __mapper_args__ = {
+ "polymorphic_on":employee_type,
+ "polymorphic_identity":"employee"
+ }
+
+ When setting ``polymorphic_on`` to reference an
+ attribute or expression that's not present in the
+ locally mapped :class:`_schema.Table`, yet the value
+ of the discriminator should be persisted to the database,
+ the value of the
+ discriminator is not automatically set on new
+ instances; this must be handled by the user,
+ either through manual means or via event listeners.
+ A typical approach to establishing such a listener
+ looks like::
+
+ from sqlalchemy import event
+ from sqlalchemy.orm import object_mapper
+
+ @event.listens_for(Employee, "init", propagate=True)
+ def set_identity(instance, *arg, **kw):
+ mapper = object_mapper(instance)
+ instance.discriminator = mapper.polymorphic_identity
+
+ Where above, we assign the value of ``polymorphic_identity``
+ for the mapped class to the ``discriminator`` attribute,
+ thus persisting the value to the ``discriminator`` column
+ in the database.
+
+ .. warning::
+
+ Currently, **only one discriminator column may be set**, typically
+ on the base-most class in the hierarchy. "Cascading" polymorphic
+ columns are not yet supported.
+
+ .. seealso::
+
+ :ref:`inheritance_toplevel`
+
+ :param polymorphic_identity: Specifies the value which
+ identifies this particular class as returned by the
+ column expression referred to by the ``polymorphic_on``
+ setting. As rows are received, the value corresponding
+ to the ``polymorphic_on`` column expression is compared
+ to this value, indicating which subclass should
+ be used for the newly reconstructed object.
+
+ :param properties: A dictionary mapping the string names of object
+ attributes to :class:`.MapperProperty` instances, which define the
+ persistence behavior of that attribute. Note that
+ :class:`_schema.Column`
+ objects present in
+ the mapped :class:`_schema.Table` are automatically placed into
+ ``ColumnProperty`` instances upon mapping, unless overridden.
+ When using Declarative, this argument is passed automatically,
+ based on all those :class:`.MapperProperty` instances declared
+ in the declared class body.
+
+ .. seealso::
+
+ :ref:`orm_mapping_properties` - in the
+ :ref:`orm_mapping_classes_toplevel`
+
+ :param primary_key: A list of :class:`_schema.Column`
+ objects which define
+ the primary key to be used against this mapper's selectable unit.
+ This is normally simply the primary key of the ``local_table``, but
+ can be overridden here.
+
+ .. seealso::
+
+ :ref:`mapper_primary_key` - background and example use
+
+ :param version_id_col: A :class:`_schema.Column`
+ that will be used to keep a running version id of rows
+ in the table. This is used to detect concurrent updates or
+ the presence of stale data in a flush. The methodology is to
+ detect if an UPDATE statement does not match the last known
+ version id, a
+ :class:`~sqlalchemy.orm.exc.StaleDataError` exception is
+ thrown.
+ By default, the column must be of :class:`.Integer` type,
+ unless ``version_id_generator`` specifies an alternative version
+ generator.
+
+ .. seealso::
+
+ :ref:`mapper_version_counter` - discussion of version counting
+ and rationale.
+
+ :param version_id_generator: Define how new version ids should
+ be generated. Defaults to ``None``, which indicates that
+ a simple integer counting scheme be employed. To provide a custom
+ versioning scheme, provide a callable function of the form::
+
+ def generate_version(version):
+ return next_version
+
+ Alternatively, server-side versioning functions such as triggers,
+ or programmatic versioning schemes outside of the version id
+ generator may be used, by specifying the value ``False``.
+ Please see :ref:`server_side_version_counter` for a discussion
+ of important points when using this option.
+
+ .. versionadded:: 0.9.0 ``version_id_generator`` supports
+ server-side version number generation.
+
+ .. seealso::
+
+ :ref:`custom_version_counter`
+
+ :ref:`server_side_version_counter`
+
+
+ :param with_polymorphic: A tuple in the form ``(<classes>,
+ <selectable>)`` indicating the default style of "polymorphic"
+ loading, that is, which tables are queried at once. <classes> is
+ any single or list of mappers and/or classes indicating the
+ inherited classes that should be loaded at once. The special value
+ ``'*'`` may be used to indicate all descending classes should be
+ loaded immediately. The second tuple argument <selectable>
+ indicates a selectable that will be used to query for multiple
+ classes.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - discussion of polymorphic querying
+ techniques.
+
+ """
+ self.class_ = util.assert_arg_type(class_, type, "class_")
+ self._sort_key = "%s.%s" % (
+ self.class_.__module__,
+ self.class_.__name__,
+ )
+
+ self.class_manager = None
+
+ self._primary_key_argument = util.to_list(primary_key)
+ self.non_primary = non_primary
+
+ self.always_refresh = always_refresh
+
+ if isinstance(version_id_col, MapperProperty):
+ self.version_id_prop = version_id_col
+ self.version_id_col = None
+ else:
+ self.version_id_col = version_id_col
+ if version_id_generator is False:
+ self.version_id_generator = False
+ elif version_id_generator is None:
+ self.version_id_generator = lambda x: (x or 0) + 1
+ else:
+ self.version_id_generator = version_id_generator
+
+ self.concrete = concrete
+ self.single = False
+ self.inherits = inherits
+ if local_table is not None:
+ self.local_table = coercions.expect(
+ roles.StrictFromClauseRole, local_table
+ )
+ else:
+ self.local_table = None
+
+ self.inherit_condition = inherit_condition
+ self.inherit_foreign_keys = inherit_foreign_keys
+ self._init_properties = properties or {}
+ self._delete_orphans = []
+ self.batch = batch
+ self.eager_defaults = eager_defaults
+ self.column_prefix = column_prefix
+ self.polymorphic_on = (
+ coercions.expect(
+ roles.ColumnArgumentOrKeyRole,
+ polymorphic_on,
+ argname="polymorphic_on",
+ )
+ if polymorphic_on is not None
+ else None
+ )
+ self._dependency_processors = []
+ self.validators = util.EMPTY_DICT
+ self.passive_updates = passive_updates
+ self.passive_deletes = passive_deletes
+ self.legacy_is_orphan = legacy_is_orphan
+ self._clause_adapter = None
+ self._requires_row_aliasing = False
+ self._inherits_equated_pairs = None
+ self._memoized_values = {}
+ self._compiled_cache_size = _compiled_cache_size
+ self._reconstructor = None
+ self.allow_partial_pks = allow_partial_pks
+
+ if self.inherits and not self.concrete:
+ self.confirm_deleted_rows = False
+ else:
+ self.confirm_deleted_rows = confirm_deleted_rows
+
+ self._set_with_polymorphic(with_polymorphic)
+ self.polymorphic_load = polymorphic_load
+
+ # our 'polymorphic identity', a string name that when located in a
+ # result set row indicates this Mapper should be used to construct
+ # the object instance for that row.
+ self.polymorphic_identity = polymorphic_identity
+
+ # a dictionary of 'polymorphic identity' names, associating those
+ # names with Mappers that will be used to construct object instances
+ # upon a select operation.
+ if _polymorphic_map is None:
+ self.polymorphic_map = {}
+ else:
+ self.polymorphic_map = _polymorphic_map
+
+ if include_properties is not None:
+ self.include_properties = util.to_set(include_properties)
+ else:
+ self.include_properties = None
+ if exclude_properties:
+ self.exclude_properties = util.to_set(exclude_properties)
+ else:
+ self.exclude_properties = None
+
+ # prevent this mapper from being constructed
+ # while a configure_mappers() is occurring (and defer a
+ # configure_mappers() until construction succeeds)
+ with _CONFIGURE_MUTEX:
+ self.dispatch._events._new_mapper_instance(class_, self)
+ self._configure_inheritance()
+ self._configure_class_instrumentation()
+ self._configure_properties()
+ self._configure_polymorphic_setter()
+ self._configure_pks()
+ self.registry._flag_new_mapper(self)
+ self._log("constructed")
+ self._expire_memoizations()
+
+ # major attributes initialized at the classlevel so that
+ # they can be Sphinx-documented.
+
+ is_mapper = True
+ """Part of the inspection API."""
+
+ represents_outer_join = False
+
+ @property
+ def mapper(self):
+ """Part of the inspection API.
+
+ Returns self.
+
+ """
+ return self
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ @property
+ def entity(self):
+ r"""Part of the inspection API.
+
+ Returns self.class\_.
+
+ """
+ return self.class_
+
+ local_table = None
+ """The :class:`_expression.Selectable` which this :class:`_orm.Mapper`
+ manages.
+
+ Typically is an instance of :class:`_schema.Table` or
+ :class:`_expression.Alias`.
+ May also be ``None``.
+
+ The "local" table is the
+ selectable that the :class:`_orm.Mapper` is directly responsible for
+ managing from an attribute access and flush perspective. For
+ non-inheriting mappers, the local table is the same as the
+ "mapped" table. For joined-table inheritance mappers, local_table
+ will be the particular sub-table of the overall "join" which
+ this :class:`_orm.Mapper` represents. If this mapper is a
+ single-table inheriting mapper, local_table will be ``None``.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.persist_selectable`.
+
+ """
+
+ persist_selectable = None
+ """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper`
+ is mapped.
+
+ Typically an instance of :class:`_schema.Table`,
+ :class:`_expression.Join`, or :class:`_expression.Alias`.
+
+ The :attr:`_orm.Mapper.persist_selectable` is separate from
+ :attr:`_orm.Mapper.selectable` in that the former represents columns
+ that are mapped on this class or its superclasses, whereas the
+ latter may be a "polymorphic" selectable that contains additional columns
+ which are in fact mapped on subclasses only.
+
+ "persist selectable" is the "thing the mapper writes to" and
+ "selectable" is the "thing the mapper selects from".
+
+ :attr:`_orm.Mapper.persist_selectable` is also separate from
+ :attr:`_orm.Mapper.local_table`, which represents the set of columns that
+ are locally mapped on this class directly.
+
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.selectable`.
+
+ :attr:`_orm.Mapper.local_table`.
+
+ """
+
+ inherits = None
+ """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper`
+ inherits from, if any.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ configured = False
+ """Represent ``True`` if this :class:`_orm.Mapper` has been configured.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ .. seealso::
+
+ :func:`.configure_mappers`.
+
+ """
+
+ concrete = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a concrete
+ inheritance mapper.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ tables = None
+ """An iterable containing the collection of :class:`_schema.Table` objects
+ which this :class:`_orm.Mapper` is aware of.
+
+ If the mapper is mapped to a :class:`_expression.Join`, or an
+ :class:`_expression.Alias`
+ representing a :class:`_expression.Select`, the individual
+ :class:`_schema.Table`
+ objects that comprise the full construct will be represented here.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ primary_key = None
+ """An iterable containing the collection of :class:`_schema.Column`
+ objects
+ which comprise the 'primary key' of the mapped table, from the
+ perspective of this :class:`_orm.Mapper`.
+
+ This list is against the selectable in
+ :attr:`_orm.Mapper.persist_selectable`.
+ In the case of inheriting mappers, some columns may be managed by a
+ superclass mapper. For example, in the case of a
+ :class:`_expression.Join`, the
+ primary key is determined by all of the primary key columns across all
+ tables referenced by the :class:`_expression.Join`.
+
+ The list is also not necessarily the same as the primary key column
+ collection associated with the underlying tables; the :class:`_orm.Mapper`
+ features a ``primary_key`` argument that can override what the
+ :class:`_orm.Mapper` considers as primary key columns.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ class_ = None
+ """The Python class which this :class:`_orm.Mapper` maps.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ class_manager = None
+ """The :class:`.ClassManager` which maintains event listeners
+ and class-bound descriptors for this :class:`_orm.Mapper`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ single = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a single table
+ inheritance mapper.
+
+ :attr:`_orm.Mapper.local_table` will be ``None`` if this flag is set.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ non_primary = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary"
+ mapper, e.g. a mapper that is used only to select rows but not for
+ persistence management.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_on = None
+ """The :class:`_schema.Column` or SQL expression specified as the
+ ``polymorphic_on`` argument
+ for this :class:`_orm.Mapper`, within an inheritance scenario.
+
+ This attribute is normally a :class:`_schema.Column` instance but
+ may also be an expression, such as one derived from
+ :func:`.cast`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_map = None
+ """A mapping of "polymorphic identity" identifiers mapped to
+ :class:`_orm.Mapper` instances, within an inheritance scenario.
+
+ The identifiers can be of any type which is comparable to the
+ type of column represented by :attr:`_orm.Mapper.polymorphic_on`.
+
+ An inheritance chain of mappers will all reference the same
+ polymorphic map object. The object is used to correlate incoming
+ result rows to target mappers.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_identity = None
+ """Represent an identifier which is matched against the
+ :attr:`_orm.Mapper.polymorphic_on` column during result row loading.
+
+ Used only with inheritance, this object can be of any type which is
+ comparable to the type of column represented by
+ :attr:`_orm.Mapper.polymorphic_on`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ base_mapper = None
+ """The base-most :class:`_orm.Mapper` in an inheritance chain.
+
+ In a non-inheriting scenario, this attribute will always be this
+ :class:`_orm.Mapper`. In an inheritance scenario, it references
+ the :class:`_orm.Mapper` which is parent to all other :class:`_orm.Mapper`
+ objects in the inheritance chain.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ columns = None
+ """A collection of :class:`_schema.Column` or other scalar expression
+ objects maintained by this :class:`_orm.Mapper`.
+
+ The collection behaves the same as that of the ``c`` attribute on
+ any :class:`_schema.Table` object,
+ except that only those columns included in
+ this mapping are present, and are keyed based on the attribute name
+ defined in the mapping, not necessarily the ``key`` attribute of the
+ :class:`_schema.Column` itself. Additionally, scalar expressions mapped
+ by :func:`.column_property` are also present here.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ validators = None
+ """An immutable dictionary of attributes which have been decorated
+ using the :func:`_orm.validates` decorator.
+
+ The dictionary contains string attribute names as keys
+ mapped to the actual validation method.
+
+ """
+
+ c = None
+ """A synonym for :attr:`_orm.Mapper.columns`."""
+
+ @property
+ @util.deprecated("1.3", "Use .persist_selectable")
+ def mapped_table(self):
+ return self.persist_selectable
+
+ @util.memoized_property
+ def _path_registry(self):
+ return PathRegistry.per_mapper(self)
+
+ def _configure_inheritance(self):
+ """Configure settings related to inheriting and/or inherited mappers
+ being present."""
+
+ # a set of all mappers which inherit from this one.
+ self._inheriting_mappers = util.WeakSequence()
+
+ if self.inherits:
+ if isinstance(self.inherits, type):
+ self.inherits = class_mapper(self.inherits, configure=False)
+ if not issubclass(self.class_, self.inherits.class_):
+ raise sa_exc.ArgumentError(
+ "Class '%s' does not inherit from '%s'"
+ % (self.class_.__name__, self.inherits.class_.__name__)
+ )
+
+ self.dispatch._update(self.inherits.dispatch)
+
+ if self.non_primary != self.inherits.non_primary:
+ np = not self.non_primary and "primary" or "non-primary"
+ raise sa_exc.ArgumentError(
+ "Inheritance of %s mapper for class '%s' is "
+ "only allowed from a %s mapper"
+ % (np, self.class_.__name__, np)
+ )
+ # inherit_condition is optional.
+ if self.local_table is None:
+ self.local_table = self.inherits.local_table
+ self.persist_selectable = self.inherits.persist_selectable
+ self.single = True
+ elif self.local_table is not self.inherits.local_table:
+ if self.concrete:
+ self.persist_selectable = self.local_table
+ for mapper in self.iterate_to_root():
+ if mapper.polymorphic_on is not None:
+ mapper._requires_row_aliasing = True
+ else:
+ if self.inherit_condition is None:
+ # figure out inherit condition from our table to the
+ # immediate table of the inherited mapper, not its
+ # full table which could pull in other stuff we don't
+ # want (allows test/inheritance.InheritTest4 to pass)
+ try:
+ self.inherit_condition = sql_util.join_condition(
+ self.inherits.local_table, self.local_table
+ )
+ except sa_exc.NoForeignKeysError as nfe:
+ assert self.inherits.local_table is not None
+ assert self.local_table is not None
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Can't determine the inherit condition "
+ "between inherited table '%s' and "
+ "inheriting "
+ "table '%s'; tables have no "
+ "foreign key relationships established. "
+ "Please ensure the inheriting table has "
+ "a foreign key relationship to the "
+ "inherited "
+ "table, or provide an "
+ "'on clause' using "
+ "the 'inherit_condition' mapper argument."
+ % (
+ self.inherits.local_table.description,
+ self.local_table.description,
+ )
+ ),
+ replace_context=nfe,
+ )
+ except sa_exc.AmbiguousForeignKeysError as afe:
+ assert self.inherits.local_table is not None
+ assert self.local_table is not None
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Can't determine the inherit condition "
+ "between inherited table '%s' and "
+ "inheriting "
+ "table '%s'; tables have more than one "
+ "foreign key relationship established. "
+ "Please specify the 'on clause' using "
+ "the 'inherit_condition' mapper argument."
+ % (
+ self.inherits.local_table.description,
+ self.local_table.description,
+ )
+ ),
+ replace_context=afe,
+ )
+ self.persist_selectable = sql.join(
+ self.inherits.persist_selectable,
+ self.local_table,
+ self.inherit_condition,
+ )
+
+ fks = util.to_set(self.inherit_foreign_keys)
+ self._inherits_equated_pairs = sql_util.criterion_as_pairs(
+ self.persist_selectable.onclause,
+ consider_as_foreign_keys=fks,
+ )
+ else:
+ self.persist_selectable = self.local_table
+
+ if self.polymorphic_identity is not None and not self.concrete:
+ self._identity_class = self.inherits._identity_class
+ else:
+ self._identity_class = self.class_
+
+ if self.version_id_col is None:
+ self.version_id_col = self.inherits.version_id_col
+ self.version_id_generator = self.inherits.version_id_generator
+ elif (
+ self.inherits.version_id_col is not None
+ and self.version_id_col is not self.inherits.version_id_col
+ ):
+ util.warn(
+ "Inheriting version_id_col '%s' does not match inherited "
+ "version_id_col '%s' and will not automatically populate "
+ "the inherited versioning column. "
+ "version_id_col should only be specified on "
+ "the base-most mapper that includes versioning."
+ % (
+ self.version_id_col.description,
+ self.inherits.version_id_col.description,
+ )
+ )
+
+ self.polymorphic_map = self.inherits.polymorphic_map
+ self.batch = self.inherits.batch
+ self.inherits._inheriting_mappers.append(self)
+ self.base_mapper = self.inherits.base_mapper
+ self.passive_updates = self.inherits.passive_updates
+ self.passive_deletes = (
+ self.inherits.passive_deletes or self.passive_deletes
+ )
+ self._all_tables = self.inherits._all_tables
+
+ if self.polymorphic_identity is not None:
+ if self.polymorphic_identity in self.polymorphic_map:
+ util.warn(
+ "Reassigning polymorphic association for identity %r "
+ "from %r to %r: Check for duplicate use of %r as "
+ "value for polymorphic_identity."
+ % (
+ self.polymorphic_identity,
+ self.polymorphic_map[self.polymorphic_identity],
+ self,
+ self.polymorphic_identity,
+ )
+ )
+ self.polymorphic_map[self.polymorphic_identity] = self
+
+ if self.polymorphic_load and self.concrete:
+ raise sa_exc.ArgumentError(
+ "polymorphic_load is not currently supported "
+ "with concrete table inheritance"
+ )
+ if self.polymorphic_load == "inline":
+ self.inherits._add_with_polymorphic_subclass(self)
+ elif self.polymorphic_load == "selectin":
+ pass
+ elif self.polymorphic_load is not None:
+ raise sa_exc.ArgumentError(
+ "unknown argument for polymorphic_load: %r"
+ % self.polymorphic_load
+ )
+
+ else:
+ self._all_tables = set()
+ self.base_mapper = self
+ self.persist_selectable = self.local_table
+ if self.polymorphic_identity is not None:
+ self.polymorphic_map[self.polymorphic_identity] = self
+ self._identity_class = self.class_
+
+ if self.persist_selectable is None:
+ raise sa_exc.ArgumentError(
+ "Mapper '%s' does not have a persist_selectable specified."
+ % self
+ )
+
+ def _set_with_polymorphic(self, with_polymorphic):
+ if with_polymorphic == "*":
+ self.with_polymorphic = ("*", None)
+ elif isinstance(with_polymorphic, (tuple, list)):
+ if isinstance(
+ with_polymorphic[0], util.string_types + (tuple, list)
+ ):
+ self.with_polymorphic = with_polymorphic
+ else:
+ self.with_polymorphic = (with_polymorphic, None)
+ elif with_polymorphic is not None:
+ raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
+ else:
+ self.with_polymorphic = None
+
+ if self.with_polymorphic and self.with_polymorphic[1] is not None:
+ self.with_polymorphic = (
+ self.with_polymorphic[0],
+ coercions.expect(
+ roles.StrictFromClauseRole,
+ self.with_polymorphic[1],
+ allow_select=True,
+ ),
+ )
+
+ if self.configured:
+ self._expire_memoizations()
+
+ def _add_with_polymorphic_subclass(self, mapper):
+ subcl = mapper.class_
+ if self.with_polymorphic is None:
+ self._set_with_polymorphic((subcl,))
+ elif self.with_polymorphic[0] != "*":
+ self._set_with_polymorphic(
+ (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1])
+ )
+
+ def _set_concrete_base(self, mapper):
+ """Set the given :class:`_orm.Mapper` as the 'inherits' for this
+ :class:`_orm.Mapper`, assuming this :class:`_orm.Mapper` is concrete
+ and does not already have an inherits."""
+
+ assert self.concrete
+ assert not self.inherits
+ assert isinstance(mapper, Mapper)
+ self.inherits = mapper
+ self.inherits.polymorphic_map.update(self.polymorphic_map)
+ self.polymorphic_map = self.inherits.polymorphic_map
+ for mapper in self.iterate_to_root():
+ if mapper.polymorphic_on is not None:
+ mapper._requires_row_aliasing = True
+ self.batch = self.inherits.batch
+ for mp in self.self_and_descendants:
+ mp.base_mapper = self.inherits.base_mapper
+ self.inherits._inheriting_mappers.append(self)
+ self.passive_updates = self.inherits.passive_updates
+ self._all_tables = self.inherits._all_tables
+
+ for key, prop in mapper._props.items():
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
+ self._adapt_inherited_property(key, prop, False)
+
+ def _set_polymorphic_on(self, polymorphic_on):
+ self.polymorphic_on = polymorphic_on
+ self._configure_polymorphic_setter(True)
+
+ def _configure_class_instrumentation(self):
+ """If this mapper is to be a primary mapper (i.e. the
+ non_primary flag is not set), associate this Mapper with the
+ given class and entity name.
+
+ Subsequent calls to ``class_mapper()`` for the ``class_`` / ``entity``
+ name combination will return this mapper. Also decorate the
+ `__init__` method on the mapped class to include optional
+ auto-session attachment logic.
+
+ """
+
+ # we expect that declarative has applied the class manager
+ # already and set up a registry. if this is None,
+ # we will emit a deprecation warning below when we also see that
+ # it has no registry.
+ manager = attributes.manager_of_class(self.class_)
+
+ if self.non_primary:
+ if not manager or not manager.is_mapped:
+ raise sa_exc.InvalidRequestError(
+ "Class %s has no primary mapper configured. Configure "
+ "a primary mapper first before setting up a non primary "
+ "Mapper." % self.class_
+ )
+ self.class_manager = manager
+ self.registry = manager.registry
+ self._identity_class = manager.mapper._identity_class
+ manager.registry._add_non_primary_mapper(self)
+ return
+
+ if manager is not None:
+ assert manager.class_ is self.class_
+ if manager.is_mapped:
+ # changed in #7579:
+ # this message is defined in two places as of this change,
+ # also in decl_api -> _add_manager(). in 2.0, this codepath
+ # is removed as any calls to mapper() / Mapper without
+ # the registry setting up first will be rejected.
+ raise sa_exc.ArgumentError(
+ "Class '%s' already has a primary mapper defined. "
+ % self.class_
+ )
+ # else:
+ # a ClassManager may already exist as
+ # ClassManager.instrument_attribute() creates
+ # new managers for each subclass if they don't yet exist.
+
+ self.dispatch.instrument_class(self, self.class_)
+
+ # this invokes the class_instrument event and sets up
+ # the __init__ method. documented behavior is that this must
+ # occur after the instrument_class event above.
+ # yes two events with the same two words reversed and different APIs.
+ # :(
+
+ manager = instrumentation.register_class(
+ self.class_,
+ mapper=self,
+ expired_attribute_loader=util.partial(
+ loading.load_scalar_attributes, self
+ ),
+ # finalize flag means instrument the __init__ method
+ # and call the class_instrument event
+ finalize=True,
+ )
+
+ if not manager.registry:
+ util.warn_deprecated_20(
+ "Calling the mapper() function directly outside of a "
+ "declarative registry is deprecated."
+ " Please use the sqlalchemy.orm.registry.map_imperatively() "
+ "function for a classical mapping."
+ )
+ assert _legacy_registry is not None
+ _legacy_registry._add_manager(manager)
+
+ self.class_manager = manager
+ self.registry = manager.registry
+
+ # The remaining members can be added by any mapper,
+ # e_name None or not.
+ if manager.mapper is None:
+ return
+
+ event.listen(manager, "init", _event_on_init, raw=True)
+
+ for key, method in util.iterate_attributes(self.class_):
+ if key == "__init__" and hasattr(method, "_sa_original_init"):
+ method = method._sa_original_init
+ if hasattr(method, "__func__"):
+ method = method.__func__
+ if callable(method):
+ if hasattr(method, "__sa_reconstructor__"):
+ self._reconstructor = method
+ event.listen(manager, "load", _event_on_load, raw=True)
+ elif hasattr(method, "__sa_validators__"):
+ validation_opts = method.__sa_validation_opts__
+ for name in method.__sa_validators__:
+ if name in self.validators:
+ raise sa_exc.InvalidRequestError(
+ "A validation function for mapped "
+ "attribute %r on mapper %s already exists."
+ % (name, self)
+ )
+ self.validators = self.validators.union(
+ {name: (method, validation_opts)}
+ )
+
+ def _set_dispose_flags(self):
+ self.configured = True
+ self._ready_for_configure = True
+ self._dispose_called = True
+
+ self.__dict__.pop("_configure_failed", None)
+
+ def _configure_pks(self):
+ self.tables = sql_util.find_tables(self.persist_selectable)
+
+ self._pks_by_table = {}
+ self._cols_by_table = {}
+
+ all_cols = util.column_set(
+ chain(*[col.proxy_set for col in self._columntoproperty])
+ )
+
+ pk_cols = util.column_set(c for c in all_cols if c.primary_key)
+
+ # identify primary key columns which are also mapped by this mapper.
+ tables = set(self.tables + [self.persist_selectable])
+ self._all_tables.update(tables)
+ for t in tables:
+ if t.primary_key and pk_cols.issuperset(t.primary_key):
+ # ordering is important since it determines the ordering of
+ # mapper.primary_key (and therefore query.get())
+ self._pks_by_table[t] = util.ordered_column_set(
+ t.primary_key
+ ).intersection(pk_cols)
+ self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(
+ all_cols
+ )
+
+ # if explicit PK argument sent, add those columns to the
+ # primary key mappings
+ if self._primary_key_argument:
+ for k in self._primary_key_argument:
+ if k.table not in self._pks_by_table:
+ self._pks_by_table[k.table] = util.OrderedSet()
+ self._pks_by_table[k.table].add(k)
+
+ # otherwise, see that we got a full PK for the mapped table
+ elif (
+ self.persist_selectable not in self._pks_by_table
+ or len(self._pks_by_table[self.persist_selectable]) == 0
+ ):
+ raise sa_exc.ArgumentError(
+ "Mapper %s could not assemble any primary "
+ "key columns for mapped table '%s'"
+ % (self, self.persist_selectable.description)
+ )
+ elif self.local_table not in self._pks_by_table and isinstance(
+ self.local_table, schema.Table
+ ):
+ util.warn(
+ "Could not assemble any primary "
+ "keys for locally mapped table '%s' - "
+ "no rows will be persisted in this Table."
+ % self.local_table.description
+ )
+
+ if (
+ self.inherits
+ and not self.concrete
+ and not self._primary_key_argument
+ ):
+ # if inheriting, the "primary key" for this mapper is
+ # that of the inheriting (unless concrete or explicit)
+ self.primary_key = self.inherits.primary_key
+ else:
+ # determine primary key from argument or persist_selectable pks
+ if self._primary_key_argument:
+ primary_key = [
+ self.persist_selectable.corresponding_column(c)
+ for c in self._primary_key_argument
+ ]
+ else:
+ # if heuristically determined PKs, reduce to the minimal set
+ # of columns by eliminating FK->PK pairs for a multi-table
+ # expression. May over-reduce for some kinds of UNIONs
+ # / CTEs; use explicit PK argument for these special cases
+ primary_key = sql_util.reduce_columns(
+ self._pks_by_table[self.persist_selectable],
+ ignore_nonexistent_tables=True,
+ )
+
+ if len(primary_key) == 0:
+ raise sa_exc.ArgumentError(
+ "Mapper %s could not assemble any primary "
+ "key columns for mapped table '%s'"
+ % (self, self.persist_selectable.description)
+ )
+
+ self.primary_key = tuple(primary_key)
+ self._log("Identified primary key columns: %s", primary_key)
+
+ # determine cols that aren't expressed within our tables; mark these
+ # as "read only" properties which are refreshed upon INSERT/UPDATE
+ self._readonly_props = set(
+ self._columntoproperty[col]
+ for col in self._columntoproperty
+ if self._columntoproperty[col] not in self._identity_key_props
+ and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ )
+ )
+
+ def _configure_properties(self):
+
+ # TODO: consider using DedupeColumnCollection
+ self.columns = self.c = sql_base.ColumnCollection()
+
+ # object attribute names mapped to MapperProperty objects
+ self._props = util.OrderedDict()
+
+ # table columns mapped to MapperProperty
+ self._columntoproperty = _ColumnMapping(self)
+
+ # load custom properties
+ if self._init_properties:
+ for key, prop in self._init_properties.items():
+ self._configure_property(key, prop, False)
+
+ # pull properties from the inherited mapper if any.
+ if self.inherits:
+ for key, prop in self.inherits._props.items():
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
+ self._adapt_inherited_property(key, prop, False)
+
+ # create properties for each column in the mapped table,
+ # for those columns which don't already map to a property
+ for column in self.persist_selectable.columns:
+ if column in self._columntoproperty:
+ continue
+
+ column_key = (self.column_prefix or "") + column.key
+
+ if self._should_exclude(
+ column.key,
+ column_key,
+ local=self.local_table.c.contains_column(column),
+ column=column,
+ ):
+ continue
+
+ # adjust the "key" used for this column to that
+ # of the inheriting mapper
+ for mapper in self.iterate_to_root():
+ if column in mapper._columntoproperty:
+ column_key = mapper._columntoproperty[column].key
+
+ self._configure_property(
+ column_key, column, init=False, setparent=True
+ )
+
+ def _configure_polymorphic_setter(self, init=False):
+ """Configure an attribute on the mapper representing the
+ 'polymorphic_on' column, if applicable, and not
+ already generated by _configure_properties (which is typical).
+
+ Also create a setter function which will assign this
+ attribute to the value of the 'polymorphic_identity'
+ upon instance construction, also if applicable. This
+ routine will run when an instance is created.
+
+ """
+ setter = False
+
+ if self.polymorphic_on is not None:
+ setter = True
+
+ if isinstance(self.polymorphic_on, util.string_types):
+ # polymorphic_on specified as a string - link
+ # it to mapped ColumnProperty
+ try:
+ self.polymorphic_on = self._props[self.polymorphic_on]
+ except KeyError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Can't determine polymorphic_on "
+ "value '%s' - no attribute is "
+ "mapped to this name." % self.polymorphic_on
+ ),
+ replace_context=err,
+ )
+
+ if self.polymorphic_on in self._columntoproperty:
+ # polymorphic_on is a column that is already mapped
+ # to a ColumnProperty
+ prop = self._columntoproperty[self.polymorphic_on]
+ elif isinstance(self.polymorphic_on, MapperProperty):
+ # polymorphic_on is directly a MapperProperty,
+ # ensure it's a ColumnProperty
+ if not isinstance(
+ self.polymorphic_on, properties.ColumnProperty
+ ):
+ raise sa_exc.ArgumentError(
+ "Only direct column-mapped "
+ "property or SQL expression "
+ "can be passed for polymorphic_on"
+ )
+ prop = self.polymorphic_on
+ else:
+ # polymorphic_on is a Column or SQL expression and
+ # doesn't appear to be mapped. this means it can be 1.
+ # only present in the with_polymorphic selectable or
+ # 2. a totally standalone SQL expression which we'd
+ # hope is compatible with this mapper's persist_selectable
+ col = self.persist_selectable.corresponding_column(
+ self.polymorphic_on
+ )
+ if col is None:
+ # polymorphic_on doesn't derive from any
+ # column/expression isn't present in the mapped
+ # table. we will make a "hidden" ColumnProperty
+ # for it. Just check that if it's directly a
+ # schema.Column and we have with_polymorphic, it's
+ # likely a user error if the schema.Column isn't
+ # represented somehow in either persist_selectable or
+ # with_polymorphic. Otherwise as of 0.7.4 we
+ # just go with it and assume the user wants it
+ # that way (i.e. a CASE statement)
+ setter = False
+ instrument = False
+ col = self.polymorphic_on
+ if isinstance(col, schema.Column) and (
+ self.with_polymorphic is None
+ or self.with_polymorphic[1].corresponding_column(col)
+ is None
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Could not map polymorphic_on column "
+ "'%s' to the mapped table - polymorphic "
+ "loads will not function properly"
+ % col.description
+ )
+ else:
+ # column/expression that polymorphic_on derives from
+ # is present in our mapped table
+ # and is probably mapped, but polymorphic_on itself
+ # is not. This happens when
+ # the polymorphic_on is only directly present in the
+ # with_polymorphic selectable, as when use
+ # polymorphic_union.
+ # we'll make a separate ColumnProperty for it.
+ instrument = True
+ key = getattr(col, "key", None)
+ if key:
+ if self._should_exclude(col.key, col.key, False, col):
+ raise sa_exc.InvalidRequestError(
+ "Cannot exclude or override the "
+ "discriminator column %r" % col.key
+ )
+ else:
+ self.polymorphic_on = col = col.label("_sa_polymorphic_on")
+ key = col.key
+
+ prop = properties.ColumnProperty(col, _instrument=instrument)
+ self._configure_property(key, prop, init=init, setparent=True)
+
+ # the actual polymorphic_on should be the first public-facing
+ # column in the property
+ self.polymorphic_on = prop.columns[0]
+ polymorphic_key = prop.key
+
+ else:
+ # no polymorphic_on was set.
+ # check inheriting mappers for one.
+ for mapper in self.iterate_to_root():
+ # determine if polymorphic_on of the parent
+ # should be propagated here. If the col
+ # is present in our mapped table, or if our mapped
+ # table is the same as the parent (i.e. single table
+ # inheritance), we can use it
+ if mapper.polymorphic_on is not None:
+ if self.persist_selectable is mapper.persist_selectable:
+ self.polymorphic_on = mapper.polymorphic_on
+ else:
+ self.polymorphic_on = (
+ self.persist_selectable
+ ).corresponding_column(mapper.polymorphic_on)
+ # we can use the parent mapper's _set_polymorphic_identity
+ # directly; it ensures the polymorphic_identity of the
+ # instance's mapper is used so is portable to subclasses.
+ if self.polymorphic_on is not None:
+ self._set_polymorphic_identity = (
+ mapper._set_polymorphic_identity
+ )
+ self._validate_polymorphic_identity = (
+ mapper._validate_polymorphic_identity
+ )
+ else:
+ self._set_polymorphic_identity = None
+ return
+
+ if setter:
+
+ def _set_polymorphic_identity(state):
+ dict_ = state.dict
+ state.get_impl(polymorphic_key).set(
+ state,
+ dict_,
+ state.manager.mapper.polymorphic_identity,
+ None,
+ )
+
+ def _validate_polymorphic_identity(mapper, state, dict_):
+ if (
+ polymorphic_key in dict_
+ and dict_[polymorphic_key]
+ not in mapper._acceptable_polymorphic_identities
+ ):
+ util.warn_limited(
+ "Flushing object %s with "
+ "incompatible polymorphic identity %r; the "
+ "object may not refresh and/or load correctly",
+ (state_str(state), dict_[polymorphic_key]),
+ )
+
+ self._set_polymorphic_identity = _set_polymorphic_identity
+ self._validate_polymorphic_identity = (
+ _validate_polymorphic_identity
+ )
+ else:
+ self._set_polymorphic_identity = None
+
+ _validate_polymorphic_identity = None
+
+ @HasMemoized.memoized_attribute
+ def _version_id_prop(self):
+ if self.version_id_col is not None:
+ return self._columntoproperty[self.version_id_col]
+ else:
+ return None
+
+ @HasMemoized.memoized_attribute
+ def _acceptable_polymorphic_identities(self):
+ identities = set()
+
+ stack = deque([self])
+ while stack:
+ item = stack.popleft()
+ if item.persist_selectable is self.persist_selectable:
+ identities.add(item.polymorphic_identity)
+ stack.extend(item._inheriting_mappers)
+
+ return identities
+
+ @HasMemoized.memoized_attribute
+ def _prop_set(self):
+ return frozenset(self._props.values())
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _adapt_inherited_property(self, key, prop, init):
+ descriptor_props = util.preloaded.orm_descriptor_props
+
+ if not self.concrete:
+ self._configure_property(key, prop, init=False, setparent=False)
+ elif key not in self._props:
+ # determine if the class implements this attribute; if not,
+ # or if it is implemented by the attribute that is handling the
+ # given superclass-mapped property, then we need to report that we
+ # can't use this at the instance level since we are a concrete
+ # mapper and we don't map this. don't trip user-defined
+ # descriptors that might have side effects when invoked.
+ implementing_attribute = self.class_manager._get_class_attr_mro(
+ key, prop
+ )
+ if implementing_attribute is prop or (
+ isinstance(
+ implementing_attribute, attributes.InstrumentedAttribute
+ )
+ and implementing_attribute._parententity is prop.parent
+ ):
+ self._configure_property(
+ key,
+ descriptor_props.ConcreteInheritedProperty(),
+ init=init,
+ setparent=True,
+ )
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _configure_property(self, key, prop, init=True, setparent=True):
+ descriptor_props = util.preloaded.orm_descriptor_props
+ self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
+
+ if not isinstance(prop, MapperProperty):
+ prop = self._property_from_column(key, prop)
+
+ if isinstance(prop, properties.ColumnProperty):
+ col = self.persist_selectable.corresponding_column(prop.columns[0])
+
+ # if the column is not present in the mapped table,
+ # test if a column has been added after the fact to the
+ # parent table (or their parent, etc.) [ticket:1570]
+ if col is None and self.inherits:
+ path = [self]
+ for m in self.inherits.iterate_to_root():
+ col = m.local_table.corresponding_column(prop.columns[0])
+ if col is not None:
+ for m2 in path:
+ m2.persist_selectable._refresh_for_new_column(col)
+ col = self.persist_selectable.corresponding_column(
+ prop.columns[0]
+ )
+ break
+ path.append(m)
+
+ # subquery expression, column not present in the mapped
+ # selectable.
+ if col is None:
+ col = prop.columns[0]
+
+ # column is coming in after _readonly_props was
+ # initialized; check for 'readonly'
+ if hasattr(self, "_readonly_props") and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ ):
+ self._readonly_props.add(prop)
+
+ else:
+ # if column is coming in after _cols_by_table was
+ # initialized, ensure the col is in the right set
+ if (
+ hasattr(self, "_cols_by_table")
+ and col.table in self._cols_by_table
+ and col not in self._cols_by_table[col.table]
+ ):
+ self._cols_by_table[col.table].add(col)
+
+ # if this properties.ColumnProperty represents the "polymorphic
+ # discriminator" column, mark it. We'll need this when rendering
+ # columns in SELECT statements.
+ if not hasattr(prop, "_is_polymorphic_discriminator"):
+ prop._is_polymorphic_discriminator = (
+ col is self.polymorphic_on
+ or prop.columns[0] is self.polymorphic_on
+ )
+
+ if isinstance(col, expression.Label):
+ # new in 1.4, get column property against expressions
+ # to be addressable in subqueries
+ col.key = col._tq_key_label = key
+
+ self.columns.add(col, key)
+ for col in prop.columns:
+ for col in col.proxy_set:
+ self._columntoproperty[col] = prop
+
+ prop.key = key
+
+ if setparent:
+ prop.set_parent(self, init)
+
+ if key in self._props and getattr(
+ self._props[key], "_mapped_by_synonym", False
+ ):
+ syn = self._props[key]._mapped_by_synonym
+ raise sa_exc.ArgumentError(
+ "Can't call map_column=True for synonym %r=%r, "
+ "a ColumnProperty already exists keyed to the name "
+ "%r for column %r" % (syn, key, key, syn)
+ )
+
+ if (
+ key in self._props
+ and not isinstance(prop, properties.ColumnProperty)
+ and not isinstance(
+ self._props[key],
+ (
+ properties.ColumnProperty,
+ descriptor_props.ConcreteInheritedProperty,
+ ),
+ )
+ ):
+ util.warn(
+ "Property %s on %s being replaced with new "
+ "property %s; the old property will be discarded"
+ % (self._props[key], self, prop)
+ )
+ oldprop = self._props[key]
+ self._path_registry.pop(oldprop, None)
+
+ self._props[key] = prop
+
+ if not self.non_primary:
+ prop.instrument_class(self)
+
+ for mapper in self._inheriting_mappers:
+ mapper._adapt_inherited_property(key, prop, init)
+
+ if init:
+ prop.init()
+ prop.post_instrument_class(self)
+
+ if self.configured:
+ self._expire_memoizations()
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _property_from_column(self, key, prop):
+ """generate/update a :class:`.ColumnProperty` given a
+ :class:`_schema.Column` object."""
+ descriptor_props = util.preloaded.orm_descriptor_props
+ # we were passed a Column or a list of Columns;
+ # generate a properties.ColumnProperty
+ columns = util.to_list(prop)
+ column = columns[0]
+ assert isinstance(column, expression.ColumnElement)
+
+ prop = self._props.get(key, None)
+
+ if isinstance(prop, properties.ColumnProperty):
+ if (
+ (
+ not self._inherits_equated_pairs
+ or (prop.columns[0], column)
+ not in self._inherits_equated_pairs
+ )
+ and not prop.columns[0].shares_lineage(column)
+ and prop.columns[0] is not self.version_id_col
+ and column is not self.version_id_col
+ ):
+ warn_only = prop.parent is not self
+ msg = (
+ "Implicitly combining column %s with column "
+ "%s under attribute '%s'. Please configure one "
+ "or more attributes for these same-named columns "
+ "explicitly." % (prop.columns[-1], column, key)
+ )
+ if warn_only:
+ util.warn(msg)
+ else:
+ raise sa_exc.InvalidRequestError(msg)
+
+ # existing properties.ColumnProperty from an inheriting
+ # mapper. make a copy and append our column to it
+ prop = prop.copy()
+ prop.columns.insert(0, column)
+ self._log(
+ "inserting column to existing list "
+ "in properties.ColumnProperty %s" % (key)
+ )
+ return prop
+ elif prop is None or isinstance(
+ prop, descriptor_props.ConcreteInheritedProperty
+ ):
+ mapped_column = []
+ for c in columns:
+ mc = self.persist_selectable.corresponding_column(c)
+ if mc is None:
+ mc = self.local_table.corresponding_column(c)
+ if mc is not None:
+ # if the column is in the local table but not the
+ # mapped table, this corresponds to adding a
+ # column after the fact to the local table.
+ # [ticket:1523]
+ self.persist_selectable._refresh_for_new_column(mc)
+ mc = self.persist_selectable.corresponding_column(c)
+ if mc is None:
+ raise sa_exc.ArgumentError(
+ "When configuring property '%s' on %s, "
+ "column '%s' is not represented in the mapper's "
+ "table. Use the `column_property()` function to "
+ "force this column to be mapped as a read-only "
+ "attribute." % (key, self, c)
+ )
+ mapped_column.append(mc)
+ return properties.ColumnProperty(*mapped_column)
+ else:
+ raise sa_exc.ArgumentError(
+ "WARNING: when configuring property '%s' on %s, "
+ "column '%s' conflicts with property '%r'. "
+ "To resolve this, map the column to the class under a "
+ "different name in the 'properties' dictionary. Or, "
+ "to remove all awareness of the column entirely "
+ "(including its availability as a foreign key), "
+ "use the 'include_properties' or 'exclude_properties' "
+ "mapper arguments to control specifically which table "
+ "columns get mapped." % (key, self, column.key, prop)
+ )
+
+ def _check_configure(self):
+ if self.registry._new_mappers:
+ _configure_registries({self.registry}, cascade=True)
+
+ def _post_configure_properties(self):
+ """Call the ``init()`` method on all ``MapperProperties``
+ attached to this mapper.
+
+ This is a deferred configuration step which is intended
+ to execute once all mappers have been constructed.
+
+ """
+
+ self._log("_post_configure_properties() started")
+ l = [(key, prop) for key, prop in self._props.items()]
+ for key, prop in l:
+ self._log("initialize prop %s", key)
+
+ if prop.parent is self and not prop._configure_started:
+ prop.init()
+
+ if prop._configure_finished:
+ prop.post_instrument_class(self)
+
+ self._log("_post_configure_properties() complete")
+ self.configured = True
+
+ def add_properties(self, dict_of_properties):
+ """Add the given dictionary of properties to this mapper,
+ using `add_property`.
+
+ """
+ for key, value in dict_of_properties.items():
+ self.add_property(key, value)
+
+ def add_property(self, key, prop):
+ """Add an individual MapperProperty to this mapper.
+
+ If the mapper has not been configured yet, just adds the
+ property to the initial properties dictionary sent to the
+ constructor. If this Mapper has already been configured, then
+ the given MapperProperty is configured immediately.
+
+ """
+ self._init_properties[key] = prop
+ self._configure_property(key, prop, init=self.configured)
+
+ def _expire_memoizations(self):
+ for mapper in self.iterate_to_root():
+ mapper._reset_memoizations()
+
+ @property
+ def _log_desc(self):
+ return (
+ "("
+ + self.class_.__name__
+ + "|"
+ + (
+ self.local_table is not None
+ and self.local_table.description
+ or str(self.local_table)
+ )
+ + (self.non_primary and "|non-primary" or "")
+ + ")"
+ )
+
+ def _log(self, msg, *args):
+ self.logger.info("%s " + msg, *((self._log_desc,) + args))
+
+ def _log_debug(self, msg, *args):
+ self.logger.debug("%s " + msg, *((self._log_desc,) + args))
+
+ def __repr__(self):
+ return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
+
+ def __str__(self):
+ return "mapped class %s%s->%s" % (
+ self.class_.__name__,
+ self.non_primary and " (non-primary)" or "",
+ self.local_table.description
+ if self.local_table is not None
+ else self.persist_selectable.description,
+ )
+
+ def _is_orphan(self, state):
+ orphan_possible = False
+ for mapper in self.iterate_to_root():
+ for (key, cls) in mapper._delete_orphans:
+ orphan_possible = True
+
+ has_parent = attributes.manager_of_class(cls).has_parent(
+ state, key, optimistic=state.has_identity
+ )
+
+ if self.legacy_is_orphan and has_parent:
+ return False
+ elif not self.legacy_is_orphan and not has_parent:
+ return True
+
+ if self.legacy_is_orphan:
+ return orphan_possible
+ else:
+ return False
+
+ def has_property(self, key):
+ return key in self._props
+
+ def get_property(self, key, _configure_mappers=True):
+ """return a MapperProperty associated with the given key."""
+
+ if _configure_mappers:
+ self._check_configure()
+
+ try:
+ return self._props[key]
+ except KeyError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Mapper '%s' has no property '%s'" % (self, key)
+ ),
+ replace_context=err,
+ )
+
+ def get_property_by_column(self, column):
+ """Given a :class:`_schema.Column` object, return the
+ :class:`.MapperProperty` which maps this column."""
+
+ return self._columntoproperty[column]
+
+ @property
+ def iterate_properties(self):
+ """return an iterator of all MapperProperty objects."""
+
+ self._check_configure()
+ return iter(self._props.values())
+
+ def _mappers_from_spec(self, spec, selectable):
+ """given a with_polymorphic() argument, return the set of mappers it
+ represents.
+
+ Trims the list of mappers to just those represented within the given
+ selectable, if present. This helps some more legacy-ish mappings.
+
+ """
+ if spec == "*":
+ mappers = list(self.self_and_descendants)
+ elif spec:
+ mappers = set()
+ for m in util.to_list(spec):
+ m = _class_to_mapper(m)
+ if not m.isa(self):
+ raise sa_exc.InvalidRequestError(
+ "%r does not inherit from %r" % (m, self)
+ )
+
+ if selectable is None:
+ mappers.update(m.iterate_to_root())
+ else:
+ mappers.add(m)
+ mappers = [m for m in self.self_and_descendants if m in mappers]
+ else:
+ mappers = []
+
+ if selectable is not None:
+ tables = set(
+ sql_util.find_tables(selectable, include_aliases=True)
+ )
+ mappers = [m for m in mappers if m.local_table in tables]
+ return mappers
+
+ def _selectable_from_mappers(self, mappers, innerjoin):
+ """given a list of mappers (assumed to be within this mapper's
+ inheritance hierarchy), construct an outerjoin amongst those mapper's
+ mapped tables.
+
+ """
+ from_obj = self.persist_selectable
+ for m in mappers:
+ if m is self:
+ continue
+ if m.concrete:
+ raise sa_exc.InvalidRequestError(
+ "'with_polymorphic()' requires 'selectable' argument "
+ "when concrete-inheriting mappers are used."
+ )
+ elif not m.single:
+ if innerjoin:
+ from_obj = from_obj.join(
+ m.local_table, m.inherit_condition
+ )
+ else:
+ from_obj = from_obj.outerjoin(
+ m.local_table, m.inherit_condition
+ )
+
+ return from_obj
+
+ @HasMemoized.memoized_attribute
+ def _single_table_criterion(self):
+ if self.single and self.inherits and self.polymorphic_on is not None:
+ return self.polymorphic_on._annotate(
+ {"parententity": self, "parentmapper": self}
+ ).in_(m.polymorphic_identity for m in self.self_and_descendants)
+ else:
+ return None
+
+ @HasMemoized.memoized_attribute
+ def _with_polymorphic_mappers(self):
+ self._check_configure()
+
+ if not self.with_polymorphic:
+ return []
+ return self._mappers_from_spec(*self.with_polymorphic)
+
+ @HasMemoized.memoized_attribute
+ def _post_inspect(self):
+ """This hook is invoked by attribute inspection.
+
+ E.g. when Query calls:
+
+ coercions.expect(roles.ColumnsClauseRole, ent, keep_inspect=True)
+
+ This allows the inspection process run a configure mappers hook.
+
+ """
+ self._check_configure()
+
+ @HasMemoized.memoized_attribute
+ def _with_polymorphic_selectable(self):
+ if not self.with_polymorphic:
+ return self.persist_selectable
+
+ spec, selectable = self.with_polymorphic
+ if selectable is not None:
+ return selectable
+ else:
+ return self._selectable_from_mappers(
+ self._mappers_from_spec(spec, selectable), False
+ )
+
+ with_polymorphic_mappers = _with_polymorphic_mappers
+ """The list of :class:`_orm.Mapper` objects included in the
+ default "polymorphic" query.
+
+ """
+
+ @HasMemoized.memoized_attribute
+ def _insert_cols_evaluating_none(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ col for col in columns if col.type.should_evaluate_none
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _insert_cols_as_none(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ col.key
+ for col in columns
+ if not col.primary_key
+ and not col.server_default
+ and not col.default
+ and not col.type.should_evaluate_none
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _propkey_to_col(self):
+ return dict(
+ (
+ table,
+ dict(
+ (self._columntoproperty[col].key, col) for col in columns
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _pk_keys_by_table(self):
+ return dict(
+ (table, frozenset([col.key for col in pks]))
+ for table, pks in self._pks_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _pk_attr_keys_by_table(self):
+ return dict(
+ (
+ table,
+ frozenset([self._columntoproperty[col].key for col in pks]),
+ )
+ for table, pks in self._pks_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _server_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_default is not None
+ ]
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _server_default_plus_onupdate_propkeys(self):
+ result = set()
+
+ for table, columns in self._cols_by_table.items():
+ for col in columns:
+ if (
+ col.server_default is not None
+ or col.server_onupdate is not None
+ ) and col in self._columntoproperty:
+ result.add(self._columntoproperty[col].key)
+
+ return result
+
+ @HasMemoized.memoized_attribute
+ def _server_onupdate_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_onupdate is not None
+ ]
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_instancemethod
+ def __clause_element__(self):
+
+ annotations = {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ }
+ if self.persist_selectable is not self.local_table:
+ # joined table inheritance, with polymorphic selectable,
+ # etc.
+ annotations["dml_table"] = self.local_table._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ return self.selectable._annotate(annotations)._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ @util.memoized_property
+ def select_identity_token(self):
+ return (
+ expression.null()
+ ._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "identity_token": True,
+ }
+ )
+ ._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+ )
+
+ @property
+ def selectable(self):
+ """The :class:`_schema.FromClause` construct this
+ :class:`_orm.Mapper` selects from by default.
+
+ Normally, this is equivalent to :attr:`.persist_selectable`, unless
+ the ``with_polymorphic`` feature is in use, in which case the
+ full "polymorphic" selectable is returned.
+
+ """
+ return self._with_polymorphic_selectable
+
+ def _with_polymorphic_args(
+ self, spec=None, selectable=False, innerjoin=False
+ ):
+ if selectable not in (None, False):
+ selectable = coercions.expect(
+ roles.StrictFromClauseRole, selectable, allow_select=True
+ )
+
+ if self.with_polymorphic:
+ if not spec:
+ spec = self.with_polymorphic[0]
+ if selectable is False:
+ selectable = self.with_polymorphic[1]
+ elif selectable is False:
+ selectable = None
+ mappers = self._mappers_from_spec(spec, selectable)
+ if selectable is not None:
+ return mappers, selectable
+ else:
+ return mappers, self._selectable_from_mappers(mappers, innerjoin)
+
+ @HasMemoized.memoized_attribute
+ def _polymorphic_properties(self):
+ return list(
+ self._iterate_polymorphic_properties(
+ self._with_polymorphic_mappers
+ )
+ )
+
+ @property
+ def _all_column_expressions(self):
+ poly_properties = self._polymorphic_properties
+ adapter = self._polymorphic_adapter
+
+ return [
+ adapter.columns[prop.columns[0]] if adapter else prop.columns[0]
+ for prop in poly_properties
+ if isinstance(prop, properties.ColumnProperty)
+ and prop._renders_in_subqueries
+ ]
+
+ def _columns_plus_keys(self, polymorphic_mappers=()):
+ if polymorphic_mappers:
+ poly_properties = self._iterate_polymorphic_properties(
+ polymorphic_mappers
+ )
+ else:
+ poly_properties = self._polymorphic_properties
+
+ return [
+ (prop.key, prop.columns[0])
+ for prop in poly_properties
+ if isinstance(prop, properties.ColumnProperty)
+ ]
+
+ @HasMemoized.memoized_attribute
+ def _polymorphic_adapter(self):
+ if self.with_polymorphic:
+ return sql_util.ColumnAdapter(
+ self.selectable, equivalents=self._equivalent_columns
+ )
+ else:
+ return None
+
+ def _iterate_polymorphic_properties(self, mappers=None):
+ """Return an iterator of MapperProperty objects which will render into
+ a SELECT."""
+ if mappers is None:
+ mappers = self._with_polymorphic_mappers
+
+ if not mappers:
+ for c in self.iterate_properties:
+ yield c
+ else:
+ # in the polymorphic case, filter out discriminator columns
+ # from other mappers, as these are sometimes dependent on that
+ # mapper's polymorphic selectable (which we don't want rendered)
+ for c in util.unique_list(
+ chain(
+ *[
+ list(mapper.iterate_properties)
+ for mapper in [self] + mappers
+ ]
+ )
+ ):
+ if getattr(c, "_is_polymorphic_discriminator", False) and (
+ self.polymorphic_on is None
+ or c.columns[0] is not self.polymorphic_on
+ ):
+ continue
+ yield c
+
+ @HasMemoized.memoized_attribute
+ def attrs(self):
+ """A namespace of all :class:`.MapperProperty` objects
+ associated this mapper.
+
+ This is an object that provides each property based on
+ its key name. For instance, the mapper for a
+ ``User`` class which has ``User.name`` attribute would
+ provide ``mapper.attrs.name``, which would be the
+ :class:`.ColumnProperty` representing the ``name``
+ column. The namespace object can also be iterated,
+ which would yield each :class:`.MapperProperty`.
+
+ :class:`_orm.Mapper` has several pre-filtered views
+ of this attribute which limit the types of properties
+ returned, including :attr:`.synonyms`, :attr:`.column_attrs`,
+ :attr:`.relationships`, and :attr:`.composites`.
+
+ .. warning::
+
+ The :attr:`_orm.Mapper.attrs` accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.attrs[somename]`` over
+ ``getattr(mapper.attrs, somename)`` to avoid name collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ """
+
+ self._check_configure()
+ return util.ImmutableProperties(self._props)
+
+ @HasMemoized.memoized_attribute
+ def all_orm_descriptors(self):
+ """A namespace of all :class:`.InspectionAttr` attributes associated
+ with the mapped class.
+
+ These attributes are in all cases Python :term:`descriptors`
+ associated with the mapped class or its superclasses.
+
+ This namespace includes attributes that are mapped to the class
+ as well as attributes declared by extension modules.
+ It includes any Python descriptor type that inherits from
+ :class:`.InspectionAttr`. This includes
+ :class:`.QueryableAttribute`, as well as extension types such as
+ :class:`.hybrid_property`, :class:`.hybrid_method` and
+ :class:`.AssociationProxy`.
+
+ To distinguish between mapped attributes and extension attributes,
+ the attribute :attr:`.InspectionAttr.extension_type` will refer
+ to a constant that distinguishes between different extension types.
+
+ The sorting of the attributes is based on the following rules:
+
+ 1. Iterate through the class and its superclasses in order from
+ subclass to superclass (i.e. iterate through ``cls.__mro__``)
+
+ 2. For each class, yield the attributes in the order in which they
+ appear in ``__dict__``, with the exception of those in step
+ 3 below. In Python 3.6 and above this ordering will be the
+ same as that of the class' construction, with the exception
+ of attributes that were added after the fact by the application
+ or the mapper.
+
+ 3. If a certain attribute key is also in the superclass ``__dict__``,
+ then it's included in the iteration for that class, and not the
+ class in which it first appeared.
+
+ The above process produces an ordering that is deterministic in terms
+ of the order in which attributes were assigned to the class.
+
+ .. versionchanged:: 1.3.19 ensured deterministic ordering for
+ :meth:`_orm.Mapper.all_orm_descriptors`.
+
+ When dealing with a :class:`.QueryableAttribute`, the
+ :attr:`.QueryableAttribute.property` attribute refers to the
+ :class:`.MapperProperty` property, which is what you get when
+ referring to the collection of mapped properties via
+ :attr:`_orm.Mapper.attrs`.
+
+ .. warning::
+
+ The :attr:`_orm.Mapper.all_orm_descriptors`
+ accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.all_orm_descriptors[somename]`` over
+ ``getattr(mapper.all_orm_descriptors, somename)`` to avoid name
+ collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs`
+
+ """
+ return util.ImmutableProperties(
+ dict(self.class_manager._all_sqla_attributes())
+ )
+
+ @HasMemoized.memoized_attribute
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def synonyms(self):
+ """Return a namespace of all :class:`.SynonymProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ descriptor_props = util.preloaded.orm_descriptor_props
+
+ return self._filter_properties(descriptor_props.SynonymProperty)
+
+ @property
+ def entity_namespace(self):
+ return self.class_
+
+ @HasMemoized.memoized_attribute
+ def column_attrs(self):
+ """Return a namespace of all :class:`.ColumnProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(properties.ColumnProperty)
+
+ @util.preload_module("sqlalchemy.orm.relationships")
+ @HasMemoized.memoized_attribute
+ def relationships(self):
+ """A namespace of all :class:`.RelationshipProperty` properties
+ maintained by this :class:`_orm.Mapper`.
+
+ .. warning::
+
+ the :attr:`_orm.Mapper.relationships` accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.relationships[somename]`` over
+ ``getattr(mapper.relationships, somename)`` to avoid name
+ collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(
+ util.preloaded.orm_relationships.RelationshipProperty
+ )
+
+ @HasMemoized.memoized_attribute
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def composites(self):
+ """Return a namespace of all :class:`.CompositeProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(
+ util.preloaded.orm_descriptor_props.CompositeProperty
+ )
+
+ def _filter_properties(self, type_):
+ self._check_configure()
+ return util.ImmutableProperties(
+ util.OrderedDict(
+ (k, v) for k, v in self._props.items() if isinstance(v, type_)
+ )
+ )
+
+ @HasMemoized.memoized_attribute
+ def _get_clause(self):
+ """create a "get clause" based on the primary key. this is used
+ by query.get() and many-to-one lazyloads to load this item
+ by primary key.
+
+ """
+ params = [
+ (
+ primary_key,
+ sql.bindparam("pk_%d" % idx, type_=primary_key.type),
+ )
+ for idx, primary_key in enumerate(self.primary_key, 1)
+ ]
+ return (
+ sql.and_(*[k == v for (k, v) in params]),
+ util.column_dict(params),
+ )
+
+ @HasMemoized.memoized_attribute
+ def _equivalent_columns(self):
+ """Create a map of all equivalent columns, based on
+ the determination of column pairs that are equated to
+ one another based on inherit condition. This is designed
+ to work with the queries that util.polymorphic_union
+ comes up with, which often don't include the columns from
+ the base table directly (including the subclass table columns
+ only).
+
+ The resulting structure is a dictionary of columns mapped
+ to lists of equivalent columns, e.g.::
+
+ {
+ tablea.col1:
+ {tableb.col1, tablec.col1},
+ tablea.col2:
+ {tabled.col2}
+ }
+
+ """
+ result = util.column_dict()
+
+ def visit_binary(binary):
+ if binary.operator == operators.eq:
+ if binary.left in result:
+ result[binary.left].add(binary.right)
+ else:
+ result[binary.left] = util.column_set((binary.right,))
+ if binary.right in result:
+ result[binary.right].add(binary.left)
+ else:
+ result[binary.right] = util.column_set((binary.left,))
+
+ for mapper in self.base_mapper.self_and_descendants:
+ if mapper.inherit_condition is not None:
+ visitors.traverse(
+ mapper.inherit_condition, {}, {"binary": visit_binary}
+ )
+
+ return result
+
+ def _is_userland_descriptor(self, assigned_name, obj):
+ if isinstance(
+ obj,
+ (
+ _MappedAttribute,
+ instrumentation.ClassManager,
+ expression.ColumnElement,
+ ),
+ ):
+ return False
+ else:
+ return assigned_name not in self._dataclass_fields
+
+ @HasMemoized.memoized_attribute
+ def _dataclass_fields(self):
+ return [f.name for f in util.dataclass_fields(self.class_)]
+
+ def _should_exclude(self, name, assigned_name, local, column):
+ """determine whether a particular property should be implicitly
+ present on the class.
+
+ This occurs when properties are propagated from an inherited class, or
+ are applied from the columns present in the mapped table.
+
+ """
+
+ # check for class-bound attributes and/or descriptors,
+ # either local or from an inherited class
+ # ignore dataclass field default values
+ if local:
+ if self.class_.__dict__.get(
+ assigned_name, None
+ ) is not None and self._is_userland_descriptor(
+ assigned_name, self.class_.__dict__[assigned_name]
+ ):
+ return True
+ else:
+ attr = self.class_manager._get_class_attr_mro(assigned_name, None)
+ if attr is not None and self._is_userland_descriptor(
+ assigned_name, attr
+ ):
+ return True
+
+ if (
+ self.include_properties is not None
+ and name not in self.include_properties
+ and (column is None or column not in self.include_properties)
+ ):
+ self._log("not including property %s" % (name))
+ return True
+
+ if self.exclude_properties is not None and (
+ name in self.exclude_properties
+ or (column is not None and column in self.exclude_properties)
+ ):
+ self._log("excluding property %s" % (name))
+ return True
+
+ return False
+
+ def common_parent(self, other):
+ """Return true if the given mapper shares a
+ common inherited parent as this mapper."""
+
+ return self.base_mapper is other.base_mapper
+
+ def is_sibling(self, other):
+ """return true if the other mapper is an inheriting sibling to this
+ one. common parent but different branch
+
+ """
+ return (
+ self.base_mapper is other.base_mapper
+ and not self.isa(other)
+ and not other.isa(self)
+ )
+
+ def _canload(self, state, allow_subtypes):
+ s = self.primary_mapper()
+ if self.polymorphic_on is not None or allow_subtypes:
+ return _state_mapper(state).isa(s)
+ else:
+ return _state_mapper(state) is s
+
+ def isa(self, other):
+ """Return True if the this mapper inherits from the given mapper."""
+
+ m = self
+ while m and m is not other:
+ m = m.inherits
+ return bool(m)
+
+ def iterate_to_root(self):
+ m = self
+ while m:
+ yield m
+ m = m.inherits
+
+ @HasMemoized.memoized_attribute
+ def self_and_descendants(self):
+ """The collection including this mapper and all descendant mappers.
+
+ This includes not just the immediately inheriting mappers but
+ all their inheriting mappers as well.
+
+ """
+ descendants = []
+ stack = deque([self])
+ while stack:
+ item = stack.popleft()
+ descendants.append(item)
+ stack.extend(item._inheriting_mappers)
+ return util.WeakSequence(descendants)
+
+ def polymorphic_iterator(self):
+ """Iterate through the collection including this mapper and
+ all descendant mappers.
+
+ This includes not just the immediately inheriting mappers but
+ all their inheriting mappers as well.
+
+ To iterate through an entire hierarchy, use
+ ``mapper.base_mapper.polymorphic_iterator()``.
+
+ """
+ return iter(self.self_and_descendants)
+
+ def primary_mapper(self):
+ """Return the primary mapper corresponding to this mapper's class key
+ (class)."""
+
+ return self.class_manager.mapper
+
+ @property
+ def primary_base_mapper(self):
+ return self.class_manager.mapper.base_mapper
+
+ def _result_has_identity_key(self, result, adapter=None):
+ pk_cols = self.primary_key
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+ rk = result.keys()
+ for col in pk_cols:
+ if col not in rk:
+ return False
+ else:
+ return True
+
+ def identity_key_from_row(self, row, identity_token=None, adapter=None):
+ """Return an identity-map key for use in storing/retrieving an
+ item from the identity map.
+
+ :param row: A :class:`.Row` instance. The columns which are
+ mapped by this :class:`_orm.Mapper` should be locatable in the row,
+ preferably via the :class:`_schema.Column`
+ object directly (as is the case
+ when a :func:`_expression.select` construct is executed), or
+ via string names of the form ``<tablename>_<colname>``.
+
+ """
+ pk_cols = self.primary_key
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+
+ return (
+ self._identity_class,
+ tuple(row[column] for column in pk_cols),
+ identity_token,
+ )
+
+ def identity_key_from_primary_key(self, primary_key, identity_token=None):
+ """Return an identity-map key for use in storing/retrieving an
+ item from an identity map.
+
+ :param primary_key: A list of values indicating the identifier.
+
+ """
+ return self._identity_class, tuple(primary_key), identity_token
+
+ def identity_key_from_instance(self, instance):
+ """Return the identity key for the given instance, based on
+ its primary key attributes.
+
+ If the instance's state is expired, calling this method
+ will result in a database check to see if the object has been deleted.
+ If the row no longer exists,
+ :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ This value is typically also found on the instance state under the
+ attribute name `key`.
+
+ """
+ state = attributes.instance_state(instance)
+ return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
+
+ def _identity_key_from_state(
+ self, state, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+ dict_ = state.dict
+ manager = state.manager
+ return (
+ self._identity_class,
+ tuple(
+ [
+ manager[prop.key].impl.get(state, dict_, passive)
+ for prop in self._identity_key_props
+ ]
+ ),
+ state.identity_token,
+ )
+
+ def primary_key_from_instance(self, instance):
+ """Return the list of primary key values for the given
+ instance.
+
+ If the instance's state is expired, calling this method
+ will result in a database check to see if the object has been deleted.
+ If the row no longer exists,
+ :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ """
+ state = attributes.instance_state(instance)
+ identity_key = self._identity_key_from_state(
+ state, attributes.PASSIVE_OFF
+ )
+ return identity_key[1]
+
+ @HasMemoized.memoized_attribute
+ def _persistent_sortkey_fn(self):
+ key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+ if set(key_fns).difference([None]):
+
+ def key(state):
+ return tuple(
+ key_fn(val) if key_fn is not None else val
+ for key_fn, val in zip(key_fns, state.key[1])
+ )
+
+ else:
+
+ def key(state):
+ return state.key[1]
+
+ return key
+
+ @HasMemoized.memoized_attribute
+ def _identity_key_props(self):
+ return [self._columntoproperty[col] for col in self.primary_key]
+
+ @HasMemoized.memoized_attribute
+ def _all_pk_cols(self):
+ collection = set()
+ for table in self.tables:
+ collection.update(self._pks_by_table[table])
+ return collection
+
+ @HasMemoized.memoized_attribute
+ def _should_undefer_in_wildcard(self):
+ cols = set(self.primary_key)
+ if self.polymorphic_on is not None:
+ cols.add(self.polymorphic_on)
+ return cols
+
+ @HasMemoized.memoized_attribute
+ def _primary_key_propkeys(self):
+ return {self._columntoproperty[col].key for col in self._all_pk_cols}
+
+ def _get_state_attr_by_column(
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+ prop = self._columntoproperty[column]
+ return state.manager[prop.key].impl.get(state, dict_, passive=passive)
+
+ def _set_committed_state_attr_by_column(self, state, dict_, column, value):
+ prop = self._columntoproperty[column]
+ state.manager[prop.key].impl.set_committed_value(state, dict_, value)
+
+ def _set_state_attr_by_column(self, state, dict_, column, value):
+ prop = self._columntoproperty[column]
+ state.manager[prop.key].impl.set(state, dict_, value, None)
+
+ def _get_committed_attr_by_column(self, obj, column):
+ state = attributes.instance_state(obj)
+ dict_ = attributes.instance_dict(obj)
+ return self._get_committed_state_attr_by_column(
+ state, dict_, column, passive=attributes.PASSIVE_OFF
+ )
+
+ def _get_committed_state_attr_by_column(
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+
+ prop = self._columntoproperty[column]
+ return state.manager[prop.key].impl.get_committed_value(
+ state, dict_, passive=passive
+ )
+
+ def _optimized_get_statement(self, state, attribute_names):
+ """assemble a WHERE clause which retrieves a given state by primary
+ key, using a minimized set of tables.
+
+ Applies to a joined-table inheritance mapper where the
+ requested attribute names are only present on joined tables,
+ not the base table. The WHERE clause attempts to include
+ only those tables to minimize joins.
+
+ """
+ props = self._props
+
+ col_attribute_names = set(attribute_names).intersection(
+ state.mapper.column_attrs.keys()
+ )
+ tables = set(
+ chain(
+ *[
+ sql_util.find_tables(c, check_columns=True)
+ for key in col_attribute_names
+ for c in props[key].columns
+ ]
+ )
+ )
+
+ if self.base_mapper.local_table in tables:
+ return None
+
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+
+ if leftcol.table not in tables:
+ leftval = self._get_committed_state_attr_by_column(
+ state,
+ state.dict,
+ leftcol,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+ if leftval in orm_util._none_set:
+ raise _OptGetColumnsNotAvailable()
+ binary.left = sql.bindparam(
+ None, leftval, type_=binary.right.type
+ )
+ elif rightcol.table not in tables:
+ rightval = self._get_committed_state_attr_by_column(
+ state,
+ state.dict,
+ rightcol,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+ if rightval in orm_util._none_set:
+ raise _OptGetColumnsNotAvailable()
+ binary.right = sql.bindparam(
+ None, rightval, type_=binary.right.type
+ )
+
+ allconds = []
+
+ start = False
+
+ # as of #7507, from the lowest base table on upwards,
+ # we include all intermediary tables.
+
+ for mapper in reversed(list(self.iterate_to_root())):
+ if mapper.local_table in tables:
+ start = True
+ elif not isinstance(mapper.local_table, expression.TableClause):
+ return None
+ if start and not mapper.single:
+ allconds.append(mapper.inherit_condition)
+ tables.add(mapper.local_table)
+
+ # only the bottom table needs its criteria to be altered to fit
+ # the primary key ident - the rest of the tables upwards to the
+ # descendant-most class should all be present and joined to each
+ # other.
+ try:
+ allconds[0] = visitors.cloned_traverse(
+ allconds[0], {}, {"binary": visit_binary}
+ )
+ except _OptGetColumnsNotAvailable:
+ return None
+
+ cond = sql.and_(*allconds)
+
+ cols = []
+ for key in col_attribute_names:
+ cols.extend(props[key].columns)
+ return (
+ sql.select(*cols)
+ .where(cond)
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ )
+
+ def _iterate_to_target_viawpoly(self, mapper):
+ if self.isa(mapper):
+ prev = self
+ for m in self.iterate_to_root():
+ yield m
+
+ if m is not prev and prev not in m._with_polymorphic_mappers:
+ break
+
+ prev = m
+ if m is mapper:
+ break
+
+ def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
+ if not enabled_via_opt:
+ # common case, takes place for all polymorphic loads
+ mapper = polymorphic_from
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == "selectin":
+ return m
+ else:
+ # uncommon case, selectin load options were used
+ enabled_via_opt = set(enabled_via_opt)
+ enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt}
+ for entity in enabled_via_opt.union([polymorphic_from]):
+ mapper = entity.mapper
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if (
+ m.polymorphic_load == "selectin"
+ or m in enabled_via_opt_mappers
+ ):
+ return enabled_via_opt_mappers.get(m, m)
+
+ return None
+
+ @util.preload_module("sqlalchemy.orm.strategy_options")
+ def _subclass_load_via_in(self, entity):
+ """Assemble a that can load the columns local to
+ this subclass as a SELECT with IN.
+
+ """
+ strategy_options = util.preloaded.orm_strategy_options
+
+ assert self.inherits
+
+ if self.polymorphic_on is not None:
+ polymorphic_prop = self._columntoproperty[self.polymorphic_on]
+ keep_props = set([polymorphic_prop] + self._identity_key_props)
+ else:
+ keep_props = set(self._identity_key_props)
+
+ disable_opt = strategy_options.Load(entity)
+ enable_opt = strategy_options.Load(entity)
+
+ for prop in self.attrs:
+ if prop.parent is self or prop in keep_props:
+ # "enable" options, to turn on the properties that we want to
+ # load by default (subject to options from the query)
+ if not isinstance(prop, StrategizedProperty):
+ continue
+
+ enable_opt.set_generic_strategy(
+ # convert string name to an attribute before passing
+ # to loader strategy
+ (getattr(entity.entity_namespace, prop.key),),
+ dict(prop.strategy_key),
+ )
+ else:
+ # "disable" options, to turn off the properties from the
+ # superclass that we *don't* want to load, applied after
+ # the options from the query to override them
+ disable_opt.set_generic_strategy(
+ # convert string name to an attribute before passing
+ # to loader strategy
+ (getattr(entity.entity_namespace, prop.key),),
+ {"do_nothing": True},
+ )
+
+ primary_key = [
+ sql_util._deep_annotate(pk, {"_orm_adapt": True})
+ for pk in self.primary_key
+ ]
+
+ if len(primary_key) > 1:
+ in_expr = sql.tuple_(*primary_key)
+ else:
+ in_expr = primary_key[0]
+
+ if entity.is_aliased_class:
+ assert entity.mapper is self
+
+ q = sql.select(entity).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+
+ in_expr = entity._adapter.traverse(in_expr)
+ primary_key = [entity._adapter.traverse(k) for k in primary_key]
+ q = q.where(
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
+ ).order_by(*primary_key)
+ else:
+
+ q = sql.select(self).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ q = q.where(
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
+ ).order_by(*primary_key)
+
+ return q, enable_opt, disable_opt
+
+ @HasMemoized.memoized_attribute
+ def _subclass_load_via_in_mapper(self):
+ return self._subclass_load_via_in(self)
+
+ def cascade_iterator(self, type_, state, halt_on=None):
+ r"""Iterate each element and its mapper in an object graph,
+ for all relationships that meet the given cascade rule.
+
+ :param type\_:
+ The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``,
+ etc.).
+
+ .. note:: the ``"all"`` cascade is not accepted here. For a generic
+ object traversal function, see :ref:`faq_walk_objects`.
+
+ :param state:
+ The lead InstanceState. child items will be processed per
+ the relationships defined for this object's mapper.
+
+ :return: the method yields individual object instances.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades`
+
+ :ref:`faq_walk_objects` - illustrates a generic function to
+ traverse all objects without relying on cascades.
+
+ """
+ visited_states = set()
+ prp, mpp = object(), object()
+
+ assert state.mapper.isa(self)
+
+ visitables = deque(
+ [(deque(state.mapper._props.values()), prp, state, state.dict)]
+ )
+
+ while visitables:
+ iterator, item_type, parent_state, parent_dict = visitables[-1]
+ if not iterator:
+ visitables.pop()
+ continue
+
+ if item_type is prp:
+ prop = iterator.popleft()
+ if type_ not in prop.cascade:
+ continue
+ queue = deque(
+ prop.cascade_iterator(
+ type_,
+ parent_state,
+ parent_dict,
+ visited_states,
+ halt_on,
+ )
+ )
+ if queue:
+ visitables.append((queue, mpp, None, None))
+ elif item_type is mpp:
+ (
+ instance,
+ instance_mapper,
+ corresponding_state,
+ corresponding_dict,
+ ) = iterator.popleft()
+ yield (
+ instance,
+ instance_mapper,
+ corresponding_state,
+ corresponding_dict,
+ )
+ visitables.append(
+ (
+ deque(instance_mapper._props.values()),
+ prp,
+ corresponding_state,
+ corresponding_dict,
+ )
+ )
+
+ @HasMemoized.memoized_attribute
+ def _compiled_cache(self):
+ return util.LRUCache(self._compiled_cache_size)
+
+ @HasMemoized.memoized_attribute
+ def _sorted_tables(self):
+ table_to_mapper = {}
+
+ for mapper in self.base_mapper.self_and_descendants:
+ for t in mapper.tables:
+ table_to_mapper.setdefault(t, mapper)
+
+ extra_dependencies = []
+ for table, mapper in table_to_mapper.items():
+ super_ = mapper.inherits
+ if super_:
+ extra_dependencies.extend(
+ [(super_table, table) for super_table in super_.tables]
+ )
+
+ def skip(fk):
+ # attempt to skip dependencies that are not
+ # significant to the inheritance chain
+ # for two tables that are related by inheritance.
+ # while that dependency may be important, it's technically
+ # not what we mean to sort on here.
+ parent = table_to_mapper.get(fk.parent.table)
+ dep = table_to_mapper.get(fk.column.table)
+ if (
+ parent is not None
+ and dep is not None
+ and dep is not parent
+ and dep.inherit_condition is not None
+ ):
+ cols = set(sql_util._find_columns(dep.inherit_condition))
+ if parent.inherit_condition is not None:
+ cols = cols.union(
+ sql_util._find_columns(parent.inherit_condition)
+ )
+ return fk.parent not in cols and fk.column not in cols
+ else:
+ return fk.parent not in cols
+ return False
+
+ sorted_ = sql_util.sort_tables(
+ table_to_mapper,
+ skip_fn=skip,
+ extra_dependencies=extra_dependencies,
+ )
+
+ ret = util.OrderedDict()
+ for t in sorted_:
+ ret[t] = table_to_mapper[t]
+ return ret
+
+ def _memo(self, key, callable_):
+ if key in self._memoized_values:
+ return self._memoized_values[key]
+ else:
+ self._memoized_values[key] = value = callable_()
+ return value
+
+ @util.memoized_property
+ def _table_to_equated(self):
+ """memoized map of tables to collections of columns to be
+ synchronized upwards to the base mapper."""
+
+ result = util.defaultdict(list)
+
+ for table in self._sorted_tables:
+ cols = set(table.c)
+ for m in self.iterate_to_root():
+ if m._inherits_equated_pairs and cols.intersection(
+ util.reduce(
+ set.union,
+ [l.proxy_set for l, r in m._inherits_equated_pairs],
+ )
+ ):
+ result[table].append((m, m._inherits_equated_pairs))
+
+ return result
+
+
+class _OptGetColumnsNotAvailable(Exception):
+ pass
+
+
+def configure_mappers():
+ """Initialize the inter-mapper relationships of all mappers that
+ have been constructed thus far across all :class:`_orm.registry`
+ collections.
+
+ The configure step is used to reconcile and initialize the
+ :func:`_orm.relationship` linkages between mapped classes, as well as to
+ invoke configuration events such as the
+ :meth:`_orm.MapperEvents.before_configured` and
+ :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM
+ extensions or user-defined extension hooks.
+
+ Mapper configuration is normally invoked automatically, the first time
+ mappings from a particular :class:`_orm.registry` are used, as well as
+ whenever mappings are used and additional not-yet-configured mappers have
+ been constructed. The automatic configuration process however is local only
+ to the :class:`_orm.registry` involving the target mapper and any related
+ :class:`_orm.registry` objects which it may depend on; this is
+ equivalent to invoking the :meth:`_orm.registry.configure` method
+ on a particular :class:`_orm.registry`.
+
+ By contrast, the :func:`_orm.configure_mappers` function will invoke the
+ configuration process on all :class:`_orm.registry` objects that
+ exist in memory, and may be useful for scenarios where many individual
+ :class:`_orm.registry` objects that are nonetheless interrelated are
+ in use.
+
+ .. versionchanged:: 1.4
+
+ As of SQLAlchemy 1.4.0b2, this function works on a
+ per-:class:`_orm.registry` basis, locating all :class:`_orm.registry`
+ objects present and invoking the :meth:`_orm.registry.configure` method
+ on each. The :meth:`_orm.registry.configure` method may be preferred to
+ limit the configuration of mappers to those local to a particular
+ :class:`_orm.registry` and/or declarative base class.
+
+ Points at which automatic configuration is invoked include when a mapped
+ class is instantiated into an instance, as well as when ORM queries
+ are emitted using :meth:`.Session.query` or :meth:`_orm.Session.execute`
+ with an ORM-enabled statement.
+
+ The mapper configure process, whether invoked by
+ :func:`_orm.configure_mappers` or from :meth:`_orm.registry.configure`,
+ provides several event hooks that can be used to augment the mapper
+ configuration step. These hooks include:
+
+ * :meth:`.MapperEvents.before_configured` - called once before
+ :func:`.configure_mappers` or :meth:`_orm.registry.configure` does any
+ work; this can be used to establish additional options, properties, or
+ related mappings before the operation proceeds.
+
+ * :meth:`.MapperEvents.mapper_configured` - called as each individual
+ :class:`_orm.Mapper` is configured within the process; will include all
+ mapper state except for backrefs set up by other mappers that are still
+ to be configured.
+
+ * :meth:`.MapperEvents.after_configured` - called once after
+ :func:`.configure_mappers` or :meth:`_orm.registry.configure` is
+ complete; at this stage, all :class:`_orm.Mapper` objects that fall
+ within the scope of the configuration operation will be fully configured.
+ Note that the calling application may still have other mappings that
+ haven't been produced yet, such as if they are in modules as yet
+ unimported, and may also have mappings that are still to be configured,
+ if they are in other :class:`_orm.registry` collections not part of the
+ current scope of configuration.
+
+ """
+
+ _configure_registries(_all_registries(), cascade=True)
+
+
+def _configure_registries(registries, cascade):
+ for reg in registries:
+ if reg._new_mappers:
+ break
+ else:
+ return
+
+ with _CONFIGURE_MUTEX:
+ global _already_compiling
+ if _already_compiling:
+ return
+ _already_compiling = True
+ try:
+
+ # double-check inside mutex
+ for reg in registries:
+ if reg._new_mappers:
+ break
+ else:
+ return
+
+ Mapper.dispatch._for_class(Mapper).before_configured()
+ # initialize properties on all mappers
+ # note that _mapper_registry is unordered, which
+ # may randomly conceal/reveal issues related to
+ # the order of mapper compilation
+
+ _do_configure_registries(registries, cascade)
+ finally:
+ _already_compiling = False
+ Mapper.dispatch._for_class(Mapper).after_configured()
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _do_configure_registries(registries, cascade):
+
+ registry = util.preloaded.orm_decl_api.registry
+
+ orig = set(registries)
+
+ for reg in registry._recurse_with_dependencies(registries):
+ has_skip = False
+
+ for mapper in reg._mappers_to_configure():
+ run_configure = None
+ for fn in mapper.dispatch.before_mapper_configured:
+ run_configure = fn(mapper, mapper.class_)
+ if run_configure is EXT_SKIP:
+ has_skip = True
+ break
+ if run_configure is EXT_SKIP:
+ continue
+
+ if getattr(mapper, "_configure_failed", False):
+ e = sa_exc.InvalidRequestError(
+ "One or more mappers failed to initialize - "
+ "can't proceed with initialization of other "
+ "mappers. Triggering mapper: '%s'. "
+ "Original exception was: %s"
+ % (mapper, mapper._configure_failed)
+ )
+ e._configure_failed = mapper._configure_failed
+ raise e
+
+ if not mapper.configured:
+ try:
+ mapper._post_configure_properties()
+ mapper._expire_memoizations()
+ mapper.dispatch.mapper_configured(mapper, mapper.class_)
+ except Exception:
+ exc = sys.exc_info()[1]
+ if not hasattr(exc, "_configure_failed"):
+ mapper._configure_failed = exc
+ raise
+ if not has_skip:
+ reg._new_mappers = False
+
+ if not cascade and reg._dependencies.difference(orig):
+ raise sa_exc.InvalidRequestError(
+ "configure was called with cascade=False but "
+ "additional registries remain"
+ )
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _dispose_registries(registries, cascade):
+
+ registry = util.preloaded.orm_decl_api.registry
+
+ orig = set(registries)
+
+ for reg in registry._recurse_with_dependents(registries):
+ if not cascade and reg._dependents.difference(orig):
+ raise sa_exc.InvalidRequestError(
+ "Registry has dependent registries that are not disposed; "
+ "pass cascade=True to clear these also"
+ )
+
+ while reg._managers:
+ try:
+ manager, _ = reg._managers.popitem()
+ except KeyError:
+ # guard against race between while and popitem
+ pass
+ else:
+ reg._dispose_manager_and_mapper(manager)
+
+ reg._non_primary_mappers.clear()
+ reg._dependents.clear()
+ for dep in reg._dependencies:
+ dep._dependents.discard(reg)
+ reg._dependencies.clear()
+ # this wasn't done in the 1.3 clear_mappers() and in fact it
+ # was a bug, as it could cause configure_mappers() to invoke
+ # the "before_configured" event even though mappers had all been
+ # disposed.
+ reg._new_mappers = False
+
+
+def reconstructor(fn):
+ """Decorate a method as the 'reconstructor' hook.
+
+ Designates a single method as the "reconstructor", an ``__init__``-like
+ method that will be called by the ORM after the instance has been
+ loaded from the database or otherwise reconstituted.
+
+ The reconstructor will be invoked with no arguments. Scalar
+ (non-collection) database-mapped attributes of the instance will
+ be available for use within the function. Eagerly-loaded
+ collections are generally not yet available and will usually only
+ contain the first element. ORM state changes made to objects at
+ this stage will not be recorded for the next flush() operation, so
+ the activity within a reconstructor should be conservative.
+
+ .. seealso::
+
+ :ref:`mapping_constructors`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+ fn.__sa_reconstructor__ = True
+ return fn
+
+
+def validates(*names, **kw):
+ r"""Decorate a method as a 'validator' for one or more named properties.
+
+ Designates a method as a validator, a method which receives the
+ name of the attribute as well as a value to be assigned, or in the
+ case of a collection, the value to be added to the collection.
+ The function can then raise validation exceptions to halt the
+ process from continuing (where Python's built-in ``ValueError``
+ and ``AssertionError`` exceptions are reasonable choices), or can
+ modify or replace the value before proceeding. The function should
+ otherwise return the given value.
+
+ Note that a validator for a collection **cannot** issue a load of that
+ collection within the validation routine - this usage raises
+ an assertion to avoid recursion overflows. This is a reentrant
+ condition which is not supported.
+
+ :param \*names: list of attribute names to be validated.
+ :param include_removes: if True, "remove" events will be
+ sent as well - the validation function must accept an additional
+ argument "is_remove" which will be a boolean.
+
+ :param include_backrefs: defaults to ``True``; if ``False``, the
+ validation function will not emit if the originator is an attribute
+ event related via a backref. This can be used for bi-directional
+ :func:`.validates` usage where only one validator should emit per
+ attribute operation.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`simple_validators` - usage examples for :func:`.validates`
+
+ """
+ include_removes = kw.pop("include_removes", False)
+ include_backrefs = kw.pop("include_backrefs", True)
+
+ def wrap(fn):
+ fn.__sa_validators__ = names
+ fn.__sa_validation_opts__ = {
+ "include_removes": include_removes,
+ "include_backrefs": include_backrefs,
+ }
+ return fn
+
+ return wrap
+
+
+def _event_on_load(state, ctx):
+ instrumenting_mapper = state.manager.mapper
+
+ if instrumenting_mapper._reconstructor:
+ instrumenting_mapper._reconstructor(state.obj())
+
+
+def _event_on_init(state, args, kwargs):
+ """Run init_instance hooks.
+
+ This also includes mapper compilation, normally not needed
+ here but helps with some piecemeal configuration
+ scenarios (such as in the ORM tutorial).
+
+ """
+
+ instrumenting_mapper = state.manager.mapper
+ if instrumenting_mapper:
+ instrumenting_mapper._check_configure()
+ if instrumenting_mapper._set_polymorphic_identity:
+ instrumenting_mapper._set_polymorphic_identity(state)
+
+
+class _ColumnMapping(dict):
+ """Error reporting helper for mapper._columntoproperty."""
+
+ __slots__ = ("mapper",)
+
+ def __init__(self, mapper):
+ # TODO: weakref would be a good idea here
+ self.mapper = mapper
+
+ def __missing__(self, column):
+ prop = self.mapper._props.get(column)
+ if prop:
+ raise orm_exc.UnmappedColumnError(
+ "Column '%s.%s' is not available, due to "
+ "conflicting property '%s':%r"
+ % (column.table.name, column.name, column.key, prop)
+ )
+ raise orm_exc.UnmappedColumnError(
+ "No column %s is configured on mapper %s..."
+ % (column, self.mapper)
+ )
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
new file mode 100644
index 0000000..331ddd7
--- /dev/null
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -0,0 +1,519 @@
+# orm/path_registry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Path tracking utilities, representing mapper graph traversals.
+
+"""
+
+from itertools import chain
+import logging
+
+from . import base as orm_base
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
+
+log = logging.getLogger(__name__)
+
+
+def _unreduce_path(path):
+ return PathRegistry.deserialize(path)
+
+
+_WILDCARD_TOKEN = "*"
+_DEFAULT_TOKEN = "_sa_default"
+
+
+class PathRegistry(HasCacheKey):
+ """Represent query load paths and registry functions.
+
+ Basically represents structures like:
+
+ (<User mapper>, "orders", <Order mapper>, "items", <Item mapper>)
+
+ These structures are generated by things like
+ query options (joinedload(), subqueryload(), etc.) and are
+ used to compose keys stored in the query._attributes dictionary
+ for various options.
+
+ They are then re-composed at query compile/result row time as
+ the query is formed and as rows are fetched, where they again
+ serve to compose keys to look up options in the context.attributes
+ dictionary, which is copied from query._attributes.
+
+ The path structure has a limited amount of caching, where each
+ "root" ultimately pulls from a fixed registry associated with
+ the first mapper, that also contains elements for each of its
+ property keys. However paths longer than two elements, which
+ are the exception rather than the rule, are generated on an
+ as-needed basis.
+
+ """
+
+ __slots__ = ()
+
+ is_token = False
+ is_root = False
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
+ ]
+
+ def __eq__(self, other):
+ try:
+ return other is not None and self.path == other._path_for_compare
+ except AttributeError:
+ util.warn(
+ "Comparison of PathRegistry to %r is not supported"
+ % (type(other))
+ )
+ return False
+
+ def __ne__(self, other):
+ try:
+ return other is None or self.path != other._path_for_compare
+ except AttributeError:
+ util.warn(
+ "Comparison of PathRegistry to %r is not supported"
+ % (type(other))
+ )
+ return True
+
+ @property
+ def _path_for_compare(self):
+ return self.path
+
+ def set(self, attributes, key, value):
+ log.debug("set '%s' on path '%s' to '%s'", key, self, value)
+ attributes[(key, self.natural_path)] = value
+
+ def setdefault(self, attributes, key, value):
+ log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
+ attributes.setdefault((key, self.natural_path), value)
+
+ def get(self, attributes, key, value=None):
+ key = (key, self.natural_path)
+ if key in attributes:
+ return attributes[key]
+ else:
+ return value
+
+ def __len__(self):
+ return len(self.path)
+
+ def __hash__(self):
+ return id(self)
+
+ @property
+ def length(self):
+ return len(self.path)
+
+ def pairs(self):
+ path = self.path
+ for i in range(0, len(path), 2):
+ yield path[i], path[i + 1]
+
+ def contains_mapper(self, mapper):
+ for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]:
+ if path_mapper.is_mapper and path_mapper.isa(mapper):
+ return True
+ else:
+ return False
+
+ def contains(self, attributes, key):
+ return (key, self.path) in attributes
+
+ def __reduce__(self):
+ return _unreduce_path, (self.serialize(),)
+
+ @classmethod
+ def _serialize_path(cls, path):
+ return list(
+ zip(
+ [
+ m.class_ if (m.is_mapper or m.is_aliased_class) else str(m)
+ for m in [path[i] for i in range(0, len(path), 2)]
+ ],
+ [
+ path[i].key if (path[i].is_property) else str(path[i])
+ for i in range(1, len(path), 2)
+ ]
+ + [None],
+ )
+ )
+
+ @classmethod
+ def _deserialize_path(cls, path):
+ def _deserialize_mapper_token(mcls):
+ return (
+ # note: we likely dont want configure=True here however
+ # this is maintained at the moment for backwards compatibility
+ orm_base._inspect_mapped_class(mcls, configure=True)
+ if mcls not in PathToken._intern
+ else PathToken._intern[mcls]
+ )
+
+ def _deserialize_key_token(mcls, key):
+ if key is None:
+ return None
+ elif key in PathToken._intern:
+ return PathToken._intern[key]
+ else:
+ return orm_base._inspect_mapped_class(
+ mcls, configure=True
+ ).attrs[key]
+
+ p = tuple(
+ chain(
+ *[
+ (
+ _deserialize_mapper_token(mcls),
+ _deserialize_key_token(mcls, key),
+ )
+ for mcls, key in path
+ ]
+ )
+ )
+ if p and p[-1] is None:
+ p = p[0:-1]
+ return p
+
+ @classmethod
+ def serialize_context_dict(cls, dict_, tokens):
+ return [
+ ((key, cls._serialize_path(path)), value)
+ for (key, path), value in [
+ (k, v)
+ for k, v in dict_.items()
+ if isinstance(k, tuple) and k[0] in tokens
+ ]
+ ]
+
+ @classmethod
+ def deserialize_context_dict(cls, serialized):
+ return util.OrderedDict(
+ ((key, tuple(cls._deserialize_path(path))), value)
+ for (key, path), value in serialized
+ )
+
+ def serialize(self):
+ path = self.path
+ return self._serialize_path(path)
+
+ @classmethod
+ def deserialize(cls, path):
+ if path is None:
+ return None
+ p = cls._deserialize_path(path)
+ return cls.coerce(p)
+
+ @classmethod
+ def per_mapper(cls, mapper):
+ if mapper.is_mapper:
+ return CachingEntityRegistry(cls.root, mapper)
+ else:
+ return SlotsEntityRegistry(cls.root, mapper)
+
+ @classmethod
+ def coerce(cls, raw):
+ return util.reduce(lambda prev, next: prev[next], raw, cls.root)
+
+ def token(self, token):
+ if token.endswith(":" + _WILDCARD_TOKEN):
+ return TokenRegistry(self, token)
+ elif token.endswith(":" + _DEFAULT_TOKEN):
+ return TokenRegistry(self.root, token)
+ else:
+ raise exc.ArgumentError("invalid token: %s" % token)
+
+ def __add__(self, other):
+ return util.reduce(lambda prev, next: prev[next], other.path, self)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.path)
+
+
+class RootRegistry(PathRegistry):
+ """Root registry, defers to mappers so that
+ paths are maintained per-root-mapper.
+
+ """
+
+ inherit_cache = True
+
+ path = natural_path = ()
+ has_entity = False
+ is_aliased_class = False
+ is_root = True
+
+ def __getitem__(self, entity):
+ if entity in PathToken._intern:
+ return PathToken._intern[entity]
+ else:
+ return entity._path_registry
+
+
+PathRegistry.root = RootRegistry()
+
+
+class PathToken(orm_base.InspectionAttr, HasCacheKey, str):
+ """cacheable string token"""
+
+ _intern = {}
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (str(self),)
+
+ @property
+ def _path_for_compare(self):
+ return None
+
+ @classmethod
+ def intern(cls, strvalue):
+ if strvalue in cls._intern:
+ return cls._intern[strvalue]
+ else:
+ cls._intern[strvalue] = result = PathToken(strvalue)
+ return result
+
+
+class TokenRegistry(PathRegistry):
+ __slots__ = ("token", "parent", "path", "natural_path")
+
+ inherit_cache = True
+
+ def __init__(self, parent, token):
+ token = PathToken.intern(token)
+
+ self.token = token
+ self.parent = parent
+ self.path = parent.path + (token,)
+ self.natural_path = parent.natural_path + (token,)
+
+ has_entity = False
+
+ is_token = True
+
+ def generate_for_superclasses(self):
+ if not self.parent.is_aliased_class and not self.parent.is_root:
+ for ent in self.parent.mapper.iterate_to_root():
+ yield TokenRegistry(self.parent.parent[ent], self.token)
+ elif (
+ self.parent.is_aliased_class
+ and self.parent.entity._is_with_polymorphic
+ ):
+ yield self
+ for ent in self.parent.entity._with_polymorphic_entities:
+ yield TokenRegistry(self.parent.parent[ent], self.token)
+ else:
+ yield self
+
+ def __getitem__(self, entity):
+ raise NotImplementedError()
+
+
+class PropRegistry(PathRegistry):
+ is_unnatural = False
+ inherit_cache = True
+
+ def __init__(self, parent, prop):
+ # restate this path in terms of the
+ # given MapperProperty's parent.
+ insp = inspection.inspect(parent[-1])
+ natural_parent = parent
+
+ if not insp.is_aliased_class or insp._use_mapper_path:
+ parent = natural_parent = parent.parent[prop.parent]
+ elif (
+ insp.is_aliased_class
+ and insp.with_polymorphic_mappers
+ and prop.parent in insp.with_polymorphic_mappers
+ ):
+ subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
+ parent = parent.parent[subclass_entity]
+
+ # when building a path where with_polymorphic() is in use,
+ # special logic to determine the "natural path" when subclass
+ # entities are used.
+ #
+ # here we are trying to distinguish between a path that starts
+ # on a the with_polymorhpic entity vs. one that starts on a
+ # normal entity that introduces a with_polymorphic() in the
+ # middle using of_type():
+ #
+ # # as in test_polymorphic_rel->
+ # # test_subqueryload_on_subclass_uses_path_correctly
+ # wp = with_polymorphic(RegularEntity, "*")
+ # sess.query(wp).options(someload(wp.SomeSubEntity.foos))
+ #
+ # vs
+ #
+ # # as in test_relationship->JoinedloadWPolyOfTypeContinued
+ # wp = with_polymorphic(SomeFoo, "*")
+ # sess.query(RegularEntity).options(
+ # someload(RegularEntity.foos.of_type(wp))
+ # .someload(wp.SubFoo.bar)
+ # )
+ #
+ # in the former case, the Query as it generates a path that we
+ # want to match will be in terms of the with_polymorphic at the
+ # beginning. in the latter case, Query will generate simple
+ # paths that don't know about this with_polymorphic, so we must
+ # use a separate natural path.
+ #
+ #
+ if parent.parent:
+ natural_parent = parent.parent[subclass_entity.mapper]
+ self.is_unnatural = True
+ else:
+ natural_parent = parent
+ elif (
+ natural_parent.parent
+ and insp.is_aliased_class
+ and prop.parent # this should always be the case here
+ is not insp.mapper
+ and insp.mapper.isa(prop.parent)
+ ):
+ natural_parent = parent.parent[prop.parent]
+
+ self.prop = prop
+ self.parent = parent
+ self.path = parent.path + (prop,)
+ self.natural_path = natural_parent.natural_path + (prop,)
+
+ self._wildcard_path_loader_key = (
+ "loader",
+ parent.path + self.prop._wildcard_token,
+ )
+ self._default_path_loader_key = self.prop._default_path_loader_key
+ self._loader_key = ("loader", self.natural_path)
+
+ def __str__(self):
+ return " -> ".join(str(elem) for elem in self.path)
+
+ @util.memoized_property
+ def has_entity(self):
+ return self.prop._links_to_entity
+
+ @util.memoized_property
+ def entity(self):
+ return self.prop.entity
+
+ @property
+ def mapper(self):
+ return self.prop.mapper
+
+ @property
+ def entity_path(self):
+ return self[self.entity]
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ else:
+ return SlotsEntityRegistry(self, entity)
+
+
+class AbstractEntityRegistry(PathRegistry):
+ __slots__ = ()
+
+ has_entity = True
+
+ def __init__(self, parent, entity):
+ self.key = entity
+ self.parent = parent
+ self.is_aliased_class = entity.is_aliased_class
+ self.entity = entity
+ self.path = parent.path + (entity,)
+
+ # the "natural path" is the path that we get when Query is traversing
+ # from the lead entities into the various relationships; it corresponds
+ # to the structure of mappers and relationships. when we are given a
+ # path that comes from loader options, as of 1.3 it can have ac-hoc
+ # with_polymorphic() and other AliasedInsp objects inside of it, which
+ # are usually not present in mappings. So here we track both the
+ # "enhanced" path in self.path and the "natural" path that doesn't
+ # include those objects so these two traversals can be matched up.
+
+ # the test here for "(self.is_aliased_class or parent.is_unnatural)"
+ # are to avoid the more expensive conditional logic that follows if we
+ # know we don't have to do it. This conditional can just as well be
+ # "if parent.path:", it just is more function calls.
+ if parent.path and (self.is_aliased_class or parent.is_unnatural):
+ # this is an infrequent code path used only for loader strategies
+ # that also make use of of_type().
+ if entity.mapper.isa(parent.natural_path[-1].entity):
+ self.natural_path = parent.natural_path + (entity.mapper,)
+ else:
+ self.natural_path = parent.natural_path + (
+ parent.natural_path[-1].entity,
+ )
+ # it seems to make sense that since these paths get mixed up
+ # with statements that are cached or not, we should make
+ # sure the natural path is cacheable across different occurrences
+ # of equivalent AliasedClass objects. however, so far this
+ # does not seem to be needed for whatever reason.
+ # elif not parent.path and self.is_aliased_class:
+ # self.natural_path = (self.entity._generate_cache_key()[0], )
+ else:
+ # self.natural_path = parent.natural_path + (entity, )
+ self.natural_path = self.path
+
+ @property
+ def entity_path(self):
+ return self
+
+ @property
+ def mapper(self):
+ return inspection.inspect(self.entity).mapper
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ elif entity in PathToken._intern:
+ return TokenRegistry(self, PathToken._intern[entity])
+ else:
+ return PropRegistry(self, entity)
+
+
+class SlotsEntityRegistry(AbstractEntityRegistry):
+ # for aliased class, return lightweight, no-cycles created
+ # version
+ inherit_cache = True
+
+ __slots__ = (
+ "key",
+ "parent",
+ "is_aliased_class",
+ "entity",
+ "path",
+ "natural_path",
+ )
+
+
+class CachingEntityRegistry(AbstractEntityRegistry, dict):
+ # for long lived mapper, return dict based caching
+ # version that creates reference cycles
+
+ inherit_cache = True
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ else:
+ return dict.__getitem__(self, entity)
+
+ def __missing__(self, key):
+ self[key] = item = PropRegistry(self, key)
+
+ return item
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
new file mode 100644
index 0000000..a17b24a
--- /dev/null
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -0,0 +1,2517 @@
+# orm/persistence.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to emit INSERT, UPDATE
+and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
+mappers.
+
+The functions here are called only by the unit of work functions
+in unitofwork.py.
+
+"""
+
+from itertools import chain
+from itertools import groupby
+import operator
+
+from . import attributes
+from . import evaluator
+from . import exc as orm_exc
+from . import loading
+from . import sync
+from .base import NO_VALUE
+from .base import state_str
+from .. import exc as sa_exc
+from .. import future
+from .. import sql
+from .. import util
+from ..engine import result as _result
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import select
+from ..sql import sqltypes
+from ..sql.base import _entity_namespace_key
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import InsertDMLState
+from ..sql.dml import UpdateDMLState
+from ..sql.elements import BooleanClauseList
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+
+
+def _bulk_insert(
+ mapper,
+ mappings,
+ session_transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+):
+ base_mapper = mapper.base_mapper
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_insert()"
+ )
+
+ if isstates:
+ if return_defaults:
+ states = [(state, state.dict) for state in mappings]
+ mappings = [dict_ for (state, dict_) in states]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ connection = session_transaction.connection(base_mapper)
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = (
+ (
+ None,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+ for (
+ state,
+ state_dict,
+ params,
+ mp,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in _collect_insert_commands(
+ table,
+ ((None, mapping, mapper, connection) for mapping in mappings),
+ bulk=True,
+ return_defaults=return_defaults,
+ render_nulls=render_nulls,
+ )
+ )
+ _emit_insert_statements(
+ base_mapper,
+ None,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=return_defaults,
+ )
+
+ if return_defaults and isstates:
+ identity_cls = mapper._identity_class
+ identity_props = [p.key for p in mapper._identity_key_props]
+ for state, dict_ in states:
+ state.key = (
+ identity_cls,
+ tuple([dict_[key] for key in identity_props]),
+ )
+
+
+def _bulk_update(
+ mapper, mappings, session_transaction, isstates, update_changed_only
+):
+ base_mapper = mapper.base_mapper
+
+ search_keys = mapper._primary_key_propkeys
+ if mapper._version_id_prop:
+ search_keys = {mapper._version_id_prop.key}.union(search_keys)
+
+ def _changed_dict(mapper, state):
+ return dict(
+ (k, v)
+ for k, v in state.dict.items()
+ if k in state.committed_state or k in search_keys
+ )
+
+ if isstates:
+ if update_changed_only:
+ mappings = [_changed_dict(mapper, state) for state in mappings]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_update()"
+ )
+
+ connection = session_transaction.connection(base_mapper)
+
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = _collect_update_commands(
+ None,
+ table,
+ (
+ (
+ None,
+ mapping,
+ mapper,
+ connection,
+ (
+ mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop
+ else None
+ ),
+ )
+ for mapping in mappings
+ ),
+ bulk=True,
+ )
+
+ _emit_update_statements(
+ base_mapper,
+ None,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=False,
+ )
+
+
+def save_obj(base_mapper, states, uowtransaction, single=False):
+ """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
+ of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
+
+ """
+
+ # if batch=false, call _save_obj separately for each object
+ if not single and not base_mapper.batch:
+ for state in _sort_states(base_mapper, states):
+ save_obj(base_mapper, [state], uowtransaction, single=True)
+ return
+
+ states_to_update = []
+ states_to_insert = []
+
+ for (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ ) in _organize_states_for_save(base_mapper, states, uowtransaction):
+ if has_identity or row_switch:
+ states_to_update.append(
+ (state, dict_, mapper, connection, update_version_id)
+ )
+ else:
+ states_to_insert.append((state, dict_, mapper, connection))
+
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+ insert = _collect_insert_commands(table, states_to_insert)
+
+ update = _collect_update_commands(
+ uowtransaction, table, states_to_update
+ )
+
+ _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ )
+
+ _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ insert,
+ )
+
+ _finalize_insert_update_commands(
+ base_mapper,
+ uowtransaction,
+ chain(
+ (
+ (state, state_dict, mapper, connection, False)
+ for (state, state_dict, mapper, connection) in states_to_insert
+ ),
+ (
+ (state, state_dict, mapper, connection, True)
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update
+ ),
+ ),
+ )
+
+
+def post_update(base_mapper, states, uowtransaction, post_update_cols):
+ """Issue UPDATE statements on behalf of a relationship() which
+ specifies post_update.
+
+ """
+
+ states_to_update = list(
+ _organize_states_for_post_update(base_mapper, states, uowtransaction)
+ )
+
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+
+ update = (
+ (
+ state,
+ state_dict,
+ sub_mapper,
+ connection,
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict, mapper.version_id_col
+ )
+ if mapper.version_id_col is not None
+ else None,
+ )
+ for state, state_dict, sub_mapper, connection in states_to_update
+ if table in sub_mapper._pks_by_table
+ )
+
+ update = _collect_post_update_commands(
+ base_mapper, uowtransaction, table, update, post_update_cols
+ )
+
+ _emit_post_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ )
+
+
+def delete_obj(base_mapper, states, uowtransaction):
+ """Issue ``DELETE`` statements for a list of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation.
+
+ """
+
+ states_to_delete = list(
+ _organize_states_for_delete(base_mapper, states, uowtransaction)
+ )
+
+ table_to_mapper = base_mapper._sorted_tables
+
+ for table in reversed(list(table_to_mapper.keys())):
+ mapper = table_to_mapper[table]
+ if table not in mapper._pks_by_table:
+ continue
+ elif mapper.inherits and mapper.passive_deletes:
+ continue
+
+ delete = _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+ )
+
+ _emit_delete_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ delete,
+ )
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
+ mapper.dispatch.after_delete(mapper, connection, state)
+
+
+def _organize_states_for_save(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for INSERT or
+ UPDATE.
+
+ This includes splitting out into distinct lists for
+ each, calling before_insert/before_update, obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state,
+ and the identity flag.
+
+ """
+
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction, states
+ ):
+
+ has_identity = bool(state.key)
+
+ instance_key = state.key or mapper._identity_key_from_state(state)
+
+ row_switch = update_version_id = None
+
+ # call before_XXX extensions
+ if not has_identity:
+ mapper.dispatch.before_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.before_update(mapper, connection, state)
+
+ if mapper._validate_polymorphic_identity:
+ mapper._validate_polymorphic_identity(mapper, state, dict_)
+
+ # detect if we have a "pending" instance (i.e. has
+ # no instance_key attached to it), and another instance
+ # with the same identity key already exists as persistent.
+ # convert to an UPDATE if so.
+ if (
+ not has_identity
+ and instance_key in uowtransaction.session.identity_map
+ ):
+ instance = uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
+
+ if not uowtransaction.was_already_deleted(existing):
+ if not uowtransaction.is_deleted(existing):
+ util.warn(
+ "New instance %s with identity key %s conflicts "
+ "with persistent instance %s"
+ % (state_str(state), instance_key, state_str(existing))
+ )
+ else:
+ base_mapper._log_debug(
+ "detected row switch for identity %s. "
+ "will update %s, remove %s from "
+ "transaction",
+ instance_key,
+ state_str(state),
+ state_str(existing),
+ )
+
+ # remove the "delete" flag from the existing element
+ uowtransaction.remove_state_actions(existing)
+ row_switch = existing
+
+ if (has_identity or row_switch) and mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ row_switch if row_switch else state,
+ row_switch.dict if row_switch else dict_,
+ mapper.version_id_col,
+ )
+
+ yield (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ )
+
+
+def _organize_states_for_post_update(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for UPDATE
+ corresponding to post_update.
+
+ This includes obtaining key information for each state
+ including its dictionary, mapper, the connection to use for
+ the execution per state.
+
+ """
+ return _connections_for_states(base_mapper, uowtransaction, states)
+
+
+def _organize_states_for_delete(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for DELETE.
+
+ This includes calling out before_delete and obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state.
+
+ """
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction, states
+ ):
+
+ mapper.dispatch.before_delete(mapper, connection, state)
+
+ if mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ )
+ else:
+ update_version_id = None
+
+ yield (state, dict_, mapper, connection, update_version_id)
+
+
+def _collect_insert_commands(
+ table,
+ states_to_insert,
+ bulk=False,
+ return_defaults=False,
+ render_nulls=False,
+):
+ """Identify sets of values to use in INSERT statements for a
+ list of states.
+
+ """
+ for state, state_dict, mapper, connection in states_to_insert:
+ if table not in mapper._pks_by_table:
+ continue
+
+ params = {}
+ value_params = {}
+
+ propkey_to_col = mapper._propkey_to_col[table]
+
+ eval_none = mapper._insert_cols_evaluating_none[table]
+
+ for propkey in set(propkey_to_col).intersection(state_dict):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+ if value is None and col not in eval_none and not render_nulls:
+ continue
+ elif not bulk and (
+ hasattr(value, "__clause_element__")
+ or isinstance(value, sql.ClauseElement)
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ else:
+ params[col.key] = value
+
+ if not bulk:
+ # for all the columns that have no default and we don't have
+ # a value and where "None" is not a special value, add
+ # explicit None to the INSERT. This is a legacy behavior
+ # which might be worth removing, as it should not be necessary
+ # and also produces confusion, given that "missing" and None
+ # now have distinct meanings
+ for colkey in (
+ mapper._insert_cols_as_none[table]
+ .difference(params)
+ .difference([c.key for c in value_params])
+ ):
+ params[colkey] = None
+
+ if not bulk or return_defaults:
+ # params are in terms of Column key objects, so
+ # compare to pk_keys_by_table
+ has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
+
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_default_cols[table].issubset(
+ params
+ )
+ else:
+ has_all_defaults = True
+ else:
+ has_all_defaults = has_all_pks = True
+
+ if (
+ mapper.version_id_generator is not False
+ and mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = mapper.version_id_generator(
+ None
+ )
+
+ yield (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+
+
+def _collect_update_commands(
+ uowtransaction, table, states_to_update, bulk=False
+):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states.
+
+ This function works intricately with the history system
+ to determine exactly what values should be updated
+ as well as how the row should be matched within an UPDATE
+ statement. Includes some tricky scenarios where the primary
+ key of an object might have been changed.
+
+ """
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
+
+ if table not in mapper._pks_by_table:
+ continue
+
+ pks = mapper._pks_by_table[table]
+
+ value_params = {}
+
+ propkey_to_col = mapper._propkey_to_col[table]
+
+ if bulk:
+ # keys here are mapped attribute keys, so
+ # look at mapper attribute keys for pk
+ params = dict(
+ (propkey_to_col[propkey].key, state_dict[propkey])
+ for propkey in set(propkey_to_col)
+ .intersection(state_dict)
+ .difference(mapper._pk_attr_keys_by_table[table])
+ )
+ has_all_defaults = True
+ else:
+ params = {}
+ for propkey in set(propkey_to_col).intersection(
+ state.committed_state
+ ):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+
+ if hasattr(value, "__clause_element__") or isinstance(
+ value, sql.ClauseElement
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ # guard against values that generate non-__nonzero__
+ # objects for __eq__()
+ elif (
+ state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]
+ )
+ is not True
+ ):
+ params[col.key] = value
+
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = (
+ mapper._server_onupdate_default_cols[table]
+ ).issubset(params)
+ else:
+ has_all_defaults = True
+
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+
+ if not bulk and not (params or value_params):
+ # HACK: check for history in other tables, in case the
+ # history is only in a different table than the one
+ # where the version_id_col is. This logic was lost
+ # from 0.9 -> 1.0.0 and restored in 1.0.6.
+ for prop in mapper._columntoproperty.values():
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history.added:
+ break
+ else:
+ # no net change, break
+ continue
+
+ col = mapper.version_id_col
+ no_params = not params and not value_params
+ params[col._label] = update_version_id
+
+ if (
+ bulk or col.key not in params
+ ) and mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
+ elif mapper.version_id_generator is False and no_params:
+ # no version id generator, no values set on the table,
+ # and version id wasn't manually incremented.
+ # set version id to itself so we get an UPDATE
+ # statement
+ params[col.key] = update_version_id
+
+ elif not (params or value_params):
+ continue
+
+ has_all_pks = True
+ expect_pk_cascaded = False
+ if bulk:
+ # keys here are mapped attribute keys, so
+ # look at mapper attribute keys for pk
+ pk_params = dict(
+ (propkey_to_col[propkey]._label, state_dict.get(propkey))
+ for propkey in set(propkey_to_col).intersection(
+ mapper._pk_attr_keys_by_table[table]
+ )
+ )
+ else:
+ pk_params = {}
+ for col in pks:
+ propkey = mapper._columntoproperty[col].key
+
+ history = state.manager[propkey].impl.get_history(
+ state, state_dict, attributes.PASSIVE_OFF
+ )
+
+ if history.added:
+ if (
+ not history.deleted
+ or ("pk_cascaded", state, col)
+ in uowtransaction.attributes
+ ):
+ expect_pk_cascaded = True
+ pk_params[col._label] = history.added[0]
+ params.pop(col.key, None)
+ else:
+ # else, use the old value to locate the row
+ pk_params[col._label] = history.deleted[0]
+ if col in value_params:
+ has_all_pks = False
+ else:
+ pk_params[col._label] = history.unchanged[0]
+ if pk_params[col._label] is None:
+ raise orm_exc.FlushError(
+ "Can't update table %s using NULL for primary "
+ "key value on column %s" % (table, col)
+ )
+
+ if params or value_params:
+ params.update(pk_params)
+ yield (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ )
+ elif expect_pk_cascaded:
+ # no UPDATE occurs on this table, but we expect that CASCADE rules
+ # have changed the primary key of the row; propagate this event to
+ # other columns that expect to have been modified. this normally
+ # occurs after the UPDATE is emitted however we invoke it here
+ # explicitly in the absence of our invoking an UPDATE
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
+
+
+def _collect_post_update_commands(
+ base_mapper, uowtransaction, table, states_to_update, post_update_cols
+):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states within a post_update operation.
+
+ """
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
+
+ # assert table in mapper._pks_by_table
+
+ pks = mapper._pks_by_table[table]
+ params = {}
+ hasdata = False
+
+ for col in mapper._cols_by_table[table]:
+ if col in pks:
+ params[col._label] = mapper._get_state_attr_by_column(
+ state, state_dict, col, passive=attributes.PASSIVE_OFF
+ )
+
+ elif col in post_update_cols or col.onupdate is not None:
+ prop = mapper._columntoproperty[col]
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history.added:
+ value = history.added[0]
+ params[col.key] = value
+ hasdata = True
+ if hasdata:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+
+ col = mapper.version_id_col
+ params[col._label] = update_version_id
+
+ if (
+ bool(state.key)
+ and col.key not in params
+ and mapper.version_id_generator is not False
+ ):
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
+ yield state, state_dict, mapper, connection, params
+
+
+def _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+):
+ """Identify values to use in DELETE statements for a list of
+ states to be deleted."""
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
+
+ if table not in mapper._pks_by_table:
+ continue
+
+ params = {}
+ for col in mapper._pks_by_table[table]:
+ params[
+ col.key
+ ] = value = mapper._get_committed_state_attr_by_column(
+ state, state_dict, col
+ )
+ if value is None:
+ raise orm_exc.FlushError(
+ "Can't delete from table %s "
+ "using NULL for primary "
+ "key value on column %s" % (table, col)
+ )
+
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = update_version_id
+ yield params, connection
+
+
+def _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ bookkeeping=True,
+):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_update_commands()."""
+
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ def update_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
+
+ if needs_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type,
+ )
+ )
+
+ stmt = table.update().where(clauses)
+ return stmt
+
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
+
+ for (
+ (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
+ records,
+ ) in groupby(
+ update,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6], # has_all_defaults
+ rec[7], # has all pks
+ ),
+ ):
+ rows = 0
+ records = list(records)
+
+ statement = cached_stmt
+ return_defaults = False
+
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ return_defaults = True
+ elif (
+ bookkeeping
+ and not has_all_defaults
+ and mapper.base_mapper.eager_defaults
+ ):
+ statement = statement.return_defaults()
+ return_defaults = True
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+ return_defaults = True
+
+ assert_singlerow = (
+ connection.dialect.supports_sane_rowcount
+ if not return_defaults
+ else connection.dialect.supports_sane_rowcount_returning
+ )
+
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
+ allow_multirow = has_all_defaults and not needs_version_id
+
+ if hasvalue:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults,
+ )
+ rows += c.rowcount
+ check_rowcount = assert_singlerow
+ else:
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
+ # TODO: why with bookkeeping=False?
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults,
+ )
+ rows += c.rowcount
+ else:
+ multiparams = [rec[2] for rec in records]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and len(multiparams) == 1
+ )
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ rows += c.rowcount
+
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
+ )
+
+ if check_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
+
+ elif needs_version_id:
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
+
+
+def _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ insert,
+ bookkeeping=True,
+):
+ """Emit INSERT statements corresponding to value lists collected
+ by _collect_insert_commands()."""
+
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ for (
+ (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ records,
+ ) in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7],
+ ),
+ ):
+
+ statement = cached_stmt
+
+ if (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ )
+ and has_all_pks
+ and not hasvalue
+ ):
+ # the "we don't need newly generated values back" section.
+ # here we have all the PKs, all the defaults or we don't want
+ # to fetch them, or the dialect doesn't support RETURNING at all
+ # so we have to post-fetch / use lastrowid anyway.
+ records = list(records)
+ multiparams = [rec[2] for rec in records]
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ if bookkeeping:
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ ) in zip(records, c.context.compiled_parameters):
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params,
+ False,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+
+ else:
+ # here, we need defaults and/or pk values back.
+
+ records = list(records)
+ if (
+ not hasvalue
+ and connection.dialect.insert_executemany_returning
+ and len(records) > 1
+ ):
+ do_executemany = True
+ else:
+ do_executemany = False
+
+ if not has_all_defaults and base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+ elif do_executemany:
+ statement = statement.return_defaults(*table.primary_key)
+
+ if do_executemany:
+ multiparams = [rec[2] for rec in records]
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ if bookkeeping:
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ inserted_primary_key,
+ returned_defaults,
+ ) in util.zip_longest(
+ records,
+ c.context.compiled_parameters,
+ c.inserted_primary_key_rows,
+ c.returned_defaults_rows or (),
+ ):
+ if inserted_primary_key is None:
+ # this is a real problem and means that we didn't
+ # get back as many PK rows. we can't continue
+ # since this indicates PK rows were missing, which
+ # means we likely mis-populated records starting
+ # at that point with incorrectly matched PK
+ # values.
+ raise orm_exc.FlushError(
+ "Multi-row INSERT statement for %s did not "
+ "produce "
+ "the correct number of INSERTed rows for "
+ "RETURNING. Ensure there are no triggers or "
+ "special driver issues preventing INSERT from "
+ "functioning properly." % mapper_rec
+ )
+
+ for pk, col in zip(
+ inserted_primary_key,
+ mapper._pks_by_table[table],
+ ):
+ prop = mapper_rec._columntoproperty[col]
+ if state_dict.get(prop.key) is None:
+ state_dict[prop.key] = pk
+
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params,
+ False,
+ returned_defaults,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+ else:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
+ if value_params:
+ result = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
+ else:
+ result = connection._execute_20(
+ statement,
+ params,
+ execution_options=execution_options,
+ )
+
+ primary_key = result.inserted_primary_key
+ if primary_key is None:
+ raise orm_exc.FlushError(
+ "Single-row INSERT statement for %s "
+ "did not produce a "
+ "new primary key result "
+ "being invoked. Ensure there are no triggers or "
+ "special driver issues preventing INSERT from "
+ "functioning properly." % (mapper_rec,)
+ )
+ for pk, col in zip(
+ primary_key, mapper._pks_by_table[table]
+ ):
+ prop = mapper_rec._columntoproperty[col]
+ if (
+ col in value_params
+ or state_dict.get(prop.key) is None
+ ):
+ state_dict[prop.key] = pk
+ if bookkeeping:
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result,
+ result.context.compiled_parameters[0],
+ value_params,
+ False,
+ result.returned_defaults
+ if not result.context.executemany
+ else None,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+
+
+def _emit_post_update_statements(
+ base_mapper, uowtransaction, mapper, table, update
+):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_post_update_commands()."""
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ def update_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
+
+ if needs_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type,
+ )
+ )
+
+ stmt = table.update().where(clauses)
+
+ if mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
+
+ statement = base_mapper._memo(("post_update", table), update_stmt)
+
+ # execute each UPDATE in the order according to the original
+ # list of states to guarantee row access order, but
+ # also group them into common (connection, cols) sets
+ # to support executemany().
+ for key, records in groupby(
+ update,
+ lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
+ ):
+ rows = 0
+
+ records = list(records)
+ connection = key[0]
+
+ assert_singlerow = (
+ connection.dialect.supports_sane_rowcount
+ if mapper.version_id_col is None
+ else connection.dialect.supports_sane_rowcount_returning
+ )
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
+ allow_multirow = not needs_version_id or assert_multirow
+
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
+ for state, state_dict, mapper_rec, connection, params in records:
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
+ _postfetch_post_update(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
+ rows += c.rowcount
+ else:
+ multiparams = [
+ params
+ for state, state_dict, mapper_rec, conn, params in records
+ ]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and len(multiparams) == 1
+ )
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ rows += c.rowcount
+ for state, state_dict, mapper_rec, connection, params in records:
+ _postfetch_post_update(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
+
+ if check_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
+
+ elif needs_version_id:
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
+
+
+def _emit_delete_statements(
+ base_mapper, uowtransaction, mapper, table, delete
+):
+ """Emit DELETE statements corresponding to value lists collected
+ by _collect_delete_commands()."""
+
+ need_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ def delete_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col.key, type_=col.type)
+ )
+
+ if need_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col.key, type_=mapper.version_id_col.type
+ )
+ )
+
+ return table.delete().where(clauses)
+
+ statement = base_mapper._memo(("delete", table), delete_stmt)
+ for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
+ del_objects = [params for params, connection in recs]
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ expected = len(del_objects)
+ rows_matched = -1
+ only_warn = False
+
+ if (
+ need_version_id
+ and not connection.dialect.supports_sane_multi_rowcount
+ ):
+ if connection.dialect.supports_sane_rowcount:
+ rows_matched = 0
+ # execute deletes individually so that versioned
+ # rows can be verified
+ for params in del_objects:
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+ rows_matched += c.rowcount
+ else:
+ util.warn(
+ "Dialect %s does not support deleted rowcount "
+ "- versioning cannot be verified."
+ % connection.dialect.dialect_description
+ )
+ connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
+ else:
+ c = connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
+
+ if not need_version_id:
+ only_warn = True
+
+ rows_matched = c.rowcount
+
+ if (
+ base_mapper.confirm_deleted_rows
+ and rows_matched > -1
+ and expected != rows_matched
+ and (
+ connection.dialect.supports_sane_multi_rowcount
+ or len(del_objects) == 1
+ )
+ ):
+ # TODO: why does this "only warn" if versioning is turned off,
+ # whereas the UPDATE raises?
+ if only_warn:
+ util.warn(
+ "DELETE statement on table '%s' expected to "
+ "delete %d row(s); %d were matched. Please set "
+ "confirm_deleted_rows=False within the mapper "
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
+ )
+ else:
+ raise orm_exc.StaleDataError(
+ "DELETE statement on table '%s' expected to "
+ "delete %d row(s); %d were matched. Please set "
+ "confirm_deleted_rows=False within the mapper "
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
+ )
+
+
+def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
+ """finalize state on states that have been inserted or updated,
+ including calling after_insert/after_update events.
+
+ """
+ for state, state_dict, mapper, connection, has_identity in states:
+
+ if mapper._readonly_props:
+ readonly = state.unmodified_intersection(
+ [
+ p.key
+ for p in mapper._readonly_props
+ if (
+ p.expire_on_flush
+ and (not p.deferred or p.key in state.dict)
+ )
+ or (
+ not p.expire_on_flush
+ and not p.deferred
+ and p.key not in state.dict
+ )
+ ]
+ )
+ if readonly:
+ state._expire_attributes(state.dict, readonly)
+
+ # if eager_defaults option is enabled, load
+ # all expired cols. Else if we have a version_id_col, make sure
+ # it isn't expired.
+ toload_now = []
+
+ if base_mapper.eager_defaults:
+ toload_now.extend(
+ state._unloaded_non_object.intersection(
+ mapper._server_default_plus_onupdate_propkeys
+ )
+ )
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_generator is False
+ ):
+ if mapper._version_id_prop.key in state.unloaded:
+ toload_now.extend([mapper._version_id_prop.key])
+
+ if toload_now:
+ state.key = base_mapper._identity_key_from_state(state)
+ stmt = future.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ loading.load_on_ident(
+ uowtransaction.session,
+ stmt,
+ state.key,
+ refresh_state=state,
+ only_load_props=toload_now,
+ )
+
+ # call after_XXX extensions
+ if not has_identity:
+ mapper.dispatch.after_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.after_update(mapper, connection, state)
+
+ if (
+ mapper.version_id_generator is False
+ and mapper.version_id_col is not None
+ ):
+ if state_dict[mapper._version_id_prop.key] is None:
+ raise orm_exc.FlushError(
+ "Instance does not contain a non-NULL version value"
+ )
+
+
+def _postfetch_post_update(
+ mapper, uowtransaction, table, state, dict_, result, params
+):
+ if uowtransaction.is_deleted(state):
+ return
+
+ prefetch_cols = result.context.compiled.prefetch
+ postfetch_cols = result.context.compiled.postfetch
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+ refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+ if refresh_flush:
+ load_evt_attrs = []
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
+ if refresh_flush:
+ load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+ if refresh_flush and load_evt_attrs:
+ mapper.class_manager.dispatch.refresh_flush(
+ state, uowtransaction, load_evt_attrs
+ )
+
+ if postfetch_cols:
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
+
+
+def _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ dict_,
+ result,
+ params,
+ value_params,
+ isupdate,
+ returned_defaults,
+):
+ """Expire attributes in need of newly persisted database state,
+ after an INSERT or UPDATE statement has proceeded for that
+ state."""
+
+ prefetch_cols = result.context.compiled.prefetch
+ postfetch_cols = result.context.compiled.postfetch
+ returning_cols = result.context.compiled.returning
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+ refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+ if refresh_flush:
+ load_evt_attrs = []
+
+ if returning_cols:
+ row = returned_defaults
+ if row is not None:
+ for row_value, col in zip(row, returning_cols):
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
+ continue
+
+ # note that columns can be in the "return defaults" that are
+ # not mapped to this mapper, typically because they are
+ # "excluded", which can be specified directly or also occurs
+ # when using declarative w/ single table inheritance
+ prop = mapper._columntoproperty.get(col)
+ if prop:
+ dict_[prop.key] = row_value
+ if refresh_flush:
+ load_evt_attrs.append(prop.key)
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
+ if refresh_flush:
+ load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+ if refresh_flush and load_evt_attrs:
+ mapper.class_manager.dispatch.refresh_flush(
+ state, uowtransaction, load_evt_attrs
+ )
+
+ if isupdate and value_params:
+ # explicitly suit the use case specified by
+ # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
+ # database which are set to themselves in order to do a version bump.
+ postfetch_cols.extend(
+ [
+ col
+ for col in value_params
+ if col.primary_key and col not in returning_cols
+ ]
+ )
+
+ if postfetch_cols:
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
+
+ # synchronize newly inserted ids from one table to the next
+ # TODO: this still goes a little too often. would be nice to
+ # have definitive list of "columns that changed" here
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
+
+
+def _postfetch_bulk_save(mapper, dict_, table):
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
+
+
+def _connections_for_states(base_mapper, uowtransaction, states):
+ """Return an iterator of (state, state.dict, mapper, connection).
+
+ The states are sorted according to _sort_states, then paired
+ with the connection they should be using for the given
+ unit of work transaction.
+
+ """
+ # if session has a connection callable,
+ # organize individual states with the connection
+ # to use for update
+ if uowtransaction.session.connection_callable:
+ connection_callable = uowtransaction.session.connection_callable
+ else:
+ connection = uowtransaction.transaction.connection(base_mapper)
+ connection_callable = None
+
+ for state in _sort_states(base_mapper, states):
+ if connection_callable:
+ connection = connection_callable(base_mapper, state.obj())
+
+ mapper = state.manager.mapper
+
+ yield state, state.dict, mapper, connection
+
+
+def _sort_states(mapper, states):
+ pending = set(states)
+ persistent = set(s for s in pending if s.key is not None)
+ pending.difference_update(persistent)
+
+ try:
+ persistent_sorted = sorted(
+ persistent, key=mapper._persistent_sortkey_fn
+ )
+ except TypeError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Could not sort objects by primary key; primary key "
+ "values must be sortable in Python (was: %s)" % err
+ ),
+ replace_context=err,
+ )
+ return (
+ sorted(pending, key=operator.attrgetter("insert_order"))
+ + persistent_sorted
+ )
+
+
+_EMPTY_DICT = util.immutabledict()
+
+
+class BulkUDCompileState(CompileState):
+ class default_update_options(Options):
+ _synchronize_session = "evaluate"
+ _autoflush = True
+ _subject_mapper = None
+ _resolved_values = _EMPTY_DICT
+ _resolved_keys_as_propnames = _EMPTY_DICT
+ _value_evaluators = _EMPTY_DICT
+ _matched_objects = None
+ _matched_rows = None
+ _refresh_identity_token = None
+
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ if is_reentrant_invoke:
+ return statement, execution_options
+
+ (
+ update_options,
+ execution_options,
+ ) = BulkUDCompileState.default_update_options.from_execution_options(
+ "_sa_orm_update_options",
+ {"synchronize_session"},
+ execution_options,
+ statement._execution_options,
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'evaluate', 'fetch', False"
+ )
+
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ update_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if update_options._autoflush:
+ session._autoflush()
+
+ statement = statement._annotate(
+ {"synchronize_session": update_options._synchronize_session}
+ )
+
+ # this stage of the execution is called before the do_orm_execute event
+ # hook. meaning for an extension like horizontal sharding, this step
+ # happens before the extension splits out into multiple backends and
+ # runs only once. if we do pre_sync_fetch, we execute a SELECT
+ # statement, which the horizontal sharding extension splits amongst the
+ # shards and combines the results together.
+
+ if update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+
+ return (
+ statement,
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_update_options": update_options}
+ ),
+ )
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ # this stage of the execution is called after the
+ # do_orm_execute event hook. meaning for an extension like
+ # horizontal sharding, this step happens *within* the horizontal
+ # sharding event handler which calls session.execute() re-entrantly
+ # and will occur for each backend individually.
+ # the sharding extension then returns its own merged result from the
+ # individual ones we return here.
+
+ update_options = execution_options["_sa_orm_update_options"]
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(session, result, update_options)
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(session, result, update_options)
+
+ return result
+
+ @classmethod
+ def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in the
+ table being updated or deleted, produce additional WHERE criteria such
+ that only the appropriate subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ collected from the statement.
+
+ """
+
+ return_crit = ()
+
+ adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in global_attributes:
+ return_crit += tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+
+ if ext_info.mapper._single_table_criterion is not None:
+ return_crit += (ext_info.mapper._single_table_criterion,)
+
+ if adapter:
+ return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+ return return_crit
+
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
+
+ try:
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(
+ global_attributes, mapper
+ )
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ except evaluator.UnevaluatableError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ 'Could not evaluate current criteria in Python: "%s". '
+ "Specify 'fetch' or False for the "
+ "synchronize_session execution option." % err
+ ),
+ from_=err,
+ )
+
+ if statement.__visit_name__ == "lambda_element":
+ # ._resolved is called on every LambdaElement in order to
+ # generate the cache key, so this access does not add
+ # additional expense
+ effective_statement = statement._resolved
+ else:
+ effective_statement = statement
+
+ if effective_statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(
+ mapper, effective_statement
+ )
+ value_evaluators = {}
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ # TODO: detect when the where clause is a trivial primary key match.
+ matched_objects = [
+ state.obj()
+ for state in session.identity_map.all_states()
+ if state.mapper.isa(mapper)
+ and not state.expired
+ and eval_condition(state.obj())
+ and (
+ update_options._refresh_identity_token is None
+ # TODO: coverage for the case where horizontal sharding
+ # invokes an update() or delete() given an explicit identity
+ # token up front
+ or state.identity_token
+ == update_options._refresh_identity_token
+ )
+ ]
+ return update_options + {
+ "_matched_objects": matched_objects,
+ "_value_evaluators": value_evaluators,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
+
+ @classmethod
+ def _get_resolved_values(cls, mapper, statement):
+ if statement._multi_values:
+ return []
+ elif statement._ordered_values:
+ return list(statement._ordered_values)
+ elif statement._values:
+ return list(statement._values.items())
+ else:
+ return []
+
+ @classmethod
+ def _resolved_keys_as_propnames(cls, mapper, resolved_values):
+ values = []
+ for k, v in resolved_values:
+ if isinstance(k, attributes.QueryableAttribute):
+ values.append((k.key, v))
+ continue
+ elif hasattr(k, "__clause_element__"):
+ k = k.__clause_element__()
+
+ if mapper and isinstance(k, expression.ColumnElement):
+ try:
+ attr = mapper._columntoproperty[k]
+ except orm_exc.UnmappedColumnError:
+ pass
+ else:
+ values.append((attr.key, v))
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Invalid expression type: %r" % k
+ )
+ return values
+
+ @classmethod
+ def _do_pre_synchronize_fetch(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+
+ select_stmt = (
+ select(*(mapper.primary_key + (mapper.select_identity_token,)))
+ .select_from(mapper)
+ .options(*statement._with_options)
+ )
+ select_stmt._where_criteria = statement._where_criteria
+
+ def skip_for_full_returning(orm_context):
+ bind = orm_context.session.get_bind(**orm_context.bind_arguments)
+ if bind.dialect.full_returning:
+ return _result.null_result()
+ else:
+ return None
+
+ result = session.execute(
+ select_stmt,
+ params,
+ execution_options,
+ bind_arguments,
+ _add_event=skip_for_full_returning,
+ )
+ matched_rows = result.fetchall()
+
+ value_evaluators = _EMPTY_DICT
+
+ if statement.__visit_name__ == "lambda_element":
+ # ._resolved is called on every LambdaElement in order to
+ # generate the cache key, so this access does not add
+ # additional expense
+ effective_statement = statement._resolved
+ else:
+ effective_statement = statement
+
+ if effective_statement.__visit_name__ == "update":
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(
+ mapper, effective_statement
+ )
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ else:
+ resolved_keys_as_propnames = _EMPTY_DICT
+
+ return update_options + {
+ "_value_evaluators": value_evaluators,
+ "_matched_rows": matched_rows,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
+
+
+class ORMDMLState:
+ @classmethod
+ def get_entity_description(cls, statement):
+ ext_info = statement.table._annotations["parententity"]
+ mapper = ext_info.mapper
+ if ext_info.is_aliased_class:
+ _label_name = ext_info.name
+ else:
+ _label_name = mapper.class_.__name__
+
+ return {
+ "name": _label_name,
+ "type": mapper.class_,
+ "expr": ext_info.entity,
+ "entity": ext_info.entity,
+ "table": mapper.local_table,
+ }
+
+ @classmethod
+ def get_returning_column_descriptions(cls, statement):
+ def _ent_for_col(c):
+ return c._annotations.get("parententity", None)
+
+ def _attr_for_col(c, ent):
+ if ent is None:
+ return c
+ proxy_key = c._annotations.get("proxy_key", None)
+ if not proxy_key:
+ return c
+ else:
+ return getattr(ent.entity, proxy_key, c)
+
+ return [
+ {
+ "name": c.key,
+ "type": c.type,
+ "expr": _attr_for_col(c, ent),
+ "aliased": ent.is_aliased_class,
+ "entity": ent.entity,
+ }
+ for c, ent in [
+ (c, _ent_for_col(c)) for c in statement._all_selected_columns
+ ]
+ ]
+
+
+@CompileState.plugin_for("orm", "insert")
+class ORMInsert(ORMDMLState, InsertDMLState):
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ return (
+ statement,
+ util.immutabledict(execution_options),
+ )
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+ return result
+
+
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+
+ self = cls.__new__(cls)
+
+ ext_info = statement.table._annotations["parententity"]
+
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
+
+ self._resolved_values = cls._get_resolved_values(mapper, statement)
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ if not statement._preserve_parameter_order and statement._values:
+ self._resolved_values = dict(self._resolved_values)
+
+ new_stmt = sql.Update.__new__(sql.Update)
+ new_stmt.__dict__.update(statement.__dict__)
+ new_stmt.table = mapper.local_table
+
+ # note if the statement has _multi_values, these
+ # are passed through to the new statement, which will then raise
+ # InvalidRequestError because UPDATE doesn't support multi_values
+ # right now.
+ if statement._ordered_values:
+ new_stmt._ordered_values = self._resolved_values
+ elif statement._values:
+ new_stmt._values = self._resolved_values
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ new_stmt = new_stmt.where(*new_crit)
+
+ # if we are against a lambda statement we might not be the
+ # topmost object that received per-execute annotations
+
+ if (
+ compiler._annotations.get("synchronize_session", None) == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ if new_stmt._returning:
+ raise sa_exc.InvalidRequestError(
+ "Can't use synchronize_session='fetch' "
+ "with explicit returning()"
+ )
+ new_stmt = new_stmt.returning(*mapper.primary_key)
+
+ UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return core_get_crud_kv_pairs(statement, kv_iterator)
+
+ mapper = plugin_subject.mapper
+
+ values = []
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, util.string_types):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ values.append(
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+ )
+ else:
+ values.extend(
+ core_get_crud_kv_pairs(
+ statement, desc._bulk_update_tuples(v)
+ )
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ values.extend(
+ core_get_crud_kv_pairs(
+ statement, attr._bulk_update_tuples(v)
+ )
+ )
+ else:
+ values.append(
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+ )
+ return values
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+ states = set()
+ evaluated_keys = list(update_options._value_evaluators.keys())
+ values = update_options._resolved_keys_as_propnames
+ attrib = set(k for k, v in values)
+ for obj in update_options._matched_objects:
+
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+
+ # the evaluated states were gathered across all identity tokens.
+ # however the post_sync events are called per identity token,
+ # so filter.
+ if (
+ update_options._refresh_identity_token is not None
+ and state.identity_token
+ != update_options._refresh_identity_token
+ ):
+ continue
+
+ # only evaluate unmodified attributes
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
+ for key in to_evaluate:
+ if key in dict_:
+ dict_[key] = update_options._value_evaluators[key](obj)
+
+ state.manager.dispatch.refresh(state, None, to_evaluate)
+
+ state._commit(dict_, list(to_evaluate))
+
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ if to_expire:
+ state._expire_attributes(dict_, to_expire)
+
+ states.add(state)
+ session._register_altered(states)
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
+ target_mapper = update_options._subject_mapper
+
+ states = set()
+ evaluated_keys = list(update_options._value_evaluators.keys())
+
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ objs = [
+ session.identity_map[identity_key]
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key),
+ identity_token=identity_token,
+ )
+ for primary_key, identity_token in [
+ (row[0:-1], row[-1]) for row in matched_rows
+ ]
+ if update_options._refresh_identity_token is None
+ or identity_token == update_options._refresh_identity_token
+ ]
+ if identity_key in session.identity_map
+ ]
+
+ values = update_options._resolved_keys_as_propnames
+ attrib = set(k for k, v in values)
+
+ for obj in objs:
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
+ for key in to_evaluate:
+ if key in dict_:
+ dict_[key] = update_options._value_evaluators[key](obj)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
+
+ state._commit(dict_, list(to_evaluate))
+
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ if to_expire:
+ state._expire_attributes(dict_, to_expire)
+
+ states.add(state)
+ session._register_altered(states)
+
+
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ self = cls.__new__(cls)
+
+ ext_info = statement.table._annotations["parententity"]
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ statement = statement.where(*new_crit)
+
+ if (
+ mapper
+ and compiler._annotations.get("synchronize_session", None)
+ == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ statement = statement.returning(*mapper.primary_key)
+
+ DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(obj)
+ for obj in update_options._matched_objects
+ ]
+ )
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
+ target_mapper = update_options._subject_mapper
+
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ for row in matched_rows:
+ primary_key = row[0:-1]
+ identity_token = row[-1]
+
+ # TODO: inline this and call remove_newly_deleted
+ # once
+ identity_key = target_mapper.identity_key_from_primary_key(
+ list(primary_key),
+ identity_token=identity_token,
+ )
+ if identity_key in session.identity_map:
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(
+ session.identity_map[identity_key]
+ )
+ ]
+ )
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
new file mode 100644
index 0000000..d32af17
--- /dev/null
+++ b/lib/sqlalchemy/orm/properties.py
@@ -0,0 +1,430 @@
+# orm/properties.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""MapperProperty implementations.
+
+This is a private module which defines the behavior of individual ORM-
+mapped attributes.
+
+"""
+from __future__ import absolute_import
+
+from . import attributes
+from .descriptor_props import CompositeProperty
+from .descriptor_props import ConcreteInheritedProperty
+from .descriptor_props import SynonymProperty
+from .interfaces import PropComparator
+from .interfaces import StrategizedProperty
+from .relationships import RelationshipProperty
+from .. import log
+from .. import util
+from ..sql import coercions
+from ..sql import roles
+
+
+__all__ = [
+ "ColumnProperty",
+ "CompositeProperty",
+ "ConcreteInheritedProperty",
+ "RelationshipProperty",
+ "SynonymProperty",
+]
+
+
+@log.class_logger
+class ColumnProperty(StrategizedProperty):
+ """Describes an object attribute that corresponds to a table column.
+
+ Public constructor is the :func:`_orm.column_property` function.
+
+ """
+
+ strategy_wildcard_key = "column"
+ inherit_cache = True
+ _links_to_entity = False
+
+ __slots__ = (
+ "columns",
+ "group",
+ "deferred",
+ "instrument",
+ "comparator_factory",
+ "descriptor",
+ "active_history",
+ "expire_on_flush",
+ "info",
+ "doc",
+ "strategy_key",
+ "_creation_order",
+ "_is_polymorphic_discriminator",
+ "_mapped_by_synonym",
+ "_deferred_column_loader",
+ "_raise_column_loader",
+ "_renders_in_subqueries",
+ "raiseload",
+ )
+
+ def __init__(self, *columns, **kwargs):
+ r"""Provide a column-level property for use with a mapping.
+
+ Column-based properties can normally be applied to the mapper's
+ ``properties`` dictionary using the :class:`_schema.Column`
+ element directly.
+ Use this function when the given column is not directly present within
+ the mapper's selectable; examples include SQL expressions, functions,
+ and scalar SELECT queries.
+
+ The :func:`_orm.column_property` function returns an instance of
+ :class:`.ColumnProperty`.
+
+ Columns that aren't present in the mapper's selectable won't be
+ persisted by the mapper and are effectively "read-only" attributes.
+
+ :param \*cols:
+ list of Column objects to be mapped.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ scalar attribute should be loaded when replaced, if not
+ already loaded. Normally, history tracking logic for
+ simple non-primary-key scalar values only needs to be
+ aware of the "new" value in order to perform a flush. This
+ flag is available for applications that make use of
+ :func:`.attributes.get_history` or :meth:`.Session.is_modified`
+ which also need to know
+ the "previous" value of the attribute.
+
+ :param comparator_factory: a class which extends
+ :class:`.ColumnProperty.Comparator` which provides custom SQL
+ clause generation for comparison operations.
+
+ :param group:
+ a group name for this property when marked as deferred.
+
+ :param deferred:
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ :func:`~sqlalchemy.orm.deferred`.
+
+ :param doc:
+ optional string that will be applied as the doc on the
+ class-bound descriptor.
+
+ :param expire_on_flush=True:
+ Disable expiry on flush. A column_property() which refers
+ to a SQL expression (and not a single table-bound column)
+ is considered to be a "read only" property; populating it
+ has no effect on the state of data, and it can only return
+ database state. For this reason a column_property()'s value
+ is expired whenever the parent object is involved in a
+ flush, that is, has any kind of "dirty" state within a flush.
+ Setting this parameter to ``False`` will have the effect of
+ leaving any existing value present after the flush proceeds.
+ Note however that the :class:`.Session` with default expiration
+ settings still expires
+ all attributes after a :meth:`.Session.commit` call, however.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ :param raiseload: if True, indicates the column should raise an error
+ when undeferred, rather than loading the value. This can be
+ altered at query time by using the :func:`.deferred` option with
+ raiseload=False.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ .. seealso::
+
+ :ref:`column_property_options` - to map columns while including
+ mapping options
+
+ :ref:`mapper_column_property_sql_expressions` - to map SQL
+ expressions
+
+ """
+ super(ColumnProperty, self).__init__()
+ self.columns = [
+ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
+ ]
+ self.group = kwargs.pop("group", None)
+ self.deferred = kwargs.pop("deferred", False)
+ self.raiseload = kwargs.pop("raiseload", False)
+ self.instrument = kwargs.pop("_instrument", True)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ self.descriptor = kwargs.pop("descriptor", None)
+ self.active_history = kwargs.pop("active_history", False)
+ self.expire_on_flush = kwargs.pop("expire_on_flush", True)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ if "doc" in kwargs:
+ self.doc = kwargs.pop("doc")
+ else:
+ for col in reversed(self.columns):
+ doc = getattr(col, "doc", None)
+ if doc is not None:
+ self.doc = doc
+ break
+ else:
+ self.doc = None
+
+ if kwargs:
+ raise TypeError(
+ "%s received unexpected keyword argument(s): %s"
+ % (self.__class__.__name__, ", ".join(sorted(kwargs.keys())))
+ )
+
+ util.set_creation_order(self)
+
+ self.strategy_key = (
+ ("deferred", self.deferred),
+ ("instrument", self.instrument),
+ )
+ if self.raiseload:
+ self.strategy_key += (("raiseload", True),)
+
+ def _memoized_attr__renders_in_subqueries(self):
+ return ("deferred", True) not in self.strategy_key or (
+ self not in self.parent._readonly_props
+ )
+
+ @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
+ def _memoized_attr__deferred_column_loader(self):
+ state = util.preloaded.orm_state
+ strategies = util.preloaded.orm_strategies
+ return state.InstanceState._instance_level_callable_processor(
+ self.parent.class_manager,
+ strategies.LoadDeferredColumns(self.key),
+ self.key,
+ )
+
+ @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
+ def _memoized_attr__raise_column_loader(self):
+ state = util.preloaded.orm_state
+ strategies = util.preloaded.orm_strategies
+ return state.InstanceState._instance_level_callable_processor(
+ self.parent.class_manager,
+ strategies.LoadDeferredColumns(self.key, True),
+ self.key,
+ )
+
+ def __clause_element__(self):
+ """Allow the ColumnProperty to work in expression before it is turned
+ into an instrumented attribute.
+ """
+
+ return self.expression
+
+ @property
+ def expression(self):
+ """Return the primary column or expression for this ColumnProperty.
+
+ E.g.::
+
+
+ class File(Base):
+ # ...
+
+ name = Column(String(64))
+ extension = Column(String(8))
+ filename = column_property(name + '.' + extension)
+ path = column_property('C:/' + filename.expression)
+
+ .. seealso::
+
+ :ref:`mapper_column_property_sql_expressions_composed`
+
+ """
+ return self.columns[0]
+
+ def instrument_class(self, mapper):
+ if not self.instrument:
+ return
+
+ attributes.register_descriptor(
+ mapper.class_,
+ self.key,
+ comparator=self.comparator_factory(self, mapper),
+ parententity=mapper,
+ doc=self.doc,
+ )
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+
+ if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
+ self.columns
+ ):
+ util.warn(
+ (
+ "On mapper %s, primary key column '%s' is being combined "
+ "with distinct primary key column '%s' in attribute '%s'. "
+ "Use explicit properties to give each column its own "
+ "mapped attribute name."
+ )
+ % (self.parent, self.columns[1], self.columns[0], self.key)
+ )
+
+ def copy(self):
+ return ColumnProperty(
+ deferred=self.deferred,
+ group=self.group,
+ active_history=self.active_history,
+ *self.columns
+ )
+
+ def _getcommitted(
+ self, state, dict_, column, passive=attributes.PASSIVE_OFF
+ ):
+ return state.get_impl(self.key).get_committed_value(
+ state, dict_, passive=passive
+ )
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+ if not self.instrument:
+ return
+ elif self.key in source_dict:
+ value = source_dict[self.key]
+
+ if not load:
+ dest_dict[self.key] = value
+ else:
+ impl = dest_state.get_impl(self.key)
+ impl.set(dest_state, dest_dict, value, None)
+ elif dest_state.has_identity and self.key not in dest_dict:
+ dest_state._expire_attributes(
+ dest_dict, [self.key], no_loader=True
+ )
+
+ class Comparator(util.MemoizedSlots, PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.ColumnProperty` attributes.
+
+ See the documentation for :class:`.PropComparator` for a brief
+ overview.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __slots__ = "__clause_element__", "info", "expressions"
+
+ def _orm_annotate_column(self, column):
+ """annotate and possibly adapt a column to be returned
+ as the mapped-attribute exposed version of the column.
+
+ The column in this context needs to act as much like the
+ column in an ORM mapped context as possible, so includes
+ annotations to give hints to various ORM functions as to
+ the source entity of this column. It also adapts it
+ to the mapper's with_polymorphic selectable if one is
+ present.
+
+ """
+
+ pe = self._parententity
+ annotations = {
+ "entity_namespace": pe,
+ "parententity": pe,
+ "parentmapper": pe,
+ "proxy_key": self.prop.key,
+ }
+
+ col = column
+
+ # for a mapper with polymorphic_on and an adapter, return
+ # the column against the polymorphic selectable.
+ # see also orm.util._orm_downgrade_polymorphic_columns
+ # for the reverse operation.
+ if self._parentmapper._polymorphic_adapter:
+ mapper_local_col = col
+ col = self._parentmapper._polymorphic_adapter.traverse(col)
+
+ # this is a clue to the ORM Query etc. that this column
+ # was adapted to the mapper's polymorphic_adapter. the
+ # ORM uses this hint to know which column its adapting.
+ annotations["adapt_column"] = mapper_local_col
+
+ return col._annotate(annotations)._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": pe}
+ )
+
+ def _memoized_method___clause_element__(self):
+ if self.adapter:
+ return self.adapter(self.prop.columns[0], self.prop.key)
+ else:
+ return self._orm_annotate_column(self.prop.columns[0])
+
+ def _memoized_attr_info(self):
+ """The .info dictionary for this attribute."""
+
+ ce = self.__clause_element__()
+ try:
+ return ce.info
+ except AttributeError:
+ return self.prop.info
+
+ def _memoized_attr_expressions(self):
+ """The full sequence of columns referenced by this
+ attribute, adjusted for any aliasing in progress.
+
+ .. versionadded:: 1.3.17
+
+ """
+ if self.adapter:
+ return [
+ self.adapter(col, self.prop.key)
+ for col in self.prop.columns
+ ]
+ else:
+ return [
+ self._orm_annotate_column(col) for col in self.prop.columns
+ ]
+
+ def _fallback_getattr(self, key):
+ """proxy attribute access down to the mapped column.
+
+ this allows user-defined comparison methods to be accessed.
+ """
+ return getattr(self.__clause_element__(), key)
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.__clause_element__(), *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ col = self.__clause_element__()
+ return op(col._bind_param(op, other), col, **kwargs)
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
new file mode 100644
index 0000000..99e4591
--- /dev/null
+++ b/lib/sqlalchemy/orm/query.py
@@ -0,0 +1,3508 @@
+# orm/query.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The Query class and support.
+
+Defines the :class:`_query.Query` class, the central
+construct used by the ORM to construct database queries.
+
+The :class:`_query.Query` class should not be confused with the
+:class:`_expression.Select` class, which defines database
+SELECT operations at the SQL (non-ORM) level. ``Query`` differs from
+``Select`` in that it returns ORM-mapped objects and interacts with an
+ORM session, whereas the ``Select`` construct interacts directly with the
+database to return iterable result sets.
+
+"""
+import itertools
+import operator
+import types
+
+from . import exc as orm_exc
+from . import interfaces
+from . import loading
+from . import util as orm_util
+from .base import _assertions
+from .context import _column_descriptions
+from .context import _legacy_determine_last_joined_entity
+from .context import _legacy_filter_by_entity_zero
+from .context import LABEL_STYLE_LEGACY_ORM
+from .context import ORMCompileState
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
+from .interfaces import ORMColumnsClauseRole
+from .util import aliased
+from .util import AliasedClass
+from .util import object_mapper
+from .util import with_parent
+from .util import with_polymorphic
+from .. import exc as sa_exc
+from .. import inspect
+from .. import inspection
+from .. import log
+from .. import sql
+from .. import util
+from ..sql import coercions
+from ..sql import elements
+from ..sql import expression
+from ..sql import roles
+from ..sql import Select
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.annotation import SupportsCloneAnnotations
+from ..sql.base import _entity_namespace_key
+from ..sql.base import _generative
+from ..sql.base import Executable
+from ..sql.selectable import _MemoizedSelectEntities
+from ..sql.selectable import _SelectFromElements
+from ..sql.selectable import ForUpdateArg
+from ..sql.selectable import GroupedElement
+from ..sql.selectable import HasHints
+from ..sql.selectable import HasPrefixes
+from ..sql.selectable import HasSuffixes
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectBase
+from ..sql.selectable import SelectStatementGrouping
+from ..sql.visitors import InternalTraversal
+from ..util import collections_abc
+
+__all__ = ["Query", "QueryContext", "aliased"]
+
+
+@inspection._self_inspects
+@log.class_logger
+class Query(
+ _SelectFromElements,
+ SupportsCloneAnnotations,
+ HasPrefixes,
+ HasSuffixes,
+ HasHints,
+ Executable,
+):
+
+ """ORM-level SQL construction object.
+
+ :class:`_query.Query`
+ is the source of all SELECT statements generated by the
+ ORM, both those formulated by end-user query operations as well as by
+ high level internal operations such as related collection loading. It
+ features a generative interface whereby successive calls return a new
+ :class:`_query.Query` object, a copy of the former with additional
+ criteria and options associated with it.
+
+ :class:`_query.Query` objects are normally initially generated using the
+ :meth:`~.Session.query` method of :class:`.Session`, and in
+ less common cases by instantiating the :class:`_query.Query` directly and
+ associating with a :class:`.Session` using the
+ :meth:`_query.Query.with_session`
+ method.
+
+ For a full walk through of :class:`_query.Query` usage, see the
+ :ref:`ormtutorial_toplevel`.
+
+ """
+
+ # elements that are in Core and can be cached in the same way
+ _where_criteria = ()
+ _having_criteria = ()
+
+ _order_by_clauses = ()
+ _group_by_clauses = ()
+ _limit_clause = None
+ _offset_clause = None
+
+ _distinct = False
+ _distinct_on = ()
+
+ _for_update_arg = None
+ _correlate = ()
+ _auto_correlate = True
+ _from_obj = ()
+ _setup_joins = ()
+ _legacy_setup_joins = ()
+ _label_style = LABEL_STYLE_LEGACY_ORM
+
+ _memoized_select_entities = ()
+
+ _compile_options = ORMCompileState.default_compile_options
+
+ load_options = QueryContext.default_load_options + {
+ "_legacy_uniquing": True
+ }
+
+ _params = util.EMPTY_DICT
+
+ # local Query builder state, not needed for
+ # compilation or execution
+ _aliased_generation = None
+ _enable_assertions = True
+ _last_joined_entity = None
+ _statement = None
+
+ # mirrors that of ClauseElement, used to propagate the "orm"
+ # plugin as well as the "subject" of the plugin, e.g. the mapper
+ # we are querying against.
+ _propagate_attrs = util.immutabledict()
+
+ def __init__(self, entities, session=None):
+ """Construct a :class:`_query.Query` directly.
+
+ E.g.::
+
+ q = Query([User, Address], session=some_session)
+
+ The above is equivalent to::
+
+ q = some_session.query(User, Address)
+
+ :param entities: a sequence of entities and/or SQL expressions.
+
+ :param session: a :class:`.Session` with which the
+ :class:`_query.Query`
+ will be associated. Optional; a :class:`_query.Query`
+ can be associated
+ with a :class:`.Session` generatively via the
+ :meth:`_query.Query.with_session` method as well.
+
+ .. seealso::
+
+ :meth:`.Session.query`
+
+ :meth:`_query.Query.with_session`
+
+ """
+
+ self.session = session
+ self._set_entities(entities)
+
+ def _set_propagate_attrs(self, values):
+ self._propagate_attrs = util.immutabledict(values)
+ return self
+
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ ent,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for ent in util.to_list(entities)
+ ]
+
+ def _entity_from_pre_ent_zero(self):
+ if not self._raw_columns:
+ return None
+
+ ent = self._raw_columns[0]
+
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ elif isinstance(ent, ORMColumnsClauseRole):
+ return ent.entity
+ elif "bundle" in ent._annotations:
+ return ent._annotations["bundle"]
+ else:
+ # label, other SQL expression
+ for element in visitors.iterate(ent):
+ if "parententity" in element._annotations:
+ return element._annotations["parententity"]
+ else:
+ return None
+
+ def _only_full_mapper_zero(self, methname):
+ if (
+ len(self._raw_columns) != 1
+ or "parententity" not in self._raw_columns[0]._annotations
+ or not self._raw_columns[0].is_selectable
+ ):
+ raise sa_exc.InvalidRequestError(
+ "%s() can only be used against "
+ "a single mapped class." % methname
+ )
+
+ return self._raw_columns[0]._annotations["parententity"]
+
+ def _set_select_from(self, obj, set_base_alias):
+ fa = [
+ coercions.expect(
+ roles.StrictFromClauseRole,
+ elem,
+ allow_select=True,
+ apply_propagate_attrs=self,
+ )
+ for elem in obj
+ ]
+
+ self._compile_options += {"_set_base_alias": set_base_alias}
+ self._from_obj = tuple(fa)
+
+ @_generative
+ def _set_lazyload_from(self, state):
+ self.load_options += {"_lazy_loaded_from": state}
+
+ def _get_condition(self):
+ return self._no_criterion_condition(
+ "get", order_by=False, distinct=False
+ )
+
+ def _get_existing_condition(self):
+ self._no_criterion_assertion("get", order_by=False, distinct=False)
+
+ def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
+ if not self._enable_assertions:
+ return
+ if (
+ self._where_criteria
+ or self._statement is not None
+ or self._from_obj
+ or self._legacy_setup_joins
+ or self._limit_clause is not None
+ or self._offset_clause is not None
+ or self._group_by_clauses
+ or (order_by and self._order_by_clauses)
+ or (distinct and self._distinct)
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a "
+ "Query with existing criterion. " % meth
+ )
+
+ def _no_criterion_condition(self, meth, order_by=True, distinct=True):
+ self._no_criterion_assertion(meth, order_by, distinct)
+
+ self._from_obj = self._legacy_setup_joins = ()
+ if self._statement is not None:
+ self._compile_options += {"_statement": None}
+ self._where_criteria = ()
+ self._distinct = False
+
+ self._order_by_clauses = self._group_by_clauses = ()
+
+ def _no_clauseelement_condition(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._order_by_clauses:
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a "
+ "Query with existing criterion. " % meth
+ )
+ self._no_criterion_condition(meth)
+
+ def _no_statement_condition(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._statement is not None:
+ raise sa_exc.InvalidRequestError(
+ (
+ "Query.%s() being called on a Query with an existing full "
+ "statement - can't apply criterion."
+ )
+ % meth
+ )
+
+ def _no_limit_offset(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._limit_clause is not None or self._offset_clause is not None:
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a Query which already has LIMIT "
+ "or OFFSET applied. Call %s() before limit() or offset() "
+ "are applied." % (meth, meth)
+ )
+
+ @property
+ def _has_row_limiting_clause(self):
+ return (
+ self._limit_clause is not None or self._offset_clause is not None
+ )
+
+ def _get_options(
+ self,
+ populate_existing=None,
+ version_check=None,
+ only_load_props=None,
+ refresh_state=None,
+ identity_token=None,
+ ):
+ load_options = {}
+ compile_options = {}
+
+ if version_check:
+ load_options["_version_check"] = version_check
+ if populate_existing:
+ load_options["_populate_existing"] = populate_existing
+ if refresh_state:
+ load_options["_refresh_state"] = refresh_state
+ compile_options["_for_refresh_state"] = True
+ if only_load_props:
+ compile_options["_only_load_props"] = frozenset(only_load_props)
+ if identity_token:
+ load_options["_refresh_identity_token"] = identity_token
+
+ if load_options:
+ self.load_options += load_options
+ if compile_options:
+ self._compile_options += compile_options
+
+ return self
+
+ def _clone(self):
+ return self._generate()
+
+ @property
+ def statement(self):
+ """The full SELECT statement represented by this Query.
+
+ The statement by default will not have disambiguating labels
+ applied to the construct unless with_labels(True) is called
+ first.
+
+ """
+
+ # .statement can return the direct future.Select() construct here, as
+ # long as we are not using subsequent adaption features that
+ # are made against raw entities, e.g. from_self(), with_polymorphic(),
+ # select_entity_from(). If these features are being used, then
+ # the Select() we return will not have the correct .selected_columns
+ # collection and will not embed in subsequent queries correctly.
+ # We could find a way to make this collection "correct", however
+ # this would not be too different from doing the full compile as
+ # we are doing in any case, the Select() would still not have the
+ # proper state for other attributes like whereclause, order_by,
+ # and these features are all deprecated in any case.
+ #
+ # for these reasons, Query is not a Select, it remains an ORM
+ # object for which __clause_element__() must be called in order for
+ # it to provide a real expression object.
+ #
+ # from there, it starts to look much like Query itself won't be
+ # passed into the execute process and wont generate its own cache
+ # key; this will all occur in terms of the ORM-enabled Select.
+ if (
+ not self._compile_options._set_base_alias
+ and not self._compile_options._with_polymorphic_adapt_map
+ ):
+ # if we don't have legacy top level aliasing features in use
+ # then convert to a future select() directly
+ stmt = self._statement_20(for_statement=True)
+ else:
+ stmt = self._compile_state(for_statement=True).statement
+
+ if self._params:
+ stmt = stmt.params(self._params)
+
+ return stmt
+
+ def _final_statement(self, legacy_query_style=True):
+ """Return the 'final' SELECT statement for this :class:`.Query`.
+
+ This is the Core-only select() that will be rendered by a complete
+ compilation of this query, and is what .statement used to return
+ in 1.3.
+
+ This method creates a complete compile state so is fairly expensive.
+
+ """
+
+ q = self._clone()
+
+ return q._compile_state(
+ use_legacy_query_style=legacy_query_style
+ ).statement
+
+ def _statement_20(self, for_statement=False, use_legacy_query_style=True):
+ # TODO: this event needs to be deprecated, as it currently applies
+ # only to ORM query and occurs at this spot that is now more
+ # or less an artificial spot
+ if self.dispatch.before_compile:
+ for fn in self.dispatch.before_compile:
+ new_query = fn(self)
+ if new_query is not None and new_query is not self:
+ self = new_query
+ if not fn._bake_ok:
+ self._compile_options += {"_bake_ok": False}
+
+ compile_options = self._compile_options
+ compile_options += {
+ "_for_statement": for_statement,
+ "_use_legacy_query_style": use_legacy_query_style,
+ }
+
+ if self._statement is not None:
+ stmt = FromStatement(self._raw_columns, self._statement)
+ stmt.__dict__.update(
+ _with_options=self._with_options,
+ _with_context_options=self._with_context_options,
+ _compile_options=compile_options,
+ _execution_options=self._execution_options,
+ _propagate_attrs=self._propagate_attrs,
+ )
+ else:
+ # Query / select() internal attributes are 99% cross-compatible
+ stmt = Select._create_raw_select(**self.__dict__)
+ stmt.__dict__.update(
+ _label_style=self._label_style,
+ _compile_options=compile_options,
+ _propagate_attrs=self._propagate_attrs,
+ )
+ stmt.__dict__.pop("session", None)
+
+ # ensure the ORM context is used to compile the statement, even
+ # if it has no ORM entities. This is so ORM-only things like
+ # _legacy_joins are picked up that wouldn't be picked up by the
+ # Core statement context
+ if "compile_state_plugin" not in stmt._propagate_attrs:
+ stmt._propagate_attrs = stmt._propagate_attrs.union(
+ {"compile_state_plugin": "orm", "plugin_subject": None}
+ )
+
+ return stmt
+
+ def subquery(
+ self,
+ name=None,
+ with_labels=False,
+ reduce_columns=False,
+ ):
+ """Return the full SELECT statement represented by
+ this :class:`_query.Query`, embedded within an
+ :class:`_expression.Alias`.
+
+ Eager JOIN generation within the query is disabled.
+
+ :param name: string name to be assigned as the alias;
+ this is passed through to :meth:`_expression.FromClause.alias`.
+ If ``None``, a name will be deterministically generated
+ at compile time.
+
+ :param with_labels: if True, :meth:`.with_labels` will be called
+ on the :class:`_query.Query` first to apply table-qualified labels
+ to all columns.
+
+ :param reduce_columns: if True,
+ :meth:`_expression.Select.reduce_columns` will
+ be called on the resulting :func:`_expression.select` construct,
+ to remove same-named columns where one also refers to the other
+ via foreign key or WHERE clause equivalence.
+
+ """
+ q = self.enable_eagerloads(False)
+ if with_labels:
+ q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ q = q.statement
+
+ if reduce_columns:
+ q = q.reduce_columns()
+ return q.alias(name=name)
+
+ def cte(self, name=None, recursive=False, nesting=False):
+ r"""Return the full SELECT statement represented by this
+ :class:`_query.Query` represented as a common table expression (CTE).
+
+ Parameters and usage are the same as those of the
+ :meth:`_expression.SelectBase.cte` method; see that method for
+ further details.
+
+ Here is the `PostgreSQL WITH
+ RECURSIVE example
+ <https://www.postgresql.org/docs/current/static/queries-with.html>`_.
+ Note that, in this example, the ``included_parts`` cte and the
+ ``incl_alias`` alias of it are Core selectables, which
+ means the columns are accessed via the ``.c.`` attribute. The
+ ``parts_alias`` object is an :func:`_orm.aliased` instance of the
+ ``Part`` entity, so column-mapped attributes are available
+ directly::
+
+ from sqlalchemy.orm import aliased
+
+ class Part(Base):
+ __tablename__ = 'part'
+ part = Column(String, primary_key=True)
+ sub_part = Column(String, primary_key=True)
+ quantity = Column(Integer)
+
+ included_parts = session.query(
+ Part.sub_part,
+ Part.part,
+ Part.quantity).\
+ filter(Part.part=="our part").\
+ cte(name="included_parts", recursive=True)
+
+ incl_alias = aliased(included_parts, name="pr")
+ parts_alias = aliased(Part, name="p")
+ included_parts = included_parts.union_all(
+ session.query(
+ parts_alias.sub_part,
+ parts_alias.part,
+ parts_alias.quantity).\
+ filter(parts_alias.part==incl_alias.c.sub_part)
+ )
+
+ q = session.query(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).
+ label('total_quantity')
+ ).\
+ group_by(included_parts.c.sub_part)
+
+ .. seealso::
+
+ :meth:`_expression.HasCTE.cte`
+
+ """
+ return self.enable_eagerloads(False).statement.cte(
+ name=name, recursive=recursive, nesting=nesting
+ )
+
+ def label(self, name):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted
+ to a scalar subquery with a label of the given name.
+
+ Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.label`.
+
+ """
+
+ return self.enable_eagerloads(False).statement.label(name)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_query.Query.as_scalar` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_query.Query.scalar_subquery`.",
+ )
+ def as_scalar(self):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted to a scalar subquery.
+
+ """
+ return self.scalar_subquery()
+
+ def scalar_subquery(self):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted to a scalar subquery.
+
+ Analogous to
+ :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`.
+
+ .. versionchanged:: 1.4 The :meth:`_query.Query.scalar_subquery`
+ method replaces the :meth:`_query.Query.as_scalar` method.
+
+ """
+
+ return self.enable_eagerloads(False).statement.scalar_subquery()
+
+ @property
+ def selectable(self):
+ """Return the :class:`_expression.Select` object emitted by this
+ :class:`_query.Query`.
+
+ Used for :func:`_sa.inspect` compatibility, this is equivalent to::
+
+ query.enable_eagerloads(False).with_labels().statement
+
+ """
+ return self.__clause_element__()
+
+ def __clause_element__(self):
+ return (
+ self._with_compile_options(
+ _enable_eagerloads=False, _render_for_subquery=True
+ )
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement
+ )
+
+ @_generative
+ def only_return_tuples(self, value):
+ """When set to True, the query results will always be a tuple.
+
+ This is specifically for single element queries. The default is False.
+
+ .. versionadded:: 1.2.5
+
+ .. seealso::
+
+ :meth:`_query.Query.is_single_entity`
+
+ """
+ self.load_options += dict(_only_return_tuples=value)
+
+ @property
+ def is_single_entity(self):
+ """Indicates if this :class:`_query.Query`
+ returns tuples or single entities.
+
+ Returns True if this query returns a single entity for each instance
+ in its result list, and False if this query returns a tuple of entities
+ for each result.
+
+ .. versionadded:: 1.3.11
+
+ .. seealso::
+
+ :meth:`_query.Query.only_return_tuples`
+
+ """
+ return (
+ not self.load_options._only_return_tuples
+ and len(self._raw_columns) == 1
+ and "parententity" in self._raw_columns[0]._annotations
+ and isinstance(
+ self._raw_columns[0]._annotations["parententity"],
+ ORMColumnsClauseRole,
+ )
+ )
+
+ @_generative
+ def enable_eagerloads(self, value):
+ """Control whether or not eager joins and subqueries are
+ rendered.
+
+ When set to False, the returned Query will not render
+ eager joins regardless of :func:`~sqlalchemy.orm.joinedload`,
+ :func:`~sqlalchemy.orm.subqueryload` options
+ or mapper-level ``lazy='joined'``/``lazy='subquery'``
+ configurations.
+
+ This is used primarily when nesting the Query's
+ statement into a subquery or other
+ selectable, or when using :meth:`_query.Query.yield_per`.
+
+ """
+ self._compile_options += {"_enable_eagerloads": value}
+
+ @_generative
+ def _with_compile_options(self, **opt):
+ self._compile_options += opt
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_labels` and :meth:`_orm.Query.apply_labels`",
+ alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "instead.",
+ )
+ def with_labels(self):
+ return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ apply_labels = with_labels
+
+ @property
+ def get_label_style(self):
+ """
+ Retrieve the current label style.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._label_style
+
+ def set_label_style(self, style):
+ """Apply column labels to the return value of Query.statement.
+
+ Indicates that this Query's `statement` accessor should return
+ a SELECT statement that applies labels to all columns in the
+ form <tablename>_<columnname>; this is commonly used to
+ disambiguate columns from multiple tables which have the same
+ name.
+
+ When the `Query` actually issues SQL to load rows, it always
+ uses column labeling.
+
+ .. note:: The :meth:`_query.Query.set_label_style` method *only* applies
+ the output of :attr:`_query.Query.statement`, and *not* to any of
+ the result-row invoking systems of :class:`_query.Query` itself,
+ e.g.
+ :meth:`_query.Query.first`, :meth:`_query.Query.all`, etc.
+ To execute
+ a query using :meth:`_query.Query.set_label_style`, invoke the
+ :attr:`_query.Query.statement` using :meth:`.Session.execute`::
+
+ result = session.execute(
+ query
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement
+ )
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
+
+ @_generative
+ def enable_assertions(self, value):
+ """Control whether assertions are generated.
+
+ When set to False, the returned Query will
+ not assert its state before certain operations,
+ including that LIMIT/OFFSET has not been applied
+ when filter() is called, no criterion exists
+ when get() is called, and no "from_statement()"
+ exists when filter()/order_by()/group_by() etc.
+ is called. This more permissive mode is used by
+ custom Query subclasses to specify criterion or
+ other modifiers outside of the usual usage patterns.
+
+ Care should be taken to ensure that the usage
+ pattern is even possible. A statement applied
+ by from_statement() will override any criterion
+ set by filter() or order_by(), for example.
+
+ """
+ self._enable_assertions = value
+
+ @property
+ def whereclause(self):
+ """A readonly attribute which returns the current WHERE criterion for
+ this Query.
+
+ This returned value is a SQL expression construct, or ``None`` if no
+ criterion has been established.
+
+ """
+ return sql.elements.BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+ @_generative
+ def _with_current_path(self, path):
+ """indicate that this query applies to objects loaded
+ within a certain path.
+
+ Used by deferred loaders (see strategies.py) which transfer
+ query options from an originating query to a newly generated
+ query intended for the deferred load.
+
+ """
+ self._compile_options += {"_current_path": path}
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_polymorphic`",
+ alternative="Use the orm.with_polymorphic() standalone function",
+ )
+ def with_polymorphic(
+ self, cls_or_mappers, selectable=None, polymorphic_on=None
+ ):
+ """Load columns for inheriting classes.
+
+ This is a legacy method which is replaced by the
+ :func:`_orm.with_polymorphic` function.
+
+ .. warning:: The :meth:`_orm.Query.with_polymorphic` method does
+ **not** support 1.4/2.0 style features including
+ :func:`_orm.with_loader_criteria`. Please migrate code
+ to use :func:`_orm.with_polymorphic`.
+
+ :meth:`_query.Query.with_polymorphic` applies transformations
+ to the "main" mapped class represented by this :class:`_query.Query`.
+ The "main" mapped class here means the :class:`_query.Query`
+ object's first argument is a full class, i.e.
+ ``session.query(SomeClass)``. These transformations allow additional
+ tables to be present in the FROM clause so that columns for a
+ joined-inheritance subclass are available in the query, both for the
+ purposes of load-time efficiency as well as the ability to use
+ these columns at query time.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - illustrates current patterns
+
+ """
+
+ entity = _legacy_filter_by_entity_zero(self)
+
+ wp = with_polymorphic(
+ entity,
+ cls_or_mappers,
+ selectable=selectable,
+ polymorphic_on=polymorphic_on,
+ )
+
+ self._compile_options = self._compile_options.add_to_element(
+ "_with_polymorphic_adapt_map", ((entity, inspect(wp)),)
+ )
+
+ @_generative
+ def yield_per(self, count):
+ r"""Yield only ``count`` rows at a time.
+
+ The purpose of this method is when fetching very large result sets
+ (> 10K rows), to batch results in sub-collections and yield them
+ out partially, so that the Python interpreter doesn't need to declare
+ very large areas of memory which is both time consuming and leads
+ to excessive memory use. The performance from fetching hundreds of
+ thousands of rows can often double when a suitable yield-per setting
+ (e.g. approximately 1000) is used, even with DBAPIs that buffer
+ rows (which are most).
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.yield_per` method is
+ equivalent to using the ``yield_per`` execution option at the ORM
+ level. See the section :ref:`orm_queryguide_yield_per` for further
+ background on this option.
+
+ .. seealso::
+
+ :ref:`orm_queryguide_yield_per`
+
+ """
+ self.load_options += {"_yield_per": count}
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.get`",
+ alternative="The method is now available as :meth:`_orm.Session.get`",
+ becomes_legacy=True,
+ )
+ def get(self, ident):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ E.g.::
+
+ my_user = session.query(User).get(5)
+
+ some_object = session.query(VersionedFoo).get((5, 10))
+
+ some_object = session.query(VersionedFoo).get(
+ {"id": 5, "version_id": 10})
+
+ :meth:`_query.Query.get` is special in that it provides direct
+ access to the identity map of the owning :class:`.Session`.
+ If the given primary key identifier is present
+ in the local identity map, the object is returned
+ directly from this collection and no SQL is emitted,
+ unless the object has been marked fully expired.
+ If not present,
+ a SELECT is performed in order to locate the object.
+
+ :meth:`_query.Query.get` also will perform a check if
+ the object is present in the identity map and
+ marked as expired - a SELECT
+ is emitted to refresh the object as well as to
+ ensure that the row is still present.
+ If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ :meth:`_query.Query.get` is only used to return a single
+ mapped instance, not multiple instances or
+ individual column constructs, and strictly
+ on a single primary key value. The originating
+ :class:`_query.Query` must be constructed in this way,
+ i.e. against a single mapped entity,
+ with no additional filtering criterion. Loading
+ options via :meth:`_query.Query.options` may be applied
+ however, and will be used if the object is not
+ yet locally present.
+
+ :param ident: A scalar, tuple, or dictionary representing the
+ primary key. For a composite (e.g. multiple column) primary key,
+ a tuple or dictionary should be passed.
+
+ For a single-column primary key, the scalar calling form is typically
+ the most expedient. If the primary key of a row is the value "5",
+ the call looks like::
+
+ my_object = query.get(5)
+
+ The tuple form contains primary key values typically in
+ the order in which they correspond to the mapped
+ :class:`_schema.Table`
+ object's primary key columns, or if the
+ :paramref:`_orm.Mapper.primary_key` configuration parameter were
+ used, in
+ the order used for that parameter. For example, if the primary key
+ of a row is represented by the integer
+ digits "5, 10" the call would look like::
+
+ my_object = query.get((5, 10))
+
+ The dictionary form should include as keys the mapped attribute names
+ corresponding to each element of the primary key. If the mapped class
+ has the attributes ``id``, ``version_id`` as the attributes which
+ store the object's primary key value, the call would look like::
+
+ my_object = query.get({"id": 5, "version_id": 10})
+
+ .. versionadded:: 1.3 the :meth:`_query.Query.get`
+ method now optionally
+ accepts a dictionary of attribute names to values in order to
+ indicate a primary key identifier.
+
+
+ :return: The object instance, or ``None``.
+
+ """
+ self._no_criterion_assertion("get", order_by=False, distinct=False)
+
+ # we still implement _get_impl() so that baked query can override
+ # it
+ return self._get_impl(ident, loading.load_on_pk_identity)
+
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
+ mapper = self._only_full_mapper_zero("get")
+ return self.session._get_impl(
+ mapper,
+ primary_key_identity,
+ db_load_fn,
+ populate_existing=self.load_options._populate_existing,
+ with_for_update=self._for_update_arg,
+ options=self._with_options,
+ identity_token=identity_token,
+ execution_options=self._execution_options,
+ )
+
+ @property
+ def lazy_loaded_from(self):
+ """An :class:`.InstanceState` that is using this :class:`_query.Query`
+ for a lazy load operation.
+
+ .. deprecated:: 1.4 This attribute should be viewed via the
+ :attr:`.ORMExecuteState.lazy_loaded_from` attribute, within
+ the context of the :meth:`.SessionEvents.do_orm_execute`
+ event.
+
+ .. seealso::
+
+ :attr:`.ORMExecuteState.lazy_loaded_from`
+
+ """
+ return self.load_options._lazy_loaded_from
+
+ @property
+ def _current_path(self):
+ return self._compile_options._current_path
+
+ @_generative
+ def correlate(self, *fromclauses):
+ """Return a :class:`.Query` construct which will correlate the given
+ FROM clauses to that of an enclosing :class:`.Query` or
+ :func:`~.expression.select`.
+
+ The method here accepts mapped classes, :func:`.aliased` constructs,
+ and :func:`.mapper` constructs as arguments, which are resolved into
+ expression constructs, in addition to appropriate expression
+ constructs.
+
+ The correlation arguments are ultimately passed to
+ :meth:`_expression.Select.correlate`
+ after coercion to expression constructs.
+
+ The correlation arguments take effect in such cases
+ as when :meth:`_query.Query.from_self` is used, or when
+ a subquery as returned by :meth:`_query.Query.subquery` is
+ embedded in another :func:`_expression.select` construct.
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate = ()
+ else:
+ self._correlate = set(self._correlate).union(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @_generative
+ def autoflush(self, setting):
+ """Return a Query with a specific 'autoflush' setting.
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
+ is equivalent to using the ``autoflush`` execution option at the
+ ORM level. See the section :ref:`orm_queryguide_autoflush` for
+ further background on this option.
+
+ """
+ self.load_options += {"_autoflush": setting}
+
+ @_generative
+ def populate_existing(self):
+ """Return a :class:`_query.Query`
+ that will expire and refresh all instances
+ as they are loaded, or reused from the current :class:`.Session`.
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.populate_existing` method
+ is equivalent to using the ``populate_existing`` execution option at
+ the ORM level. See the section :ref:`orm_queryguide_populate_existing`
+ for further background on this option.
+
+ """
+ self.load_options += {"_populate_existing": True}
+
+ @_generative
+ def _with_invoke_all_eagers(self, value):
+ """Set the 'invoke all eagers' flag which causes joined- and
+ subquery loaders to traverse into already-loaded related objects
+ and collections.
+
+ Default is that of :attr:`_query.Query._invoke_all_eagers`.
+
+ """
+ self.load_options += {"_invoke_all_eagers": value}
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_parent`",
+ alternative="Use the :func:`_orm.with_parent` standalone construct.",
+ becomes_legacy=True,
+ )
+ @util.preload_module("sqlalchemy.orm.relationships")
+ def with_parent(self, instance, property=None, from_entity=None): # noqa
+ """Add filtering criterion that relates the given instance
+ to a child object or collection, using its attribute state
+ as well as an established :func:`_orm.relationship()`
+ configuration.
+
+ The method uses the :func:`.with_parent` function to generate
+ the clause, the result of which is passed to
+ :meth:`_query.Query.filter`.
+
+ Parameters are the same as :func:`.with_parent`, with the exception
+ that the given property can be None, in which case a search is
+ performed against this :class:`_query.Query` object's target mapper.
+
+ :param instance:
+ An instance which has some :func:`_orm.relationship`.
+
+ :param property:
+ String property name, or class-bound attribute, which indicates
+ what relationship from the instance should be used to reconcile the
+ parent/child relationship.
+
+ :param from_entity:
+ Entity in which to consider as the left side. This defaults to the
+ "zero" entity of the :class:`_query.Query` itself.
+
+ """
+ relationships = util.preloaded.orm_relationships
+
+ if from_entity:
+ entity_zero = inspect(from_entity)
+ else:
+ entity_zero = _legacy_filter_by_entity_zero(self)
+ if property is None:
+ # TODO: deprecate, property has to be supplied
+ mapper = object_mapper(instance)
+
+ for prop in mapper.iterate_properties:
+ if (
+ isinstance(prop, relationships.RelationshipProperty)
+ and prop.mapper is entity_zero.mapper
+ ):
+ property = prop # noqa
+ break
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Could not locate a property which relates instances "
+ "of class '%s' to instances of class '%s'"
+ % (
+ entity_zero.mapper.class_.__name__,
+ instance.__class__.__name__,
+ )
+ )
+
+ return self.filter(with_parent(instance, property, entity_zero.entity))
+
+ @_generative
+ def add_entity(self, entity, alias=None):
+ """add a mapped entity to the list of result columns
+ to be returned."""
+
+ if alias is not None:
+ # TODO: deprecate
+ entity = aliased(entity, alias)
+
+ self._raw_columns = list(self._raw_columns)
+
+ self._raw_columns.append(
+ coercions.expect(
+ roles.ColumnsClauseRole, entity, apply_propagate_attrs=self
+ )
+ )
+
+ @_generative
+ def with_session(self, session):
+ """Return a :class:`_query.Query` that will use the given
+ :class:`.Session`.
+
+ While the :class:`_query.Query`
+ object is normally instantiated using the
+ :meth:`.Session.query` method, it is legal to build the
+ :class:`_query.Query`
+ directly without necessarily using a :class:`.Session`. Such a
+ :class:`_query.Query` object, or any :class:`_query.Query`
+ already associated
+ with a different :class:`.Session`, can produce a new
+ :class:`_query.Query`
+ object associated with a target session using this method::
+
+ from sqlalchemy.orm import Query
+
+ query = Query([MyClass]).filter(MyClass.id == 5)
+
+ result = query.with_session(my_session).one()
+
+ """
+
+ self.session = session
+
+ @util.deprecated_20(
+ ":meth:`_query.Query.from_self`",
+ alternative="The new approach is to use the :func:`.orm.aliased` "
+ "construct in conjunction with a subquery. See the section "
+ ":ref:`Selecting from the query itself as a subquery "
+ "<migration_20_query_from_self>` in the 2.0 migration notes for an "
+ "example.",
+ )
+ def from_self(self, *entities):
+ r"""return a Query that selects from this Query's
+ SELECT statement.
+
+ :meth:`_query.Query.from_self` essentially turns the SELECT statement
+ into a SELECT of itself. Given a query such as::
+
+ q = session.query(User).filter(User.name.like('e%'))
+
+ Given the :meth:`_query.Query.from_self` version::
+
+ q = session.query(User).filter(User.name.like('e%')).from_self()
+
+ This query renders as:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1) AS anon_1
+
+ There are lots of cases where :meth:`_query.Query.from_self`
+ may be useful.
+ A simple one is where above, we may want to apply a row LIMIT to
+ the set of user objects we query against, and then apply additional
+ joins against that row-limited set::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self().\
+ join(User.addresses).filter(Address.email.like('q%'))
+
+ The above query joins to the ``Address`` entity but only against the
+ first five results of the ``User`` query:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1
+
+ **Automatic Aliasing**
+
+ Another key behavior of :meth:`_query.Query.from_self`
+ is that it applies
+ **automatic aliasing** to the entities inside the subquery, when
+ they are referenced on the outside. Above, if we continue to
+ refer to the ``User`` entity without any additional aliasing applied
+ to it, those references will be in terms of the subquery::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self().\
+ join(User.addresses).filter(Address.email.like('q%')).\
+ order_by(User.name)
+
+ The ORDER BY against ``User.name`` is aliased to be in terms of the
+ inner subquery:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1 ORDER BY anon_1.user_name
+
+ The automatic aliasing feature only works in a **limited** way,
+ for simple filters and orderings. More ambitious constructions
+ such as referring to the entity in joins should prefer to use
+ explicit subquery objects, typically making use of the
+ :meth:`_query.Query.subquery`
+ method to produce an explicit subquery object.
+ Always test the structure of queries by viewing the SQL to ensure
+ a particular structure does what's expected!
+
+ **Changing the Entities**
+
+ :meth:`_query.Query.from_self`
+ also includes the ability to modify what
+ columns are being queried. In our example, we want ``User.id``
+ to be queried by the inner query, so that we can join to the
+ ``Address`` entity on the outside, but we only wanted the outer
+ query to return the ``Address.email`` column::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self(Address.email).\
+ join(User.addresses).filter(Address.email.like('q%'))
+
+ yielding:
+
+ .. sourcecode:: sql
+
+ SELECT address.email AS address_email
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1
+
+ **Looking out for Inner / Outer Columns**
+
+ Keep in mind that when referring to columns that originate from
+ inside the subquery, we need to ensure they are present in the
+ columns clause of the subquery itself; this is an ordinary aspect of
+ SQL. For example, if we wanted to load from a joined entity inside
+ the subquery using :func:`.contains_eager`, we need to add those
+ columns. Below illustrates a join of ``Address`` to ``User``,
+ then a subquery, and then we'd like :func:`.contains_eager` to access
+ the ``User`` columns::
+
+ q = session.query(Address).join(Address.user).\
+ filter(User.name.like('e%'))
+
+ q = q.add_entity(User).from_self().\
+ options(contains_eager(Address.user))
+
+ We use :meth:`_query.Query.add_entity` above **before** we call
+ :meth:`_query.Query.from_self`
+ so that the ``User`` columns are present
+ in the inner subquery, so that they are available to the
+ :func:`.contains_eager` modifier we are using on the outside,
+ producing:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.address_id AS anon_1_address_id,
+ anon_1.address_email AS anon_1_address_email,
+ anon_1.address_user_id AS anon_1_address_user_id,
+ anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (
+ SELECT address.id AS address_id,
+ address.email AS address_email,
+ address.user_id AS address_user_id,
+ "user".id AS user_id,
+ "user".name AS user_name
+ FROM address JOIN "user" ON "user".id = address.user_id
+ WHERE "user".name LIKE :name_1) AS anon_1
+
+ If we didn't call ``add_entity(User)``, but still asked
+ :func:`.contains_eager` to load the ``User`` entity, it would be
+ forced to add the table on the outside without the correct
+ join criteria - note the ``anon1, "user"`` phrase at
+ the end:
+
+ .. sourcecode:: sql
+
+ -- incorrect query
+ SELECT anon_1.address_id AS anon_1_address_id,
+ anon_1.address_email AS anon_1_address_email,
+ anon_1.address_user_id AS anon_1_address_user_id,
+ "user".id AS user_id,
+ "user".name AS user_name
+ FROM (
+ SELECT address.id AS address_id,
+ address.email AS address_email,
+ address.user_id AS address_user_id
+ FROM address JOIN "user" ON "user".id = address.user_id
+ WHERE "user".name LIKE :name_1) AS anon_1, "user"
+
+ :param \*entities: optional list of entities which will replace
+ those being selected.
+
+ """
+ return self._from_self(*entities)
+
+ def _from_self(self, *entities):
+ fromclause = (
+ self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .correlate(None)
+ .subquery()
+ ._anonymous_fromclause()
+ )
+
+ q = self._from_selectable(fromclause)
+
+ if entities:
+ q._set_entities(entities)
+ return q
+
+ @_generative
+ def _set_enable_single_crit(self, val):
+ self._compile_options += {"_enable_single_crit": val}
+
+ @_generative
+ def _from_selectable(self, fromclause, set_entity_from=True):
+ for attr in (
+ "_where_criteria",
+ "_order_by_clauses",
+ "_group_by_clauses",
+ "_limit_clause",
+ "_offset_clause",
+ "_last_joined_entity",
+ "_legacy_setup_joins",
+ "_memoized_select_entities",
+ "_distinct",
+ "_distinct_on",
+ "_having_criteria",
+ "_prefixes",
+ "_suffixes",
+ ):
+ self.__dict__.pop(attr, None)
+ self._set_select_from([fromclause], set_entity_from)
+ self._compile_options += {
+ "_enable_single_crit": False,
+ }
+
+ # this enables clause adaptation for non-ORM
+ # expressions.
+ # legacy. see test/orm/test_froms.py for various
+ # "oldstyle" tests that rely on this and the corresponding
+ # "newtyle" that do not.
+ self._compile_options += {"_orm_only_from_obj_alias": False}
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.values` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.with_entities`",
+ )
+ def values(self, *columns):
+ """Return an iterator yielding result tuples corresponding
+ to the given list of columns
+
+ """
+
+ if not columns:
+ return iter(())
+ q = self._clone().enable_eagerloads(False)
+ q._set_entities(columns)
+ if not q.load_options._yield_per:
+ q.load_options += {"_yield_per": 10}
+ return iter(q)
+
+ _values = values
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.value` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.with_entities` "
+ "in combination with :meth:`_query.Query.scalar`",
+ )
+ def value(self, column):
+ """Return a scalar result corresponding to the given
+ column expression.
+
+ """
+ try:
+ return next(self.values(column))[0]
+ except StopIteration:
+ return None
+
+ @_generative
+ def with_entities(self, *entities):
+ r"""Return a new :class:`_query.Query`
+ replacing the SELECT list with the
+ given entities.
+
+ e.g.::
+
+ # Users, filtered on some arbitrary criterion
+ # and then ordered by related email address
+ q = session.query(User).\
+ join(User.address).\
+ filter(User.name.like('%ed%')).\
+ order_by(Address.email)
+
+ # given *only* User.id==5, Address.email, and 'q', what
+ # would the *next* User in the result be ?
+ subq = q.with_entities(Address.email).\
+ order_by(None).\
+ filter(User.id==5).\
+ subquery()
+ q = q.join((subq, subq.c.email < Address.email)).\
+ limit(1)
+
+ """
+ _MemoizedSelectEntities._generate_for_statement(self)
+ self._set_entities(entities)
+
+ @_generative
+ def add_columns(self, *column):
+ """Add one or more column expressions to the list
+ of result columns to be returned."""
+
+ self._raw_columns = list(self._raw_columns)
+
+ self._raw_columns.extend(
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ c,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for c in column
+ )
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.add_column` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.add_columns`",
+ )
+ def add_column(self, column):
+ """Add a column expression to the list of result columns to be
+ returned.
+
+ """
+ return self.add_columns(column)
+
+ @_generative
+ def options(self, *args):
+ """Return a new :class:`_query.Query` object,
+ applying the given list of
+ mapper options.
+
+ Most supplied options regard changing how column- and
+ relationship-mapped attributes are loaded.
+
+ .. seealso::
+
+ :ref:`deferred_options`
+
+ :ref:`relationship_loader_options`
+
+ """
+
+ opts = tuple(util.flatten_iterator(args))
+ if self._compile_options._current_path:
+ for opt in opts:
+ if opt._is_legacy_option:
+ opt.process_query_conditionally(self)
+ else:
+ for opt in opts:
+ if opt._is_legacy_option:
+ opt.process_query(self)
+
+ self._with_options += opts
+
+ def with_transformation(self, fn):
+ """Return a new :class:`_query.Query` object transformed by
+ the given function.
+
+ E.g.::
+
+ def filter_something(criterion):
+ def transform(q):
+ return q.filter(criterion)
+ return transform
+
+ q = q.with_transformation(filter_something(x==5))
+
+ This allows ad-hoc recipes to be created for :class:`_query.Query`
+ objects. See the example at :ref:`hybrid_transformers`.
+
+ """
+ return fn(self)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`_query.Query.execution_options`
+ """
+ return self._execution_options
+
+ @_generative
+ def execution_options(self, **kwargs):
+ """Set non-SQL options which take effect during execution.
+
+ Options allowed here include all of those accepted by
+ :meth:`_engine.Connection.execution_options`, as well as a series
+ of ORM specific options:
+
+ ``populate_existing=True`` - equivalent to using
+ :meth:`_orm.Query.populate_existing`
+
+ ``autoflush=True|False`` - equivalent to using
+ :meth:`_orm.Query.autoflush`
+
+ ``yield_per=<value>`` - equivalent to using
+ :meth:`_orm.Query.yield_per`
+
+ Note that the ``stream_results`` execution option is enabled
+ automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()`
+ method or execution option is used.
+
+ .. versionadded:: 1.4 - added ORM options to
+ :meth:`_orm.Query.execution_options`
+
+ The execution options may also be specified on a per execution basis
+ when using :term:`2.0 style` queries via the
+ :paramref:`_orm.Session.execution_options` parameter.
+
+ .. warning:: The
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ parameter should not be used at the level of individual ORM
+ statement executions, as the :class:`_orm.Session` will not track
+ objects from different schema translate maps within a single
+ session. For multiple schema translate maps within the scope of a
+ single :class:`_orm.Session`, see :ref:`examples_sharding`.
+
+
+ .. seealso::
+
+ :ref:`engine_stream_results`
+
+ :meth:`_query.Query.get_execution_options`
+
+ """
+ self._execution_options = self._execution_options.union(kwargs)
+
+ @_generative
+ def with_for_update(
+ self,
+ read=False,
+ nowait=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """return a new :class:`_query.Query`
+ with the specified options for the
+ ``FOR UPDATE`` clause.
+
+ The behavior of this method is identical to that of
+ :meth:`_expression.GenerativeSelect.with_for_update`.
+ When called with no arguments,
+ the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause
+ appended. When additional arguments are specified, backend-specific
+ options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE``
+ can take effect.
+
+ E.g.::
+
+ q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User)
+
+ The above query on a PostgreSQL backend will render like::
+
+ SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT
+
+ .. warning::
+
+ Using ``with_for_update`` in the context of eager loading
+ relationships is not officially supported or recommended by
+ SQLAlchemy and may not work with certain queries on various
+ database backends. When ``with_for_update`` is successfully used
+ with a query that involves :func:`_orm.joinedload`, SQLAlchemy will
+ attempt to emit SQL that locks all involved tables.
+
+ .. note:: It is generally a good idea to combine the use of the
+ :meth:`_orm.Query.populate_existing` method when using the
+ :meth:`_orm.Query.with_for_update` method. The purpose of
+ :meth:`_orm.Query.populate_existing` is to force all the data read
+ from the SELECT to be populated into the ORM objects returned,
+ even if these objects are already in the :term:`identity map`.
+
+ .. seealso::
+
+ :meth:`_expression.GenerativeSelect.with_for_update`
+ - Core level method with
+ full argument and behavioral description.
+
+ :meth:`_orm.Query.populate_existing` - overwrites attributes of
+ objects already loaded in the identity map.
+
+ """ # noqa: E501
+
+ self._for_update_arg = ForUpdateArg(
+ read=read,
+ nowait=nowait,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
+
+ @_generative
+ def params(self, *args, **kwargs):
+ r"""Add values for bind parameters which may have been
+ specified in filter().
+
+ Parameters may be specified using \**kwargs, or optionally a single
+ dictionary as the first positional argument. The reason for both is
+ that \**kwargs is convenient, however some parameter dictionaries
+ contain unicode keys in which case \**kwargs cannot be used.
+
+ """
+ if len(args) == 1:
+ kwargs.update(args[0])
+ elif len(args) > 0:
+ raise sa_exc.ArgumentError(
+ "params() takes zero or one positional argument, "
+ "which is a dictionary."
+ )
+ self._params = self._params.union(kwargs)
+
+ def where(self, *criterion):
+ """A synonym for :meth:`.Query.filter`.
+
+ .. versionadded:: 1.4
+
+ """
+ return self.filter(*criterion)
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def filter(self, *criterion):
+ r"""Apply the given filtering criterion to a copy
+ of this :class:`_query.Query`, using SQL expressions.
+
+ e.g.::
+
+ session.query(MyClass).filter(MyClass.name == 'some name')
+
+ Multiple criteria may be specified as comma separated; the effect
+ is that they will be joined together using the :func:`.and_`
+ function::
+
+ session.query(MyClass).\
+ filter(MyClass.name == 'some name', MyClass.id > 5)
+
+ The criterion is any SQL expression object applicable to the
+ WHERE clause of a select. String expressions are coerced
+ into SQL expression constructs via the :func:`_expression.text`
+ construct.
+
+ .. seealso::
+
+ :meth:`_query.Query.filter_by` - filter on keyword expressions.
+
+ """
+ for criterion in list(criterion):
+ criterion = coercions.expect(
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = sql_util._deep_annotate(
+ criterion, {"aliased_generation": self._aliased_generation}
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._where_criteria += (criterion,)
+
+ @util.memoized_property
+ def _last_joined_entity(self):
+ if self._legacy_setup_joins:
+ return _legacy_determine_last_joined_entity(
+ self._legacy_setup_joins, self._entity_from_pre_ent_zero()
+ )
+ else:
+ return None
+
+ def _filter_by_zero(self):
+ """for the filter_by() method, return the target entity for which
+ we will attempt to derive an expression from based on string name.
+
+ """
+
+ if self._legacy_setup_joins:
+ _last_joined_entity = self._last_joined_entity
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ # discussion related to #7239
+ # special check determines if we should try to derive attributes
+ # for filter_by() from the "from object", i.e., if the user
+ # called query.select_from(some selectable).filter_by(some_attr=value).
+ # We don't want to do that in the case that methods like
+ # from_self(), select_entity_from(), or a set op like union() were
+ # called; while these methods also place a
+ # selectable in the _from_obj collection, they also set up
+ # the _set_base_alias boolean which turns on the whole "adapt the
+ # entity to this selectable" thing, meaning the query still continues
+ # to construct itself in terms of the lead entity that was passed
+ # to query(), e.g. query(User).from_self() is still in terms of User,
+ # and not the subquery that from_self() created. This feature of
+ # "implicitly adapt all occurrences of entity X to some arbitrary
+ # subquery" is the main thing I am trying to do away with in 2.0 as
+ # users should now used aliased() for that, but I can't entirely get
+ # rid of it due to query.union() and other set ops relying upon it.
+ #
+ # compare this to the base Select()._filter_by_zero() which can
+ # just return self._from_obj[0] if present, because there is no
+ # "_set_base_alias" feature.
+ #
+ # IOW, this conditional essentially detects if
+ # "select_from(some_selectable)" has been called, as opposed to
+ # "select_entity_from()", "from_self()"
+ # or "union() / some_set_op()".
+ if self._from_obj and not self._compile_options._set_base_alias:
+ return self._from_obj[0]
+
+ return self._raw_columns[0]
+
+ def filter_by(self, **kwargs):
+ r"""Apply the given filtering criterion to a copy
+ of this :class:`_query.Query`, using keyword expressions.
+
+ e.g.::
+
+ session.query(MyClass).filter_by(name = 'some name')
+
+ Multiple criteria may be specified as comma separated; the effect
+ is that they will be joined together using the :func:`.and_`
+ function::
+
+ session.query(MyClass).\
+ filter_by(name = 'some name', id = 5)
+
+ The keyword expressions are extracted from the primary
+ entity of the query, or the last entity that was the
+ target of a call to :meth:`_query.Query.join`.
+
+ .. seealso::
+
+ :meth:`_query.Query.filter` - filter on SQL expressions.
+
+ """
+ from_entity = self._filter_by_zero()
+ if from_entity is None:
+ raise sa_exc.InvalidRequestError(
+ "Can't use filter_by when the first entity '%s' of a query "
+ "is not a mapped class. Please use the filter method instead, "
+ "or change the order of the entities in the query"
+ % self._query_entity_zero()
+ )
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def order_by(self, *clauses):
+ """Apply one or more ORDER BY criteria to the query and return
+ the newly resulting :class:`_query.Query`.
+
+ e.g.::
+
+ q = session.query(Entity).order_by(Entity.id, Entity.name)
+
+ All existing ORDER BY criteria may be cancelled by passing
+ ``None`` by itself. New ORDER BY criteria may then be added by
+ invoking :meth:`_orm.Query.order_by` again, e.g.::
+
+ # will erase all ORDER BY and ORDER BY new_col alone
+ q = q.order_by(None).order_by(new_col)
+
+ .. seealso::
+
+ These sections describe ORDER BY in terms of :term:`2.0 style`
+ invocation but apply to :class:`_orm.Query` as well:
+
+ :ref:`tutorial_order_by` - in the :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ self._order_by_clauses = ()
+ else:
+ criterion = tuple(
+ coercions.expect(roles.OrderByRole, clause)
+ for clause in clauses
+ )
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = tuple(
+ [
+ sql_util._deep_annotate(
+ o, {"aliased_generation": self._aliased_generation}
+ )
+ for o in criterion
+ ]
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._order_by_clauses += criterion
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def group_by(self, *clauses):
+ """Apply one or more GROUP BY criterion to the query and return
+ the newly resulting :class:`_query.Query`.
+
+ All existing GROUP BY settings can be suppressed by
+ passing ``None`` - this will suppress any GROUP BY configured
+ on mappers as well.
+
+ .. seealso::
+
+ These sections describe GROUP BY in terms of :term:`2.0 style`
+ invocation but apply to :class:`_orm.Query` as well:
+
+ :ref:`tutorial_group_by_w_aggregates` - in the
+ :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ self._group_by_clauses = ()
+ else:
+ criterion = tuple(
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
+ )
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = tuple(
+ [
+ sql_util._deep_annotate(
+ o, {"aliased_generation": self._aliased_generation}
+ )
+ for o in criterion
+ ]
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._group_by_clauses += criterion
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def having(self, criterion):
+ r"""Apply a HAVING criterion to the query and return the
+ newly resulting :class:`_query.Query`.
+
+ :meth:`_query.Query.having` is used in conjunction with
+ :meth:`_query.Query.group_by`.
+
+ HAVING criterion makes it possible to use filters on aggregate
+ functions like COUNT, SUM, AVG, MAX, and MIN, eg.::
+
+ q = session.query(User.id).\
+ join(User.addresses).\
+ group_by(User.id).\
+ having(func.count(Address.id) > 2)
+
+ """
+
+ self._having_criteria += (
+ coercions.expect(
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ ),
+ )
+
+ def _set_op(self, expr_fn, *q):
+ return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
+
+ def union(self, *q):
+ """Produce a UNION of this Query against one or more queries.
+
+ e.g.::
+
+ q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar')
+ q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo')
+
+ q3 = q1.union(q2)
+
+ The method accepts multiple Query objects so as to control
+ the level of nesting. A series of ``union()`` calls such as::
+
+ x.union(y).union(z).all()
+
+ will nest on each ``union()``, and produces::
+
+ SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION
+ SELECT * FROM y) UNION SELECT * FROM Z)
+
+ Whereas::
+
+ x.union(y, z).all()
+
+ produces::
+
+ SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION
+ SELECT * FROM Z)
+
+ Note that many database backends do not allow ORDER BY to
+ be rendered on a query called within UNION, EXCEPT, etc.
+ To disable all ORDER BY clauses including those configured
+ on mappers, issue ``query.order_by(None)`` - the resulting
+ :class:`_query.Query` object will not render ORDER BY within
+ its SELECT statement.
+
+ """
+ return self._set_op(expression.union, *q)
+
+ def union_all(self, *q):
+ """Produce a UNION ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.union_all, *q)
+
+ def intersect(self, *q):
+ """Produce an INTERSECT of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.intersect, *q)
+
+ def intersect_all(self, *q):
+ """Produce an INTERSECT ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.intersect_all, *q)
+
+ def except_(self, *q):
+ """Produce an EXCEPT of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.except_, *q)
+
+ def except_all(self, *q):
+ """Produce an EXCEPT ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.except_all, *q)
+
+ def _next_aliased_generation(self):
+ if "_aliased_generation_counter" not in self.__dict__:
+ self._aliased_generation_counter = 0
+ self._aliased_generation_counter += 1
+ return self._aliased_generation_counter
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def join(self, target, *props, **kwargs):
+ r"""Create a SQL JOIN against this :class:`_query.Query`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_query.Query`.
+
+ **Simple Relationship Joins**
+
+ Consider a mapping between two classes ``User`` and ``Address``,
+ with a relationship ``User.addresses`` representing a collection
+ of ``Address`` objects associated with each ``User``. The most
+ common usage of :meth:`_query.Query.join`
+ is to create a JOIN along this
+ relationship, using the ``User.addresses`` attribute as an indicator
+ for how this should occur::
+
+ q = session.query(User).join(User.addresses)
+
+ Where above, the call to :meth:`_query.Query.join` along
+ ``User.addresses`` will result in SQL approximately equivalent to::
+
+ SELECT user.id, user.name
+ FROM user JOIN address ON user.id = address.user_id
+
+ In the above example we refer to ``User.addresses`` as passed to
+ :meth:`_query.Query.join` as the "on clause", that is, it indicates
+ how the "ON" portion of the JOIN should be constructed.
+
+ To construct a chain of joins, multiple :meth:`_query.Query.join`
+ calls may be used. The relationship-bound attribute implies both
+ the left and right side of the join at once::
+
+ q = session.query(User).\
+ join(User.orders).\
+ join(Order.items).\
+ join(Item.keywords)
+
+ .. note:: as seen in the above example, **the order in which each
+ call to the join() method occurs is important**. Query would not,
+ for example, know how to join correctly if we were to specify
+ ``User``, then ``Item``, then ``Order``, in our chain of joins; in
+ such a case, depending on the arguments passed, it may raise an
+ error that it doesn't know how to join, or it may produce invalid
+ SQL in which case the database will raise an error. In correct
+ practice, the
+ :meth:`_query.Query.join` method is invoked in such a way that lines
+ up with how we would want the JOIN clauses in SQL to be
+ rendered, and each call should represent a clear link from what
+ precedes it.
+
+ **Joins to a Target Entity or Selectable**
+
+ A second form of :meth:`_query.Query.join` allows any mapped entity or
+ core selectable construct as a target. In this usage,
+ :meth:`_query.Query.join` will attempt to create a JOIN along the
+ natural foreign key relationship between two entities::
+
+ q = session.query(User).join(Address)
+
+ In the above calling form, :meth:`_query.Query.join` is called upon to
+ create the "on clause" automatically for us. This calling form will
+ ultimately raise an error if either there are no foreign keys between
+ the two entities, or if there are multiple foreign key linkages between
+ the target entity and the entity or entities already present on the
+ left side such that creating a join requires more information. Note
+ that when indicating a join to a target without any ON clause, ORM
+ configured relationships are not taken into account.
+
+ **Joins to a Target with an ON Clause**
+
+ The third calling form allows both the target entity as well
+ as the ON clause to be passed explicitly. A example that includes
+ a SQL expression as the ON clause is as follows::
+
+ q = session.query(User).join(Address, User.id==Address.user_id)
+
+ The above form may also use a relationship-bound attribute as the
+ ON clause as well::
+
+ q = session.query(User).join(Address, User.addresses)
+
+ The above syntax can be useful for the case where we wish
+ to join to an alias of a particular target entity. If we wanted
+ to join to ``Address`` twice, it could be achieved using two
+ aliases set up using the :func:`~sqlalchemy.orm.aliased` function::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+
+ q = session.query(User).\
+ join(a1, User.addresses).\
+ join(a2, User.addresses).\
+ filter(a1.email_address=='ed@foo.com').\
+ filter(a2.email_address=='ed@bar.com')
+
+ The relationship-bound calling form can also specify a target entity
+ using the :meth:`_orm.PropComparator.of_type` method; a query
+ equivalent to the one above would be::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+
+ q = session.query(User).\
+ join(User.addresses.of_type(a1)).\
+ join(User.addresses.of_type(a2)).\
+ filter(a1.email_address == 'ed@foo.com').\
+ filter(a2.email_address == 'ed@bar.com')
+
+ **Augmenting Built-in ON Clauses**
+
+ As a substitute for providing a full custom ON condition for an
+ existing relationship, the :meth:`_orm.PropComparator.and_` function
+ may be applied to a relationship attribute to augment additional
+ criteria into the ON clause; the additional criteria will be combined
+ with the default criteria using AND::
+
+ q = session.query(User).join(
+ User.addresses.and_(Address.email_address != 'foo@bar.com')
+ )
+
+ .. versionadded:: 1.4
+
+ **Joining to Tables and Subqueries**
+
+
+ The target of a join may also be any table or SELECT statement,
+ which may be related to a target entity or not. Use the
+ appropriate ``.subquery()`` method in order to make a subquery
+ out of a query::
+
+ subq = session.query(Address).\
+ filter(Address.email_address == 'ed@foo.com').\
+ subquery()
+
+
+ q = session.query(User).join(
+ subq, User.id == subq.c.user_id
+ )
+
+ Joining to a subquery in terms of a specific relationship and/or
+ target entity may be achieved by linking the subquery to the
+ entity using :func:`_orm.aliased`::
+
+ subq = session.query(Address).\
+ filter(Address.email_address == 'ed@foo.com').\
+ subquery()
+
+ address_subq = aliased(Address, subq)
+
+ q = session.query(User).join(
+ User.addresses.of_type(address_subq)
+ )
+
+
+ **Controlling what to Join From**
+
+ In cases where the left side of the current state of
+ :class:`_query.Query` is not in line with what we want to join from,
+ the :meth:`_query.Query.select_from` method may be used::
+
+ q = session.query(Address).select_from(User).\
+ join(User.addresses).\
+ filter(User.name == 'ed')
+
+ Which will produce SQL similar to::
+
+ SELECT address.* FROM user
+ JOIN address ON user.id=address.user_id
+ WHERE user.name = :name_1
+
+ **Legacy Features of Query.join()**
+
+ .. deprecated:: 1.4 The following features are deprecated and will
+ be removed in SQLAlchemy 2.0.
+
+ The :meth:`_query.Query.join` method currently supports several
+ usage patterns and arguments that are considered to be legacy
+ as of SQLAlchemy 1.3. A deprecation path will follow
+ in the 1.4 series for the following features:
+
+
+ * Joining on relationship names rather than attributes::
+
+ session.query(User).join("addresses")
+
+ **Why it's legacy**: the string name does not provide enough context
+ for :meth:`_query.Query.join` to always know what is desired,
+ notably in that there is no indication of what the left side
+ of the join should be. This gives rise to flags like
+ ``from_joinpoint`` as well as the ability to place several
+ join clauses in a single :meth:`_query.Query.join` call
+ which don't solve the problem fully while also
+ adding new calling styles that are unnecessary and expensive to
+ accommodate internally.
+
+ **Modern calling pattern**: Use the actual relationship,
+ e.g. ``User.addresses`` in the above case::
+
+ session.query(User).join(User.addresses)
+
+ * Automatic aliasing with the ``aliased=True`` flag::
+
+ session.query(Node).join(Node.children, aliased=True).\
+ filter(Node.name == 'some name')
+
+ **Why it's legacy**: the automatic aliasing feature of
+ :class:`_query.Query` is intensely complicated, both in its internal
+ implementation as well as in its observed behavior, and is almost
+ never used. It is difficult to know upon inspection where and when
+ its aliasing of a target entity, ``Node`` in the above case, will be
+ applied and when it won't, and additionally the feature has to use
+ very elaborate heuristics to achieve this implicit behavior.
+
+ **Modern calling pattern**: Use the :func:`_orm.aliased` construct
+ explicitly::
+
+ from sqlalchemy.orm import aliased
+
+ n1 = aliased(Node)
+
+ session.query(Node).join(Node.children.of_type(n1)).\
+ filter(n1.name == 'some name')
+
+ * Multiple joins in one call::
+
+ session.query(User).join("orders", "items")
+
+ session.query(User).join(User.orders, Order.items)
+
+ session.query(User).join(
+ (Order, User.orders),
+ (Item, Item.order_id == Order.id)
+ )
+
+ session.query(User).join(Order, Item)
+
+ # ... and several more forms actually
+
+ **Why it's legacy**: being able to chain multiple ON clauses in one
+ call to :meth:`_query.Query.join` is yet another attempt to solve
+ the problem of being able to specify what entity to join from,
+ and is the source of a large variety of potential calling patterns
+ that are internally expensive and complicated to parse and
+ accommodate.
+
+ **Modern calling pattern**: Use relationship-bound attributes
+ or SQL-oriented ON clauses within separate calls, so that
+ each call to :meth:`_query.Query.join` knows what the left
+ side should be::
+
+ session.query(User).join(User.orders).join(
+ Item, Item.order_id == Order.id)
+
+
+ :param \*props: Incoming arguments for :meth:`_query.Query.join`,
+ the props collection in modern use should be considered to be a one
+ or two argument form, either as a single "target" entity or ORM
+ attribute-bound relationship, or as a target entity plus an "on
+ clause" which may be a SQL expression or ORM attribute-bound
+ relationship.
+
+ :param isouter=False: If True, the join used will be a left outer join,
+ just as if the :meth:`_query.Query.outerjoin` method were called.
+
+ :param full=False: render FULL OUTER JOIN; implies ``isouter``.
+
+ .. versionadded:: 1.1
+
+ :param from_joinpoint=False: When using ``aliased=True``, a setting
+ of True here will cause the join to be from the most recent
+ joined target, rather than starting back from the original
+ FROM clauses of the query.
+
+ .. note:: This flag is considered legacy.
+
+ :param aliased=False: If True, indicate that the JOIN target should be
+ anonymously aliased. Subsequent calls to :meth:`_query.Query.filter`
+ and similar will adapt the incoming criterion to the target
+ alias, until :meth:`_query.Query.reset_joinpoint` is called.
+
+ .. note:: This flag is considered legacy.
+
+ .. seealso::
+
+ :ref:`ormtutorial_joins` in the ORM tutorial.
+
+ :ref:`inheritance_toplevel` for details on how
+ :meth:`_query.Query.join` is used for inheritance relationships.
+
+ :func:`_orm.join` - a standalone ORM-level join function,
+ used internally by :meth:`_query.Query.join`, which in previous
+ SQLAlchemy versions was the primary ORM-level joining interface.
+
+ """
+
+ aliased, from_joinpoint, isouter, full = (
+ kwargs.pop("aliased", False),
+ kwargs.pop("from_joinpoint", False),
+ kwargs.pop("isouter", False),
+ kwargs.pop("full", False),
+ )
+
+ if aliased or from_joinpoint:
+ util.warn_deprecated_20(
+ "The ``aliased`` and ``from_joinpoint`` keyword arguments "
+ "to Query.join() are deprecated and will be removed "
+ "in SQLAlchemy 2.0."
+ )
+
+ if kwargs:
+ raise TypeError(
+ "unknown arguments: %s" % ", ".join(sorted(kwargs))
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if not from_joinpoint:
+ self._last_joined_entity = None
+ self._aliased_generation = None
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if props:
+ onclause, legacy = props[0], props[1:]
+ else:
+ onclause = legacy = None
+
+ if not legacy and onclause is None and not isinstance(target, tuple):
+ # non legacy argument form
+ _props = [(target,)]
+ elif (
+ not legacy
+ and isinstance(
+ target,
+ (
+ expression.Selectable,
+ type,
+ AliasedClass,
+ types.FunctionType,
+ ),
+ )
+ and isinstance(
+ onclause,
+ (
+ elements.ColumnElement,
+ str,
+ interfaces.PropComparator,
+ types.FunctionType,
+ ),
+ )
+ ):
+ # non legacy argument form
+ _props = [(target, onclause)]
+ else:
+ # legacy forms. more time consuming :)
+ _props = []
+ _single = []
+ for prop in (target,) + props:
+ if isinstance(prop, tuple):
+ util.warn_deprecated_20(
+ "Query.join() will no longer accept tuples as "
+ "arguments in SQLAlchemy 2.0."
+ )
+ if _single:
+ _props.extend((_s,) for _s in _single)
+ _single = []
+
+ # this checks for an extremely ancient calling form of
+ # reversed tuples.
+ if isinstance(prop[0], (str, interfaces.PropComparator)):
+ prop = (prop[1], prop[0])
+
+ _props.append(prop)
+ else:
+ _single.append(prop)
+ if _single:
+ _props.extend((_s,) for _s in _single)
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if aliased:
+ self._aliased_generation = self._next_aliased_generation()
+
+ if self._aliased_generation:
+ _props = [
+ (
+ prop[0],
+ sql_util._deep_annotate(
+ prop[1],
+ {"aliased_generation": self._aliased_generation},
+ )
+ if isinstance(prop[1], expression.ClauseElement)
+ else prop[1],
+ )
+ if len(prop) == 2
+ else prop
+ for prop in _props
+ ]
+
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ joins_to_add = tuple(
+ (
+ coercions.expect(
+ roles.JoinTargetRole,
+ prop[0],
+ legacy=True,
+ apply_propagate_attrs=self,
+ ),
+ (
+ coercions.expect(roles.OnClauseRole, prop[1], legacy=True)
+ # if not isinstance(prop[1], str)
+ # else prop[1]
+ )
+ if len(prop) == 2
+ else None,
+ None,
+ {
+ "isouter": isouter,
+ "aliased": aliased,
+ "from_joinpoint": True if i > 0 else from_joinpoint,
+ "full": full,
+ "aliased_generation": self._aliased_generation,
+ },
+ )
+ for i, prop in enumerate(_props)
+ )
+
+ if len(joins_to_add) > 1:
+ util.warn_deprecated_20(
+ "Passing a chain of multiple join conditions to Query.join() "
+ "is deprecated and will be removed in SQLAlchemy 2.0. "
+ "Please use individual join() calls per relationship."
+ )
+
+ self._legacy_setup_joins += joins_to_add
+
+ self.__dict__.pop("_last_joined_entity", None)
+
+ def outerjoin(self, target, *props, **kwargs):
+ """Create a left outer join against this ``Query`` object's criterion
+ and apply generatively, returning the newly resulting ``Query``.
+
+ Usage is the same as the ``join()`` method.
+
+ """
+ kwargs["isouter"] = True
+ return self.join(target, *props, **kwargs)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def reset_joinpoint(self):
+ """Return a new :class:`.Query`, where the "join point" has
+ been reset back to the base FROM entities of the query.
+
+ This method is usually used in conjunction with the
+ ``aliased=True`` feature of the :meth:`~.Query.join`
+ method. See the example in :meth:`~.Query.join` for how
+ this is used.
+
+ """
+ self._last_joined_entity = None
+ self._aliased_generation = None
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def select_from(self, *from_obj):
+ r"""Set the FROM clause of this :class:`.Query` explicitly.
+
+ :meth:`.Query.select_from` is often used in conjunction with
+ :meth:`.Query.join` in order to control which entity is selected
+ from on the "left" side of the join.
+
+ The entity or selectable object here effectively replaces the
+ "left edge" of any calls to :meth:`~.Query.join`, when no
+ joinpoint is otherwise established - usually, the default "join
+ point" is the leftmost entity in the :class:`~.Query` object's
+ list of entities to be selected.
+
+ A typical example::
+
+ q = session.query(Address).select_from(User).\
+ join(User.addresses).\
+ filter(User.name == 'ed')
+
+ Which produces SQL equivalent to::
+
+ SELECT address.* FROM user
+ JOIN address ON user.id=address.user_id
+ WHERE user.name = :name_1
+
+ :param \*from_obj: collection of one or more entities to apply
+ to the FROM clause. Entities can be mapped classes,
+ :class:`.AliasedClass` objects, :class:`.Mapper` objects
+ as well as core :class:`.FromClause` elements like subqueries.
+
+ .. versionchanged:: 0.9
+ This method no longer applies the given FROM object
+ to be the selectable from which matching entities
+ select from; the :meth:`.select_entity_from` method
+ now accomplishes this. See that method for a description
+ of this behavior.
+
+ .. seealso::
+
+ :meth:`~.Query.join`
+
+ :meth:`.Query.select_entity_from`
+
+ """
+
+ self._set_select_from(from_obj, False)
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.select_entity_from`",
+ alternative="Use the :func:`_orm.aliased` construct instead",
+ )
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def select_entity_from(self, from_obj):
+ r"""Set the FROM clause of this :class:`_query.Query` to a
+ core selectable, applying it as a replacement FROM clause
+ for corresponding mapped entities.
+
+ The :meth:`_query.Query.select_entity_from`
+ method supplies an alternative
+ approach to the use case of applying an :func:`.aliased` construct
+ explicitly throughout a query. Instead of referring to the
+ :func:`.aliased` construct explicitly,
+ :meth:`_query.Query.select_entity_from` automatically *adapts* all
+ occurrences of the entity to the target selectable.
+
+ Given a case for :func:`.aliased` such as selecting ``User``
+ objects from a SELECT statement::
+
+ select_stmt = select(User).where(User.id == 7)
+ user_alias = aliased(User, select_stmt)
+
+ q = session.query(user_alias).\
+ filter(user_alias.name == 'ed')
+
+ Above, we apply the ``user_alias`` object explicitly throughout the
+ query. When it's not feasible for ``user_alias`` to be referenced
+ explicitly in many places, :meth:`_query.Query.select_entity_from`
+ may be
+ used at the start of the query to adapt the existing ``User`` entity::
+
+ q = session.query(User).\
+ select_entity_from(select_stmt.subquery()).\
+ filter(User.name == 'ed')
+
+ Above, the generated SQL will show that the ``User`` entity is
+ adapted to our statement, even in the case of the WHERE clause:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name
+ FROM (SELECT "user".id AS id, "user".name AS name
+ FROM "user"
+ WHERE "user".id = :id_1) AS anon_1
+ WHERE anon_1.name = :name_1
+
+ The :meth:`_query.Query.select_entity_from` method is similar to the
+ :meth:`_query.Query.select_from` method,
+ in that it sets the FROM clause
+ of the query. The difference is that it additionally applies
+ adaptation to the other parts of the query that refer to the
+ primary entity. If above we had used :meth:`_query.Query.select_from`
+ instead, the SQL generated would have been:
+
+ .. sourcecode:: sql
+
+ -- uses plain select_from(), not select_entity_from()
+ SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user", (SELECT "user".id AS id, "user".name AS name
+ FROM "user"
+ WHERE "user".id = :id_1) AS anon_1
+ WHERE "user".name = :name_1
+
+ To supply textual SQL to the :meth:`_query.Query.select_entity_from`
+ method,
+ we can make use of the :func:`_expression.text` construct. However,
+ the
+ :func:`_expression.text`
+ construct needs to be aligned with the columns of our
+ entity, which is achieved by making use of the
+ :meth:`_expression.TextClause.columns` method::
+
+ text_stmt = text("select id, name from user").columns(
+ User.id, User.name).subquery()
+ q = session.query(User).select_entity_from(text_stmt)
+
+ :meth:`_query.Query.select_entity_from` itself accepts an
+ :func:`.aliased`
+ object, so that the special options of :func:`.aliased` such as
+ :paramref:`.aliased.adapt_on_names` may be used within the
+ scope of the :meth:`_query.Query.select_entity_from`
+ method's adaptation
+ services. Suppose
+ a view ``user_view`` also returns rows from ``user``. If
+ we reflect this view into a :class:`_schema.Table`, this view has no
+ relationship to the :class:`_schema.Table` to which we are mapped,
+ however
+ we can use name matching to select from it::
+
+ user_view = Table('user_view', metadata,
+ autoload_with=engine)
+ user_view_alias = aliased(
+ User, user_view, adapt_on_names=True)
+ q = session.query(User).\
+ select_entity_from(user_view_alias).\
+ order_by(User.name)
+
+ .. versionchanged:: 1.1.7 The :meth:`_query.Query.select_entity_from`
+ method now accepts an :func:`.aliased` object as an alternative
+ to a :class:`_expression.FromClause` object.
+
+ :param from_obj: a :class:`_expression.FromClause`
+ object that will replace
+ the FROM clause of this :class:`_query.Query`.
+ It also may be an instance
+ of :func:`.aliased`.
+
+
+
+ .. seealso::
+
+ :meth:`_query.Query.select_from`
+
+ """
+
+ self._set_select_from([from_obj], True)
+ self._compile_options += {"_enable_single_crit": False}
+
+ def __getitem__(self, item):
+ return orm_util._getitem(
+ self,
+ item,
+ allow_negative=not self.session or not self.session.future,
+ )
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def slice(self, start, stop):
+ """Computes the "slice" of the :class:`_query.Query` represented by
+ the given indices and returns the resulting :class:`_query.Query`.
+
+ The start and stop indices behave like the argument to Python's
+ built-in :func:`range` function. This method provides an
+ alternative to using ``LIMIT``/``OFFSET`` to get a slice of the
+ query.
+
+ For example, ::
+
+ session.query(User).order_by(User.id).slice(1, 3)
+
+ renders as
+
+ .. sourcecode:: sql
+
+ SELECT users.id AS users_id,
+ users.name AS users_name
+ FROM users ORDER BY users.id
+ LIMIT ? OFFSET ?
+ (2, 1)
+
+ .. seealso::
+
+ :meth:`_query.Query.limit`
+
+ :meth:`_query.Query.offset`
+
+ """
+
+ self._limit_clause, self._offset_clause = sql_util._make_slice(
+ self._limit_clause, self._offset_clause, start, stop
+ )
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def limit(self, limit):
+ """Apply a ``LIMIT`` to the query and return the newly resulting
+ ``Query``.
+
+ """
+ self._limit_clause = sql_util._offset_or_limit_clause(limit)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def offset(self, offset):
+ """Apply an ``OFFSET`` to the query and return the newly resulting
+ ``Query``.
+
+ """
+ self._offset_clause = sql_util._offset_or_limit_clause(offset)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def distinct(self, *expr):
+ r"""Apply a ``DISTINCT`` to the query and return the newly resulting
+ ``Query``.
+
+
+ .. note::
+
+ The ORM-level :meth:`.distinct` call includes logic that will
+ automatically add columns from the ORDER BY of the query to the
+ columns clause of the SELECT statement, to satisfy the common need
+ of the database backend that ORDER BY columns be part of the SELECT
+ list when DISTINCT is used. These columns *are not* added to the
+ list of columns actually fetched by the :class:`_query.Query`,
+ however,
+ so would not affect results. The columns are passed through when
+ using the :attr:`_query.Query.statement` accessor, however.
+
+ .. deprecated:: 2.0 This logic is deprecated and will be removed
+ in SQLAlchemy 2.0. See :ref:`migration_20_query_distinct`
+ for a description of this use case in 2.0.
+
+ :param \*expr: optional column expressions. When present,
+ the PostgreSQL dialect will render a ``DISTINCT ON (<expressions>)``
+ construct.
+
+ .. deprecated:: 1.4 Using \*expr in other dialects is deprecated
+ and will raise :class:`_exc.CompileError` in a future version.
+
+ """
+ if expr:
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(
+ coercions.expect(roles.ByOfRole, e) for e in expr
+ )
+ else:
+ self._distinct = True
+
+ def all(self):
+ """Return the results represented by this :class:`_query.Query`
+ as a list.
+
+ This results in an execution of the underlying SQL statement.
+
+ .. warning:: The :class:`_query.Query` object,
+ when asked to return either
+ a sequence or iterator that consists of full ORM-mapped entities,
+ will **deduplicate entries based on primary key**. See the FAQ for
+ more details.
+
+ .. seealso::
+
+ :ref:`faq_query_deduplicating`
+ """
+ return self._iter().all()
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def from_statement(self, statement):
+ """Execute the given SELECT statement and return results.
+
+ This method bypasses all internal statement compilation, and the
+ statement is executed without modification.
+
+ The statement is typically either a :func:`_expression.text`
+ or :func:`_expression.select` construct, and should return the set
+ of columns
+ appropriate to the entity class represented by this
+ :class:`_query.Query`.
+
+ .. seealso::
+
+ :ref:`orm_tutorial_literal_sql` - usage examples in the
+ ORM tutorial
+
+ """
+ statement = coercions.expect(
+ roles.SelectStatementRole, statement, apply_propagate_attrs=self
+ )
+ self._statement = statement
+
+ def first(self):
+ """Return the first result of this ``Query`` or
+ None if the result doesn't contain any row.
+
+ first() applies a limit of one within the generated SQL, so that
+ only one primary entity row is generated on the server side
+ (note this may consist of multiple result rows if join-loaded
+ collections are present).
+
+ Calling :meth:`_query.Query.first`
+ results in an execution of the underlying
+ query.
+
+ .. seealso::
+
+ :meth:`_query.Query.one`
+
+ :meth:`_query.Query.one_or_none`
+
+ """
+ # replicates limit(1) behavior
+ if self._statement is not None:
+ return self._iter().first()
+ else:
+ return self.limit(1)._iter().first()
+
+ def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the query selects
+ no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound``
+ if multiple object identities are returned, or if multiple
+ rows are returned for a query that returns only scalar values
+ as opposed to full identity-mapped entities.
+
+ Calling :meth:`_query.Query.one_or_none`
+ results in an execution of the
+ underlying query.
+
+ .. versionadded:: 1.0.9
+
+ Added :meth:`_query.Query.one_or_none`
+
+ .. seealso::
+
+ :meth:`_query.Query.first`
+
+ :meth:`_query.Query.one`
+
+ """
+ return self._iter().one_or_none()
+
+ def one(self):
+ """Return exactly one result or raise an exception.
+
+ Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
+ no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound``
+ if multiple object identities are returned, or if multiple
+ rows are returned for a query that returns only scalar values
+ as opposed to full identity-mapped entities.
+
+ Calling :meth:`.one` results in an execution of the underlying query.
+
+ .. seealso::
+
+ :meth:`_query.Query.first`
+
+ :meth:`_query.Query.one_or_none`
+
+ """
+ return self._iter().one()
+
+ def scalar(self):
+ """Return the first element of the first result or None
+ if no rows present. If multiple rows are returned,
+ raises MultipleResultsFound.
+
+ >>> session.query(Item).scalar()
+ <Item>
+ >>> session.query(Item.id).scalar()
+ 1
+ >>> session.query(Item.id).filter(Item.id < 0).scalar()
+ None
+ >>> session.query(Item.id, Item.name).scalar()
+ 1
+ >>> session.query(func.count(Parent.id)).scalar()
+ 20
+
+ This results in an execution of the underlying query.
+
+ """
+ # TODO: not sure why we can't use result.scalar() here
+ try:
+ ret = self.one()
+ if not isinstance(ret, collections_abc.Sequence):
+ return ret
+ return ret[0]
+ except orm_exc.NoResultFound:
+ return None
+
+ def __iter__(self):
+ return self._iter().__iter__()
+
+ def _iter(self):
+ # new style execution.
+ params = self._params
+
+ statement = self._statement_20()
+ result = self.session.execute(
+ statement,
+ params,
+ execution_options={"_sa_orm_load_options": self.load_options},
+ )
+
+ # legacy: automatically set scalars, unique
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if (
+ result._attributes.get("filtered", False)
+ and not self.load_options._yield_per
+ ):
+ result = result.unique()
+
+ return result
+
+ def __str__(self):
+ statement = self._statement_20()
+
+ try:
+ bind = (
+ self._get_bind_args(statement, self.session.get_bind)
+ if self.session
+ else None
+ )
+ except sa_exc.UnboundExecutionError:
+ bind = None
+
+ return str(statement.compile(bind))
+
+ def _get_bind_args(self, statement, fn, **kw):
+ return fn(clause=statement, **kw)
+
+ @property
+ def column_descriptions(self):
+ """Return metadata about the columns which would be
+ returned by this :class:`_query.Query`.
+
+ Format is a list of dictionaries::
+
+ user_alias = aliased(User, name='user2')
+ q = sess.query(User, User.id, user_alias)
+
+ # this expression:
+ q.column_descriptions
+
+ # would return:
+ [
+ {
+ 'name':'User',
+ 'type':User,
+ 'aliased':False,
+ 'expr':User,
+ 'entity': User
+ },
+ {
+ 'name':'id',
+ 'type':Integer(),
+ 'aliased':False,
+ 'expr':User.id,
+ 'entity': User
+ },
+ {
+ 'name':'user2',
+ 'type':User,
+ 'aliased':True,
+ 'expr':user_alias,
+ 'entity': user_alias
+ }
+ ]
+
+ .. seealso::
+
+ This API is available using :term:`2.0 style` queries as well,
+ documented at:
+
+ * :ref:`queryguide_inspection`
+
+ * :attr:`.Select.column_descriptions`
+
+ """
+
+ return _column_descriptions(self, legacy=True)
+
+ def instances(self, result_proxy, context=None):
+ """Return an ORM result given a :class:`_engine.CursorResult` and
+ :class:`.QueryContext`.
+
+ """
+ if context is None:
+ util.warn_deprecated(
+ "Using the Query.instances() method without a context "
+ "is deprecated and will be disallowed in a future release. "
+ "Please make use of :meth:`_query.Query.from_statement` "
+ "for linking ORM results to arbitrary select constructs.",
+ version="1.4",
+ )
+ compile_state = self._compile_state(for_statement=False)
+
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self._params,
+ self.session,
+ self.load_options,
+ )
+
+ result = loading.instances(result_proxy, context)
+
+ # legacy: automatically set scalars, unique
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if result._attributes.get("filtered", False):
+ result = result.unique()
+
+ return result
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.merge_result`",
+ alternative="The method is superseded by the "
+ ":func:`_orm.merge_frozen_result` function.",
+ becomes_legacy=True,
+ enable_warnings=False, # warnings occur via loading.merge_result
+ )
+ def merge_result(self, iterator, load=True):
+ """Merge a result into this :class:`_query.Query` object's Session.
+
+ Given an iterator returned by a :class:`_query.Query`
+ of the same structure
+ as this one, return an identical iterator of results, with all mapped
+ instances merged into the session using :meth:`.Session.merge`. This
+ is an optimized method which will merge all mapped instances,
+ preserving the structure of the result rows and unmapped columns with
+ less method overhead than that of calling :meth:`.Session.merge`
+ explicitly for each value.
+
+ The structure of the results is determined based on the column list of
+ this :class:`_query.Query` - if these do not correspond,
+ unchecked errors
+ will occur.
+
+ The 'load' argument is the same as that of :meth:`.Session.merge`.
+
+ For an example of how :meth:`_query.Query.merge_result` is used, see
+ the source code for the example :ref:`examples_caching`, where
+ :meth:`_query.Query.merge_result` is used to efficiently restore state
+ from a cache back into a target :class:`.Session`.
+
+ """
+
+ return loading.merge_result(self, iterator, load)
+
+ def exists(self):
+ """A convenience method that turns a query into an EXISTS subquery
+ of the form EXISTS (SELECT 1 FROM ... WHERE ...).
+
+ e.g.::
+
+ q = session.query(User).filter(User.name == 'fred')
+ session.query(q.exists())
+
+ Producing SQL similar to::
+
+ SELECT EXISTS (
+ SELECT 1 FROM users WHERE users.name = :name_1
+ ) AS anon_1
+
+ The EXISTS construct is usually used in the WHERE clause::
+
+ session.query(User.id).filter(q.exists()).scalar()
+
+ Note that some databases such as SQL Server don't allow an
+ EXISTS expression to be present in the columns clause of a
+ SELECT. To select a simple boolean value based on the exists
+ as a WHERE, use :func:`.literal`::
+
+ from sqlalchemy import literal
+
+ session.query(literal(True)).filter(q.exists()).scalar()
+
+ """
+
+ # .add_columns() for the case that we are a query().select_from(X),
+ # so that ".statement" can be produced (#2995) but also without
+ # omitting the FROM clause from a query(X) (#2818);
+ # .with_only_columns() after we have a core select() so that
+ # we get just "SELECT 1" without any entities.
+
+ inner = (
+ self.enable_eagerloads(False)
+ .add_columns(sql.literal_column("1"))
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement.with_only_columns(1)
+ )
+
+ ezero = self._entity_from_pre_ent_zero()
+ if ezero is not None:
+ inner = inner.select_from(ezero)
+
+ return sql.exists(inner)
+
+ def count(self):
+ r"""Return a count of rows this the SQL formed by this :class:`Query`
+ would return.
+
+ This generates the SQL for this Query as follows::
+
+ SELECT count(1) AS count_1 FROM (
+ SELECT <rest of query follows...>
+ ) AS anon_1
+
+ The above SQL returns a single row, which is the aggregate value
+ of the count function; the :meth:`_query.Query.count`
+ method then returns
+ that single integer value.
+
+ .. warning::
+
+ It is important to note that the value returned by
+ count() is **not the same as the number of ORM objects that this
+ Query would return from a method such as the .all() method**.
+ The :class:`_query.Query` object,
+ when asked to return full entities,
+ will **deduplicate entries based on primary key**, meaning if the
+ same primary key value would appear in the results more than once,
+ only one object of that primary key would be present. This does
+ not apply to a query that is against individual columns.
+
+ .. seealso::
+
+ :ref:`faq_query_deduplicating`
+
+ :ref:`orm_tutorial_query_returning`
+
+ For fine grained control over specific columns to count, to skip the
+ usage of a subquery or otherwise control of the FROM clause, or to use
+ other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func`
+ expressions in conjunction with :meth:`~.Session.query`, i.e.::
+
+ from sqlalchemy import func
+
+ # count User records, without
+ # using a subquery.
+ session.query(func.count(User.id))
+
+ # return count of user "id" grouped
+ # by "name"
+ session.query(func.count(User.id)).\
+ group_by(User.name)
+
+ from sqlalchemy import distinct
+
+ # count distinct "name" values
+ session.query(func.count(distinct(User.name)))
+
+ """
+ col = sql.func.count(sql.literal_column("*"))
+ return self._from_self(col).enable_eagerloads(False).scalar()
+
+ def delete(self, synchronize_session="evaluate"):
+ r"""Perform a DELETE with an arbitrary WHERE clause.
+
+ Deletes rows matched by this query from the database.
+
+ E.g.::
+
+ sess.query(User).filter(User.age == 25).\
+ delete(synchronize_session=False)
+
+ sess.query(User).filter(User.age == 25).\
+ delete(synchronize_session='evaluate')
+
+ .. warning::
+
+ See the section :ref:`orm_expression_update_delete` for important
+ caveats and warnings, including limitations when using bulk UPDATE
+ and DELETE with mapper inheritance configurations.
+
+ :param synchronize_session: chooses the strategy to update the
+ attributes on objects in the session. See the section
+ :ref:`orm_expression_update_delete` for a discussion of these
+ strategies.
+
+ :return: the count of rows matched as returned by the database's
+ "row count" feature.
+
+ .. seealso::
+
+ :ref:`orm_expression_update_delete`
+
+ """
+
+ bulk_del = BulkDelete(self)
+ if self.dispatch.before_compile_delete:
+ for fn in self.dispatch.before_compile_delete:
+ new_query = fn(bulk_del.query, bulk_del)
+ if new_query is not None:
+ bulk_del.query = new_query
+
+ self = bulk_del.query
+
+ delete_ = sql.delete(*self._raw_columns)
+ delete_._where_criteria = self._where_criteria
+ result = self.session.execute(
+ delete_,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_del.result = result
+ self.session.dispatch.after_bulk_delete(bulk_del)
+ result.close()
+
+ return result.rowcount
+
+ def update(self, values, synchronize_session="evaluate", update_args=None):
+ r"""Perform an UPDATE with an arbitrary WHERE clause.
+
+ Updates rows matched by this query in the database.
+
+ E.g.::
+
+ sess.query(User).filter(User.age == 25).\
+ update({User.age: User.age - 10}, synchronize_session=False)
+
+ sess.query(User).filter(User.age == 25).\
+ update({"age": User.age - 10}, synchronize_session='evaluate')
+
+ .. warning::
+
+ See the section :ref:`orm_expression_update_delete` for important
+ caveats and warnings, including limitations when using arbitrary
+ UPDATE and DELETE with mapper inheritance configurations.
+
+ :param values: a dictionary with attributes names, or alternatively
+ mapped attributes or SQL expressions, as keys, and literal
+ values or sql expressions as values. If :ref:`parameter-ordered
+ mode <tutorial_parameter_ordered_updates>` is desired, the values can
+ be passed as a list of 2-tuples; this requires that the
+ :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
+ flag is passed to the :paramref:`.Query.update.update_args` dictionary
+ as well.
+
+ :param synchronize_session: chooses the strategy to update the
+ attributes on objects in the session. See the section
+ :ref:`orm_expression_update_delete` for a discussion of these
+ strategies.
+
+ :param update_args: Optional dictionary, if present will be passed
+ to the underlying :func:`_expression.update`
+ construct as the ``**kw`` for
+ the object. May be used to pass dialect-specific arguments such
+ as ``mysql_limit``, as well as other special arguments such as
+ :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`.
+
+ :return: the count of rows matched as returned by the database's
+ "row count" feature.
+
+
+ .. seealso::
+
+ :ref:`orm_expression_update_delete`
+
+
+ """
+
+ update_args = update_args or {}
+
+ bulk_ud = BulkUpdate(self, values, update_args)
+
+ if self.dispatch.before_compile_update:
+ for fn in self.dispatch.before_compile_update:
+ new_query = fn(bulk_ud.query, bulk_ud)
+ if new_query is not None:
+ bulk_ud.query = new_query
+ self = bulk_ud.query
+
+ upd = sql.update(*self._raw_columns)
+
+ ppo = update_args.pop("preserve_parameter_order", False)
+ if ppo:
+ upd = upd.ordered_values(*values)
+ else:
+ upd = upd.values(values)
+ if update_args:
+ upd = upd.with_dialect_options(**update_args)
+
+ upd._where_criteria = self._where_criteria
+ result = self.session.execute(
+ upd,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_ud.result = result
+ self.session.dispatch.after_bulk_update(bulk_ud)
+ result.close()
+ return result.rowcount
+
+ def _compile_state(self, for_statement=False, **kw):
+ """Create an out-of-compiler ORMCompileState object.
+
+ The ORMCompileState object is normally created directly as a result
+ of the SQLCompiler.process() method being handed a Select()
+ or FromStatement() object that uses the "orm" plugin. This method
+ provides a means of creating this ORMCompileState object directly
+ without using the compiler.
+
+ This method is used only for deprecated cases, which include
+ the .from_self() method for a Query that has multiple levels
+ of .from_self() in use, as well as the instances() method. It is
+ also used within the test suite to generate ORMCompileState objects
+ for test purposes.
+
+ """
+
+ stmt = self._statement_20(for_statement=for_statement, **kw)
+ assert for_statement == stmt._compile_options._for_statement
+
+ # this chooses between ORMFromStatementCompileState and
+ # ORMSelectCompileState. We could also base this on
+ # query._statement is not None as we have the ORM Query here
+ # however this is the more general path.
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ stmt, "orm"
+ )
+
+ return compile_state_cls.create_for_statement(stmt, None)
+
+ def _compile_context(self, for_statement=False):
+ compile_state = self._compile_state(for_statement=for_statement)
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self._params,
+ self.session,
+ self.load_options,
+ )
+
+ return context
+
+
+class FromStatement(GroupedElement, SelectBase, Executable):
+ """Core construct that represents a load of ORM objects from a finished
+ select or text construct.
+
+ """
+
+ __visit_name__ = "orm_from_statement"
+
+ _compile_options = ORMFromStatementCompileState.default_compile_options
+
+ _compile_state_factory = ORMFromStatementCompileState.create_for_statement
+
+ _for_update_arg = None
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("element", InternalTraversal.dp_clauseelement),
+ ] + Executable._executable_traverse_internals
+
+ _cache_key_traversal = _traverse_internals + [
+ ("_compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
+ def __init__(self, entities, element):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ ent,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for ent in util.to_list(entities)
+ ]
+ self.element = element
+
+ def get_label_style(self):
+ return self._label_style
+
+ def set_label_style(self, label_style):
+ return SelectStatementGrouping(
+ self.element.set_label_style(label_style)
+ )
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
+ def _compiler_dispatch(self, compiler, **kw):
+
+ """provide a fixed _compiler_dispatch method.
+
+ This is roughly similar to using the sqlalchemy.ext.compiler
+ ``@compiles`` extension.
+
+ """
+
+ compile_state = self._compile_state_factory(self, compiler, **kw)
+
+ toplevel = not compiler.stack
+
+ if toplevel:
+ compiler.compile_state = compile_state
+
+ return compiler.process(compile_state.statement, **kw)
+
+ def _ensure_disambiguated_names(self):
+ return self
+
+ def get_children(self, **kw):
+ for elem in itertools.chain.from_iterable(
+ element._from_objects for element in self._raw_columns
+ ):
+ yield elem
+ for elem in super(FromStatement, self).get_children(**kw):
+ yield elem
+
+ @property
+ def _returning(self):
+ return self.element._returning if self.element.is_dml else None
+
+ @property
+ def _inline(self):
+ return self.element._inline if self.element.is_dml else None
+
+
+class AliasOption(interfaces.LoaderOption):
+ @util.deprecated(
+ "1.4",
+ "The :class:`.AliasOption` is not necessary "
+ "for entities to be matched up to a query that is established "
+ "via :meth:`.Query.from_statement` and now does nothing.",
+ )
+ def __init__(self, alias):
+ r"""Return a :class:`.MapperOption` that will indicate to the
+ :class:`_query.Query`
+ that the main table has been aliased.
+
+ """
+
+ inherit_cache = False
+
+ def process_compile_state(self, compile_state):
+ pass
+
+
+class BulkUD(object):
+ """State used for the orm.Query version of update() / delete().
+
+ This object is now specific to Query only.
+
+ """
+
+ def __init__(self, query):
+ self.query = query.enable_eagerloads(False)
+ self._validate_query_state()
+ self.mapper = self.query._entity_from_pre_ent_zero()
+
+ def _validate_query_state(self):
+ for attr, methname, notset, op in (
+ ("_limit_clause", "limit()", None, operator.is_),
+ ("_offset_clause", "offset()", None, operator.is_),
+ ("_order_by_clauses", "order_by()", (), operator.eq),
+ ("_group_by_clauses", "group_by()", (), operator.eq),
+ ("_distinct", "distinct()", False, operator.is_),
+ (
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ (
+ "_legacy_setup_joins",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ ):
+ if not op(getattr(self.query, attr), notset):
+ raise sa_exc.InvalidRequestError(
+ "Can't call Query.update() or Query.delete() "
+ "when %s has been called" % (methname,)
+ )
+
+ @property
+ def session(self):
+ return self.query.session
+
+
+class BulkUpdate(BulkUD):
+ """BulkUD which handles UPDATEs."""
+
+ def __init__(self, query, values, update_kwargs):
+ super(BulkUpdate, self).__init__(query)
+ self.values = values
+ self.update_kwargs = update_kwargs
+
+
+class BulkDelete(BulkUD):
+ """BulkUD which handles DELETEs."""
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
new file mode 100644
index 0000000..b51ea0e
--- /dev/null
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -0,0 +1,3684 @@
+# orm/relationships.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Heuristics related to join conditions as used in
+:func:`_orm.relationship`.
+
+Provides the :class:`.JoinCondition` object, which encapsulates
+SQL annotation and aliasing behavior focused on the `primaryjoin`
+and `secondaryjoin` aspects of :func:`_orm.relationship`.
+
+"""
+from __future__ import absolute_import
+
+import collections
+import re
+import weakref
+
+from . import attributes
+from .base import _is_mapped_class
+from .base import state_str
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import ONETOMANY
+from .interfaces import PropComparator
+from .interfaces import StrategizedProperty
+from .util import _orm_annotate
+from .util import _orm_deannotate
+from .util import CascadeOptions
+from .. import exc as sa_exc
+from .. import log
+from .. import schema
+from .. import sql
+from .. import util
+from ..inspection import inspect
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import visitors
+from ..sql.util import _deep_deannotate
+from ..sql.util import _shallow_annotate
+from ..sql.util import adapt_criterion_to_null
+from ..sql.util import ClauseAdapter
+from ..sql.util import join_condition
+from ..sql.util import selectables_overlap
+from ..sql.util import visit_binary_product
+
+
+def remote(expr):
+ """Annotate a portion of a primaryjoin expression
+ with a 'remote' annotation.
+
+ See the section :ref:`relationship_custom_foreign` for a
+ description of use.
+
+ .. seealso::
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.foreign`
+
+ """
+ return _annotate_columns(
+ coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
+ )
+
+
+def foreign(expr):
+ """Annotate a portion of a primaryjoin expression
+ with a 'foreign' annotation.
+
+ See the section :ref:`relationship_custom_foreign` for a
+ description of use.
+
+ .. seealso::
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.remote`
+
+ """
+
+ return _annotate_columns(
+ coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
+ )
+
+
+@log.class_logger
+class RelationshipProperty(StrategizedProperty):
+ """Describes an object property that holds a single item or list
+ of items that correspond to a related database table.
+
+ Public constructor is the :func:`_orm.relationship` function.
+
+ .. seealso::
+
+ :ref:`relationship_config_toplevel`
+
+ """
+
+ strategy_wildcard_key = "relationship"
+ inherit_cache = True
+
+ _links_to_entity = True
+
+ _persistence_only = dict(
+ passive_deletes=False,
+ passive_updates=True,
+ enable_typechecks=True,
+ active_history=False,
+ cascade_backrefs=True,
+ )
+
+ _dependency_processor = None
+
+ def __init__(
+ self,
+ argument,
+ secondary=None,
+ primaryjoin=None,
+ secondaryjoin=None,
+ foreign_keys=None,
+ uselist=None,
+ order_by=False,
+ backref=None,
+ back_populates=None,
+ overlaps=None,
+ post_update=False,
+ cascade=False,
+ viewonly=False,
+ lazy="select",
+ collection_class=None,
+ passive_deletes=_persistence_only["passive_deletes"],
+ passive_updates=_persistence_only["passive_updates"],
+ remote_side=None,
+ enable_typechecks=_persistence_only["enable_typechecks"],
+ join_depth=None,
+ comparator_factory=None,
+ single_parent=False,
+ innerjoin=False,
+ distinct_target_key=None,
+ doc=None,
+ active_history=_persistence_only["active_history"],
+ cascade_backrefs=_persistence_only["cascade_backrefs"],
+ load_on_pending=False,
+ bake_queries=True,
+ _local_remote_pairs=None,
+ query_class=None,
+ info=None,
+ omit_join=None,
+ sync_backref=None,
+ _legacy_inactive_history_style=False,
+ ):
+ """Provide a relationship between two mapped classes.
+
+ This corresponds to a parent-child or associative table relationship.
+ The constructed class is an instance of
+ :class:`.RelationshipProperty`.
+
+ A typical :func:`_orm.relationship`, used in a classical mapping::
+
+ mapper(Parent, properties={
+ 'children': relationship(Child)
+ })
+
+ Some arguments accepted by :func:`_orm.relationship`
+ optionally accept a
+ callable function, which when called produces the desired value.
+ The callable is invoked by the parent :class:`_orm.Mapper` at "mapper
+ initialization" time, which happens only when mappers are first used,
+ and is assumed to be after all mappings have been constructed. This
+ can be used to resolve order-of-declaration and other dependency
+ issues, such as if ``Child`` is declared below ``Parent`` in the same
+ file::
+
+ mapper(Parent, properties={
+ "children":relationship(lambda: Child,
+ order_by=lambda: Child.id)
+ })
+
+ When using the :ref:`declarative_toplevel` extension, the Declarative
+ initializer allows string arguments to be passed to
+ :func:`_orm.relationship`. These string arguments are converted into
+ callables that evaluate the string as Python code, using the
+ Declarative class-registry as a namespace. This allows the lookup of
+ related classes to be automatic via their string name, and removes the
+ need for related classes to be imported into the local module space
+ before the dependent classes have been declared. It is still required
+ that the modules in which these related classes appear are imported
+ anywhere in the application at some point before the related mappings
+ are actually used, else a lookup error will be raised when the
+ :func:`_orm.relationship`
+ attempts to resolve the string reference to the
+ related class. An example of a string- resolved class is as
+ follows::
+
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class Parent(Base):
+ __tablename__ = 'parent'
+ id = Column(Integer, primary_key=True)
+ children = relationship("Child", order_by="Child.id")
+
+ .. seealso::
+
+ :ref:`relationship_config_toplevel` - Full introductory and
+ reference documentation for :func:`_orm.relationship`.
+
+ :ref:`tutorial_orm_related_objects` - ORM tutorial introduction.
+
+ :param argument:
+ A mapped class, or actual :class:`_orm.Mapper` instance,
+ representing
+ the target of the relationship.
+
+ :paramref:`_orm.relationship.argument`
+ may also be passed as a callable
+ function which is evaluated at mapper initialization time, and may
+ be passed as a string name when using Declarative.
+
+ .. warning:: Prior to SQLAlchemy 1.3.16, this value is interpreted
+ using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. versionchanged 1.3.16::
+
+ The string evaluation of the main "argument" no longer accepts an
+ open ended Python expression, instead only accepting a string
+ class name or dotted package-qualified name.
+
+ .. seealso::
+
+ :ref:`declarative_configuring_relationships` - further detail
+ on relationship configuration when using Declarative.
+
+ :param secondary:
+ For a many-to-many relationship, specifies the intermediary
+ table, and is typically an instance of :class:`_schema.Table`.
+ In less common circumstances, the argument may also be specified
+ as an :class:`_expression.Alias` construct, or even a
+ :class:`_expression.Join` construct.
+
+ :paramref:`_orm.relationship.secondary` may
+ also be passed as a callable function which is evaluated at
+ mapper initialization time. When using Declarative, it may also
+ be a string argument noting the name of a :class:`_schema.Table`
+ that is
+ present in the :class:`_schema.MetaData`
+ collection associated with the
+ parent-mapped :class:`_schema.Table`.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ The :paramref:`_orm.relationship.secondary` keyword argument is
+ typically applied in the case where the intermediary
+ :class:`_schema.Table`
+ is not otherwise expressed in any direct class mapping. If the
+ "secondary" table is also explicitly mapped elsewhere (e.g. as in
+ :ref:`association_pattern`), one should consider applying the
+ :paramref:`_orm.relationship.viewonly` flag so that this
+ :func:`_orm.relationship`
+ is not used for persistence operations which
+ may conflict with those of the association object pattern.
+
+ .. seealso::
+
+ :ref:`relationships_many_to_many` - Reference example of "many
+ to many".
+
+ :ref:`self_referential_many_to_many` - Specifics on using
+ many-to-many in a self-referential case.
+
+ :ref:`declarative_many_to_many` - Additional options when using
+ Declarative.
+
+ :ref:`association_pattern` - an alternative to
+ :paramref:`_orm.relationship.secondary`
+ when composing association
+ table relationships, allowing additional attributes to be
+ specified on the association table.
+
+ :ref:`composite_secondary_join` - a lesser-used pattern which
+ in some cases can enable complex :func:`_orm.relationship` SQL
+ conditions to be used.
+
+ .. versionadded:: 0.9.2 :paramref:`_orm.relationship.secondary`
+ works
+ more effectively when referring to a :class:`_expression.Join`
+ instance.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ many-to-one reference should be loaded when replaced, if
+ not already loaded. Normally, history tracking logic for
+ simple many-to-ones only needs to be aware of the "new"
+ value in order to perform a flush. This flag is available
+ for applications that make use of
+ :func:`.attributes.get_history` which also need to know
+ the "previous" value of the attribute.
+
+ :param backref:
+ A reference to a string relationship name, or a :func:`_orm.backref`
+ construct, which will be used to automatically generate a new
+ :func:`_orm.relationship` on the related class, which then refers to
+ this one using a bi-directional
+ :paramref:`_orm.relationship.back_populates` configuration.
+
+ In modern Python, explicit use of :func:`_orm.relationship` with
+ :paramref:`_orm.relationship.back_populates` should be preferred, as
+ it is more robust in terms of mapper configuration as well as more
+ conceptually straightforward. It also integrates with new :pep:`484`
+ typing features introduced in SQLAlchemy 2.0 which is not possible
+ with dynamically generated attributes.
+
+ .. seealso::
+
+ :ref:`relationships_backref` - notes on using
+ :paramref:`_orm.relationship.backref`
+
+ :ref:`tutorial_orm_related_objects` - in the
+ :ref:`unified_tutorial`, presents an overview of bi-directional
+ relationship configuration and behaviors using
+ :paramref:`_orm.relationship.back_populates`
+
+ :func:`.backref` - allows control over :func:`_orm.relationship`
+ configuration when using :paramref:`_orm.relationship.backref`.
+
+
+ :param back_populates:
+ Indicates the name of a :func:`_orm.relationship` on the related
+ class that will be synchronized with this one. It is usually
+ expected that the :func:`_orm.relationship` on the related class
+ also refer to this one. This allows objects on both sides of
+ each :func:`_orm.relationship` to synchronize in-Python state
+ changes and also provides directives to the :term:`unit of work`
+ flush process how changes along these relationships should
+ be persisted.
+
+ .. seealso::
+
+ :ref:`tutorial_orm_related_objects` - in the
+ :ref:`unified_tutorial`, presents an overview of bi-directional
+ relationship configuration and behaviors.
+
+ :ref:`relationship_patterns` - includes many examples of
+ :paramref:`_orm.relationship.back_populates`.
+
+ :param overlaps:
+ A string name or comma-delimited set of names of other relationships
+ on either this mapper, a descendant mapper, or a target mapper with
+ which this relationship may write to the same foreign keys upon
+ persistence. The only effect this has is to eliminate the
+ warning that this relationship will conflict with another upon
+ persistence. This is used for such relationships that are truly
+ capable of conflicting with each other on write, but the application
+ will ensure that no such conflicts occur.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`error_qzyx` - usage example
+
+ :param bake_queries=True:
+ Legacy parameter, not used.
+
+ .. versionchanged:: 1.4.23 the "lambda caching" system is no longer
+ used by loader strategies and the ``bake_queries`` parameter
+ has no effect.
+
+ :param cascade:
+ A comma-separated list of cascade rules which determines how
+ Session operations should be "cascaded" from parent to child.
+ This defaults to ``False``, which means the default cascade
+ should be used - this default cascade is ``"save-update, merge"``.
+
+ The available cascades are ``save-update``, ``merge``,
+ ``expunge``, ``delete``, ``delete-orphan``, and ``refresh-expire``.
+ An additional option, ``all`` indicates shorthand for
+ ``"save-update, merge, refresh-expire,
+ expunge, delete"``, and is often used as in ``"all, delete-orphan"``
+ to indicate that related objects should follow along with the
+ parent object in all cases, and be deleted when de-associated.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades` - Full detail on each of the available
+ cascade options.
+
+ :param cascade_backrefs=True:
+ A boolean value indicating if the ``save-update`` cascade should
+ operate along an assignment event intercepted by a backref.
+ When set to ``False``, the attribute managed by this relationship
+ will not cascade an incoming transient object into the session of a
+ persistent parent, if the event is received via backref.
+
+ .. deprecated:: 1.4 The
+ :paramref:`_orm.relationship.cascade_backrefs`
+ flag will default to False in all cases in SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`backref_cascade` - Full discussion and examples on how
+ the :paramref:`_orm.relationship.cascade_backrefs` option is used.
+
+ :param collection_class:
+ A class or callable that returns a new list-holding object. will
+ be used in place of a plain list for storing elements.
+
+ .. seealso::
+
+ :ref:`custom_collections` - Introductory documentation and
+ examples.
+
+ :param comparator_factory:
+ A class which extends :class:`.RelationshipProperty.Comparator`
+ which provides custom SQL clause generation for comparison
+ operations.
+
+ .. seealso::
+
+ :class:`.PropComparator` - some detail on redefining comparators
+ at this level.
+
+ :ref:`custom_comparators` - Brief intro to this feature.
+
+
+ :param distinct_target_key=None:
+ Indicate if a "subquery" eager load should apply the DISTINCT
+ keyword to the innermost SELECT statement. When left as ``None``,
+ the DISTINCT keyword will be applied in those cases when the target
+ columns do not comprise the full primary key of the target table.
+ When set to ``True``, the DISTINCT keyword is applied to the
+ innermost SELECT unconditionally.
+
+ It may be desirable to set this flag to False when the DISTINCT is
+ reducing performance of the innermost subquery beyond that of what
+ duplicate innermost rows may be causing.
+
+ .. versionchanged:: 0.9.0 -
+ :paramref:`_orm.relationship.distinct_target_key` now defaults to
+ ``None``, so that the feature enables itself automatically for
+ those cases where the innermost query targets a non-unique
+ key.
+
+ .. seealso::
+
+ :ref:`loading_toplevel` - includes an introduction to subquery
+ eager loading.
+
+ :param doc:
+ Docstring which will be applied to the resulting descriptor.
+
+ :param foreign_keys:
+
+ A list of columns which are to be used as "foreign key"
+ columns, or columns which refer to the value in a remote
+ column, within the context of this :func:`_orm.relationship`
+ object's :paramref:`_orm.relationship.primaryjoin` condition.
+ That is, if the :paramref:`_orm.relationship.primaryjoin`
+ condition of this :func:`_orm.relationship` is ``a.id ==
+ b.a_id``, and the values in ``b.a_id`` are required to be
+ present in ``a.id``, then the "foreign key" column of this
+ :func:`_orm.relationship` is ``b.a_id``.
+
+ In normal cases, the :paramref:`_orm.relationship.foreign_keys`
+ parameter is **not required.** :func:`_orm.relationship` will
+ automatically determine which columns in the
+ :paramref:`_orm.relationship.primaryjoin` condition are to be
+ considered "foreign key" columns based on those
+ :class:`_schema.Column` objects that specify
+ :class:`_schema.ForeignKey`,
+ or are otherwise listed as referencing columns in a
+ :class:`_schema.ForeignKeyConstraint` construct.
+ :paramref:`_orm.relationship.foreign_keys` is only needed when:
+
+ 1. There is more than one way to construct a join from the local
+ table to the remote table, as there are multiple foreign key
+ references present. Setting ``foreign_keys`` will limit the
+ :func:`_orm.relationship`
+ to consider just those columns specified
+ here as "foreign".
+
+ 2. The :class:`_schema.Table` being mapped does not actually have
+ :class:`_schema.ForeignKey` or
+ :class:`_schema.ForeignKeyConstraint`
+ constructs present, often because the table
+ was reflected from a database that does not support foreign key
+ reflection (MySQL MyISAM).
+
+ 3. The :paramref:`_orm.relationship.primaryjoin`
+ argument is used to
+ construct a non-standard join condition, which makes use of
+ columns or expressions that do not normally refer to their
+ "parent" column, such as a join condition expressed by a
+ complex comparison using a SQL function.
+
+ The :func:`_orm.relationship` construct will raise informative
+ error messages that suggest the use of the
+ :paramref:`_orm.relationship.foreign_keys` parameter when
+ presented with an ambiguous condition. In typical cases,
+ if :func:`_orm.relationship` doesn't raise any exceptions, the
+ :paramref:`_orm.relationship.foreign_keys` parameter is usually
+ not needed.
+
+ :paramref:`_orm.relationship.foreign_keys` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_foreign_keys`
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.foreign` - allows direct annotation of the "foreign"
+ columns within a :paramref:`_orm.relationship.primaryjoin`
+ condition.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ :param innerjoin=False:
+ When ``True``, joined eager loads will use an inner join to join
+ against related tables instead of an outer join. The purpose
+ of this option is generally one of performance, as inner joins
+ generally perform better than outer joins.
+
+ This flag can be set to ``True`` when the relationship references an
+ object via many-to-one using local foreign keys that are not
+ nullable, or when the reference is one-to-one or a collection that
+ is guaranteed to have one or at least one entry.
+
+ The option supports the same "nested" and "unnested" options as
+ that of :paramref:`_orm.joinedload.innerjoin`. See that flag
+ for details on nested / unnested behaviors.
+
+ .. seealso::
+
+ :paramref:`_orm.joinedload.innerjoin` - the option as specified by
+ loader option, including detail on nesting behavior.
+
+ :ref:`what_kind_of_loading` - Discussion of some details of
+ various loader options.
+
+
+ :param join_depth:
+ When non-``None``, an integer value indicating how many levels
+ deep "eager" loaders should join on a self-referring or cyclical
+ relationship. The number counts how many times the same Mapper
+ shall be present in the loading condition along a particular join
+ branch. When left at its default of ``None``, eager loaders
+ will stop chaining when they encounter a the same target mapper
+ which is already higher up in the chain. This option applies
+ both to joined- and subquery- eager loaders.
+
+ .. seealso::
+
+ :ref:`self_referential_eager_loading` - Introductory documentation
+ and examples.
+
+ :param lazy='select': specifies
+ How the related items should be loaded. Default value is
+ ``select``. Values include:
+
+ * ``select`` - items should be loaded lazily when the property is
+ first accessed, using a separate SELECT statement, or identity map
+ fetch for simple many-to-one references.
+
+ * ``immediate`` - items should be loaded as the parents are loaded,
+ using a separate SELECT statement, or identity map fetch for
+ simple many-to-one references.
+
+ * ``joined`` - items should be loaded "eagerly" in the same query as
+ that of the parent, using a JOIN or LEFT OUTER JOIN. Whether
+ the join is "outer" or not is determined by the
+ :paramref:`_orm.relationship.innerjoin` parameter.
+
+ * ``subquery`` - items should be loaded "eagerly" as the parents are
+ loaded, using one additional SQL statement, which issues a JOIN to
+ a subquery of the original statement, for each collection
+ requested.
+
+ * ``selectin`` - items should be loaded "eagerly" as the parents
+ are loaded, using one or more additional SQL statements, which
+ issues a JOIN to the immediate parent object, specifying primary
+ key identifiers using an IN clause.
+
+ .. versionadded:: 1.2
+
+ * ``noload`` - no loading should occur at any time. This is to
+ support "write-only" attributes, or attributes which are
+ populated in some manner specific to the application.
+
+ * ``raise`` - lazy loading is disallowed; accessing
+ the attribute, if its value were not already loaded via eager
+ loading, will raise an :exc:`~sqlalchemy.exc.InvalidRequestError`.
+ This strategy can be used when objects are to be detached from
+ their attached :class:`.Session` after they are loaded.
+
+ .. versionadded:: 1.1
+
+ * ``raise_on_sql`` - lazy loading that emits SQL is disallowed;
+ accessing the attribute, if its value were not already loaded via
+ eager loading, will raise an
+ :exc:`~sqlalchemy.exc.InvalidRequestError`, **if the lazy load
+ needs to emit SQL**. If the lazy load can pull the related value
+ from the identity map or determine that it should be None, the
+ value is loaded. This strategy can be used when objects will
+ remain associated with the attached :class:`.Session`, however
+ additional SELECT statements should be blocked.
+
+ .. versionadded:: 1.1
+
+ * ``dynamic`` - the attribute will return a pre-configured
+ :class:`_query.Query` object for all read
+ operations, onto which further filtering operations can be
+ applied before iterating the results. See
+ the section :ref:`dynamic_relationship` for more details.
+
+ * True - a synonym for 'select'
+
+ * False - a synonym for 'joined'
+
+ * None - a synonym for 'noload'
+
+ .. seealso::
+
+ :doc:`/orm/loading_relationships` - Full documentation on
+ relationship loader configuration.
+
+ :ref:`dynamic_relationship` - detail on the ``dynamic`` option.
+
+ :ref:`collections_noload_raiseload` - notes on "noload" and "raise"
+
+ :param load_on_pending=False:
+ Indicates loading behavior for transient or pending parent objects.
+
+ When set to ``True``, causes the lazy-loader to
+ issue a query for a parent object that is not persistent, meaning it
+ has never been flushed. This may take effect for a pending object
+ when autoflush is disabled, or for a transient object that has been
+ "attached" to a :class:`.Session` but is not part of its pending
+ collection.
+
+ The :paramref:`_orm.relationship.load_on_pending`
+ flag does not improve
+ behavior when the ORM is used normally - object references should be
+ constructed at the object level, not at the foreign key level, so
+ that they are present in an ordinary way before a flush proceeds.
+ This flag is not not intended for general use.
+
+ .. seealso::
+
+ :meth:`.Session.enable_relationship_loading` - this method
+ establishes "load on pending" behavior for the whole object, and
+ also allows loading on objects that remain transient or
+ detached.
+
+ :param order_by:
+ Indicates the ordering that should be applied when loading these
+ items. :paramref:`_orm.relationship.order_by`
+ is expected to refer to
+ one of the :class:`_schema.Column`
+ objects to which the target class is
+ mapped, or the attribute itself bound to the target class which
+ refers to the column.
+
+ :paramref:`_orm.relationship.order_by`
+ may also be passed as a callable
+ function which is evaluated at mapper initialization time, and may
+ be passed as a Python-evaluable string when using Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ :param passive_deletes=False:
+ Indicates loading behavior during delete operations.
+
+ A value of True indicates that unloaded child items should not
+ be loaded during a delete operation on the parent. Normally,
+ when a parent item is deleted, all child items are loaded so
+ that they can either be marked as deleted, or have their
+ foreign key to the parent set to NULL. Marking this flag as
+ True usually implies an ON DELETE <CASCADE|SET NULL> rule is in
+ place which will handle updating/deleting child rows on the
+ database side.
+
+ Additionally, setting the flag to the string value 'all' will
+ disable the "nulling out" of the child foreign keys, when the parent
+ object is deleted and there is no delete or delete-orphan cascade
+ enabled. This is typically used when a triggering or error raise
+ scenario is in place on the database side. Note that the foreign
+ key attributes on in-session child objects will not be changed after
+ a flush occurs so this is a very special use-case setting.
+ Additionally, the "nulling out" will still occur if the child
+ object is de-associated with the parent.
+
+ .. seealso::
+
+ :ref:`passive_deletes` - Introductory documentation
+ and examples.
+
+ :param passive_updates=True:
+ Indicates the persistence behavior to take when a referenced
+ primary key value changes in place, indicating that the referencing
+ foreign key columns will also need their value changed.
+
+ When True, it is assumed that ``ON UPDATE CASCADE`` is configured on
+ the foreign key in the database, and that the database will
+ handle propagation of an UPDATE from a source column to
+ dependent rows. When False, the SQLAlchemy
+ :func:`_orm.relationship`
+ construct will attempt to emit its own UPDATE statements to
+ modify related targets. However note that SQLAlchemy **cannot**
+ emit an UPDATE for more than one level of cascade. Also,
+ setting this flag to False is not compatible in the case where
+ the database is in fact enforcing referential integrity, unless
+ those constraints are explicitly "deferred", if the target backend
+ supports it.
+
+ It is highly advised that an application which is employing
+ mutable primary keys keeps ``passive_updates`` set to True,
+ and instead uses the referential integrity features of the database
+ itself in order to handle the change efficiently and fully.
+
+ .. seealso::
+
+ :ref:`passive_updates` - Introductory documentation and
+ examples.
+
+ :paramref:`.mapper.passive_updates` - a similar flag which
+ takes effect for joined-table inheritance mappings.
+
+ :param post_update:
+ This indicates that the relationship should be handled by a
+ second UPDATE statement after an INSERT or before a
+ DELETE. Currently, it also will issue an UPDATE after the
+ instance was UPDATEd as well, although this technically should
+ be improved. This flag is used to handle saving bi-directional
+ dependencies between two individual rows (i.e. each row
+ references the other), where it would otherwise be impossible to
+ INSERT or DELETE both rows fully since one row exists before the
+ other. Use this flag when a particular mapping arrangement will
+ incur two rows that are dependent on each other, such as a table
+ that has a one-to-many relationship to a set of child rows, and
+ also has a column that references a single child row within that
+ list (i.e. both tables contain a foreign key to each other). If
+ a flush operation returns an error that a "cyclical
+ dependency" was detected, this is a cue that you might want to
+ use :paramref:`_orm.relationship.post_update` to "break" the cycle.
+
+ .. seealso::
+
+ :ref:`post_update` - Introductory documentation and examples.
+
+ :param primaryjoin:
+ A SQL expression that will be used as the primary
+ join of the child object against the parent object, or in a
+ many-to-many relationship the join of the parent object to the
+ association table. By default, this value is computed based on the
+ foreign key relationships of the parent and child tables (or
+ association table).
+
+ :paramref:`_orm.relationship.primaryjoin` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_primaryjoin`
+
+ :param remote_side:
+ Used for self-referential relationships, indicates the column or
+ list of columns that form the "remote side" of the relationship.
+
+ :paramref:`_orm.relationship.remote_side` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`self_referential` - in-depth explanation of how
+ :paramref:`_orm.relationship.remote_side`
+ is used to configure self-referential relationships.
+
+ :func:`.remote` - an annotation function that accomplishes the
+ same purpose as :paramref:`_orm.relationship.remote_side`,
+ typically
+ when a custom :paramref:`_orm.relationship.primaryjoin` condition
+ is used.
+
+ :param query_class:
+ A :class:`_query.Query`
+ subclass that will be used internally by the
+ ``AppenderQuery`` returned by a "dynamic" relationship, that
+ is, a relationship that specifies ``lazy="dynamic"`` or was
+ otherwise constructed using the :func:`_orm.dynamic_loader`
+ function.
+
+ .. seealso::
+
+ :ref:`dynamic_relationship` - Introduction to "dynamic"
+ relationship loaders.
+
+ :param secondaryjoin:
+ A SQL expression that will be used as the join of
+ an association table to the child object. By default, this value is
+ computed based on the foreign key relationships of the association
+ and child tables.
+
+ :paramref:`_orm.relationship.secondaryjoin` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_primaryjoin`
+
+ :param single_parent:
+ When True, installs a validator which will prevent objects
+ from being associated with more than one parent at a time.
+ This is used for many-to-one or many-to-many relationships that
+ should be treated either as one-to-one or one-to-many. Its usage
+ is optional, except for :func:`_orm.relationship` constructs which
+ are many-to-one or many-to-many and also
+ specify the ``delete-orphan`` cascade option. The
+ :func:`_orm.relationship` construct itself will raise an error
+ instructing when this option is required.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades` - includes detail on when the
+ :paramref:`_orm.relationship.single_parent`
+ flag may be appropriate.
+
+ :param uselist:
+ A boolean that indicates if this property should be loaded as a
+ list or a scalar. In most cases, this value is determined
+ automatically by :func:`_orm.relationship` at mapper configuration
+ time, based on the type and direction
+ of the relationship - one to many forms a list, many to one
+ forms a scalar, many to many is a list. If a scalar is desired
+ where normally a list would be present, such as a bi-directional
+ one-to-one relationship, set :paramref:`_orm.relationship.uselist`
+ to
+ False.
+
+ The :paramref:`_orm.relationship.uselist`
+ flag is also available on an
+ existing :func:`_orm.relationship`
+ construct as a read-only attribute,
+ which can be used to determine if this :func:`_orm.relationship`
+ deals
+ with collections or scalar attributes::
+
+ >>> User.addresses.property.uselist
+ True
+
+ .. seealso::
+
+ :ref:`relationships_one_to_one` - Introduction to the "one to
+ one" relationship pattern, which is typically when the
+ :paramref:`_orm.relationship.uselist` flag is needed.
+
+ :param viewonly=False:
+ When set to ``True``, the relationship is used only for loading
+ objects, and not for any persistence operation. A
+ :func:`_orm.relationship` which specifies
+ :paramref:`_orm.relationship.viewonly` can work
+ with a wider range of SQL operations within the
+ :paramref:`_orm.relationship.primaryjoin` condition, including
+ operations that feature the use of a variety of comparison operators
+ as well as SQL functions such as :func:`_expression.cast`. The
+ :paramref:`_orm.relationship.viewonly`
+ flag is also of general use when defining any kind of
+ :func:`_orm.relationship` that doesn't represent
+ the full set of related objects, to prevent modifications of the
+ collection from resulting in persistence operations.
+
+ When using the :paramref:`_orm.relationship.viewonly` flag in
+ conjunction with backrefs, the originating relationship for a
+ particular state change will not produce state changes within the
+ viewonly relationship. This is the behavior implied by
+ :paramref:`_orm.relationship.sync_backref` being set to False.
+
+ .. versionchanged:: 1.3.17 - the
+ :paramref:`_orm.relationship.sync_backref` flag is set to False
+ when using viewonly in conjunction with backrefs.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.sync_backref`
+
+ :param sync_backref:
+ A boolean that enables the events used to synchronize the in-Python
+ attributes when this relationship is target of either
+ :paramref:`_orm.relationship.backref` or
+ :paramref:`_orm.relationship.back_populates`.
+
+ Defaults to ``None``, which indicates that an automatic value should
+ be selected based on the value of the
+ :paramref:`_orm.relationship.viewonly` flag. When left at its
+ default, changes in state will be back-populated only if neither
+ sides of a relationship is viewonly.
+
+ .. versionadded:: 1.3.17
+
+ .. versionchanged:: 1.4 - A relationship that specifies
+ :paramref:`_orm.relationship.viewonly` automatically implies
+ that :paramref:`_orm.relationship.sync_backref` is ``False``.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.viewonly`
+
+ :param omit_join:
+ Allows manual control over the "selectin" automatic join
+ optimization. Set to ``False`` to disable the "omit join" feature
+ added in SQLAlchemy 1.3; or leave as ``None`` to leave automatic
+ optimization in place.
+
+ .. note:: This flag may only be set to ``False``. It is not
+ necessary to set it to ``True`` as the "omit_join" optimization is
+ automatically detected; if it is not detected, then the
+ optimization is not supported.
+
+ .. versionchanged:: 1.3.11 setting ``omit_join`` to True will now
+ emit a warning as this was not the intended use of this flag.
+
+ .. versionadded:: 1.3
+
+
+ """
+ super(RelationshipProperty, self).__init__()
+
+ self.uselist = uselist
+ self.argument = argument
+ self.secondary = secondary
+ self.primaryjoin = primaryjoin
+ self.secondaryjoin = secondaryjoin
+ self.post_update = post_update
+ self.direction = None
+ self.viewonly = viewonly
+ if viewonly:
+ self._warn_for_persistence_only_flags(
+ passive_deletes=passive_deletes,
+ passive_updates=passive_updates,
+ enable_typechecks=enable_typechecks,
+ active_history=active_history,
+ cascade_backrefs=cascade_backrefs,
+ )
+ if viewonly and sync_backref:
+ raise sa_exc.ArgumentError(
+ "sync_backref and viewonly cannot both be True"
+ )
+ self.sync_backref = sync_backref
+ self.lazy = lazy
+ self.single_parent = single_parent
+ self._user_defined_foreign_keys = foreign_keys
+ self.collection_class = collection_class
+ self.passive_deletes = passive_deletes
+ self.cascade_backrefs = cascade_backrefs
+ self.passive_updates = passive_updates
+ self.remote_side = remote_side
+ self.enable_typechecks = enable_typechecks
+ self.query_class = query_class
+ self.innerjoin = innerjoin
+ self.distinct_target_key = distinct_target_key
+ self.doc = doc
+ self.active_history = active_history
+ self._legacy_inactive_history_style = _legacy_inactive_history_style
+
+ self.join_depth = join_depth
+ if omit_join:
+ util.warn(
+ "setting omit_join to True is not supported; selectin "
+ "loading of this relationship may not work correctly if this "
+ "flag is set explicitly. omit_join optimization is "
+ "automatically detected for conditions under which it is "
+ "supported."
+ )
+
+ self.omit_join = omit_join
+ self.local_remote_pairs = _local_remote_pairs
+ self.bake_queries = bake_queries
+ self.load_on_pending = load_on_pending
+ self.comparator_factory = (
+ comparator_factory or RelationshipProperty.Comparator
+ )
+ self.comparator = self.comparator_factory(self, None)
+ util.set_creation_order(self)
+
+ if info is not None:
+ self.info = info
+
+ self.strategy_key = (("lazy", self.lazy),)
+
+ self._reverse_property = set()
+ if overlaps:
+ self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+ else:
+ self._overlaps = ()
+
+ if cascade is not False:
+ self.cascade = cascade
+ elif self.viewonly:
+ self.cascade = "none"
+ else:
+ self.cascade = "save-update, merge"
+
+ self.order_by = order_by
+
+ self.back_populates = back_populates
+
+ if self.back_populates:
+ if backref:
+ raise sa_exc.ArgumentError(
+ "backref and back_populates keyword arguments "
+ "are mutually exclusive"
+ )
+ self.backref = None
+ else:
+ self.backref = backref
+
+ def _warn_for_persistence_only_flags(self, **kw):
+ for k, v in kw.items():
+ if v != self._persistence_only[k]:
+ # we are warning here rather than warn deprecated as this is a
+ # configuration mistake, and Python shows regular warnings more
+ # aggressively than deprecation warnings by default. Unlike the
+ # case of setting viewonly with cascade, the settings being
+ # warned about here are not actively doing the wrong thing
+ # against viewonly=True, so it is not as urgent to have these
+ # raise an error.
+ util.warn(
+ "Setting %s on relationship() while also "
+ "setting viewonly=True does not make sense, as a "
+ "viewonly=True relationship does not perform persistence "
+ "operations. This configuration may raise an error "
+ "in a future release." % (k,)
+ )
+
+ def instrument_class(self, mapper):
+ attributes.register_descriptor(
+ mapper.class_,
+ self.key,
+ comparator=self.comparator_factory(self, mapper),
+ parententity=mapper,
+ doc=self.doc,
+ )
+
+ class Comparator(PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.RelationshipProperty` attributes.
+
+ See the documentation for :class:`.PropComparator` for a brief
+ overview of ORM level operator definition.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnProperty.Comparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ _of_type = None
+ _extra_criteria = ()
+
+ def __init__(
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ of_type=None,
+ extra_criteria=(),
+ ):
+ """Construction of :class:`.RelationshipProperty.Comparator`
+ is internal to the ORM's attribute mechanics.
+
+ """
+ self.prop = prop
+ self._parententity = parentmapper
+ self._adapt_to_entity = adapt_to_entity
+ if of_type:
+ self._of_type = of_type
+ self._extra_criteria = extra_criteria
+
+ def adapt_to_entity(self, adapt_to_entity):
+ return self.__class__(
+ self.property,
+ self._parententity,
+ adapt_to_entity=adapt_to_entity,
+ of_type=self._of_type,
+ )
+
+ @util.memoized_property
+ def entity(self):
+ """The target entity referred to by this
+ :class:`.RelationshipProperty.Comparator`.
+
+ This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
+ object.
+
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
+
+ """
+ # this is a relatively recent change made for
+ # 1.4.27 as part of #7244.
+ # TODO: shouldn't _of_type be inspected up front when received?
+ if self._of_type is not None:
+ return inspect(self._of_type)
+ else:
+ return self.property.entity
+
+ @util.memoized_property
+ def mapper(self):
+ """The target :class:`_orm.Mapper` referred to by this
+ :class:`.RelationshipProperty.Comparator`.
+
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
+
+ """
+ return self.property.mapper
+
+ @util.memoized_property
+ def _parententity(self):
+ return self.property.parent
+
+ def _source_selectable(self):
+ if self._adapt_to_entity:
+ return self._adapt_to_entity.selectable
+ else:
+ return self.property.parent._with_polymorphic_selectable
+
+ def __clause_element__(self):
+ adapt_from = self._source_selectable()
+ if self._of_type:
+ of_type_entity = inspect(self._of_type)
+ else:
+ of_type_entity = None
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = self.property._create_joins(
+ source_selectable=adapt_from,
+ source_polymorphic=True,
+ of_type_entity=of_type_entity,
+ alias_secondary=True,
+ extra_criteria=self._extra_criteria,
+ )
+ if sj is not None:
+ return pj & sj
+ else:
+ return pj
+
+ def of_type(self, cls):
+ r"""Redefine this object in terms of a polymorphic subclass.
+
+ See :meth:`.PropComparator.of_type` for an example.
+
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=cls,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ """Add AND criteria.
+
+ See :meth:`.PropComparator.and_` for an example.
+
+ .. versionadded:: 1.4
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def in_(self, other):
+ """Produce an IN clause - this is not implemented
+ for :func:`_orm.relationship`-based attributes at this time.
+
+ """
+ raise NotImplementedError(
+ "in_() not yet supported for "
+ "relationships. For a simple "
+ "many-to-one, use in_() against "
+ "the set of foreign key values."
+ )
+
+ __hash__ = None
+
+ def __eq__(self, other):
+ """Implement the ``==`` operator.
+
+ In a many-to-one context, such as::
+
+ MyClass.some_prop == <some object>
+
+ this will typically produce a
+ clause such as::
+
+ mytable.related_id == <some id>
+
+ Where ``<some id>`` is the primary key of the given
+ object.
+
+ The ``==`` operator provides partial functionality for non-
+ many-to-one comparisons:
+
+ * Comparisons against collections are not supported.
+ Use :meth:`~.RelationshipProperty.Comparator.contains`.
+ * Compared to a scalar one-to-many, will produce a
+ clause that compares the target columns in the parent to
+ the given target.
+ * Compared to a scalar many-to-many, an alias
+ of the association table will be rendered as
+ well, forming a natural join that is part of the
+ main body of the query. This will not work for
+ queries that go beyond simple AND conjunctions of
+ comparisons, such as those which use OR. Use
+ explicit joins, outerjoins, or
+ :meth:`~.RelationshipProperty.Comparator.has` for
+ more comprehensive non-many-to-one scalar
+ membership tests.
+ * Comparisons against ``None`` given in a one-to-many
+ or many-to-many context produce a NOT EXISTS clause.
+
+ """
+ if isinstance(other, (util.NoneType, expression.Null)):
+ if self.property.direction in [ONETOMANY, MANYTOMANY]:
+ return ~self._criterion_exists()
+ else:
+ return _orm_annotate(
+ self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
+ elif self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection to an object or collection; "
+ "use contains() to test for membership."
+ )
+ else:
+ return _orm_annotate(
+ self.property._optimized_compare(
+ other, adapt_source=self.adapter
+ )
+ )
+
+ def _criterion_exists(self, criterion=None, **kwargs):
+ if getattr(self, "_of_type", None):
+ info = inspect(self._of_type)
+ target_mapper, to_selectable, is_aliased_class = (
+ info.mapper,
+ info.selectable,
+ info.is_aliased_class,
+ )
+ if self.property._is_self_referential and not is_aliased_class:
+ to_selectable = to_selectable._anonymous_fromclause()
+
+ single_crit = target_mapper._single_table_criterion
+ if single_crit is not None:
+ if criterion is not None:
+ criterion = single_crit & criterion
+ else:
+ criterion = single_crit
+ else:
+ is_aliased_class = False
+ to_selectable = None
+
+ if self.adapter:
+ source_selectable = self._source_selectable()
+ else:
+ source_selectable = None
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = self.property._create_joins(
+ dest_selectable=to_selectable,
+ source_selectable=source_selectable,
+ )
+
+ for k in kwargs:
+ crit = getattr(self.property.mapper.class_, k) == kwargs[k]
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+
+ # annotate the *local* side of the join condition, in the case
+ # of pj + sj this is the full primaryjoin, in the case of just
+ # pj its the local side of the primaryjoin.
+ if sj is not None:
+ j = _orm_annotate(pj) & sj
+ else:
+ j = _orm_annotate(pj, exclude=self.property.remote_side)
+
+ if (
+ criterion is not None
+ and target_adapter
+ and not is_aliased_class
+ ):
+ # limit this adapter to annotated only?
+ criterion = target_adapter.traverse(criterion)
+
+ # only have the "joined left side" of what we
+ # return be subject to Query adaption. The right
+ # side of it is used for an exists() subquery and
+ # should not correlate or otherwise reach out
+ # to anything in the enclosing query.
+ if criterion is not None:
+ criterion = criterion._annotate(
+ {"no_replacement_traverse": True}
+ )
+
+ crit = j & sql.True_._ifnone(criterion)
+
+ if secondary is not None:
+ ex = (
+ sql.exists(1)
+ .where(crit)
+ .select_from(dest, secondary)
+ .correlate_except(dest, secondary)
+ )
+ else:
+ ex = (
+ sql.exists(1)
+ .where(crit)
+ .select_from(dest)
+ .correlate_except(dest)
+ )
+ return ex
+
+ def any(self, criterion=None, **kwargs):
+ """Produce an expression that tests a collection against
+ particular criterion, using EXISTS.
+
+ An expression like::
+
+ session.query(MyClass).filter(
+ MyClass.somereference.any(SomeRelated.x==2)
+ )
+
+
+ Will produce a query like::
+
+ SELECT * FROM my_table WHERE
+ EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id
+ AND related.x=2)
+
+ Because :meth:`~.RelationshipProperty.Comparator.any` uses
+ a correlated subquery, its performance is not nearly as
+ good when compared against large target tables as that of
+ using a join.
+
+ :meth:`~.RelationshipProperty.Comparator.any` is particularly
+ useful for testing for empty collections::
+
+ session.query(MyClass).filter(
+ ~MyClass.somereference.any()
+ )
+
+ will produce::
+
+ SELECT * FROM my_table WHERE
+ NOT (EXISTS (SELECT 1 FROM related WHERE
+ related.my_id=my_table.id))
+
+ :meth:`~.RelationshipProperty.Comparator.any` is only
+ valid for collections, i.e. a :func:`_orm.relationship`
+ that has ``uselist=True``. For scalar references,
+ use :meth:`~.RelationshipProperty.Comparator.has`.
+
+ """
+ if not self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'any()' not implemented for scalar "
+ "attributes. Use has()."
+ )
+
+ return self._criterion_exists(criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ """Produce an expression that tests a scalar reference against
+ particular criterion, using EXISTS.
+
+ An expression like::
+
+ session.query(MyClass).filter(
+ MyClass.somereference.has(SomeRelated.x==2)
+ )
+
+
+ Will produce a query like::
+
+ SELECT * FROM my_table WHERE
+ EXISTS (SELECT 1 FROM related WHERE
+ related.id==my_table.related_id AND related.x=2)
+
+ Because :meth:`~.RelationshipProperty.Comparator.has` uses
+ a correlated subquery, its performance is not nearly as
+ good when compared against large target tables as that of
+ using a join.
+
+ :meth:`~.RelationshipProperty.Comparator.has` is only
+ valid for scalar references, i.e. a :func:`_orm.relationship`
+ that has ``uselist=False``. For collection references,
+ use :meth:`~.RelationshipProperty.Comparator.any`.
+
+ """
+ if self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'has()' not implemented for collections. " "Use any()."
+ )
+ return self._criterion_exists(criterion, **kwargs)
+
+ def contains(self, other, **kwargs):
+ """Return a simple expression that tests a collection for
+ containment of a particular item.
+
+ :meth:`~.RelationshipProperty.Comparator.contains` is
+ only valid for a collection, i.e. a
+ :func:`_orm.relationship` that implements
+ one-to-many or many-to-many with ``uselist=True``.
+
+ When used in a simple one-to-many context, an
+ expression like::
+
+ MyClass.contains(other)
+
+ Produces a clause like::
+
+ mytable.id == <some id>
+
+ Where ``<some id>`` is the value of the foreign key
+ attribute on ``other`` which refers to the primary
+ key of its parent object. From this it follows that
+ :meth:`~.RelationshipProperty.Comparator.contains` is
+ very useful when used with simple one-to-many
+ operations.
+
+ For many-to-many operations, the behavior of
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ has more caveats. The association table will be
+ rendered in the statement, producing an "implicit"
+ join, that is, includes multiple tables in the FROM
+ clause which are equated in the WHERE clause::
+
+ query(MyClass).filter(MyClass.contains(other))
+
+ Produces a query like::
+
+ SELECT * FROM my_table, my_association_table AS
+ my_association_table_1 WHERE
+ my_table.id = my_association_table_1.parent_id
+ AND my_association_table_1.child_id = <some id>
+
+ Where ``<some id>`` would be the primary key of
+ ``other``. From the above, it is clear that
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ will **not** work with many-to-many collections when
+ used in queries that move beyond simple AND
+ conjunctions, such as multiple
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ expressions joined by OR. In such cases subqueries or
+ explicit "outer joins" will need to be used instead.
+ See :meth:`~.RelationshipProperty.Comparator.any` for
+ a less-performant alternative using EXISTS, or refer
+ to :meth:`_query.Query.outerjoin`
+ as well as :ref:`orm_queryguide_joins`
+ for more details on constructing outer joins.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ if not self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'contains' not implemented for scalar "
+ "attributes. Use =="
+ )
+ clause = self.property._optimized_compare(
+ other, adapt_source=self.adapter
+ )
+
+ if self.property.secondaryjoin is not None:
+ clause.negation_clause = self.__negated_contains_or_equals(
+ other
+ )
+
+ return clause
+
+ def __negated_contains_or_equals(self, other):
+ if self.property.direction == MANYTOONE:
+ state = attributes.instance_state(other)
+
+ def state_bindparam(local_col, state, remote_col):
+ dict_ = state.dict
+ return sql.bindparam(
+ local_col.key,
+ type_=local_col.type,
+ unique=True,
+ callable_=self.property._get_attr_w_warn_on_none(
+ self.property.mapper, state, dict_, remote_col
+ ),
+ )
+
+ def adapt(col):
+ if self.adapter:
+ return self.adapter(col)
+ else:
+ return col
+
+ if self.property._use_get:
+ return sql.and_(
+ *[
+ sql.or_(
+ adapt(x)
+ != state_bindparam(adapt(x), state, y),
+ adapt(x) == None,
+ )
+ for (x, y) in self.property.local_remote_pairs
+ ]
+ )
+
+ criterion = sql.and_(
+ *[
+ x == y
+ for (x, y) in zip(
+ self.property.mapper.primary_key,
+ self.property.mapper.primary_key_from_instance(other),
+ )
+ ]
+ )
+
+ return ~self._criterion_exists(criterion)
+
+ def __ne__(self, other):
+ """Implement the ``!=`` operator.
+
+ In a many-to-one context, such as::
+
+ MyClass.some_prop != <some object>
+
+ This will typically produce a clause such as::
+
+ mytable.related_id != <some id>
+
+ Where ``<some id>`` is the primary key of the
+ given object.
+
+ The ``!=`` operator provides partial functionality for non-
+ many-to-one comparisons:
+
+ * Comparisons against collections are not supported.
+ Use
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ in conjunction with :func:`_expression.not_`.
+ * Compared to a scalar one-to-many, will produce a
+ clause that compares the target columns in the parent to
+ the given target.
+ * Compared to a scalar many-to-many, an alias
+ of the association table will be rendered as
+ well, forming a natural join that is part of the
+ main body of the query. This will not work for
+ queries that go beyond simple AND conjunctions of
+ comparisons, such as those which use OR. Use
+ explicit joins, outerjoins, or
+ :meth:`~.RelationshipProperty.Comparator.has` in
+ conjunction with :func:`_expression.not_` for
+ more comprehensive non-many-to-one scalar
+ membership tests.
+ * Comparisons against ``None`` given in a one-to-many
+ or many-to-many context produce an EXISTS clause.
+
+ """
+ if isinstance(other, (util.NoneType, expression.Null)):
+ if self.property.direction == MANYTOONE:
+ return _orm_annotate(
+ ~self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
+
+ else:
+ return self._criterion_exists()
+ elif self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection"
+ " to an object or collection; use "
+ "contains() to test for membership."
+ )
+ else:
+ return _orm_annotate(self.__negated_contains_or_equals(other))
+
+ @util.memoized_property
+ def property(self):
+ self.prop.parent._check_configure()
+ return self.prop
+
+ def _with_parent(self, instance, alias_secondary=True, from_entity=None):
+ assert instance is not None
+ adapt_source = None
+ if from_entity is not None:
+ insp = inspect(from_entity)
+ if insp.is_aliased_class:
+ adapt_source = insp._adapter.adapt_clause
+ return self._optimized_compare(
+ instance,
+ value_is_parent=True,
+ adapt_source=adapt_source,
+ alias_secondary=alias_secondary,
+ )
+
+ def _optimized_compare(
+ self,
+ state,
+ value_is_parent=False,
+ adapt_source=None,
+ alias_secondary=True,
+ ):
+ if state is not None:
+ try:
+ state = inspect(state)
+ except sa_exc.NoInspectionAvailable:
+ state = None
+
+ if state is None or not getattr(state, "is_instance", False):
+ raise sa_exc.ArgumentError(
+ "Mapped instance expected for relationship "
+ "comparison to object. Classes, queries and other "
+ "SQL elements are not accepted in this context; for "
+ "comparison with a subquery, "
+ "use %s.has(**criteria)." % self
+ )
+ reverse_direction = not value_is_parent
+
+ if state is None:
+ return self._lazy_none_clause(
+ reverse_direction, adapt_source=adapt_source
+ )
+
+ if not reverse_direction:
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
+ else:
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
+
+ if reverse_direction:
+ mapper = self.mapper
+ else:
+ mapper = self.parent
+
+ dict_ = attributes.instance_dict(state.obj())
+
+ def visit_bindparam(bindparam):
+ if bindparam._identifying_key in bind_to_col:
+ bindparam.callable = self._get_attr_w_warn_on_none(
+ mapper,
+ state,
+ dict_,
+ bind_to_col[bindparam._identifying_key],
+ )
+
+ if self.secondary is not None and alias_secondary:
+ criterion = ClauseAdapter(
+ self.secondary._anonymous_fromclause()
+ ).traverse(criterion)
+
+ criterion = visitors.cloned_traverse(
+ criterion, {}, {"bindparam": visit_bindparam}
+ )
+
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
+
+ def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+ """Create the callable that is used in a many-to-one expression.
+
+ E.g.::
+
+ u1 = s.query(User).get(5)
+
+ expr = Address.user == u1
+
+ Above, the SQL should be "address.user_id = 5". The callable
+ returned by this method produces the value "5" based on the identity
+ of ``u1``.
+
+ """
+
+ # in this callable, we're trying to thread the needle through
+ # a wide variety of scenarios, including:
+ #
+ # * the object hasn't been flushed yet and there's no value for
+ # the attribute as of yet
+ #
+ # * the object hasn't been flushed yet but it has a user-defined
+ # value
+ #
+ # * the object has a value but it's expired and not locally present
+ #
+ # * the object has a value but it's expired and not locally present,
+ # and the object is also detached
+ #
+ # * The object hadn't been flushed yet, there was no value, but
+ # later, the object has been expired and detached, and *now*
+ # they're trying to evaluate it
+ #
+ # * the object had a value, but it was changed to a new value, and
+ # then expired
+ #
+ # * the object had a value, but it was changed to a new value, and
+ # then expired, then the object was detached
+ #
+ # * the object has a user-set value, but it's None and we don't do
+ # the comparison correctly for that so warn
+ #
+
+ prop = mapper.get_property_by_column(column)
+
+ # by invoking this method, InstanceState will track the last known
+ # value for this key each time the attribute is to be expired.
+ # this feature was added explicitly for use in this method.
+ state._track_last_known_value(prop.key)
+
+ def _go():
+ last_known = to_return = state._last_known_values[prop.key]
+ existing_is_available = last_known is not attributes.NO_VALUE
+
+ # we support that the value may have changed. so here we
+ # try to get the most recent value including re-fetching.
+ # only if we can't get a value now due to detachment do we return
+ # the last known value
+ current_value = mapper._get_state_attr_by_column(
+ state,
+ dict_,
+ column,
+ passive=attributes.PASSIVE_OFF
+ if state.persistent
+ else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ )
+
+ if current_value is attributes.NEVER_SET:
+ if not existing_is_available:
+ raise sa_exc.InvalidRequestError(
+ "Can't resolve value for column %s on object "
+ "%s; no value has been set for this column"
+ % (column, state_str(state))
+ )
+ elif current_value is attributes.PASSIVE_NO_RESULT:
+ if not existing_is_available:
+ raise sa_exc.InvalidRequestError(
+ "Can't resolve value for column %s on object "
+ "%s; the object is detached and the value was "
+ "expired" % (column, state_str(state))
+ )
+ else:
+ to_return = current_value
+ if to_return is None:
+ util.warn(
+ "Got None for value of column %s; this is unsupported "
+ "for a relationship comparison and will not "
+ "currently produce an IS comparison "
+ "(but may in a future release)" % column
+ )
+ return to_return
+
+ return _go
+
+ def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+ if not reverse_direction:
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
+ else:
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
+
+ criterion = adapt_criterion_to_null(criterion, bind_to_col)
+
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+
+ if load:
+ for r in self._reverse_property:
+ if (source_state, r) in _recursive:
+ return
+
+ if "merge" not in self._cascade:
+ return
+
+ if self.key not in source_dict:
+ return
+
+ if self.uselist:
+ impl = source_state.get_impl(self.key)
+ instances_iterable = impl.get_collection(source_state, source_dict)
+
+ # if this is a CollectionAttributeImpl, then empty should
+ # be False, otherwise "self.key in source_dict" should not be
+ # True
+ assert not instances_iterable.empty if impl.collection else True
+
+ if load:
+ # for a full merge, pre-load the destination collection,
+ # so that individual _merge of each item pulls from identity
+ # map for those already present.
+ # also assumes CollectionAttributeImpl behavior of loading
+ # "old" list in any case
+ dest_state.get_impl(self.key).get(dest_state, dest_dict)
+
+ dest_list = []
+ for current in instances_iterable:
+ current_state = attributes.instance_state(current)
+ current_dict = attributes.instance_dict(current)
+ _recursive[(current_state, self)] = True
+ obj = session._merge(
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ if obj is not None:
+ dest_list.append(obj)
+
+ if not load:
+ coll = attributes.init_state_collection(
+ dest_state, dest_dict, self.key
+ )
+ for c in dest_list:
+ coll.append_without_event(c)
+ else:
+ dest_state.get_impl(self.key).set(
+ dest_state, dest_dict, dest_list, _adapt=False
+ )
+ else:
+ current = source_dict[self.key]
+ if current is not None:
+ current_state = attributes.instance_state(current)
+ current_dict = attributes.instance_dict(current)
+ _recursive[(current_state, self)] = True
+ obj = session._merge(
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ else:
+ obj = None
+
+ if not load:
+ dest_dict[self.key] = obj
+ else:
+ dest_state.get_impl(self.key).set(
+ dest_state, dest_dict, obj, None
+ )
+
+ def _value_as_iterable(
+ self, state, dict_, key, passive=attributes.PASSIVE_OFF
+ ):
+ """Return a list of tuples (state, obj) for the given
+ key.
+
+ returns an empty list if the value is None/empty/PASSIVE_NO_RESULT
+ """
+
+ impl = state.manager[key].impl
+ x = impl.get(state, dict_, passive=passive)
+ if x is attributes.PASSIVE_NO_RESULT or x is None:
+ return []
+ elif hasattr(impl, "get_collection"):
+ return [
+ (attributes.instance_state(o), o)
+ for o in impl.get_collection(state, dict_, x, passive=passive)
+ ]
+ else:
+ return [(attributes.instance_state(x), x)]
+
+ def cascade_iterator(
+ self, type_, state, dict_, visited_states, halt_on=None
+ ):
+ # assert type_ in self._cascade
+
+ # only actively lazy load on the 'delete' cascade
+ if type_ != "delete" or self.passive_deletes:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ if type_ == "save-update":
+ tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
+
+ else:
+ tuples = self._value_as_iterable(
+ state, dict_, self.key, passive=passive
+ )
+
+ skip_pending = (
+ type_ == "refresh-expire" and "delete-orphan" not in self._cascade
+ )
+
+ for instance_state, c in tuples:
+ if instance_state in visited_states:
+ continue
+
+ if c is None:
+ # would like to emit a warning here, but
+ # would not be consistent with collection.append(None)
+ # current behavior of silently skipping.
+ # see [ticket:2229]
+ continue
+
+ instance_dict = attributes.instance_dict(c)
+
+ if halt_on and halt_on(instance_state):
+ continue
+
+ if skip_pending and not instance_state.key:
+ continue
+
+ instance_mapper = instance_state.manager.mapper
+
+ if not instance_mapper.isa(self.mapper.class_manager.mapper):
+ raise AssertionError(
+ "Attribute '%s' on class '%s' "
+ "doesn't handle objects "
+ "of type '%s'"
+ % (self.key, self.parent.class_, c.__class__)
+ )
+
+ visited_states.add(instance_state)
+
+ yield c, instance_mapper, instance_state, instance_dict
+
+ @property
+ def _effective_sync_backref(self):
+ if self.viewonly:
+ return False
+ else:
+ return self.sync_backref is not False
+
+ @staticmethod
+ def _check_sync_backref(rel_a, rel_b):
+ if rel_a.viewonly and rel_b.sync_backref:
+ raise sa_exc.InvalidRequestError(
+ "Relationship %s cannot specify sync_backref=True since %s "
+ "includes viewonly=True." % (rel_b, rel_a)
+ )
+ if (
+ rel_a.viewonly
+ and not rel_b.viewonly
+ and rel_b.sync_backref is not False
+ ):
+ rel_b.sync_backref = False
+
+ def _add_reverse_property(self, key):
+ other = self.mapper.get_property(key, _configure_mappers=False)
+ if not isinstance(other, RelationshipProperty):
+ raise sa_exc.InvalidRequestError(
+ "back_populates on relationship '%s' refers to attribute '%s' "
+ "that is not a relationship. The back_populates parameter "
+ "should refer to the name of a relationship on the target "
+ "class." % (self, other)
+ )
+ # viewonly and sync_backref cases
+ # 1. self.viewonly==True and other.sync_backref==True -> error
+ # 2. self.viewonly==True and other.viewonly==False and
+ # other.sync_backref==None -> warn sync_backref=False, set to False
+ self._check_sync_backref(self, other)
+ # 3. other.viewonly==True and self.sync_backref==True -> error
+ # 4. other.viewonly==True and self.viewonly==False and
+ # self.sync_backref==None -> warn sync_backref=False, set to False
+ self._check_sync_backref(other, self)
+
+ self._reverse_property.add(other)
+ other._reverse_property.add(self)
+
+ if not other.mapper.common_parent(self.parent):
+ raise sa_exc.ArgumentError(
+ "reverse_property %r on "
+ "relationship %s references relationship %s, which "
+ "does not reference mapper %s"
+ % (key, self, other, self.parent)
+ )
+
+ if (
+ self.direction in (ONETOMANY, MANYTOONE)
+ and self.direction == other.direction
+ ):
+ raise sa_exc.ArgumentError(
+ "%s and back-reference %s are "
+ "both of the same direction %r. Did you mean to "
+ "set remote_side on the many-to-one side ?"
+ % (other, self, self.direction)
+ )
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.orm.mapper")
+ def entity(self):
+ """Return the target mapped entity, which is an inspect() of the
+ class or aliased class that is referred towards.
+
+ """
+
+ mapperlib = util.preloaded.orm_mapper
+
+ if isinstance(self.argument, util.string_types):
+ argument = self._clsregistry_resolve_name(self.argument)()
+
+ elif callable(self.argument) and not isinstance(
+ self.argument, (type, mapperlib.Mapper)
+ ):
+ argument = self.argument()
+ else:
+ argument = self.argument
+
+ if isinstance(argument, type):
+ return mapperlib.class_mapper(argument, configure=False)
+
+ try:
+ entity = inspect(argument)
+ except sa_exc.NoInspectionAvailable:
+ pass
+ else:
+ if hasattr(entity, "mapper"):
+ return entity
+
+ raise sa_exc.ArgumentError(
+ "relationship '%s' expects "
+ "a class or a mapper argument (received: %s)"
+ % (self.key, type(argument))
+ )
+
+ @util.memoized_property
+ def mapper(self):
+ """Return the targeted :class:`_orm.Mapper` for this
+ :class:`.RelationshipProperty`.
+
+ This is a lazy-initializing static attribute.
+
+ """
+ return self.entity.mapper
+
+ def do_init(self):
+ self._check_conflicts()
+ self._process_dependent_arguments()
+ self._setup_registry_dependencies()
+ self._setup_join_conditions()
+ self._check_cascade_settings(self._cascade)
+ self._post_init()
+ self._generate_backref()
+ self._join_condition._warn_for_conflicting_sync_targets()
+ super(RelationshipProperty, self).do_init()
+ self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+
+ def _setup_registry_dependencies(self):
+ self.parent.mapper.registry._set_depends_on(
+ self.entity.mapper.registry
+ )
+
+ def _process_dependent_arguments(self):
+ """Convert incoming configuration arguments to their
+ proper form.
+
+ Callables are resolved, ORM annotations removed.
+
+ """
+
+ # accept callables for other attributes which may require
+ # deferred initialization. This technique is used
+ # by declarative "string configs" and some recipes.
+ for attr in (
+ "order_by",
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "_user_defined_foreign_keys",
+ "remote_side",
+ ):
+ attr_value = getattr(self, attr)
+
+ if isinstance(attr_value, util.string_types):
+ setattr(
+ self,
+ attr,
+ self._clsregistry_resolve_arg(
+ attr_value, favor_tables=attr == "secondary"
+ )(),
+ )
+ elif callable(attr_value) and not _is_mapped_class(attr_value):
+ setattr(self, attr, attr_value())
+
+ # remove "annotations" which are present if mapped class
+ # descriptors are used to create the join expression.
+ for attr in "primaryjoin", "secondaryjoin":
+ val = getattr(self, attr)
+ if val is not None:
+ setattr(
+ self,
+ attr,
+ _orm_deannotate(
+ coercions.expect(
+ roles.ColumnArgumentRole, val, argname=attr
+ )
+ ),
+ )
+
+ if self.secondary is not None and _is_mapped_class(self.secondary):
+ raise sa_exc.ArgumentError(
+ "secondary argument %s passed to to relationship() %s must "
+ "be a Table object or other FROM clause; can't send a mapped "
+ "class directly as rows in 'secondary' are persisted "
+ "independently of a class that is mapped "
+ "to that same table." % (self.secondary, self)
+ )
+
+ # ensure expressions in self.order_by, foreign_keys,
+ # remote_side are all columns, not strings.
+ if self.order_by is not False and self.order_by is not None:
+ self.order_by = tuple(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="order_by"
+ )
+ for x in util.to_list(self.order_by)
+ )
+
+ self._user_defined_foreign_keys = util.column_set(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="foreign_keys"
+ )
+ for x in util.to_column_set(self._user_defined_foreign_keys)
+ )
+
+ self.remote_side = util.column_set(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="remote_side"
+ )
+ for x in util.to_column_set(self.remote_side)
+ )
+
+ self.target = self.entity.persist_selectable
+
+ def _setup_join_conditions(self):
+ self._join_condition = jc = JoinCondition(
+ parent_persist_selectable=self.parent.persist_selectable,
+ child_persist_selectable=self.entity.persist_selectable,
+ parent_local_selectable=self.parent.local_table,
+ child_local_selectable=self.entity.local_table,
+ primaryjoin=self.primaryjoin,
+ secondary=self.secondary,
+ secondaryjoin=self.secondaryjoin,
+ parent_equivalents=self.parent._equivalent_columns,
+ child_equivalents=self.mapper._equivalent_columns,
+ consider_as_foreign_keys=self._user_defined_foreign_keys,
+ local_remote_pairs=self.local_remote_pairs,
+ remote_side=self.remote_side,
+ self_referential=self._is_self_referential,
+ prop=self,
+ support_sync=not self.viewonly,
+ can_be_synced_fn=self._columns_are_mapped,
+ )
+ self.primaryjoin = jc.primaryjoin
+ self.secondaryjoin = jc.secondaryjoin
+ self.direction = jc.direction
+ self.local_remote_pairs = jc.local_remote_pairs
+ self.remote_side = jc.remote_columns
+ self.local_columns = jc.local_columns
+ self.synchronize_pairs = jc.synchronize_pairs
+ self._calculated_foreign_keys = jc.foreign_key_columns
+ self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
+
+ @property
+ def _clsregistry_resolve_arg(self):
+ return self._clsregistry_resolvers[1]
+
+ @property
+ def _clsregistry_resolve_name(self):
+ return self._clsregistry_resolvers[0]
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.orm.clsregistry")
+ def _clsregistry_resolvers(self):
+ _resolver = util.preloaded.orm_clsregistry._resolver
+
+ return _resolver(self.parent.class_, self)
+
+ @util.preload_module("sqlalchemy.orm.mapper")
+ def _check_conflicts(self):
+ """Test that this relationship is legal, warn about
+ inheritance conflicts."""
+ mapperlib = util.preloaded.orm_mapper
+ if self.parent.non_primary and not mapperlib.class_mapper(
+ self.parent.class_, configure=False
+ ).has_property(self.key):
+ raise sa_exc.ArgumentError(
+ "Attempting to assign a new "
+ "relationship '%s' to a non-primary mapper on "
+ "class '%s'. New relationships can only be added "
+ "to the primary mapper, i.e. the very first mapper "
+ "created for class '%s' "
+ % (
+ self.key,
+ self.parent.class_.__name__,
+ self.parent.class_.__name__,
+ )
+ )
+
+ @property
+ def cascade(self):
+ """Return the current cascade setting for this
+ :class:`.RelationshipProperty`.
+ """
+ return self._cascade
+
+ @cascade.setter
+ def cascade(self, cascade):
+ self._set_cascade(cascade)
+
+ def _set_cascade(self, cascade):
+ cascade = CascadeOptions(cascade)
+
+ if self.viewonly:
+ non_viewonly = set(cascade).difference(
+ CascadeOptions._viewonly_cascades
+ )
+ if non_viewonly:
+ raise sa_exc.ArgumentError(
+ 'Cascade settings "%s" apply to persistence operations '
+ "and should not be combined with a viewonly=True "
+ "relationship." % (", ".join(sorted(non_viewonly)))
+ )
+
+ if "mapper" in self.__dict__:
+ self._check_cascade_settings(cascade)
+ self._cascade = cascade
+
+ if self._dependency_processor:
+ self._dependency_processor.cascade = cascade
+
+ def _check_cascade_settings(self, cascade):
+ if (
+ cascade.delete_orphan
+ and not self.single_parent
+ and (self.direction is MANYTOMANY or self.direction is MANYTOONE)
+ ):
+ raise sa_exc.ArgumentError(
+ "For %(direction)s relationship %(rel)s, delete-orphan "
+ "cascade is normally "
+ 'configured only on the "one" side of a one-to-many '
+ "relationship, "
+ 'and not on the "many" side of a many-to-one or many-to-many '
+ "relationship. "
+ "To force this relationship to allow a particular "
+ '"%(relatedcls)s" object to be referred towards by only '
+ 'a single "%(clsname)s" object at a time via the '
+ "%(rel)s relationship, which "
+ "would allow "
+ "delete-orphan cascade to take place in this direction, set "
+ "the single_parent=True flag."
+ % {
+ "rel": self,
+ "direction": "many-to-one"
+ if self.direction is MANYTOONE
+ else "many-to-many",
+ "clsname": self.parent.class_.__name__,
+ "relatedcls": self.mapper.class_.__name__,
+ },
+ code="bbf0",
+ )
+
+ if self.passive_deletes == "all" and (
+ "delete" in cascade or "delete-orphan" in cascade
+ ):
+ raise sa_exc.ArgumentError(
+ "On %s, can't set passive_deletes='all' in conjunction "
+ "with 'delete' or 'delete-orphan' cascade" % self
+ )
+
+ if cascade.delete_orphan:
+ self.mapper.primary_mapper()._delete_orphans.append(
+ (self.key, self.parent.class_)
+ )
+
+ def _persists_for(self, mapper):
+ """Return True if this property will persist values on behalf
+ of the given mapper.
+
+ """
+
+ return (
+ self.key in mapper.relationships
+ and mapper.relationships[self.key] is self
+ )
+
+ def _columns_are_mapped(self, *cols):
+ """Return True if all columns in the given collection are
+ mapped by the tables referenced by this :class:`.Relationship`.
+
+ """
+ for c in cols:
+ if (
+ self.secondary is not None
+ and self.secondary.c.contains_column(c)
+ ):
+ continue
+ if not self.parent.persist_selectable.c.contains_column(
+ c
+ ) and not self.target.c.contains_column(c):
+ return False
+ return True
+
+ def _generate_backref(self):
+ """Interpret the 'backref' instruction to create a
+ :func:`_orm.relationship` complementary to this one."""
+
+ if self.parent.non_primary:
+ return
+ if self.backref is not None and not self.back_populates:
+ if isinstance(self.backref, util.string_types):
+ backref_key, kwargs = self.backref, {}
+ else:
+ backref_key, kwargs = self.backref
+ mapper = self.mapper.primary_mapper()
+
+ if not mapper.concrete:
+ check = set(mapper.iterate_to_root()).union(
+ mapper.self_and_descendants
+ )
+ for m in check:
+ if m.has_property(backref_key) and not m.concrete:
+ raise sa_exc.ArgumentError(
+ "Error creating backref "
+ "'%s' on relationship '%s': property of that "
+ "name exists on mapper '%s'"
+ % (backref_key, self, m)
+ )
+
+ # determine primaryjoin/secondaryjoin for the
+ # backref. Use the one we had, so that
+ # a custom join doesn't have to be specified in
+ # both directions.
+ if self.secondary is not None:
+ # for many to many, just switch primaryjoin/
+ # secondaryjoin. use the annotated
+ # pj/sj on the _join_condition.
+ pj = kwargs.pop(
+ "primaryjoin",
+ self._join_condition.secondaryjoin_minus_local,
+ )
+ sj = kwargs.pop(
+ "secondaryjoin",
+ self._join_condition.primaryjoin_minus_local,
+ )
+ else:
+ pj = kwargs.pop(
+ "primaryjoin",
+ self._join_condition.primaryjoin_reverse_remote,
+ )
+ sj = kwargs.pop("secondaryjoin", None)
+ if sj:
+ raise sa_exc.InvalidRequestError(
+ "Can't assign 'secondaryjoin' on a backref "
+ "against a non-secondary relationship."
+ )
+
+ foreign_keys = kwargs.pop(
+ "foreign_keys", self._user_defined_foreign_keys
+ )
+ parent = self.parent.primary_mapper()
+ kwargs.setdefault("viewonly", self.viewonly)
+ kwargs.setdefault("post_update", self.post_update)
+ kwargs.setdefault("passive_updates", self.passive_updates)
+ kwargs.setdefault("sync_backref", self.sync_backref)
+ self.back_populates = backref_key
+ relationship = RelationshipProperty(
+ parent,
+ self.secondary,
+ pj,
+ sj,
+ foreign_keys=foreign_keys,
+ back_populates=self.key,
+ **kwargs
+ )
+ mapper._configure_property(backref_key, relationship)
+
+ if self.back_populates:
+ self._add_reverse_property(self.back_populates)
+
+ @util.preload_module("sqlalchemy.orm.dependency")
+ def _post_init(self):
+ dependency = util.preloaded.orm_dependency
+
+ if self.uselist is None:
+ self.uselist = self.direction is not MANYTOONE
+ if not self.viewonly:
+ self._dependency_processor = (
+ dependency.DependencyProcessor.from_relationship
+ )(self)
+
+ @util.memoized_property
+ def _use_get(self):
+ """memoize the 'use_get' attribute of this RelationshipLoader's
+ lazyloader."""
+
+ strategy = self._lazy_strategy
+ return strategy.use_get
+
+ @util.memoized_property
+ def _is_self_referential(self):
+ return self.mapper.common_parent(self.parent)
+
+ def _create_joins(
+ self,
+ source_polymorphic=False,
+ source_selectable=None,
+ dest_selectable=None,
+ of_type_entity=None,
+ alias_secondary=False,
+ extra_criteria=(),
+ ):
+
+ aliased = False
+
+ if alias_secondary and self.secondary is not None:
+ aliased = True
+
+ if source_selectable is None:
+ if source_polymorphic and self.parent.with_polymorphic:
+ source_selectable = self.parent._with_polymorphic_selectable
+
+ if of_type_entity:
+ dest_mapper = of_type_entity.mapper
+ if dest_selectable is None:
+ dest_selectable = of_type_entity.selectable
+ aliased = True
+ else:
+ dest_mapper = self.mapper
+
+ if dest_selectable is None:
+ dest_selectable = self.entity.selectable
+ if self.mapper.with_polymorphic:
+ aliased = True
+
+ if self._is_self_referential and source_selectable is None:
+ dest_selectable = dest_selectable._anonymous_fromclause()
+ aliased = True
+ elif (
+ dest_selectable is not self.mapper._with_polymorphic_selectable
+ or self.mapper.with_polymorphic
+ ):
+ aliased = True
+
+ single_crit = dest_mapper._single_table_criterion
+ aliased = aliased or (
+ source_selectable is not None
+ and (
+ source_selectable
+ is not self.parent._with_polymorphic_selectable
+ or source_selectable._is_subquery
+ )
+ )
+
+ (
+ primaryjoin,
+ secondaryjoin,
+ secondary,
+ target_adapter,
+ dest_selectable,
+ ) = self._join_condition.join_targets(
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit,
+ extra_criteria,
+ )
+ if source_selectable is None:
+ source_selectable = self.parent.local_table
+ if dest_selectable is None:
+ dest_selectable = self.entity.local_table
+ return (
+ primaryjoin,
+ secondaryjoin,
+ source_selectable,
+ dest_selectable,
+ secondary,
+ target_adapter,
+ )
+
+
+def _annotate_columns(element, annotations):
+ def clone(elem):
+ if isinstance(elem, expression.ColumnClause):
+ elem = elem._annotate(annotations.copy())
+ elem._copy_internals(clone=clone)
+ return elem
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+class JoinCondition(object):
+ def __init__(
+ self,
+ parent_persist_selectable,
+ child_persist_selectable,
+ parent_local_selectable,
+ child_local_selectable,
+ primaryjoin=None,
+ secondary=None,
+ secondaryjoin=None,
+ parent_equivalents=None,
+ child_equivalents=None,
+ consider_as_foreign_keys=None,
+ local_remote_pairs=None,
+ remote_side=None,
+ self_referential=False,
+ prop=None,
+ support_sync=True,
+ can_be_synced_fn=lambda *c: True,
+ ):
+ self.parent_persist_selectable = parent_persist_selectable
+ self.parent_local_selectable = parent_local_selectable
+ self.child_persist_selectable = child_persist_selectable
+ self.child_local_selectable = child_local_selectable
+ self.parent_equivalents = parent_equivalents
+ self.child_equivalents = child_equivalents
+ self.primaryjoin = primaryjoin
+ self.secondaryjoin = secondaryjoin
+ self.secondary = secondary
+ self.consider_as_foreign_keys = consider_as_foreign_keys
+ self._local_remote_pairs = local_remote_pairs
+ self._remote_side = remote_side
+ self.prop = prop
+ self.self_referential = self_referential
+ self.support_sync = support_sync
+ self.can_be_synced_fn = can_be_synced_fn
+ self._determine_joins()
+ self._sanitize_joins()
+ self._annotate_fks()
+ self._annotate_remote()
+ self._annotate_local()
+ self._annotate_parentmapper()
+ self._setup_pairs()
+ self._check_foreign_cols(self.primaryjoin, True)
+ if self.secondaryjoin is not None:
+ self._check_foreign_cols(self.secondaryjoin, False)
+ self._determine_direction()
+ self._check_remote_side()
+ self._log_joins()
+
+ def _log_joins(self):
+ if self.prop is None:
+ return
+ log = self.prop.logger
+ log.info("%s setup primary join %s", self.prop, self.primaryjoin)
+ log.info("%s setup secondary join %s", self.prop, self.secondaryjoin)
+ log.info(
+ "%s synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs
+ ),
+ )
+ log.info(
+ "%s secondary synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r)
+ for (l, r) in self.secondary_synchronize_pairs or []
+ ),
+ )
+ log.info(
+ "%s local/remote pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs
+ ),
+ )
+ log.info(
+ "%s remote columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.remote_columns),
+ )
+ log.info(
+ "%s local columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.local_columns),
+ )
+ log.info("%s relationship direction %s", self.prop, self.direction)
+
+ def _sanitize_joins(self):
+ """remove the parententity annotation from our join conditions which
+ can leak in here based on some declarative patterns and maybe others.
+
+ We'd want to remove "parentmapper" also, but apparently there's
+ an exotic use case in _join_fixture_inh_selfref_w_entity
+ that relies upon it being present, see :ticket:`3364`.
+
+ """
+
+ self.primaryjoin = _deep_deannotate(
+ self.primaryjoin, values=("parententity", "proxy_key")
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = _deep_deannotate(
+ self.secondaryjoin, values=("parententity", "proxy_key")
+ )
+
+ def _determine_joins(self):
+ """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
+ if not passed to the constructor already.
+
+ This is based on analysis of the foreign key relationships
+ between the parent and target mapped selectables.
+
+ """
+ if self.secondaryjoin is not None and self.secondary is None:
+ raise sa_exc.ArgumentError(
+ "Property %s specified with secondary "
+ "join condition but "
+ "no secondary argument" % self.prop
+ )
+
+ # find a join between the given mapper's mapped table and
+ # the given table. will try the mapper's local table first
+ # for more specificity, then if not found will try the more
+ # general mapped table, which in the case of inheritance is
+ # a join.
+ try:
+ consider_as_foreign_keys = self.consider_as_foreign_keys or None
+ if self.secondary is not None:
+ if self.secondaryjoin is None:
+ self.secondaryjoin = join_condition(
+ self.child_persist_selectable,
+ self.secondary,
+ a_subset=self.child_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ if self.primaryjoin is None:
+ self.primaryjoin = join_condition(
+ self.parent_persist_selectable,
+ self.secondary,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ else:
+ if self.primaryjoin is None:
+ self.primaryjoin = join_condition(
+ self.parent_persist_selectable,
+ self.child_persist_selectable,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ except sa_exc.NoForeignKeysError as nfe:
+ if self.secondary is not None:
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are no foreign keys "
+ "linking these tables via secondary table '%s'. "
+ "Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or "
+ "specify 'primaryjoin' and 'secondaryjoin' "
+ "expressions." % (self.prop, self.secondary)
+ ),
+ from_=nfe,
+ )
+ else:
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are no foreign keys "
+ "linking these tables. "
+ "Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or "
+ "specify a 'primaryjoin' expression." % self.prop
+ ),
+ from_=nfe,
+ )
+ except sa_exc.AmbiguousForeignKeysError as afe:
+ if self.secondary is not None:
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are multiple foreign key "
+ "paths linking the tables via secondary table '%s'. "
+ "Specify the 'foreign_keys' "
+ "argument, providing a list of those columns which "
+ "should be counted as containing a foreign key "
+ "reference from the secondary table to each of the "
+ "parent and child tables."
+ % (self.prop, self.secondary)
+ ),
+ from_=afe,
+ )
+ else:
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are multiple foreign key "
+ "paths linking the tables. Specify the "
+ "'foreign_keys' argument, providing a list of those "
+ "columns which should be counted as containing a "
+ "foreign key reference to the parent table."
+ % self.prop
+ ),
+ from_=afe,
+ )
+
+ @property
+ def primaryjoin_minus_local(self):
+ return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
+
+ @property
+ def secondaryjoin_minus_local(self):
+ return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
+
+ @util.memoized_property
+ def primaryjoin_reverse_remote(self):
+ """Return the primaryjoin condition suitable for the
+ "reverse" direction.
+
+ If the primaryjoin was delivered here with pre-existing
+ "remote" annotations, the local/remote annotations
+ are reversed. Otherwise, the local/remote annotations
+ are removed.
+
+ """
+ if self._has_remote_annotations:
+
+ def replace(element):
+ if "remote" in element._annotations:
+ v = dict(element._annotations)
+ del v["remote"]
+ v["local"] = True
+ return element._with_annotations(v)
+ elif "local" in element._annotations:
+ v = dict(element._annotations)
+ del v["local"]
+ v["remote"] = True
+ return element._with_annotations(v)
+
+ return visitors.replacement_traverse(self.primaryjoin, {}, replace)
+ else:
+ if self._has_foreign_annotations:
+ # TODO: coverage
+ return _deep_deannotate(
+ self.primaryjoin, values=("local", "remote")
+ )
+ else:
+ return _deep_deannotate(self.primaryjoin)
+
+ def _has_annotation(self, clause, annotation):
+ for col in visitors.iterate(clause, {}):
+ if annotation in col._annotations:
+ return True
+ else:
+ return False
+
+ @util.memoized_property
+ def _has_foreign_annotations(self):
+ return self._has_annotation(self.primaryjoin, "foreign")
+
+ @util.memoized_property
+ def _has_remote_annotations(self):
+ return self._has_annotation(self.primaryjoin, "remote")
+
+ def _annotate_fks(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'foreign' annotations marking columns
+ considered as foreign.
+
+ """
+ if self._has_foreign_annotations:
+ return
+
+ if self.consider_as_foreign_keys:
+ self._annotate_from_fk_list()
+ else:
+ self._annotate_present_fks()
+
+ def _annotate_from_fk_list(self):
+ def check_fk(col):
+ if col in self.consider_as_foreign_keys:
+ return col._annotate({"foreign": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, check_fk
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, check_fk
+ )
+
+ def _annotate_present_fks(self):
+ if self.secondary is not None:
+ secondarycols = util.column_set(self.secondary.c)
+ else:
+ secondarycols = set()
+
+ def is_foreign(a, b):
+ if isinstance(a, schema.Column) and isinstance(b, schema.Column):
+ if a.references(b):
+ return a
+ elif b.references(a):
+ return b
+
+ if secondarycols:
+ if a in secondarycols and b not in secondarycols:
+ return a
+ elif b in secondarycols and a not in secondarycols:
+ return b
+
+ def visit_binary(binary):
+ if not isinstance(
+ binary.left, sql.ColumnElement
+ ) or not isinstance(binary.right, sql.ColumnElement):
+ return
+
+ if (
+ "foreign" not in binary.left._annotations
+ and "foreign" not in binary.right._annotations
+ ):
+ col = is_foreign(binary.left, binary.right)
+ if col is not None:
+ if col.compare(binary.left):
+ binary.left = binary.left._annotate({"foreign": True})
+ elif col.compare(binary.right):
+ binary.right = binary.right._annotate(
+ {"foreign": True}
+ )
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.cloned_traverse(
+ self.secondaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _refers_to_parent_table(self):
+ """Return True if the join condition contains column
+ comparisons where both columns are in both tables.
+
+ """
+ pt = self.parent_persist_selectable
+ mt = self.child_persist_selectable
+ result = [False]
+
+ def visit_binary(binary):
+ c, f = binary.left, binary.right
+ if (
+ isinstance(c, expression.ColumnClause)
+ and isinstance(f, expression.ColumnClause)
+ and pt.is_derived_from(c.table)
+ and pt.is_derived_from(f.table)
+ and mt.is_derived_from(c.table)
+ and mt.is_derived_from(f.table)
+ ):
+ result[0] = True
+
+ visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
+ return result[0]
+
+ def _tables_overlap(self):
+ """Return True if parent/child tables have some overlap."""
+
+ return selectables_overlap(
+ self.parent_persist_selectable, self.child_persist_selectable
+ )
+
+ def _annotate_remote(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'remote' annotations marking columns
+ considered as part of the 'remote' side.
+
+ """
+ if self._has_remote_annotations:
+ return
+
+ if self.secondary is not None:
+ self._annotate_remote_secondary()
+ elif self._local_remote_pairs or self._remote_side:
+ self._annotate_remote_from_args()
+ elif self._refers_to_parent_table():
+ self._annotate_selfref(
+ lambda col: "foreign" in col._annotations, False
+ )
+ elif self._tables_overlap():
+ self._annotate_remote_with_overlap()
+ else:
+ self._annotate_remote_distinct_selectables()
+
+ def _annotate_remote_secondary(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when 'secondary' is present.
+
+ """
+
+ def repl(element):
+ if self.secondary.c.contains_column(element):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, repl
+ )
+
+ def _annotate_selfref(self, fn, remote_side_given):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the relationship is detected as self-referential.
+
+ """
+
+ def visit_binary(binary):
+ equated = binary.left.compare(binary.right)
+ if isinstance(binary.left, expression.ColumnClause) and isinstance(
+ binary.right, expression.ColumnClause
+ ):
+ # assume one to many - FKs are "remote"
+ if fn(binary.left):
+ binary.left = binary.left._annotate({"remote": True})
+ if fn(binary.right) and not equated:
+ binary.right = binary.right._annotate({"remote": True})
+ elif not remote_side_given:
+ self._warn_non_column_elements()
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _annotate_remote_from_args(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the 'remote_side' or '_local_remote_pairs'
+ arguments are used.
+
+ """
+ if self._local_remote_pairs:
+ if self._remote_side:
+ raise sa_exc.ArgumentError(
+ "remote_side argument is redundant "
+ "against more detailed _local_remote_side "
+ "argument."
+ )
+
+ remote_side = [r for (l, r) in self._local_remote_pairs]
+ else:
+ remote_side = self._remote_side
+
+ if self._refers_to_parent_table():
+ self._annotate_selfref(lambda col: col in remote_side, True)
+ else:
+
+ def repl(element):
+ # use set() to avoid generating ``__eq__()`` expressions
+ # against each element
+ if element in set(remote_side):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+
+ def _annotate_remote_with_overlap(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the parent/child tables have some set of
+ tables in common, though is not a fully self-referential
+ relationship.
+
+ """
+
+ def visit_binary(binary):
+ binary.left, binary.right = proc_left_right(
+ binary.left, binary.right
+ )
+ binary.right, binary.left = proc_left_right(
+ binary.right, binary.left
+ )
+
+ check_entities = (
+ self.prop is not None and self.prop.mapper is not self.prop.parent
+ )
+
+ def proc_left_right(left, right):
+ if isinstance(left, expression.ColumnClause) and isinstance(
+ right, expression.ColumnClause
+ ):
+ if self.child_persist_selectable.c.contains_column(
+ right
+ ) and self.parent_persist_selectable.c.contains_column(left):
+ right = right._annotate({"remote": True})
+ elif (
+ check_entities
+ and right._annotations.get("parentmapper") is self.prop.mapper
+ ):
+ right = right._annotate({"remote": True})
+ elif (
+ check_entities
+ and left._annotations.get("parentmapper") is self.prop.mapper
+ ):
+ left = left._annotate({"remote": True})
+ else:
+ self._warn_non_column_elements()
+
+ return left, right
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _annotate_remote_distinct_selectables(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the parent/child tables are entirely
+ separate.
+
+ """
+
+ def repl(element):
+ if self.child_persist_selectable.c.contains_column(element) and (
+ not self.parent_local_selectable.c.contains_column(element)
+ or self.child_local_selectable.c.contains_column(element)
+ ):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+
+ def _warn_non_column_elements(self):
+ util.warn(
+ "Non-simple column elements in primary "
+ "join condition for property %s - consider using "
+ "remote() annotations to mark the remote side." % self.prop
+ )
+
+ def _annotate_local(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'local' annotations.
+
+ This annotates all column elements found
+ simultaneously in the parent table
+ and the join condition that don't have a
+ 'remote' annotation set up from
+ _annotate_remote() or user-defined.
+
+ """
+ if self._has_annotation(self.primaryjoin, "local"):
+ return
+
+ if self._local_remote_pairs:
+ local_side = util.column_set(
+ [l for (l, r) in self._local_remote_pairs]
+ )
+ else:
+ local_side = util.column_set(self.parent_persist_selectable.c)
+
+ def locals_(elem):
+ if "remote" not in elem._annotations and elem in local_side:
+ return elem._annotate({"local": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, locals_
+ )
+
+ def _annotate_parentmapper(self):
+ if self.prop is None:
+ return
+
+ def parentmappers_(elem):
+ if "remote" in elem._annotations:
+ return elem._annotate({"parentmapper": self.prop.mapper})
+ elif "local" in elem._annotations:
+ return elem._annotate({"parentmapper": self.prop.parent})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, parentmappers_
+ )
+
+ def _check_remote_side(self):
+ if not self.local_remote_pairs:
+ raise sa_exc.ArgumentError(
+ "Relationship %s could "
+ "not determine any unambiguous local/remote column "
+ "pairs based on join condition and remote_side "
+ "arguments. "
+ "Consider using the remote() annotation to "
+ "accurately mark those elements of the join "
+ "condition that are on the remote side of "
+ "the relationship." % (self.prop,)
+ )
+
+ def _check_foreign_cols(self, join_condition, primary):
+ """Check the foreign key columns collected and emit error
+ messages."""
+
+ can_sync = False
+
+ foreign_cols = self._gather_columns_with_annotation(
+ join_condition, "foreign"
+ )
+
+ has_foreign = bool(foreign_cols)
+
+ if primary:
+ can_sync = bool(self.synchronize_pairs)
+ else:
+ can_sync = bool(self.secondary_synchronize_pairs)
+
+ if (
+ self.support_sync
+ and can_sync
+ or (not self.support_sync and has_foreign)
+ ):
+ return
+
+ # from here below is just determining the best error message
+ # to report. Check for a join condition using any operator
+ # (not just ==), perhaps they need to turn on "viewonly=True".
+ if self.support_sync and has_foreign and not can_sync:
+ err = (
+ "Could not locate any simple equality expressions "
+ "involving locally mapped foreign key columns for "
+ "%s join condition "
+ "'%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
+ join_condition,
+ self.prop,
+ )
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation. To allow comparison operators other than "
+ "'==', the relationship can be marked as viewonly=True."
+ )
+
+ raise sa_exc.ArgumentError(err)
+ else:
+ err = (
+ "Could not locate any relevant foreign key columns "
+ "for %s join condition '%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
+ join_condition,
+ self.prop,
+ )
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation."
+ )
+ raise sa_exc.ArgumentError(err)
+
+ def _determine_direction(self):
+ """Determine if this relationship is one to many, many to one,
+ many to many.
+
+ """
+ if self.secondaryjoin is not None:
+ self.direction = MANYTOMANY
+ else:
+ parentcols = util.column_set(self.parent_persist_selectable.c)
+ targetcols = util.column_set(self.child_persist_selectable.c)
+
+ # fk collection which suggests ONETOMANY.
+ onetomany_fk = targetcols.intersection(self.foreign_key_columns)
+
+ # fk collection which suggests MANYTOONE.
+
+ manytoone_fk = parentcols.intersection(self.foreign_key_columns)
+
+ if onetomany_fk and manytoone_fk:
+ # fks on both sides. test for overlap of local/remote
+ # with foreign key.
+ # we will gather columns directly from their annotations
+ # without deannotating, so that we can distinguish on a column
+ # that refers to itself.
+
+ # 1. columns that are both remote and FK suggest
+ # onetomany.
+ onetomany_local = self._gather_columns_with_annotation(
+ self.primaryjoin, "remote", "foreign"
+ )
+
+ # 2. columns that are FK but are not remote (e.g. local)
+ # suggest manytoone.
+ manytoone_local = set(
+ [
+ c
+ for c in self._gather_columns_with_annotation(
+ self.primaryjoin, "foreign"
+ )
+ if "remote" not in c._annotations
+ ]
+ )
+
+ # 3. if both collections are present, remove columns that
+ # refer to themselves. This is for the case of
+ # and_(Me.id == Me.remote_id, Me.version == Me.version)
+ if onetomany_local and manytoone_local:
+ self_equated = self.remote_columns.intersection(
+ self.local_columns
+ )
+ onetomany_local = onetomany_local.difference(self_equated)
+ manytoone_local = manytoone_local.difference(self_equated)
+
+ # at this point, if only one or the other collection is
+ # present, we know the direction, otherwise it's still
+ # ambiguous.
+
+ if onetomany_local and not manytoone_local:
+ self.direction = ONETOMANY
+ elif manytoone_local and not onetomany_local:
+ self.direction = MANYTOONE
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't determine relationship"
+ " direction for relationship '%s' - foreign "
+ "key columns within the join condition are present "
+ "in both the parent and the child's mapped tables. "
+ "Ensure that only those columns referring "
+ "to a parent column are marked as foreign, "
+ "either via the foreign() annotation or "
+ "via the foreign_keys argument." % self.prop
+ )
+ elif onetomany_fk:
+ self.direction = ONETOMANY
+ elif manytoone_fk:
+ self.direction = MANYTOONE
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't determine relationship "
+ "direction for relationship '%s' - foreign "
+ "key columns are present in neither the parent "
+ "nor the child's mapped tables" % self.prop
+ )
+
+ def _deannotate_pairs(self, collection):
+ """provide deannotation for the various lists of
+ pairs, so that using them in hashes doesn't incur
+ high-overhead __eq__() comparisons against
+ original columns mapped.
+
+ """
+ return [(x._deannotate(), y._deannotate()) for x, y in collection]
+
+ def _setup_pairs(self):
+ sync_pairs = []
+ lrp = util.OrderedSet([])
+ secondary_sync_pairs = []
+
+ def go(joincond, collection):
+ def visit_binary(binary, left, right):
+ if (
+ "remote" in right._annotations
+ and "remote" not in left._annotations
+ and self.can_be_synced_fn(left)
+ ):
+ lrp.add((left, right))
+ elif (
+ "remote" in left._annotations
+ and "remote" not in right._annotations
+ and self.can_be_synced_fn(right)
+ ):
+ lrp.add((right, left))
+ if binary.operator is operators.eq and self.can_be_synced_fn(
+ left, right
+ ):
+ if "foreign" in right._annotations:
+ collection.append((left, right))
+ elif "foreign" in left._annotations:
+ collection.append((right, left))
+
+ visit_binary_product(visit_binary, joincond)
+
+ for joincond, collection in [
+ (self.primaryjoin, sync_pairs),
+ (self.secondaryjoin, secondary_sync_pairs),
+ ]:
+ if joincond is None:
+ continue
+ go(joincond, collection)
+
+ self.local_remote_pairs = self._deannotate_pairs(lrp)
+ self.synchronize_pairs = self._deannotate_pairs(sync_pairs)
+ self.secondary_synchronize_pairs = self._deannotate_pairs(
+ secondary_sync_pairs
+ )
+
+ _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+
+ def _warn_for_conflicting_sync_targets(self):
+ if not self.support_sync:
+ return
+
+ # we would like to detect if we are synchronizing any column
+ # pairs in conflict with another relationship that wishes to sync
+ # an entirely different column to the same target. This is a
+ # very rare edge case so we will try to minimize the memory/overhead
+ # impact of this check
+ for from_, to_ in [
+ (from_, to_) for (from_, to_) in self.synchronize_pairs
+ ] + [
+ (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs
+ ]:
+ # save ourselves a ton of memory and overhead by only
+ # considering columns that are subject to a overlapping
+ # FK constraints at the core level. This condition can arise
+ # if multiple relationships overlap foreign() directly, but
+ # we're going to assume it's typically a ForeignKeyConstraint-
+ # level configuration that benefits from this warning.
+
+ if to_ not in self._track_overlapping_sync_targets:
+ self._track_overlapping_sync_targets[
+ to_
+ ] = weakref.WeakKeyDictionary({self.prop: from_})
+ else:
+ other_props = []
+ prop_to_from = self._track_overlapping_sync_targets[to_]
+
+ for pr, fr_ in prop_to_from.items():
+ if (
+ not pr.mapper._dispose_called
+ and pr not in self.prop._reverse_property
+ and pr.key not in self.prop._overlaps
+ and self.prop.key not in pr._overlaps
+ # note: the "__*" symbol is used internally by
+ # SQLAlchemy as a general means of suppressing the
+ # overlaps warning for some extension cases, however
+ # this is not currently
+ # a publicly supported symbol and may change at
+ # any time.
+ and "__*" not in self.prop._overlaps
+ and "__*" not in pr._overlaps
+ and not self.prop.parent.is_sibling(pr.parent)
+ and not self.prop.mapper.is_sibling(pr.mapper)
+ and not self.prop.parent.is_sibling(pr.mapper)
+ and not self.prop.mapper.is_sibling(pr.parent)
+ and (
+ self.prop.key != pr.key
+ or not self.prop.parent.common_parent(pr.parent)
+ )
+ ):
+
+ other_props.append((pr, fr_))
+
+ if other_props:
+ util.warn(
+ "relationship '%s' will copy column %s to column %s, "
+ "which conflicts with relationship(s): %s. "
+ "If this is not the intention, consider if these "
+ "relationships should be linked with "
+ "back_populates, or if viewonly=True should be "
+ "applied to one or more if they are read-only. "
+ "For the less common case that foreign key "
+ "constraints are partially overlapping, the "
+ "orm.foreign() "
+ "annotation can be used to isolate the columns that "
+ "should be written towards. To silence this "
+ "warning, add the parameter 'overlaps=\"%s\"' to the "
+ "'%s' relationship."
+ % (
+ self.prop,
+ from_,
+ to_,
+ ", ".join(
+ sorted(
+ "'%s' (copies %s to %s)" % (pr, fr_, to_)
+ for (pr, fr_) in other_props
+ )
+ ),
+ ",".join(sorted(pr.key for pr, fr in other_props)),
+ self.prop,
+ ),
+ code="qzyx",
+ )
+ self._track_overlapping_sync_targets[to_][self.prop] = from_
+
+ @util.memoized_property
+ def remote_columns(self):
+ return self._gather_join_annotations("remote")
+
+ @util.memoized_property
+ def local_columns(self):
+ return self._gather_join_annotations("local")
+
+ @util.memoized_property
+ def foreign_key_columns(self):
+ return self._gather_join_annotations("foreign")
+
+ def _gather_join_annotations(self, annotation):
+ s = set(
+ self._gather_columns_with_annotation(self.primaryjoin, annotation)
+ )
+ if self.secondaryjoin is not None:
+ s.update(
+ self._gather_columns_with_annotation(
+ self.secondaryjoin, annotation
+ )
+ )
+ return {x._deannotate() for x in s}
+
+ def _gather_columns_with_annotation(self, clause, *annotation):
+ annotation = set(annotation)
+ return set(
+ [
+ col
+ for col in visitors.iterate(clause, {})
+ if annotation.issubset(col._annotations)
+ ]
+ )
+
+ def join_targets(
+ self,
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit=None,
+ extra_criteria=(),
+ ):
+ """Given a source and destination selectable, create a
+ join between them.
+
+ This takes into account aliasing the join clause
+ to reference the appropriate corresponding columns
+ in the target objects, as well as the extra child
+ criterion, equivalent column sets, etc.
+
+ """
+ # place a barrier on the destination such that
+ # replacement traversals won't ever dig into it.
+ # its internal structure remains fixed
+ # regardless of context.
+ dest_selectable = _shallow_annotate(
+ dest_selectable, {"no_replacement_traverse": True}
+ )
+
+ primaryjoin, secondaryjoin, secondary = (
+ self.primaryjoin,
+ self.secondaryjoin,
+ self.secondary,
+ )
+
+ # adjust the join condition for single table inheritance,
+ # in the case that the join is to a subclass
+ # this is analogous to the
+ # "_adjust_for_single_table_inheritance()" method in Query.
+
+ if single_crit is not None:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & single_crit
+ else:
+ primaryjoin = primaryjoin & single_crit
+
+ if extra_criteria:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
+ else:
+ primaryjoin = primaryjoin & sql.and_(*extra_criteria)
+
+ if aliased:
+ if secondary is not None:
+ secondary = secondary._anonymous_fromclause(flat=True)
+ primary_aliasizer = ClauseAdapter(
+ secondary, exclude_fn=_ColInAnnotations("local")
+ )
+ secondary_aliasizer = ClauseAdapter(
+ dest_selectable, equivalents=self.child_equivalents
+ ).chain(primary_aliasizer)
+ if source_selectable is not None:
+ primary_aliasizer = ClauseAdapter(
+ secondary, exclude_fn=_ColInAnnotations("local")
+ ).chain(
+ ClauseAdapter(
+ source_selectable,
+ equivalents=self.parent_equivalents,
+ )
+ )
+
+ secondaryjoin = secondary_aliasizer.traverse(secondaryjoin)
+ else:
+ primary_aliasizer = ClauseAdapter(
+ dest_selectable,
+ exclude_fn=_ColInAnnotations("local"),
+ equivalents=self.child_equivalents,
+ )
+ if source_selectable is not None:
+ primary_aliasizer.chain(
+ ClauseAdapter(
+ source_selectable,
+ exclude_fn=_ColInAnnotations("remote"),
+ equivalents=self.parent_equivalents,
+ )
+ )
+ secondary_aliasizer = None
+
+ primaryjoin = primary_aliasizer.traverse(primaryjoin)
+ target_adapter = secondary_aliasizer or primary_aliasizer
+ target_adapter.exclude_fn = None
+ else:
+ target_adapter = None
+ return (
+ primaryjoin,
+ secondaryjoin,
+ secondary,
+ target_adapter,
+ dest_selectable,
+ )
+
+ def create_lazy_clause(self, reverse_direction=False):
+ binds = util.column_dict()
+ equated_columns = util.column_dict()
+
+ has_secondary = self.secondaryjoin is not None
+
+ if has_secondary:
+ lookup = collections.defaultdict(list)
+ for l, r in self.local_remote_pairs:
+ lookup[l].append((l, r))
+ equated_columns[r] = l
+ elif not reverse_direction:
+ for l, r in self.local_remote_pairs:
+ equated_columns[r] = l
+ else:
+ for l, r in self.local_remote_pairs:
+ equated_columns[l] = r
+
+ def col_to_bind(col):
+
+ if (
+ (not reverse_direction and "local" in col._annotations)
+ or reverse_direction
+ and (
+ (has_secondary and col in lookup)
+ or (not has_secondary and "remote" in col._annotations)
+ )
+ ):
+ if col not in binds:
+ binds[col] = sql.bindparam(
+ None, None, type_=col.type, unique=True
+ )
+ return binds[col]
+ return None
+
+ lazywhere = self.primaryjoin
+ if self.secondaryjoin is None or not reverse_direction:
+ lazywhere = visitors.replacement_traverse(
+ lazywhere, {}, col_to_bind
+ )
+
+ if self.secondaryjoin is not None:
+ secondaryjoin = self.secondaryjoin
+ if reverse_direction:
+ secondaryjoin = visitors.replacement_traverse(
+ secondaryjoin, {}, col_to_bind
+ )
+ lazywhere = sql.and_(lazywhere, secondaryjoin)
+
+ bind_to_col = {binds[col].key: col for col in binds}
+
+ return lazywhere, bind_to_col, equated_columns
+
+
+class _ColInAnnotations(object):
+ """Serializable object that tests for a name in c._annotations."""
+
+ __slots__ = ("name",)
+
+ def __init__(self, name):
+ self.name = name
+
+ def __call__(self, c):
+ return self.name in c._annotations
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
new file mode 100644
index 0000000..f323233
--- /dev/null
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -0,0 +1,228 @@
+# orm/scoping.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import class_mapper
+from . import exc as orm_exc
+from .session import Session
+from .. import exc as sa_exc
+from ..util import create_proxy_methods
+from ..util import ScopedRegistry
+from ..util import ThreadLocalRegistry
+from ..util import warn
+from ..util import warn_deprecated
+
+__all__ = ["scoped_session", "ScopedSessionMixin"]
+
+
+class ScopedSessionMixin(object):
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ def __call__(self, **kw):
+ r"""Return the current :class:`.Session`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.Session` is not present. If the :class:`.Session` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ else:
+ sess = self.registry()
+ if not self._support_async and sess._is_asyncio:
+ warn_deprecated(
+ "Using `scoped_session` with asyncio is deprecated and "
+ "will raise an error in a future version. "
+ "Please use `async_scoped_session` instead.",
+ "1.4.23",
+ )
+ return sess
+
+ def configure(self, **kwargs):
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
+
+
+@create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_orm.scoping.scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "bulk_save_objects",
+ "bulk_insert_mappings",
+ "bulk_update_mappings",
+ "merge",
+ "query",
+ "refresh",
+ "rollback",
+ "scalar",
+ "scalars",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ "autocommit",
+ ],
+)
+class scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.Session` objects.
+
+ See :ref:`unitofwork_contextual` for a tutorial.
+
+ .. note::
+
+ When using :ref:`asyncio_toplevel`, the async-compatible
+ :class:`_asyncio.async_scoped_session` class should be
+ used in place of :class:`.scoped_session`.
+
+ """
+
+ _support_async = False
+
+ session_factory = None
+ """The `session_factory` provided to `__init__` is stored in this
+ attribute and may be accessed at a later time. This can be useful when
+ a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
+ database is needed."""
+
+ def __init__(self, session_factory, scopefunc=None):
+ """Construct a new :class:`.scoped_session`.
+
+ :param session_factory: a factory to create new :class:`.Session`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`.sessionmaker`.
+ :param scopefunc: optional function which defines
+ the current scope. If not passed, the :class:`.scoped_session`
+ object assumes "thread-local" scope, and will use
+ a Python ``threading.local()`` in order to maintain the current
+ :class:`.Session`. If passed, the function should return
+ a hashable token; this token will be used as the key in a
+ dictionary in order to store and retrieve the current
+ :class:`.Session`.
+
+ """
+ self.session_factory = session_factory
+
+ if scopefunc:
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+ else:
+ self.registry = ThreadLocalRegistry(session_factory)
+
+ def remove(self):
+ """Dispose of the current :class:`.Session`, if present.
+
+ This will first call :meth:`.Session.close` method
+ on the current :class:`.Session`, which releases any existing
+ transactional/connection resources still being held; transactions
+ specifically are rolled back. The :class:`.Session` is then
+ discarded. Upon next usage within the same scope,
+ the :class:`.scoped_session` will produce a new
+ :class:`.Session` object.
+
+ """
+
+ if self.registry.has():
+ self.registry().close()
+ self.registry.clear()
+
+ def query_property(self, query_cls=None):
+ """return a class property which produces a :class:`_query.Query`
+ object
+ against the class and the current :class:`.Session` when called.
+
+ e.g.::
+
+ Session = scoped_session(sessionmaker())
+
+ class MyClass(object):
+ query = Session.query_property()
+
+ # after mappers are defined
+ result = MyClass.query.filter(MyClass.name=='foo').all()
+
+ Produces instances of the session's configured query class by
+ default. To override and use a custom implementation, provide
+ a ``query_cls`` callable. The callable will be invoked with
+ the class's mapper as a positional argument and a session
+ keyword argument.
+
+ There is no limit to the number of query properties placed on
+ a class.
+
+ """
+
+ class query(object):
+ def __get__(s, instance, owner):
+ try:
+ mapper = class_mapper(owner)
+ if mapper:
+ if query_cls:
+ # custom query class
+ return query_cls(mapper, session=self.registry())
+ else:
+ # session's configured query class
+ return self.registry().query(mapper)
+ except orm_exc.UnmappedClassError:
+ return None
+
+ return query()
+
+
+ScopedSession = scoped_session
+"""Old name for backwards compatibility."""
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
new file mode 100644
index 0000000..c6a9169
--- /dev/null
+++ b/lib/sqlalchemy/orm/session.py
@@ -0,0 +1,4386 @@
+# orm/session.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Provides the Session class and related utilities."""
+
+
+import itertools
+import sys
+import weakref
+
+from . import attributes
+from . import context
+from . import exc
+from . import identity
+from . import loading
+from . import persistence
+from . import query
+from . import state as statelib
+from .base import _class_to_mapper
+from .base import _none_set
+from .base import _state_mapper
+from .base import instance_str
+from .base import object_mapper
+from .base import object_state
+from .base import state_str
+from .unitofwork import UOWTransaction
+from .. import engine
+from .. import exc as sa_exc
+from .. import sql
+from .. import util
+from ..engine.util import TransactionalContext
+from ..inspection import inspect
+from ..sql import coercions
+from ..sql import dml
+from ..sql import roles
+from ..sql import visitors
+from ..sql.base import CompileState
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+
+__all__ = [
+ "Session",
+ "SessionTransaction",
+ "sessionmaker",
+ "ORMExecuteState",
+ "close_all_sessions",
+ "make_transient",
+ "make_transient_to_detached",
+ "object_session",
+]
+
+_sessions = weakref.WeakValueDictionary()
+"""Weak-referencing dictionary of :class:`.Session` objects.
+"""
+
+statelib._sessions = _sessions
+
+
+def _state_session(state):
+ """Given an :class:`.InstanceState`, return the :class:`.Session`
+ associated, if any.
+ """
+ return state.session
+
+
+class _SessionClassMethods(object):
+ """Class-level methods for :class:`.Session`, :class:`.sessionmaker`."""
+
+ @classmethod
+ @util.deprecated(
+ "1.3",
+ "The :meth:`.Session.close_all` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":func:`.session.close_all_sessions`.",
+ )
+ def close_all(cls):
+ """Close *all* sessions in memory."""
+
+ close_all_sessions()
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm.util")
+ def identity_key(cls, *args, **kwargs):
+ """Return an identity key.
+
+ This is an alias of :func:`.util.identity_key`.
+
+ """
+ return util.preloaded.orm_util.identity_key(*args, **kwargs)
+
+ @classmethod
+ def object_session(cls, instance):
+ """Return the :class:`.Session` to which an object belongs.
+
+ This is an alias of :func:`.object_session`.
+
+ """
+
+ return object_session(instance)
+
+
+ACTIVE = util.symbol("ACTIVE")
+PREPARED = util.symbol("PREPARED")
+COMMITTED = util.symbol("COMMITTED")
+DEACTIVE = util.symbol("DEACTIVE")
+CLOSED = util.symbol("CLOSED")
+
+
+class ORMExecuteState(util.MemoizedSlots):
+ """Represents a call to the :meth:`_orm.Session.execute` method, as passed
+ to the :meth:`.SessionEvents.do_orm_execute` event hook.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`session_execute_events` - top level documentation on how
+ to use :meth:`_orm.SessionEvents.do_orm_execute`
+
+ """
+
+ __slots__ = (
+ "session",
+ "statement",
+ "parameters",
+ "execution_options",
+ "local_execution_options",
+ "bind_arguments",
+ "_compile_state_cls",
+ "_starting_event_idx",
+ "_events_todo",
+ "_update_execution_options",
+ )
+
+ def __init__(
+ self,
+ session,
+ statement,
+ parameters,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ events_todo,
+ ):
+ self.session = session
+ self.statement = statement
+ self.parameters = parameters
+ self.local_execution_options = execution_options
+ self.execution_options = statement._execution_options.union(
+ execution_options
+ )
+ self.bind_arguments = bind_arguments
+ self._compile_state_cls = compile_state_cls
+ self._events_todo = list(events_todo)
+
+ def _remaining_events(self):
+ return self._events_todo[self._starting_event_idx + 1 :]
+
+ def invoke_statement(
+ self,
+ statement=None,
+ params=None,
+ execution_options=None,
+ bind_arguments=None,
+ ):
+ """Execute the statement represented by this
+ :class:`.ORMExecuteState`, without re-invoking events that have
+ already proceeded.
+
+ This method essentially performs a re-entrant execution of the current
+ statement for which the :meth:`.SessionEvents.do_orm_execute` event is
+ being currently invoked. The use case for this is for event handlers
+ that want to override how the ultimate
+ :class:`_engine.Result` object is returned, such as for schemes that
+ retrieve results from an offline cache or which concatenate results
+ from multiple executions.
+
+ When the :class:`_engine.Result` object is returned by the actual
+ handler function within :meth:`_orm.SessionEvents.do_orm_execute` and
+ is propagated to the calling
+ :meth:`_orm.Session.execute` method, the remainder of the
+ :meth:`_orm.Session.execute` method is preempted and the
+ :class:`_engine.Result` object is returned to the caller of
+ :meth:`_orm.Session.execute` immediately.
+
+ :param statement: optional statement to be invoked, in place of the
+ statement currently represented by :attr:`.ORMExecuteState.statement`.
+
+ :param params: optional dictionary of parameters which will be merged
+ into the existing :attr:`.ORMExecuteState.parameters` of this
+ :class:`.ORMExecuteState`.
+
+ :param execution_options: optional dictionary of execution options
+ will be merged into the existing
+ :attr:`.ORMExecuteState.execution_options` of this
+ :class:`.ORMExecuteState`.
+
+ :param bind_arguments: optional dictionary of bind_arguments
+ which will be merged amongst the current
+ :attr:`.ORMExecuteState.bind_arguments`
+ of this :class:`.ORMExecuteState`.
+
+ :return: a :class:`_engine.Result` object with ORM-level results.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - background and examples on the
+ appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`.
+
+
+ """
+
+ if statement is None:
+ statement = self.statement
+
+ _bind_arguments = dict(self.bind_arguments)
+ if bind_arguments:
+ _bind_arguments.update(bind_arguments)
+ _bind_arguments["_sa_skip_events"] = True
+
+ if params:
+ _params = dict(self.parameters)
+ _params.update(params)
+ else:
+ _params = self.parameters
+
+ _execution_options = self.local_execution_options
+ if execution_options:
+ _execution_options = _execution_options.union(execution_options)
+
+ return self.session.execute(
+ statement,
+ _params,
+ _execution_options,
+ _bind_arguments,
+ _parent_execute_state=self,
+ )
+
+ @property
+ def bind_mapper(self):
+ """Return the :class:`_orm.Mapper` that is the primary "bind" mapper.
+
+ For an :class:`_orm.ORMExecuteState` object invoking an ORM
+ statement, that is, the :attr:`_orm.ORMExecuteState.is_orm_statement`
+ attribute is ``True``, this attribute will return the
+ :class:`_orm.Mapper` that is considered to be the "primary" mapper
+ of the statement. The term "bind mapper" refers to the fact that
+ a :class:`_orm.Session` object may be "bound" to multiple
+ :class:`_engine.Engine` objects keyed to mapped classes, and the
+ "bind mapper" determines which of those :class:`_engine.Engine` objects
+ would be selected.
+
+ For a statement that is invoked against a single mapped class,
+ :attr:`_orm.ORMExecuteState.bind_mapper` is intended to be a reliable
+ way of getting this mapper.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.all_mappers`
+
+
+ """
+ return self.bind_arguments.get("mapper", None)
+
+ @property
+ def all_mappers(self):
+ """Return a sequence of all :class:`_orm.Mapper` objects that are
+ involved at the top level of this statement.
+
+ By "top level" we mean those :class:`_orm.Mapper` objects that would
+ be represented in the result set rows for a :func:`_sql.select`
+ query, or for a :func:`_dml.update` or :func:`_dml.delete` query,
+ the mapper that is the main subject of the UPDATE or DELETE.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.bind_mapper`
+
+
+
+ """
+ if not self.is_orm_statement:
+ return []
+ elif self.is_select:
+ result = []
+ seen = set()
+ for d in self.statement.column_descriptions:
+ ent = d["entity"]
+ if ent:
+ insp = inspect(ent, raiseerr=False)
+ if insp and insp.mapper and insp.mapper not in seen:
+ seen.add(insp.mapper)
+ result.append(insp.mapper)
+ return result
+ elif self.is_update or self.is_delete:
+ return [self.bind_mapper]
+ else:
+ return []
+
+ @property
+ def is_orm_statement(self):
+ """return True if the operation is an ORM statement.
+
+ This indicates that the select(), update(), or delete() being
+ invoked contains ORM entities as subjects. For a statement
+ that does not have ORM entities and instead refers only to
+ :class:`.Table` metadata, it is invoked as a Core SQL statement
+ and no ORM-level automation takes place.
+
+ """
+ return self._compile_state_cls is not None
+
+ @property
+ def is_select(self):
+ """return True if this is a SELECT operation."""
+ return self.statement.is_select
+
+ @property
+ def is_insert(self):
+ """return True if this is an INSERT operation."""
+ return self.statement.is_dml and self.statement.is_insert
+
+ @property
+ def is_update(self):
+ """return True if this is an UPDATE operation."""
+ return self.statement.is_dml and self.statement.is_update
+
+ @property
+ def is_delete(self):
+ """return True if this is a DELETE operation."""
+ return self.statement.is_dml and self.statement.is_delete
+
+ @property
+ def _is_crud(self):
+ return isinstance(self.statement, (dml.Update, dml.Delete))
+
+ def update_execution_options(self, **opts):
+ # TODO: no coverage
+ self.local_execution_options = self.local_execution_options.union(opts)
+
+ def _orm_compile_options(self):
+ if not self.is_select:
+ return None
+ opts = self.statement._compile_options
+ if opts.isinstance(context.ORMCompileState.default_compile_options):
+ return opts
+ else:
+ return None
+
+ @property
+ def lazy_loaded_from(self):
+ """An :class:`.InstanceState` that is using this statement execution
+ for a lazy load operation.
+
+ The primary rationale for this attribute is to support the horizontal
+ sharding extension, where it is available within specific query
+ execution time hooks created by this extension. To that end, the
+ attribute is only intended to be meaningful at **query execution
+ time**, and importantly not any time prior to that, including query
+ compilation time.
+
+ """
+ return self.load_options._lazy_loaded_from
+
+ @property
+ def loader_strategy_path(self):
+ """Return the :class:`.PathRegistry` for the current load path.
+
+ This object represents the "path" in a query along relationships
+ when a particular object or collection is being loaded.
+
+ """
+ opts = self._orm_compile_options()
+ if opts is not None:
+ return opts._current_path
+ else:
+ return None
+
+ @property
+ def is_column_load(self):
+ """Return True if the operation is refreshing column-oriented
+ attributes on an existing ORM object.
+
+ This occurs during operations such as :meth:`_orm.Session.refresh`,
+ as well as when an attribute deferred by :func:`_orm.defer` is
+ being loaded, or an attribute that was expired either directly
+ by :meth:`_orm.Session.expire` or via a commit operation is being
+ loaded.
+
+ Handlers will very likely not want to add any options to queries
+ when such an operation is occurring as the query should be a straight
+ primary key fetch which should not have any additional WHERE criteria,
+ and loader options travelling with the instance
+ will have already been added to the query.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.is_relationship_load`
+
+ """
+ opts = self._orm_compile_options()
+ return opts is not None and opts._for_refresh_state
+
+ @property
+ def is_relationship_load(self):
+ """Return True if this load is loading objects on behalf of a
+ relationship.
+
+ This means, the loader in effect is either a LazyLoader,
+ SelectInLoader, SubqueryLoader, or similar, and the entire
+ SELECT statement being emitted is on behalf of a relationship
+ load.
+
+ Handlers will very likely not want to add any options to queries
+ when such an operation is occurring, as loader options are already
+ capable of being propagated to relationship loaders and should
+ be already present.
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.is_column_load`
+
+ """
+ opts = self._orm_compile_options()
+ if opts is None:
+ return False
+ path = self.loader_strategy_path
+ return path is not None and not path.is_root
+
+ @property
+ def load_options(self):
+ """Return the load_options that will be used for this execution."""
+
+ if not self.is_select:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against a SELECT statement "
+ "so there are no load options."
+ )
+ return self.execution_options.get(
+ "_sa_orm_load_options", context.QueryContext.default_load_options
+ )
+
+ @property
+ def update_delete_options(self):
+ """Return the update_delete_options that will be used for this
+ execution."""
+
+ if not self._is_crud:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against an UPDATE or DELETE "
+ "statement so there are no update options."
+ )
+ return self.execution_options.get(
+ "_sa_orm_update_options",
+ persistence.BulkUDCompileState.default_update_options,
+ )
+
+ @property
+ def user_defined_options(self):
+ """The sequence of :class:`.UserDefinedOptions` that have been
+ associated with the statement being invoked.
+
+ """
+ return [
+ opt
+ for opt in self.statement._with_options
+ if not opt._is_compile_state and not opt._is_legacy_option
+ ]
+
+
+class SessionTransaction(TransactionalContext):
+ """A :class:`.Session`-level transaction.
+
+ :class:`.SessionTransaction` is produced from the
+ :meth:`_orm.Session.begin`
+ and :meth:`_orm.Session.begin_nested` methods. It's largely an internal
+ object that in modern use provides a context manager for session
+ transactions.
+
+ Documentation on interacting with :class:`_orm.SessionTransaction` is
+ at: :ref:`unitofwork_transaction`.
+
+
+ .. versionchanged:: 1.4 The scoping and API methods to work with the
+ :class:`_orm.SessionTransaction` object directly have been simplified.
+
+ .. seealso::
+
+ :ref:`unitofwork_transaction`
+
+ :meth:`.Session.begin`
+
+ :meth:`.Session.begin_nested`
+
+ :meth:`.Session.rollback`
+
+ :meth:`.Session.commit`
+
+ :meth:`.Session.in_transaction`
+
+ :meth:`.Session.in_nested_transaction`
+
+ :meth:`.Session.get_transaction`
+
+ :meth:`.Session.get_nested_transaction`
+
+
+ """
+
+ _rollback_exception = None
+
+ def __init__(
+ self,
+ session,
+ parent=None,
+ nested=False,
+ autobegin=False,
+ ):
+ TransactionalContext._trans_ctx_check(session)
+
+ self.session = session
+ self._connections = {}
+ self._parent = parent
+ self.nested = nested
+ if nested:
+ self._previous_nested_transaction = session._nested_transaction
+ self._state = ACTIVE
+ if not parent and nested:
+ raise sa_exc.InvalidRequestError(
+ "Can't start a SAVEPOINT transaction when no existing "
+ "transaction is in progress"
+ )
+
+ self._take_snapshot(autobegin=autobegin)
+
+ # make sure transaction is assigned before we call the
+ # dispatch
+ self.session._transaction = self
+
+ self.session.dispatch.after_transaction_create(self.session, self)
+
+ @property
+ def parent(self):
+ """The parent :class:`.SessionTransaction` of this
+ :class:`.SessionTransaction`.
+
+ If this attribute is ``None``, indicates this
+ :class:`.SessionTransaction` is at the top of the stack, and
+ corresponds to a real "COMMIT"/"ROLLBACK"
+ block. If non-``None``, then this is either a "subtransaction"
+ or a "nested" / SAVEPOINT transaction. If the
+ :attr:`.SessionTransaction.nested` attribute is ``True``, then
+ this is a SAVEPOINT, and if ``False``, indicates this a subtransaction.
+
+ .. versionadded:: 1.0.16 - use ._parent for previous versions
+
+ """
+ return self._parent
+
+ nested = False
+ """Indicates if this is a nested, or SAVEPOINT, transaction.
+
+ When :attr:`.SessionTransaction.nested` is True, it is expected
+ that :attr:`.SessionTransaction.parent` will be True as well.
+
+ """
+
+ @property
+ def is_active(self):
+ return self.session is not None and self._state is ACTIVE
+
+ def _assert_active(
+ self,
+ prepared_ok=False,
+ rollback_ok=False,
+ deactive_ok=False,
+ closed_msg="This transaction is closed",
+ ):
+ if self._state is COMMITTED:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'committed' state; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is PREPARED:
+ if not prepared_ok:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'prepared' state; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is DEACTIVE:
+ if not deactive_ok and not rollback_ok:
+ if self._rollback_exception:
+ raise sa_exc.PendingRollbackError(
+ "This Session's transaction has been rolled back "
+ "due to a previous exception during flush."
+ " To begin a new transaction with this Session, "
+ "first issue Session.rollback()."
+ " Original exception was: %s"
+ % self._rollback_exception,
+ code="7s2a",
+ )
+ elif not deactive_ok:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'inactive' state, due to the "
+ "SQL transaction being rolled back; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is CLOSED:
+ raise sa_exc.ResourceClosedError(closed_msg)
+
+ @property
+ def _is_transaction_boundary(self):
+ return self.nested or not self._parent
+
+ def connection(self, bindkey, execution_options=None, **kwargs):
+ self._assert_active()
+ bind = self.session.get_bind(bindkey, **kwargs)
+ return self._connection_for_bind(bind, execution_options)
+
+ def _begin(self, nested=False):
+ self._assert_active()
+ return SessionTransaction(self.session, self, nested=nested)
+
+ def _iterate_self_and_parents(self, upto=None):
+
+ current = self
+ result = ()
+ while current:
+ result += (current,)
+ if current._parent is upto:
+ break
+ elif current._parent is None:
+ raise sa_exc.InvalidRequestError(
+ "Transaction %s is not on the active transaction list"
+ % (upto)
+ )
+ else:
+ current = current._parent
+
+ return result
+
+ def _take_snapshot(self, autobegin=False):
+ if not self._is_transaction_boundary:
+ self._new = self._parent._new
+ self._deleted = self._parent._deleted
+ self._dirty = self._parent._dirty
+ self._key_switches = self._parent._key_switches
+ return
+
+ if not autobegin and not self.session._flushing:
+ self.session.flush()
+
+ self._new = weakref.WeakKeyDictionary()
+ self._deleted = weakref.WeakKeyDictionary()
+ self._dirty = weakref.WeakKeyDictionary()
+ self._key_switches = weakref.WeakKeyDictionary()
+
+ def _restore_snapshot(self, dirty_only=False):
+ """Restore the restoration state taken before a transaction began.
+
+ Corresponds to a rollback.
+
+ """
+ assert self._is_transaction_boundary
+
+ to_expunge = set(self._new).union(self.session._new)
+ self.session._expunge_states(to_expunge, to_transient=True)
+
+ for s, (oldkey, newkey) in self._key_switches.items():
+ # we probably can do this conditionally based on
+ # if we expunged or not, but safe_discard does that anyway
+ self.session.identity_map.safe_discard(s)
+
+ # restore the old key
+ s.key = oldkey
+
+ # now restore the object, but only if we didn't expunge
+ if s not in to_expunge:
+ self.session.identity_map.replace(s)
+
+ for s in set(self._deleted).union(self.session._deleted):
+ self.session._update_impl(s, revert_deletion=True)
+
+ assert not self.session._deleted
+
+ for s in self.session.identity_map.all_states():
+ if not dirty_only or s.modified or s in self._dirty:
+ s._expire(s.dict, self.session.identity_map._modified)
+
+ def _remove_snapshot(self):
+ """Remove the restoration state taken before a transaction began.
+
+ Corresponds to a commit.
+
+ """
+ assert self._is_transaction_boundary
+
+ if not self.nested and self.session.expire_on_commit:
+ for s in self.session.identity_map.all_states():
+ s._expire(s.dict, self.session.identity_map._modified)
+
+ statelib.InstanceState._detach_states(
+ list(self._deleted), self.session
+ )
+ self._deleted.clear()
+ elif self.nested:
+ self._parent._new.update(self._new)
+ self._parent._dirty.update(self._dirty)
+ self._parent._deleted.update(self._deleted)
+ self._parent._key_switches.update(self._key_switches)
+
+ def _connection_for_bind(self, bind, execution_options):
+ self._assert_active()
+
+ if bind in self._connections:
+ if execution_options:
+ util.warn(
+ "Connection is already established for the "
+ "given bind; execution_options ignored"
+ )
+ return self._connections[bind][0]
+
+ local_connect = False
+ should_commit = True
+
+ if self._parent:
+ conn = self._parent._connection_for_bind(bind, execution_options)
+ if not self.nested:
+ return conn
+ else:
+ if isinstance(bind, engine.Connection):
+ conn = bind
+ if conn.engine in self._connections:
+ raise sa_exc.InvalidRequestError(
+ "Session already has a Connection associated for the "
+ "given Connection's Engine"
+ )
+ else:
+ conn = bind.connect()
+ local_connect = True
+
+ try:
+ if execution_options:
+ conn = conn.execution_options(**execution_options)
+
+ if self.session.twophase and self._parent is None:
+ transaction = conn.begin_twophase()
+ elif self.nested:
+ transaction = conn.begin_nested()
+ elif conn.in_transaction():
+ # if given a future connection already in a transaction, don't
+ # commit that transaction unless it is a savepoint
+ if conn.in_nested_transaction():
+ transaction = conn.get_nested_transaction()
+ else:
+ transaction = conn.get_transaction()
+ should_commit = False
+ else:
+ transaction = conn.begin()
+ except:
+ # connection will not not be associated with this Session;
+ # close it immediately so that it isn't closed under GC
+ if local_connect:
+ conn.close()
+ raise
+ else:
+ bind_is_connection = isinstance(bind, engine.Connection)
+
+ self._connections[conn] = self._connections[conn.engine] = (
+ conn,
+ transaction,
+ should_commit,
+ not bind_is_connection,
+ )
+ self.session.dispatch.after_begin(self.session, self, conn)
+ return conn
+
+ def prepare(self):
+ if self._parent is not None or not self.session.twophase:
+ raise sa_exc.InvalidRequestError(
+ "'twophase' mode not enabled, or not root transaction; "
+ "can't prepare."
+ )
+ self._prepare_impl()
+
+ def _prepare_impl(self):
+ self._assert_active()
+ if self._parent is None or self.nested:
+ self.session.dispatch.before_commit(self.session)
+
+ stx = self.session._transaction
+ if stx is not self:
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
+ subtransaction.commit()
+
+ if not self.session._flushing:
+ for _flush_guard in range(100):
+ if self.session._is_clean():
+ break
+ self.session.flush()
+ else:
+ raise exc.FlushError(
+ "Over 100 subsequent flushes have occurred within "
+ "session.commit() - is an after_flush() hook "
+ "creating new objects?"
+ )
+
+ if self._parent is None and self.session.twophase:
+ try:
+ for t in set(self._connections.values()):
+ t[1].prepare()
+ except:
+ with util.safe_reraise():
+ self.rollback()
+
+ self._state = PREPARED
+
+ def commit(self, _to_root=False):
+ self._assert_active(prepared_ok=True)
+ if self._state is not PREPARED:
+ self._prepare_impl()
+
+ if self._parent is None or self.nested:
+ for conn, trans, should_commit, autoclose in set(
+ self._connections.values()
+ ):
+ if should_commit:
+ trans.commit()
+
+ self._state = COMMITTED
+ self.session.dispatch.after_commit(self.session)
+
+ self._remove_snapshot()
+
+ self.close()
+
+ if _to_root and self._parent:
+ return self._parent.commit(_to_root=True)
+
+ return self._parent
+
+ def rollback(self, _capture_exception=False, _to_root=False):
+ self._assert_active(prepared_ok=True, rollback_ok=True)
+
+ stx = self.session._transaction
+ if stx is not self:
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
+ subtransaction.close()
+
+ boundary = self
+ rollback_err = None
+ if self._state in (ACTIVE, PREPARED):
+ for transaction in self._iterate_self_and_parents():
+ if transaction._parent is None or transaction.nested:
+ try:
+ for t in set(transaction._connections.values()):
+ t[1].rollback()
+
+ transaction._state = DEACTIVE
+ self.session.dispatch.after_rollback(self.session)
+ except:
+ rollback_err = sys.exc_info()
+ finally:
+ transaction._state = DEACTIVE
+ transaction._restore_snapshot(
+ dirty_only=transaction.nested
+ )
+ boundary = transaction
+ break
+ else:
+ transaction._state = DEACTIVE
+
+ sess = self.session
+
+ if not rollback_err and not sess._is_clean():
+
+ # if items were added, deleted, or mutated
+ # here, we need to re-restore the snapshot
+ util.warn(
+ "Session's state has been changed on "
+ "a non-active transaction - this state "
+ "will be discarded."
+ )
+ boundary._restore_snapshot(dirty_only=boundary.nested)
+
+ self.close()
+
+ if self._parent and _capture_exception:
+ self._parent._rollback_exception = sys.exc_info()[1]
+
+ if rollback_err:
+ util.raise_(rollback_err[1], with_traceback=rollback_err[2])
+
+ sess.dispatch.after_soft_rollback(sess, self)
+
+ if _to_root and self._parent:
+ return self._parent.rollback(_to_root=True)
+ return self._parent
+
+ def close(self, invalidate=False):
+ if self.nested:
+ self.session._nested_transaction = (
+ self._previous_nested_transaction
+ )
+
+ self.session._transaction = self._parent
+
+ if self._parent is None:
+ for connection, transaction, should_commit, autoclose in set(
+ self._connections.values()
+ ):
+ if invalidate:
+ connection.invalidate()
+ if should_commit and transaction.is_active:
+ transaction.close()
+ if autoclose:
+ connection.close()
+
+ self._state = CLOSED
+ self.session.dispatch.after_transaction_end(self.session, self)
+
+ self.session = None
+ self._connections = None
+
+ def _get_subject(self):
+ return self.session
+
+ def _transaction_is_active(self):
+ return self._state is ACTIVE
+
+ def _transaction_is_closed(self):
+ return self._state is CLOSED
+
+ def _rollback_can_be_called(self):
+ return self._state not in (COMMITTED, CLOSED)
+
+
+class Session(_SessionClassMethods):
+ """Manages persistence operations for ORM-mapped objects.
+
+ The Session's usage paradigm is described at :doc:`/orm/session`.
+
+
+ """
+
+ _is_asyncio = False
+
+ @util.deprecated_params(
+ autocommit=(
+ "2.0",
+ "The :paramref:`.Session.autocommit` parameter is deprecated "
+ "and will be removed in SQLAlchemy version 2.0. The "
+ ':class:`_orm.Session` now features "autobegin" behavior '
+ "such that the :meth:`.Session.begin` method may be called "
+ "if a transaction has not yet been started yet. See the section "
+ ":ref:`session_explicit_begin` for background.",
+ ),
+ )
+ def __init__(
+ self,
+ bind=None,
+ autoflush=True,
+ future=False,
+ expire_on_commit=True,
+ autocommit=False,
+ twophase=False,
+ binds=None,
+ enable_baked_queries=True,
+ info=None,
+ query_cls=None,
+ ):
+ r"""Construct a new Session.
+
+ See also the :class:`.sessionmaker` function which is used to
+ generate a :class:`.Session`-producing callable with a given
+ set of arguments.
+
+ :param autocommit:
+ Defaults to ``False``. When ``True``, the
+ :class:`.Session` does not automatically begin transactions for
+ individual statement executions, will acquire connections from the
+ engine on an as-needed basis, releasing to the connection pool
+ after each statement. Flushes will begin and commit (or possibly
+ rollback) their own transaction if no transaction is present.
+ When using this mode, the
+ :meth:`.Session.begin` method may be used to explicitly start
+ transactions, but the usual "autobegin" behavior is not present.
+
+ :param autoflush: When ``True``, all query operations will issue a
+ :meth:`~.Session.flush` call to this ``Session`` before proceeding.
+ This is a convenience feature so that :meth:`~.Session.flush` need
+ not be called repeatedly in order for database queries to retrieve
+ results. It's typical that ``autoflush`` is used in conjunction
+ with ``autocommit=False``. In this scenario, explicit calls to
+ :meth:`~.Session.flush` are rarely needed; you usually only need to
+ call :meth:`~.Session.commit` (which flushes) to finalize changes.
+
+ .. seealso::
+
+ :ref:`session_flushing` - additional background on autoflush
+
+ :param bind: An optional :class:`_engine.Engine` or
+ :class:`_engine.Connection` to
+ which this ``Session`` should be bound. When specified, all SQL
+ operations performed by this session will execute via this
+ connectable.
+
+ :param binds: A dictionary which may specify any number of
+ :class:`_engine.Engine` or :class:`_engine.Connection`
+ objects as the source of
+ connectivity for SQL operations on a per-entity basis. The keys
+ of the dictionary consist of any series of mapped classes,
+ arbitrary Python classes that are bases for mapped classes,
+ :class:`_schema.Table` objects and :class:`_orm.Mapper` objects.
+ The
+ values of the dictionary are then instances of
+ :class:`_engine.Engine`
+ or less commonly :class:`_engine.Connection` objects.
+ Operations which
+ proceed relative to a particular mapped class will consult this
+ dictionary for the closest matching entity in order to determine
+ which :class:`_engine.Engine` should be used for a particular SQL
+ operation. The complete heuristics for resolution are
+ described at :meth:`.Session.get_bind`. Usage looks like::
+
+ Session = sessionmaker(binds={
+ SomeMappedClass: create_engine('postgresql://engine1'),
+ SomeDeclarativeBase: create_engine('postgresql://engine2'),
+ some_mapper: create_engine('postgresql://engine3'),
+ some_table: create_engine('postgresql://engine4'),
+ })
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :meth:`.Session.bind_mapper`
+
+ :meth:`.Session.bind_table`
+
+ :meth:`.Session.get_bind`
+
+
+ :param \class_: Specify an alternate class other than
+ ``sqlalchemy.orm.session.Session`` which should be used by the
+ returned class. This is the only argument that is local to the
+ :class:`.sessionmaker` function, and is not sent directly to the
+ constructor for ``Session``.
+
+ :param enable_baked_queries: defaults to ``True``. A flag consumed
+ by the :mod:`sqlalchemy.ext.baked` extension to determine if
+ "baked queries" should be cached, as is the normal operation
+ of this extension. When set to ``False``, caching as used by
+ this particular extension is disabled.
+
+ .. versionchanged:: 1.4 The ``sqlalchemy.ext.baked`` extension is
+ legacy and is not used by any of SQLAlchemy's internals. This
+ flag therefore only affects applications that are making explicit
+ use of this extension within their own code.
+
+ :param expire_on_commit: Defaults to ``True``. When ``True``, all
+ instances will be fully expired after each :meth:`~.commit`,
+ so that all attribute/object access subsequent to a completed
+ transaction will load from the most recent database state.
+
+ .. seealso::
+
+ :ref:`session_committing`
+
+ :param future: if True, use 2.0 style transactional and engine
+ behavior. Future mode includes the following behaviors:
+
+ * The :class:`_orm.Session` will not use "bound" metadata in order
+ to locate an :class:`_engine.Engine`; the engine or engines in use
+ must be specified to the constructor of :class:`_orm.Session` or
+ otherwise be configured against the :class:`_orm.sessionmaker`
+ in use
+
+ * The "subtransactions" feature of :meth:`_orm.Session.begin` is
+ removed in version 2.0 and is disabled when the future flag is
+ set.
+
+ * The behavior of the :paramref:`_orm.relationship.cascade_backrefs`
+ flag on a :func:`_orm.relationship` will always assume
+ "False" behavior.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`migration_20_toplevel`
+
+ :param info: optional dictionary of arbitrary data to be associated
+ with this :class:`.Session`. Is available via the
+ :attr:`.Session.info` attribute. Note the dictionary is copied at
+ construction time so that modifications to the per-
+ :class:`.Session` dictionary will be local to that
+ :class:`.Session`.
+
+ :param query_cls: Class which should be used to create new Query
+ objects, as returned by the :meth:`~.Session.query` method.
+ Defaults to :class:`_query.Query`.
+
+ :param twophase: When ``True``, all transactions will be started as
+ a "two phase" transaction, i.e. using the "two phase" semantics
+ of the database in use along with an XID. During a
+ :meth:`~.commit`, after :meth:`~.flush` has been issued for all
+ attached databases, the :meth:`~.TwoPhaseTransaction.prepare`
+ method on each database's :class:`.TwoPhaseTransaction` will be
+ called. This allows each database to roll back the entire
+ transaction, before each transaction is committed.
+
+ """
+ self.identity_map = identity.WeakInstanceDict()
+
+ self._new = {} # InstanceState->object, strong refs object
+ self._deleted = {} # same
+ self.bind = bind
+ self.__binds = {}
+ self._flushing = False
+ self._warn_on_events = False
+ self._transaction = None
+ self._nested_transaction = None
+ self.future = future
+ self.hash_key = _new_sessionid()
+ self.autoflush = autoflush
+ self.expire_on_commit = expire_on_commit
+ self.enable_baked_queries = enable_baked_queries
+
+ if autocommit:
+ if future:
+ raise sa_exc.ArgumentError(
+ "Cannot use autocommit mode with future=True."
+ )
+ self.autocommit = True
+ else:
+ self.autocommit = False
+
+ self.twophase = twophase
+ self._query_cls = query_cls if query_cls else query.Query
+ if info:
+ self.info.update(info)
+
+ if binds is not None:
+ for key, bind in binds.items():
+ self._add_bind(key, bind)
+
+ _sessions[self.hash_key] = self
+
+ # used by sqlalchemy.engine.util.TransactionalContext
+ _trans_context_manager = None
+
+ connection_callable = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.close()
+
+ @util.contextmanager
+ def _maker_context_manager(self):
+ with self:
+ with self.begin():
+ yield self
+
+ @property
+ @util.deprecated_20(
+ ":attr:`_orm.Session.transaction`",
+ alternative="For context manager use, use "
+ ":meth:`_orm.Session.begin`. To access "
+ "the current root transaction, use "
+ ":meth:`_orm.Session.get_transaction`.",
+ warn_on_attribute_access=True,
+ )
+ def transaction(self):
+ """The current active or inactive :class:`.SessionTransaction`.
+
+ May be None if no transaction has begun yet.
+
+ .. versionchanged:: 1.4 the :attr:`.Session.transaction` attribute
+ is now a read-only descriptor that also may return None if no
+ transaction has begun yet.
+
+
+ """
+ return self._legacy_transaction()
+
+ def _legacy_transaction(self):
+ if not self.future:
+ self._autobegin()
+ return self._transaction
+
+ def in_transaction(self):
+ """Return True if this :class:`_orm.Session` has begun a transaction.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_orm.Session.is_active`
+
+
+ """
+ return self._transaction is not None
+
+ def in_nested_transaction(self):
+ """Return True if this :class:`_orm.Session` has begun a nested
+ transaction, e.g. SAVEPOINT.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._nested_transaction is not None
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+ trans = self._transaction
+ while trans is not None and trans._parent is not None:
+ trans = trans._parent
+ return trans
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return self._nested_transaction
+
+ @util.memoized_property
+ def info(self):
+ """A user-modifiable dictionary.
+
+ The initial value of this dictionary can be populated using the
+ ``info`` argument to the :class:`.Session` constructor or
+ :class:`.sessionmaker` constructor or factory methods. The dictionary
+ here is always local to this :class:`.Session` and can be modified
+ independently of all other :class:`.Session` objects.
+
+ """
+ return {}
+
+ def _autobegin(self):
+ if not self.autocommit and self._transaction is None:
+
+ trans = SessionTransaction(self, autobegin=True)
+ assert self._transaction is trans
+ return True
+
+ return False
+
+ @util.deprecated_params(
+ subtransactions=(
+ "2.0",
+ "The :paramref:`_orm.Session.begin.subtransactions` flag is "
+ "deprecated and "
+ "will be removed in SQLAlchemy version 2.0. See "
+ "the documentation at :ref:`session_subtransactions` for "
+ "background on a compatible alternative pattern.",
+ )
+ )
+ def begin(self, subtransactions=False, nested=False, _subtrans=False):
+ """Begin a transaction, or nested transaction,
+ on this :class:`.Session`, if one is not already begun.
+
+ The :class:`_orm.Session` object features **autobegin** behavior,
+ so that normally it is not necessary to call the
+ :meth:`_orm.Session.begin`
+ method explicitly. However, it may be used in order to control
+ the scope of when the transactional state is begun.
+
+ When used to begin the outermost transaction, an error is raised
+ if this :class:`.Session` is already inside of a transaction.
+
+ :param nested: if True, begins a SAVEPOINT transaction and is
+ equivalent to calling :meth:`~.Session.begin_nested`. For
+ documentation on SAVEPOINT transactions, please see
+ :ref:`session_begin_nested`.
+
+ :param subtransactions: if True, indicates that this
+ :meth:`~.Session.begin` can create a "subtransaction".
+
+ :return: the :class:`.SessionTransaction` object. Note that
+ :class:`.SessionTransaction`
+ acts as a Python context manager, allowing :meth:`.Session.begin`
+ to be used in a "with" block. See :ref:`session_autocommit` for
+ an example.
+
+ .. seealso::
+
+ :ref:`session_autobegin`
+
+ :ref:`unitofwork_transaction`
+
+ :meth:`.Session.begin_nested`
+
+
+ """
+
+ if subtransactions and self.future:
+ raise NotImplementedError(
+ "subtransactions are not implemented in future "
+ "Session objects."
+ )
+
+ if self._autobegin():
+ if not subtransactions and not nested and not _subtrans:
+ return self._transaction
+
+ if self._transaction is not None:
+ if subtransactions or _subtrans or nested:
+ trans = self._transaction._begin(nested=nested)
+ assert self._transaction is trans
+ if nested:
+ self._nested_transaction = trans
+ else:
+ raise sa_exc.InvalidRequestError(
+ "A transaction is already begun on this Session."
+ )
+ elif not self.autocommit:
+ # outermost transaction. must be a not nested and not
+ # a subtransaction
+
+ assert not nested and not _subtrans and not subtransactions
+ trans = SessionTransaction(self)
+ assert self._transaction is trans
+ else:
+ # legacy autocommit mode
+ assert not self.future
+ trans = SessionTransaction(self, nested=nested)
+ assert self._transaction is trans
+
+ return self._transaction # needed for __enter__/__exit__ hook
+
+ def begin_nested(self):
+ """Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
+
+ The target database(s) and associated drivers must support SQL
+ SAVEPOINT for this method to function correctly.
+
+ For documentation on SAVEPOINT
+ transactions, please see :ref:`session_begin_nested`.
+
+ :return: the :class:`.SessionTransaction` object. Note that
+ :class:`.SessionTransaction` acts as a context manager, allowing
+ :meth:`.Session.begin_nested` to be used in a "with" block.
+ See :ref:`session_begin_nested` for a usage example.
+
+ .. seealso::
+
+ :ref:`session_begin_nested`
+
+ :ref:`pysqlite_serializable` - special workarounds required
+ with the SQLite driver in order for SAVEPOINT to work
+ correctly.
+
+ """
+ return self.begin(nested=True)
+
+ def rollback(self):
+ """Rollback the current transaction in progress.
+
+ If no transaction is in progress, this method is a pass-through.
+
+ In :term:`1.x-style` use, this method rolls back the topmost
+ database transaction if no nested transactions are in effect, or
+ to the current nested transaction if one is in effect.
+
+ When
+ :term:`2.0-style` use is in effect via the
+ :paramref:`_orm.Session.future` flag, the method always rolls back
+ the topmost database transaction, discarding any nested
+ transactions that may be in progress.
+
+ .. seealso::
+
+ :ref:`session_rollback`
+
+ :ref:`unitofwork_transaction`
+
+ """
+ if self._transaction is None:
+ pass
+ else:
+ self._transaction.rollback(_to_root=self.future)
+
+ def commit(self):
+ """Flush pending changes and commit the current transaction.
+
+ When the COMMIT operation is complete, all objects are fully
+ :term:`expired`, erasing their internal contents, which will be
+ automatically re-loaded when the objects are next accessed. In the
+ interim, these objects are in an expired state and will not function if
+ they are :term:`detached` from the :class:`.Session`. Additionally,
+ this re-load operation is not supported when using asyncio-oriented
+ APIs. The :paramref:`.Session.expire_on_commit` parameter may be used
+ to disable this behavior.
+
+ When there is no transaction in place for the :class:`.Session`,
+ indicating that no operations were invoked on this :class:`.Session`
+ since the previous call to :meth:`.Session.commit`, the method will
+ begin and commit an internal-only "logical" transaction, that does not
+ normally affect the database unless pending flush changes were
+ detected, but will still invoke event handlers and object expiration
+ rules.
+
+ If :term:`1.x-style` use is in effect and there are currently
+ SAVEPOINTs in progress via :meth:`_orm.Session.begin_nested`,
+ the operation will release the current SAVEPOINT but not commit
+ the outermost database transaction.
+
+ If :term:`2.0-style` use is in effect via the
+ :paramref:`_orm.Session.future` flag, the outermost database
+ transaction is committed unconditionally, automatically releasing any
+ SAVEPOINTs in effect.
+
+ When using legacy "autocommit" mode, this method is only
+ valid to call if a transaction is actually in progress, else
+ an error is raised. Similarly, when using legacy "subtransactions",
+ the method will instead close out the current "subtransaction",
+ rather than the actual database transaction, if a transaction
+ is in progress.
+
+ .. seealso::
+
+ :ref:`session_committing`
+
+ :ref:`unitofwork_transaction`
+
+ :ref:`asyncio_orm_avoid_lazyloads`
+
+ """
+ if self._transaction is None:
+ if not self._autobegin():
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
+
+ self._transaction.commit(_to_root=self.future)
+
+ def prepare(self):
+ """Prepare the current transaction in progress for two phase commit.
+
+ If no transaction is in progress, this method raises an
+ :exc:`~sqlalchemy.exc.InvalidRequestError`.
+
+ Only root transactions of two phase sessions can be prepared. If the
+ current transaction is not such, an
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if self._transaction is None:
+ if not self._autobegin():
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
+
+ self._transaction.prepare()
+
+ def connection(
+ self,
+ bind_arguments=None,
+ close_with_result=False,
+ execution_options=None,
+ **kw
+ ):
+ r"""Return a :class:`_engine.Connection` object corresponding to this
+ :class:`.Session` object's transactional state.
+
+ If this :class:`.Session` is configured with ``autocommit=False``,
+ either the :class:`_engine.Connection` corresponding to the current
+ transaction is returned, or if no transaction is in progress, a new
+ one is begun and the :class:`_engine.Connection`
+ returned (note that no
+ transactional state is established with the DBAPI until the first
+ SQL statement is emitted).
+
+ Alternatively, if this :class:`.Session` is configured with
+ ``autocommit=True``, an ad-hoc :class:`_engine.Connection` is returned
+ using :meth:`_engine.Engine.connect` on the underlying
+ :class:`_engine.Engine`.
+
+ Ambiguity in multi-bind or unbound :class:`.Session` objects can be
+ resolved through any of the optional keyword arguments. This
+ ultimately makes usage of the :meth:`.get_bind` method for resolution.
+
+ :param bind_arguments: dictionary of bind arguments. May include
+ "mapper", "bind", "clause", other custom arguments that are passed
+ to :meth:`.Session.get_bind`.
+
+ :param bind:
+ deprecated; use bind_arguments
+
+ :param mapper:
+ deprecated; use bind_arguments
+
+ :param clause:
+ deprecated; use bind_arguments
+
+ :param close_with_result: Passed to :meth:`_engine.Engine.connect`,
+ indicating the :class:`_engine.Connection` should be considered
+ "single use", automatically closing when the first result set is
+ closed. This flag only has an effect if this :class:`.Session` is
+ configured with ``autocommit=True`` and does not already have a
+ transaction in progress.
+
+ .. deprecated:: 1.4 this parameter is deprecated and will be removed
+ in SQLAlchemy 2.0
+
+ :param execution_options: a dictionary of execution options that will
+ be passed to :meth:`_engine.Connection.execution_options`, **when the
+ connection is first procured only**. If the connection is already
+ present within the :class:`.Session`, a warning is emitted and
+ the arguments are ignored.
+
+ .. seealso::
+
+ :ref:`session_transaction_isolation`
+
+ :param \**kw:
+ deprecated; use bind_arguments
+
+ """
+
+ if not bind_arguments:
+ bind_arguments = kw
+
+ bind = bind_arguments.pop("bind", None)
+ if bind is None:
+ bind = self.get_bind(**bind_arguments)
+
+ return self._connection_for_bind(
+ bind,
+ close_with_result=close_with_result,
+ execution_options=execution_options,
+ )
+
+ def _connection_for_bind(self, engine, execution_options=None, **kw):
+ TransactionalContext._trans_ctx_check(self)
+
+ if self._transaction is not None or self._autobegin():
+ return self._transaction._connection_for_bind(
+ engine, execution_options
+ )
+
+ assert self._transaction is None
+ assert self.autocommit
+ conn = engine.connect(**kw)
+ if execution_options:
+ conn = conn.execution_options(**execution_options)
+ return conn
+
+ def execute(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ _parent_execute_state=None,
+ _add_event=None,
+ **kw
+ ):
+ r"""Execute a SQL expression construct.
+
+ Returns a :class:`_engine.Result` object representing
+ results of the statement execution.
+
+ E.g.::
+
+ from sqlalchemy import select
+ result = session.execute(
+ select(User).where(User.id == 5)
+ )
+
+ The API contract of :meth:`_orm.Session.execute` is similar to that
+ of :meth:`_future.Connection.execute`, the :term:`2.0 style` version
+ of :class:`_future.Connection`.
+
+ .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is
+ now the primary point of ORM statement execution when using
+ :term:`2.0 style` ORM usage.
+
+ :param statement:
+ An executable statement (i.e. an :class:`.Executable` expression
+ such as :func:`_expression.select`).
+
+ :param params:
+ Optional dictionary, or list of dictionaries, containing
+ bound parameter values. If a single dictionary, single-row
+ execution occurs; if a list of dictionaries, an
+ "executemany" will be invoked. The keys in each dictionary
+ must correspond to parameter names present in the statement.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_engine.Connection.execution_options`, and may also
+ provide additional options understood only in an ORM context.
+
+ :param bind_arguments: dictionary of additional arguments to determine
+ the bind. May include "mapper", "bind", or other custom arguments.
+ Contents of this dictionary are passed to the
+ :meth:`.Session.get_bind` method.
+
+ :param mapper:
+ deprecated; use the bind_arguments dictionary
+
+ :param bind:
+ deprecated; use the bind_arguments dictionary
+
+ :param \**kw:
+ deprecated; use the bind_arguments dictionary
+
+ :return: a :class:`_engine.Result` object.
+
+
+ """
+ statement = coercions.expect(roles.StatementRole, statement)
+
+ if kw:
+ util.warn_deprecated_20(
+ "Passing bind arguments to Session.execute() as keyword "
+ "arguments is deprecated and will be removed SQLAlchemy 2.0. "
+ "Please use the bind_arguments parameter."
+ )
+ if not bind_arguments:
+ bind_arguments = kw
+ else:
+ bind_arguments.update(kw)
+ elif not bind_arguments:
+ bind_arguments = {}
+
+ if (
+ statement._propagate_attrs.get("compile_state_plugin", None)
+ == "orm"
+ ):
+ # note that even without "future" mode, we need
+ compile_state_cls = CompileState._get_plugin_class_for_plugin(
+ statement, "orm"
+ )
+ else:
+ compile_state_cls = None
+
+ execution_options = util.coerce_to_immutabledict(execution_options)
+
+ if compile_state_cls is not None:
+ (
+ statement,
+ execution_options,
+ ) = compile_state_cls.orm_pre_session_exec(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ _parent_execute_state is not None,
+ )
+ else:
+ bind_arguments.setdefault("clause", statement)
+ execution_options = execution_options.union(
+ {"future_result": True}
+ )
+
+ if _parent_execute_state:
+ events_todo = _parent_execute_state._remaining_events()
+ else:
+ events_todo = self.dispatch.do_orm_execute
+ if _add_event:
+ events_todo = list(events_todo) + [_add_event]
+
+ if events_todo:
+ orm_exec_state = ORMExecuteState(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ events_todo,
+ )
+ for idx, fn in enumerate(events_todo):
+ orm_exec_state._starting_event_idx = idx
+ result = fn(orm_exec_state)
+ if result:
+ return result
+
+ statement = orm_exec_state.statement
+ execution_options = orm_exec_state.local_execution_options
+
+ bind = self.get_bind(**bind_arguments)
+
+ if self.autocommit:
+ # legacy stuff, we can't use future_result w/ autocommit because
+ # we rely upon close_with_result, also legacy. it's all
+ # interrelated
+ conn = self._connection_for_bind(bind, close_with_result=True)
+ execution_options = execution_options.union(
+ dict(future_result=False)
+ )
+ else:
+ conn = self._connection_for_bind(bind)
+ result = conn._execute_20(statement, params or {}, execution_options)
+
+ if compile_state_cls:
+ result = compile_state_cls.orm_setup_cursor_result(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ return result
+
+ def scalar(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a scalar result.
+
+ Usage and parameters are the same as that of
+ :meth:`_orm.Session.execute`; the return result is a scalar Python
+ value.
+
+ """
+
+ return self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ ).scalar()
+
+ def scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return the results as scalars.
+
+ Usage and parameters are the same as that of
+ :meth:`_orm.Session.execute`; the return result is a
+ :class:`_result.ScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ :return: a :class:`_result.ScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ return self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ ).scalars()
+
+ def close(self):
+ """Close out the transactional resources and ORM objects used by this
+ :class:`_orm.Session`.
+
+ This expunges all ORM objects associated with this
+ :class:`_orm.Session`, ends any transaction in progress and
+ :term:`releases` any :class:`_engine.Connection` objects which this
+ :class:`_orm.Session` itself has checked out from associated
+ :class:`_engine.Engine` objects. The operation then leaves the
+ :class:`_orm.Session` in a state which it may be used again.
+
+ .. tip::
+
+ The :meth:`_orm.Session.close` method **does not prevent the
+ Session from being used again**. The :class:`_orm.Session` itself
+ does not actually have a distinct "closed" state; it merely means
+ the :class:`_orm.Session` will release all database connections
+ and ORM objects.
+
+ .. versionchanged:: 1.4 The :meth:`.Session.close` method does not
+ immediately create a new :class:`.SessionTransaction` object;
+ instead, the new :class:`.SessionTransaction` is created only if
+ the :class:`.Session` is used again for a database operation.
+
+ .. seealso::
+
+ :ref:`session_closing` - detail on the semantics of
+ :meth:`_orm.Session.close`
+
+ """
+ self._close_impl(invalidate=False)
+
+ def invalidate(self):
+ """Close this Session, using connection invalidation.
+
+ This is a variant of :meth:`.Session.close` that will additionally
+ ensure that the :meth:`_engine.Connection.invalidate`
+ method will be called on each :class:`_engine.Connection` object
+ that is currently in use for a transaction (typically there is only
+ one connection unless the :class:`_orm.Session` is used with
+ multiple engines).
+
+ This can be called when the database is known to be in a state where
+ the connections are no longer safe to be used.
+
+ Below illustrates a scenario when using `gevent
+ <https://www.gevent.org/>`_, which can produce ``Timeout`` exceptions
+ that may mean the underlying connection should be discarded::
+
+ import gevent
+
+ try:
+ sess = Session()
+ sess.add(User())
+ sess.commit()
+ except gevent.Timeout:
+ sess.invalidate()
+ raise
+ except:
+ sess.rollback()
+ raise
+
+ The method additionally does everything that :meth:`_orm.Session.close`
+ does, including that all ORM objects are expunged.
+
+ """
+ self._close_impl(invalidate=True)
+
+ def _close_impl(self, invalidate):
+ self.expunge_all()
+ if self._transaction is not None:
+ for transaction in self._transaction._iterate_self_and_parents():
+ transaction.close(invalidate)
+
+ def expunge_all(self):
+ """Remove all object instances from this ``Session``.
+
+ This is equivalent to calling ``expunge(obj)`` on all objects in this
+ ``Session``.
+
+ """
+
+ all_states = self.identity_map.all_states() + list(self._new)
+ self.identity_map._kill()
+ self.identity_map = identity.WeakInstanceDict()
+ self._new = {}
+ self._deleted = {}
+
+ statelib.InstanceState._detach_states(all_states, self)
+
+ def _add_bind(self, key, bind):
+ try:
+ insp = inspect(key)
+ except sa_exc.NoInspectionAvailable as err:
+ if not isinstance(key, type):
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Not an acceptable bind target: %s" % key
+ ),
+ replace_context=err,
+ )
+ else:
+ self.__binds[key] = bind
+ else:
+ if insp.is_selectable:
+ self.__binds[insp] = bind
+ elif insp.is_mapper:
+ self.__binds[insp.class_] = bind
+ for _selectable in insp._all_tables:
+ self.__binds[_selectable] = bind
+ else:
+ raise sa_exc.ArgumentError(
+ "Not an acceptable bind target: %s" % key
+ )
+
+ def bind_mapper(self, mapper, bind):
+ """Associate a :class:`_orm.Mapper` or arbitrary Python class with a
+ "bind", e.g. an :class:`_engine.Engine` or
+ :class:`_engine.Connection`.
+
+ The given entity is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
+
+ :param mapper: a :class:`_orm.Mapper` object,
+ or an instance of a mapped
+ class, or any Python class that is the base of a set of mapped
+ classes.
+
+ :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection`
+ object.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_table`
+
+
+ """
+ self._add_bind(mapper, bind)
+
+ def bind_table(self, table, bind):
+ """Associate a :class:`_schema.Table` with a "bind", e.g. an
+ :class:`_engine.Engine`
+ or :class:`_engine.Connection`.
+
+ The given :class:`_schema.Table` is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
+
+ :param table: a :class:`_schema.Table` object,
+ which is typically the target
+ of an ORM mapping, or is present within a selectable that is
+ mapped.
+
+ :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection`
+ object.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_mapper`
+
+
+ """
+ self._add_bind(table, bind)
+
+ def get_bind(
+ self,
+ mapper=None,
+ clause=None,
+ bind=None,
+ _sa_skip_events=None,
+ _sa_skip_for_implicit_returning=False,
+ ):
+ """Return a "bind" to which this :class:`.Session` is bound.
+
+ The "bind" is usually an instance of :class:`_engine.Engine`,
+ except in the case where the :class:`.Session` has been
+ explicitly bound directly to a :class:`_engine.Connection`.
+
+ For a multiply-bound or unbound :class:`.Session`, the
+ ``mapper`` or ``clause`` arguments are used to determine the
+ appropriate bind to return.
+
+ Note that the "mapper" argument is usually present
+ when :meth:`.Session.get_bind` is called via an ORM
+ operation such as a :meth:`.Session.query`, each
+ individual INSERT/UPDATE/DELETE operation within a
+ :meth:`.Session.flush`, call, etc.
+
+ The order of resolution is:
+
+ 1. if mapper given and :paramref:`.Session.binds` is present,
+ locate a bind based first on the mapper in use, then
+ on the mapped class in use, then on any base classes that are
+ present in the ``__mro__`` of the mapped class, from more specific
+ superclasses to more general.
+ 2. if clause given and ``Session.binds`` is present,
+ locate a bind based on :class:`_schema.Table` objects
+ found in the given clause present in ``Session.binds``.
+ 3. if ``Session.binds`` is present, return that.
+ 4. if clause given, attempt to return a bind
+ linked to the :class:`_schema.MetaData` ultimately
+ associated with the clause.
+ 5. if mapper given, attempt to return a bind
+ linked to the :class:`_schema.MetaData` ultimately
+ associated with the :class:`_schema.Table` or other
+ selectable to which the mapper is mapped.
+ 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError`
+ is raised.
+
+ Note that the :meth:`.Session.get_bind` method can be overridden on
+ a user-defined subclass of :class:`.Session` to provide any kind
+ of bind resolution scheme. See the example at
+ :ref:`session_custom_partitioning`.
+
+ :param mapper:
+ Optional :func:`.mapper` mapped class or instance of
+ :class:`_orm.Mapper`. The bind can be derived from a
+ :class:`_orm.Mapper`
+ first by consulting the "binds" map associated with this
+ :class:`.Session`, and secondly by consulting the
+ :class:`_schema.MetaData`
+ associated with the :class:`_schema.Table` to which the
+ :class:`_orm.Mapper`
+ is mapped for a bind.
+
+ :param clause:
+ A :class:`_expression.ClauseElement` (i.e.
+ :func:`_expression.select`,
+ :func:`_expression.text`,
+ etc.). If the ``mapper`` argument is not present or could not
+ produce a bind, the given expression construct will be searched
+ for a bound element, typically a :class:`_schema.Table`
+ associated with
+ bound :class:`_schema.MetaData`.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_mapper`
+
+ :meth:`.Session.bind_table`
+
+ """
+
+ # this function is documented as a subclassing hook, so we have
+ # to call this method even if the return is simple
+ if bind:
+ return bind
+ elif not self.__binds and self.bind:
+ # simplest and most common case, we have a bind and no
+ # per-mapper/table binds, we're done
+ return self.bind
+
+ # we don't have self.bind and either have self.__binds
+ # or we don't have self.__binds (which is legacy). Look at the
+ # mapper and the clause
+ if mapper is clause is None:
+ if self.bind:
+ return self.bind
+ else:
+ raise sa_exc.UnboundExecutionError(
+ "This session is not bound to a single Engine or "
+ "Connection, and no context was provided to locate "
+ "a binding."
+ )
+
+ # look more closely at the mapper.
+ if mapper is not None:
+ try:
+ mapper = inspect(mapper)
+ except sa_exc.NoInspectionAvailable as err:
+ if isinstance(mapper, type):
+ util.raise_(
+ exc.UnmappedClassError(mapper),
+ replace_context=err,
+ )
+ else:
+ raise
+
+ # match up the mapper or clause in the __binds
+ if self.__binds:
+ # matching mappers and selectables to entries in the
+ # binds dictionary; supported use case.
+ if mapper:
+ for cls in mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+ if clause is None:
+ clause = mapper.persist_selectable
+
+ if clause is not None:
+ plugin_subject = clause._propagate_attrs.get(
+ "plugin_subject", None
+ )
+
+ if plugin_subject is not None:
+ for cls in plugin_subject.mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+
+ for obj in visitors.iterate(clause):
+ if obj in self.__binds:
+ return self.__binds[obj]
+
+ # none of the __binds matched, but we have a fallback bind.
+ # return that
+ if self.bind:
+ return self.bind
+
+ # now we are in legacy territory. looking for "bind" on tables
+ # that are via bound metadata. this goes away in 2.0.
+
+ future_msg = ""
+ future_code = ""
+
+ if mapper and clause is None:
+ clause = mapper.persist_selectable
+
+ if clause is not None:
+ if clause.bind:
+ if self.future:
+ future_msg = (
+ " A bind was located via legacy bound metadata, but "
+ "since future=True is set on this Session, this "
+ "bind is ignored."
+ )
+ else:
+ util.warn_deprecated_20(
+ "This Session located a target engine via bound "
+ "metadata; as this functionality will be removed in "
+ "SQLAlchemy 2.0, an Engine object should be passed "
+ "to the Session() constructor directly."
+ )
+ return clause.bind
+
+ if mapper:
+ if mapper.persist_selectable.bind:
+ if self.future:
+ future_msg = (
+ " A bind was located via legacy bound metadata, but "
+ "since future=True is set on this Session, this "
+ "bind is ignored."
+ )
+ else:
+ util.warn_deprecated_20(
+ "This Session located a target engine via bound "
+ "metadata; as this functionality will be removed in "
+ "SQLAlchemy 2.0, an Engine object should be passed "
+ "to the Session() constructor directly."
+ )
+ return mapper.persist_selectable.bind
+
+ context = []
+ if mapper is not None:
+ context.append("mapper %s" % mapper)
+ if clause is not None:
+ context.append("SQL expression")
+
+ raise sa_exc.UnboundExecutionError(
+ "Could not locate a bind configured on %s or this Session.%s"
+ % (", ".join(context), future_msg),
+ code=future_code,
+ )
+
+ def query(self, *entities, **kwargs):
+ """Return a new :class:`_query.Query` object corresponding to this
+ :class:`_orm.Session`.
+
+ """
+
+ return self._query_cls(entities, self, **kwargs)
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ passive=attributes.PASSIVE_OFF,
+ lazy_loaded_from=None,
+ ):
+ """Locate an object in the identity map.
+
+ Given a primary key identity, constructs an identity key and then
+ looks in the session's identity map. If present, the object may
+ be run through unexpiration rules (e.g. load unloaded attributes,
+ check if was deleted).
+
+ e.g.::
+
+ obj = session._identity_lookup(inspect(SomeClass), (1, ))
+
+ :param mapper: mapper in use
+ :param primary_key_identity: the primary key we are searching for, as
+ a tuple.
+ :param identity_token: identity token that should be used to create
+ the identity key. Used as is, however overriding subclasses can
+ repurpose this in order to interpret the value in a special way,
+ such as if None then look among multiple target tokens.
+ :param passive: passive load flag passed to
+ :func:`.loading.get_from_identity`, which impacts the behavior if
+ the object is found; the object may be validated and/or unexpired
+ if the flag allows for SQL to be emitted.
+ :param lazy_loaded_from: an :class:`.InstanceState` that is
+ specifically asking for this identity as a related identity. Used
+ for sharding schemes where there is a correspondence between an object
+ and a related object being lazy-loaded (or otherwise
+ relationship-loaded).
+
+ :return: None if the object is not found in the identity map, *or*
+ if the object was unexpired and found to have been deleted.
+ if passive flags disallow SQL and the object is expired, returns
+ PASSIVE_NO_RESULT. In all other cases the instance is returned.
+
+ .. versionchanged:: 1.4.0 - the :meth:`.Session._identity_lookup`
+ method was moved from :class:`_query.Query` to
+ :class:`.Session`, to avoid having to instantiate the
+ :class:`_query.Query` object.
+
+
+ """
+
+ key = mapper.identity_key_from_primary_key(
+ primary_key_identity, identity_token=identity_token
+ )
+ return loading.get_from_identity(self, mapper, key, passive)
+
+ @property
+ @util.contextmanager
+ def no_autoflush(self):
+ """Return a context manager that disables autoflush.
+
+ e.g.::
+
+ with session.no_autoflush:
+
+ some_object = SomeClass()
+ session.add(some_object)
+ # won't autoflush
+ some_object.related_thing = session.query(SomeRelated).first()
+
+ Operations that proceed within the ``with:`` block
+ will not be subject to flushes occurring upon query
+ access. This is useful when initializing a series
+ of objects which involve existing database queries,
+ where the uncompleted object should not yet be flushed.
+
+ """
+ autoflush = self.autoflush
+ self.autoflush = False
+ try:
+ yield self
+ finally:
+ self.autoflush = autoflush
+
+ def _autoflush(self):
+ if self.autoflush and not self._flushing:
+ try:
+ self.flush()
+ except sa_exc.StatementError as e:
+ # note we are reraising StatementError as opposed to
+ # raising FlushError with "chaining" to remain compatible
+ # with code that catches StatementError, IntegrityError,
+ # etc.
+ e.add_detail(
+ "raised as a result of Query-invoked autoflush; "
+ "consider using a session.no_autoflush block if this "
+ "flush is occurring prematurely"
+ )
+ util.raise_(e, with_traceback=sys.exc_info()[2])
+
+ def refresh(self, instance, attribute_names=None, with_for_update=None):
+ """Expire and refresh attributes on the given instance.
+
+ The selected attributes will first be expired as they would when using
+ :meth:`_orm.Session.expire`; then a SELECT statement will be issued to
+ the database to refresh column-oriented attributes with the current
+ value available in the current transaction.
+
+ :func:`_orm.relationship` oriented attributes will also be immediately
+ loaded if they were already eagerly loaded on the object, using the
+ same eager loading strategy that they were loaded with originally.
+ Unloaded relationship attributes will remain unloaded, as will
+ relationship attributes that were originally lazy loaded.
+
+ .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method
+ can also refresh eagerly loaded attributes.
+
+ .. tip::
+
+ While the :meth:`_orm.Session.refresh` method is capable of
+ refreshing both column and relationship oriented attributes, its
+ primary focus is on refreshing of local column-oriented attributes
+ on a single instance. For more open ended "refresh" functionality,
+ including the ability to refresh the attributes on many objects at
+ once while having explicit control over relationship loader
+ strategies, use the
+ :ref:`populate existing <orm_queryguide_populate_existing>` feature
+ instead.
+
+ Note that a highly isolated transaction will return the same values as
+ were previously read in that same transaction, regardless of changes
+ in database state outside of that transaction. Refreshing
+ attributes usually only makes sense at the start of a transaction
+ where database rows have not yet been accessed.
+
+ :param attribute_names: optional. An iterable collection of
+ string attribute names indicating a subset of attributes to
+ be refreshed.
+
+ :param with_for_update: optional boolean ``True`` indicating FOR UPDATE
+ should be used, or may be a dictionary containing flags to
+ indicate a more specific set of FOR UPDATE flags for the SELECT;
+ flags should match the parameters of
+ :meth:`_query.Query.with_for_update`.
+ Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.expire_all`
+
+ :ref:`orm_queryguide_populate_existing` - allows any ORM query
+ to refresh objects as they would be loaded normally.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._expire_state(state, attribute_names)
+
+ if with_for_update == {}:
+ raise sa_exc.ArgumentError(
+ "with_for_update should be the boolean value "
+ "True, or a dictionary with options. "
+ "A blank dictionary is ambiguous."
+ )
+
+ with_for_update = query.ForUpdateArg._from_argument(with_for_update)
+
+ stmt = sql.select(object_mapper(instance))
+ if (
+ loading.load_on_ident(
+ self,
+ stmt,
+ state.key,
+ refresh_state=state,
+ with_for_update=with_for_update,
+ only_load_props=attribute_names,
+ )
+ is None
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Could not refresh instance '%s'" % instance_str(instance)
+ )
+
+ def expire_all(self):
+ """Expires all persistent instances within this Session.
+
+ When any attributes on a persistent instance is next accessed,
+ a query will be issued using the
+ :class:`.Session` object's current transactional context in order to
+ load all expired attributes for the given instance. Note that
+ a highly isolated transaction will return the same values as were
+ previously read in that same transaction, regardless of changes
+ in database state outside of that transaction.
+
+ To expire individual objects and individual attributes
+ on those objects, use :meth:`Session.expire`.
+
+ The :class:`.Session` object's default behavior is to
+ expire all state whenever the :meth:`Session.rollback`
+ or :meth:`Session.commit` methods are called, so that new
+ state can be loaded for the new transaction. For this reason,
+ calling :meth:`Session.expire_all` should not be needed when
+ autocommit is ``False``, assuming the transaction is isolated.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.refresh`
+
+ :meth:`_orm.Query.populate_existing`
+
+ """
+ for state in self.identity_map.all_states():
+ state._expire(state.dict, self.identity_map._modified)
+
+ def expire(self, instance, attribute_names=None):
+ """Expire the attributes on an instance.
+
+ Marks the attributes of an instance as out of date. When an expired
+ attribute is next accessed, a query will be issued to the
+ :class:`.Session` object's current transactional context in order to
+ load all expired attributes for the given instance. Note that
+ a highly isolated transaction will return the same values as were
+ previously read in that same transaction, regardless of changes
+ in database state outside of that transaction.
+
+ To expire all objects in the :class:`.Session` simultaneously,
+ use :meth:`Session.expire_all`.
+
+ The :class:`.Session` object's default behavior is to
+ expire all state whenever the :meth:`Session.rollback`
+ or :meth:`Session.commit` methods are called, so that new
+ state can be loaded for the new transaction. For this reason,
+ calling :meth:`Session.expire` only makes sense for the specific
+ case that a non-ORM SQL statement was emitted in the current
+ transaction.
+
+ :param instance: The instance to be refreshed.
+ :param attribute_names: optional list of string attribute names
+ indicating a subset of attributes to be expired.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.refresh`
+
+ :meth:`_orm.Query.populate_existing`
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ self._expire_state(state, attribute_names)
+
+ def _expire_state(self, state, attribute_names):
+ self._validate_persistent(state)
+ if attribute_names:
+ state._expire_attributes(state.dict, attribute_names)
+ else:
+ # pre-fetch the full cascade since the expire is going to
+ # remove associations
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("refresh-expire", state)
+ )
+ self._conditional_expire(state)
+ for o, m, st_, dct_ in cascaded:
+ self._conditional_expire(st_)
+
+ def _conditional_expire(self, state, autoflush=None):
+ """Expire a state if persistent, else expunge if pending"""
+
+ if state.key:
+ state._expire(state.dict, self.identity_map._modified)
+ elif state in self._new:
+ self._new.pop(state)
+ state._detach(self)
+
+ def expunge(self, instance):
+ """Remove the `instance` from this ``Session``.
+
+ This will free all internal references to the instance. Cascading
+ will be applied according to the *expunge* cascade rule.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ if state.session_id is not self.hash_key:
+ raise sa_exc.InvalidRequestError(
+ "Instance %s is not present in this Session" % state_str(state)
+ )
+
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("expunge", state)
+ )
+ self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
+
+ def _expunge_states(self, states, to_transient=False):
+ for state in states:
+ if state in self._new:
+ self._new.pop(state)
+ elif self.identity_map.contains_state(state):
+ self.identity_map.safe_discard(state)
+ self._deleted.pop(state, None)
+ elif self._transaction:
+ # state is "detached" from being deleted, but still present
+ # in the transaction snapshot
+ self._transaction._deleted.pop(state, None)
+ statelib.InstanceState._detach_states(
+ states, self, to_transient=to_transient
+ )
+
+ def _register_persistent(self, states):
+ """Register all persistent objects from a flush.
+
+ This is used both for pending objects moving to the persistent
+ state as well as already persistent objects.
+
+ """
+
+ pending_to_persistent = self.dispatch.pending_to_persistent or None
+ for state in states:
+ mapper = _state_mapper(state)
+
+ # prevent against last minute dereferences of the object
+ obj = state.obj()
+ if obj is not None:
+
+ instance_key = mapper._identity_key_from_state(state)
+
+ if (
+ _none_set.intersection(instance_key[1])
+ and not mapper.allow_partial_pks
+ or _none_set.issuperset(instance_key[1])
+ ):
+ raise exc.FlushError(
+ "Instance %s has a NULL identity key. If this is an "
+ "auto-generated value, check that the database table "
+ "allows generation of new primary key values, and "
+ "that the mapped Column object is configured to "
+ "expect these generated values. Ensure also that "
+ "this flush() is not occurring at an inappropriate "
+ "time, such as within a load() event."
+ % state_str(state)
+ )
+
+ if state.key is None:
+ state.key = instance_key
+ elif state.key != instance_key:
+ # primary key switch. use safe_discard() in case another
+ # state has already replaced this one in the identity
+ # map (see test/orm/test_naturalpks.py ReversePKsTest)
+ self.identity_map.safe_discard(state)
+ if state in self._transaction._key_switches:
+ orig_key = self._transaction._key_switches[state][0]
+ else:
+ orig_key = state.key
+ self._transaction._key_switches[state] = (
+ orig_key,
+ instance_key,
+ )
+ state.key = instance_key
+
+ # there can be an existing state in the identity map
+ # that is replaced when the primary keys of two instances
+ # are swapped; see test/orm/test_naturalpks.py -> test_reverse
+ old = self.identity_map.replace(state)
+ if (
+ old is not None
+ and mapper._identity_key_from_state(old) == instance_key
+ and old.obj() is not None
+ ):
+ util.warn(
+ "Identity map already had an identity for %s, "
+ "replacing it with newly flushed object. Are there "
+ "load operations occurring inside of an event handler "
+ "within the flush?" % (instance_key,)
+ )
+ state._orphaned_outside_of_session = False
+
+ statelib.InstanceState._commit_all_states(
+ ((state, state.dict) for state in states), self.identity_map
+ )
+
+ self._register_altered(states)
+
+ if pending_to_persistent is not None:
+ for state in states.intersection(self._new):
+ pending_to_persistent(self, state)
+
+ # remove from new last, might be the last strong ref
+ for state in set(states).intersection(self._new):
+ self._new.pop(state)
+
+ def _register_altered(self, states):
+ if self._transaction:
+ for state in states:
+ if state in self._new:
+ self._transaction._new[state] = True
+ else:
+ self._transaction._dirty[state] = True
+
+ def _remove_newly_deleted(self, states):
+ persistent_to_deleted = self.dispatch.persistent_to_deleted or None
+ for state in states:
+ if self._transaction:
+ self._transaction._deleted[state] = True
+
+ if persistent_to_deleted is not None:
+ # get a strong reference before we pop out of
+ # self._deleted
+ obj = state.obj() # noqa
+
+ self.identity_map.safe_discard(state)
+ self._deleted.pop(state, None)
+ state._deleted = True
+ # can't call state._detach() here, because this state
+ # is still in the transaction snapshot and needs to be
+ # tracked as part of that
+ if persistent_to_deleted is not None:
+ persistent_to_deleted(self, state)
+
+ def add(self, instance, _warn=True):
+ """Place an object in the ``Session``.
+
+ Its state will be persisted to the database on the next flush
+ operation.
+
+ Repeated calls to ``add()`` will be ignored. The opposite of ``add()``
+ is ``expunge()``.
+
+ """
+ if _warn and self._warn_on_events:
+ self._flush_warning("Session.add()")
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._save_or_update_state(state)
+
+ def add_all(self, instances):
+ """Add the given collection of instances to this ``Session``."""
+
+ if self._warn_on_events:
+ self._flush_warning("Session.add_all()")
+
+ for instance in instances:
+ self.add(instance, _warn=False)
+
+ def _save_or_update_state(self, state):
+ state._orphaned_outside_of_session = False
+ self._save_or_update_impl(state)
+
+ mapper = _state_mapper(state)
+ for o, m, st_, dct_ in mapper.cascade_iterator(
+ "save-update", state, halt_on=self._contains_state
+ ):
+ self._save_or_update_impl(st_)
+
+ def delete(self, instance):
+ """Mark an instance as deleted.
+
+ The database delete operation occurs upon ``flush()``.
+
+ """
+ if self._warn_on_events:
+ self._flush_warning("Session.delete()")
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._delete_impl(state, instance, head=True)
+
+ def _delete_impl(self, state, obj, head):
+
+ if state.key is None:
+ if head:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persisted" % state_str(state)
+ )
+ else:
+ return
+
+ to_attach = self._before_attach(state, obj)
+
+ if state in self._deleted:
+ return
+
+ self.identity_map.add(state)
+
+ if to_attach:
+ self._after_attach(state, obj)
+
+ if head:
+ # grab the cascades before adding the item to the deleted list
+ # so that autoflush does not delete the item
+ # the strong reference to the instance itself is significant here
+ cascade_states = list(
+ state.manager.mapper.cascade_iterator("delete", state)
+ )
+
+ self._deleted[state] = obj
+
+ if head:
+ for o, m, st_, dct_ in cascade_states:
+ self._delete_impl(st_, o, False)
+
+ def get(
+ self,
+ entity,
+ ident,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ execution_options=None,
+ ):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ E.g.::
+
+ my_user = session.get(User, 5)
+
+ some_object = session.get(VersionedFoo, (5, 10))
+
+ some_object = session.get(
+ VersionedFoo,
+ {"id": 5, "version_id": 10}
+ )
+
+ .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved
+ from the now deprecated :meth:`_orm.Query.get` method.
+
+ :meth:`_orm.Session.get` is special in that it provides direct
+ access to the identity map of the :class:`.Session`.
+ If the given primary key identifier is present
+ in the local identity map, the object is returned
+ directly from this collection and no SQL is emitted,
+ unless the object has been marked fully expired.
+ If not present,
+ a SELECT is performed in order to locate the object.
+
+ :meth:`_orm.Session.get` also will perform a check if
+ the object is present in the identity map and
+ marked as expired - a SELECT
+ is emitted to refresh the object as well as to
+ ensure that the row is still present.
+ If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ :param entity: a mapped class or :class:`.Mapper` indicating the
+ type of entity to be loaded.
+
+ :param ident: A scalar, tuple, or dictionary representing the
+ primary key. For a composite (e.g. multiple column) primary key,
+ a tuple or dictionary should be passed.
+
+ For a single-column primary key, the scalar calling form is typically
+ the most expedient. If the primary key of a row is the value "5",
+ the call looks like::
+
+ my_object = session.get(SomeClass, 5)
+
+ The tuple form contains primary key values typically in
+ the order in which they correspond to the mapped
+ :class:`_schema.Table`
+ object's primary key columns, or if the
+ :paramref:`_orm.Mapper.primary_key` configuration parameter were
+ used, in
+ the order used for that parameter. For example, if the primary key
+ of a row is represented by the integer
+ digits "5, 10" the call would look like::
+
+ my_object = session.get(SomeClass, (5, 10))
+
+ The dictionary form should include as keys the mapped attribute names
+ corresponding to each element of the primary key. If the mapped class
+ has the attributes ``id``, ``version_id`` as the attributes which
+ store the object's primary key value, the call would look like::
+
+ my_object = session.get(SomeClass, {"id": 5, "version_id": 10})
+
+ :param options: optional sequence of loader options which will be
+ applied to the query, if one is emitted.
+
+ :param populate_existing: causes the method to unconditionally emit
+ a SQL query and refresh the object with the newly loaded data,
+ regardless of whether or not the object is already present.
+
+ :param with_for_update: optional boolean ``True`` indicating FOR UPDATE
+ should be used, or may be a dictionary containing flags to
+ indicate a more specific set of FOR UPDATE flags for the SELECT;
+ flags should match the parameters of
+ :meth:`_query.Query.with_for_update`.
+ Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the query execution if one is emitted.
+ This dictionary can provide a subset of the options that are
+ accepted by :meth:`_engine.Connection.execution_options`, and may
+ also provide additional options understood only in an ORM context.
+
+ .. versionadded:: 1.4.29
+
+ .. seealso::
+
+ :ref:`orm_queryguide_execution_options` - ORM-specific execution
+ options
+
+ :return: The object instance, or ``None``.
+
+ """
+ return self._get_impl(
+ entity,
+ ident,
+ loading.load_on_pk_identity,
+ options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ execution_options=execution_options,
+ )
+
+ def _get_impl(
+ self,
+ entity,
+ primary_key_identity,
+ db_load_fn,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ execution_options=None,
+ ):
+
+ # convert composite types to individual args
+ if hasattr(primary_key_identity, "__composite_values__"):
+ primary_key_identity = primary_key_identity.__composite_values__()
+
+ mapper = inspect(entity)
+
+ if not mapper or not mapper.is_mapper:
+ raise sa_exc.ArgumentError(
+ "Expected mapped class or mapper, got: %r" % entity
+ )
+
+ is_dict = isinstance(primary_key_identity, dict)
+ if not is_dict:
+ primary_key_identity = util.to_list(
+ primary_key_identity, default=(None,)
+ )
+
+ if len(primary_key_identity) != len(mapper.primary_key):
+ raise sa_exc.InvalidRequestError(
+ "Incorrect number of values in identifier to formulate "
+ "primary key for session.get(); primary key columns "
+ "are %s" % ",".join("'%s'" % c for c in mapper.primary_key)
+ )
+
+ if is_dict:
+ try:
+ primary_key_identity = list(
+ primary_key_identity[prop.key]
+ for prop in mapper._identity_key_props
+ )
+
+ except KeyError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Incorrect names of values in identifier to formulate "
+ "primary key for session.get(); primary key attribute "
+ "names are %s"
+ % ",".join(
+ "'%s'" % prop.key
+ for prop in mapper._identity_key_props
+ )
+ ),
+ replace_context=err,
+ )
+
+ if (
+ not populate_existing
+ and not mapper.always_refresh
+ and with_for_update is None
+ ):
+
+ instance = self._identity_lookup(
+ mapper, primary_key_identity, identity_token=identity_token
+ )
+
+ if instance is not None:
+ # reject calls for id in identity map but class
+ # mismatch.
+ if not issubclass(instance.__class__, mapper.class_):
+ return None
+ return instance
+ elif instance is attributes.PASSIVE_CLASS_MISMATCH:
+ return None
+
+ # set_label_style() not strictly necessary, however this will ensure
+ # that tablename_colname style is used which at the moment is
+ # asserted in a lot of unit tests :)
+
+ load_options = context.QueryContext.default_load_options
+
+ if populate_existing:
+ load_options += {"_populate_existing": populate_existing}
+ statement = sql.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ if with_for_update is not None:
+ statement._for_update_arg = query.ForUpdateArg._from_argument(
+ with_for_update
+ )
+
+ if options:
+ statement = statement.options(*options)
+ if execution_options:
+ statement = statement.execution_options(**execution_options)
+ return db_load_fn(
+ self,
+ statement,
+ primary_key_identity,
+ load_options=load_options,
+ )
+
+ def merge(self, instance, load=True, options=None):
+ """Copy the state of a given instance into a corresponding instance
+ within this :class:`.Session`.
+
+ :meth:`.Session.merge` examines the primary key attributes of the
+ source instance, and attempts to reconcile it with an instance of the
+ same primary key in the session. If not found locally, it attempts
+ to load the object from the database based on primary key, and if
+ none can be located, creates a new instance. The state of each
+ attribute on the source instance is then copied to the target
+ instance. The resulting target instance is then returned by the
+ method; the original source instance is left unmodified, and
+ un-associated with the :class:`.Session` if not already.
+
+ This operation cascades to associated instances if the association is
+ mapped with ``cascade="merge"``.
+
+ See :ref:`unitofwork_merging` for a detailed discussion of merging.
+
+ .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile
+ pending objects with overlapping primary keys in the same way
+ as persistent. See :ref:`change_3601` for discussion.
+
+ :param instance: Instance to be merged.
+ :param load: Boolean, when False, :meth:`.merge` switches into
+ a "high performance" mode which causes it to forego emitting history
+ events as well as all database access. This flag is used for
+ cases such as transferring graphs of objects into a :class:`.Session`
+ from a second level cache, or to transfer just-loaded objects
+ into the :class:`.Session` owned by a worker thread or process
+ without re-querying the database.
+
+ The ``load=False`` use case adds the caveat that the given
+ object has to be in a "clean" state, that is, has no pending changes
+ to be flushed - even if the incoming object is detached from any
+ :class:`.Session`. This is so that when
+ the merge operation populates local attributes and
+ cascades to related objects and
+ collections, the values can be "stamped" onto the
+ target object as is, without generating any history or attribute
+ events, and without the need to reconcile the incoming data with
+ any existing related objects or collections that might not
+ be loaded. The resulting objects from ``load=False`` are always
+ produced as "clean", so it is only appropriate that the given objects
+ should be "clean" as well, else this suggests a mis-use of the
+ method.
+ :param options: optional sequence of loader options which will be
+ applied to the :meth:`_orm.Session.get` method when the merge
+ operation loads the existing version of the object from the database.
+
+ .. versionadded:: 1.4.24
+
+
+ .. seealso::
+
+ :func:`.make_transient_to_detached` - provides for an alternative
+ means of "merging" a single object into the :class:`.Session`
+
+ """
+
+ if self._warn_on_events:
+ self._flush_warning("Session.merge()")
+
+ _recursive = {}
+ _resolve_conflict_map = {}
+
+ if load:
+ # flush current contents if we expect to load data
+ self._autoflush()
+
+ object_mapper(instance) # verify mapped
+ autoflush = self.autoflush
+ try:
+ self.autoflush = False
+ return self._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ options=options,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ finally:
+ self.autoflush = autoflush
+
+ def _merge(
+ self,
+ state,
+ state_dict,
+ load=True,
+ options=None,
+ _recursive=None,
+ _resolve_conflict_map=None,
+ ):
+ mapper = _state_mapper(state)
+ if state in _recursive:
+ return _recursive[state]
+
+ new_instance = False
+ key = state.key
+
+ if key is None:
+ if state in self._new:
+ util.warn(
+ "Instance %s is already pending in this Session yet is "
+ "being merged again; this is probably not what you want "
+ "to do" % state_str(state)
+ )
+
+ if not load:
+ raise sa_exc.InvalidRequestError(
+ "merge() with load=False option does not support "
+ "objects transient (i.e. unpersisted) objects. flush() "
+ "all changes on mapped instances before merging with "
+ "load=False."
+ )
+ key = mapper._identity_key_from_state(state)
+ key_is_persistent = attributes.NEVER_SET not in key[1] and (
+ not _none_set.intersection(key[1])
+ or (
+ mapper.allow_partial_pks
+ and not _none_set.issuperset(key[1])
+ )
+ )
+ else:
+ key_is_persistent = True
+
+ if key in self.identity_map:
+ try:
+ merged = self.identity_map[key]
+ except KeyError:
+ # object was GC'ed right as we checked for it
+ merged = None
+ else:
+ merged = None
+
+ if merged is None:
+ if key_is_persistent and key in _resolve_conflict_map:
+ merged = _resolve_conflict_map[key]
+
+ elif not load:
+ if state.modified:
+ raise sa_exc.InvalidRequestError(
+ "merge() with load=False option does not support "
+ "objects marked as 'dirty'. flush() all changes on "
+ "mapped instances before merging with load=False."
+ )
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ merged_state.key = key
+ self._update_impl(merged_state)
+ new_instance = True
+
+ elif key_is_persistent:
+ merged = self.get(
+ mapper.class_,
+ key[1],
+ identity_token=key[2],
+ options=options,
+ )
+
+ if merged is None:
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ merged_dict = attributes.instance_dict(merged)
+ new_instance = True
+ self._save_or_update_state(merged_state)
+ else:
+ merged_state = attributes.instance_state(merged)
+ merged_dict = attributes.instance_dict(merged)
+
+ _recursive[state] = merged
+ _resolve_conflict_map[key] = merged
+
+ # check that we didn't just pull the exact same
+ # state out.
+ if state is not merged_state:
+ # version check if applicable
+ if mapper.version_id_col is not None:
+ existing_version = mapper._get_state_attr_by_column(
+ state,
+ state_dict,
+ mapper.version_id_col,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ merged_version = mapper._get_state_attr_by_column(
+ merged_state,
+ merged_dict,
+ mapper.version_id_col,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ if (
+ existing_version is not attributes.PASSIVE_NO_RESULT
+ and merged_version is not attributes.PASSIVE_NO_RESULT
+ and existing_version != merged_version
+ ):
+ raise exc.StaleDataError(
+ "Version id '%s' on merged state %s "
+ "does not match existing version '%s'. "
+ "Leave the version attribute unset when "
+ "merging to update the most recent version."
+ % (
+ existing_version,
+ state_str(merged_state),
+ merged_version,
+ )
+ )
+
+ merged_state.load_path = state.load_path
+ merged_state.load_options = state.load_options
+
+ # since we are copying load_options, we need to copy
+ # the callables_ that would have been generated by those
+ # load_options.
+ # assumes that the callables we put in state.callables_
+ # are not instance-specific (which they should not be)
+ merged_state._copy_callables(state)
+
+ for prop in mapper.iterate_properties:
+ prop.merge(
+ self,
+ state,
+ state_dict,
+ merged_state,
+ merged_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ )
+
+ if not load:
+ # remove any history
+ merged_state._commit_all(merged_dict, self.identity_map)
+
+ if new_instance:
+ merged_state.manager.dispatch.load(merged_state, None)
+ return merged
+
+ def _validate_persistent(self, state):
+ if not self.identity_map.contains_state(state):
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persistent within this Session"
+ % state_str(state)
+ )
+
+ def _save_impl(self, state):
+ if state.key is not None:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' already has an identity - "
+ "it can't be registered as pending" % state_str(state)
+ )
+
+ obj = state.obj()
+ to_attach = self._before_attach(state, obj)
+ if state not in self._new:
+ self._new[state] = obj
+ state.insert_order = len(self._new)
+ if to_attach:
+ self._after_attach(state, obj)
+
+ def _update_impl(self, state, revert_deletion=False):
+ if state.key is None:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persisted" % state_str(state)
+ )
+
+ if state._deleted:
+ if revert_deletion:
+ if not state._attached:
+ return
+ del state._deleted
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' has been deleted. "
+ "Use the make_transient() "
+ "function to send this object back "
+ "to the transient state." % state_str(state)
+ )
+
+ obj = state.obj()
+
+ # check for late gc
+ if obj is None:
+ return
+
+ to_attach = self._before_attach(state, obj)
+
+ self._deleted.pop(state, None)
+ if revert_deletion:
+ self.identity_map.replace(state)
+ else:
+ self.identity_map.add(state)
+
+ if to_attach:
+ self._after_attach(state, obj)
+ elif revert_deletion:
+ self.dispatch.deleted_to_persistent(self, state)
+
+ def _save_or_update_impl(self, state):
+ if state.key is None:
+ self._save_impl(state)
+ else:
+ self._update_impl(state)
+
+ def enable_relationship_loading(self, obj):
+ """Associate an object with this :class:`.Session` for related
+ object loading.
+
+ .. warning::
+
+ :meth:`.enable_relationship_loading` exists to serve special
+ use cases and is not recommended for general use.
+
+ Accesses of attributes mapped with :func:`_orm.relationship`
+ will attempt to load a value from the database using this
+ :class:`.Session` as the source of connectivity. The values
+ will be loaded based on foreign key and primary key values
+ present on this object - if not present, then those relationships
+ will be unavailable.
+
+ The object will be attached to this session, but will
+ **not** participate in any persistence operations; its state
+ for almost all purposes will remain either "transient" or
+ "detached", except for the case of relationship loading.
+
+ Also note that backrefs will often not work as expected.
+ Altering a relationship-bound attribute on the target object
+ may not fire off a backref event, if the effective value
+ is what was already loaded from a foreign-key-holding value.
+
+ The :meth:`.Session.enable_relationship_loading` method is
+ similar to the ``load_on_pending`` flag on :func:`_orm.relationship`.
+ Unlike that flag, :meth:`.Session.enable_relationship_loading` allows
+ an object to remain transient while still being able to load
+ related items.
+
+ To make a transient object associated with a :class:`.Session`
+ via :meth:`.Session.enable_relationship_loading` pending, add
+ it to the :class:`.Session` using :meth:`.Session.add` normally.
+ If the object instead represents an existing identity in the database,
+ it should be merged using :meth:`.Session.merge`.
+
+ :meth:`.Session.enable_relationship_loading` does not improve
+ behavior when the ORM is used normally - object references should be
+ constructed at the object level, not at the foreign key level, so
+ that they are present in an ordinary way before flush()
+ proceeds. This method is not intended for general use.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.load_on_pending` - this flag
+ allows per-relationship loading of many-to-ones on items that
+ are pending.
+
+ :func:`.make_transient_to_detached` - allows for an object to
+ be added to a :class:`.Session` without SQL emitted, which then
+ will unexpire attributes on access.
+
+ """
+ try:
+ state = attributes.instance_state(obj)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(obj),
+ replace_context=err,
+ )
+
+ to_attach = self._before_attach(state, obj)
+ state._load_pending = True
+ if to_attach:
+ self._after_attach(state, obj)
+
+ def _before_attach(self, state, obj):
+ self._autobegin()
+
+ if state.session_id == self.hash_key:
+ return False
+
+ if state.session_id and state.session_id in _sessions:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' is already attached to session '%s' "
+ "(this is '%s')"
+ % (state_str(state), state.session_id, self.hash_key)
+ )
+
+ self.dispatch.before_attach(self, state)
+
+ return True
+
+ def _after_attach(self, state, obj):
+ state.session_id = self.hash_key
+ if state.modified and state._strong_obj is None:
+ state._strong_obj = obj
+ self.dispatch.after_attach(self, state)
+
+ if state.key:
+ self.dispatch.detached_to_persistent(self, state)
+ else:
+ self.dispatch.transient_to_pending(self, state)
+
+ def __contains__(self, instance):
+ """Return True if the instance is associated with this session.
+
+ The instance may be pending or persistent within the Session for a
+ result of True.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ return self._contains_state(state)
+
+ def __iter__(self):
+ """Iterate over all pending or persistent instances within this
+ Session.
+
+ """
+ return iter(
+ list(self._new.values()) + list(self.identity_map.values())
+ )
+
+ def _contains_state(self, state):
+ return state in self._new or self.identity_map.contains_state(state)
+
+ def flush(self, objects=None):
+ """Flush all the object changes to the database.
+
+ Writes out all pending object creations, deletions and modifications
+ to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are
+ automatically ordered by the Session's unit of work dependency
+ solver.
+
+ Database operations will be issued in the current transactional
+ context and do not affect the state of the transaction, unless an
+ error occurs, in which case the entire transaction is rolled back.
+ You may flush() as often as you like within a transaction to move
+ changes from Python to the database's transaction buffer.
+
+ For ``autocommit`` Sessions with no active manual transaction, flush()
+ will create a transaction on the fly that surrounds the entire set of
+ operations into the flush.
+
+ :param objects: Optional; restricts the flush operation to operate
+ only on elements that are in the given collection.
+
+ This feature is for an extremely narrow set of use cases where
+ particular objects may need to be operated upon before the
+ full flush() occurs. It is not intended for general use.
+
+ """
+
+ if self._flushing:
+ raise sa_exc.InvalidRequestError("Session is already flushing")
+
+ if self._is_clean():
+ return
+ try:
+ self._flushing = True
+ self._flush(objects)
+ finally:
+ self._flushing = False
+
+ def _flush_warning(self, method):
+ util.warn(
+ "Usage of the '%s' operation is not currently supported "
+ "within the execution stage of the flush process. "
+ "Results may not be consistent. Consider using alternative "
+ "event listeners or connection-level operations instead." % method
+ )
+
+ def _is_clean(self):
+ return (
+ not self.identity_map.check_modified()
+ and not self._deleted
+ and not self._new
+ )
+
+ def _flush(self, objects=None):
+
+ dirty = self._dirty_states
+ if not dirty and not self._deleted and not self._new:
+ self.identity_map._modified.clear()
+ return
+
+ flush_context = UOWTransaction(self)
+
+ if self.dispatch.before_flush:
+ self.dispatch.before_flush(self, flush_context, objects)
+ # re-establish "dirty states" in case the listeners
+ # added
+ dirty = self._dirty_states
+
+ deleted = set(self._deleted)
+ new = set(self._new)
+
+ dirty = set(dirty).difference(deleted)
+
+ # create the set of all objects we want to operate upon
+ if objects:
+ # specific list passed in
+ objset = set()
+ for o in objects:
+ try:
+ state = attributes.instance_state(o)
+
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(o),
+ replace_context=err,
+ )
+ objset.add(state)
+ else:
+ objset = None
+
+ # store objects whose fate has been decided
+ processed = set()
+
+ # put all saves/updates into the flush context. detect top-level
+ # orphans and throw them into deleted.
+ if objset:
+ proc = new.union(dirty).intersection(objset).difference(deleted)
+ else:
+ proc = new.union(dirty).difference(deleted)
+
+ for state in proc:
+ is_orphan = _state_mapper(state)._is_orphan(state)
+
+ is_persistent_orphan = is_orphan and state.has_identity
+
+ if (
+ is_orphan
+ and not is_persistent_orphan
+ and state._orphaned_outside_of_session
+ ):
+ self._expunge_states([state])
+ else:
+ _reg = flush_context.register_object(
+ state, isdelete=is_persistent_orphan
+ )
+ assert _reg, "Failed to add object to the flush context!"
+ processed.add(state)
+
+ # put all remaining deletes into the flush context.
+ if objset:
+ proc = deleted.intersection(objset).difference(processed)
+ else:
+ proc = deleted.difference(processed)
+ for state in proc:
+ _reg = flush_context.register_object(state, isdelete=True)
+ assert _reg, "Failed to add object to the flush context!"
+
+ if not flush_context.has_work:
+ return
+
+ flush_context.transaction = transaction = self.begin(_subtrans=True)
+ try:
+ self._warn_on_events = True
+ try:
+ flush_context.execute()
+ finally:
+ self._warn_on_events = False
+
+ self.dispatch.after_flush(self, flush_context)
+
+ flush_context.finalize_flush_changes()
+
+ if not objects and self.identity_map._modified:
+ len_ = len(self.identity_map._modified)
+
+ statelib.InstanceState._commit_all_states(
+ [
+ (state, state.dict)
+ for state in self.identity_map._modified
+ ],
+ instance_dict=self.identity_map,
+ )
+ util.warn(
+ "Attribute history events accumulated on %d "
+ "previously clean instances "
+ "within inner-flush event handlers have been "
+ "reset, and will not result in database updates. "
+ "Consider using set_committed_value() within "
+ "inner-flush event handlers to avoid this warning." % len_
+ )
+
+ # useful assertions:
+ # if not objects:
+ # assert not self.identity_map._modified
+ # else:
+ # assert self.identity_map._modified == \
+ # self.identity_map._modified.difference(objects)
+
+ self.dispatch.after_flush_postexec(self, flush_context)
+
+ transaction.commit()
+
+ except:
+ with util.safe_reraise():
+ transaction.rollback(_capture_exception=True)
+
+ def bulk_save_objects(
+ self,
+ objects,
+ return_defaults=False,
+ update_changed_only=True,
+ preserve_order=True,
+ ):
+ """Perform a bulk save of the given list of objects.
+
+ The bulk save feature allows mapped objects to be used as the
+ source of simple INSERT and UPDATE operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations; the extraction of data from the objects is also performed
+ using a lower-latency process that ignores whether or not attributes
+ have actually been modified in the case of UPDATEs, and also ignores
+ SQL expressions.
+
+ The objects as given are not added to the session and no additional
+ state is established on them. If the
+ :paramref:`_orm.Session.bulk_save_objects.return_defaults` flag is set,
+ then server-generated primary key values will be assigned to the
+ returned objects, but **not server side defaults**; this is a
+ limitation in the implementation. If stateful objects are desired,
+ please use the standard :meth:`_orm.Session.add_all` approach or
+ as an alternative newer mass-insert features such as
+ :ref:`orm_dml_returning_objects`.
+
+ .. warning::
+
+ The bulk save feature allows for a lower-latency INSERT/UPDATE
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ INSERT/UPDATES of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param objects: a sequence of mapped object instances. The mapped
+ objects are persisted as is, and are **not** associated with the
+ :class:`.Session` afterwards.
+
+ For each object, whether the object is sent as an INSERT or an
+ UPDATE is dependent on the same rules used by the :class:`.Session`
+ in traditional operation; if the object has the
+ :attr:`.InstanceState.key`
+ attribute set, then the object is assumed to be "detached" and
+ will result in an UPDATE. Otherwise, an INSERT is used.
+
+ In the case of an UPDATE, statements are grouped based on which
+ attributes have changed, and are thus to be the subject of each
+ SET clause. If ``update_changed_only`` is False, then all
+ attributes present within each object are applied to the UPDATE
+ statement, which may help in allowing the statements to be grouped
+ together into a larger executemany(), and will also reduce the
+ overhead of checking history on attributes.
+
+ :param return_defaults: when True, rows that are missing values which
+ generate defaults, namely integer primary key defaults and sequences,
+ will be inserted **one at a time**, so that the primary key value
+ is available. In particular this will allow joined-inheritance
+ and other multi-table mappings to insert correctly without the need
+ to provide primary key values ahead of time; however,
+ :paramref:`.Session.bulk_save_objects.return_defaults` **greatly
+ reduces the performance gains** of the method overall. It is strongly
+ advised to please use the standard :meth:`_orm.Session.add_all`
+ approach.
+
+ :param update_changed_only: when True, UPDATE statements are rendered
+ based on those attributes in each state that have logged changes.
+ When False, all attributes present are rendered into the SET clause
+ with the exception of primary key attributes.
+
+ :param preserve_order: when True, the order of inserts and updates
+ matches exactly the order in which the objects are given. When
+ False, common types of objects are grouped into inserts
+ and updates, to allow for more batching opportunities.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_insert_mappings`
+
+ :meth:`.Session.bulk_update_mappings`
+
+ """
+
+ obj_states = (attributes.instance_state(obj) for obj in objects)
+
+ if not preserve_order:
+ # the purpose of this sort is just so that common mappers
+ # and persistence states are grouped together, so that groupby
+ # will return a single group for a particular type of mapper.
+ # it's not trying to be deterministic beyond that.
+ obj_states = sorted(
+ obj_states,
+ key=lambda state: (id(state.mapper), state.key is not None),
+ )
+
+ def grouping_key(state):
+ return (state.mapper, state.key is not None)
+
+ for (mapper, isupdate), states in itertools.groupby(
+ obj_states, grouping_key
+ ):
+ self._bulk_save_mappings(
+ mapper,
+ states,
+ isupdate,
+ True,
+ return_defaults,
+ update_changed_only,
+ False,
+ )
+
+ def bulk_insert_mappings(
+ self, mapper, mappings, return_defaults=False, render_nulls=False
+ ):
+ """Perform a bulk insert of the given list of mapping dictionaries.
+
+ The bulk insert feature allows plain Python dictionaries to be used as
+ the source of simple INSERT operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations. Using dictionaries, there is no "history" or session
+ state management features in use, reducing latency when inserting
+ large numbers of simple rows.
+
+ The values within the dictionaries as given are typically passed
+ without modification into Core :meth:`_expression.Insert` constructs,
+ after
+ organizing the values within them across the tables to which
+ the given mapper is mapped.
+
+ .. versionadded:: 1.0.0
+
+ .. warning::
+
+ The bulk insert feature allows for a lower-latency INSERT
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ INSERT of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param mapper: a mapped class, or the actual :class:`_orm.Mapper`
+ object,
+ representing the single kind of object represented within the mapping
+ list.
+
+ :param mappings: a sequence of dictionaries, each one containing the
+ state of the mapped row to be inserted, in terms of the attribute
+ names on the mapped class. If the mapping refers to multiple tables,
+ such as a joined-inheritance mapping, each dictionary must contain all
+ keys to be populated into all tables.
+
+ :param return_defaults: when True, rows that are missing values which
+ generate defaults, namely integer primary key defaults and sequences,
+ will be inserted **one at a time**, so that the primary key value
+ is available. In particular this will allow joined-inheritance
+ and other multi-table mappings to insert correctly without the need
+ to provide primary
+ key values ahead of time; however,
+ :paramref:`.Session.bulk_insert_mappings.return_defaults`
+ **greatly reduces the performance gains** of the method overall.
+ If the rows
+ to be inserted only refer to a single table, then there is no
+ reason this flag should be set as the returned default information
+ is not used.
+
+ :param render_nulls: When True, a value of ``None`` will result
+ in a NULL value being included in the INSERT statement, rather
+ than the column being omitted from the INSERT. This allows all
+ the rows being INSERTed to have the identical set of columns which
+ allows the full set of rows to be batched to the DBAPI. Normally,
+ each column-set that contains a different combination of NULL values
+ than the previous row must omit a different series of columns from
+ the rendered INSERT statement, which means it must be emitted as a
+ separate statement. By passing this flag, the full set of rows
+ are guaranteed to be batchable into one batch; the cost however is
+ that server-side defaults which are invoked by an omitted column will
+ be skipped, so care must be taken to ensure that these are not
+ necessary.
+
+ .. warning::
+
+ When this flag is set, **server side default SQL values will
+ not be invoked** for those columns that are inserted as NULL;
+ the NULL value will be sent explicitly. Care must be taken
+ to ensure that no server-side default functions need to be
+ invoked for the operation as a whole.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_save_objects`
+
+ :meth:`.Session.bulk_update_mappings`
+
+ """
+ self._bulk_save_mappings(
+ mapper,
+ mappings,
+ False,
+ False,
+ return_defaults,
+ False,
+ render_nulls,
+ )
+
+ def bulk_update_mappings(self, mapper, mappings):
+ """Perform a bulk update of the given list of mapping dictionaries.
+
+ The bulk update feature allows plain Python dictionaries to be used as
+ the source of simple UPDATE operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations. Using dictionaries, there is no "history" or session
+ state management features in use, reducing latency when updating
+ large numbers of simple rows.
+
+ .. versionadded:: 1.0.0
+
+ .. warning::
+
+ The bulk update feature allows for a lower-latency UPDATE
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ UPDATES of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param mapper: a mapped class, or the actual :class:`_orm.Mapper`
+ object,
+ representing the single kind of object represented within the mapping
+ list.
+
+ :param mappings: a sequence of dictionaries, each one containing the
+ state of the mapped row to be updated, in terms of the attribute names
+ on the mapped class. If the mapping refers to multiple tables, such
+ as a joined-inheritance mapping, each dictionary may contain keys
+ corresponding to all tables. All those keys which are present and
+ are not part of the primary key are applied to the SET clause of the
+ UPDATE statement; the primary key values, which are required, are
+ applied to the WHERE clause.
+
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_insert_mappings`
+
+ :meth:`.Session.bulk_save_objects`
+
+ """
+ self._bulk_save_mappings(
+ mapper, mappings, True, False, False, False, False
+ )
+
+ def _bulk_save_mappings(
+ self,
+ mapper,
+ mappings,
+ isupdate,
+ isstates,
+ return_defaults,
+ update_changed_only,
+ render_nulls,
+ ):
+ mapper = _class_to_mapper(mapper)
+ self._flushing = True
+
+ transaction = self.begin(_subtrans=True)
+ try:
+ if isupdate:
+ persistence._bulk_update(
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ update_changed_only,
+ )
+ else:
+ persistence._bulk_insert(
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+ )
+ transaction.commit()
+
+ except:
+ with util.safe_reraise():
+ transaction.rollback(_capture_exception=True)
+ finally:
+ self._flushing = False
+
+ def is_modified(self, instance, include_collections=True):
+ r"""Return ``True`` if the given instance has locally
+ modified attributes.
+
+ This method retrieves the history for each instrumented
+ attribute on the instance and performs a comparison of the current
+ value to its previously committed value, if any.
+
+ It is in effect a more expensive and accurate
+ version of checking for the given instance in the
+ :attr:`.Session.dirty` collection; a full test for
+ each attribute's net "dirty" status is performed.
+
+ E.g.::
+
+ return session.is_modified(someobject)
+
+ A few caveats to this method apply:
+
+ * Instances present in the :attr:`.Session.dirty` collection may
+ report ``False`` when tested with this method. This is because
+ the object may have received change events via attribute mutation,
+ thus placing it in :attr:`.Session.dirty`, but ultimately the state
+ is the same as that loaded from the database, resulting in no net
+ change here.
+ * Scalar attributes may not have recorded the previously set
+ value when a new value was applied, if the attribute was not loaded,
+ or was expired, at the time the new value was received - in these
+ cases, the attribute is assumed to have a change, even if there is
+ ultimately no net change against its database value. SQLAlchemy in
+ most cases does not need the "old" value when a set event occurs, so
+ it skips the expense of a SQL call if the old value isn't present,
+ based on the assumption that an UPDATE of the scalar value is
+ usually needed, and in those few cases where it isn't, is less
+ expensive on average than issuing a defensive SELECT.
+
+ The "old" value is fetched unconditionally upon set only if the
+ attribute container has the ``active_history`` flag set to ``True``.
+ This flag is set typically for primary key attributes and scalar
+ object references that are not a simple many-to-one. To set this
+ flag for any arbitrary mapped column, use the ``active_history``
+ argument with :func:`.column_property`.
+
+ :param instance: mapped instance to be tested for pending changes.
+ :param include_collections: Indicates if multivalued collections
+ should be included in the operation. Setting this to ``False`` is a
+ way to detect only local-column based properties (i.e. scalar columns
+ or many-to-one foreign keys) that would result in an UPDATE for this
+ instance upon flush.
+
+ """
+ state = object_state(instance)
+
+ if not state.modified:
+ return False
+
+ dict_ = state.dict
+
+ for attr in state.manager.attributes:
+ if (
+ not include_collections
+ and hasattr(attr.impl, "get_collection")
+ ) or not hasattr(attr.impl, "get_history"):
+ continue
+
+ (added, unchanged, deleted) = attr.impl.get_history(
+ state, dict_, passive=attributes.NO_CHANGE
+ )
+
+ if added or deleted:
+ return True
+ else:
+ return False
+
+ @property
+ def is_active(self):
+ """True if this :class:`.Session` not in "partial rollback" state.
+
+ .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
+ a new transaction immediately, so this attribute will be False
+ when the :class:`_orm.Session` is first instantiated.
+
+ "partial rollback" state typically indicates that the flush process
+ of the :class:`_orm.Session` has failed, and that the
+ :meth:`_orm.Session.rollback` method must be emitted in order to
+ fully roll back the transaction.
+
+ If this :class:`_orm.Session` is not in a transaction at all, the
+ :class:`_orm.Session` will autobegin when it is first used, so in this
+ case :attr:`_orm.Session.is_active` will return True.
+
+ Otherwise, if this :class:`_orm.Session` is within a transaction,
+ and that transaction has not been rolled back internally, the
+ :attr:`_orm.Session.is_active` will also return True.
+
+ .. seealso::
+
+ :ref:`faq_session_rollback`
+
+ :meth:`_orm.Session.in_transaction`
+
+ """
+ if self.autocommit:
+ return (
+ self._transaction is not None and self._transaction.is_active
+ )
+ else:
+ return self._transaction is None or self._transaction.is_active
+
+ identity_map = None
+ """A mapping of object identities to objects themselves.
+
+ Iterating through ``Session.identity_map.values()`` provides
+ access to the full set of persistent objects (i.e., those
+ that have row identity) currently in the session.
+
+ .. seealso::
+
+ :func:`.identity_key` - helper function to produce the keys used
+ in this dictionary.
+
+ """
+
+ @property
+ def _dirty_states(self):
+ """The set of all persistent states considered dirty.
+
+ This method returns all states that were modified including
+ those that were possibly deleted.
+
+ """
+ return self.identity_map._dirty_states()
+
+ @property
+ def dirty(self):
+ """The set of all persistent instances considered dirty.
+
+ E.g.::
+
+ some_mapped_object in session.dirty
+
+ Instances are considered dirty when they were modified but not
+ deleted.
+
+ Note that this 'dirty' calculation is 'optimistic'; most
+ attribute-setting or collection modification operations will
+ mark an instance as 'dirty' and place it in this set, even if
+ there is no net change to the attribute's value. At flush
+ time, the value of each attribute is compared to its
+ previously saved value, and if there's no net change, no SQL
+ operation will occur (this is a more expensive operation so
+ it's only done at flush time).
+
+ To check if an instance has actionable net changes to its
+ attributes, use the :meth:`.Session.is_modified` method.
+
+ """
+ return util.IdentitySet(
+ [
+ state.obj()
+ for state in self._dirty_states
+ if state not in self._deleted
+ ]
+ )
+
+ @property
+ def deleted(self):
+ "The set of all instances marked as 'deleted' within this ``Session``"
+
+ return util.IdentitySet(list(self._deleted.values()))
+
+ @property
+ def new(self):
+ "The set of all instances marked as 'new' within this ``Session``."
+
+ return util.IdentitySet(list(self._new.values()))
+
+
+class sessionmaker(_SessionClassMethods):
+ """A configurable :class:`.Session` factory.
+
+ The :class:`.sessionmaker` factory generates new
+ :class:`.Session` objects when called, creating them given
+ the configurational arguments established here.
+
+ e.g.::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy.orm import sessionmaker
+
+ # an Engine, which the Session will use for connection
+ # resources
+ engine = create_engine('postgresql://scott:tiger@localhost/')
+
+ Session = sessionmaker(engine)
+
+ with Session() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ session.commit()
+
+ Context manager use is optional; otherwise, the returned
+ :class:`_orm.Session` object may be closed explicitly via the
+ :meth:`_orm.Session.close` method. Using a
+ ``try:/finally:`` block is optional, however will ensure that the close
+ takes place even if there are database errors::
+
+ session = Session()
+ try:
+ session.add(some_object)
+ session.add(some_other_object)
+ session.commit()
+ finally:
+ session.close()
+
+ :class:`.sessionmaker` acts as a factory for :class:`_orm.Session`
+ objects in the same way as an :class:`_engine.Engine` acts as a factory
+ for :class:`_engine.Connection` objects. In this way it also includes
+ a :meth:`_orm.sessionmaker.begin` method, that provides a context
+ manager which both begins and commits a transaction, as well as closes
+ out the :class:`_orm.Session` when complete, rolling back the transaction
+ if any errors occur::
+
+ Session = sessionmaker(engine)
+
+ with Session.begin() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ # commits transaction, closes session
+
+ .. versionadded:: 1.4
+
+ When calling upon :class:`_orm.sessionmaker` to construct a
+ :class:`_orm.Session`, keyword arguments may also be passed to the
+ method; these arguments will override that of the globally configured
+ parameters. Below we use a :class:`_orm.sessionmaker` bound to a certain
+ :class:`_engine.Engine` to produce a :class:`_orm.Session` that is instead
+ bound to a specific :class:`_engine.Connection` procured from that engine::
+
+ Session = sessionmaker(engine)
+
+ # bind an individual session to a connection
+
+ with engine.connect() as connection:
+ with Session(bind=connection) as session:
+ # work with session
+
+ The class also includes a method :meth:`_orm.sessionmaker.configure`, which
+ can be used to specify additional keyword arguments to the factory, which
+ will take effect for subsequent :class:`.Session` objects generated. This
+ is usually used to associate one or more :class:`_engine.Engine` objects
+ with an existing
+ :class:`.sessionmaker` factory before it is first used::
+
+ # application starts, sessionmaker does not have
+ # an engine bound yet
+ Session = sessionmaker()
+
+ # ... later, when an engine URL is read from a configuration
+ # file or other events allow the engine to be created
+ engine = create_engine('sqlite:///foo.db')
+ Session.configure(bind=engine)
+
+ sess = Session()
+ # work with session
+
+ .. seealso::
+
+ :ref:`session_getting` - introductory text on creating
+ sessions using :class:`.sessionmaker`.
+
+ """
+
+ def __init__(
+ self,
+ bind=None,
+ class_=Session,
+ autoflush=True,
+ autocommit=False,
+ expire_on_commit=True,
+ info=None,
+ **kw
+ ):
+ r"""Construct a new :class:`.sessionmaker`.
+
+ All arguments here except for ``class_`` correspond to arguments
+ accepted by :class:`.Session` directly. See the
+ :meth:`.Session.__init__` docstring for more details on parameters.
+
+ :param bind: a :class:`_engine.Engine` or other :class:`.Connectable`
+ with
+ which newly created :class:`.Session` objects will be associated.
+ :param class\_: class to use in order to create new :class:`.Session`
+ objects. Defaults to :class:`.Session`.
+ :param autoflush: The autoflush setting to use with newly created
+ :class:`.Session` objects.
+ :param autocommit: The autocommit setting to use with newly created
+ :class:`.Session` objects.
+ :param expire_on_commit=True: the
+ :paramref:`_orm.Session.expire_on_commit` setting to use
+ with newly created :class:`.Session` objects.
+
+ :param info: optional dictionary of information that will be available
+ via :attr:`.Session.info`. Note this dictionary is *updated*, not
+ replaced, when the ``info`` parameter is specified to the specific
+ :class:`.Session` construction operation.
+
+ :param \**kw: all other keyword arguments are passed to the
+ constructor of newly created :class:`.Session` objects.
+
+ """
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["autocommit"] = autocommit
+ kw["expire_on_commit"] = expire_on_commit
+ if info is not None:
+ kw["info"] = info
+ self.kw = kw
+ # make our own subclass of the given class, so that
+ # events can be associated with it specifically.
+ self.class_ = type(class_.__name__, (class_,), {})
+
+ def begin(self):
+ """Produce a context manager that both provides a new
+ :class:`_orm.Session` as well as a transaction that commits.
+
+
+ e.g.::
+
+ Session = sessionmaker(some_engine)
+
+ with Session.begin() as session:
+ session.add(some_object)
+
+ # commits transaction, closes session
+
+ .. versionadded:: 1.4
+
+
+ """
+
+ session = self()
+ return session._maker_context_manager()
+
+ def __call__(self, **local_kw):
+ """Produce a new :class:`.Session` object using the configuration
+ established in this :class:`.sessionmaker`.
+
+ In Python, the ``__call__`` method is invoked on an object when
+ it is "called" in the same way as a function::
+
+ Session = sessionmaker()
+ session = Session() # invokes sessionmaker.__call__()
+
+ """
+ for k, v in self.kw.items():
+ if k == "info" and "info" in local_kw:
+ d = v.copy()
+ d.update(local_kw["info"])
+ local_kw["info"] = d
+ else:
+ local_kw.setdefault(k, v)
+ return self.class_(**local_kw)
+
+ def configure(self, **new_kw):
+ """(Re)configure the arguments for this sessionmaker.
+
+ e.g.::
+
+ Session = sessionmaker()
+
+ Session.configure(bind=create_engine('sqlite://'))
+ """
+ self.kw.update(new_kw)
+
+ def __repr__(self):
+ return "%s(class_=%r, %s)" % (
+ self.__class__.__name__,
+ self.class_.__name__,
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+ )
+
+
+def close_all_sessions():
+ """Close all sessions in memory.
+
+ This function consults a global registry of all :class:`.Session` objects
+ and calls :meth:`.Session.close` on them, which resets them to a clean
+ state.
+
+ This function is not for general use but may be useful for test suites
+ within the teardown scheme.
+
+ .. versionadded:: 1.3
+
+ """
+
+ for sess in _sessions.values():
+ sess.close()
+
+
+def make_transient(instance):
+ """Alter the state of the given instance so that it is :term:`transient`.
+
+ .. note::
+
+ :func:`.make_transient` is a special-case function for
+ advanced use cases only.
+
+ The given mapped instance is assumed to be in the :term:`persistent` or
+ :term:`detached` state. The function will remove its association with any
+ :class:`.Session` as well as its :attr:`.InstanceState.identity`. The
+ effect is that the object will behave as though it were newly constructed,
+ except retaining any attribute / collection values that were loaded at the
+ time of the call. The :attr:`.InstanceState.deleted` flag is also reset
+ if this object had been deleted as a result of using
+ :meth:`.Session.delete`.
+
+ .. warning::
+
+ :func:`.make_transient` does **not** "unexpire" or otherwise eagerly
+ load ORM-mapped attributes that are not currently loaded at the time
+ the function is called. This includes attributes which:
+
+ * were expired via :meth:`.Session.expire`
+
+ * were expired as the natural effect of committing a session
+ transaction, e.g. :meth:`.Session.commit`
+
+ * are normally :term:`lazy loaded` but are not currently loaded
+
+ * are "deferred" via :ref:`deferred` and are not yet loaded
+
+ * were not present in the query which loaded this object, such as that
+ which is common in joined table inheritance and other scenarios.
+
+ After :func:`.make_transient` is called, unloaded attributes such
+ as those above will normally resolve to the value ``None`` when
+ accessed, or an empty collection for a collection-oriented attribute.
+ As the object is transient and un-associated with any database
+ identity, it will no longer retrieve these values.
+
+ .. seealso::
+
+ :func:`.make_transient_to_detached`
+
+ """
+ state = attributes.instance_state(instance)
+ s = _state_session(state)
+ if s:
+ s._expunge_states([state])
+
+ # remove expired state
+ state.expired_attributes.clear()
+
+ # remove deferred callables
+ if state.callables:
+ del state.callables
+
+ if state.key:
+ del state.key
+ if state._deleted:
+ del state._deleted
+
+
+def make_transient_to_detached(instance):
+ """Make the given transient instance :term:`detached`.
+
+ .. note::
+
+ :func:`.make_transient_to_detached` is a special-case function for
+ advanced use cases only.
+
+ All attribute history on the given instance
+ will be reset as though the instance were freshly loaded
+ from a query. Missing attributes will be marked as expired.
+ The primary key attributes of the object, which are required, will be made
+ into the "key" of the instance.
+
+ The object can then be added to a session, or merged
+ possibly with the load=False flag, at which point it will look
+ as if it were loaded that way, without emitting SQL.
+
+ This is a special use case function that differs from a normal
+ call to :meth:`.Session.merge` in that a given persistent state
+ can be manufactured without any SQL calls.
+
+ .. seealso::
+
+ :func:`.make_transient`
+
+ :meth:`.Session.enable_relationship_loading`
+
+ """
+ state = attributes.instance_state(instance)
+ if state.session_id or state.key:
+ raise sa_exc.InvalidRequestError("Given object must be transient")
+ state.key = state.mapper._identity_key_from_state(state)
+ if state._deleted:
+ del state._deleted
+ state._commit_all(state.dict)
+ state._expire_attributes(state.dict, state.unloaded_expirable)
+
+
+def object_session(instance):
+ """Return the :class:`.Session` to which the given instance belongs.
+
+ This is essentially the same as the :attr:`.InstanceState.session`
+ accessor. See that attribute for details.
+
+ """
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ else:
+ return _state_session(state)
+
+
+_new_sessionid = util.counter()
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
new file mode 100644
index 0000000..9718024
--- /dev/null
+++ b/lib/sqlalchemy/orm/state.py
@@ -0,0 +1,1025 @@
+# orm/state.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines instrumentation of instances.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+"""
+
+import weakref
+
+from . import base
+from . import exc as orm_exc
+from . import interfaces
+from .base import ATTR_WAS_SET
+from .base import INIT_OK
+from .base import NEVER_SET
+from .base import NO_VALUE
+from .base import PASSIVE_NO_INITIALIZE
+from .base import PASSIVE_NO_RESULT
+from .base import PASSIVE_OFF
+from .base import SQL_OK
+from .path_registry import PathRegistry
+from .. import exc as sa_exc
+from .. import inspection
+from .. import util
+
+
+# late-populated by session.py
+_sessions = None
+
+# optionally late-provided by sqlalchemy.ext.asyncio.session
+_async_provider = None
+
+
+@inspection._self_inspects
+class InstanceState(interfaces.InspectionAttrInfo):
+ """tracks state information at the instance level.
+
+ The :class:`.InstanceState` is a key object used by the
+ SQLAlchemy ORM in order to track the state of an object;
+ it is created the moment an object is instantiated, typically
+ as a result of :term:`instrumentation` which SQLAlchemy applies
+ to the ``__init__()`` method of the class.
+
+ :class:`.InstanceState` is also a semi-public object,
+ available for runtime inspection as to the state of a
+ mapped instance, including information such as its current
+ status within a particular :class:`.Session` and details
+ about data on individual attributes. The public API
+ in order to acquire a :class:`.InstanceState` object
+ is to use the :func:`_sa.inspect` system::
+
+ >>> from sqlalchemy import inspect
+ >>> insp = inspect(some_mapped_object)
+ >>> insp.attrs.nickname.history
+ History(added=['new nickname'], unchanged=(), deleted=['nickname'])
+
+ .. seealso::
+
+ :ref:`orm_mapper_inspection_instancestate`
+
+ """
+
+ session_id = None
+ key = None
+ runid = None
+ load_options = util.EMPTY_SET
+ load_path = PathRegistry.root
+ insert_order = None
+ _strong_obj = None
+ modified = False
+ expired = False
+ _deleted = False
+ _load_pending = False
+ _orphaned_outside_of_session = False
+ is_instance = True
+ identity_token = None
+ _last_known_values = ()
+
+ callables = ()
+ """A namespace where a per-state loader callable can be associated.
+
+ In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
+ loaders that were set up via query option.
+
+ Previously, callables was used also to indicate expired attributes
+ by storing a link to the InstanceState itself in this dictionary.
+ This role is now handled by the expired_attributes set.
+
+ """
+
+ def __init__(self, obj, manager):
+ self.class_ = obj.__class__
+ self.manager = manager
+ self.obj = weakref.ref(obj, self._cleanup)
+ self.committed_state = {}
+ self.expired_attributes = set()
+
+ expired_attributes = None
+ """The set of keys which are 'expired' to be loaded by
+ the manager's deferred scalar loader, assuming no pending
+ changes.
+
+ see also the ``unmodified`` collection which is intersected
+ against this set when a refresh operation occurs."""
+
+ @util.memoized_property
+ def attrs(self):
+ """Return a namespace representing each attribute on
+ the mapped object, including its current value
+ and history.
+
+ The returned object is an instance of :class:`.AttributeState`.
+ This object allows inspection of the current data
+ within an attribute as well as attribute history
+ since the last flush.
+
+ """
+ return util.ImmutableProperties(
+ dict((key, AttributeState(self, key)) for key in self.manager)
+ )
+
+ @property
+ def transient(self):
+ """Return ``True`` if the object is :term:`transient`.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is None and not self._attached
+
+ @property
+ def pending(self):
+ """Return ``True`` if the object is :term:`pending`.
+
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is None and self._attached
+
+ @property
+ def deleted(self):
+ """Return ``True`` if the object is :term:`deleted`.
+
+ An object that is in the deleted state is guaranteed to
+ not be within the :attr:`.Session.identity_map` of its parent
+ :class:`.Session`; however if the session's transaction is rolled
+ back, the object will be restored to the persistent state and
+ the identity map.
+
+ .. note::
+
+ The :attr:`.InstanceState.deleted` attribute refers to a specific
+ state of the object that occurs between the "persistent" and
+ "detached" states; once the object is :term:`detached`, the
+ :attr:`.InstanceState.deleted` attribute **no longer returns
+ True**; in order to detect that a state was deleted, regardless
+ of whether or not the object is associated with a
+ :class:`.Session`, use the :attr:`.InstanceState.was_deleted`
+ accessor.
+
+ .. versionadded: 1.1
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and self._attached and self._deleted
+
+ @property
+ def was_deleted(self):
+ """Return True if this object is or was previously in the
+ "deleted" state and has not been reverted to persistent.
+
+ This flag returns True once the object was deleted in flush.
+ When the object is expunged from the session either explicitly
+ or via transaction commit and enters the "detached" state,
+ this flag will continue to report True.
+
+ .. versionadded:: 1.1 - added a local method form of
+ :func:`.orm.util.was_deleted`.
+
+ .. seealso::
+
+ :attr:`.InstanceState.deleted` - refers to the "deleted" state
+
+ :func:`.orm.util.was_deleted` - standalone function
+
+ :ref:`session_object_states`
+
+ """
+ return self._deleted
+
+ @property
+ def persistent(self):
+ """Return ``True`` if the object is :term:`persistent`.
+
+ An object that is in the persistent state is guaranteed to
+ be within the :attr:`.Session.identity_map` of its parent
+ :class:`.Session`.
+
+ .. versionchanged:: 1.1 The :attr:`.InstanceState.persistent`
+ accessor no longer returns True for an object that was
+ "deleted" within a flush; use the :attr:`.InstanceState.deleted`
+ accessor to detect this state. This allows the "persistent"
+ state to guarantee membership in the identity map.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and self._attached and not self._deleted
+
+ @property
+ def detached(self):
+ """Return ``True`` if the object is :term:`detached`.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and not self._attached
+
+ @property
+ @util.preload_module("sqlalchemy.orm.session")
+ def _attached(self):
+ return (
+ self.session_id is not None
+ and self.session_id in util.preloaded.orm_session._sessions
+ )
+
+ def _track_last_known_value(self, key):
+ """Track the last known value of a particular key after expiration
+ operations.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if key not in self._last_known_values:
+ self._last_known_values = dict(self._last_known_values)
+ self._last_known_values[key] = NO_VALUE
+
+ @property
+ def session(self):
+ """Return the owning :class:`.Session` for this instance,
+ or ``None`` if none available.
+
+ Note that the result here can in some cases be *different*
+ from that of ``obj in session``; an object that's been deleted
+ will report as not ``in session``, however if the transaction is
+ still in progress, this attribute will still refer to that session.
+ Only when the transaction is completed does the object become
+ fully detached under normal circumstances.
+
+ .. seealso::
+
+ :attr:`_orm.InstanceState.async_session`
+
+ """
+ if self.session_id:
+ try:
+ return _sessions[self.session_id]
+ except KeyError:
+ pass
+ return None
+
+ @property
+ def async_session(self):
+ """Return the owning :class:`_asyncio.AsyncSession` for this instance,
+ or ``None`` if none available.
+
+ This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio`
+ API is in use for this ORM object. The returned
+ :class:`_asyncio.AsyncSession` object will be a proxy for the
+ :class:`_orm.Session` object that would be returned from the
+ :attr:`_orm.InstanceState.session` attribute for this
+ :class:`_orm.InstanceState`.
+
+ .. versionadded:: 1.4.18
+
+ .. seealso::
+
+ :ref:`asyncio_toplevel`
+
+ """
+ if _async_provider is None:
+ return None
+
+ sess = self.session
+ if sess is not None:
+ return _async_provider(sess)
+ else:
+ return None
+
+ @property
+ def object(self):
+ """Return the mapped object represented by this
+ :class:`.InstanceState`."""
+ return self.obj()
+
+ @property
+ def identity(self):
+ """Return the mapped identity of the mapped object.
+ This is the primary key identity as persisted by the ORM
+ which can always be passed directly to
+ :meth:`_query.Query.get`.
+
+ Returns ``None`` if the object has no primary key identity.
+
+ .. note::
+ An object which is :term:`transient` or :term:`pending`
+ does **not** have a mapped identity until it is flushed,
+ even if its attributes include primary key values.
+
+ """
+ if self.key is None:
+ return None
+ else:
+ return self.key[1]
+
+ @property
+ def identity_key(self):
+ """Return the identity key for the mapped object.
+
+ This is the key used to locate the object within
+ the :attr:`.Session.identity_map` mapping. It contains
+ the identity as returned by :attr:`.identity` within it.
+
+
+ """
+ # TODO: just change .key to .identity_key across
+ # the board ? probably
+ return self.key
+
+ @util.memoized_property
+ def parents(self):
+ return {}
+
+ @util.memoized_property
+ def _pending_mutations(self):
+ return {}
+
+ @util.memoized_property
+ def _empty_collections(self):
+ return {}
+
+ @util.memoized_property
+ def mapper(self):
+ """Return the :class:`_orm.Mapper` used for this mapped object."""
+ return self.manager.mapper
+
+ @property
+ def has_identity(self):
+ """Return ``True`` if this object has an identity key.
+
+ This should always have the same value as the
+ expression ``state.persistent`` or ``state.detached``.
+
+ """
+ return bool(self.key)
+
+ @classmethod
+ def _detach_states(self, states, session, to_transient=False):
+ persistent_to_detached = (
+ session.dispatch.persistent_to_detached or None
+ )
+ deleted_to_detached = session.dispatch.deleted_to_detached or None
+ pending_to_transient = session.dispatch.pending_to_transient or None
+ persistent_to_transient = (
+ session.dispatch.persistent_to_transient or None
+ )
+
+ for state in states:
+ deleted = state._deleted
+ pending = state.key is None
+ persistent = not pending and not deleted
+
+ state.session_id = None
+
+ if to_transient and state.key:
+ del state.key
+ if persistent:
+ if to_transient:
+ if persistent_to_transient is not None:
+ persistent_to_transient(session, state)
+ elif persistent_to_detached is not None:
+ persistent_to_detached(session, state)
+ elif deleted and deleted_to_detached is not None:
+ deleted_to_detached(session, state)
+ elif pending and pending_to_transient is not None:
+ pending_to_transient(session, state)
+
+ state._strong_obj = None
+
+ def _detach(self, session=None):
+ if session:
+ InstanceState._detach_states([self], session)
+ else:
+ self.session_id = self._strong_obj = None
+
+ def _dispose(self):
+ self._detach()
+ del self.obj
+
+ def _cleanup(self, ref):
+ """Weakref callback cleanup.
+
+ This callable cleans out the state when it is being garbage
+ collected.
+
+ this _cleanup **assumes** that there are no strong refs to us!
+ Will not work otherwise!
+
+ """
+
+ # Python builtins become undefined during interpreter shutdown.
+ # Guard against exceptions during this phase, as the method cannot
+ # proceed in any case if builtins have been undefined.
+ if dict is None:
+ return
+
+ instance_dict = self._instance_dict()
+ if instance_dict is not None:
+ instance_dict._fast_discard(self)
+ del self._instance_dict
+
+ # we can't possibly be in instance_dict._modified
+ # b.c. this is weakref cleanup only, that set
+ # is strong referencing!
+ # assert self not in instance_dict._modified
+
+ self.session_id = self._strong_obj = None
+ del self.obj
+
+ def obj(self):
+ return None
+
+ @property
+ def dict(self):
+ """Return the instance dict used by the object.
+
+ Under normal circumstances, this is always synonymous
+ with the ``__dict__`` attribute of the mapped object,
+ unless an alternative instrumentation system has been
+ configured.
+
+ In the case that the actual object has been garbage
+ collected, this accessor returns a blank dictionary.
+
+ """
+ o = self.obj()
+ if o is not None:
+ return base.instance_dict(o)
+ else:
+ return {}
+
+ def _initialize_instance(*mixed, **kwargs):
+ self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa
+ manager = self.manager
+
+ manager.dispatch.init(self, args, kwargs)
+
+ try:
+ return manager.original_init(*mixed[1:], **kwargs)
+ except:
+ with util.safe_reraise():
+ manager.dispatch.init_failure(self, args, kwargs)
+
+ def get_history(self, key, passive):
+ return self.manager[key].impl.get_history(self, self.dict, passive)
+
+ def get_impl(self, key):
+ return self.manager[key].impl
+
+ def _get_pending_mutation(self, key):
+ if key not in self._pending_mutations:
+ self._pending_mutations[key] = PendingCollection()
+ return self._pending_mutations[key]
+
+ def __getstate__(self):
+ state_dict = {"instance": self.obj()}
+ state_dict.update(
+ (k, self.__dict__[k])
+ for k in (
+ "committed_state",
+ "_pending_mutations",
+ "modified",
+ "expired",
+ "callables",
+ "key",
+ "parents",
+ "load_options",
+ "class_",
+ "expired_attributes",
+ "info",
+ )
+ if k in self.__dict__
+ )
+ if self.load_path:
+ state_dict["load_path"] = self.load_path.serialize()
+
+ state_dict["manager"] = self.manager._serialize(self, state_dict)
+
+ return state_dict
+
+ def __setstate__(self, state_dict):
+ inst = state_dict["instance"]
+ if inst is not None:
+ self.obj = weakref.ref(inst, self._cleanup)
+ self.class_ = inst.__class__
+ else:
+ # None being possible here generally new as of 0.7.4
+ # due to storage of state in "parents". "class_"
+ # also new.
+ self.obj = None
+ self.class_ = state_dict["class_"]
+
+ self.committed_state = state_dict.get("committed_state", {})
+ self._pending_mutations = state_dict.get("_pending_mutations", {})
+ self.parents = state_dict.get("parents", {})
+ self.modified = state_dict.get("modified", False)
+ self.expired = state_dict.get("expired", False)
+ if "info" in state_dict:
+ self.info.update(state_dict["info"])
+ if "callables" in state_dict:
+ self.callables = state_dict["callables"]
+
+ try:
+ self.expired_attributes = state_dict["expired_attributes"]
+ except KeyError:
+ self.expired_attributes = set()
+ # 0.9 and earlier compat
+ for k in list(self.callables):
+ if self.callables[k] is self:
+ self.expired_attributes.add(k)
+ del self.callables[k]
+ else:
+ if "expired_attributes" in state_dict:
+ self.expired_attributes = state_dict["expired_attributes"]
+ else:
+ self.expired_attributes = set()
+
+ self.__dict__.update(
+ [
+ (k, state_dict[k])
+ for k in ("key", "load_options")
+ if k in state_dict
+ ]
+ )
+ if self.key:
+ try:
+ self.identity_token = self.key[2]
+ except IndexError:
+ # 1.1 and earlier compat before identity_token
+ assert len(self.key) == 2
+ self.key = self.key + (None,)
+ self.identity_token = None
+
+ if "load_path" in state_dict:
+ self.load_path = PathRegistry.deserialize(state_dict["load_path"])
+
+ state_dict["manager"](self, inst, state_dict)
+
+ def _reset(self, dict_, key):
+ """Remove the given attribute and any
+ callables associated with it."""
+
+ old = dict_.pop(key, None)
+ if old is not None and self.manager[key].impl.collection:
+ self.manager[key].impl._invalidate_collection(old)
+ self.expired_attributes.discard(key)
+ if self.callables:
+ self.callables.pop(key, None)
+
+ def _copy_callables(self, from_):
+ if "callables" in from_.__dict__:
+ self.callables = dict(from_.callables)
+
+ @classmethod
+ def _instance_level_callable_processor(cls, manager, fn, key):
+ impl = manager[key].impl
+ if impl.collection:
+
+ def _set_callable(state, dict_, row):
+ if "callables" not in state.__dict__:
+ state.callables = {}
+ old = dict_.pop(key, None)
+ if old is not None:
+ impl._invalidate_collection(old)
+ state.callables[key] = fn
+
+ else:
+
+ def _set_callable(state, dict_, row):
+ if "callables" not in state.__dict__:
+ state.callables = {}
+ state.callables[key] = fn
+
+ return _set_callable
+
+ def _expire(self, dict_, modified_set):
+ self.expired = True
+ if self.modified:
+ modified_set.discard(self)
+ self.committed_state.clear()
+ self.modified = False
+
+ self._strong_obj = None
+
+ if "_pending_mutations" in self.__dict__:
+ del self.__dict__["_pending_mutations"]
+
+ if "parents" in self.__dict__:
+ del self.__dict__["parents"]
+
+ self.expired_attributes.update(
+ [impl.key for impl in self.manager._loader_impls]
+ )
+
+ if self.callables:
+ # the per state loader callables we can remove here are
+ # LoadDeferredColumns, which undefers a column at the instance
+ # level that is mapped with deferred, and LoadLazyAttribute,
+ # which lazy loads a relationship at the instance level that
+ # is mapped with "noload" or perhaps "immediateload".
+ # Before 1.4, only column-based
+ # attributes could be considered to be "expired", so here they
+ # were the only ones "unexpired", which means to make them deferred
+ # again. For the moment, as of 1.4 we also apply the same
+ # treatment relationships now, that is, an instance level lazy
+ # loader is reset in the same way as a column loader.
+ for k in self.expired_attributes.intersection(self.callables):
+ del self.callables[k]
+
+ for k in self.manager._collection_impl_keys.intersection(dict_):
+ collection = dict_.pop(k)
+ collection._sa_adapter.invalidated = True
+
+ if self._last_known_values:
+ self._last_known_values.update(
+ (k, dict_[k]) for k in self._last_known_values if k in dict_
+ )
+
+ for key in self.manager._all_key_set.intersection(dict_):
+ del dict_[key]
+
+ self.manager.dispatch.expire(self, None)
+
+ def _expire_attributes(self, dict_, attribute_names, no_loader=False):
+ pending = self.__dict__.get("_pending_mutations", None)
+
+ callables = self.callables
+
+ for key in attribute_names:
+ impl = self.manager[key].impl
+ if impl.accepts_scalar_loader:
+ if no_loader and (impl.callable_ or key in callables):
+ continue
+
+ self.expired_attributes.add(key)
+ if callables and key in callables:
+ del callables[key]
+ old = dict_.pop(key, NO_VALUE)
+ if impl.collection and old is not NO_VALUE:
+ impl._invalidate_collection(old)
+
+ if (
+ self._last_known_values
+ and key in self._last_known_values
+ and old is not NO_VALUE
+ ):
+ self._last_known_values[key] = old
+
+ self.committed_state.pop(key, None)
+ if pending:
+ pending.pop(key, None)
+
+ self.manager.dispatch.expire(self, attribute_names)
+
+ def _load_expired(self, state, passive):
+ """__call__ allows the InstanceState to act as a deferred
+ callable for loading expired attributes, which is also
+ serializable (picklable).
+
+ """
+
+ if not passive & SQL_OK:
+ return PASSIVE_NO_RESULT
+
+ toload = self.expired_attributes.intersection(self.unmodified)
+ toload = toload.difference(
+ attr
+ for attr in toload
+ if not self.manager[attr].impl.load_on_unexpire
+ )
+
+ self.manager.expired_attribute_loader(self, toload, passive)
+
+ # if the loader failed, or this
+ # instance state didn't have an identity,
+ # the attributes still might be in the callables
+ # dict. ensure they are removed.
+ self.expired_attributes.clear()
+
+ return ATTR_WAS_SET
+
+ @property
+ def unmodified(self):
+ """Return the set of keys which have no uncommitted changes"""
+
+ return set(self.manager).difference(self.committed_state)
+
+ def unmodified_intersection(self, keys):
+ """Return self.unmodified.intersection(keys)."""
+
+ return (
+ set(keys)
+ .intersection(self.manager)
+ .difference(self.committed_state)
+ )
+
+ @property
+ def unloaded(self):
+ """Return the set of keys which do not have a loaded value.
+
+ This includes expired attributes and any other attribute that
+ was never populated or modified.
+
+ """
+ return (
+ set(self.manager)
+ .difference(self.committed_state)
+ .difference(self.dict)
+ )
+
+ @property
+ def unloaded_expirable(self):
+ """Return the set of keys which do not have a loaded value.
+
+ This includes expired attributes and any other attribute that
+ was never populated or modified.
+
+ """
+ return self.unloaded
+
+ @property
+ def _unloaded_non_object(self):
+ return self.unloaded.intersection(
+ attr
+ for attr in self.manager
+ if self.manager[attr].impl.accepts_scalar_loader
+ )
+
+ def _instance_dict(self):
+ return None
+
+ def _modified_event(
+ self, dict_, attr, previous, collection=False, is_userland=False
+ ):
+ if attr:
+ if not attr.send_modified_events:
+ return
+ if is_userland and attr.key not in dict_:
+ raise sa_exc.InvalidRequestError(
+ "Can't flag attribute '%s' modified; it's not present in "
+ "the object state" % attr.key
+ )
+ if attr.key not in self.committed_state or is_userland:
+ if collection:
+ if previous is NEVER_SET:
+ if attr.key in dict_:
+ previous = dict_[attr.key]
+
+ if previous not in (None, NO_VALUE, NEVER_SET):
+ previous = attr.copy(previous)
+ self.committed_state[attr.key] = previous
+
+ if attr.key in self._last_known_values:
+ self._last_known_values[attr.key] = NO_VALUE
+
+ # assert self._strong_obj is None or self.modified
+
+ if (self.session_id and self._strong_obj is None) or not self.modified:
+ self.modified = True
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ has_modified = bool(instance_dict._modified)
+ instance_dict._modified.add(self)
+ else:
+ has_modified = False
+
+ # only create _strong_obj link if attached
+ # to a session
+
+ inst = self.obj()
+ if self.session_id:
+ self._strong_obj = inst
+
+ # if identity map already had modified objects,
+ # assume autobegin already occurred, else check
+ # for autobegin
+ if not has_modified:
+ # inline of autobegin, to ensure session transaction
+ # snapshot is established
+ try:
+ session = _sessions[self.session_id]
+ except KeyError:
+ pass
+ else:
+ if session._transaction is None:
+ session._autobegin()
+
+ if inst is None and attr:
+ raise orm_exc.ObjectDereferencedError(
+ "Can't emit change event for attribute '%s' - "
+ "parent object of type %s has been garbage "
+ "collected."
+ % (self.manager[attr.key], base.state_class_str(self))
+ )
+
+ def _commit(self, dict_, keys):
+ """Commit attributes.
+
+ This is used by a partial-attribute load operation to mark committed
+ those attributes which were refreshed from the database.
+
+ Attributes marked as "expired" can potentially remain "expired" after
+ this step if a value was not populated in state.dict.
+
+ """
+ for key in keys:
+ self.committed_state.pop(key, None)
+
+ self.expired = False
+
+ self.expired_attributes.difference_update(
+ set(keys).intersection(dict_)
+ )
+
+ # the per-keys commit removes object-level callables,
+ # while that of commit_all does not. it's not clear
+ # if this behavior has a clear rationale, however tests do
+ # ensure this is what it does.
+ if self.callables:
+ for key in (
+ set(self.callables).intersection(keys).intersection(dict_)
+ ):
+ del self.callables[key]
+
+ def _commit_all(self, dict_, instance_dict=None):
+ """commit all attributes unconditionally.
+
+ This is used after a flush() or a full load/refresh
+ to remove all pending state from the instance.
+
+ - all attributes are marked as "committed"
+ - the "strong dirty reference" is removed
+ - the "modified" flag is set to False
+ - any "expired" markers for scalar attributes loaded are removed.
+ - lazy load callables for objects / collections *stay*
+
+ Attributes marked as "expired" can potentially remain
+ "expired" after this step if a value was not populated in state.dict.
+
+ """
+ self._commit_all_states([(self, dict_)], instance_dict)
+
+ @classmethod
+ def _commit_all_states(self, iter_, instance_dict=None):
+ """Mass / highly inlined version of commit_all()."""
+
+ for state, dict_ in iter_:
+ state_dict = state.__dict__
+
+ state.committed_state.clear()
+
+ if "_pending_mutations" in state_dict:
+ del state_dict["_pending_mutations"]
+
+ state.expired_attributes.difference_update(dict_)
+
+ if instance_dict and state.modified:
+ instance_dict._modified.discard(state)
+
+ state.modified = state.expired = False
+ state._strong_obj = None
+
+
+class AttributeState(object):
+ """Provide an inspection interface corresponding
+ to a particular attribute on a particular mapped object.
+
+ The :class:`.AttributeState` object is accessed
+ via the :attr:`.InstanceState.attrs` collection
+ of a particular :class:`.InstanceState`::
+
+ from sqlalchemy import inspect
+
+ insp = inspect(some_mapped_object)
+ attr_state = insp.attrs.some_attribute
+
+ """
+
+ def __init__(self, state, key):
+ self.state = state
+ self.key = key
+
+ @property
+ def loaded_value(self):
+ """The current value of this attribute as loaded from the database.
+
+ If the value has not been loaded, or is otherwise not present
+ in the object's dictionary, returns NO_VALUE.
+
+ """
+ return self.state.dict.get(self.key, NO_VALUE)
+
+ @property
+ def value(self):
+ """Return the value of this attribute.
+
+ This operation is equivalent to accessing the object's
+ attribute directly or via ``getattr()``, and will fire
+ off any pending loader callables if needed.
+
+ """
+ return self.state.manager[self.key].__get__(
+ self.state.obj(), self.state.class_
+ )
+
+ @property
+ def history(self):
+ """Return the current **pre-flush** change history for
+ this attribute, via the :class:`.History` interface.
+
+ This method will **not** emit loader callables if the value of the
+ attribute is unloaded.
+
+ .. note::
+
+ The attribute history system tracks changes on a **per flush
+ basis**. Each time the :class:`.Session` is flushed, the history
+ of each attribute is reset to empty. The :class:`.Session` by
+ default autoflushes each time a :class:`_query.Query` is invoked.
+ For
+ options on how to control this, see :ref:`session_flushing`.
+
+
+ .. seealso::
+
+ :meth:`.AttributeState.load_history` - retrieve history
+ using loader callables if the value is not locally present.
+
+ :func:`.attributes.get_history` - underlying function
+
+ """
+ return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE)
+
+ def load_history(self):
+ """Return the current **pre-flush** change history for
+ this attribute, via the :class:`.History` interface.
+
+ This method **will** emit loader callables if the value of the
+ attribute is unloaded.
+
+ .. note::
+
+ The attribute history system tracks changes on a **per flush
+ basis**. Each time the :class:`.Session` is flushed, the history
+ of each attribute is reset to empty. The :class:`.Session` by
+ default autoflushes each time a :class:`_query.Query` is invoked.
+ For
+ options on how to control this, see :ref:`session_flushing`.
+
+ .. seealso::
+
+ :attr:`.AttributeState.history`
+
+ :func:`.attributes.get_history` - underlying function
+
+ .. versionadded:: 0.9.0
+
+ """
+ return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK)
+
+
+class PendingCollection(object):
+ """A writable placeholder for an unloaded collection.
+
+ Stores items appended to and removed from a collection that has not yet
+ been loaded. When the collection is loaded, the changes stored in
+ PendingCollection are applied to it to produce the final result.
+
+ """
+
+ def __init__(self):
+ self.deleted_items = util.IdentitySet()
+ self.added_items = util.OrderedIdentitySet()
+
+ def append(self, value):
+ if value in self.deleted_items:
+ self.deleted_items.remove(value)
+ else:
+ self.added_items.add(value)
+
+ def remove(self, value):
+ if value in self.added_items:
+ self.added_items.remove(value)
+ else:
+ self.deleted_items.add(value)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
new file mode 100644
index 0000000..71aae00
--- /dev/null
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -0,0 +1,3141 @@
+# orm/strategies.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""sqlalchemy.orm.interfaces.LoaderStrategy
+ implementations, and related MapperOptions."""
+from __future__ import absolute_import
+
+import collections
+import itertools
+
+from . import attributes
+from . import exc as orm_exc
+from . import interfaces
+from . import loading
+from . import path_registry
+from . import properties
+from . import query
+from . import relationships
+from . import unitofwork
+from . import util as orm_util
+from .base import _DEFER_FOR_STATE
+from .base import _RAISE_FOR_STATE
+from .base import _SET_DEFERRED_EXPIRED
+from .context import _column_descriptions
+from .context import ORMCompileState
+from .context import ORMSelectCompileState
+from .context import QueryContext
+from .interfaces import LoaderStrategy
+from .interfaces import StrategizedProperty
+from .session import _state_session
+from .state import InstanceState
+from .util import _none_set
+from .util import aliased
+from .. import event
+from .. import exc as sa_exc
+from .. import inspect
+from .. import log
+from .. import sql
+from .. import util
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import Select
+
+
+def _register_attribute(
+ prop,
+ mapper,
+ useobject,
+ compare_function=None,
+ typecallable=None,
+ callable_=None,
+ proxy_property=None,
+ active_history=False,
+ impl_class=None,
+ **kw
+):
+
+ listen_hooks = []
+
+ uselist = useobject and prop.uselist
+
+ if useobject and prop.single_parent:
+ listen_hooks.append(single_parent_validator)
+
+ if prop.key in prop.parent.validators:
+ fn, opts = prop.parent.validators[prop.key]
+ listen_hooks.append(
+ lambda desc, prop: orm_util._validator_events(
+ desc, prop.key, fn, **opts
+ )
+ )
+
+ if useobject:
+ listen_hooks.append(unitofwork.track_cascade_events)
+
+ # need to assemble backref listeners
+ # after the singleparentvalidator, mapper validator
+ if useobject:
+ backref = prop.back_populates
+ if backref and prop._effective_sync_backref:
+ listen_hooks.append(
+ lambda desc, prop: attributes.backref_listeners(
+ desc, backref, uselist
+ )
+ )
+
+ # a single MapperProperty is shared down a class inheritance
+ # hierarchy, so we set up attribute instrumentation and backref event
+ # for each mapper down the hierarchy.
+
+ # typically, "mapper" is the same as prop.parent, due to the way
+ # the configure_mappers() process runs, however this is not strongly
+ # enforced, and in the case of a second configure_mappers() run the
+ # mapper here might not be prop.parent; also, a subclass mapper may
+ # be called here before a superclass mapper. That is, can't depend
+ # on mappers not already being set up so we have to check each one.
+
+ for m in mapper.self_and_descendants:
+ if prop is m._props.get(
+ prop.key
+ ) and not m.class_manager._attr_has_impl(prop.key):
+
+ desc = attributes.register_attribute_impl(
+ m.class_,
+ prop.key,
+ parent_token=prop,
+ uselist=uselist,
+ compare_function=compare_function,
+ useobject=useobject,
+ trackparent=useobject
+ and (
+ prop.single_parent
+ or prop.direction is interfaces.ONETOMANY
+ ),
+ typecallable=typecallable,
+ callable_=callable_,
+ active_history=active_history,
+ impl_class=impl_class,
+ send_modified_events=not useobject or not prop.viewonly,
+ doc=prop.doc,
+ **kw
+ )
+
+ for hook in listen_hooks:
+ hook(desc, prop)
+
+
+@properties.ColumnProperty.strategy_for(instrument=False, deferred=False)
+class UninstrumentedColumnLoader(LoaderStrategy):
+ """Represent a non-instrumented MapperProperty.
+
+ The polymorphic_on argument of mapper() often results in this,
+ if the argument is against the with_polymorphic selectable.
+
+ """
+
+ __slots__ = ("columns",)
+
+ def __init__(self, parent, strategy_key):
+ super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key)
+ self.columns = self.parent_property.columns
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ **kwargs
+ ):
+ for c in self.columns:
+ if adapter:
+ c = adapter.columns[c]
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ pass
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(instrument=True, deferred=False)
+class ColumnLoader(LoaderStrategy):
+ """Provide loading behavior for a :class:`.ColumnProperty`."""
+
+ __slots__ = "columns", "is_composite"
+
+ def __init__(self, parent, strategy_key):
+ super(ColumnLoader, self).__init__(parent, strategy_key)
+ self.columns = self.parent_property.columns
+ self.is_composite = hasattr(self.parent_property, "composite_class")
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ check_for_adapt=False,
+ **kwargs
+ ):
+ for c in self.columns:
+ if adapter:
+ if check_for_adapt:
+ c = adapter.adapt_check_present(c)
+ if c is None:
+ return
+ else:
+ c = adapter.columns[c]
+
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ fetch = self.columns[0]
+ if adapter:
+ fetch = adapter.columns[fetch]
+ memoized_populators[self.parent_property] = fetch
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+ coltype = self.columns[0].type
+ # TODO: check all columns ? check for foreign key as well?
+ active_history = (
+ self.parent_property.active_history
+ or self.columns[0].primary_key
+ or (
+ mapper.version_id_col is not None
+ and mapper._columntoproperty.get(mapper.version_id_col, None)
+ is self.parent_property
+ )
+ )
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=coltype.compare_values,
+ active_history=active_history,
+ )
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ # look through list of columns represented here
+ # to see which, if any, is present in the row.
+ for col in self.columns:
+ if adapter:
+ col = adapter.columns[col]
+ getter = result._getter(col, False)
+ if getter:
+ populators["quick"].append((self.key, getter))
+ break
+ else:
+ populators["expire"].append((self.key, True))
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(query_expression=True)
+class ExpressionColumnLoader(ColumnLoader):
+ def __init__(self, parent, strategy_key):
+ super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
+
+ # compare to the "default" expression that is mapped in
+ # the column. If it's sql.null, we don't need to render
+ # unless an expr is passed in the options.
+ null = sql.null().label(None)
+ self._have_default_expression = any(
+ not c.compare(null) for c in self.parent_property.columns
+ )
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kwargs
+ ):
+ columns = None
+ if loadopt and "expression" in loadopt.local_opts:
+ columns = [loadopt.local_opts["expression"]]
+ elif self._have_default_expression:
+ columns = self.parent_property.columns
+
+ if columns is None:
+ return
+
+ for c in columns:
+ if adapter:
+ c = adapter.columns[c]
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ fetch = columns[0]
+ if adapter:
+ fetch = adapter.columns[fetch]
+ memoized_populators[self.parent_property] = fetch
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ # look through list of columns represented here
+ # to see which, if any, is present in the row.
+ if loadopt and "expression" in loadopt.local_opts:
+ columns = [loadopt.local_opts["expression"]]
+
+ for col in columns:
+ if adapter:
+ col = adapter.columns[col]
+ getter = result._getter(col, False)
+ if getter:
+ populators["quick"].append((self.key, getter))
+ break
+ else:
+ populators["expire"].append((self.key, True))
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=self.columns[0].type.compare_values,
+ accepts_scalar_loader=False,
+ )
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(deferred=True, instrument=True)
+@properties.ColumnProperty.strategy_for(
+ deferred=True, instrument=True, raiseload=True
+)
+@properties.ColumnProperty.strategy_for(do_nothing=True)
+class DeferredColumnLoader(LoaderStrategy):
+ """Provide loading behavior for a deferred :class:`.ColumnProperty`."""
+
+ __slots__ = "columns", "group", "raiseload"
+
+ def __init__(self, parent, strategy_key):
+ super(DeferredColumnLoader, self).__init__(parent, strategy_key)
+ if hasattr(self.parent_property, "composite_class"):
+ raise NotImplementedError(
+ "Deferred loading for composite " "types not implemented yet"
+ )
+ self.raiseload = self.strategy_opts.get("raiseload", False)
+ self.columns = self.parent_property.columns
+ self.group = self.parent_property.group
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ # for a DeferredColumnLoader, this method is only used during a
+ # "row processor only" query; see test_deferred.py ->
+ # tests with "rowproc_only" in their name. As of the 1.0 series,
+ # loading._instance_processor doesn't use a "row processing" function
+ # to populate columns, instead it uses data in the "populators"
+ # dictionary. Normally, the DeferredColumnLoader.setup_query()
+ # sets up that data in the "memoized_populators" dictionary
+ # and "create_row_processor()" here is never invoked.
+
+ if (
+ context.refresh_state
+ and context.query._compile_options._only_load_props
+ and self.key in context.query._compile_options._only_load_props
+ ):
+ self.parent_property._get_strategy(
+ (("deferred", False), ("instrument", True))
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ elif not self.is_class_level:
+ if self.raiseload:
+ set_deferred_for_local_state = (
+ self.parent_property._raise_column_loader
+ )
+ else:
+ set_deferred_for_local_state = (
+ self.parent_property._deferred_column_loader
+ )
+ populators["new"].append((self.key, set_deferred_for_local_state))
+ else:
+ populators["expire"].append((self.key, False))
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=self.columns[0].type.compare_values,
+ callable_=self._load_for_state,
+ load_on_unexpire=False,
+ )
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ only_load_props=None,
+ **kw
+ ):
+
+ if (
+ (
+ compile_state.compile_options._render_for_subquery
+ and self.parent_property._renders_in_subqueries
+ )
+ or (
+ loadopt
+ and "undefer_pks" in loadopt.local_opts
+ and set(self.columns).intersection(
+ self.parent._should_undefer_in_wildcard
+ )
+ )
+ or (
+ loadopt
+ and self.group
+ and loadopt.local_opts.get(
+ "undefer_group_%s" % self.group, False
+ )
+ )
+ or (only_load_props and self.key in only_load_props)
+ ):
+ self.parent_property._get_strategy(
+ (("deferred", False), ("instrument", True))
+ ).setup_query(
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kw
+ )
+ elif self.is_class_level:
+ memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED
+ elif not self.raiseload:
+ memoized_populators[self.parent_property] = _DEFER_FOR_STATE
+ else:
+ memoized_populators[self.parent_property] = _RAISE_FOR_STATE
+
+ def _load_for_state(self, state, passive):
+ if not state.key:
+ return attributes.ATTR_EMPTY
+
+ if not passive & attributes.SQL_OK:
+ return attributes.PASSIVE_NO_RESULT
+
+ localparent = state.manager.mapper
+
+ if self.group:
+ toload = [
+ p.key
+ for p in localparent.iterate_properties
+ if isinstance(p, StrategizedProperty)
+ and isinstance(p.strategy, DeferredColumnLoader)
+ and p.group == self.group
+ ]
+ else:
+ toload = [self.key]
+
+ # narrow the keys down to just those which have no history
+ group = [k for k in toload if k in state.unmodified]
+
+ session = _state_session(state)
+ if session is None:
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session; "
+ "deferred load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
+ )
+
+ if self.raiseload:
+ self._invoke_raise_load(state, passive, "raise")
+
+ if (
+ loading.load_on_ident(
+ session,
+ sql.select(localparent).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ ),
+ state.key,
+ only_load_props=group,
+ refresh_state=state,
+ )
+ is None
+ ):
+ raise orm_exc.ObjectDeletedError(state)
+
+ return attributes.ATTR_WAS_SET
+
+ def _invoke_raise_load(self, state, passive, lazy):
+ raise sa_exc.InvalidRequestError(
+ "'%s' is not available due to raiseload=True" % (self,)
+ )
+
+
+class LoadDeferredColumns(object):
+ """serializable loader object used by DeferredColumnLoader"""
+
+ def __init__(self, key, raiseload=False):
+ self.key = key
+ self.raiseload = raiseload
+
+ def __call__(self, state, passive=attributes.PASSIVE_OFF):
+ key = self.key
+
+ localparent = state.manager.mapper
+ prop = localparent._props[key]
+ if self.raiseload:
+ strategy_key = (
+ ("deferred", True),
+ ("instrument", True),
+ ("raiseload", True),
+ )
+ else:
+ strategy_key = (("deferred", True), ("instrument", True))
+ strategy = prop._get_strategy(strategy_key)
+ return strategy._load_for_state(state, passive)
+
+
+class AbstractRelationshipLoader(LoaderStrategy):
+ """LoaderStratgies which deal with related objects."""
+
+ __slots__ = "mapper", "target", "uselist", "entity"
+
+ def __init__(self, parent, strategy_key):
+ super(AbstractRelationshipLoader, self).__init__(parent, strategy_key)
+ self.mapper = self.parent_property.mapper
+ self.entity = self.parent_property.entity
+ self.target = self.parent_property.target
+ self.uselist = self.parent_property.uselist
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(do_nothing=True)
+class DoNothingLoader(LoaderStrategy):
+ """Relationship loader that makes no change to the object's state.
+
+ Compared to NoLoader, this loader does not initialize the
+ collection/attribute to empty/none; the usual default LazyLoader will
+ take effect.
+
+ """
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="noload")
+@relationships.RelationshipProperty.strategy_for(lazy=None)
+class NoLoader(AbstractRelationshipLoader):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ with "lazy=None".
+
+ """
+
+ __slots__ = ()
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ typecallable=self.parent_property.collection_class,
+ )
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ def invoke_no_load(state, dict_, row):
+ if self.uselist:
+ attributes.init_state_collection(state, dict_, self.key)
+ else:
+ dict_[self.key] = None
+
+ populators["new"].append((self.key, invoke_no_load))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy=True)
+@relationships.RelationshipProperty.strategy_for(lazy="select")
+@relationships.RelationshipProperty.strategy_for(lazy="raise")
+@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql")
+@relationships.RelationshipProperty.strategy_for(lazy="baked_select")
+class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ with "lazy=True", that is loads when first accessed.
+
+ """
+
+ __slots__ = (
+ "_lazywhere",
+ "_rev_lazywhere",
+ "_lazyload_reverse_option",
+ "_order_by",
+ "use_get",
+ "is_aliased_class",
+ "_bind_to_col",
+ "_equated_columns",
+ "_rev_bind_to_col",
+ "_rev_equated_columns",
+ "_simple_lazy_clause",
+ "_raise_always",
+ "_raise_on_sql",
+ )
+
+ def __init__(self, parent, strategy_key):
+ super(LazyLoader, self).__init__(parent, strategy_key)
+ self._raise_always = self.strategy_opts["lazy"] == "raise"
+ self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
+
+ self.is_aliased_class = inspect(self.entity).is_aliased_class
+
+ join_condition = self.parent_property._join_condition
+ (
+ self._lazywhere,
+ self._bind_to_col,
+ self._equated_columns,
+ ) = join_condition.create_lazy_clause()
+
+ (
+ self._rev_lazywhere,
+ self._rev_bind_to_col,
+ self._rev_equated_columns,
+ ) = join_condition.create_lazy_clause(reverse_direction=True)
+
+ if self.parent_property.order_by:
+ self._order_by = [
+ sql_util._deep_annotate(elem, {"_orm_adapt": True})
+ for elem in util.to_list(self.parent_property.order_by)
+ ]
+ else:
+ self._order_by = None
+
+ self.logger.info("%s lazy loading clause %s", self, self._lazywhere)
+
+ # determine if our "lazywhere" clause is the same as the mapper's
+ # get() clause. then we can just use mapper.get()
+ #
+ # TODO: the "not self.uselist" can be taken out entirely; a m2o
+ # load that populates for a list (very unusual, but is possible with
+ # the API) can still set for "None" and the attribute system will
+ # populate as an empty list.
+ self.use_get = (
+ not self.is_aliased_class
+ and not self.uselist
+ and self.entity._get_clause[0].compare(
+ self._lazywhere,
+ use_proxies=True,
+ compare_keys=False,
+ equivalents=self.mapper._equivalent_columns,
+ )
+ )
+
+ if self.use_get:
+ for col in list(self._equated_columns):
+ if col in self.mapper._equivalent_columns:
+ for c in self.mapper._equivalent_columns[col]:
+ self._equated_columns[c] = self._equated_columns[col]
+
+ self.logger.info(
+ "%s will use Session.get() to " "optimize instance loads", self
+ )
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _legacy_inactive_history_style = (
+ self.parent_property._legacy_inactive_history_style
+ )
+
+ if self.parent_property.active_history:
+ active_history = True
+ _deferred_history = False
+
+ elif (
+ self.parent_property.direction is not interfaces.MANYTOONE
+ or not self.use_get
+ ):
+ if _legacy_inactive_history_style:
+ active_history = True
+ _deferred_history = False
+ else:
+ active_history = False
+ _deferred_history = True
+ else:
+ active_history = _deferred_history = False
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ callable_=self._load_for_state,
+ typecallable=self.parent_property.collection_class,
+ active_history=active_history,
+ _deferred_history=_deferred_history,
+ )
+
+ def _memoized_attr__simple_lazy_clause(self):
+
+ lazywhere = sql_util._deep_annotate(
+ self._lazywhere, {"_orm_adapt": True}
+ )
+
+ criterion, bind_to_col = (lazywhere, self._bind_to_col)
+
+ params = []
+
+ def visit_bindparam(bindparam):
+ bindparam.unique = False
+
+ visitors.traverse(criterion, {}, {"bindparam": visit_bindparam})
+
+ def visit_bindparam(bindparam):
+ if bindparam._identifying_key in bind_to_col:
+ params.append(
+ (
+ bindparam.key,
+ bind_to_col[bindparam._identifying_key],
+ None,
+ )
+ )
+ elif bindparam.callable is None:
+ params.append((bindparam.key, None, bindparam.value))
+
+ criterion = visitors.cloned_traverse(
+ criterion, {}, {"bindparam": visit_bindparam}
+ )
+
+ return criterion, params
+
+ def _generate_lazy_clause(self, state, passive):
+ criterion, param_keys = self._simple_lazy_clause
+
+ if state is None:
+ return sql_util.adapt_criterion_to_null(
+ criterion, [key for key, ident, value in param_keys]
+ )
+
+ mapper = self.parent_property.parent
+
+ o = state.obj() # strong ref
+ dict_ = attributes.instance_dict(o)
+
+ if passive & attributes.INIT_OK:
+ passive ^= attributes.INIT_OK
+
+ params = {}
+ for key, ident, value in param_keys:
+ if ident is not None:
+ if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
+ value = mapper._get_committed_state_attr_by_column(
+ state, dict_, ident, passive
+ )
+ else:
+ value = mapper._get_state_attr_by_column(
+ state, dict_, ident, passive
+ )
+
+ params[key] = value
+
+ return criterion, params
+
+ def _invoke_raise_load(self, state, passive, lazy):
+ raise sa_exc.InvalidRequestError(
+ "'%s' is not available due to lazy='%s'" % (self, lazy)
+ )
+
+ def _load_for_state(self, state, passive, loadopt=None, extra_criteria=()):
+ if not state.key and (
+ (
+ not self.parent_property.load_on_pending
+ and not state._load_pending
+ )
+ or not state.session_id
+ ):
+ return attributes.ATTR_EMPTY
+
+ pending = not state.key
+ primary_key_identity = None
+
+ use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
+
+ if (not passive & attributes.SQL_OK and not use_get) or (
+ not passive & attributes.NON_PERSISTENT_OK and pending
+ ):
+ return attributes.PASSIVE_NO_RESULT
+
+ if (
+ # we were given lazy="raise"
+ self._raise_always
+ # the no_raise history-related flag was not passed
+ and not passive & attributes.NO_RAISE
+ and (
+ # if we are use_get and related_object_ok is disabled,
+ # which means we are at most looking in the identity map
+ # for history purposes or otherwise returning
+ # PASSIVE_NO_RESULT, don't raise. This is also a
+ # history-related flag
+ not use_get
+ or passive & attributes.RELATED_OBJECT_OK
+ )
+ ):
+
+ self._invoke_raise_load(state, passive, "raise")
+
+ session = _state_session(state)
+ if not session:
+ if passive & attributes.NO_RAISE:
+ return attributes.PASSIVE_NO_RESULT
+
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session; "
+ "lazy load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
+ )
+
+ # if we have a simple primary key load, check the
+ # identity map without generating a Query at all
+ if use_get:
+ primary_key_identity = self._get_ident_for_use_get(
+ session, state, passive
+ )
+ if attributes.PASSIVE_NO_RESULT in primary_key_identity:
+ return attributes.PASSIVE_NO_RESULT
+ elif attributes.NEVER_SET in primary_key_identity:
+ return attributes.NEVER_SET
+
+ if _none_set.issuperset(primary_key_identity):
+ return None
+
+ if (
+ self.key in state.dict
+ and not passive & attributes.DEFERRED_HISTORY_LOAD
+ ):
+ return attributes.ATTR_WAS_SET
+
+ # look for this identity in the identity map. Delegate to the
+ # Query class in use, as it may have special rules for how it
+ # does this, including how it decides what the correct
+ # identity_token would be for this identity.
+
+ instance = session._identity_lookup(
+ self.entity,
+ primary_key_identity,
+ passive=passive,
+ lazy_loaded_from=state,
+ )
+
+ if instance is not None:
+ if instance is attributes.PASSIVE_CLASS_MISMATCH:
+ return None
+ else:
+ return instance
+ elif (
+ not passive & attributes.SQL_OK
+ or not passive & attributes.RELATED_OBJECT_OK
+ ):
+ return attributes.PASSIVE_NO_RESULT
+
+ return self._emit_lazyload(
+ session,
+ state,
+ primary_key_identity,
+ passive,
+ loadopt,
+ extra_criteria,
+ )
+
+ def _get_ident_for_use_get(self, session, state, passive):
+ instance_mapper = state.manager.mapper
+
+ if passive & attributes.LOAD_AGAINST_COMMITTED:
+ get_attr = instance_mapper._get_committed_state_attr_by_column
+ else:
+ get_attr = instance_mapper._get_state_attr_by_column
+
+ dict_ = state.dict
+
+ return [
+ get_attr(state, dict_, self._equated_columns[pk], passive=passive)
+ for pk in self.mapper.primary_key
+ ]
+
+ @util.preload_module("sqlalchemy.orm.strategy_options")
+ def _emit_lazyload(
+ self,
+ session,
+ state,
+ primary_key_identity,
+ passive,
+ loadopt,
+ extra_criteria,
+ ):
+ strategy_options = util.preloaded.orm_strategy_options
+
+ clauseelement = self.entity.__clause_element__()
+ stmt = Select._create_raw_select(
+ _raw_columns=[clauseelement],
+ _propagate_attrs=clauseelement._propagate_attrs,
+ _label_style=LABEL_STYLE_TABLENAME_PLUS_COL,
+ _compile_options=ORMCompileState.default_compile_options,
+ )
+ load_options = QueryContext.default_load_options
+
+ load_options += {
+ "_invoke_all_eagers": False,
+ "_lazy_loaded_from": state,
+ }
+
+ if self.parent_property.secondary is not None:
+ stmt = stmt.select_from(
+ self.mapper, self.parent_property.secondary
+ )
+
+ pending = not state.key
+
+ # don't autoflush on pending
+ if pending or passive & attributes.NO_AUTOFLUSH:
+ stmt._execution_options = util.immutabledict({"autoflush": False})
+
+ use_get = self.use_get
+
+ if state.load_options or (loadopt and loadopt._extra_criteria):
+ effective_path = state.load_path[self.parent_property]
+
+ opts = tuple(state.load_options)
+
+ if loadopt and loadopt._extra_criteria:
+ use_get = False
+ opts += (
+ orm_util.LoaderCriteriaOption(self.entity, extra_criteria),
+ )
+
+ stmt._with_options = opts
+ else:
+ # this path is used if there are not already any options
+ # in the query, but an event may want to add them
+ effective_path = state.mapper._path_registry[self.parent_property]
+
+ stmt._compile_options += {"_current_path": effective_path}
+
+ if use_get:
+ if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ self._invoke_raise_load(state, passive, "raise_on_sql")
+
+ return loading.load_on_pk_identity(
+ session, stmt, primary_key_identity, load_options=load_options
+ )
+
+ if self._order_by:
+ stmt._order_by_clauses = self._order_by
+
+ def _lazyload_reverse(compile_context):
+ for rev in self.parent_property._reverse_property:
+ # reverse props that are MANYTOONE are loading *this*
+ # object from get(), so don't need to eager out to those.
+ if (
+ rev.direction is interfaces.MANYTOONE
+ and rev._use_get
+ and not isinstance(rev.strategy, LazyLoader)
+ ):
+ strategy_options.Load.for_existing_path(
+ compile_context.compile_options._current_path[
+ rev.parent
+ ]
+ ).lazyload(rev).process_compile_state(compile_context)
+
+ stmt._with_context_options += (
+ (_lazyload_reverse, self.parent_property),
+ )
+
+ lazy_clause, params = self._generate_lazy_clause(state, passive)
+
+ execution_options = {
+ "_sa_orm_load_options": load_options,
+ }
+
+ if (
+ self.key in state.dict
+ and not passive & attributes.DEFERRED_HISTORY_LOAD
+ ):
+ return attributes.ATTR_WAS_SET
+
+ if pending:
+ if util.has_intersection(orm_util._none_set, params.values()):
+ return None
+
+ elif util.has_intersection(orm_util._never_set, params.values()):
+ return None
+
+ if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ self._invoke_raise_load(state, passive, "raise_on_sql")
+
+ stmt._where_criteria = (lazy_clause,)
+
+ result = session.execute(
+ stmt, params, execution_options=execution_options
+ )
+
+ result = result.unique().scalars().all()
+
+ if self.uselist:
+ return result
+ else:
+ l = len(result)
+ if l:
+ if l > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for lazily-loaded attribute '%s' "
+ % self.parent_property
+ )
+
+ return result[0]
+ else:
+ return None
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ key = self.key
+
+ if not self.is_class_level or (loadopt and loadopt._extra_criteria):
+ # we are not the primary manager for this attribute
+ # on this class - set up a
+ # per-instance lazyloader, which will override the
+ # class-level behavior.
+ # this currently only happens when using a
+ # "lazyload" option on a "no load"
+ # attribute - "eager" attributes always have a
+ # class-level lazyloader installed.
+ set_lazy_callable = (
+ InstanceState._instance_level_callable_processor
+ )(
+ mapper.class_manager,
+ LoadLazyAttribute(
+ key,
+ self,
+ loadopt,
+ loadopt._generate_extra_criteria(context)
+ if loadopt._extra_criteria
+ else None,
+ ),
+ key,
+ )
+
+ populators["new"].append((self.key, set_lazy_callable))
+ elif context.populate_existing or mapper.always_refresh:
+
+ def reset_for_lazy_callable(state, dict_, row):
+ # we are the primary manager for this attribute on
+ # this class - reset its
+ # per-instance attribute state, so that the class-level
+ # lazy loader is
+ # executed when next referenced on this instance.
+ # this is needed in
+ # populate_existing() types of scenarios to reset
+ # any existing state.
+ state._reset(dict_, key)
+
+ populators["new"].append((self.key, reset_for_lazy_callable))
+
+
+class LoadLazyAttribute(object):
+ """semi-serializable loader object used by LazyLoader
+
+ Historically, this object would be carried along with instances that
+ needed to run lazyloaders, so it had to be serializable to support
+ cached instances.
+
+ this is no longer a general requirement, and the case where this object
+ is used is exactly the case where we can't really serialize easily,
+ which is when extra criteria in the loader option is present.
+
+ We can't reliably serialize that as it refers to mapped entities and
+ AliasedClass objects that are local to the current process, which would
+ need to be matched up on deserialize e.g. the sqlalchemy.ext.serializer
+ approach.
+
+ """
+
+ def __init__(self, key, initiating_strategy, loadopt, extra_criteria):
+ self.key = key
+ self.strategy_key = initiating_strategy.strategy_key
+ self.loadopt = loadopt
+ self.extra_criteria = extra_criteria
+
+ def __getstate__(self):
+ if self.extra_criteria is not None:
+ util.warn(
+ "Can't reliably serialize a lazyload() option that "
+ "contains additional criteria; please use eager loading "
+ "for this case"
+ )
+ return {
+ "key": self.key,
+ "strategy_key": self.strategy_key,
+ "loadopt": self.loadopt,
+ "extra_criteria": (),
+ }
+
+ def __call__(self, state, passive=attributes.PASSIVE_OFF):
+ key = self.key
+ instance_mapper = state.manager.mapper
+ prop = instance_mapper._props[key]
+ strategy = prop._strategies[self.strategy_key]
+
+ return strategy._load_for_state(
+ state,
+ passive,
+ loadopt=self.loadopt,
+ extra_criteria=self.extra_criteria,
+ )
+
+
+class PostLoader(AbstractRelationshipLoader):
+ """A relationship loader that emits a second SELECT statement."""
+
+ def _check_recursive_postload(self, context, path, join_depth=None):
+ effective_path = (
+ context.compile_state.current_path or orm_util.PathRegistry.root
+ ) + path
+
+ if loading.PostLoad.path_exists(
+ context, effective_path, self.parent_property
+ ):
+ return True
+
+ path_w_prop = path[self.parent_property]
+ effective_path_w_prop = effective_path[self.parent_property]
+
+ if not path_w_prop.contains(context.attributes, "loader"):
+ if join_depth:
+ if effective_path_w_prop.length / 2 > join_depth:
+ return True
+ elif effective_path_w_prop.contains_mapper(self.mapper):
+ return True
+
+ return False
+
+ def _immediateload_create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ return self.parent_property._get_strategy(
+ (("lazy", "immediate"),)
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+
+@relationships.RelationshipProperty.strategy_for(lazy="immediate")
+class ImmediateLoader(PostLoader):
+ __slots__ = ()
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ def load_immediate(state, dict_, row):
+ state.get_impl(self.key).get(state, dict_, flags)
+
+ if self._check_recursive_postload(context, path):
+ # this will not emit SQL and will only emit for a many-to-one
+ # "use get" load. the "_RELATED" part means it may return
+ # instance even if its expired, since this is a mutually-recursive
+ # load operation.
+ flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE
+ else:
+ flags = attributes.PASSIVE_OFF | attributes.NO_RAISE
+
+ populators["delayed"].append((self.key, load_immediate))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="subquery")
+class SubqueryLoader(PostLoader):
+ __slots__ = ("join_depth",)
+
+ def __init__(self, parent, strategy_key):
+ super(SubqueryLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def _get_leftmost(
+ self,
+ orig_query_entity_index,
+ subq_path,
+ current_compile_state,
+ is_root,
+ ):
+ given_subq_path = subq_path
+ subq_path = subq_path.path
+ subq_mapper = orm_util._class_to_mapper(subq_path[0])
+
+ # determine attributes of the leftmost mapper
+ if (
+ self.parent.isa(subq_mapper)
+ and self.parent_property is subq_path[1]
+ ):
+ leftmost_mapper, leftmost_prop = self.parent, self.parent_property
+ else:
+ leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1]
+
+ if is_root:
+ # the subq_path is also coming from cached state, so when we start
+ # building up this path, it has to also be converted to be in terms
+ # of the current state. this is for the specific case of the entity
+ # is an AliasedClass against a subquery that's not otherwise going
+ # to adapt
+ new_subq_path = current_compile_state._entities[
+ orig_query_entity_index
+ ].entity_zero._path_registry[leftmost_prop]
+ additional = len(subq_path) - len(new_subq_path)
+ if additional:
+ new_subq_path += path_registry.PathRegistry.coerce(
+ subq_path[-additional:]
+ )
+ else:
+ new_subq_path = given_subq_path
+
+ leftmost_cols = leftmost_prop.local_columns
+
+ leftmost_attr = [
+ getattr(
+ new_subq_path.path[0].entity,
+ leftmost_mapper._columntoproperty[c].key,
+ )
+ for c in leftmost_cols
+ ]
+
+ return leftmost_mapper, leftmost_attr, leftmost_prop, new_subq_path
+
+ def _generate_from_original_query(
+ self,
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ orig_entity,
+ ):
+ # reformat the original query
+ # to look only for significant columns
+ q = orig_query._clone().correlate(None)
+
+ # LEGACY: make a Query back from the select() !!
+ # This suits at least two legacy cases:
+ # 1. applications which expect before_compile() to be called
+ # below when we run .subquery() on this query (Keystone)
+ # 2. applications which are doing subqueryload with complex
+ # from_self() queries, as query.subquery() / .statement
+ # has to do the full compile context for multiply-nested
+ # from_self() (Neutron) - see test_subqload_from_self
+ # for demo.
+ q2 = query.Query.__new__(query.Query)
+ q2.__dict__.update(q.__dict__)
+ q = q2
+
+ # set the query's "FROM" list explicitly to what the
+ # FROM list would be in any case, as we will be limiting
+ # the columns in the SELECT list which may no longer include
+ # all entities mentioned in things like WHERE, JOIN, etc.
+ if not q._from_obj:
+ q._enable_assertions = False
+ q.select_from.non_generative(
+ q,
+ *{
+ ent["entity"]
+ for ent in _column_descriptions(
+ orig_query, compile_state=orig_compile_state
+ )
+ if ent["entity"] is not None
+ }
+ )
+
+ # select from the identity columns of the outer (specifically, these
+ # are the 'local_cols' of the property). This will remove other
+ # columns from the query that might suggest the right entity which is
+ # why we do set select_from above. The attributes we have are
+ # coerced and adapted using the original query's adapter, which is
+ # needed only for the case of adapting a subclass column to
+ # that of a polymorphic selectable, e.g. we have
+ # Engineer.primary_language and the entity is Person. All other
+ # adaptations, e.g. from_self, select_entity_from(), will occur
+ # within the new query when it compiles, as the compile_state we are
+ # using here is only a partial one. If the subqueryload is from a
+ # with_polymorphic() or other aliased() object, left_attr will already
+ # be the correct attributes so no adaptation is needed.
+ target_cols = orig_compile_state._adapt_col_list(
+ [
+ sql.coercions.expect(sql.roles.ColumnsClauseRole, o)
+ for o in leftmost_attr
+ ],
+ orig_compile_state._get_current_adapter(),
+ )
+ q._raw_columns = target_cols
+
+ distinct_target_key = leftmost_relationship.distinct_target_key
+
+ if distinct_target_key is True:
+ q._distinct = True
+ elif distinct_target_key is None:
+ # if target_cols refer to a non-primary key or only
+ # part of a composite primary key, set the q as distinct
+ for t in set(c.table for c in target_cols):
+ if not set(target_cols).issuperset(t.primary_key):
+ q._distinct = True
+ break
+
+ # don't need ORDER BY if no limit/offset
+ if not q._has_row_limiting_clause:
+ q._order_by_clauses = ()
+
+ if q._distinct is True and q._order_by_clauses:
+ # the logic to automatically add the order by columns to the query
+ # when distinct is True is deprecated in the query
+ to_add = sql_util.expand_column_list_from_order_by(
+ target_cols, q._order_by_clauses
+ )
+ if to_add:
+ q._set_entities(target_cols + to_add)
+
+ # the original query now becomes a subquery
+ # which we'll join onto.
+ # LEGACY: as "q" is a Query, the before_compile() event is invoked
+ # here.
+ embed_q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).subquery()
+ left_alias = orm_util.AliasedClass(
+ leftmost_mapper, embed_q, use_mapper_path=True
+ )
+ return left_alias
+
+ def _prep_for_joins(self, left_alias, subq_path):
+ # figure out what's being joined. a.k.a. the fun part
+ to_join = []
+ pairs = list(subq_path.pairs())
+
+ for i, (mapper, prop) in enumerate(pairs):
+ if i > 0:
+ # look at the previous mapper in the chain -
+ # if it is as or more specific than this prop's
+ # mapper, use that instead.
+ # note we have an assumption here that
+ # the non-first element is always going to be a mapper,
+ # not an AliasedClass
+
+ prev_mapper = pairs[i - 1][1].mapper
+ to_append = prev_mapper if prev_mapper.isa(mapper) else mapper
+ else:
+ to_append = mapper
+
+ to_join.append((to_append, prop.key))
+
+ # determine the immediate parent class we are joining from,
+ # which needs to be aliased.
+
+ if len(to_join) < 2:
+ # in the case of a one level eager load, this is the
+ # leftmost "left_alias".
+ parent_alias = left_alias
+ else:
+ info = inspect(to_join[-1][0])
+ if info.is_aliased_class:
+ parent_alias = info.entity
+ else:
+ # alias a plain mapper as we may be
+ # joining multiple times
+ parent_alias = orm_util.AliasedClass(
+ info.entity, use_mapper_path=True
+ )
+
+ local_cols = self.parent_property.local_columns
+
+ local_attr = [
+ getattr(parent_alias, self.parent._columntoproperty[c].key)
+ for c in local_cols
+ ]
+ return to_join, local_attr, parent_alias
+
+ def _apply_joins(
+ self, q, to_join, left_alias, parent_alias, effective_entity
+ ):
+
+ ltj = len(to_join)
+ if ltj == 1:
+ to_join = [
+ getattr(left_alias, to_join[0][1]).of_type(effective_entity)
+ ]
+ elif ltj == 2:
+ to_join = [
+ getattr(left_alias, to_join[0][1]).of_type(parent_alias),
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ ),
+ ]
+ elif ltj > 2:
+ middle = [
+ (
+ orm_util.AliasedClass(item[0])
+ if not inspect(item[0]).is_aliased_class
+ else item[0].entity,
+ item[1],
+ )
+ for item in to_join[1:-1]
+ ]
+ inner = []
+
+ while middle:
+ item = middle.pop(0)
+ attr = getattr(item[0], item[1])
+ if middle:
+ attr = attr.of_type(middle[0][0])
+ else:
+ attr = attr.of_type(parent_alias)
+
+ inner.append(attr)
+
+ to_join = (
+ [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)]
+ + inner
+ + [
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ )
+ ]
+ )
+
+ for attr in to_join:
+ q = q.join(attr)
+
+ return q
+
+ def _setup_options(
+ self,
+ context,
+ q,
+ subq_path,
+ rewritten_path,
+ orig_query,
+ effective_entity,
+ loadopt,
+ ):
+
+ # note that because the subqueryload object
+ # does not re-use the cached query, instead always making
+ # use of the current invoked query, while we have two queries
+ # here (orig and context.query), they are both non-cached
+ # queries and we can transfer the options as is without
+ # adjusting for new criteria. Some work on #6881 / #6889
+ # brought this into question.
+ new_options = orig_query._with_options
+
+ if loadopt and loadopt._extra_criteria:
+
+ new_options += (
+ orm_util.LoaderCriteriaOption(
+ self.entity,
+ loadopt._generate_extra_criteria(context),
+ ),
+ )
+
+ # propagate loader options etc. to the new query.
+ # these will fire relative to subq_path.
+ q = q._with_current_path(rewritten_path)
+ q = q.options(*new_options)
+
+ return q
+
+ def _setup_outermost_orderby(self, q):
+ if self.parent_property.order_by:
+
+ def _setup_outermost_orderby(compile_context):
+ compile_context.eager_order_by += tuple(
+ util.to_list(self.parent_property.order_by)
+ )
+
+ q = q._add_context_option(
+ _setup_outermost_orderby, self.parent_property
+ )
+
+ return q
+
+ class _SubqCollections(object):
+ """Given a :class:`_query.Query` used to emit the "subquery load",
+ provide a load interface that executes the query at the
+ first moment a value is needed.
+
+ """
+
+ __slots__ = (
+ "session",
+ "execution_options",
+ "load_options",
+ "params",
+ "subq",
+ "_data",
+ )
+
+ def __init__(self, context, subq):
+ # avoid creating a cycle by storing context
+ # even though that's preferable
+ self.session = context.session
+ self.execution_options = context.execution_options
+ self.load_options = context.load_options
+ self.params = context.params or {}
+ self.subq = subq
+ self._data = None
+
+ def get(self, key, default):
+ if self._data is None:
+ self._load()
+ return self._data.get(key, default)
+
+ def _load(self):
+ self._data = collections.defaultdict(list)
+
+ q = self.subq
+ assert q.session is None
+
+ q = q.with_session(self.session)
+
+ if self.load_options._populate_existing:
+ q = q.populate_existing()
+ # to work with baked query, the parameters may have been
+ # updated since this query was created, so take these into account
+
+ rows = list(q.params(self.params))
+ for k, v in itertools.groupby(rows, lambda x: x[1:]):
+ self._data[k].extend(vv[0] for vv in v)
+
+ def loader(self, state, dict_, row):
+ if self._data is None:
+ self._load()
+
+ def _setup_query_from_rowproc(
+ self,
+ context,
+ query_entity,
+ path,
+ entity,
+ loadopt,
+ adapter,
+ ):
+ compile_state = context.compile_state
+ if (
+ not compile_state.compile_options._enable_eagerloads
+ or compile_state.compile_options._for_refresh_state
+ ):
+ return
+
+ orig_query_entity_index = compile_state._entities.index(query_entity)
+ context.loaders_require_buffering = True
+
+ path = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = with_poly_entity
+ else:
+ effective_entity = self.entity
+
+ subq_path, rewritten_path = context.query._execution_options.get(
+ ("subquery_paths", None),
+ (orm_util.PathRegistry.root, orm_util.PathRegistry.root),
+ )
+ is_root = subq_path is orm_util.PathRegistry.root
+ subq_path = subq_path + path
+ rewritten_path = rewritten_path + path
+
+ # if not via query option, check for
+ # a cycle
+ # TODO: why is this here??? this is now handled
+ # by the _check_recursive_postload call
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if (
+ (
+ compile_state.current_path.length
+ if compile_state.current_path
+ else 0
+ )
+ + path.length
+ ) / 2 > self.join_depth:
+ return
+ elif subq_path.contains_mapper(self.mapper):
+ return
+
+ # use the current query being invoked, not the compile state
+ # one. this is so that we get the current parameters. however,
+ # it means we can't use the existing compile state, we have to make
+ # a new one. other approaches include possibly using the
+ # compiled query but swapping the params, seems only marginally
+ # less time spent but more complicated
+ orig_query = context.query._execution_options.get(
+ ("orig_query", SubqueryLoader), context.query
+ )
+
+ # make a new compile_state for the query that's probably cached, but
+ # we're sort of undoing a bit of that caching :(
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ orig_query, "orm"
+ )
+
+ if orig_query._is_lambda_element:
+ if context.load_options._lazy_loaded_from is None:
+ util.warn(
+ 'subqueryloader for "%s" must invoke lambda callable '
+ "at %r in "
+ "order to produce a new query, decreasing the efficiency "
+ "of caching for this statement. Consider using "
+ "selectinload() for more effective full-lambda caching"
+ % (self, orig_query)
+ )
+ orig_query = orig_query._resolved
+
+ # this is the more "quick" version, however it's not clear how
+ # much of this we need. in particular I can't get a test to
+ # fail if the "set_base_alias" is missing and not sure why that is.
+ orig_compile_state = compile_state_cls._create_entities_collection(
+ orig_query, legacy=False
+ )
+
+ (
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ rewritten_path,
+ ) = self._get_leftmost(
+ orig_query_entity_index,
+ rewritten_path,
+ orig_compile_state,
+ is_root,
+ )
+
+ # generate a new Query from the original, then
+ # produce a subquery from it.
+ left_alias = self._generate_from_original_query(
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ entity,
+ )
+
+ # generate another Query that will join the
+ # left alias to the target relationships.
+ # basically doing a longhand
+ # "from_self()". (from_self() itself not quite industrial
+ # strength enough for all contingencies...but very close)
+
+ q = query.Query(effective_entity)
+
+ q._execution_options = q._execution_options.union(
+ {
+ ("orig_query", SubqueryLoader): orig_query,
+ ("subquery_paths", None): (subq_path, rewritten_path),
+ }
+ )
+
+ q = q._set_enable_single_crit(False)
+ to_join, local_attr, parent_alias = self._prep_for_joins(
+ left_alias, subq_path
+ )
+
+ q = q.add_columns(*local_attr)
+ q = self._apply_joins(
+ q, to_join, left_alias, parent_alias, effective_entity
+ )
+
+ q = self._setup_options(
+ context,
+ q,
+ subq_path,
+ rewritten_path,
+ orig_query,
+ effective_entity,
+ loadopt,
+ )
+ q = self._setup_outermost_orderby(q)
+
+ return q
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ if context.refresh_state:
+ return self._immediateload_create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+ # the subqueryloader does a similar check in setup_query() unlike
+ # the other post loaders, however we have this here for consistency
+ elif self._check_recursive_postload(context, path, self.join_depth):
+ return
+ elif not isinstance(context.compile_state, ORMSelectCompileState):
+ # issue 7505 - subqueryload() in 1.3 and previous would silently
+ # degrade for from_statement() without warning. this behavior
+ # is restored here
+ return
+
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
+ subq = self._setup_query_from_rowproc(
+ context,
+ query_entity,
+ path,
+ path[-1],
+ loadopt,
+ adapter,
+ )
+
+ if subq is None:
+ return
+
+ assert subq.session is None
+
+ path = path[self.parent_property]
+
+ local_cols = self.parent_property.local_columns
+
+ # cache the loaded collections in the context
+ # so that inheriting mappers don't re-load when they
+ # call upon create_row_processor again
+ collections = path.get(context.attributes, "collections")
+ if collections is None:
+ collections = self._SubqCollections(context, subq)
+ path.set(context.attributes, "collections", collections)
+
+ if adapter:
+ local_cols = [adapter.columns[c] for c in local_cols]
+
+ if self.uselist:
+ self._create_collection_loader(
+ context, result, collections, local_cols, populators
+ )
+ else:
+ self._create_scalar_loader(
+ context, result, collections, local_cols, populators
+ )
+
+ def _create_collection_loader(
+ self, context, result, collections, local_cols, populators
+ ):
+ tuple_getter = result._tuple_getter(local_cols)
+
+ def load_collection_from_subq(state, dict_, row):
+ collection = collections.get(tuple_getter(row), ())
+ state.get_impl(self.key).set_committed_value(
+ state, dict_, collection
+ )
+
+ def load_collection_from_subq_existing_row(state, dict_, row):
+ if self.key not in dict_:
+ load_collection_from_subq(state, dict_, row)
+
+ populators["new"].append((self.key, load_collection_from_subq))
+ populators["existing"].append(
+ (self.key, load_collection_from_subq_existing_row)
+ )
+
+ if context.invoke_all_eagers:
+ populators["eager"].append((self.key, collections.loader))
+
+ def _create_scalar_loader(
+ self, context, result, collections, local_cols, populators
+ ):
+ tuple_getter = result._tuple_getter(local_cols)
+
+ def load_scalar_from_subq(state, dict_, row):
+ collection = collections.get(tuple_getter(row), (None,))
+ if len(collection) > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded attribute '%s' " % self
+ )
+
+ scalar = collection[0]
+ state.get_impl(self.key).set_committed_value(state, dict_, scalar)
+
+ def load_scalar_from_subq_existing_row(state, dict_, row):
+ if self.key not in dict_:
+ load_scalar_from_subq(state, dict_, row)
+
+ populators["new"].append((self.key, load_scalar_from_subq))
+ populators["existing"].append(
+ (self.key, load_scalar_from_subq_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append((self.key, collections.loader))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="joined")
+@relationships.RelationshipProperty.strategy_for(lazy=False)
+class JoinedLoader(AbstractRelationshipLoader):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ using joined eager loading.
+
+ """
+
+ __slots__ = "join_depth", "_aliased_class_pool"
+
+ def __init__(self, parent, strategy_key):
+ super(JoinedLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+ self._aliased_class_pool = []
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ parentmapper=None,
+ chained_from_outerjoin=False,
+ **kwargs
+ ):
+ """Add a left outer join to the statement that's being constructed."""
+
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+ elif self.uselist:
+ compile_state.multi_row_eager_loaders = True
+
+ path = path[self.parent_property]
+
+ with_polymorphic = None
+
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(
+ loadopt, compile_state, compile_state.attributes
+ )
+ if loadopt
+ else False
+ )
+
+ if user_defined_adapter is not False:
+ (
+ clauses,
+ adapter,
+ add_to_collection,
+ ) = self._setup_query_on_user_defined_adapter(
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ user_defined_adapter,
+ )
+ else:
+ # if not via query option, check for
+ # a cycle
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if path.length / 2 > self.join_depth:
+ return
+ elif path.contains_mapper(self.mapper):
+ return
+
+ (
+ clauses,
+ adapter,
+ add_to_collection,
+ chained_from_outerjoin,
+ ) = self._generate_row_adapter(
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ )
+
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ with_polymorphic = inspect(
+ with_poly_entity
+ ).with_polymorphic_mappers
+ else:
+ with_polymorphic = None
+
+ path = path[self.entity]
+
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ query_entity,
+ path,
+ clauses,
+ add_to_collection,
+ with_polymorphic=with_polymorphic,
+ parentmapper=self.mapper,
+ chained_from_outerjoin=chained_from_outerjoin,
+ )
+
+ if with_poly_entity is not None and None in set(
+ compile_state.secondary_columns
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Detected unaliased columns when generating joined "
+ "load. Make sure to use aliased=True or flat=True "
+ "when using joined loading with with_polymorphic()."
+ )
+
+ def _init_user_defined_eager_proc(
+ self, loadopt, compile_state, target_attributes
+ ):
+
+ # check if the opt applies at all
+ if "eager_from_alias" not in loadopt.local_opts:
+ # nope
+ return False
+
+ path = loadopt.path.parent
+
+ # the option applies. check if the "user_defined_eager_row_processor"
+ # has been built up.
+ adapter = path.get(
+ compile_state.attributes, "user_defined_eager_row_processor", False
+ )
+ if adapter is not False:
+ # just return it
+ return adapter
+
+ # otherwise figure it out.
+ alias = loadopt.local_opts["eager_from_alias"]
+ root_mapper, prop = path[-2:]
+
+ if alias is not None:
+ if isinstance(alias, str):
+ alias = prop.target.alias(alias)
+ adapter = sql_util.ColumnAdapter(
+ alias, equivalents=prop.mapper._equivalent_columns
+ )
+ else:
+ if path.contains(
+ compile_state.attributes, "path_with_polymorphic"
+ ):
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic"
+ )
+ adapter = orm_util.ORMAdapter(
+ with_poly_entity,
+ equivalents=prop.mapper._equivalent_columns,
+ )
+ else:
+ adapter = compile_state._polymorphic_adapters.get(
+ prop.mapper, None
+ )
+ path.set(
+ target_attributes,
+ "user_defined_eager_row_processor",
+ adapter,
+ )
+
+ return adapter
+
+ def _setup_query_on_user_defined_adapter(
+ self, context, entity, path, adapter, user_defined_adapter
+ ):
+
+ # apply some more wrapping to the "user defined adapter"
+ # if we are setting up the query for SQL render.
+ adapter = entity._get_entity_clauses(context)
+
+ if adapter and user_defined_adapter:
+ user_defined_adapter = user_defined_adapter.wrap(adapter)
+ path.set(
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
+ elif adapter:
+ user_defined_adapter = adapter
+ path.set(
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
+
+ add_to_collection = context.primary_columns
+ return user_defined_adapter, adapter, add_to_collection
+
+ def _gen_pooled_aliased_class(self, context):
+ # keep a local pool of AliasedClass objects that get re-used.
+ # we need one unique AliasedClass per query per appearance of our
+ # entity in the query.
+
+ if inspect(self.entity).is_aliased_class:
+ alt_selectable = inspect(self.entity).selectable
+ else:
+ alt_selectable = None
+
+ key = ("joinedloader_ac", self)
+ if key not in context.attributes:
+ context.attributes[key] = idx = 0
+ else:
+ context.attributes[key] = idx = context.attributes[key] + 1
+
+ if idx >= len(self._aliased_class_pool):
+ to_adapt = orm_util.AliasedClass(
+ self.mapper,
+ alias=alt_selectable._anonymous_fromclause(flat=True)
+ if alt_selectable is not None
+ else None,
+ flat=True,
+ use_mapper_path=True,
+ )
+
+ # load up the .columns collection on the Alias() before
+ # the object becomes shared among threads. this prevents
+ # races for column identities.
+ inspect(to_adapt).selectable.c
+ self._aliased_class_pool.append(to_adapt)
+
+ return self._aliased_class_pool[idx]
+
+ def _generate_row_adapter(
+ self,
+ compile_state,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ ):
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity:
+ to_adapt = with_poly_entity
+ else:
+ to_adapt = self._gen_pooled_aliased_class(compile_state)
+
+ clauses = inspect(to_adapt)._memo(
+ ("joinedloader_ormadapter", self),
+ orm_util.ORMAdapter,
+ to_adapt,
+ equivalents=self.mapper._equivalent_columns,
+ adapt_required=True,
+ allow_label_resolve=False,
+ anonymize_labels=True,
+ )
+
+ assert clauses.aliased_class is not None
+
+ innerjoin = (
+ loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin)
+ if loadopt is not None
+ else self.parent_property.innerjoin
+ )
+
+ if not innerjoin:
+ # if this is an outer join, all non-nested eager joins from
+ # this path must also be outer joins
+ chained_from_outerjoin = True
+
+ compile_state.create_eager_joins.append(
+ (
+ self._create_eager_join,
+ entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
+ loadopt._extra_criteria if loadopt else (),
+ )
+ )
+
+ add_to_collection = compile_state.secondary_columns
+ path.set(compile_state.attributes, "eager_row_processor", clauses)
+
+ return clauses, adapter, add_to_collection, chained_from_outerjoin
+
+ def _create_eager_join(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
+ extra_criteria,
+ ):
+ if parentmapper is None:
+ localparent = query_entity.mapper
+ else:
+ localparent = parentmapper
+
+ # whether or not the Query will wrap the selectable in a subquery,
+ # and then attach eager load joins to that (i.e., in the case of
+ # LIMIT/OFFSET etc.)
+ should_nest_selectable = (
+ compile_state.multi_row_eager_loaders
+ and compile_state._should_nest_selectable
+ )
+
+ query_entity_key = None
+
+ if (
+ query_entity not in compile_state.eager_joins
+ and not should_nest_selectable
+ and compile_state.from_clauses
+ ):
+
+ indexes = sql_util.find_left_clause_that_matches_given(
+ compile_state.from_clauses, query_entity.selectable
+ )
+
+ if len(indexes) > 1:
+ # for the eager load case, I can't reproduce this right
+ # now. For query.join() I can.
+ raise sa_exc.InvalidRequestError(
+ "Can't identify which query entity in which to joined "
+ "eager load from. Please use an exact match when "
+ "specifying the join path."
+ )
+
+ if indexes:
+ clause = compile_state.from_clauses[indexes[0]]
+ # join to an existing FROM clause on the query.
+ # key it to its list index in the eager_joins dict.
+ # Query._compile_context will adapt as needed and
+ # append to the FROM clause of the select().
+ query_entity_key, default_towrap = indexes[0], clause
+
+ if query_entity_key is None:
+ query_entity_key, default_towrap = (
+ query_entity,
+ query_entity.selectable,
+ )
+
+ towrap = compile_state.eager_joins.setdefault(
+ query_entity_key, default_towrap
+ )
+
+ if adapter:
+ if getattr(adapter, "aliased_class", None):
+ # joining from an adapted entity. The adapted entity
+ # might be a "with_polymorphic", so resolve that to our
+ # specific mapper's entity before looking for our attribute
+ # name on it.
+ efm = inspect(adapter.aliased_class)._entity_for_mapper(
+ localparent
+ if localparent.isa(self.parent)
+ else self.parent
+ )
+
+ # look for our attribute on the adapted entity, else fall back
+ # to our straight property
+ onclause = getattr(efm.entity, self.key, self.parent_property)
+ else:
+ onclause = getattr(
+ orm_util.AliasedClass(
+ self.parent, adapter.selectable, use_mapper_path=True
+ ),
+ self.key,
+ self.parent_property,
+ )
+
+ else:
+ onclause = self.parent_property
+
+ assert clauses.aliased_class is not None
+
+ attach_on_outside = (
+ not chained_from_outerjoin
+ or not innerjoin
+ or innerjoin == "unnested"
+ or query_entity.entity_zero.represents_outer_join
+ )
+
+ extra_join_criteria = extra_criteria
+ additional_entity_criteria = compile_state.global_attributes.get(
+ ("additional_entity_criteria", self.mapper), ()
+ )
+ if additional_entity_criteria:
+ extra_join_criteria += tuple(
+ ae._resolve_where_criteria(self.mapper)
+ for ae in additional_entity_criteria
+ if ae.propagate_to_loaders
+ )
+
+ if attach_on_outside:
+ # this is the "classic" eager join case.
+ eagerjoin = orm_util._ORMJoin(
+ towrap,
+ clauses.aliased_class,
+ onclause,
+ isouter=not innerjoin
+ or query_entity.entity_zero.represents_outer_join
+ or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
+ _left_memo=self.parent,
+ _right_memo=self.mapper,
+ _extra_criteria=extra_join_criteria,
+ )
+ else:
+ # all other cases are innerjoin=='nested' approach
+ eagerjoin = self._splice_nested_inner_join(
+ path, towrap, clauses, onclause, extra_join_criteria
+ )
+
+ compile_state.eager_joins[query_entity_key] = eagerjoin
+
+ # send a hint to the Query as to where it may "splice" this join
+ eagerjoin.stop_on = query_entity.selectable
+
+ if not parentmapper:
+ # for parentclause that is the non-eager end of the join,
+ # ensure all the parent cols in the primaryjoin are actually
+ # in the
+ # columns clause (i.e. are not deferred), so that aliasing applied
+ # by the Query propagates those columns outward.
+ # This has the effect
+ # of "undefering" those columns.
+ for col in sql_util._find_columns(
+ self.parent_property.primaryjoin
+ ):
+ if localparent.persist_selectable.c.contains_column(col):
+ if adapter:
+ col = adapter.columns[col]
+ compile_state._append_dedupe_col_collection(
+ col, compile_state.primary_columns
+ )
+
+ if self.parent_property.order_by:
+ compile_state.eager_order_by += tuple(
+ (eagerjoin._target_adapter.copy_and_process)(
+ util.to_list(self.parent_property.order_by)
+ )
+ )
+
+ def _splice_nested_inner_join(
+ self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
+ ):
+
+ if splicing is False:
+ # first call is always handed a join object
+ # from the outside
+ assert isinstance(join_obj, orm_util._ORMJoin)
+ elif isinstance(join_obj, sql.selectable.FromGrouping):
+ return self._splice_nested_inner_join(
+ path,
+ join_obj.element,
+ clauses,
+ onclause,
+ extra_criteria,
+ splicing,
+ )
+ elif not isinstance(join_obj, orm_util._ORMJoin):
+ if path[-2] is splicing:
+ return orm_util._ORMJoin(
+ join_obj,
+ clauses.aliased_class,
+ onclause,
+ isouter=False,
+ _left_memo=splicing,
+ _right_memo=path[-1].mapper,
+ _extra_criteria=extra_criteria,
+ )
+ else:
+ # only here if splicing == True
+ return None
+
+ target_join = self._splice_nested_inner_join(
+ path,
+ join_obj.right,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._right_memo,
+ )
+ if target_join is None:
+ right_splice = False
+ target_join = self._splice_nested_inner_join(
+ path,
+ join_obj.left,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._left_memo,
+ )
+ if target_join is None:
+ # should only return None when recursively called,
+ # e.g. splicing==True
+ assert (
+ splicing is not False
+ ), "assertion failed attempting to produce joined eager loads"
+ return None
+ else:
+ right_splice = True
+
+ if right_splice:
+ # for a right splice, attempt to flatten out
+ # a JOIN b JOIN c JOIN .. to avoid needless
+ # parenthesis nesting
+ if not join_obj.isouter and not target_join.isouter:
+ eagerjoin = join_obj._splice_into_center(target_join)
+ else:
+ eagerjoin = orm_util._ORMJoin(
+ join_obj.left,
+ target_join,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _left_memo=join_obj._left_memo,
+ )
+ else:
+ eagerjoin = orm_util._ORMJoin(
+ target_join,
+ join_obj.right,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _right_memo=join_obj._right_memo,
+ )
+
+ eagerjoin._target_adapter = target_join._target_adapter
+ return eagerjoin
+
+ def _create_eager_adapter(self, context, result, adapter, path, loadopt):
+ compile_state = context.compile_state
+
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(
+ loadopt, compile_state, context.attributes
+ )
+ if loadopt
+ else False
+ )
+
+ if user_defined_adapter is not False:
+ decorator = user_defined_adapter
+ # user defined eagerloads are part of the "primary"
+ # portion of the load.
+ # the adapters applied to the Query should be honored.
+ if compile_state.compound_eager_adapter and decorator:
+ decorator = decorator.wrap(
+ compile_state.compound_eager_adapter
+ )
+ elif compile_state.compound_eager_adapter:
+ decorator = compile_state.compound_eager_adapter
+ else:
+ decorator = path.get(
+ compile_state.attributes, "eager_row_processor"
+ )
+ if decorator is None:
+ return False
+
+ if self.mapper._result_has_identity_key(result, decorator):
+ return decorator
+ else:
+ # no identity key - don't return a row
+ # processor, will cause a degrade to lazy
+ return False
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ if self.uselist:
+ context.loaders_require_uniquing = True
+
+ our_path = path[self.parent_property]
+
+ eager_adapter = self._create_eager_adapter(
+ context, result, adapter, our_path, loadopt
+ )
+
+ if eager_adapter is not False:
+ key = self.key
+
+ _instance = loading._instance_processor(
+ query_entity,
+ self.mapper,
+ context,
+ result,
+ our_path[self.entity],
+ eager_adapter,
+ )
+
+ if not self.uselist:
+ self._create_scalar_loader(context, key, _instance, populators)
+ else:
+ self._create_collection_loader(
+ context, key, _instance, populators
+ )
+ else:
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ def _create_collection_loader(self, context, key, _instance, populators):
+ def load_collection_from_joined_new_row(state, dict_, row):
+ # note this must unconditionally clear out any existing collection.
+ # an existing collection would be present only in the case of
+ # populate_existing().
+ collection = attributes.init_state_collection(state, dict_, key)
+ result_list = util.UniqueAppender(
+ collection, "append_without_event"
+ )
+ context.attributes[(state, key)] = result_list
+ inst = _instance(row)
+ if inst is not None:
+ result_list.append(inst)
+
+ def load_collection_from_joined_existing_row(state, dict_, row):
+ if (state, key) in context.attributes:
+ result_list = context.attributes[(state, key)]
+ else:
+ # appender_key can be absent from context.attributes
+ # with isnew=False when self-referential eager loading
+ # is used; the same instance may be present in two
+ # distinct sets of result columns
+ collection = attributes.init_state_collection(
+ state, dict_, key
+ )
+ result_list = util.UniqueAppender(
+ collection, "append_without_event"
+ )
+ context.attributes[(state, key)] = result_list
+ inst = _instance(row)
+ if inst is not None:
+ result_list.append(inst)
+
+ def load_collection_from_joined_exec(state, dict_, row):
+ _instance(row)
+
+ populators["new"].append(
+ (self.key, load_collection_from_joined_new_row)
+ )
+ populators["existing"].append(
+ (self.key, load_collection_from_joined_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append(
+ (self.key, load_collection_from_joined_exec)
+ )
+
+ def _create_scalar_loader(self, context, key, _instance, populators):
+ def load_scalar_from_joined_new_row(state, dict_, row):
+ # set a scalar object instance directly on the parent
+ # object, bypassing InstrumentedAttribute event handlers.
+ dict_[key] = _instance(row)
+
+ def load_scalar_from_joined_existing_row(state, dict_, row):
+ # call _instance on the row, even though the object has
+ # been created, so that we further descend into properties
+ existing = _instance(row)
+
+ # conflicting value already loaded, this shouldn't happen
+ if key in dict_:
+ if existing is not dict_[key]:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded attribute '%s' "
+ % self
+ )
+ else:
+ # this case is when one row has multiple loads of the
+ # same entity (e.g. via aliasing), one has an attribute
+ # that the other doesn't.
+ dict_[key] = existing
+
+ def load_scalar_from_joined_exec(state, dict_, row):
+ _instance(row)
+
+ populators["new"].append((self.key, load_scalar_from_joined_new_row))
+ populators["existing"].append(
+ (self.key, load_scalar_from_joined_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append(
+ (self.key, load_scalar_from_joined_exec)
+ )
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="selectin")
+class SelectInLoader(PostLoader, util.MemoizedSlots):
+ __slots__ = (
+ "join_depth",
+ "omit_join",
+ "_parent_alias",
+ "_query_info",
+ "_fallback_query_info",
+ )
+
+ query_info = collections.namedtuple(
+ "queryinfo",
+ [
+ "load_only_child",
+ "load_with_join",
+ "in_expr",
+ "pk_cols",
+ "zero_idx",
+ "child_lookup_cols",
+ ],
+ )
+
+ _chunksize = 500
+
+ def __init__(self, parent, strategy_key):
+ super(SelectInLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+ is_m2o = self.parent_property.direction is interfaces.MANYTOONE
+
+ if self.parent_property.omit_join is not None:
+ self.omit_join = self.parent_property.omit_join
+ else:
+ lazyloader = self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ )
+ if is_m2o:
+ self.omit_join = lazyloader.use_get
+ else:
+ self.omit_join = self.parent._get_clause[0].compare(
+ lazyloader._rev_lazywhere,
+ use_proxies=True,
+ compare_keys=False,
+ equivalents=self.parent._equivalent_columns,
+ )
+
+ if self.omit_join:
+ if is_m2o:
+ self._query_info = self._init_for_omit_join_m2o()
+ self._fallback_query_info = self._init_for_join()
+ else:
+ self._query_info = self._init_for_omit_join()
+ else:
+ self._query_info = self._init_for_join()
+
+ def _init_for_omit_join(self):
+ pk_to_fk = dict(
+ self.parent_property._join_condition.local_remote_pairs
+ )
+ pk_to_fk.update(
+ (equiv, pk_to_fk[k])
+ for k in list(pk_to_fk)
+ for equiv in self.parent._equivalent_columns.get(k, ())
+ )
+
+ pk_cols = fk_cols = [
+ pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk
+ ]
+ if len(fk_cols) > 1:
+ in_expr = sql.tuple_(*fk_cols)
+ zero_idx = False
+ else:
+ in_expr = fk_cols[0]
+ zero_idx = True
+
+ return self.query_info(False, False, in_expr, pk_cols, zero_idx, None)
+
+ def _init_for_omit_join_m2o(self):
+ pk_cols = self.mapper.primary_key
+ if len(pk_cols) > 1:
+ in_expr = sql.tuple_(*pk_cols)
+ zero_idx = False
+ else:
+ in_expr = pk_cols[0]
+ zero_idx = True
+
+ lazyloader = self.parent_property._get_strategy((("lazy", "select"),))
+ lookup_cols = [lazyloader._equated_columns[pk] for pk in pk_cols]
+
+ return self.query_info(
+ True, False, in_expr, pk_cols, zero_idx, lookup_cols
+ )
+
+ def _init_for_join(self):
+ self._parent_alias = aliased(self.parent.class_)
+ pa_insp = inspect(self._parent_alias)
+ pk_cols = [
+ pa_insp._adapt_element(col) for col in self.parent.primary_key
+ ]
+ if len(pk_cols) > 1:
+ in_expr = sql.tuple_(*pk_cols)
+ zero_idx = False
+ else:
+ in_expr = pk_cols[0]
+ zero_idx = True
+ return self.query_info(False, True, in_expr, pk_cols, zero_idx, None)
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ if context.refresh_state:
+ return self._immediateload_create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+ elif self._check_recursive_postload(context, path, self.join_depth):
+ return
+
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
+ selectin_path = (
+ context.compile_state.current_path or orm_util.PathRegistry.root
+ ) + path
+
+ path_w_prop = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path_w_prop.get(
+ context.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = inspect(with_poly_entity)
+ else:
+ effective_entity = self.entity
+
+ loading.PostLoad.callable_for_path(
+ context,
+ selectin_path,
+ self.parent,
+ self.parent_property,
+ self._load_for_path,
+ effective_entity,
+ loadopt,
+ )
+
+ def _load_for_path(
+ self, context, path, states, load_only, effective_entity, loadopt
+ ):
+ if load_only and self.key not in load_only:
+ return
+
+ query_info = self._query_info
+
+ if query_info.load_only_child:
+ our_states = collections.defaultdict(list)
+ none_states = []
+
+ mapper = self.parent
+
+ for state, overwrite in states:
+ state_dict = state.dict
+ related_ident = tuple(
+ mapper._get_state_attr_by_column(
+ state,
+ state_dict,
+ lk,
+ passive=attributes.PASSIVE_NO_FETCH,
+ )
+ for lk in query_info.child_lookup_cols
+ )
+ # if the loaded parent objects do not have the foreign key
+ # to the related item loaded, then degrade into the joined
+ # version of selectinload
+ if attributes.PASSIVE_NO_RESULT in related_ident:
+ query_info = self._fallback_query_info
+ break
+
+ # organize states into lists keyed to particular foreign
+ # key values.
+ if None not in related_ident:
+ our_states[related_ident].append(
+ (state, state_dict, overwrite)
+ )
+ else:
+ # For FK values that have None, add them to a
+ # separate collection that will be populated separately
+ none_states.append((state, state_dict, overwrite))
+
+ # note the above conditional may have changed query_info
+ if not query_info.load_only_child:
+ our_states = [
+ (state.key[1], state, state.dict, overwrite)
+ for state, overwrite in states
+ ]
+
+ pk_cols = query_info.pk_cols
+ in_expr = query_info.in_expr
+
+ if not query_info.load_with_join:
+ # in "omit join" mode, the primary key column and the
+ # "in" expression are in terms of the related entity. So
+ # if the related entity is polymorphic or otherwise aliased,
+ # we need to adapt our "pk_cols" and "in_expr" to that
+ # entity. in non-"omit join" mode, these are against the
+ # parent entity and do not need adaption.
+ if effective_entity.is_aliased_class:
+ pk_cols = [
+ effective_entity._adapt_element(col) for col in pk_cols
+ ]
+ in_expr = effective_entity._adapt_element(in_expr)
+
+ bundle_ent = orm_util.Bundle("pk", *pk_cols)
+ bundle_sql = bundle_ent.__clause_element__()
+
+ entity_sql = effective_entity.__clause_element__()
+ q = Select._create_raw_select(
+ _raw_columns=[bundle_sql, entity_sql],
+ _label_style=LABEL_STYLE_TABLENAME_PLUS_COL,
+ _compile_options=ORMCompileState.default_compile_options,
+ _propagate_attrs={
+ "compile_state_plugin": "orm",
+ "plugin_subject": effective_entity,
+ },
+ )
+
+ if not query_info.load_with_join:
+ # the Bundle we have in the "omit_join" case is against raw, non
+ # annotated columns, so to ensure the Query knows its primary
+ # entity, we add it explicitly. If we made the Bundle against
+ # annotated columns, we hit a performance issue in this specific
+ # case, which is detailed in issue #4347.
+ q = q.select_from(effective_entity)
+ else:
+ # in the non-omit_join case, the Bundle is against the annotated/
+ # mapped column of the parent entity, but the #4347 issue does not
+ # occur in this case.
+ q = q.select_from(self._parent_alias).join(
+ getattr(self._parent_alias, self.parent_property.key).of_type(
+ effective_entity
+ )
+ )
+
+ q = q.filter(in_expr.in_(sql.bindparam("primary_keys")))
+
+ # a test which exercises what these comments talk about is
+ # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
+ #
+ # effective_entity above is given to us in terms of the cached
+ # statement, namely this one:
+ orig_query = context.compile_state.select_statement
+
+ # the actual statement that was requested is this one:
+ # context_query = context.query
+ #
+ # that's not the cached one, however. So while it is of the identical
+ # structure, if it has entities like AliasedInsp, which we get from
+ # aliased() or with_polymorphic(), the AliasedInsp will likely be a
+ # different object identity each time, and will not match up
+ # hashing-wise to the corresponding AliasedInsp that's in the
+ # cached query, meaning it won't match on paths and loader lookups
+ # and loaders like this one will be skipped if it is used in options.
+ #
+ # Now we want to transfer loader options from the parent query to the
+ # "selectinload" query we're about to run. Which query do we transfer
+ # the options from? We use the cached query, because the options in
+ # that query will be in terms of the effective entity we were just
+ # handed.
+ #
+ # But now the selectinload query we are running is *also*
+ # cached. What if it's cached and running from some previous iteration
+ # of that AliasedInsp? Well in that case it will also use the previous
+ # iteration of the loader options. If the query expires and
+ # gets generated again, it will be handed the current effective_entity
+ # and the current _with_options, again in terms of whatever
+ # compile_state.select_statement happens to be right now, so the
+ # query will still be internally consistent and loader callables
+ # will be correctly invoked.
+
+ effective_path = path[self.parent_property]
+
+ if orig_query is context.query:
+ options = new_options = orig_query._with_options
+ user_defined_options = []
+ else:
+ options = orig_query._with_options
+
+ # propagate compile state options from the original query,
+ # updating their "extra_criteria" as necessary.
+ # note this will create a different cache key than
+ # "orig" options if extra_criteria is present, because the copy
+ # of extra_criteria will have different boundparam than that of
+ # the QueryableAttribute in the path
+
+ new_options = [
+ orig_opt._adjust_for_extra_criteria(context)
+ if orig_opt._is_strategy_option
+ else orig_opt
+ for orig_opt in options
+ if orig_opt._is_compile_state or orig_opt._is_legacy_option
+ ]
+
+ # propagate user defined options from the current query
+ user_defined_options = [
+ opt
+ for opt in context.query._with_options
+ if not opt._is_compile_state and not opt._is_legacy_option
+ ]
+
+ if loadopt and loadopt._extra_criteria:
+ new_options += (
+ orm_util.LoaderCriteriaOption(
+ effective_entity,
+ loadopt._generate_extra_criteria(context),
+ ),
+ )
+
+ q = q.options(*new_options)._update_compile_options(
+ {"_current_path": effective_path}
+ )
+ if user_defined_options:
+ q = q.options(*user_defined_options)
+
+ if context.populate_existing:
+ q = q.execution_options(populate_existing=True)
+
+ if self.parent_property.order_by:
+ if not query_info.load_with_join:
+ eager_order_by = self.parent_property.order_by
+ if effective_entity.is_aliased_class:
+ eager_order_by = [
+ effective_entity._adapt_element(elem)
+ for elem in eager_order_by
+ ]
+ q = q.order_by(*eager_order_by)
+ else:
+
+ def _setup_outermost_orderby(compile_context):
+ compile_context.eager_order_by += tuple(
+ util.to_list(self.parent_property.order_by)
+ )
+
+ q = q._add_context_option(
+ _setup_outermost_orderby, self.parent_property
+ )
+
+ if query_info.load_only_child:
+ self._load_via_child(
+ our_states, none_states, query_info, q, context
+ )
+ else:
+ self._load_via_parent(our_states, query_info, q, context)
+
+ def _load_via_child(self, our_states, none_states, query_info, q, context):
+ uselist = self.uselist
+
+ # this sort is really for the benefit of the unit tests
+ our_keys = sorted(our_states)
+ while our_keys:
+ chunk = our_keys[0 : self._chunksize]
+ our_keys = our_keys[self._chunksize :]
+ data = {
+ k: v
+ for k, v in context.session.execute(
+ q,
+ params={
+ "primary_keys": [
+ key[0] if query_info.zero_idx else key
+ for key in chunk
+ ]
+ },
+ ).unique()
+ }
+
+ for key in chunk:
+ # for a real foreign key and no concurrent changes to the
+ # DB while running this method, "key" is always present in
+ # data. However, for primaryjoins without real foreign keys
+ # a non-None primaryjoin condition may still refer to no
+ # related object.
+ related_obj = data.get(key, None)
+ for state, dict_, overwrite in our_states[key]:
+ if not overwrite and self.key in dict_:
+ continue
+
+ state.get_impl(self.key).set_committed_value(
+ state,
+ dict_,
+ related_obj if not uselist else [related_obj],
+ )
+ # populate none states with empty value / collection
+ for state, dict_, overwrite in none_states:
+ if not overwrite and self.key in dict_:
+ continue
+
+ # note it's OK if this is a uselist=True attribute, the empty
+ # collection will be populated
+ state.get_impl(self.key).set_committed_value(state, dict_, None)
+
+ def _load_via_parent(self, our_states, query_info, q, context):
+ uselist = self.uselist
+ _empty_result = () if uselist else None
+
+ while our_states:
+ chunk = our_states[0 : self._chunksize]
+ our_states = our_states[self._chunksize :]
+
+ primary_keys = [
+ key[0] if query_info.zero_idx else key
+ for key, state, state_dict, overwrite in chunk
+ ]
+
+ data = collections.defaultdict(list)
+ for k, v in itertools.groupby(
+ context.session.execute(
+ q, params={"primary_keys": primary_keys}
+ ).unique(),
+ lambda x: x[0],
+ ):
+ data[k].extend(vv[1] for vv in v)
+
+ for key, state, state_dict, overwrite in chunk:
+
+ if not overwrite and self.key in state_dict:
+ continue
+
+ collection = data.get(key, _empty_result)
+
+ if not uselist and collection:
+ if len(collection) > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded "
+ "attribute '%s' " % self
+ )
+ state.get_impl(self.key).set_committed_value(
+ state, state_dict, collection[0]
+ )
+ else:
+ # note that empty tuple set on uselist=False sets the
+ # value to None
+ state.get_impl(self.key).set_committed_value(
+ state, state_dict, collection
+ )
+
+
+def single_parent_validator(desc, prop):
+ def _do_check(state, value, oldvalue, initiator):
+ if value is not None and initiator.key == prop.key:
+ hasparent = initiator.hasparent(attributes.instance_state(value))
+ if hasparent and oldvalue is not value:
+ raise sa_exc.InvalidRequestError(
+ "Instance %s is already associated with an instance "
+ "of %s via its %s attribute, and is only allowed a "
+ "single parent."
+ % (orm_util.instance_str(value), state.class_, prop),
+ code="bbf1",
+ )
+ return value
+
+ def append(state, value, initiator):
+ return _do_check(state, value, None, initiator)
+
+ def set_(state, value, oldvalue, initiator):
+ return _do_check(state, value, oldvalue, initiator)
+
+ event.listen(
+ desc, "append", append, raw=True, retval=True, active_history=True
+ )
+ event.listen(desc, "set", set_, raw=True, retval=True, active_history=True)
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
new file mode 100644
index 0000000..c3dd5df
--- /dev/null
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -0,0 +1,2008 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+"""
+
+from . import util as orm_util
+from .attributes import QueryableAttribute
+from .base import _class_to_mapper
+from .base import _is_aliased_class
+from .base import _is_mapped_class
+from .base import InspectionAttr
+from .interfaces import LoaderOption
+from .interfaces import MapperProperty
+from .interfaces import PropComparator
+from .path_registry import _DEFAULT_TOKEN
+from .path_registry import _WILDCARD_TOKEN
+from .path_registry import PathRegistry
+from .path_registry import TokenRegistry
+from .util import _orm_full_deannotate
+from .. import exc as sa_exc
+from .. import inspect
+from .. import util
+from ..sql import and_
+from ..sql import coercions
+from ..sql import roles
+from ..sql import traversals
+from ..sql import visitors
+from ..sql.base import _generative
+from ..sql.base import Generative
+
+
+class Load(Generative, LoaderOption):
+ """Represents loader options which modify the state of a
+ :class:`_query.Query` in order to affect how various mapped attributes are
+ loaded.
+
+ The :class:`_orm.Load` object is in most cases used implicitly behind the
+ scenes when one makes use of a query option like :func:`_orm.joinedload`,
+ :func:`.defer`, or similar. However, the :class:`_orm.Load` object
+ can also be used directly, and in some cases can be useful.
+
+ To use :class:`_orm.Load` directly, instantiate it with the target mapped
+ class as the argument. This style of usage is
+ useful when dealing with a :class:`_query.Query`
+ that has multiple entities::
+
+ myopt = Load(MyClass).joinedload("widgets")
+
+ The above ``myopt`` can now be used with :meth:`_query.Query.options`,
+ where it
+ will only take effect for the ``MyClass`` entity::
+
+ session.query(MyClass, MyOtherClass).options(myopt)
+
+ One case where :class:`_orm.Load`
+ is useful as public API is when specifying
+ "wildcard" options that only take effect for a certain class::
+
+ session.query(Order).options(Load(Order).lazyload('*'))
+
+ Above, all relationships on ``Order`` will be lazy-loaded, but other
+ attributes on those descendant objects will load using their normal
+ loader strategy.
+
+ .. seealso::
+
+ :ref:`deferred_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ :ref:`relationship_loader_options`
+
+ """
+
+ _is_strategy_option = True
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ (
+ "_context_cache_key",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ (
+ "local_opts",
+ visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+ ),
+ ]
+
+ def __init__(self, entity):
+ insp = inspect(entity)
+ insp._post_inspect
+
+ self.path = insp._path_registry
+ # note that this .context is shared among all descendant
+ # Load objects
+ self.context = util.OrderedDict()
+ self.local_opts = {}
+ self.is_class_strategy = False
+
+ @classmethod
+ def for_existing_path(cls, path):
+ load = cls.__new__(cls)
+ load.path = path
+ load.context = {}
+ load.local_opts = {}
+ load._of_type = None
+ load._extra_criteria = ()
+ return load
+
+ def _generate_extra_criteria(self, context):
+ """Apply the current bound parameters in a QueryContext to the
+ immediate "extra_criteria" stored with this Load object.
+
+ Load objects are typically pulled from the cached version of
+ the statement from a QueryContext. The statement currently being
+ executed will have new values (and keys) for bound parameters in the
+ extra criteria which need to be applied by loader strategies when
+ they handle this criteria for a result set.
+
+ """
+
+ assert (
+ self._extra_criteria
+ ), "this should only be called if _extra_criteria is present"
+
+ orig_query = context.compile_state.select_statement
+ current_query = context.query
+
+ # NOTE: while it seems like we should not do the "apply" operation
+ # here if orig_query is current_query, skipping it in the "optimized"
+ # case causes the query to be different from a cache key perspective,
+ # because we are creating a copy of the criteria which is no longer
+ # the same identity of the _extra_criteria in the loader option
+ # itself. cache key logic produces a different key for
+ # (A, copy_of_A) vs. (A, A), because in the latter case it shortens
+ # the second part of the key to just indicate on identity.
+
+ # if orig_query is current_query:
+ # not cached yet. just do the and_()
+ # return and_(*self._extra_criteria)
+
+ k1 = orig_query._generate_cache_key()
+ k2 = current_query._generate_cache_key()
+
+ return k2._apply_params_to_element(k1, and_(*self._extra_criteria))
+
+ def _adjust_for_extra_criteria(self, context):
+ """Apply the current bound parameters in a QueryContext to all
+ occurrences "extra_criteria" stored within al this Load object;
+ copying in place.
+
+ """
+ orig_query = context.compile_state.select_statement
+
+ applied = {}
+
+ ck = [None, None]
+
+ def process(opt):
+ if not opt._extra_criteria:
+ return
+
+ if ck[0] is None:
+ ck[:] = (
+ orig_query._generate_cache_key(),
+ context.query._generate_cache_key(),
+ )
+ k1, k2 = ck
+
+ opt._extra_criteria = tuple(
+ k2._apply_params_to_element(k1, crit)
+ for crit in opt._extra_criteria
+ )
+
+ return self._deep_clone(applied, process)
+
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ if self.context:
+ cloned.context = util.OrderedDict(
+ [
+ (
+ key,
+ value._deep_clone(applied, process)
+ if isinstance(value, Load)
+ else value,
+ )
+ for key, value in self.context.items()
+ ]
+ )
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
+ @property
+ def _context_cache_key(self):
+ serialized = []
+ if self.context is None:
+ return []
+ for (key, loader_path), obj in self.context.items():
+ if key != "loader":
+ continue
+ serialized.append(loader_path + (obj,))
+ return serialized
+
+ def _generate(self):
+ cloned = super(Load, self)._generate()
+ cloned.local_opts = {}
+ return cloned
+
+ is_opts_only = False
+ is_class_strategy = False
+ strategy = None
+ propagate_to_loaders = False
+ _of_type = None
+ _extra_criteria = ()
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ # process is being run here so that the options given are validated
+ # against what the lead entities were, as well as to accommodate
+ # for the entities having been replaced with equivalents
+ self._process(
+ compile_state,
+ mapper_entities,
+ not bool(compile_state.current_path),
+ )
+
+ def process_compile_state(self, compile_state):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ self._process(
+ compile_state,
+ compile_state._lead_mapper_entities,
+ not bool(compile_state.current_path)
+ and not compile_state.compile_options._for_refresh_state,
+ )
+
+ def _process(self, compile_state, mapper_entities, raiseerr):
+ is_refresh = compile_state.compile_options._for_refresh_state
+ current_path = compile_state.current_path
+ if current_path:
+ for (token, start_path), loader in self.context.items():
+ if is_refresh and not loader.propagate_to_loaders:
+ continue
+ chopped_start_path = self._chop_path(start_path, current_path)
+ if chopped_start_path is not None:
+ compile_state.attributes[
+ (token, chopped_start_path)
+ ] = loader
+ else:
+ compile_state.attributes.update(self.context)
+
+ def _generate_path(
+ self,
+ path,
+ attr,
+ for_strategy,
+ wildcard_key,
+ raiseerr=True,
+ polymorphic_entity_context=None,
+ ):
+ existing_of_type = self._of_type
+ self._of_type = None
+ if raiseerr and not path.has_entity:
+ if isinstance(path, TokenRegistry):
+ raise sa_exc.ArgumentError(
+ "Wildcard token cannot be followed by another entity"
+ )
+ else:
+ raise sa_exc.ArgumentError(
+ "Mapped attribute '%s' does not "
+ "refer to a mapped entity" % (path.prop,)
+ )
+
+ if isinstance(attr, util.string_types):
+
+ default_token = attr.endswith(_DEFAULT_TOKEN)
+ attr_str_name = attr
+ if attr.endswith(_WILDCARD_TOKEN) or default_token:
+ if default_token:
+ self.propagate_to_loaders = False
+ if wildcard_key:
+ attr = "%s:%s" % (wildcard_key, attr)
+
+ # TODO: AliasedInsp inside the path for of_type is not
+ # working for a with_polymorphic entity because the
+ # relationship loaders don't render the with_poly into the
+ # path. See #4469 which will try to improve this
+ if existing_of_type and not existing_of_type.is_aliased_class:
+ path = path.parent[existing_of_type]
+ path = path.token(attr)
+ self.path = path
+ return path
+
+ if existing_of_type:
+ ent = inspect(existing_of_type)
+ else:
+ ent = path.entity
+
+ util.warn_deprecated_20(
+ "Using strings to indicate column or "
+ "relationship paths in loader options is deprecated "
+ "and will be removed in SQLAlchemy 2.0. Please use "
+ "the class-bound attribute directly.",
+ )
+ try:
+ # use getattr on the class to work around
+ # synonyms, hybrids, etc.
+ attr = getattr(ent.class_, attr)
+ except AttributeError as err:
+ if raiseerr:
+ util.raise_(
+ sa_exc.ArgumentError(
+ 'Can\'t find property named "%s" on '
+ "%s in this Query." % (attr, ent)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+ else:
+ try:
+ attr = found_property = attr.property
+ except AttributeError as ae:
+ if not isinstance(attr, MapperProperty):
+ util.raise_(
+ sa_exc.ArgumentError(
+ 'Expected attribute "%s" on %s to be a '
+ "mapped attribute; "
+ "instead got %s object."
+ % (attr_str_name, ent, type(attr))
+ ),
+ replace_context=ae,
+ )
+ else:
+ raise
+
+ path = path[attr]
+ else:
+ insp = inspect(attr)
+
+ if insp.is_mapper or insp.is_aliased_class:
+ # TODO: this does not appear to be a valid codepath. "attr"
+ # would never be a mapper. This block is present in 1.2
+ # as well however does not seem to be accessed in any tests.
+ if not orm_util._entity_corresponds_to_use_path_impl(
+ attr.parent, path[-1]
+ ):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Attribute '%s' does not "
+ "link from element '%s'" % (attr, path.entity)
+ )
+ else:
+ return None
+ elif insp.is_property:
+ prop = found_property = attr
+ path = path[prop]
+ elif insp.is_attribute:
+ prop = found_property = attr.property
+
+ if not orm_util._entity_corresponds_to_use_path_impl(
+ attr.parent, path[-1]
+ ):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ 'Attribute "%s" does not '
+ 'link from element "%s".%s'
+ % (
+ attr,
+ path.entity,
+ (
+ " Did you mean to use "
+ "%s.of_type(%s)?"
+ % (path[-2], attr.class_.__name__)
+ if len(path) > 1
+ and path.entity.is_mapper
+ and attr.parent.is_aliased_class
+ else ""
+ ),
+ )
+ )
+ else:
+ return None
+
+ if attr._extra_criteria and not self._extra_criteria:
+ # in most cases, the process that brings us here will have
+ # already established _extra_criteria. however if not,
+ # and it's present on the attribute, then use that.
+ self._extra_criteria = attr._extra_criteria
+
+ if getattr(attr, "_of_type", None):
+ ac = attr._of_type
+ ext_info = of_type_info = inspect(ac)
+
+ if polymorphic_entity_context is None:
+ polymorphic_entity_context = self.context
+
+ existing = path.entity_path[prop].get(
+ polymorphic_entity_context, "path_with_polymorphic"
+ )
+
+ if not ext_info.is_aliased_class:
+ ac = orm_util.with_polymorphic(
+ ext_info.mapper.base_mapper,
+ ext_info.mapper,
+ aliased=True,
+ _use_mapper_path=True,
+ _existing_alias=inspect(existing)
+ if existing is not None
+ else None,
+ )
+
+ ext_info = inspect(ac)
+
+ path.entity_path[prop].set(
+ polymorphic_entity_context, "path_with_polymorphic", ac
+ )
+
+ path = path[prop][ext_info]
+
+ self._of_type = of_type_info
+
+ else:
+ path = path[prop]
+
+ if for_strategy is not None:
+ found_property._get_strategy(for_strategy)
+ if path.has_entity:
+ path = path.entity_path
+ self.path = path
+ return path
+
+ def __str__(self):
+ return "Load(strategy=%r)" % (self.strategy,)
+
+ def _coerce_strat(self, strategy):
+ if strategy is not None:
+ strategy = tuple(sorted(strategy.items()))
+ return strategy
+
+ def _apply_to_parent(self, parent, applied, bound):
+ raise NotImplementedError(
+ "Only 'unbound' loader options may be used with the "
+ "Load.options() method"
+ )
+
+ @_generative
+ def options(self, *opts):
+ r"""Apply a series of options as sub-options to this
+ :class:`_orm.Load`
+ object.
+
+ E.g.::
+
+ query = session.query(Author)
+ query = query.options(
+ joinedload(Author.book).options(
+ load_only(Book.summary, Book.excerpt),
+ joinedload(Book.citations).options(
+ joinedload(Citation.author)
+ )
+ )
+ )
+
+ :param \*opts: A series of loader option objects (ultimately
+ :class:`_orm.Load` objects) which should be applied to the path
+ specified by this :class:`_orm.Load` object.
+
+ .. versionadded:: 1.3.6
+
+ .. seealso::
+
+ :func:`.defaultload`
+
+ :ref:`relationship_loader_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ """
+ apply_cache = {}
+ bound = not isinstance(self, _UnboundLoad)
+ if bound:
+ raise NotImplementedError(
+ "The options() method is currently only supported "
+ "for 'unbound' loader options"
+ )
+ for opt in opts:
+ opt._apply_to_parent(self, apply_cache, bound)
+
+ @_generative
+ def set_relationship_strategy(
+ self, attr, strategy, propagate_to_loaders=True
+ ):
+ strategy = self._coerce_strat(strategy)
+ self.propagate_to_loaders = propagate_to_loaders
+ cloned = self._clone_for_bind_strategy(attr, strategy, "relationship")
+ self.path = cloned.path
+ self._of_type = cloned._of_type
+ self._extra_criteria = cloned._extra_criteria
+ cloned.is_class_strategy = self.is_class_strategy = False
+ self.propagate_to_loaders = cloned.propagate_to_loaders
+
+ @_generative
+ def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False):
+ strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
+ for attr in attrs:
+ cloned = self._clone_for_bind_strategy(
+ attr, strategy, "column", opts_only=opts_only, opts=opts
+ )
+ cloned.propagate_to_loaders = True
+
+ @_generative
+ def set_generic_strategy(self, attrs, strategy):
+ strategy = self._coerce_strat(strategy)
+ for attr in attrs:
+ cloned = self._clone_for_bind_strategy(attr, strategy, None)
+ cloned.propagate_to_loaders = True
+
+ @_generative
+ def set_class_strategy(self, strategy, opts):
+ strategy = self._coerce_strat(strategy)
+ cloned = self._clone_for_bind_strategy(None, strategy, None)
+ cloned.is_class_strategy = True
+ cloned.propagate_to_loaders = True
+ cloned.local_opts.update(opts)
+
+ def _clone_for_bind_strategy(
+ self, attr, strategy, wildcard_key, opts_only=False, opts=None
+ ):
+ """Create an anonymous clone of the Load/_UnboundLoad that is suitable
+ to be placed in the context / _to_bind collection of this Load
+ object. The clone will then lose references to context/_to_bind
+ in order to not create reference cycles.
+
+ """
+ cloned = self._generate()
+ cloned._generate_path(self.path, attr, strategy, wildcard_key)
+ cloned.strategy = strategy
+
+ cloned.local_opts = self.local_opts
+ if opts:
+ cloned.local_opts.update(opts)
+ if opts_only:
+ cloned.is_opts_only = True
+
+ if strategy or cloned.is_opts_only:
+ cloned._set_path_strategy()
+ return cloned
+
+ def _set_for_path(self, context, path, replace=True, merge_opts=False):
+ if merge_opts or not replace:
+ existing = path.get(context, "loader")
+ if existing:
+ if merge_opts:
+ existing.local_opts.update(self.local_opts)
+ existing._extra_criteria += self._extra_criteria
+ else:
+ path.set(context, "loader", self)
+ else:
+ existing = path.get(context, "loader")
+ path.set(context, "loader", self)
+ if existing and existing.is_opts_only:
+ self.local_opts.update(existing.local_opts)
+ existing._extra_criteria += self._extra_criteria
+
+ def _set_path_strategy(self):
+ if not self.is_class_strategy and self.path.has_entity:
+ effective_path = self.path.parent
+ else:
+ effective_path = self.path
+
+ if effective_path.is_token:
+ for path in effective_path.generate_for_superclasses():
+ self._set_for_path(
+ self.context,
+ path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
+ else:
+ self._set_for_path(
+ self.context,
+ effective_path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
+
+ # remove cycles; _set_path_strategy is always invoked on an
+ # anonymous clone of the Load / UnboundLoad object since #5056
+ self.context = None
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+
+ # can't pickle this right now; warning is raised by strategies
+ d["_extra_criteria"] = ()
+
+ if d["context"] is not None:
+ d["context"] = PathRegistry.serialize_context_dict(
+ d["context"], ("loader",)
+ )
+ d["path"] = self.path.serialize()
+ return d
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self.path = PathRegistry.deserialize(self.path)
+ if self.context is not None:
+ self.context = PathRegistry.deserialize_context_dict(self.context)
+
+ def _chop_path(self, to_chop, path):
+ i = -1
+
+ for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
+ if isinstance(c_token, util.string_types):
+ # TODO: this is approximated from the _UnboundLoad
+ # version and probably has issues, not fully covered.
+
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
+ return to_chop
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_token.key
+ ):
+ return None
+
+ if c_token is p_token:
+ continue
+ elif (
+ isinstance(c_token, InspectionAttr)
+ and c_token.is_mapper
+ and p_token.is_mapper
+ and c_token.isa(p_token)
+ ):
+ continue
+ else:
+ return None
+ return to_chop[i + 1 :]
+
+
+class _UnboundLoad(Load):
+ """Represent a loader option that isn't tied to a root entity.
+
+ The loader option will produce an entity-linked :class:`_orm.Load`
+ object when it is passed :meth:`_query.Query.options`.
+
+ This provides compatibility with the traditional system
+ of freestanding options, e.g. ``joinedload('x.y.z')``.
+
+ """
+
+ def __init__(self):
+ self.path = ()
+ self._to_bind = []
+ self.local_opts = {}
+ self._extra_criteria = ()
+
+ def _gen_cache_key(self, anon_map, bindparams, _unbound_option_seen=None):
+ """Inlined gen_cache_key
+
+ Original traversal is::
+
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ (
+ "_to_bind",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_list,
+ ),
+ (
+ "_extra_criteria",
+ visitors.InternalTraversal.dp_clauseelement_list),
+ (
+ "local_opts",
+ visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+ ),
+ ]
+
+ The inlining is so that the "_to_bind" list can be flattened to not
+ repeat the same UnboundLoad options over and over again.
+
+ See #6869
+
+ """
+
+ idself = id(self)
+ cls = self.__class__
+
+ if idself in anon_map:
+ return (anon_map[idself], cls)
+ else:
+ id_ = anon_map[idself]
+
+ vis = traversals._cache_key_traversal_visitor
+
+ seen = _unbound_option_seen
+ if seen is None:
+ seen = set()
+
+ return (
+ (id_, cls)
+ + vis.visit_multi_list(
+ "path", self.path, self, anon_map, bindparams
+ )
+ + ("strategy", self.strategy)
+ + (
+ (
+ "_to_bind",
+ tuple(
+ elem._gen_cache_key(
+ anon_map, bindparams, _unbound_option_seen=seen
+ )
+ for elem in self._to_bind
+ if elem not in seen and not seen.add(elem)
+ ),
+ )
+ if self._to_bind
+ else ()
+ )
+ + (
+ (
+ "_extra_criteria",
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in self._extra_criteria
+ ),
+ )
+ if self._extra_criteria
+ else ()
+ )
+ + (
+ vis.visit_string_multi_dict(
+ "local_opts", self.local_opts, self, anon_map, bindparams
+ )
+ if self.local_opts
+ else ()
+ )
+ )
+
+ _is_chain_link = False
+
+ def _set_path_strategy(self):
+ self._to_bind.append(self)
+
+ # remove cycles; _set_path_strategy is always invoked on an
+ # anonymous clone of the Load / UnboundLoad object since #5056
+ self._to_bind = None
+
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ cloned._to_bind = [
+ elem._deep_clone(applied, process) for elem in self._to_bind or ()
+ ]
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
+ def _apply_to_parent(self, parent, applied, bound, to_bind=None):
+ if self in applied:
+ return applied[self]
+
+ if to_bind is None:
+ to_bind = self._to_bind
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+ if self.path:
+ attr = self.path[-1]
+ if isinstance(attr, util.string_types) and attr.endswith(
+ _DEFAULT_TOKEN
+ ):
+ attr = attr.split(":")[0] + ":" + _WILDCARD_TOKEN
+ cloned._generate_path(
+ parent.path + self.path[0:-1], attr, self.strategy, None
+ )
+
+ # these assertions can go away once the "sub options" API is
+ # mature
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ uniq = set()
+
+ cloned._to_bind = parent._to_bind
+
+ cloned._to_bind[:] = [
+ elem
+ for elem in cloned._to_bind
+ if elem not in uniq and not uniq.add(elem)
+ ] + [
+ elem._apply_to_parent(parent, applied, bound, to_bind)
+ for elem in to_bind
+ if elem not in uniq and not uniq.add(elem)
+ ]
+
+ cloned.local_opts.update(self.local_opts)
+
+ return cloned
+
+ def _generate_path(self, path, attr, for_strategy, wildcard_key):
+ if (
+ wildcard_key
+ and isinstance(attr, util.string_types)
+ and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN)
+ ):
+ if attr == _DEFAULT_TOKEN:
+ self.propagate_to_loaders = False
+ attr = "%s:%s" % (wildcard_key, attr)
+ if path and _is_mapped_class(path[-1]) and not self.is_class_strategy:
+ path = path[0:-1]
+ if attr:
+ path = path + (attr,)
+ self.path = path
+ self._extra_criteria = getattr(attr, "_extra_criteria", ())
+
+ return path
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+
+ # can't pickle this right now; warning is raised by strategies
+ d["_extra_criteria"] = ()
+
+ d["path"] = self._serialize_path(self.path, filter_aliased_class=True)
+ return d
+
+ def __setstate__(self, state):
+ ret = []
+ for key in state["path"]:
+ if isinstance(key, tuple):
+ if len(key) == 2:
+ # support legacy
+ cls, propkey = key
+ of_type = None
+ else:
+ cls, propkey, of_type = key
+ prop = getattr(cls, propkey)
+ if of_type:
+ prop = prop.of_type(of_type)
+ ret.append(prop)
+ else:
+ ret.append(key)
+ state["path"] = tuple(ret)
+ self.__dict__ = state
+
+ def _process(self, compile_state, mapper_entities, raiseerr):
+ dedupes = compile_state.attributes["_unbound_load_dedupes"]
+ is_refresh = compile_state.compile_options._for_refresh_state
+ for val in self._to_bind:
+ if val not in dedupes:
+ dedupes.add(val)
+ if is_refresh and not val.propagate_to_loaders:
+ continue
+ val._bind_loader(
+ [ent.entity_zero for ent in mapper_entities],
+ compile_state.current_path,
+ compile_state.attributes,
+ raiseerr,
+ )
+
+ @classmethod
+ def _from_keys(cls, meth, keys, chained, kw):
+ opt = _UnboundLoad()
+
+ def _split_key(key):
+ if isinstance(key, util.string_types):
+ # coerce fooload('*') into "default loader strategy"
+ if key == _WILDCARD_TOKEN:
+ return (_DEFAULT_TOKEN,)
+ # coerce fooload(".*") into "wildcard on default entity"
+ elif key.startswith("." + _WILDCARD_TOKEN):
+ util.warn_deprecated(
+ "The undocumented `.{WILDCARD}` format is deprecated "
+ "and will be removed in a future version as it is "
+ "believed to be unused. "
+ "If you have been using this functionality, please "
+ "comment on Issue #4390 on the SQLAlchemy project "
+ "tracker.",
+ version="1.4",
+ )
+ key = key[1:]
+ return key.split(".")
+ else:
+ return (key,)
+
+ all_tokens = [token for key in keys for token in _split_key(key)]
+
+ for token in all_tokens[0:-1]:
+ # set _is_chain_link first so that clones of the
+ # object also inherit this flag
+ opt._is_chain_link = True
+ if chained:
+ opt = meth(opt, token, **kw)
+ else:
+ opt = opt.defaultload(token)
+
+ opt = meth(opt, all_tokens[-1], **kw)
+ opt._is_chain_link = False
+ return opt
+
+ def _chop_path(self, to_chop, path):
+ i = -1
+ for i, (c_token, (p_entity, p_prop)) in enumerate(
+ zip(to_chop, path.pairs())
+ ):
+ if isinstance(c_token, util.string_types):
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
+ return to_chop
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_prop.key
+ ):
+ return None
+ elif isinstance(c_token, PropComparator):
+ if c_token.property is not p_prop or (
+ c_token._parententity is not p_entity
+ and (
+ not c_token._parententity.is_mapper
+ or not c_token._parententity.isa(p_entity)
+ )
+ ):
+ return None
+ else:
+ i += 1
+
+ return to_chop[i:]
+
+ def _serialize_path(self, path, filter_aliased_class=False):
+ ret = []
+ for token in path:
+ if isinstance(token, QueryableAttribute):
+ if (
+ filter_aliased_class
+ and token._of_type
+ and inspect(token._of_type).is_aliased_class
+ ):
+ ret.append((token._parentmapper.class_, token.key, None))
+ else:
+ ret.append(
+ (
+ token._parentmapper.class_,
+ token.key,
+ token._of_type.entity if token._of_type else None,
+ )
+ )
+ elif isinstance(token, PropComparator):
+ ret.append((token._parentmapper.class_, token.key, None))
+ else:
+ ret.append(token)
+ return ret
+
+ def _bind_loader(self, entities, current_path, context, raiseerr):
+ """Convert from an _UnboundLoad() object into a Load() object.
+
+ The _UnboundLoad() uses an informal "path" and does not necessarily
+ refer to a lead entity as it may use string tokens. The Load()
+ OTOH refers to a complete path. This method reconciles from a
+ given Query into a Load.
+
+ Example::
+
+
+ query = session.query(User).options(
+ joinedload("orders").joinedload("items"))
+
+ The above options will be an _UnboundLoad object along the lines
+ of (note this is not the exact API of _UnboundLoad)::
+
+ _UnboundLoad(
+ _to_bind=[
+ _UnboundLoad(["orders"], {"lazy": "joined"}),
+ _UnboundLoad(["orders", "items"], {"lazy": "joined"}),
+ ]
+ )
+
+ After this method, we get something more like this (again this is
+ not exact API)::
+
+ Load(
+ User,
+ (User, User.orders.property))
+ Load(
+ User,
+ (User, User.orders.property, Order, Order.items.property))
+
+ """
+
+ start_path = self.path
+
+ if self.is_class_strategy and current_path:
+ start_path += (entities[0],)
+
+ # _current_path implies we're in a
+ # secondary load with an existing path
+
+ if current_path:
+ start_path = self._chop_path(start_path, current_path)
+
+ if not start_path:
+ return None
+
+ # look at the first token and try to locate within the Query
+ # what entity we are referring towards.
+ token = start_path[0]
+
+ if isinstance(token, util.string_types):
+ entity = self._find_entity_basestring(entities, token, raiseerr)
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ entity = self._find_entity_prop_comparator(
+ entities, prop, token._parententity, raiseerr
+ )
+ elif self.is_class_strategy and _is_mapped_class(token):
+ entity = inspect(token)
+ if entity not in entities:
+ entity = None
+ else:
+ raise sa_exc.ArgumentError(
+ "mapper option expects " "string key or list of attributes"
+ )
+
+ if not entity:
+ return
+
+ path_element = entity
+
+ # transfer our entity-less state into a Load() object
+ # with a real entity path. Start with the lead entity
+ # we just located, then go through the rest of our path
+ # tokens and populate into the Load().
+ loader = Load(path_element)
+
+ if context is None:
+ context = loader.context
+
+ loader.strategy = self.strategy
+ loader.is_opts_only = self.is_opts_only
+ loader.is_class_strategy = self.is_class_strategy
+ loader._extra_criteria = self._extra_criteria
+
+ path = loader.path
+
+ if not loader.is_class_strategy:
+ for idx, token in enumerate(start_path):
+ if not loader._generate_path(
+ loader.path,
+ token,
+ self.strategy if idx == len(start_path) - 1 else None,
+ None,
+ raiseerr,
+ polymorphic_entity_context=context,
+ ):
+ return
+
+ loader.local_opts.update(self.local_opts)
+
+ if not loader.is_class_strategy and loader.path.has_entity:
+ effective_path = loader.path.parent
+ else:
+ effective_path = loader.path
+
+ # prioritize "first class" options over those
+ # that were "links in the chain", e.g. "x" and "y" in
+ # someload("x.y.z") versus someload("x") / someload("x.y")
+
+ if effective_path.is_token:
+ for path in effective_path.generate_for_superclasses():
+ loader._set_for_path(
+ context,
+ path,
+ replace=not self._is_chain_link,
+ merge_opts=self.is_opts_only,
+ )
+ else:
+ loader._set_for_path(
+ context,
+ effective_path,
+ replace=not self._is_chain_link,
+ merge_opts=self.is_opts_only,
+ )
+
+ return loader
+
+ def _find_entity_prop_comparator(self, entities, prop, mapper, raiseerr):
+ if _is_aliased_class(mapper):
+ searchfor = mapper
+ else:
+ searchfor = _class_to_mapper(mapper)
+ for ent in entities:
+ if orm_util._entity_corresponds_to(ent, searchfor):
+ return ent
+ else:
+ if raiseerr:
+ if not list(entities):
+ raise sa_exc.ArgumentError(
+ "Query has only expression-based entities, "
+ 'which do not apply to %s "%s"'
+ % (util.clsname_as_plain_name(type(prop)), prop)
+ )
+ else:
+ raise sa_exc.ArgumentError(
+ 'Mapped attribute "%s" does not apply to any of the '
+ "root entities in this query, e.g. %s. Please "
+ "specify the full path "
+ "from one of the root entities to the target "
+ "attribute. "
+ % (prop, ", ".join(str(x) for x in entities))
+ )
+ else:
+ return None
+
+ def _find_entity_basestring(self, entities, token, raiseerr):
+ if token.endswith(":" + _WILDCARD_TOKEN):
+ if len(list(entities)) != 1:
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Can't apply wildcard ('*') or load_only() "
+ "loader option to multiple entities %s. Specify "
+ "loader options for each entity individually, such "
+ "as %s."
+ % (
+ ", ".join(str(ent) for ent in entities),
+ ", ".join(
+ "Load(%s).some_option('*')" % ent
+ for ent in entities
+ ),
+ )
+ )
+ elif token.endswith(_DEFAULT_TOKEN):
+ raiseerr = False
+
+ for ent in entities:
+ # return only the first _MapperEntity when searching
+ # based on string prop name. Ideally object
+ # attributes are used to specify more exactly.
+ return ent
+ else:
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Query has only expression-based entities - "
+ 'can\'t find property named "%s".' % (token,)
+ )
+ else:
+ return None
+
+
+class loader_option(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, fn):
+ self.name = name = fn.__name__
+ self.fn = fn
+ if hasattr(Load, name):
+ raise TypeError("Load class already has a %s method." % (name))
+ setattr(Load, name, fn)
+
+ return self
+
+ def _add_unbound_fn(self, fn):
+ self._unbound_fn = fn
+ fn_doc = self.fn.__doc__
+ self.fn.__doc__ = """Produce a new :class:`_orm.Load` object with the
+:func:`_orm.%(name)s` option applied.
+
+See :func:`_orm.%(name)s` for usage examples.
+
+""" % {
+ "name": self.name
+ }
+
+ fn.__doc__ = fn_doc
+ return self
+
+ def _add_unbound_all_fn(self, fn):
+ fn.__doc__ = """Produce a standalone "all" option for
+:func:`_orm.%(name)s`.
+
+.. deprecated:: 0.9
+
+ The :func:`_orm.%(name)s_all` function is deprecated, and will be removed
+ in a future release. Please use method chaining with
+ :func:`_orm.%(name)s` instead, as in::
+
+ session.query(MyClass).options(
+ %(name)s("someattribute").%(name)s("anotherattribute")
+ )
+
+""" % {
+ "name": self.name
+ }
+ fn = util.deprecated(
+ # This is used by `baked_lazyload_all` was only deprecated in
+ # version 1.2 so this must stick around until that is removed
+ "0.9",
+ "The :func:`.%(name)s_all` function is deprecated, and will be "
+ "removed in a future release. Please use method chaining with "
+ ":func:`.%(name)s` instead" % {"name": self.name},
+ add_deprecation_to_docstring=False,
+ )(fn)
+
+ self._unbound_all_fn = fn
+ return self
+
+
+@loader_option()
+def contains_eager(loadopt, attr, alias=None):
+ r"""Indicate that the given attribute should be eagerly loaded from
+ columns stated manually in the query.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ The option is used in conjunction with an explicit join that loads
+ the desired rows, i.e.::
+
+ sess.query(Order).\
+ join(Order.user).\
+ options(contains_eager(Order.user))
+
+ The above query would join from the ``Order`` entity to its related
+ ``User`` entity, and the returned ``Order`` objects would have the
+ ``Order.user`` attribute pre-populated.
+
+ It may also be used for customizing the entries in an eagerly loaded
+ collection; queries will normally want to use the
+ :meth:`_query.Query.populate_existing` method assuming the primary
+ collection of parent objects may already have been loaded::
+
+ sess.query(User).\
+ join(User.addresses).\
+ filter(Address.email_address.like('%@aol.com')).\
+ options(contains_eager(User.addresses)).\
+ populate_existing()
+
+ See the section :ref:`contains_eager` for complete usage details.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`contains_eager`
+
+ """
+ if alias is not None:
+ if not isinstance(alias, str):
+ info = inspect(alias)
+ alias = info.selectable
+
+ else:
+ util.warn_deprecated(
+ "Passing a string name for the 'alias' argument to "
+ "'contains_eager()` is deprecated, and will not work in a "
+ "future release. Please use a sqlalchemy.alias() or "
+ "sqlalchemy.orm.aliased() construct.",
+ version="1.4",
+ )
+
+ elif getattr(attr, "_of_type", None):
+ ot = inspect(attr._of_type)
+ alias = ot.selectable
+
+ cloned = loadopt.set_relationship_strategy(
+ attr, {"lazy": "joined"}, propagate_to_loaders=False
+ )
+ cloned.local_opts["eager_from_alias"] = alias
+ return cloned
+
+
+@contains_eager._add_unbound_fn
+def contains_eager(*keys, **kw):
+ return _UnboundLoad()._from_keys(
+ _UnboundLoad.contains_eager, keys, True, kw
+ )
+
+
+@loader_option()
+def load_only(loadopt, *attrs):
+ """Indicate that for a particular entity, only the given list
+ of column-based attribute names should be loaded; all others will be
+ deferred.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ Example - given a class ``User``, load only the ``name`` and ``fullname``
+ attributes::
+
+ session.query(User).options(load_only(User.name, User.fullname))
+
+ Example - given a relationship ``User.addresses -> Address``, specify
+ subquery loading for the ``User.addresses`` collection, but on each
+ ``Address`` object load only the ``email_address`` attribute::
+
+ session.query(User).options(
+ subqueryload(User.addresses).load_only(Address.email_address)
+ )
+
+ For a :class:`_query.Query` that has multiple entities,
+ the lead entity can be
+ specifically referred to using the :class:`_orm.Load` constructor::
+
+ session.query(User, Address).join(User.addresses).options(
+ Load(User).load_only(User.name, User.fullname),
+ Load(Address).load_only(Address.email_address)
+ )
+
+ .. note:: This method will still load a :class:`_schema.Column` even
+ if the column property is defined with ``deferred=True``
+ for the :func:`.column_property` function.
+
+ .. versionadded:: 0.9.0
+
+ """
+ cloned = loadopt.set_column_strategy(
+ attrs, {"deferred": False, "instrument": True}
+ )
+ cloned.set_column_strategy(
+ "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
+ )
+ return cloned
+
+
+@load_only._add_unbound_fn
+def load_only(*attrs):
+ return _UnboundLoad().load_only(*attrs)
+
+
+@loader_option()
+def joinedload(loadopt, attr, innerjoin=None):
+ """Indicate that the given attribute should be loaded using joined
+ eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # joined-load the "orders" collection on "User"
+ query(User).options(joinedload(User.orders))
+
+ # joined-load Order.items and then Item.keywords
+ query(Order).options(
+ joinedload(Order.items).joinedload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # joined-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).joinedload(Item.keywords))
+
+ :param innerjoin: if ``True``, indicates that the joined eager load should
+ use an inner join instead of the default of left outer join::
+
+ query(Order).options(joinedload(Order.user, innerjoin=True))
+
+ In order to chain multiple eager joins together where some may be
+ OUTER and others INNER, right-nested joins are used to link them::
+
+ query(A).options(
+ joinedload(A.bs, innerjoin=False).
+ joinedload(B.cs, innerjoin=True)
+ )
+
+ The above query, linking A.bs via "outer" join and B.cs via "inner" join
+ would render the joins as "a LEFT OUTER JOIN (b JOIN c)". When using
+ older versions of SQLite (< 3.7.16), this form of JOIN is translated to
+ use full subqueries as this syntax is otherwise not directly supported.
+
+ The ``innerjoin`` flag can also be stated with the term ``"unnested"``.
+ This indicates that an INNER JOIN should be used, *unless* the join
+ is linked to a LEFT OUTER JOIN to the left, in which case it
+ will render as LEFT OUTER JOIN. For example, supposing ``A.bs``
+ is an outerjoin::
+
+ query(A).options(
+ joinedload(A.bs).
+ joinedload(B.cs, innerjoin="unnested")
+ )
+
+ The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c",
+ rather than as "a LEFT OUTER JOIN (b JOIN c)".
+
+ .. note:: The "unnested" flag does **not** affect the JOIN rendered
+ from a many-to-many association table, e.g. a table configured
+ as :paramref:`_orm.relationship.secondary`, to the target table; for
+ correctness of results, these joins are always INNER and are
+ therefore right-nested if linked to an OUTER join.
+
+ .. versionchanged:: 1.0.0 ``innerjoin=True`` now implies
+ ``innerjoin="nested"``, whereas in 0.9 it implied
+ ``innerjoin="unnested"``. In order to achieve the pre-1.0 "unnested"
+ inner join behavior, use the value ``innerjoin="unnested"``.
+ See :ref:`migration_3008`.
+
+ .. note::
+
+ The joins produced by :func:`_orm.joinedload` are **anonymously
+ aliased**. The criteria by which the join proceeds cannot be
+ modified, nor can the :class:`_query.Query`
+ refer to these joins in any way,
+ including ordering. See :ref:`zen_of_eager_loading` for further
+ detail.
+
+ To produce a specific SQL JOIN which is explicitly available, use
+ :meth:`_query.Query.join`.
+ To combine explicit JOINs with eager loading
+ of collections, use :func:`_orm.contains_eager`; see
+ :ref:`contains_eager`.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`joined_eager_loading`
+
+ """
+ loader = loadopt.set_relationship_strategy(attr, {"lazy": "joined"})
+ if innerjoin is not None:
+ loader.local_opts["innerjoin"] = innerjoin
+ return loader
+
+
+@joinedload._add_unbound_fn
+def joinedload(*keys, **kw):
+ return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, False, kw)
+
+
+@loader_option()
+def subqueryload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ subquery eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # subquery-load the "orders" collection on "User"
+ query(User).options(subqueryload(User.orders))
+
+ # subquery-load Order.items and then Item.keywords
+ query(Order).options(
+ subqueryload(Order.items).subqueryload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # subquery-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).subqueryload(Item.keywords))
+
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`subquery_eager_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "subquery"})
+
+
+@subqueryload._add_unbound_fn
+def subqueryload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.subqueryload, keys, False, {})
+
+
+@loader_option()
+def selectinload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ SELECT IN eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # selectin-load the "orders" collection on "User"
+ query(User).options(selectinload(User.orders))
+
+ # selectin-load Order.items and then Item.keywords
+ query(Order).options(
+ selectinload(Order.items).selectinload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # selectin-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).selectinload(Item.keywords))
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`selectin_eager_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "selectin"})
+
+
+@selectinload._add_unbound_fn
+def selectinload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.selectinload, keys, False, {})
+
+
+@loader_option()
+def lazyload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using "lazy"
+ loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`lazy_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "select"})
+
+
+@lazyload._add_unbound_fn
+def lazyload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.lazyload, keys, False, {})
+
+
+@loader_option()
+def immediateload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ an immediate load with a per-attribute SELECT statement.
+
+ The load is achieved using the "lazyloader" strategy and does not
+ fire off any additional eager loaders.
+
+ The :func:`.immediateload` option is superseded in general
+ by the :func:`.selectinload` option, which performs the same task
+ more efficiently by emitting a SELECT for all loaded objects.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`selectin_eager_loading`
+
+ """
+ loader = loadopt.set_relationship_strategy(attr, {"lazy": "immediate"})
+ return loader
+
+
+@immediateload._add_unbound_fn
+def immediateload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.immediateload, keys, False, {})
+
+
+@loader_option()
+def noload(loadopt, attr):
+ """Indicate that the given relationship attribute should remain unloaded.
+
+ The relationship attribute will return ``None`` when accessed without
+ producing any loading effect.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ :func:`_orm.noload` applies to :func:`_orm.relationship` attributes; for
+ column-based attributes, see :func:`_orm.defer`.
+
+ .. note:: Setting this loading strategy as the default strategy
+ for a relationship using the :paramref:`.orm.relationship.lazy`
+ parameter may cause issues with flushes, such if a delete operation
+ needs to load related objects and instead ``None`` was returned.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ """
+
+ return loadopt.set_relationship_strategy(attr, {"lazy": "noload"})
+
+
+@noload._add_unbound_fn
+def noload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.noload, keys, False, {})
+
+
+@loader_option()
+def raiseload(loadopt, attr, sql_only=False):
+ """Indicate that the given attribute should raise an error if accessed.
+
+ A relationship attribute configured with :func:`_orm.raiseload` will
+ raise an :exc:`~sqlalchemy.exc.InvalidRequestError` upon access. The
+ typical way this is useful is when an application is attempting to ensure
+ that all relationship attributes that are accessed in a particular context
+ would have been already loaded via eager loading. Instead of having
+ to read through SQL logs to ensure lazy loads aren't occurring, this
+ strategy will cause them to raise immediately.
+
+ :func:`_orm.raiseload` applies to :func:`_orm.relationship`
+ attributes only.
+ In order to apply raise-on-SQL behavior to a column-based attribute,
+ use the :paramref:`.orm.defer.raiseload` parameter on the :func:`.defer`
+ loader option.
+
+ :param sql_only: if True, raise only if the lazy load would emit SQL, but
+ not if it is only checking the identity map, or determining that the
+ related value should just be None due to missing keys. When False, the
+ strategy will raise for all varieties of relationship loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`prevent_lazy_with_raiseload`
+
+ :ref:`deferred_raiseload`
+
+ """
+
+ return loadopt.set_relationship_strategy(
+ attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
+ )
+
+
+@raiseload._add_unbound_fn
+def raiseload(*keys, **kw):
+ return _UnboundLoad._from_keys(_UnboundLoad.raiseload, keys, False, kw)
+
+
+@loader_option()
+def defaultload(loadopt, attr):
+ """Indicate an attribute should load using its default loader style.
+
+ This method is used to link to other loader options further into
+ a chain of attributes without altering the loader style of the links
+ along the chain. For example, to set joined eager loading for an
+ element of an element::
+
+ session.query(MyClass).options(
+ defaultload(MyClass.someattribute).
+ joinedload(MyOtherClass.someotherattribute)
+ )
+
+ :func:`.defaultload` is also useful for setting column-level options
+ on a related class, namely that of :func:`.defer` and :func:`.undefer`::
+
+ session.query(MyClass).options(
+ defaultload(MyClass.someattribute).
+ defer("some_column").
+ undefer("some_other_column")
+ )
+
+ .. seealso::
+
+ :meth:`_orm.Load.options` - allows for complex hierarchical
+ loader option structures with less verbosity than with individual
+ :func:`.defaultload` directives.
+
+ :ref:`relationship_loader_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ """
+ return loadopt.set_relationship_strategy(attr, None)
+
+
+@defaultload._add_unbound_fn
+def defaultload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.defaultload, keys, False, {})
+
+
+@loader_option()
+def defer(loadopt, key, raiseload=False):
+ r"""Indicate that the given column-oriented attribute should be deferred,
+ e.g. not loaded until accessed.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ e.g.::
+
+ from sqlalchemy.orm import defer
+
+ session.query(MyClass).options(
+ defer("attribute_one"),
+ defer("attribute_two"))
+
+ session.query(MyClass).options(
+ defer(MyClass.attribute_one),
+ defer(MyClass.attribute_two))
+
+ To specify a deferred load of an attribute on a related class,
+ the path can be specified one token at a time, specifying the loading
+ style for each link along the chain. To leave the loading style
+ for a link unchanged, use :func:`_orm.defaultload`::
+
+ session.query(MyClass).options(defaultload("someattr").defer("some_column"))
+
+ A :class:`_orm.Load` object that is present on a certain path can have
+ :meth:`_orm.Load.defer` called multiple times,
+ each will operate on the same
+ parent entity::
+
+
+ session.query(MyClass).options(
+ defaultload("someattr").
+ defer("some_column").
+ defer("some_other_column").
+ defer("another_column")
+ )
+
+ :param key: Attribute to be deferred.
+
+ :param raiseload: raise :class:`.InvalidRequestError` if the column
+ value is to be loaded from emitting SQL. Used to prevent unwanted
+ SQL from being emitted.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ :param \*addl_attrs: This option supports the old 0.8 style
+ of specifying a path as a series of attributes, which is now superseded
+ by the method-chained style.
+
+ .. deprecated:: 0.9 The \*addl_attrs on :func:`_orm.defer` is
+ deprecated and will be removed in a future release. Please
+ use method chaining in conjunction with defaultload() to
+ indicate a path.
+
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.undefer`
+
+ """
+ strategy = {"deferred": True, "instrument": True}
+ if raiseload:
+ strategy["raiseload"] = True
+ return loadopt.set_column_strategy((key,), strategy)
+
+
+@defer._add_unbound_fn
+def defer(key, *addl_attrs, **kw):
+ if addl_attrs:
+ util.warn_deprecated(
+ "The *addl_attrs on orm.defer is deprecated. Please use "
+ "method chaining in conjunction with defaultload() to "
+ "indicate a path.",
+ version="1.3",
+ )
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.defer, (key,) + addl_attrs, False, kw
+ )
+
+
+@loader_option()
+def undefer(loadopt, key):
+ r"""Indicate that the given column-oriented attribute should be undeferred,
+ e.g. specified within the SELECT statement of the entity as a whole.
+
+ The column being undeferred is typically set up on the mapping as a
+ :func:`.deferred` attribute.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ Examples::
+
+ # undefer two columns
+ session.query(MyClass).options(undefer("col1"), undefer("col2"))
+
+ # undefer all columns specific to a single class using Load + *
+ session.query(MyClass, MyOtherClass).options(
+ Load(MyClass).undefer("*"))
+
+ # undefer a column on a related object
+ session.query(MyClass).options(
+ defaultload(MyClass.items).undefer('text'))
+
+ :param key: Attribute to be undeferred.
+
+ :param \*addl_attrs: This option supports the old 0.8 style
+ of specifying a path as a series of attributes, which is now superseded
+ by the method-chained style.
+
+ .. deprecated:: 0.9 The \*addl_attrs on :func:`_orm.undefer` is
+ deprecated and will be removed in a future release. Please
+ use method chaining in conjunction with defaultload() to
+ indicate a path.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.defer`
+
+ :func:`_orm.undefer_group`
+
+ """
+ return loadopt.set_column_strategy(
+ (key,), {"deferred": False, "instrument": True}
+ )
+
+
+@undefer._add_unbound_fn
+def undefer(key, *addl_attrs):
+ if addl_attrs:
+ util.warn_deprecated(
+ "The *addl_attrs on orm.undefer is deprecated. Please use "
+ "method chaining in conjunction with defaultload() to "
+ "indicate a path.",
+ version="1.3",
+ )
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.undefer, (key,) + addl_attrs, False, {}
+ )
+
+
+@loader_option()
+def undefer_group(loadopt, name):
+ """Indicate that columns within the given deferred group name should be
+ undeferred.
+
+ The columns being undeferred are set up on the mapping as
+ :func:`.deferred` attributes and include a "group" name.
+
+ E.g::
+
+ session.query(MyClass).options(undefer_group("large_attrs"))
+
+ To undefer a group of attributes on a related entity, the path can be
+ spelled out using relationship loader options, such as
+ :func:`_orm.defaultload`::
+
+ session.query(MyClass).options(
+ defaultload("someattr").undefer_group("large_attrs"))
+
+ .. versionchanged:: 0.9.0 :func:`_orm.undefer_group` is now specific to a
+ particular entity load path.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.defer`
+
+ :func:`_orm.undefer`
+
+ """
+ return loadopt.set_column_strategy(
+ "*", None, {"undefer_group_%s" % name: True}, opts_only=True
+ )
+
+
+@undefer_group._add_unbound_fn
+def undefer_group(name):
+ return _UnboundLoad().undefer_group(name)
+
+
+@loader_option()
+def with_expression(loadopt, key, expression):
+ r"""Apply an ad-hoc SQL expression to a "deferred expression" attribute.
+
+ This option is used in conjunction with the :func:`_orm.query_expression`
+ mapper-level construct that indicates an attribute which should be the
+ target of an ad-hoc SQL expression.
+
+ E.g.::
+
+
+ sess.query(SomeClass).options(
+ with_expression(SomeClass.x_y_expr, SomeClass.x + SomeClass.y)
+ )
+
+ .. versionadded:: 1.2
+
+ :param key: Attribute to be undeferred.
+
+ :param expr: SQL expression to be applied to the attribute.
+
+ .. note:: the target attribute is populated only if the target object
+ is **not currently loaded** in the current :class:`_orm.Session`
+ unless the :meth:`_query.Query.populate_existing` method is used.
+ Please refer to :ref:`mapper_querytime_expression` for complete
+ usage details.
+
+ .. seealso::
+
+ :ref:`mapper_querytime_expression`
+
+ """
+
+ expression = coercions.expect(
+ roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+ )
+
+ return loadopt.set_column_strategy(
+ (key,), {"query_expression": True}, opts={"expression": expression}
+ )
+
+
+@with_expression._add_unbound_fn
+def with_expression(key, expression):
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.with_expression, (key,), False, {"expression": expression}
+ )
+
+
+@loader_option()
+def selectin_polymorphic(loadopt, classes):
+ """Indicate an eager load should take place for all attributes
+ specific to a subclass.
+
+ This uses an additional SELECT with IN against all matched primary
+ key values, and is the per-query analogue to the ``"selectin"``
+ setting on the :paramref:`.mapper.polymorphic_load` parameter.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`polymorphic_selectin`
+
+ """
+ loadopt.set_class_strategy(
+ {"selectinload_polymorphic": True},
+ opts={
+ "entities": tuple(
+ sorted((inspect(cls) for cls in classes), key=id)
+ )
+ },
+ )
+ return loadopt
+
+
+@selectin_polymorphic._add_unbound_fn
+def selectin_polymorphic(base_cls, classes):
+ ul = _UnboundLoad()
+ ul.is_class_strategy = True
+ ul.path = (inspect(base_cls),)
+ ul.selectin_polymorphic(classes)
+ return ul
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
new file mode 100644
index 0000000..c041804
--- /dev/null
+++ b/lib/sqlalchemy/orm/sync.py
@@ -0,0 +1,167 @@
+# orm/sync.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used for copying data
+between instances based on join conditions.
+
+"""
+
+from . import attributes
+from . import exc
+from . import util as orm_util
+from .. import util
+
+
+def populate(
+ source,
+ source_mapper,
+ dest,
+ dest_mapper,
+ synchronize_pairs,
+ uowcommit,
+ flag_cascaded_pks,
+):
+ source_dict = source.dict
+ dest_dict = dest.dict
+
+ for l, r in synchronize_pairs:
+ try:
+ # inline of source_mapper._get_state_attr_by_column
+ prop = source_mapper._columntoproperty[l]
+ value = source.manager[prop.key].impl.get(
+ source, source_dict, attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
+
+ try:
+ # inline of dest_mapper._set_state_attr_by_column
+ prop = dest_mapper._columntoproperty[r]
+ dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err)
+
+ # technically the "r.primary_key" check isn't
+ # needed here, but we check for this condition to limit
+ # how often this logic is invoked for memory/performance
+ # reasons, since we only need this info for a primary key
+ # destination.
+ if (
+ flag_cascaded_pks
+ and l.primary_key
+ and r.primary_key
+ and r.references(l)
+ ):
+ uowcommit.attributes[("pk_cascaded", dest, r)] = True
+
+
+def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
+ # a simplified version of populate() used by bulk insert mode
+ for l, r in synchronize_pairs:
+ try:
+ prop = source_mapper._columntoproperty[l]
+ value = source_dict[prop.key]
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err)
+
+ try:
+ prop = source_mapper._columntoproperty[r]
+ source_dict[prop.key] = value
+ except exc.UnmappedColumnError:
+ _raise_col_to_prop(True, source_mapper, l, source_mapper, r)
+
+
+def clear(dest, dest_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ if (
+ r.primary_key
+ and dest_mapper._get_state_attr_by_column(dest, dest.dict, r)
+ not in orm_util._none_set
+ ):
+
+ raise AssertionError(
+ "Dependency rule tried to blank-out primary key "
+ "column '%s' on instance '%s'" % (r, orm_util.state_str(dest))
+ )
+ try:
+ dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, None, l, dest_mapper, r, err)
+
+
+def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ oldvalue = source_mapper._get_committed_attr_by_column(
+ source.obj(), l
+ )
+ value = source_mapper._get_state_attr_by_column(
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+ dest[r.key] = value
+ dest[old_prefix + r.key] = oldvalue
+
+
+def populate_dict(source, source_mapper, dict_, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ value = source_mapper._get_state_attr_by_column(
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+
+ dict_[r.key] = value
+
+
+def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
+ """return true if the source object has changes from an old to a
+ new value on the given synchronize pairs
+
+ """
+ for l, r in synchronize_pairs:
+ try:
+ prop = source_mapper._columntoproperty[l]
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+ history = uowcommit.get_attribute_history(
+ source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if bool(history.deleted):
+ return True
+ else:
+ return False
+
+
+def _raise_col_to_prop(
+ isdest, source_mapper, source_column, dest_mapper, dest_column, err
+):
+ if isdest:
+ util.raise_(
+ exc.UnmappedColumnError(
+ "Can't execute sync rule for "
+ "destination column '%s'; mapper '%s' does not map "
+ "this column. Try using an explicit `foreign_keys` "
+ "collection which does not include this column (or use "
+ "a viewonly=True relation)." % (dest_column, dest_mapper)
+ ),
+ replace_context=err,
+ )
+ else:
+ util.raise_(
+ exc.UnmappedColumnError(
+ "Can't execute sync rule for "
+ "source column '%s'; mapper '%s' does not map this "
+ "column. Try using an explicit `foreign_keys` "
+ "collection which does not include destination column "
+ "'%s' (or use a viewonly=True relation)."
+ % (source_column, source_mapper, dest_column)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
new file mode 100644
index 0000000..2257637
--- /dev/null
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -0,0 +1,784 @@
+# orm/unitofwork.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The internals for the unit of work system.
+
+The session's flush() process passes objects to a contextual object
+here, which assembles flush tasks based on mappers and their properties,
+organizes them in order of dependency, and executes.
+
+"""
+
+from . import attributes
+from . import exc as orm_exc
+from . import util as orm_util
+from .. import event
+from .. import util
+from ..util import topological
+
+
+def _warn_for_cascade_backrefs(state, prop):
+ util.warn_deprecated_20(
+ '"%s" object is being merged into a Session along the backref '
+ 'cascade path for relationship "%s"; in SQLAlchemy 2.0, this '
+ "reverse cascade will not take place. Set cascade_backrefs to "
+ "False in either the relationship() or backref() function for "
+ "the 2.0 behavior; or to set globally for the whole "
+ "Session, set the future=True flag" % (state.class_.__name__, prop),
+ code="s9r1",
+ )
+
+
+def track_cascade_events(descriptor, prop):
+ """Establish event listeners on object attributes which handle
+ cascade-on-set/append.
+
+ """
+ key = prop.key
+
+ def append(state, item, initiator):
+ # process "save_update" cascade rules for when
+ # an instance is appended to the list of another instance
+
+ if item is None:
+ return
+
+ sess = state.session
+ if sess:
+ if sess._warn_on_events:
+ sess._flush_warning("collection append")
+
+ prop = state.manager.mapper._props[key]
+ item_state = attributes.instance_state(item)
+
+ if (
+ prop._cascade.save_update
+ and (
+ (prop.cascade_backrefs and not sess.future)
+ or key == initiator.key
+ )
+ and not sess._contains_state(item_state)
+ ):
+ if key != initiator.key:
+ _warn_for_cascade_backrefs(item_state, prop)
+ sess._save_or_update_state(item_state)
+ return item
+
+ def remove(state, item, initiator):
+ if item is None:
+ return
+
+ sess = state.session
+
+ prop = state.manager.mapper._props[key]
+
+ if sess and sess._warn_on_events:
+ sess._flush_warning(
+ "collection remove"
+ if prop.uselist
+ else "related attribute delete"
+ )
+
+ if (
+ item is not None
+ and item is not attributes.NEVER_SET
+ and item is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
+ # expunge pending orphans
+ item_state = attributes.instance_state(item)
+
+ if prop.mapper._is_orphan(item_state):
+ if sess and item_state in sess._new:
+ sess.expunge(item)
+ else:
+ # the related item may or may not itself be in a
+ # Session, however the parent for which we are catching
+ # the event is not in a session, so memoize this on the
+ # item
+ item_state._orphaned_outside_of_session = True
+
+ def set_(state, newvalue, oldvalue, initiator):
+ # process "save_update" cascade rules for when an instance
+ # is attached to another instance
+ if oldvalue is newvalue:
+ return newvalue
+
+ sess = state.session
+ if sess:
+
+ if sess._warn_on_events:
+ sess._flush_warning("related attribute set")
+
+ prop = state.manager.mapper._props[key]
+ if newvalue is not None:
+ newvalue_state = attributes.instance_state(newvalue)
+ if (
+ prop._cascade.save_update
+ and (
+ (prop.cascade_backrefs and not sess.future)
+ or key == initiator.key
+ )
+ and not sess._contains_state(newvalue_state)
+ ):
+ if key != initiator.key:
+ _warn_for_cascade_backrefs(newvalue_state, prop)
+ sess._save_or_update_state(newvalue_state)
+
+ if (
+ oldvalue is not None
+ and oldvalue is not attributes.NEVER_SET
+ and oldvalue is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
+ # possible to reach here with attributes.NEVER_SET ?
+ oldvalue_state = attributes.instance_state(oldvalue)
+
+ if oldvalue_state in sess._new and prop.mapper._is_orphan(
+ oldvalue_state
+ ):
+ sess.expunge(oldvalue)
+ return newvalue
+
+ event.listen(descriptor, "append_wo_mutation", append, raw=True)
+ event.listen(descriptor, "append", append, raw=True, retval=True)
+ event.listen(descriptor, "remove", remove, raw=True, retval=True)
+ event.listen(descriptor, "set", set_, raw=True, retval=True)
+
+
+class UOWTransaction(object):
+ def __init__(self, session):
+ self.session = session
+
+ # dictionary used by external actors to
+ # store arbitrary state information.
+ self.attributes = {}
+
+ # dictionary of mappers to sets of
+ # DependencyProcessors, which are also
+ # set to be part of the sorted flush actions,
+ # which have that mapper as a parent.
+ self.deps = util.defaultdict(set)
+
+ # dictionary of mappers to sets of InstanceState
+ # items pending for flush which have that mapper
+ # as a parent.
+ self.mappers = util.defaultdict(set)
+
+ # a dictionary of Preprocess objects, which gather
+ # additional states impacted by the flush
+ # and determine if a flush action is needed
+ self.presort_actions = {}
+
+ # dictionary of PostSortRec objects, each
+ # one issues work during the flush within
+ # a certain ordering.
+ self.postsort_actions = {}
+
+ # a set of 2-tuples, each containing two
+ # PostSortRec objects where the second
+ # is dependent on the first being executed
+ # first
+ self.dependencies = set()
+
+ # dictionary of InstanceState-> (isdelete, listonly)
+ # tuples, indicating if this state is to be deleted
+ # or insert/updated, or just refreshed
+ self.states = {}
+
+ # tracks InstanceStates which will be receiving
+ # a "post update" call. Keys are mappers,
+ # values are a set of states and a set of the
+ # columns which should be included in the update.
+ self.post_update_states = util.defaultdict(lambda: (set(), set()))
+
+ @property
+ def has_work(self):
+ return bool(self.states)
+
+ def was_already_deleted(self, state):
+ """Return ``True`` if the given state is expired and was deleted
+ previously.
+ """
+ if state.expired:
+ try:
+ state._load_expired(state, attributes.PASSIVE_OFF)
+ except orm_exc.ObjectDeletedError:
+ self.session._remove_newly_deleted([state])
+ return True
+ return False
+
+ def is_deleted(self, state):
+ """Return ``True`` if the given state is marked as deleted
+ within this uowtransaction."""
+
+ return state in self.states and self.states[state][0]
+
+ def memo(self, key, callable_):
+ if key in self.attributes:
+ return self.attributes[key]
+ else:
+ self.attributes[key] = ret = callable_()
+ return ret
+
+ def remove_state_actions(self, state):
+ """Remove pending actions for a state from the uowtransaction."""
+
+ isdelete = self.states[state][0]
+
+ self.states[state] = (isdelete, True)
+
+ def get_attribute_history(
+ self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
+ """Facade to attributes.get_state_history(), including
+ caching of results."""
+
+ hashkey = ("history", state, key)
+
+ # cache the objects, not the states; the strong reference here
+ # prevents newly loaded objects from being dereferenced during the
+ # flush process
+
+ if hashkey in self.attributes:
+ history, state_history, cached_passive = self.attributes[hashkey]
+ # if the cached lookup was "passive" and now
+ # we want non-passive, do a non-passive lookup and re-cache
+
+ if (
+ not cached_passive & attributes.SQL_OK
+ and passive & attributes.SQL_OK
+ ):
+ impl = state.manager[key].impl
+ history = impl.get_history(
+ state,
+ state.dict,
+ attributes.PASSIVE_OFF
+ | attributes.LOAD_AGAINST_COMMITTED
+ | attributes.NO_RAISE,
+ )
+ if history and impl.uses_objects:
+ state_history = history.as_state()
+ else:
+ state_history = history
+ self.attributes[hashkey] = (history, state_history, passive)
+ else:
+ impl = state.manager[key].impl
+ # TODO: store the history as (state, object) tuples
+ # so we don't have to keep converting here
+ history = impl.get_history(
+ state,
+ state.dict,
+ passive
+ | attributes.LOAD_AGAINST_COMMITTED
+ | attributes.NO_RAISE,
+ )
+ if history and impl.uses_objects:
+ state_history = history.as_state()
+ else:
+ state_history = history
+ self.attributes[hashkey] = (history, state_history, passive)
+
+ return state_history
+
+ def has_dep(self, processor):
+ return (processor, True) in self.presort_actions
+
+ def register_preprocessor(self, processor, fromparent):
+ key = (processor, fromparent)
+ if key not in self.presort_actions:
+ self.presort_actions[key] = Preprocess(processor, fromparent)
+
+ def register_object(
+ self,
+ state,
+ isdelete=False,
+ listonly=False,
+ cancel_delete=False,
+ operation=None,
+ prop=None,
+ ):
+ if not self.session._contains_state(state):
+ # this condition is normal when objects are registered
+ # as part of a relationship cascade operation. it should
+ # not occur for the top-level register from Session.flush().
+ if not state.deleted and operation is not None:
+ util.warn(
+ "Object of type %s not in session, %s operation "
+ "along '%s' will not proceed"
+ % (orm_util.state_class_str(state), operation, prop)
+ )
+ return False
+
+ if state not in self.states:
+ mapper = state.manager.mapper
+
+ if mapper not in self.mappers:
+ self._per_mapper_flush_actions(mapper)
+
+ self.mappers[mapper].add(state)
+ self.states[state] = (isdelete, listonly)
+ else:
+ if not listonly and (isdelete or cancel_delete):
+ self.states[state] = (isdelete, False)
+ return True
+
+ def register_post_update(self, state, post_update_cols):
+ mapper = state.manager.mapper.base_mapper
+ states, cols = self.post_update_states[mapper]
+ states.add(state)
+ cols.update(post_update_cols)
+
+ def _per_mapper_flush_actions(self, mapper):
+ saves = SaveUpdateAll(self, mapper.base_mapper)
+ deletes = DeleteAll(self, mapper.base_mapper)
+ self.dependencies.add((saves, deletes))
+
+ for dep in mapper._dependency_processors:
+ dep.per_property_preprocessors(self)
+
+ for prop in mapper.relationships:
+ if prop.viewonly:
+ continue
+ dep = prop._dependency_processor
+ dep.per_property_preprocessors(self)
+
+ @util.memoized_property
+ def _mapper_for_dep(self):
+ """return a dynamic mapping of (Mapper, DependencyProcessor) to
+ True or False, indicating if the DependencyProcessor operates
+ on objects of that Mapper.
+
+ The result is stored in the dictionary persistently once
+ calculated.
+
+ """
+ return util.PopulateDict(
+ lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
+ )
+
+ def filter_states_for_dep(self, dep, states):
+ """Filter the given list of InstanceStates to those relevant to the
+ given DependencyProcessor.
+
+ """
+ mapper_for_dep = self._mapper_for_dep
+ return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
+
+ def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
+ checktup = (isdelete, listonly)
+ for mapper in mapper.base_mapper.self_and_descendants:
+ for state in self.mappers[mapper]:
+ if self.states[state] == checktup:
+ yield state
+
+ def _generate_actions(self):
+ """Generate the full, unsorted collection of PostSortRecs as
+ well as dependency pairs for this UOWTransaction.
+
+ """
+ # execute presort_actions, until all states
+ # have been processed. a presort_action might
+ # add new states to the uow.
+ while True:
+ ret = False
+ for action in list(self.presort_actions.values()):
+ if action.execute(self):
+ ret = True
+ if not ret:
+ break
+
+ # see if the graph of mapper dependencies has cycles.
+ self.cycles = cycles = topological.find_cycles(
+ self.dependencies, list(self.postsort_actions.values())
+ )
+
+ if cycles:
+ # if yes, break the per-mapper actions into
+ # per-state actions
+ convert = dict(
+ (rec, set(rec.per_state_flush_actions(self))) for rec in cycles
+ )
+
+ # rewrite the existing dependencies to point to
+ # the per-state actions for those per-mapper actions
+ # that were broken up.
+ for edge in list(self.dependencies):
+ if (
+ None in edge
+ or edge[0].disabled
+ or edge[1].disabled
+ or cycles.issuperset(edge)
+ ):
+ self.dependencies.remove(edge)
+ elif edge[0] in cycles:
+ self.dependencies.remove(edge)
+ for dep in convert[edge[0]]:
+ self.dependencies.add((dep, edge[1]))
+ elif edge[1] in cycles:
+ self.dependencies.remove(edge)
+ for dep in convert[edge[1]]:
+ self.dependencies.add((edge[0], dep))
+
+ return set(
+ [a for a in self.postsort_actions.values() if not a.disabled]
+ ).difference(cycles)
+
+ def execute(self):
+ postsort_actions = self._generate_actions()
+
+ postsort_actions = sorted(
+ postsort_actions,
+ key=lambda item: item.sort_key,
+ )
+ # sort = topological.sort(self.dependencies, postsort_actions)
+ # print "--------------"
+ # print "\ndependencies:", self.dependencies
+ # print "\ncycles:", self.cycles
+ # print "\nsort:", list(sort)
+ # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
+
+ # execute
+ if self.cycles:
+ for subset in topological.sort_as_subsets(
+ self.dependencies, postsort_actions
+ ):
+ set_ = set(subset)
+ while set_:
+ n = set_.pop()
+ n.execute_aggregate(self, set_)
+ else:
+ for rec in topological.sort(self.dependencies, postsort_actions):
+ rec.execute(self)
+
+ def finalize_flush_changes(self):
+ """Mark processed objects as clean / deleted after a successful
+ flush().
+
+ This method is called within the flush() method after the
+ execute() method has succeeded and the transaction has been committed.
+
+ """
+ if not self.states:
+ return
+
+ states = set(self.states)
+ isdel = set(
+ s for (s, (isdelete, listonly)) in self.states.items() if isdelete
+ )
+ other = states.difference(isdel)
+ if isdel:
+ self.session._remove_newly_deleted(isdel)
+ if other:
+ self.session._register_persistent(other)
+
+
+class IterateMappersMixin(object):
+ def _mappers(self, uow):
+ if self.fromparent:
+ return iter(
+ m
+ for m in self.dependency_processor.parent.self_and_descendants
+ if uow._mapper_for_dep[(m, self.dependency_processor)]
+ )
+ else:
+ return self.dependency_processor.mapper.self_and_descendants
+
+
+class Preprocess(IterateMappersMixin):
+ __slots__ = (
+ "dependency_processor",
+ "fromparent",
+ "processed",
+ "setup_flush_actions",
+ )
+
+ def __init__(self, dependency_processor, fromparent):
+ self.dependency_processor = dependency_processor
+ self.fromparent = fromparent
+ self.processed = set()
+ self.setup_flush_actions = False
+
+ def execute(self, uow):
+ delete_states = set()
+ save_states = set()
+
+ for mapper in self._mappers(uow):
+ for state in uow.mappers[mapper].difference(self.processed):
+ (isdelete, listonly) = uow.states[state]
+ if not listonly:
+ if isdelete:
+ delete_states.add(state)
+ else:
+ save_states.add(state)
+
+ if delete_states:
+ self.dependency_processor.presort_deletes(uow, delete_states)
+ self.processed.update(delete_states)
+ if save_states:
+ self.dependency_processor.presort_saves(uow, save_states)
+ self.processed.update(save_states)
+
+ if delete_states or save_states:
+ if not self.setup_flush_actions and (
+ self.dependency_processor.prop_has_changes(
+ uow, delete_states, True
+ )
+ or self.dependency_processor.prop_has_changes(
+ uow, save_states, False
+ )
+ ):
+ self.dependency_processor.per_property_flush_actions(uow)
+ self.setup_flush_actions = True
+ return True
+ else:
+ return False
+
+
+class PostSortRec(object):
+ __slots__ = ("disabled",)
+
+ def __new__(cls, uow, *args):
+ key = (cls,) + args
+ if key in uow.postsort_actions:
+ return uow.postsort_actions[key]
+ else:
+ uow.postsort_actions[key] = ret = object.__new__(cls)
+ ret.disabled = False
+ return ret
+
+ def execute_aggregate(self, uow, recs):
+ self.execute(uow)
+
+
+class ProcessAll(IterateMappersMixin, PostSortRec):
+ __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
+
+ def __init__(self, uow, dependency_processor, isdelete, fromparent):
+ self.dependency_processor = dependency_processor
+ self.sort_key = (
+ "ProcessAll",
+ self.dependency_processor.sort_key,
+ isdelete,
+ )
+ self.isdelete = isdelete
+ self.fromparent = fromparent
+ uow.deps[dependency_processor.parent.base_mapper].add(
+ dependency_processor
+ )
+
+ def execute(self, uow):
+ states = self._elements(uow)
+ if self.isdelete:
+ self.dependency_processor.process_deletes(uow, states)
+ else:
+ self.dependency_processor.process_saves(uow, states)
+
+ def per_state_flush_actions(self, uow):
+ # this is handled by SaveUpdateAll and DeleteAll,
+ # since a ProcessAll should unconditionally be pulled
+ # into per-state if either the parent/child mappers
+ # are part of a cycle
+ return iter([])
+
+ def __repr__(self):
+ return "%s(%s, isdelete=%s)" % (
+ self.__class__.__name__,
+ self.dependency_processor,
+ self.isdelete,
+ )
+
+ def _elements(self, uow):
+ for mapper in self._mappers(uow):
+ for state in uow.mappers[mapper]:
+ (isdelete, listonly) = uow.states[state]
+ if isdelete == self.isdelete and not listonly:
+ yield state
+
+
+class PostUpdateAll(PostSortRec):
+ __slots__ = "mapper", "isdelete", "sort_key"
+
+ def __init__(self, uow, mapper, isdelete):
+ self.mapper = mapper
+ self.isdelete = isdelete
+ self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ persistence = util.preloaded.orm_persistence
+ states, cols = uow.post_update_states[self.mapper]
+ states = [s for s in states if uow.states[s][0] == self.isdelete]
+
+ persistence.post_update(self.mapper, states, uow, cols)
+
+
+class SaveUpdateAll(PostSortRec):
+ __slots__ = ("mapper", "sort_key")
+
+ def __init__(self, uow, mapper):
+ self.mapper = mapper
+ self.sort_key = ("SaveUpdateAll", mapper._sort_key)
+ assert mapper is mapper.base_mapper
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ util.preloaded.orm_persistence.save_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, False, False),
+ uow,
+ )
+
+ def per_state_flush_actions(self, uow):
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, False, False)
+ )
+ base_mapper = self.mapper.base_mapper
+ delete_all = DeleteAll(uow, base_mapper)
+ for state in states:
+ # keep saves before deletes -
+ # this ensures 'row switch' operations work
+ action = SaveUpdateState(uow, state)
+ uow.dependencies.add((action, delete_all))
+ yield action
+
+ for dep in uow.deps[self.mapper]:
+ states_for_prop = uow.filter_states_for_dep(dep, states)
+ dep.per_state_flush_actions(uow, states_for_prop, False)
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
+
+
+class DeleteAll(PostSortRec):
+ __slots__ = ("mapper", "sort_key")
+
+ def __init__(self, uow, mapper):
+ self.mapper = mapper
+ self.sort_key = ("DeleteAll", mapper._sort_key)
+ assert mapper is mapper.base_mapper
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ util.preloaded.orm_persistence.delete_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, True, False),
+ uow,
+ )
+
+ def per_state_flush_actions(self, uow):
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, True, False)
+ )
+ base_mapper = self.mapper.base_mapper
+ save_all = SaveUpdateAll(uow, base_mapper)
+ for state in states:
+ # keep saves before deletes -
+ # this ensures 'row switch' operations work
+ action = DeleteState(uow, state)
+ uow.dependencies.add((save_all, action))
+ yield action
+
+ for dep in uow.deps[self.mapper]:
+ states_for_prop = uow.filter_states_for_dep(dep, states)
+ dep.per_state_flush_actions(uow, states_for_prop, True)
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
+
+
+class ProcessState(PostSortRec):
+ __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
+
+ def __init__(self, uow, dependency_processor, isdelete, state):
+ self.dependency_processor = dependency_processor
+ self.sort_key = ("ProcessState", dependency_processor.sort_key)
+ self.isdelete = isdelete
+ self.state = state
+
+ def execute_aggregate(self, uow, recs):
+ cls_ = self.__class__
+ dependency_processor = self.dependency_processor
+ isdelete = self.isdelete
+ our_recs = [
+ r
+ for r in recs
+ if r.__class__ is cls_
+ and r.dependency_processor is dependency_processor
+ and r.isdelete is isdelete
+ ]
+ recs.difference_update(our_recs)
+ states = [self.state] + [r.state for r in our_recs]
+ if isdelete:
+ dependency_processor.process_deletes(uow, states)
+ else:
+ dependency_processor.process_saves(uow, states)
+
+ def __repr__(self):
+ return "%s(%s, %s, delete=%s)" % (
+ self.__class__.__name__,
+ self.dependency_processor,
+ orm_util.state_str(self.state),
+ self.isdelete,
+ )
+
+
+class SaveUpdateState(PostSortRec):
+ __slots__ = "state", "mapper", "sort_key"
+
+ def __init__(self, uow, state):
+ self.state = state
+ self.mapper = state.mapper.base_mapper
+ self.sort_key = ("ProcessState", self.mapper._sort_key)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute_aggregate(self, uow, recs):
+ persistence = util.preloaded.orm_persistence
+ cls_ = self.__class__
+ mapper = self.mapper
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
+ recs.difference_update(our_recs)
+ persistence.save_obj(
+ mapper, [self.state] + [r.state for r in our_recs], uow
+ )
+
+ def __repr__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ orm_util.state_str(self.state),
+ )
+
+
+class DeleteState(PostSortRec):
+ __slots__ = "state", "mapper", "sort_key"
+
+ def __init__(self, uow, state):
+ self.state = state
+ self.mapper = state.mapper.base_mapper
+ self.sort_key = ("DeleteState", self.mapper._sort_key)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute_aggregate(self, uow, recs):
+ persistence = util.preloaded.orm_persistence
+ cls_ = self.__class__
+ mapper = self.mapper
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
+ recs.difference_update(our_recs)
+ states = [self.state] + [r.state for r in our_recs]
+ persistence.delete_obj(
+ mapper, [s for s in states if uow.states[s][0]], uow
+ )
+
+ def __repr__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ orm_util.state_str(self.state),
+ )
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
new file mode 100644
index 0000000..56aa9ff
--- /dev/null
+++ b/lib/sqlalchemy/orm/util.py
@@ -0,0 +1,2149 @@
+# orm/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+import re
+import types
+import weakref
+
+from . import attributes # noqa
+from .base import _class_to_mapper # noqa
+from .base import _never_set # noqa
+from .base import _none_set # noqa
+from .base import attribute_str # noqa
+from .base import class_mapper # noqa
+from .base import InspectionAttr # noqa
+from .base import instance_str # noqa
+from .base import object_mapper # noqa
+from .base import object_state # noqa
+from .base import state_attribute_str # noqa
+from .base import state_class_str # noqa
+from .base import state_str # noqa
+from .interfaces import CriteriaOption
+from .interfaces import MapperProperty # noqa
+from .interfaces import ORMColumnsClauseRole
+from .interfaces import ORMEntityColumnsClauseRole
+from .interfaces import ORMFromClauseRole
+from .interfaces import PropComparator # noqa
+from .path_registry import PathRegistry # noqa
+from .. import event
+from .. import exc as sa_exc
+from .. import inspection
+from .. import sql
+from .. import util
+from ..engine.result import result_tuple
+from ..sql import base as sql_base
+from ..sql import coercions
+from ..sql import expression
+from ..sql import lambdas
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.annotation import SupportsCloneAnnotations
+from ..sql.base import ColumnCollection
+
+
+all_cascades = frozenset(
+ (
+ "delete",
+ "delete-orphan",
+ "all",
+ "merge",
+ "expunge",
+ "save-update",
+ "refresh-expire",
+ "none",
+ )
+)
+
+
+class CascadeOptions(frozenset):
+ """Keeps track of the options sent to
+ :paramref:`.relationship.cascade`"""
+
+ _add_w_all_cascades = all_cascades.difference(
+ ["all", "none", "delete-orphan"]
+ )
+ _allowed_cascades = all_cascades
+
+ _viewonly_cascades = ["expunge", "all", "none", "refresh-expire"]
+
+ __slots__ = (
+ "save_update",
+ "delete",
+ "refresh_expire",
+ "merge",
+ "expunge",
+ "delete_orphan",
+ )
+
+ def __new__(cls, value_list):
+ if isinstance(value_list, util.string_types) or value_list is None:
+ return cls.from_string(value_list)
+ values = set(value_list)
+ if values.difference(cls._allowed_cascades):
+ raise sa_exc.ArgumentError(
+ "Invalid cascade option(s): %s"
+ % ", ".join(
+ [
+ repr(x)
+ for x in sorted(
+ values.difference(cls._allowed_cascades)
+ )
+ ]
+ )
+ )
+
+ if "all" in values:
+ values.update(cls._add_w_all_cascades)
+ if "none" in values:
+ values.clear()
+ values.discard("all")
+
+ self = frozenset.__new__(CascadeOptions, values)
+ self.save_update = "save-update" in values
+ self.delete = "delete" in values
+ self.refresh_expire = "refresh-expire" in values
+ self.merge = "merge" in values
+ self.expunge = "expunge" in values
+ self.delete_orphan = "delete-orphan" in values
+
+ if self.delete_orphan and not self.delete:
+ util.warn(
+ "The 'delete-orphan' cascade " "option requires 'delete'."
+ )
+ return self
+
+ def __repr__(self):
+ return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)]))
+
+ @classmethod
+ def from_string(cls, arg):
+ values = [c for c in re.split(r"\s*,\s*", arg or "") if c]
+ return cls(values)
+
+
+def _validator_events(desc, key, validator, include_removes, include_backrefs):
+ """Runs a validation method on an attribute value to be set or
+ appended.
+ """
+
+ if not include_backrefs:
+
+ def detect_is_backref(state, initiator):
+ impl = state.manager[key].impl
+ return initiator.impl is not impl
+
+ if include_removes:
+
+ def append(state, value, initiator):
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
+ ):
+ return validator(state.obj(), key, value, False)
+ else:
+ return value
+
+ def bulk_set(state, values, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ obj = state.obj()
+ values[:] = [
+ validator(obj, key, value, False) for value in values
+ ]
+
+ def set_(state, value, oldvalue, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ return validator(state.obj(), key, value, False)
+ else:
+ return value
+
+ def remove(state, value, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ validator(state.obj(), key, value, True)
+
+ else:
+
+ def append(state, value, initiator):
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
+ ):
+ return validator(state.obj(), key, value)
+ else:
+ return value
+
+ def bulk_set(state, values, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ obj = state.obj()
+ values[:] = [validator(obj, key, value) for value in values]
+
+ def set_(state, value, oldvalue, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ return validator(state.obj(), key, value)
+ else:
+ return value
+
+ event.listen(desc, "append", append, raw=True, retval=True)
+ event.listen(desc, "bulk_replace", bulk_set, raw=True)
+ event.listen(desc, "set", set_, raw=True, retval=True)
+ if include_removes:
+ event.listen(desc, "remove", remove, raw=True, retval=True)
+
+
+def polymorphic_union(
+ table_map, typecolname, aliasname="p_union", cast_nulls=True
+):
+ """Create a ``UNION`` statement used by a polymorphic mapper.
+
+ See :ref:`concrete_inheritance` for an example of how
+ this is used.
+
+ :param table_map: mapping of polymorphic identities to
+ :class:`_schema.Table` objects.
+ :param typecolname: string name of a "discriminator" column, which will be
+ derived from the query, producing the polymorphic identity for
+ each row. If ``None``, no polymorphic discriminator is generated.
+ :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()`
+ construct generated.
+ :param cast_nulls: if True, non-existent columns, which are represented
+ as labeled NULLs, will be passed into CAST. This is a legacy behavior
+ that is problematic on some backends such as Oracle - in which case it
+ can be set to False.
+
+ """
+
+ colnames = util.OrderedSet()
+ colnamemaps = {}
+ types = {}
+ for key in table_map:
+ table = table_map[key]
+
+ table = coercions.expect(
+ roles.StrictFromClauseRole, table, allow_select=True
+ )
+ table_map[key] = table
+
+ m = {}
+ for c in table.c:
+ if c.key == typecolname:
+ raise sa_exc.InvalidRequestError(
+ "Polymorphic union can't use '%s' as the discriminator "
+ "column due to mapped column %r; please apply the "
+ "'typecolname' "
+ "argument; this is available on "
+ "ConcreteBase as '_concrete_discriminator_name'"
+ % (typecolname, c)
+ )
+ colnames.add(c.key)
+ m[c.key] = c
+ types[c.key] = c.type
+ colnamemaps[table] = m
+
+ def col(name, table):
+ try:
+ return colnamemaps[table][name]
+ except KeyError:
+ if cast_nulls:
+ return sql.cast(sql.null(), types[name]).label(name)
+ else:
+ return sql.type_coerce(sql.null(), types[name]).label(name)
+
+ result = []
+ for type_, table in table_map.items():
+ if typecolname is not None:
+ result.append(
+ sql.select(
+ *(
+ [col(name, table) for name in colnames]
+ + [
+ sql.literal_column(
+ sql_util._quote_ddl_expr(type_)
+ ).label(typecolname)
+ ]
+ )
+ ).select_from(table)
+ )
+ else:
+ result.append(
+ sql.select(
+ *[col(name, table) for name in colnames]
+ ).select_from(table)
+ )
+ return sql.union_all(*result).alias(aliasname)
+
+
+def identity_key(*args, **kwargs):
+ r"""Generate "identity key" tuples, as are used as keys in the
+ :attr:`.Session.identity_map` dictionary.
+
+ This function has several call styles:
+
+ * ``identity_key(class, ident, identity_token=token)``
+
+ This form receives a mapped class and a primary key scalar or
+ tuple as an argument.
+
+ E.g.::
+
+ >>> identity_key(MyClass, (1, 2))
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ :param class: mapped class (must be a positional argument)
+ :param ident: primary key, may be a scalar or tuple argument.
+ :param identity_token: optional identity token
+
+ .. versionadded:: 1.2 added identity_token
+
+
+ * ``identity_key(instance=instance)``
+
+ This form will produce the identity key for a given instance. The
+ instance need not be persistent, only that its primary key attributes
+ are populated (else the key will contain ``None`` for those missing
+ values).
+
+ E.g.::
+
+ >>> instance = MyClass(1, 2)
+ >>> identity_key(instance=instance)
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ In this form, the given instance is ultimately run though
+ :meth:`_orm.Mapper.identity_key_from_instance`, which will have the
+ effect of performing a database check for the corresponding row
+ if the object is expired.
+
+ :param instance: object instance (must be given as a keyword arg)
+
+ * ``identity_key(class, row=row, identity_token=token)``
+
+ This form is similar to the class/tuple form, except is passed a
+ database result row as a :class:`.Row` object.
+
+ E.g.::
+
+ >>> row = engine.execute(\
+ text("select * from table where a=1 and b=2")\
+ ).first()
+ >>> identity_key(MyClass, row=row)
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ :param class: mapped class (must be a positional argument)
+ :param row: :class:`.Row` row returned by a :class:`_engine.CursorResult`
+ (must be given as a keyword arg)
+ :param identity_token: optional identity token
+
+ .. versionadded:: 1.2 added identity_token
+
+ """
+ if args:
+ row = None
+ largs = len(args)
+ if largs == 1:
+ class_ = args[0]
+ try:
+ row = kwargs.pop("row")
+ except KeyError:
+ ident = kwargs.pop("ident")
+ elif largs in (2, 3):
+ class_, ident = args
+ else:
+ raise sa_exc.ArgumentError(
+ "expected up to three positional arguments, " "got %s" % largs
+ )
+
+ identity_token = kwargs.pop("identity_token", None)
+ if kwargs:
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs)
+ )
+ mapper = class_mapper(class_)
+ if row is None:
+ return mapper.identity_key_from_primary_key(
+ util.to_list(ident), identity_token=identity_token
+ )
+ else:
+ return mapper.identity_key_from_row(
+ row, identity_token=identity_token
+ )
+ else:
+ instance = kwargs.pop("instance")
+ if kwargs:
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs.keys)
+ )
+ mapper = object_mapper(instance)
+ return mapper.identity_key_from_instance(instance)
+
+
+class ORMAdapter(sql_util.ColumnAdapter):
+ """ColumnAdapter subclass which excludes adaptation of entities from
+ non-matching mappers.
+
+ """
+
+ def __init__(
+ self,
+ entity,
+ equivalents=None,
+ adapt_required=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ info = inspection.inspect(entity)
+
+ self.mapper = info.mapper
+ selectable = info.selectable
+ is_aliased_class = info.is_aliased_class
+ if is_aliased_class:
+ self.aliased_class = entity
+ else:
+ self.aliased_class = None
+
+ sql_util.ColumnAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ adapt_required=adapt_required,
+ allow_label_resolve=allow_label_resolve,
+ anonymize_labels=anonymize_labels,
+ include_fn=self._include_fn,
+ )
+
+ def _include_fn(self, elem):
+ entity = elem._annotations.get("parentmapper", None)
+
+ return not entity or entity.isa(self.mapper) or self.mapper.isa(entity)
+
+
+class AliasedClass(object):
+ r"""Represents an "aliased" form of a mapped class for usage with Query.
+
+ The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
+ construct, this object mimics the mapped class using a
+ ``__getattr__`` scheme and maintains a reference to a
+ real :class:`~sqlalchemy.sql.expression.Alias` object.
+
+ A primary purpose of :class:`.AliasedClass` is to serve as an alternate
+ within a SQL statement generated by the ORM, such that an existing
+ mapped entity can be used in multiple contexts. A simple example::
+
+ # find all pairs of users with the same name
+ user_alias = aliased(User)
+ session.query(User, user_alias).\
+ join((user_alias, User.id > user_alias.id)).\
+ filter(User.name == user_alias.name)
+
+ :class:`.AliasedClass` is also capable of mapping an existing mapped
+ class to an entirely new selectable, provided this selectable is column-
+ compatible with the existing mapped selectable, and it can also be
+ configured in a mapping as the target of a :func:`_orm.relationship`.
+ See the links below for examples.
+
+ The :class:`.AliasedClass` object is constructed typically using the
+ :func:`_orm.aliased` function. It also is produced with additional
+ configuration when using the :func:`_orm.with_polymorphic` function.
+
+ The resulting object is an instance of :class:`.AliasedClass`.
+ This object implements an attribute scheme which produces the
+ same attribute and method interface as the original mapped
+ class, allowing :class:`.AliasedClass` to be compatible
+ with any attribute technique which works on the original class,
+ including hybrid attributes (see :ref:`hybrids_toplevel`).
+
+ The :class:`.AliasedClass` can be inspected for its underlying
+ :class:`_orm.Mapper`, aliased selectable, and other information
+ using :func:`_sa.inspect`::
+
+ from sqlalchemy import inspect
+ my_alias = aliased(MyClass)
+ insp = inspect(my_alias)
+
+ The resulting inspection object is an instance of :class:`.AliasedInsp`.
+
+
+ .. seealso::
+
+ :func:`.aliased`
+
+ :func:`.with_polymorphic`
+
+ :ref:`relationship_aliased_class`
+
+ :ref:`relationship_to_window_function`
+
+
+ """
+
+ def __init__(
+ self,
+ mapped_class_or_ac,
+ alias=None,
+ name=None,
+ flat=False,
+ adapt_on_names=False,
+ # TODO: None for default here?
+ with_polymorphic_mappers=(),
+ with_polymorphic_discriminator=None,
+ base_alias=None,
+ use_mapper_path=False,
+ represents_outer_join=False,
+ ):
+ insp = inspection.inspect(mapped_class_or_ac)
+ mapper = insp.mapper
+
+ nest_adapters = False
+
+ if alias is None:
+ if insp.is_aliased_class and insp.selectable._is_subquery:
+ alias = insp.selectable.alias()
+ else:
+ alias = (
+ mapper._with_polymorphic_selectable._anonymous_fromclause(
+ name=name,
+ flat=flat,
+ )
+ )
+ elif insp.is_aliased_class:
+ nest_adapters = True
+
+ self._aliased_insp = AliasedInsp(
+ self,
+ insp,
+ alias,
+ name,
+ with_polymorphic_mappers
+ if with_polymorphic_mappers
+ else mapper.with_polymorphic_mappers,
+ with_polymorphic_discriminator
+ if with_polymorphic_discriminator is not None
+ else mapper.polymorphic_on,
+ base_alias,
+ use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ nest_adapters,
+ )
+
+ self.__name__ = "AliasedClass_%s" % mapper.class_.__name__
+
+ @classmethod
+ def _reconstitute_from_aliased_insp(cls, aliased_insp):
+ obj = cls.__new__(cls)
+ obj.__name__ = "AliasedClass_%s" % aliased_insp.mapper.class_.__name__
+ obj._aliased_insp = aliased_insp
+
+ if aliased_insp._is_with_polymorphic:
+ for sub_aliased_insp in aliased_insp._with_polymorphic_entities:
+ if sub_aliased_insp is not aliased_insp:
+ ent = AliasedClass._reconstitute_from_aliased_insp(
+ sub_aliased_insp
+ )
+ setattr(obj, sub_aliased_insp.class_.__name__, ent)
+
+ return obj
+
+ def __getattr__(self, key):
+ try:
+ _aliased_insp = self.__dict__["_aliased_insp"]
+ except KeyError:
+ raise AttributeError()
+ else:
+ target = _aliased_insp._target
+ # maintain all getattr mechanics
+ attr = getattr(target, key)
+
+ # attribute is a method, that will be invoked against a
+ # "self"; so just return a new method with the same function and
+ # new self
+ if hasattr(attr, "__call__") and hasattr(attr, "__self__"):
+ return types.MethodType(attr.__func__, self)
+
+ # attribute is a descriptor, that will be invoked against a
+ # "self"; so invoke the descriptor against this self
+ if hasattr(attr, "__get__"):
+ attr = attr.__get__(None, self)
+
+ # attributes within the QueryableAttribute system will want this
+ # to be invoked so the object can be adapted
+ if hasattr(attr, "adapt_to_entity"):
+ attr = attr.adapt_to_entity(_aliased_insp)
+ setattr(self, key, attr)
+
+ return attr
+
+ def _get_from_serialized(self, key, mapped_class, aliased_insp):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ attr = getattr(mapped_class, key)
+ if hasattr(attr, "__call__") and hasattr(attr, "__self__"):
+ return types.MethodType(attr.__func__, self)
+
+ # attribute is a descriptor, that will be invoked against a
+ # "self"; so invoke the descriptor against this self
+ if hasattr(attr, "__get__"):
+ attr = attr.__get__(None, self)
+
+ # attributes within the QueryableAttribute system will want this
+ # to be invoked so the object can be adapted
+ if hasattr(attr, "adapt_to_entity"):
+ aliased_insp._weak_entity = weakref.ref(self)
+ attr = attr.adapt_to_entity(aliased_insp)
+ setattr(self, key, attr)
+
+ return attr
+
+ def __repr__(self):
+ return "<AliasedClass at 0x%x; %s>" % (
+ id(self),
+ self._aliased_insp._target.__name__,
+ )
+
+ def __str__(self):
+ return str(self._aliased_insp)
+
+
+class AliasedInsp(
+ ORMEntityColumnsClauseRole,
+ ORMFromClauseRole,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """Provide an inspection interface for an
+ :class:`.AliasedClass` object.
+
+ The :class:`.AliasedInsp` object is returned
+ given an :class:`.AliasedClass` using the
+ :func:`_sa.inspect` function::
+
+ from sqlalchemy import inspect
+ from sqlalchemy.orm import aliased
+
+ my_alias = aliased(MyMappedClass)
+ insp = inspect(my_alias)
+
+ Attributes on :class:`.AliasedInsp`
+ include:
+
+ * ``entity`` - the :class:`.AliasedClass` represented.
+ * ``mapper`` - the :class:`_orm.Mapper` mapping the underlying class.
+ * ``selectable`` - the :class:`_expression.Alias`
+ construct which ultimately
+ represents an aliased :class:`_schema.Table` or
+ :class:`_expression.Select`
+ construct.
+ * ``name`` - the name of the alias. Also is used as the attribute
+ name when returned in a result tuple from :class:`_query.Query`.
+ * ``with_polymorphic_mappers`` - collection of :class:`_orm.Mapper`
+ objects
+ indicating all those mappers expressed in the select construct
+ for the :class:`.AliasedClass`.
+ * ``polymorphic_on`` - an alternate column or SQL expression which
+ will be used as the "discriminator" for a polymorphic load.
+
+ .. seealso::
+
+ :ref:`inspection_toplevel`
+
+ """
+
+ def __init__(
+ self,
+ entity,
+ inspected,
+ selectable,
+ name,
+ with_polymorphic_mappers,
+ polymorphic_on,
+ _base_alias,
+ _use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ nest_adapters,
+ ):
+
+ mapped_class_or_ac = inspected.entity
+ mapper = inspected.mapper
+
+ self._weak_entity = weakref.ref(entity)
+ self.mapper = mapper
+ self.selectable = (
+ self.persist_selectable
+ ) = self.local_table = selectable
+ self.name = name
+ self.polymorphic_on = polymorphic_on
+ self._base_alias = weakref.ref(_base_alias or self)
+ self._use_mapper_path = _use_mapper_path
+ self.represents_outer_join = represents_outer_join
+ self._nest_adapters = nest_adapters
+
+ if with_polymorphic_mappers:
+ self._is_with_polymorphic = True
+ self.with_polymorphic_mappers = with_polymorphic_mappers
+ self._with_polymorphic_entities = []
+ for poly in self.with_polymorphic_mappers:
+ if poly is not mapper:
+ ent = AliasedClass(
+ poly.class_,
+ selectable,
+ base_alias=self,
+ adapt_on_names=adapt_on_names,
+ use_mapper_path=_use_mapper_path,
+ )
+
+ setattr(self.entity, poly.class_.__name__, ent)
+ self._with_polymorphic_entities.append(ent._aliased_insp)
+
+ else:
+ self._is_with_polymorphic = False
+ self.with_polymorphic_mappers = [mapper]
+
+ self._adapter = sql_util.ColumnAdapter(
+ selectable,
+ equivalents=mapper._equivalent_columns,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=True,
+ # make sure the adapter doesn't try to grab other tables that
+ # are not even the thing we are mapping, such as embedded
+ # selectables in subqueries or CTEs. See issue #6060
+ adapt_from_selectables={
+ m.selectable
+ for m in self.with_polymorphic_mappers
+ if not adapt_on_names
+ },
+ )
+
+ if nest_adapters:
+ self._adapter = inspected._adapter.wrap(self._adapter)
+
+ self._adapt_on_names = adapt_on_names
+ self._target = mapped_class_or_ac
+ # self._target = mapper.class_ # mapped_class_or_ac
+
+ @property
+ def entity(self):
+ # to eliminate reference cycles, the AliasedClass is held weakly.
+ # this produces some situations where the AliasedClass gets lost,
+ # particularly when one is created internally and only the AliasedInsp
+ # is passed around.
+ # to work around this case, we just generate a new one when we need
+ # it, as it is a simple class with very little initial state on it.
+ ent = self._weak_entity()
+ if ent is None:
+ ent = AliasedClass._reconstitute_from_aliased_insp(self)
+ self._weak_entity = weakref.ref(ent)
+ return ent
+
+ is_aliased_class = True
+ "always returns True"
+
+ @util.memoized_instancemethod
+ def __clause_element__(self):
+ return self.selectable._annotate(
+ {
+ "parentmapper": self.mapper,
+ "parententity": self,
+ "entity_namespace": self,
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ @property
+ def entity_namespace(self):
+ return self.entity
+
+ _cache_key_traversal = [
+ ("name", visitors.ExtendedInternalTraversal.dp_string),
+ ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
+ ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
+ ]
+
+ @property
+ def class_(self):
+ """Return the mapped class ultimately represented by this
+ :class:`.AliasedInsp`."""
+ return self.mapper.class_
+
+ @property
+ def _path_registry(self):
+ if self._use_mapper_path:
+ return self.mapper._path_registry
+ else:
+ return PathRegistry.per_mapper(self)
+
+ def __getstate__(self):
+ return {
+ "entity": self.entity,
+ "mapper": self.mapper,
+ "alias": self.selectable,
+ "name": self.name,
+ "adapt_on_names": self._adapt_on_names,
+ "with_polymorphic_mappers": self.with_polymorphic_mappers,
+ "with_polymorphic_discriminator": self.polymorphic_on,
+ "base_alias": self._base_alias(),
+ "use_mapper_path": self._use_mapper_path,
+ "represents_outer_join": self.represents_outer_join,
+ "nest_adapters": self._nest_adapters,
+ }
+
+ def __setstate__(self, state):
+ self.__init__(
+ state["entity"],
+ state["mapper"],
+ state["alias"],
+ state["name"],
+ state["with_polymorphic_mappers"],
+ state["with_polymorphic_discriminator"],
+ state["base_alias"],
+ state["use_mapper_path"],
+ state["adapt_on_names"],
+ state["represents_outer_join"],
+ state["nest_adapters"],
+ )
+
+ def _adapt_element(self, elem, key=None):
+ d = {
+ "parententity": self,
+ "parentmapper": self.mapper,
+ }
+ if key:
+ d["proxy_key"] = key
+ return (
+ self._adapter.traverse(elem)
+ ._annotate(d)
+ ._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+ )
+
+ def _entity_for_mapper(self, mapper):
+ self_poly = self.with_polymorphic_mappers
+ if mapper in self_poly:
+ if mapper is self.mapper:
+ return self
+ else:
+ return getattr(
+ self.entity, mapper.class_.__name__
+ )._aliased_insp
+ elif mapper.isa(self.mapper):
+ return self
+ else:
+ assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
+
+ @util.memoized_property
+ def _get_clause(self):
+ onclause, replacemap = self.mapper._get_clause
+ return (
+ self._adapter.traverse(onclause),
+ {
+ self._adapter.traverse(col): param
+ for col, param in replacemap.items()
+ },
+ )
+
+ @util.memoized_property
+ def _memoized_values(self):
+ return {}
+
+ @util.memoized_property
+ def _all_column_expressions(self):
+ if self._is_with_polymorphic:
+ cols_plus_keys = self.mapper._columns_plus_keys(
+ [ent.mapper for ent in self._with_polymorphic_entities]
+ )
+ else:
+ cols_plus_keys = self.mapper._columns_plus_keys()
+
+ cols_plus_keys = [
+ (key, self._adapt_element(col)) for key, col in cols_plus_keys
+ ]
+
+ return ColumnCollection(cols_plus_keys)
+
+ def _memo(self, key, callable_, *args, **kw):
+ if key in self._memoized_values:
+ return self._memoized_values[key]
+ else:
+ self._memoized_values[key] = value = callable_(*args, **kw)
+ return value
+
+ def __repr__(self):
+ if self.with_polymorphic_mappers:
+ with_poly = "(%s)" % ", ".join(
+ mp.class_.__name__ for mp in self.with_polymorphic_mappers
+ )
+ else:
+ with_poly = ""
+ return "<AliasedInsp at 0x%x; %s%s>" % (
+ id(self),
+ self.class_.__name__,
+ with_poly,
+ )
+
+ def __str__(self):
+ if self._is_with_polymorphic:
+ return "with_polymorphic(%s, [%s])" % (
+ self._target.__name__,
+ ", ".join(
+ mp.class_.__name__
+ for mp in self.with_polymorphic_mappers
+ if mp is not self.mapper
+ ),
+ )
+ else:
+ return "aliased(%s)" % (self._target.__name__,)
+
+
+class _WrapUserEntity(object):
+ """A wrapper used within the loader_criteria lambda caller so that
+ we can bypass declared_attr descriptors on unmapped mixins, which
+ normally emit a warning for such use.
+
+ might also be useful for other per-lambda instrumentations should
+ the need arise.
+
+ """
+
+ __slots__ = ("subject",)
+
+ def __init__(self, subject):
+ self.subject = subject
+
+ @util.preload_module("sqlalchemy.orm.decl_api")
+ def __getattribute__(self, name):
+ decl_api = util.preloaded.orm.decl_api
+
+ subject = object.__getattribute__(self, "subject")
+ if name in subject.__dict__ and isinstance(
+ subject.__dict__[name], decl_api.declared_attr
+ ):
+ return subject.__dict__[name].fget(subject)
+ else:
+ return getattr(subject, name)
+
+
+class LoaderCriteriaOption(CriteriaOption):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ :class:`_orm.LoaderCriteriaOption` is invoked using the
+ :func:`_orm.with_loader_criteria` function; see that function for
+ details.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _traverse_internals = [
+ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+ ("include_aliases", visitors.InternalTraversal.dp_boolean),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(
+ self,
+ entity_or_base,
+ where_criteria,
+ loader_only=False,
+ include_aliases=False,
+ propagate_to_loaders=True,
+ track_closure_variables=True,
+ ):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ .. versionadded:: 1.4
+
+ The :func:`_orm.with_loader_criteria` option is intended to add
+ limiting criteria to a particular kind of entity in a query,
+ **globally**, meaning it will apply to the entity as it appears
+ in the SELECT query as well as within any subqueries, join
+ conditions, and relationship loads, including both eager and lazy
+ loaders, without the need for it to be specified in any particular
+ part of the query. The rendering logic uses the same system used by
+ single table inheritance to ensure a certain discriminator is applied
+ to a table.
+
+ E.g., using :term:`2.0-style` queries, we can limit the way the
+ ``User.addresses`` collection is loaded, regardless of the kind
+ of loading used::
+
+ from sqlalchemy.orm import with_loader_criteria
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ Above, the "selectinload" for ``User.addresses`` will apply the
+ given filtering criteria to the WHERE clause.
+
+ Another example, where the filtering will be applied to the
+ ON clause of the join, in this example using :term:`1.x style`
+ queries::
+
+ q = session.query(User).outerjoin(User.addresses).options(
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ The primary purpose of :func:`_orm.with_loader_criteria` is to use
+ it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+ to ensure that all occurrences of a particular entity are filtered
+ in a certain way, such as filtering for access control roles. It
+ also can be used to apply criteria to relationship loads. In the
+ example below, we can apply a certain set of rules to all queries
+ emitted by a particular :class:`_orm.Session`::
+
+ session = Session(bind=engine)
+
+ @event.listens_for("do_orm_execute", session)
+ def _add_filtering_criteria(execute_state):
+
+ if (
+ execute_state.is_select
+ and not execute_state.is_column_load
+ and not execute_state.is_relationship_load
+ ):
+ execute_state.statement = execute_state.statement.options(
+ with_loader_criteria(
+ SecurityRole,
+ lambda cls: cls.role.in_(['some_role']),
+ include_aliases=True
+ )
+ )
+
+ In the above example, the :meth:`_orm.SessionEvents.do_orm_execute`
+ event will intercept all queries emitted using the
+ :class:`_orm.Session`. For those queries which are SELECT statements
+ and are not attribute or relationship loads a custom
+ :func:`_orm.with_loader_criteria` option is added to the query. The
+ :func:`_orm.with_loader_criteria` option will be used in the given
+ statement and will also be automatically propagated to all relationship
+ loads that descend from this query.
+
+ The criteria argument given is a ``lambda`` that accepts a ``cls``
+ argument. The given class will expand to include all mapped subclass
+ and need not itself be a mapped class.
+
+ .. tip::
+
+ When using :func:`_orm.with_loader_criteria` option in
+ conjunction with the :func:`_orm.contains_eager` loader option,
+ it's important to note that :func:`_orm.with_loader_criteria` only
+ affects the part of the query that determines what SQL is rendered
+ in terms of the WHERE and FROM clauses. The
+ :func:`_orm.contains_eager` option does not affect the rendering of
+ the SELECT statement outside of the columns clause, so does not have
+ any interaction with the :func:`_orm.with_loader_criteria` option.
+ However, the way things "work" is that :func:`_orm.contains_eager`
+ is meant to be used with a query that is already selecting from the
+ additional entities in some way, where
+ :func:`_orm.with_loader_criteria` can apply it's additional
+ criteria.
+
+ In the example below, assuming a mapping relationship as
+ ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria`
+ option will affect the way in which the JOIN is rendered::
+
+ stmt = select(A).join(A.bs).options(
+ contains_eager(A.bs),
+ with_loader_criteria(B, B.flag == 1)
+ )
+
+ Above, the given :func:`_orm.with_loader_criteria` option will
+ affect the ON clause of the JOIN that is specified by
+ ``.join(A.bs)``, so is applied as expected. The
+ :func:`_orm.contains_eager` option has the effect that columns from
+ ``B`` are added to the columns clause::
+
+ SELECT
+ b.id, b.a_id, b.data, b.flag,
+ a.id AS id_1,
+ a.data AS data_1
+ FROM a JOIN b ON a.id = b.a_id AND b.flag = :flag_1
+
+
+ The use of the :func:`_orm.contains_eager` option within the above
+ statement has no effect on the behavior of the
+ :func:`_orm.with_loader_criteria` option. If the
+ :func:`_orm.contains_eager` option were omitted, the SQL would be
+ the same as regards the FROM and WHERE clauses, where
+ :func:`_orm.with_loader_criteria` continues to add its criteria to
+ the ON clause of the JOIN. The addition of
+ :func:`_orm.contains_eager` only affects the columns clause, in that
+ additional columns against ``b`` are added which are then consumed
+ by the ORM to produce ``B`` instances.
+
+ .. warning:: The use of a lambda inside of the call to
+ :func:`_orm.with_loader_criteria` is only invoked **once per unique
+ class**. Custom functions should not be invoked within this lambda.
+ See :ref:`engine_lambda_caching` for an overview of the "lambda SQL"
+ feature, which is for advanced use only.
+
+ :param entity_or_base: a mapped class, or a class that is a super
+ class of a particular set of mapped classes, to which the rule
+ will apply.
+
+ :param where_criteria: a Core SQL expression that applies limiting
+ criteria. This may also be a "lambda:" or Python function that
+ accepts a target class as an argument, when the given class is
+ a base with many different mapped subclasses.
+
+ .. note:: To support pickling, use a module-level Python function to
+ produce the SQL expression instead of a lambda or a fixed SQL
+ expression, which tend to not be picklable.
+
+ :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+ constructs as well.
+
+ :param propagate_to_loaders: defaults to True, apply to relationship
+ loaders such as lazy loaders. This indicates that the
+ option object itself including SQL expression is carried along with
+ each loaded instance. Set to ``False`` to prevent the object from
+ being assigned to individual instances.
+
+ .. seealso::
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :func:`_orm.with_loader_criteria`.
+
+ :ref:`do_orm_execute_global_criteria` - basic example on how to
+ combine :func:`_orm.with_loader_criteria` with the
+ :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+ :param track_closure_variables: when False, closure variables inside
+ of a lambda expression will not be used as part of
+ any cache key. This allows more complex expressions to be used
+ inside of a lambda expression but requires that the lambda ensures
+ it returns the identical SQL every time given a particular class.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ entity = inspection.inspect(entity_or_base, False)
+ if entity is None:
+ self.root_entity = entity_or_base
+ self.entity = None
+ else:
+ self.root_entity = None
+ self.entity = entity
+
+ self._where_crit_orig = where_criteria
+ if callable(where_criteria):
+ self.deferred_where_criteria = True
+ self.where_criteria = lambdas.DeferredLambdaElement(
+ where_criteria,
+ roles.WhereHavingRole,
+ lambda_args=(
+ _WrapUserEntity(
+ self.root_entity
+ if self.root_entity is not None
+ else self.entity.entity,
+ ),
+ ),
+ opts=lambdas.LambdaOptions(
+ track_closure_variables=track_closure_variables
+ ),
+ )
+ else:
+ self.deferred_where_criteria = False
+ self.where_criteria = coercions.expect(
+ roles.WhereHavingRole, where_criteria
+ )
+
+ self.include_aliases = include_aliases
+ self.propagate_to_loaders = propagate_to_loaders
+
+ @classmethod
+ def _unreduce(
+ cls, entity, where_criteria, include_aliases, propagate_to_loaders
+ ):
+ return LoaderCriteriaOption(
+ entity,
+ where_criteria,
+ include_aliases=include_aliases,
+ propagate_to_loaders=propagate_to_loaders,
+ )
+
+ def __reduce__(self):
+ return (
+ LoaderCriteriaOption._unreduce,
+ (
+ self.entity.class_ if self.entity else self.root_entity,
+ self._where_crit_orig,
+ self.include_aliases,
+ self.propagate_to_loaders,
+ ),
+ )
+
+ def _all_mappers(self):
+
+ if self.entity:
+ for ent in self.entity.mapper.self_and_descendants:
+ yield ent
+ else:
+ stack = list(self.root_entity.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ ent = inspection.inspect(subclass, raiseerr=False)
+ if ent:
+ for mp in ent.mapper.self_and_descendants:
+ yield mp
+ else:
+ stack.extend(subclass.__subclasses__())
+
+ def _should_include(self, compile_state):
+ if (
+ compile_state.select_statement._annotations.get(
+ "for_loader_criteria", None
+ )
+ is self
+ ):
+ return False
+ return True
+
+ def _resolve_where_criteria(self, ext_info):
+ if self.deferred_where_criteria:
+ crit = self.where_criteria._resolve_with_args(ext_info.entity)
+ else:
+ crit = self.where_criteria
+ return sql_util._deep_annotate(
+ crit, {"for_loader_criteria": self}, detect_subquery_cols=True
+ )
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ return self.process_compile_state(compile_state)
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ # if options to limit the criteria to immediate query only,
+ # use compile_state.attributes instead
+
+ if compile_state.compile_options._with_polymorphic_adapt_map:
+ util.warn(
+ "The with_loader_criteria() function may not work "
+ "correctly with the legacy Query.with_polymorphic() feature. "
+ "Please migrate code to use the with_polymorphic() standalone "
+ "function before using with_loader_criteria()."
+ )
+ self.get_global_criteria(compile_state.global_attributes)
+
+ def get_global_criteria(self, attributes):
+ for mp in self._all_mappers():
+ load_criteria = attributes.setdefault(
+ ("additional_entity_criteria", mp), []
+ )
+
+ load_criteria.append(self)
+
+
+inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
+inspection._inspects(AliasedInsp)(lambda target: target)
+
+
+def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False):
+ """Produce an alias of the given element, usually an :class:`.AliasedClass`
+ instance.
+
+ E.g.::
+
+ my_alias = aliased(MyClass)
+
+ session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+
+ The :func:`.aliased` function is used to create an ad-hoc mapping of a
+ mapped class to a new selectable. By default, a selectable is generated
+ from the normally mapped selectable (typically a :class:`_schema.Table`
+ ) using the
+ :meth:`_expression.FromClause.alias` method. However, :func:`.aliased`
+ can also be
+ used to link the class to a new :func:`_expression.select` statement.
+ Also, the :func:`.with_polymorphic` function is a variant of
+ :func:`.aliased` that is intended to specify a so-called "polymorphic
+ selectable", that corresponds to the union of several joined-inheritance
+ subclasses at once.
+
+ For convenience, the :func:`.aliased` function also accepts plain
+ :class:`_expression.FromClause` constructs, such as a
+ :class:`_schema.Table` or
+ :func:`_expression.select` construct. In those cases, the
+ :meth:`_expression.FromClause.alias`
+ method is called on the object and the new
+ :class:`_expression.Alias` object returned. The returned
+ :class:`_expression.Alias` is not
+ ORM-mapped in this case.
+
+ .. seealso::
+
+ :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial`
+
+ :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel`
+
+ :param element: element to be aliased. Is normally a mapped class,
+ but for convenience can also be a :class:`_expression.FromClause`
+ element.
+
+ :param alias: Optional selectable unit to map the element to. This is
+ usually used to link the object to a subquery, and should be an aliased
+ select construct as one would produce from the
+ :meth:`_query.Query.subquery` method or
+ the :meth:`_expression.Select.subquery` or
+ :meth:`_expression.Select.alias` methods of the :func:`_expression.select`
+ construct.
+
+ :param name: optional string name to use for the alias, if not specified
+ by the ``alias`` parameter. The name, among other things, forms the
+ attribute name that will be accessible via tuples returned by a
+ :class:`_query.Query` object. Not supported when creating aliases
+ of :class:`_sql.Join` objects.
+
+ :param flat: Boolean, will be passed through to the
+ :meth:`_expression.FromClause.alias` call so that aliases of
+ :class:`_expression.Join` objects will alias the individual tables
+ inside the join, rather than creating a subquery. This is generally
+ supported by all modern databases with regards to right-nested joins
+ and generally produces more efficient queries.
+
+ :param adapt_on_names: if True, more liberal "matching" will be used when
+ mapping the mapped columns of the ORM entity to those of the
+ given selectable - a name-based match will be performed if the
+ given selectable doesn't otherwise have a column that corresponds
+ to one on the entity. The use case for this is when associating
+ an entity with some derived selectable such as one that uses
+ aggregate functions::
+
+ class UnitPrice(Base):
+ __tablename__ = 'unit_price'
+ ...
+ unit_id = Column(Integer)
+ price = Column(Numeric)
+
+ aggregated_unit_price = Session.query(
+ func.sum(UnitPrice.price).label('price')
+ ).group_by(UnitPrice.unit_id).subquery()
+
+ aggregated_unit_price = aliased(UnitPrice,
+ alias=aggregated_unit_price, adapt_on_names=True)
+
+ Above, functions on ``aggregated_unit_price`` which refer to
+ ``.price`` will return the
+ ``func.sum(UnitPrice.price).label('price')`` column, as it is
+ matched on the name "price". Ordinarily, the "price" function
+ wouldn't have any "column correspondence" to the actual
+ ``UnitPrice.price`` column as it is not a proxy of the original.
+
+ """
+ if isinstance(element, expression.FromClause):
+ if adapt_on_names:
+ raise sa_exc.ArgumentError(
+ "adapt_on_names only applies to ORM elements"
+ )
+ if name:
+ return element.alias(name=name, flat=flat)
+ else:
+ return coercions.expect(
+ roles.AnonymizedFromClauseRole, element, flat=flat
+ )
+ else:
+ return AliasedClass(
+ element,
+ alias=alias,
+ flat=flat,
+ name=name,
+ adapt_on_names=adapt_on_names,
+ )
+
+
+def with_polymorphic(
+ base,
+ classes,
+ selectable=False,
+ flat=False,
+ polymorphic_on=None,
+ aliased=False,
+ adapt_on_names=False,
+ innerjoin=False,
+ _use_mapper_path=False,
+ _existing_alias=None,
+):
+ """Produce an :class:`.AliasedClass` construct which specifies
+ columns for descendant mappers of the given base.
+
+ Using this method will ensure that each descendant mapper's
+ tables are included in the FROM clause, and will allow filter()
+ criterion to be used against those tables. The resulting
+ instances will also have those columns already loaded so that
+ no "post fetch" of those columns will be required.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - full discussion of
+ :func:`_orm.with_polymorphic`.
+
+ :param base: Base class to be aliased.
+
+ :param classes: a single class or mapper, or list of
+ class/mappers, which inherit from the base class.
+ Alternatively, it may also be the string ``'*'``, in which case
+ all descending mapped classes will be added to the FROM clause.
+
+ :param aliased: when True, the selectable will be aliased. For a
+ JOIN, this means the JOIN will be SELECTed from inside of a subquery
+ unless the :paramref:`_orm.with_polymorphic.flat` flag is set to
+ True, which is recommended for simpler use cases.
+
+ :param flat: Boolean, will be passed through to the
+ :meth:`_expression.FromClause.alias` call so that aliases of
+ :class:`_expression.Join` objects will alias the individual tables
+ inside the join, rather than creating a subquery. This is generally
+ supported by all modern databases with regards to right-nested joins
+ and generally produces more efficient queries. Setting this flag is
+ recommended as long as the resulting SQL is functional.
+
+ :param selectable: a table or subquery that will
+ be used in place of the generated FROM clause. This argument is
+ required if any of the desired classes use concrete table
+ inheritance, since SQLAlchemy currently cannot generate UNIONs
+ among tables automatically. If used, the ``selectable`` argument
+ must represent the full set of tables and columns mapped by every
+ mapped class. Otherwise, the unaccounted mapped columns will
+ result in their table being appended directly to the FROM clause
+ which will usually lead to incorrect results.
+
+ When left at its default value of ``False``, the polymorphic
+ selectable assigned to the base mapper is used for selecting rows.
+ However, it may also be passed as ``None``, which will bypass the
+ configured polymorphic selectable and instead construct an ad-hoc
+ selectable for the target classes given; for joined table inheritance
+ this will be a join that includes all target mappers and their
+ subclasses.
+
+ :param polymorphic_on: a column to be used as the "discriminator"
+ column for the given selectable. If not given, the polymorphic_on
+ attribute of the base classes' mapper will be used, if any. This
+ is useful for mappings that don't have polymorphic loading
+ behavior by default.
+
+ :param innerjoin: if True, an INNER JOIN will be used. This should
+ only be specified if querying for one specific subtype only
+
+ :param adapt_on_names: Passes through the
+ :paramref:`_orm.aliased.adapt_on_names`
+ parameter to the aliased object. This may be useful in situations where
+ the given selectable is not directly related to the existing mapped
+ selectable.
+
+ .. versionadded:: 1.4.33
+
+ """
+ primary_mapper = _class_to_mapper(base)
+
+ if selectable not in (None, False) and flat:
+ raise sa_exc.ArgumentError(
+ "the 'flat' and 'selectable' arguments cannot be passed "
+ "simultaneously to with_polymorphic()"
+ )
+
+ if _existing_alias:
+ assert _existing_alias.mapper is primary_mapper
+ classes = util.to_set(classes)
+ new_classes = set(
+ [mp.class_ for mp in _existing_alias.with_polymorphic_mappers]
+ )
+ if classes == new_classes:
+ return _existing_alias
+ else:
+ classes = classes.union(new_classes)
+ mappers, selectable = primary_mapper._with_polymorphic_args(
+ classes, selectable, innerjoin=innerjoin
+ )
+ if aliased or flat:
+ selectable = selectable._anonymous_fromclause(flat=flat)
+ return AliasedClass(
+ base,
+ selectable,
+ adapt_on_names=adapt_on_names,
+ with_polymorphic_mappers=mappers,
+ with_polymorphic_discriminator=polymorphic_on,
+ use_mapper_path=_use_mapper_path,
+ represents_outer_join=not innerjoin,
+ )
+
+
+@inspection._self_inspects
+class Bundle(
+ ORMColumnsClauseRole,
+ SupportsCloneAnnotations,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """A grouping of SQL expressions that are returned by a :class:`.Query`
+ under one namespace.
+
+ The :class:`.Bundle` essentially allows nesting of the tuple-based
+ results returned by a column-oriented :class:`_query.Query` object.
+ It also
+ is extensible via simple subclassing, where the primary capability
+ to override is that of how the set of expressions should be returned,
+ allowing post-processing as well as custom return types, without
+ involving ORM identity-mapped classes.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`bundles`
+
+
+ """
+
+ single_entity = False
+ """If True, queries for a single Bundle will be returned as a single
+ entity, rather than an element within a keyed tuple."""
+
+ is_clause_element = False
+
+ is_mapper = False
+
+ is_aliased_class = False
+
+ is_bundle = True
+
+ _propagate_attrs = util.immutabledict()
+
+ def __init__(self, name, *exprs, **kw):
+ r"""Construct a new :class:`.Bundle`.
+
+ e.g.::
+
+ bn = Bundle("mybundle", MyClass.x, MyClass.y)
+
+ for row in session.query(bn).filter(
+ bn.c.x == 5).filter(bn.c.y == 4):
+ print(row.mybundle.x, row.mybundle.y)
+
+ :param name: name of the bundle.
+ :param \*exprs: columns or SQL expressions comprising the bundle.
+ :param single_entity=False: if True, rows for this :class:`.Bundle`
+ can be returned as a "single entity" outside of any enclosing tuple
+ in the same manner as a mapped entity.
+
+ """
+ self.name = self._label = name
+ self.exprs = exprs = [
+ coercions.expect(
+ roles.ColumnsClauseRole, expr, apply_propagate_attrs=self
+ )
+ for expr in exprs
+ ]
+
+ self.c = self.columns = ColumnCollection(
+ (getattr(col, "key", col._label), col)
+ for col in [e._annotations.get("bundle", e) for e in exprs]
+ )
+ self.single_entity = kw.pop("single_entity", self.single_entity)
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self.__class__, self.name, self.single_entity) + tuple(
+ [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs]
+ )
+
+ @property
+ def mapper(self):
+ return self.exprs[0]._annotations.get("parentmapper", None)
+
+ @property
+ def entity(self):
+ return self.exprs[0]._annotations.get("parententity", None)
+
+ @property
+ def entity_namespace(self):
+ return self.c
+
+ columns = None
+ """A namespace of SQL expressions referred to by this :class:`.Bundle`.
+
+ e.g.::
+
+ bn = Bundle("mybundle", MyClass.x, MyClass.y)
+
+ q = sess.query(bn).filter(bn.c.x == 5)
+
+ Nesting of bundles is also supported::
+
+ b1 = Bundle("b1",
+ Bundle('b2', MyClass.a, MyClass.b),
+ Bundle('b3', MyClass.x, MyClass.y)
+ )
+
+ q = sess.query(b1).filter(
+ b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9)
+
+ .. seealso::
+
+ :attr:`.Bundle.c`
+
+ """
+
+ c = None
+ """An alias for :attr:`.Bundle.columns`."""
+
+ def _clone(self):
+ cloned = self.__class__.__new__(self.__class__)
+ cloned.__dict__.update(self.__dict__)
+ return cloned
+
+ def __clause_element__(self):
+ # ensure existing entity_namespace remains
+ annotations = {"bundle": self, "entity_namespace": self}
+ annotations.update(self._annotations)
+
+ plugin_subject = self.exprs[0]._propagate_attrs.get(
+ "plugin_subject", self.entity
+ )
+ return (
+ expression.ClauseList(
+ _literal_as_text_role=roles.ColumnsClauseRole,
+ group=False,
+ *[e._annotations.get("bundle", e) for e in self.exprs]
+ )
+ ._annotate(annotations)
+ ._set_propagate_attrs(
+ # the Bundle *must* use the orm plugin no matter what. the
+ # subject can be None but it's much better if it's not.
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": plugin_subject,
+ }
+ )
+ )
+
+ @property
+ def clauses(self):
+ return self.__clause_element__().clauses
+
+ def label(self, name):
+ """Provide a copy of this :class:`.Bundle` passing a new label."""
+
+ cloned = self._clone()
+ cloned.name = name
+ return cloned
+
+ def create_row_processor(self, query, procs, labels):
+ """Produce the "row processing" function for this :class:`.Bundle`.
+
+ May be overridden by subclasses.
+
+ .. seealso::
+
+ :ref:`bundles` - includes an example of subclassing.
+
+ """
+ keyed_tuple = result_tuple(labels, [() for l in labels])
+
+ def proc(row):
+ return keyed_tuple([proc(row) for proc in procs])
+
+ return proc
+
+
+def _orm_annotate(element, exclude=None):
+ """Deep copy the given ClauseElement, annotating each element with the
+ "_orm_adapt" flag.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
+ return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
+
+
+def _orm_deannotate(element):
+ """Remove annotations that link a column to a particular mapping.
+
+ Note this doesn't affect "remote" and "foreign" annotations
+ passed by the :func:`_orm.foreign` and :func:`_orm.remote`
+ annotators.
+
+ """
+
+ return sql_util._deep_deannotate(
+ element, values=("_orm_adapt", "parententity")
+ )
+
+
+def _orm_full_deannotate(element):
+ return sql_util._deep_deannotate(element)
+
+
+class _ORMJoin(expression.Join):
+ """Extend Join to support ORM constructs as input."""
+
+ __visit_name__ = expression.Join.__visit_name__
+
+ inherit_cache = True
+
+ def __init__(
+ self,
+ left,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ _left_memo=None,
+ _right_memo=None,
+ _extra_criteria=(),
+ ):
+ left_info = inspection.inspect(left)
+
+ right_info = inspection.inspect(right)
+ adapt_to = right_info.selectable
+
+ # used by joined eager loader
+ self._left_memo = _left_memo
+ self._right_memo = _right_memo
+
+ # legacy, for string attr name ON clause. if that's removed
+ # then the "_joined_from_info" concept can go
+ left_orm_info = getattr(left, "_joined_from_info", left_info)
+ self._joined_from_info = right_info
+ if isinstance(onclause, util.string_types):
+ onclause = getattr(left_orm_info.entity, onclause)
+ # ####
+
+ if isinstance(onclause, attributes.QueryableAttribute):
+ on_selectable = onclause.comparator._source_selectable()
+ prop = onclause.property
+ _extra_criteria += onclause._extra_criteria
+ elif isinstance(onclause, MapperProperty):
+ # used internally by joined eager loader...possibly not ideal
+ prop = onclause
+ on_selectable = prop.parent.selectable
+ else:
+ prop = None
+
+ if prop:
+ left_selectable = left_info.selectable
+
+ if sql_util.clause_is_present(on_selectable, left_selectable):
+ adapt_from = on_selectable
+ else:
+ adapt_from = left_selectable
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = prop._create_joins(
+ source_selectable=adapt_from,
+ dest_selectable=adapt_to,
+ source_polymorphic=True,
+ of_type_entity=right_info,
+ alias_secondary=True,
+ extra_criteria=_extra_criteria,
+ )
+
+ if sj is not None:
+ if isouter:
+ # note this is an inner join from secondary->right
+ right = sql.join(secondary, right, sj)
+ onclause = pj
+ else:
+ left = sql.join(left, secondary, pj, isouter)
+ onclause = sj
+ else:
+ onclause = pj
+
+ self._target_adapter = target_adapter
+
+ augment_onclause = onclause is None and _extra_criteria
+ expression.Join.__init__(self, left, right, onclause, isouter, full)
+
+ if augment_onclause:
+ self.onclause &= sql.and_(*_extra_criteria)
+
+ if (
+ not prop
+ and getattr(right_info, "mapper", None)
+ and right_info.mapper.single
+ ):
+ # if single inheritance target and we are using a manual
+ # or implicit ON clause, augment it the same way we'd augment the
+ # WHERE.
+ single_crit = right_info.mapper._single_table_criterion
+ if single_crit is not None:
+ if right_info.is_aliased_class:
+ single_crit = right_info._adapter.traverse(single_crit)
+ self.onclause = self.onclause & single_crit
+
+ def _splice_into_center(self, other):
+ """Splice a join into the center.
+
+ Given join(a, b) and join(b, c), return join(a, b).join(c)
+
+ """
+ leftmost = other
+ while isinstance(leftmost, sql.Join):
+ leftmost = leftmost.left
+
+ assert self.right is leftmost
+
+ left = _ORMJoin(
+ self.left,
+ other.left,
+ self.onclause,
+ isouter=self.isouter,
+ _left_memo=self._left_memo,
+ _right_memo=other._left_memo,
+ )
+
+ return _ORMJoin(
+ left,
+ other.right,
+ other.onclause,
+ isouter=other.isouter,
+ _right_memo=other._right_memo,
+ )
+
+ def join(
+ self,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ join_to_left=None,
+ ):
+ return _ORMJoin(self, right, onclause, full=full, isouter=isouter)
+
+ def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
+ return _ORMJoin(self, right, onclause, isouter=True, full=full)
+
+
+def join(
+ left, right, onclause=None, isouter=False, full=False, join_to_left=None
+):
+ r"""Produce an inner join between left and right clauses.
+
+ :func:`_orm.join` is an extension to the core join interface
+ provided by :func:`_expression.join()`, where the
+ left and right selectables may be not only core selectable
+ objects such as :class:`_schema.Table`, but also mapped classes or
+ :class:`.AliasedClass` instances. The "on" clause can
+ be a SQL expression or an ORM mapped attribute
+ referencing a configured :func:`_orm.relationship`.
+
+ .. deprecated:: 1.4 using a string relationship name for the "onclause"
+ is deprecated and will be removed in 2.0; the onclause may be only
+ an ORM-mapped relationship attribute or a SQL expression construct.
+
+ :func:`_orm.join` is not commonly needed in modern usage,
+ as its functionality is encapsulated within that of the
+ :meth:`_sql.Select.join` and :meth:`_query.Query.join`
+ methods. which feature a
+ significant amount of automation beyond :func:`_orm.join`
+ by itself. Explicit use of :func:`_orm.join`
+ with ORM-enabled SELECT statements involves use of the
+ :meth:`_sql.Select.select_from` method, as in::
+
+ from sqlalchemy.orm import join
+ stmt = select(User).\
+ select_from(join(User, Address, User.addresses)).\
+ filter(Address.email_address=='foo@bar.com')
+
+ In modern SQLAlchemy the above join can be written more
+ succinctly as::
+
+ stmt = select(User).\
+ join(User.addresses).\
+ filter(Address.email_address=='foo@bar.com')
+
+ See :ref:`orm_queryguide_joins` for information on modern usage
+ of ORM level joins.
+
+ .. deprecated:: 0.8
+
+ the ``join_to_left`` parameter is deprecated, and will be removed
+ in a future release. The parameter has no effect.
+
+ """
+ return _ORMJoin(left, right, onclause, isouter, full)
+
+
+def outerjoin(left, right, onclause=None, full=False, join_to_left=None):
+ """Produce a left outer join between left and right clauses.
+
+ This is the "outer join" version of the :func:`_orm.join` function,
+ featuring the same behavior except that an OUTER JOIN is generated.
+ See that function's documentation for other usage details.
+
+ """
+ return _ORMJoin(left, right, onclause, True, full)
+
+
+def with_parent(instance, prop, from_entity=None):
+ """Create filtering criterion that relates this query's primary entity
+ to the given related instance, using established
+ :func:`_orm.relationship()`
+ configuration.
+
+ E.g.::
+
+ stmt = select(Address).where(with_parent(some_user, User.addresses))
+
+
+ The SQL rendered is the same as that rendered when a lazy loader
+ would fire off from the given parent on that attribute, meaning
+ that the appropriate state is taken from the parent object in
+ Python without the need to render joins to the parent table
+ in the rendered statement.
+
+ The given property may also make use of :meth:`_orm.PropComparator.of_type`
+ to indicate the left side of the criteria::
+
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+ stmt = select(a1, a2).where(
+ with_parent(u1, User.addresses.of_type(a2))
+ )
+
+ The above use is equivalent to using the
+ :func:`_orm.with_parent.from_entity` argument::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+ stmt = select(a1, a2).where(
+ with_parent(u1, User.addresses, from_entity=a2)
+ )
+
+ :param instance:
+ An instance which has some :func:`_orm.relationship`.
+
+ :param property:
+ String property name, or class-bound attribute, which indicates
+ what relationship from the instance should be used to reconcile the
+ parent/child relationship.
+
+ .. deprecated:: 1.4 Using strings is deprecated and will be removed
+ in SQLAlchemy 2.0. Please use the class-bound attribute directly.
+
+ :param from_entity:
+ Entity in which to consider as the left side. This defaults to the
+ "zero" entity of the :class:`_query.Query` itself.
+
+ .. versionadded:: 1.2
+
+ """
+ if isinstance(prop, util.string_types):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in the ORM "
+ "with_parent() function is deprecated and will be removed "
+ "SQLAlchemy 2.0. Please use the class-bound attribute directly."
+ )
+ mapper = object_mapper(instance)
+ prop = getattr(mapper.class_, prop).property
+ elif isinstance(prop, attributes.QueryableAttribute):
+ if prop._of_type:
+ from_entity = prop._of_type
+ prop = prop.property
+
+ return prop._with_parent(instance, from_entity=from_entity)
+
+
+def has_identity(object_):
+ """Return True if the given object has a database
+ identity.
+
+ This typically corresponds to the object being
+ in either the persistent or detached state.
+
+ .. seealso::
+
+ :func:`.was_deleted`
+
+ """
+ state = attributes.instance_state(object_)
+ return state.has_identity
+
+
+def was_deleted(object_):
+ """Return True if the given object was deleted
+ within a session flush.
+
+ This is regardless of whether or not the object is
+ persistent or detached.
+
+ .. seealso::
+
+ :attr:`.InstanceState.was_deleted`
+
+ """
+
+ state = attributes.instance_state(object_)
+ return state.was_deleted
+
+
+def _entity_corresponds_to(given, entity):
+ """determine if 'given' corresponds to 'entity', in terms
+ of an entity passed to Query that would match the same entity
+ being referred to elsewhere in the query.
+
+ """
+ if entity.is_aliased_class:
+ if given.is_aliased_class:
+ if entity._base_alias() is given._base_alias():
+ return True
+ return False
+ elif given.is_aliased_class:
+ if given._use_mapper_path:
+ return entity in given.with_polymorphic_mappers
+ else:
+ return entity is given
+
+ return entity.common_parent(given)
+
+
+def _entity_corresponds_to_use_path_impl(given, entity):
+ """determine if 'given' corresponds to 'entity', in terms
+ of a path of loader options where a mapped attribute is taken to
+ be a member of a parent entity.
+
+ e.g.::
+
+ someoption(A).someoption(A.b) # -> fn(A, A) -> True
+ someoption(A).someoption(C.d) # -> fn(A, C) -> False
+
+ a1 = aliased(A)
+ someoption(a1).someoption(A.b) # -> fn(a1, A) -> False
+ someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True
+
+ wp = with_polymorphic(A, [A1, A2])
+ someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False
+ someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True
+
+
+ """
+ if given.is_aliased_class:
+ return (
+ entity.is_aliased_class
+ and not entity._use_mapper_path
+ and (given is entity or given in entity._with_polymorphic_entities)
+ )
+ elif not entity.is_aliased_class:
+ return given.common_parent(entity.mapper)
+ else:
+ return (
+ entity._use_mapper_path
+ and given in entity.with_polymorphic_mappers
+ )
+
+
+def _entity_isa(given, mapper):
+ """determine if 'given' "is a" mapper, in terms of the given
+ would load rows of type 'mapper'.
+
+ """
+ if given.is_aliased_class:
+ return mapper in given.with_polymorphic_mappers or given.mapper.isa(
+ mapper
+ )
+ elif given.with_polymorphic_mappers:
+ return mapper in given.with_polymorphic_mappers
+ else:
+ return given.isa(mapper)
+
+
+def randomize_unitofwork():
+ """Use random-ordering sets within the unit of work in order
+ to detect unit of work sorting issues.
+
+ This is a utility function that can be used to help reproduce
+ inconsistent unit of work sorting issues. For example,
+ if two kinds of objects A and B are being inserted, and
+ B has a foreign key reference to A - the A must be inserted first.
+ However, if there is no relationship between A and B, the unit of work
+ won't know to perform this sorting, and an operation may or may not
+ fail, depending on how the ordering works out. Since Python sets
+ and dictionaries have non-deterministic ordering, such an issue may
+ occur on some runs and not on others, and in practice it tends to
+ have a great dependence on the state of the interpreter. This leads
+ to so-called "heisenbugs" where changing entirely irrelevant aspects
+ of the test program still cause the failure behavior to change.
+
+ By calling ``randomize_unitofwork()`` when a script first runs, the
+ ordering of a key series of sets within the unit of work implementation
+ are randomized, so that the script can be minimized down to the
+ fundamental mapping and operation that's failing, while still reproducing
+ the issue on at least some runs.
+
+ This utility is also available when running the test suite via the
+ ``--reversetop`` flag.
+
+ """
+ from sqlalchemy.orm import unitofwork, session, mapper, dependency
+ from sqlalchemy.util import topological
+ from sqlalchemy.testing.util import RandomSet
+
+ topological.set = (
+ unitofwork.set
+ ) = session.set = mapper.set = dependency.set = RandomSet
+
+
+def _getitem(iterable_query, item, allow_negative):
+ """calculate __getitem__ in terms of an iterable query object
+ that also has a slice() method.
+
+ """
+
+ def _no_negative_indexes():
+ if not allow_negative:
+ raise IndexError(
+ "negative indexes are not accepted by SQL "
+ "index / slice operators"
+ )
+ else:
+ util.warn_deprecated_20(
+ "Support for negative indexes for SQL index / slice operators "
+ "will be "
+ "removed in 2.0; these operators fetch the complete result "
+ "and do not work efficiently."
+ )
+
+ if isinstance(item, slice):
+ start, stop, step = util.decode_slice(item)
+
+ if (
+ isinstance(stop, int)
+ and isinstance(start, int)
+ and stop - start <= 0
+ ):
+ return []
+
+ elif (isinstance(start, int) and start < 0) or (
+ isinstance(stop, int) and stop < 0
+ ):
+ _no_negative_indexes()
+ return list(iterable_query)[item]
+
+ res = iterable_query.slice(start, stop)
+ if step is not None:
+ return list(res)[None : None : item.step]
+ else:
+ return list(res)
+ else:
+ if item == -1:
+ _no_negative_indexes()
+ return list(iterable_query)[-1]
+ else:
+ return list(iterable_query[item : item + 1])[0]
diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py
new file mode 100644
index 0000000..6a00ef8
--- /dev/null
+++ b/lib/sqlalchemy/pool/__init__.py
@@ -0,0 +1,56 @@
+# sqlalchemy/pool/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Connection pooling for DB-API connections.
+
+Provides a number of connection pool implementations for a variety of
+usage scenarios and thread behavior requirements imposed by the
+application, DB-API or database itself.
+
+Also provides a DB-API 2.0 connection proxying mechanism allowing
+regular DB-API connect() methods to be transparently managed by a
+SQLAlchemy connection pool.
+"""
+
+from . import events
+from .base import _ConnectionFairy
+from .base import _ConnectionRecord
+from .base import _finalize_fairy
+from .base import Pool
+from .base import reset_commit
+from .base import reset_none
+from .base import reset_rollback
+from .dbapi_proxy import clear_managers
+from .dbapi_proxy import manage
+from .impl import AssertionPool
+from .impl import AsyncAdaptedQueuePool
+from .impl import FallbackAsyncAdaptedQueuePool
+from .impl import NullPool
+from .impl import QueuePool
+from .impl import SingletonThreadPool
+from .impl import StaticPool
+
+
+__all__ = [
+ "Pool",
+ "reset_commit",
+ "reset_none",
+ "reset_rollback",
+ "clear_managers",
+ "manage",
+ "AssertionPool",
+ "NullPool",
+ "QueuePool",
+ "AsyncAdaptedQueuePool",
+ "FallbackAsyncAdaptedQueuePool",
+ "SingletonThreadPool",
+ "StaticPool",
+]
+
+# as these are likely to be used in various test suites, debugging
+# setups, keep them in the sqlalchemy.pool namespace
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
new file mode 100644
index 0000000..cde28c2
--- /dev/null
+++ b/lib/sqlalchemy/pool/base.py
@@ -0,0 +1,1121 @@
+# sqlalchemy/pool.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Base constructs for connection pools.
+
+"""
+
+from collections import deque
+import time
+import weakref
+
+from .. import event
+from .. import exc
+from .. import log
+from .. import util
+
+
+reset_rollback = util.symbol("reset_rollback")
+reset_commit = util.symbol("reset_commit")
+reset_none = util.symbol("reset_none")
+
+
+class _ConnDialect(object):
+ """partial implementation of :class:`.Dialect`
+ which provides DBAPI connection methods.
+
+ When a :class:`_pool.Pool` is combined with an :class:`_engine.Engine`,
+ the :class:`_engine.Engine` replaces this with its own
+ :class:`.Dialect`.
+
+ """
+
+ is_async = False
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback()
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit()
+
+ def do_close(self, dbapi_connection):
+ dbapi_connection.close()
+
+ def do_ping(self, dbapi_connection):
+ raise NotImplementedError(
+ "The ping feature requires that a dialect is "
+ "passed to the connection pool."
+ )
+
+ def get_driver_connection(self, connection):
+ return connection
+
+
+class _AsyncConnDialect(_ConnDialect):
+ is_async = True
+
+
+class Pool(log.Identified):
+
+ """Abstract base class for connection pools."""
+
+ _dialect = _ConnDialect()
+
+ def __init__(
+ self,
+ creator,
+ recycle=-1,
+ echo=None,
+ logging_name=None,
+ reset_on_return=True,
+ events=None,
+ dialect=None,
+ pre_ping=False,
+ _dispatch=None,
+ ):
+ """
+ Construct a Pool.
+
+ :param creator: a callable function that returns a DB-API
+ connection object. The function will be called with
+ parameters.
+
+ :param recycle: If set to a value other than -1, number of
+ seconds between connection recycling, which means upon
+ checkout, if this timeout is surpassed the connection will be
+ closed and replaced with a newly opened connection. Defaults to -1.
+
+ :param logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
+ id.
+
+ :param echo: if True, the connection pool will log
+ informational output such as when connections are invalidated
+ as well as when connections are recycled to the default log handler,
+ which defaults to ``sys.stdout`` for output.. If set to the string
+ ``"debug"``, the logging will include pool checkouts and checkins.
+
+ The :paramref:`_pool.Pool.echo` parameter can also be set from the
+ :func:`_sa.create_engine` call by using the
+ :paramref:`_sa.create_engine.echo_pool` parameter.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :param reset_on_return: Determine steps to take on
+ connections as they are returned to the pool, which were
+ not otherwise handled by a :class:`_engine.Connection`.
+
+ reset_on_return can have any of these values:
+
+ * ``"rollback"`` - call rollback() on the connection,
+ to release locks and transaction resources.
+ This is the default value. The vast majority
+ of use cases should leave this value set.
+ * ``True`` - same as 'rollback', this is here for
+ backwards compatibility.
+ * ``"commit"`` - call commit() on the connection,
+ to release locks and transaction resources.
+ A commit here may be desirable for databases that
+ cache query plans if a commit is emitted,
+ such as Microsoft SQL Server. However, this
+ value is more dangerous than 'rollback' because
+ any data changes present on the transaction
+ are committed unconditionally.
+ * ``None`` - don't do anything on the connection.
+ This setting is only appropriate if the database / DBAPI
+ works in pure "autocommit" mode at all times, or if the
+ application uses the :class:`_engine.Engine` with consistent
+ connectivity patterns. See the section
+ :ref:`pool_reset_on_return` for more details.
+
+ * ``False`` - same as None, this is here for
+ backwards compatibility.
+
+ .. seealso::
+
+ :ref:`pool_reset_on_return`
+
+ :param events: a list of 2-tuples, each of the form
+ ``(callable, target)`` which will be passed to :func:`.event.listen`
+ upon construction. Provided here so that event listeners
+ can be assigned via :func:`_sa.create_engine` before dialect-level
+ listeners are applied.
+
+ :param dialect: a :class:`.Dialect` that will handle the job
+ of calling rollback(), close(), or commit() on DBAPI connections.
+ If omitted, a built-in "stub" dialect is used. Applications that
+ make use of :func:`_sa.create_engine` should not use this parameter
+ as it is handled by the engine creation strategy.
+
+ .. versionadded:: 1.1 - ``dialect`` is now a public parameter
+ to the :class:`_pool.Pool`.
+
+ :param pre_ping: if True, the pool will emit a "ping" (typically
+ "SELECT 1", but is dialect-specific) on the connection
+ upon checkout, to test if the connection is alive or not. If not,
+ the connection is transparently re-connected and upon success, all
+ other pooled connections established prior to that timestamp are
+ invalidated. Requires that a dialect is passed as well to
+ interpret the disconnection error.
+
+ .. versionadded:: 1.2
+
+ """
+ if logging_name:
+ self.logging_name = self._orig_logging_name = logging_name
+ else:
+ self._orig_logging_name = None
+
+ log.instance_logger(self, echoflag=echo)
+ self._creator = creator
+ self._recycle = recycle
+ self._invalidate_time = 0
+ self._pre_ping = pre_ping
+ self._reset_on_return = util.symbol.parse_user_argument(
+ reset_on_return,
+ {
+ reset_rollback: ["rollback", True],
+ reset_none: ["none", None, False],
+ reset_commit: ["commit"],
+ },
+ "reset_on_return",
+ resolve_symbol_names=False,
+ )
+
+ self.echo = echo
+
+ if _dispatch:
+ self.dispatch._update(_dispatch, only_propagate=False)
+ if dialect:
+ self._dialect = dialect
+ if events:
+ for fn, target in events:
+ event.listen(self, target, fn)
+
+ @util.hybridproperty
+ def _is_asyncio(self):
+ return self._dialect.is_async
+
+ @property
+ def _creator(self):
+ return self.__dict__["_creator"]
+
+ @_creator.setter
+ def _creator(self, creator):
+ self.__dict__["_creator"] = creator
+ self._invoke_creator = self._should_wrap_creator(creator)
+
+ def _should_wrap_creator(self, creator):
+ """Detect if creator accepts a single argument, or is sent
+ as a legacy style no-arg function.
+
+ """
+
+ try:
+ argspec = util.get_callable_argspec(self._creator, no_self=True)
+ except TypeError:
+ return lambda crec: creator()
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ positionals = len(argspec[0]) - defaulted
+
+ # look for the exact arg signature that DefaultStrategy
+ # sends us
+ if (argspec[0], argspec[3]) == (["connection_record"], (None,)):
+ return creator
+ # or just a single positional
+ elif positionals == 1:
+ return creator
+ # all other cases, just wrap and assume legacy "creator" callable
+ # thing
+ else:
+ return lambda crec: creator()
+
+ def _close_connection(self, connection):
+ self.logger.debug("Closing connection %r", connection)
+
+ try:
+ self._dialect.do_close(connection)
+ except Exception:
+ self.logger.error(
+ "Exception closing connection %r", connection, exc_info=True
+ )
+
+ def _create_connection(self):
+ """Called by subclasses to create a new ConnectionRecord."""
+
+ return _ConnectionRecord(self)
+
+ def _invalidate(self, connection, exception=None, _checkin=True):
+ """Mark all connections established within the generation
+ of the given connection as invalidated.
+
+ If this pool's last invalidate time is before when the given
+ connection was created, update the timestamp til now. Otherwise,
+ no action is performed.
+
+ Connections with a start time prior to this pool's invalidation
+ time will be recycled upon next checkout.
+ """
+ rec = getattr(connection, "_connection_record", None)
+ if not rec or self._invalidate_time < rec.starttime:
+ self._invalidate_time = time.time()
+ if _checkin and getattr(connection, "is_valid", False):
+ connection.invalidate(exception)
+
+ def recreate(self):
+ """Return a new :class:`_pool.Pool`, of the same class as this one
+ and configured with identical creation arguments.
+
+ This method is used in conjunction with :meth:`dispose`
+ to close out an entire :class:`_pool.Pool` and create a new one in
+ its place.
+
+ """
+
+ raise NotImplementedError()
+
+ def dispose(self):
+ """Dispose of this pool.
+
+ This method leaves the possibility of checked-out connections
+ remaining open, as it only affects connections that are
+ idle in the pool.
+
+ .. seealso::
+
+ :meth:`Pool.recreate`
+
+ """
+
+ raise NotImplementedError()
+
+ def connect(self):
+ """Return a DBAPI connection from the pool.
+
+ The connection is instrumented such that when its
+ ``close()`` method is called, the connection will be returned to
+ the pool.
+
+ """
+ return _ConnectionFairy._checkout(self)
+
+ def _return_conn(self, record):
+ """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`.
+
+ This method is called when an instrumented DBAPI connection
+ has its ``close()`` method called.
+
+ """
+ self._do_return_conn(record)
+
+ def _do_get(self):
+ """Implementation for :meth:`get`, supplied by subclasses."""
+
+ raise NotImplementedError()
+
+ def _do_return_conn(self, conn):
+ """Implementation for :meth:`return_conn`, supplied by subclasses."""
+
+ raise NotImplementedError()
+
+ def status(self):
+ raise NotImplementedError()
+
+
+class _ConnectionRecord(object):
+
+ """Internal object which maintains an individual DBAPI connection
+ referenced by a :class:`_pool.Pool`.
+
+ The :class:`._ConnectionRecord` object always exists for any particular
+ DBAPI connection whether or not that DBAPI connection has been
+ "checked out". This is in contrast to the :class:`._ConnectionFairy`
+ which is only a public facade to the DBAPI connection while it is checked
+ out.
+
+ A :class:`._ConnectionRecord` may exist for a span longer than that
+ of a single DBAPI connection. For example, if the
+ :meth:`._ConnectionRecord.invalidate`
+ method is called, the DBAPI connection associated with this
+ :class:`._ConnectionRecord`
+ will be discarded, but the :class:`._ConnectionRecord` may be used again,
+ in which case a new DBAPI connection is produced when the
+ :class:`_pool.Pool`
+ next uses this record.
+
+ The :class:`._ConnectionRecord` is delivered along with connection
+ pool events, including :meth:`_events.PoolEvents.connect` and
+ :meth:`_events.PoolEvents.checkout`, however :class:`._ConnectionRecord`
+ still
+ remains an internal object whose API and internals may change.
+
+ .. seealso::
+
+ :class:`._ConnectionFairy`
+
+ """
+
+ def __init__(self, pool, connect=True):
+ self.__pool = pool
+ if connect:
+ self.__connect()
+ self.finalize_callback = deque()
+
+ fresh = False
+
+ fairy_ref = None
+
+ starttime = None
+
+ dbapi_connection = None
+ """A reference to the actual DBAPI connection being tracked.
+
+ May be ``None`` if this :class:`._ConnectionRecord` has been marked
+ as invalidated; a new DBAPI connection may replace it if the owning
+ pool calls upon this :class:`._ConnectionRecord` to reconnect.
+
+ For adapted drivers, like the Asyncio implementations, this is a
+ :class:`.AdaptedConnection` that adapts the driver connection
+ to the DBAPI protocol.
+ Use :attr:`._ConnectionRecord.driver_connection` to obtain the
+ connection objected returned by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect.
+
+ For normal sync drivers that support the DBAPI protocol, this object
+ is the same as the one referenced by
+ :attr:`._ConnectionRecord.dbapi_connection`.
+
+ For adapted drivers, like the Asyncio ones, this is the actual object
+ that was returned by the driver ``connect`` call.
+
+ As :attr:`._ConnectionRecord.dbapi_connection` it may be ``None``
+ if this :class:`._ConnectionRecord` has been marked as invalidated.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ if self.dbapi_connection is None:
+ return None
+ else:
+ return self.__pool._dialect.get_driver_connection(
+ self.dbapi_connection
+ )
+
+ @property
+ def connection(self):
+ """An alias to :attr:`._ConnectionRecord.dbapi_connection`.
+
+ This alias is deprecated, please use the new name.
+
+ .. deprecated:: 1.4.24
+
+ """
+ return self.dbapi_connection
+
+ @connection.setter
+ def connection(self, value):
+ self.dbapi_connection = value
+
+ _soft_invalidate_time = 0
+
+ @util.memoized_property
+ def info(self):
+ """The ``.info`` dictionary associated with the DBAPI connection.
+
+ This dictionary is shared among the :attr:`._ConnectionFairy.info`
+ and :attr:`_engine.Connection.info` accessors.
+
+ .. note::
+
+ The lifespan of this dictionary is linked to the
+ DBAPI connection itself, meaning that it is **discarded** each time
+ the DBAPI connection is closed and/or invalidated. The
+ :attr:`._ConnectionRecord.record_info` dictionary remains
+ persistent throughout the lifespan of the
+ :class:`._ConnectionRecord` container.
+
+ """
+ return {}
+
+ @util.memoized_property
+ def record_info(self):
+ """An "info' dictionary associated with the connection record
+ itself.
+
+ Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked
+ to the lifespan of the DBAPI connection, this dictionary is linked
+ to the lifespan of the :class:`._ConnectionRecord` container itself
+ and will remain persistent throughout the life of the
+ :class:`._ConnectionRecord`.
+
+ .. versionadded:: 1.1
+
+ """
+ return {}
+
+ @classmethod
+ def checkout(cls, pool):
+ rec = pool._do_get()
+ try:
+ dbapi_connection = rec.get_connection()
+ except Exception as err:
+ with util.safe_reraise():
+ rec._checkin_failed(err, _fairy_was_created=False)
+ echo = pool._should_log_debug()
+ fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+
+ rec.fairy_ref = ref = weakref.ref(
+ fairy,
+ lambda ref: _finalize_fairy
+ and _finalize_fairy(None, rec, pool, ref, echo, True),
+ )
+ _strong_ref_connection_records[ref] = rec
+ if echo:
+ pool.logger.debug(
+ "Connection %r checked out from pool", dbapi_connection
+ )
+ return fairy
+
+ def _checkin_failed(self, err, _fairy_was_created=True):
+ self.invalidate(e=err)
+ self.checkin(
+ _fairy_was_created=_fairy_was_created,
+ )
+
+ def checkin(self, _fairy_was_created=True):
+ if self.fairy_ref is None and _fairy_was_created:
+ # _fairy_was_created is False for the initial get connection phase;
+ # meaning there was no _ConnectionFairy and we must unconditionally
+ # do a checkin.
+ #
+ # otherwise, if fairy_was_created==True, if fairy_ref is None here
+ # that means we were checked in already, so this looks like
+ # a double checkin.
+ util.warn("Double checkin attempted on %s" % self)
+ return
+ self.fairy_ref = None
+ connection = self.dbapi_connection
+ pool = self.__pool
+ while self.finalize_callback:
+ finalizer = self.finalize_callback.pop()
+ finalizer(connection)
+ if pool.dispatch.checkin:
+ pool.dispatch.checkin(connection, self)
+
+ pool._return_conn(self)
+
+ @property
+ def in_use(self):
+ return self.fairy_ref is not None
+
+ @property
+ def last_connect_time(self):
+ return self.starttime
+
+ def close(self):
+ if self.dbapi_connection is not None:
+ self.__close()
+
+ def invalidate(self, e=None, soft=False):
+ """Invalidate the DBAPI connection held by this
+ :class:`._ConnectionRecord`.
+
+ This method is called for all connection invalidations, including
+ when the :meth:`._ConnectionFairy.invalidate` or
+ :meth:`_engine.Connection.invalidate` methods are called,
+ as well as when any
+ so-called "automatic invalidation" condition occurs.
+
+ :param e: an exception object indicating a reason for the
+ invalidation.
+
+ :param soft: if True, the connection isn't closed; instead, this
+ connection will be recycled on next checkout.
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+ # already invalidated
+ if self.dbapi_connection is None:
+ return
+ if soft:
+ self.__pool.dispatch.soft_invalidate(
+ self.dbapi_connection, self, e
+ )
+ else:
+ self.__pool.dispatch.invalidate(self.dbapi_connection, self, e)
+ if e is not None:
+ self.__pool.logger.info(
+ "%sInvalidate connection %r (reason: %s:%s)",
+ "Soft " if soft else "",
+ self.dbapi_connection,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ self.__pool.logger.info(
+ "%sInvalidate connection %r",
+ "Soft " if soft else "",
+ self.dbapi_connection,
+ )
+
+ if soft:
+ self._soft_invalidate_time = time.time()
+ else:
+ self.__close()
+ self.dbapi_connection = None
+
+ def get_connection(self):
+ recycle = False
+
+ # NOTE: the various comparisons here are assuming that measurable time
+ # passes between these state changes. however, time.time() is not
+ # guaranteed to have sub-second precision. comparisons of
+ # "invalidation time" to "starttime" should perhaps use >= so that the
+ # state change can take place assuming no measurable time has passed,
+ # however this does not guarantee correct behavior here as if time
+ # continues to not pass, it will try to reconnect repeatedly until
+ # these timestamps diverge, so in that sense using > is safer. Per
+ # https://stackoverflow.com/a/1938096/34549, Windows time.time() may be
+ # within 16 milliseconds accuracy, so unit tests for connection
+ # invalidation need a sleep of at least this long between initial start
+ # time and invalidation for the logic below to work reliably.
+ if self.dbapi_connection is None:
+ self.info.clear()
+ self.__connect()
+ elif (
+ self.__pool._recycle > -1
+ and time.time() - self.starttime > self.__pool._recycle
+ ):
+ self.__pool.logger.info(
+ "Connection %r exceeded timeout; recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+ elif self.__pool._invalidate_time > self.starttime:
+ self.__pool.logger.info(
+ "Connection %r invalidated due to pool invalidation; "
+ + "recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+ elif self._soft_invalidate_time > self.starttime:
+ self.__pool.logger.info(
+ "Connection %r invalidated due to local soft invalidation; "
+ + "recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+
+ if recycle:
+ self.__close()
+ self.info.clear()
+
+ self.__connect()
+ return self.dbapi_connection
+
+ def _is_hard_or_soft_invalidated(self):
+ return (
+ self.dbapi_connection is None
+ or self.__pool._invalidate_time > self.starttime
+ or (self._soft_invalidate_time > self.starttime)
+ )
+
+ def __close(self):
+ self.finalize_callback.clear()
+ if self.__pool.dispatch.close:
+ self.__pool.dispatch.close(self.dbapi_connection, self)
+ self.__pool._close_connection(self.dbapi_connection)
+ self.dbapi_connection = None
+
+ def __connect(self):
+ pool = self.__pool
+
+ # ensure any existing connection is removed, so that if
+ # creator fails, this attribute stays None
+ self.dbapi_connection = None
+ try:
+ self.starttime = time.time()
+ self.dbapi_connection = connection = pool._invoke_creator(self)
+ pool.logger.debug("Created new connection %r", connection)
+ self.fresh = True
+ except Exception as e:
+ with util.safe_reraise():
+ pool.logger.debug("Error on connect(): %s", e)
+ else:
+ # in SQLAlchemy 1.4 the first_connect event is not used by
+ # the engine, so this will usually not be set
+ if pool.dispatch.first_connect:
+ pool.dispatch.first_connect.for_modify(
+ pool.dispatch
+ ).exec_once_unless_exception(self.dbapi_connection, self)
+
+ # init of the dialect now takes place within the connect
+ # event, so ensure a mutex is used on the first run
+ pool.dispatch.connect.for_modify(
+ pool.dispatch
+ )._exec_w_sync_on_first_run(self.dbapi_connection, self)
+
+
+def _finalize_fairy(
+ dbapi_connection,
+ connection_record,
+ pool,
+ ref, # this is None when called directly, not by the gc
+ echo,
+ reset=True,
+ fairy=None,
+):
+ """Cleanup for a :class:`._ConnectionFairy` whether or not it's already
+ been garbage collected.
+
+ When using an async dialect no IO can happen here (without using
+ a dedicated thread), since this is called outside the greenlet
+ context and with an already running loop. In this case function
+ will only log a message and raise a warning.
+ """
+
+ if ref:
+ _strong_ref_connection_records.pop(ref, None)
+ elif fairy:
+ _strong_ref_connection_records.pop(weakref.ref(fairy), None)
+
+ if ref is not None:
+ if connection_record.fairy_ref is not ref:
+ return
+ assert dbapi_connection is None
+ dbapi_connection = connection_record.dbapi_connection
+
+ # null pool is not _is_asyncio but can be used also with async dialects
+ dont_restore_gced = pool._dialect.is_async
+
+ if dont_restore_gced:
+ detach = not connection_record or ref
+ can_manipulate_connection = not ref
+ else:
+ detach = not connection_record
+ can_manipulate_connection = True
+
+ if dbapi_connection is not None:
+ if connection_record and echo:
+ pool.logger.debug(
+ "Connection %r being returned to pool%s",
+ dbapi_connection,
+ ", transaction state was already reset by caller"
+ if not reset
+ else "",
+ )
+
+ try:
+ fairy = fairy or _ConnectionFairy(
+ dbapi_connection,
+ connection_record,
+ echo,
+ )
+ assert fairy.dbapi_connection is dbapi_connection
+ if reset and can_manipulate_connection:
+ fairy._reset(pool)
+
+ if detach:
+ if connection_record:
+ fairy._pool = pool
+ fairy.detach()
+
+ if can_manipulate_connection:
+ if pool.dispatch.close_detached:
+ pool.dispatch.close_detached(dbapi_connection)
+
+ pool._close_connection(dbapi_connection)
+ else:
+ message = (
+ "The garbage collector is trying to clean up "
+ "connection %r. This feature is unsupported on async "
+ "dbapi, since no IO can be performed at this stage to "
+ "reset the connection. Please close out all "
+ "connections when they are no longer used, calling "
+ "``close()`` or using a context manager to "
+ "manage their lifetime."
+ ) % dbapi_connection
+ pool.logger.error(message)
+ util.warn(message)
+
+ except BaseException as e:
+ pool.logger.error(
+ "Exception during reset or similar", exc_info=True
+ )
+ if connection_record:
+ connection_record.invalidate(e=e)
+ if not isinstance(e, Exception):
+ raise
+
+ if connection_record and connection_record.fairy_ref is not None:
+ connection_record.checkin()
+
+
+# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that
+# GC under pypy will call ConnectionFairy finalizers. linked directly to the
+# weakref that will empty itself when collected so that it should not create
+# any unmanaged memory references.
+_strong_ref_connection_records = {}
+
+
+class _ConnectionFairy(object):
+
+ """Proxies a DBAPI connection and provides return-on-dereference
+ support.
+
+ This is an internal object used by the :class:`_pool.Pool` implementation
+ to provide context management to a DBAPI connection delivered by
+ that :class:`_pool.Pool`.
+
+ The name "fairy" is inspired by the fact that the
+ :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts
+ only for the length of a specific DBAPI connection being checked out from
+ the pool, and additionally that as a transparent proxy, it is mostly
+ invisible.
+
+ .. seealso::
+
+ :class:`._ConnectionRecord`
+
+ """
+
+ def __init__(self, dbapi_connection, connection_record, echo):
+ self.dbapi_connection = dbapi_connection
+ self._connection_record = connection_record
+ self._echo = echo
+
+ dbapi_connection = None
+ """A reference to the actual DBAPI connection being tracked.
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :attr:`._ConnectionFairy.driver_connection`
+
+ :attr:`._ConnectionRecord.dbapi_connection`
+
+ :ref:`faq_dbapi_connection`
+
+ """
+
+ _connection_record = None
+ """A reference to the :class:`._ConnectionRecord` object associated
+ with the DBAPI connection.
+
+ This is currently an internal accessor which is subject to change.
+
+ """
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect.
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :attr:`._ConnectionFairy.dbapi_connection`
+
+ :attr:`._ConnectionRecord.driver_connection`
+
+ :ref:`faq_dbapi_connection`
+
+ """
+ return self._connection_record.driver_connection
+
+ @property
+ def connection(self):
+ """An alias to :attr:`._ConnectionFairy.dbapi_connection`.
+
+ This alias is deprecated, please use the new name.
+
+ .. deprecated:: 1.4.24
+
+ """
+ return self.dbapi_connection
+
+ @connection.setter
+ def connection(self, value):
+ self.dbapi_connection = value
+
+ @classmethod
+ def _checkout(cls, pool, threadconns=None, fairy=None):
+ if not fairy:
+ fairy = _ConnectionRecord.checkout(pool)
+
+ fairy._pool = pool
+ fairy._counter = 0
+
+ if threadconns is not None:
+ threadconns.current = weakref.ref(fairy)
+
+ if fairy.dbapi_connection is None:
+ raise exc.InvalidRequestError("This connection is closed")
+ fairy._counter += 1
+ if (
+ not pool.dispatch.checkout and not pool._pre_ping
+ ) or fairy._counter != 1:
+ return fairy
+
+ # Pool listeners can trigger a reconnection on checkout, as well
+ # as the pre-pinger.
+ # there are three attempts made here, but note that if the database
+ # is not accessible from a connection standpoint, those won't proceed
+ # here.
+ attempts = 2
+ while attempts > 0:
+ connection_is_fresh = fairy._connection_record.fresh
+ fairy._connection_record.fresh = False
+ try:
+ if pool._pre_ping:
+ if not connection_is_fresh:
+ if fairy._echo:
+ pool.logger.debug(
+ "Pool pre-ping on connection %s",
+ fairy.dbapi_connection,
+ )
+ result = pool._dialect.do_ping(fairy.dbapi_connection)
+ if not result:
+ if fairy._echo:
+ pool.logger.debug(
+ "Pool pre-ping on connection %s failed, "
+ "will invalidate pool",
+ fairy.dbapi_connection,
+ )
+ raise exc.InvalidatePoolError()
+ elif fairy._echo:
+ pool.logger.debug(
+ "Connection %s is fresh, skipping pre-ping",
+ fairy.dbapi_connection,
+ )
+
+ pool.dispatch.checkout(
+ fairy.dbapi_connection, fairy._connection_record, fairy
+ )
+ return fairy
+ except exc.DisconnectionError as e:
+ if e.invalidate_pool:
+ pool.logger.info(
+ "Disconnection detected on checkout, "
+ "invalidating all pooled connections prior to "
+ "current timestamp (reason: %r)",
+ e,
+ )
+ fairy._connection_record.invalidate(e)
+ pool._invalidate(fairy, e, _checkin=False)
+ else:
+ pool.logger.info(
+ "Disconnection detected on checkout, "
+ "invalidating individual connection %s (reason: %r)",
+ fairy.dbapi_connection,
+ e,
+ )
+ fairy._connection_record.invalidate(e)
+ try:
+ fairy.dbapi_connection = (
+ fairy._connection_record.get_connection()
+ )
+ except Exception as err:
+ with util.safe_reraise():
+ fairy._connection_record._checkin_failed(
+ err,
+ _fairy_was_created=True,
+ )
+
+ # prevent _ConnectionFairy from being carried
+ # in the stack trace. Do this after the
+ # connection record has been checked in, so that
+ # if the del triggers a finalize fairy, it won't
+ # try to checkin a second time.
+ del fairy
+
+ attempts -= 1
+
+ pool.logger.info("Reconnection attempts exhausted on checkout")
+ fairy.invalidate()
+ raise exc.InvalidRequestError("This connection is closed")
+
+ def _checkout_existing(self):
+ return _ConnectionFairy._checkout(self._pool, fairy=self)
+
+ def _checkin(self, reset=True):
+ _finalize_fairy(
+ self.dbapi_connection,
+ self._connection_record,
+ self._pool,
+ None,
+ self._echo,
+ reset=reset,
+ fairy=self,
+ )
+ self.dbapi_connection = None
+ self._connection_record = None
+
+ _close = _checkin
+
+ def _reset(self, pool):
+ if pool.dispatch.reset:
+ pool.dispatch.reset(self, self._connection_record)
+ if pool._reset_on_return is reset_rollback:
+ if self._echo:
+ pool.logger.debug(
+ "Connection %s rollback-on-return", self.dbapi_connection
+ )
+ pool._dialect.do_rollback(self)
+ elif pool._reset_on_return is reset_commit:
+ if self._echo:
+ pool.logger.debug(
+ "Connection %s commit-on-return",
+ self.dbapi_connection,
+ )
+ pool._dialect.do_commit(self)
+
+ @property
+ def _logger(self):
+ return self._pool.logger
+
+ @property
+ def is_valid(self):
+ """Return True if this :class:`._ConnectionFairy` still refers
+ to an active DBAPI connection."""
+
+ return self.dbapi_connection is not None
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the underlying DBAPI connection
+ referred to by this :class:`.ConnectionFairy`, allowing user-defined
+ data to be associated with the connection.
+
+ The data here will follow along with the DBAPI connection including
+ after it is returned to the connection pool and used again
+ in subsequent instances of :class:`._ConnectionFairy`. It is shared
+ with the :attr:`._ConnectionRecord.info` and
+ :attr:`_engine.Connection.info`
+ accessors.
+
+ The dictionary associated with a particular DBAPI connection is
+ discarded when the connection itself is discarded.
+
+ """
+ return self._connection_record.info
+
+ @property
+ def record_info(self):
+ """Info dictionary associated with the :class:`._ConnectionRecord
+ container referred to by this :class:`.ConnectionFairy`.
+
+ Unlike the :attr:`._ConnectionFairy.info` dictionary, the lifespan
+ of this dictionary is persistent across connections that are
+ disconnected and/or invalidated within the lifespan of a
+ :class:`._ConnectionRecord`.
+
+ .. versionadded:: 1.1
+
+ """
+ if self._connection_record:
+ return self._connection_record.record_info
+ else:
+ return None
+
+ def invalidate(self, e=None, soft=False):
+ """Mark this connection as invalidated.
+
+ This method can be called directly, and is also called as a result
+ of the :meth:`_engine.Connection.invalidate` method. When invoked,
+ the DBAPI connection is immediately closed and discarded from
+ further use by the pool. The invalidation mechanism proceeds
+ via the :meth:`._ConnectionRecord.invalidate` internal method.
+
+ :param e: an exception object indicating a reason for the invalidation.
+
+ :param soft: if True, the connection isn't closed; instead, this
+ connection will be recycled on next checkout.
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ if self.dbapi_connection is None:
+ util.warn("Can't invalidate an already-closed connection.")
+ return
+ if self._connection_record:
+ self._connection_record.invalidate(e=e, soft=soft)
+ if not soft:
+ self.dbapi_connection = None
+ self._checkin()
+
+ def cursor(self, *args, **kwargs):
+ """Return a new DBAPI cursor for the underlying connection.
+
+ This method is a proxy for the ``connection.cursor()`` DBAPI
+ method.
+
+ """
+ return self.dbapi_connection.cursor(*args, **kwargs)
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi_connection, key)
+
+ def detach(self):
+ """Separate this connection from its Pool.
+
+ This means that the connection will no longer be returned to the
+ pool when closed, and will instead be literally closed. The
+ containing ConnectionRecord is separated from the DB-API connection,
+ and will create a new connection when next used.
+
+ Note that any overall connection limiting constraints imposed by a
+ Pool implementation may be violated after a detach, as the detached
+ connection is removed from the pool's knowledge and control.
+ """
+
+ if self._connection_record is not None:
+ rec = self._connection_record
+ rec.fairy_ref = None
+ rec.dbapi_connection = None
+ # TODO: should this be _return_conn?
+ self._pool._do_return_conn(self._connection_record)
+ self.info = self.info.copy()
+ self._connection_record = None
+
+ if self._pool.dispatch.detach:
+ self._pool.dispatch.detach(self.dbapi_connection, rec)
+
+ def close(self):
+ self._counter -= 1
+ if self._counter == 0:
+ self._checkin()
+
+ def _close_no_reset(self):
+ self._counter -= 1
+ if self._counter == 0:
+ self._checkin(reset=False)
diff --git a/lib/sqlalchemy/pool/dbapi_proxy.py b/lib/sqlalchemy/pool/dbapi_proxy.py
new file mode 100644
index 0000000..b0c40f2
--- /dev/null
+++ b/lib/sqlalchemy/pool/dbapi_proxy.py
@@ -0,0 +1,147 @@
+# sqlalchemy/pool/dbapi_proxy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""DBAPI proxy utility.
+
+Provides transparent connection pooling on top of a Python DBAPI.
+
+This is legacy SQLAlchemy functionality that is not typically used
+today.
+
+"""
+
+from .impl import QueuePool
+from .. import util
+from ..util import threading
+
+proxies = {}
+
+
+@util.deprecated(
+ "1.3",
+ "The :func:`.pool.manage` function is deprecated, and will be "
+ "removed in a future release.",
+)
+def manage(module, **params):
+ r"""Return a proxy for a DB-API module that automatically
+ pools connections.
+
+ Given a DB-API 2.0 module and pool management parameters, returns
+ a proxy for the module that will automatically pool connections,
+ creating new connection pools for each distinct set of connection
+ arguments sent to the decorated module's connect() function.
+
+ :param module: a DB-API 2.0 database module
+
+ :param poolclass: the class used by the pool module to provide
+ pooling. Defaults to :class:`.QueuePool`.
+
+ :param \**params: will be passed through to *poolclass*
+
+ """
+ try:
+ return proxies[module]
+ except KeyError:
+ return proxies.setdefault(module, _DBProxy(module, **params))
+
+
+def clear_managers():
+ """Remove all current DB-API 2.0 managers.
+
+ All pools and connections are disposed.
+ """
+
+ for manager in proxies.values():
+ manager.close()
+ proxies.clear()
+
+
+class _DBProxy(object):
+
+ """Layers connection pooling behavior on top of a standard DB-API module.
+
+ Proxies a DB-API 2.0 connect() call to a connection pool keyed to the
+ specific connect parameters. Other functions and attributes are delegated
+ to the underlying DB-API module.
+ """
+
+ def __init__(self, module, poolclass=QueuePool, **kw):
+ """Initializes a new proxy.
+
+ module
+ a DB-API 2.0 module
+
+ poolclass
+ a Pool class, defaulting to QueuePool
+
+ Other parameters are sent to the Pool object's constructor.
+
+ """
+
+ self.module = module
+ self.kw = kw
+ self.poolclass = poolclass
+ self.pools = {}
+ self._create_pool_mutex = threading.Lock()
+
+ def close(self):
+ for key in list(self.pools):
+ del self.pools[key]
+
+ def __del__(self):
+ self.close()
+
+ def __getattr__(self, key):
+ return getattr(self.module, key)
+
+ def get_pool(self, *args, **kw):
+ key = self._serialize(*args, **kw)
+ try:
+ return self.pools[key]
+ except KeyError:
+ with self._create_pool_mutex:
+ if key not in self.pools:
+ kw.pop("sa_pool_key", None)
+ pool = self.poolclass(
+ lambda: self.module.connect(*args, **kw), **self.kw
+ )
+ self.pools[key] = pool
+ return pool
+ else:
+ return self.pools[key]
+
+ def connect(self, *args, **kw):
+ """Activate a connection to the database.
+
+ Connect to the database using this DBProxy's module and the given
+ connect arguments. If the arguments match an existing pool, the
+ connection will be returned from the pool's current thread-local
+ connection instance, or if there is no thread-local connection
+ instance it will be checked out from the set of pooled connections.
+
+ If the pool has no available connections and allows new connections
+ to be created, a new database connection will be made.
+
+ """
+
+ return self.get_pool(*args, **kw).connect()
+
+ def dispose(self, *args, **kw):
+ """Dispose the pool referenced by the given connect arguments."""
+
+ key = self._serialize(*args, **kw)
+ try:
+ del self.pools[key]
+ except KeyError:
+ pass
+
+ def _serialize(self, *args, **kw):
+ if "sa_pool_key" in kw:
+ return kw["sa_pool_key"]
+
+ return tuple(list(args) + [(k, kw[k]) for k in sorted(kw)])
diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py
new file mode 100644
index 0000000..2829a58
--- /dev/null
+++ b/lib/sqlalchemy/pool/events.py
@@ -0,0 +1,284 @@
+# sqlalchemy/pool/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import Pool
+from .. import event
+from ..engine.base import Engine
+
+
+class PoolEvents(event.Events):
+ """Available events for :class:`_pool.Pool`.
+
+ The methods here define the name of an event as well
+ as the names of members that are passed to listener
+ functions.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_on_checkout(dbapi_conn, connection_rec, connection_proxy):
+ "handle an on checkout event"
+
+ event.listen(Pool, 'checkout', my_on_checkout)
+
+ In addition to accepting the :class:`_pool.Pool` class and
+ :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts
+ :class:`_engine.Engine` objects and the :class:`_engine.Engine` class as
+ targets, which will be resolved to the ``.pool`` attribute of the
+ given engine or the :class:`_pool.Pool` class::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+
+ # will associate with engine.pool
+ event.listen(engine, 'checkout', my_on_checkout)
+
+ """
+
+ _target_class_doc = "SomeEngineOrPool"
+ _dispatch_target = Pool
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ if issubclass(target, Engine):
+ return Pool
+ elif issubclass(target, Pool):
+ return target
+ elif isinstance(target, Engine):
+ return target.pool
+ elif isinstance(target, Pool):
+ return target
+ elif hasattr(target, "dispatch") and hasattr(
+ target.dispatch._events, "_no_async_engine_events"
+ ):
+ target.dispatch._events._no_async_engine_events()
+ else:
+ return None
+
+ @classmethod
+ def _listen(cls, event_key, **kw):
+ target = event_key.dispatch_target
+
+ kw.setdefault("asyncio", target._is_asyncio)
+
+ event_key.base_listen(**kw)
+
+ def connect(self, dbapi_connection, connection_record):
+ """Called at the moment a particular DBAPI connection is first
+ created for a given :class:`_pool.Pool`.
+
+ This event allows one to capture the point directly after which
+ the DBAPI module-level ``.connect()`` method has been used in order
+ to produce a new DBAPI connection.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def first_connect(self, dbapi_connection, connection_record):
+ """Called exactly once for the first time a DBAPI connection is
+ checked out from a particular :class:`_pool.Pool`.
+
+ The rationale for :meth:`_events.PoolEvents.first_connect`
+ is to determine
+ information about a particular series of database connections based
+ on the settings used for all connections. Since a particular
+ :class:`_pool.Pool`
+ refers to a single "creator" function (which in terms
+ of a :class:`_engine.Engine`
+ refers to the URL and connection options used),
+ it is typically valid to make observations about a single connection
+ that can be safely assumed to be valid about all subsequent
+ connections, such as the database version, the server and client
+ encoding settings, collation settings, and many others.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def checkout(self, dbapi_connection, connection_record, connection_proxy):
+ """Called when a connection is retrieved from the Pool.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param connection_proxy: the :class:`._ConnectionFairy` object which
+ will proxy the public interface of the DBAPI connection for the
+ lifespan of the checkout.
+
+ If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current
+ connection will be disposed and a fresh connection retrieved.
+ Processing of all checkout listeners will abort and restart
+ using the new connection.
+
+ .. seealso:: :meth:`_events.ConnectionEvents.engine_connect`
+ - a similar event
+ which occurs upon creation of a new :class:`_engine.Connection`.
+
+ """
+
+ def checkin(self, dbapi_connection, connection_record):
+ """Called when a connection returns to the pool.
+
+ Note that the connection may be closed, and may be None if the
+ connection has been invalidated. ``checkin`` will not be called
+ for detached connections. (They do not return to the pool.)
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def reset(self, dbapi_connection, connection_record):
+ """Called before the "reset" action occurs for a pooled connection.
+
+ This event represents
+ when the ``rollback()`` method is called on the DBAPI connection
+ before it is returned to the pool. The behavior of "reset" can
+ be controlled, including disabled, using the ``reset_on_return``
+ pool argument.
+
+
+ The :meth:`_events.PoolEvents.reset` event is usually followed by the
+ :meth:`_events.PoolEvents.checkin` event is called, except in those
+ cases where the connection is discarded immediately after reset.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.rollback`
+
+ :meth:`_events.ConnectionEvents.commit`
+
+ """
+
+ def invalidate(self, dbapi_connection, connection_record, exception):
+ """Called when a DBAPI connection is to be "invalidated".
+
+ This event is called any time the :meth:`._ConnectionRecord.invalidate`
+ method is invoked, either from API usage or via "auto-invalidation",
+ without the ``soft`` flag.
+
+ The event occurs before a final attempt to call ``.close()`` on the
+ connection occurs.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param exception: the exception object corresponding to the reason
+ for this invalidation, if any. May be ``None``.
+
+ .. versionadded:: 0.9.2 Added support for connection invalidation
+ listening.
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ def soft_invalidate(self, dbapi_connection, connection_record, exception):
+ """Called when a DBAPI connection is to be "soft invalidated".
+
+ This event is called any time the :meth:`._ConnectionRecord.invalidate`
+ method is invoked with the ``soft`` flag.
+
+ Soft invalidation refers to when the connection record that tracks
+ this connection will force a reconnect after the current connection
+ is checked in. It does not actively close the dbapi_connection
+ at the point at which it is called.
+
+ .. versionadded:: 1.0.3
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param exception: the exception object corresponding to the reason
+ for this invalidation, if any. May be ``None``.
+
+ """
+
+ def close(self, dbapi_connection, connection_record):
+ """Called when a DBAPI connection is closed.
+
+ The event is emitted before the close occurs.
+
+ The close of a connection can fail; typically this is because
+ the connection is already closed. If the close operation fails,
+ the connection is discarded.
+
+ The :meth:`.close` event corresponds to a connection that's still
+ associated with the pool. To intercept close events for detached
+ connections use :meth:`.close_detached`.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def detach(self, dbapi_connection, connection_record):
+ """Called when a DBAPI connection is "detached" from a pool.
+
+ This event is emitted after the detach occurs. The connection
+ is no longer associated with the given connection record.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def close_detached(self, dbapi_connection):
+ """Called when a detached DBAPI connection is closed.
+
+ The event is emitted before the close occurs.
+
+ The close of a connection can fail; typically this is because
+ the connection is already closed. If the close operation fails,
+ the connection is discarded.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ """
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py
new file mode 100644
index 0000000..91d0290
--- /dev/null
+++ b/lib/sqlalchemy/pool/impl.py
@@ -0,0 +1,514 @@
+# sqlalchemy/pool.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Pool implementation classes.
+
+"""
+
+import traceback
+import weakref
+
+from .base import _AsyncConnDialect
+from .base import _ConnectionFairy
+from .base import _ConnectionRecord
+from .base import Pool
+from .. import exc
+from .. import util
+from ..util import chop_traceback
+from ..util import queue as sqla_queue
+from ..util import threading
+
+
+class QueuePool(Pool):
+
+ """A :class:`_pool.Pool`
+ that imposes a limit on the number of open connections.
+
+ :class:`.QueuePool` is the default pooling implementation used for
+ all :class:`_engine.Engine` objects, unless the SQLite dialect is in use.
+
+ """
+
+ _is_asyncio = False
+ _queue_class = sqla_queue.Queue
+
+ def __init__(
+ self,
+ creator,
+ pool_size=5,
+ max_overflow=10,
+ timeout=30.0,
+ use_lifo=False,
+ **kw
+ ):
+ r"""
+ Construct a QueuePool.
+
+ :param creator: a callable function that returns a DB-API
+ connection object, same as that of :paramref:`_pool.Pool.creator`.
+
+ :param pool_size: The size of the pool to be maintained,
+ defaults to 5. This is the largest number of connections that
+ will be kept persistently in the pool. Note that the pool
+ begins with no connections; once this number of connections
+ is requested, that number of connections will remain.
+ ``pool_size`` can be set to 0 to indicate no size limit; to
+ disable pooling, use a :class:`~sqlalchemy.pool.NullPool`
+ instead.
+
+ :param max_overflow: The maximum overflow size of the
+ pool. When the number of checked-out connections reaches the
+ size set in pool_size, additional connections will be
+ returned up to this limit. When those additional connections
+ are returned to the pool, they are disconnected and
+ discarded. It follows then that the total number of
+ simultaneous connections the pool will allow is pool_size +
+ `max_overflow`, and the total number of "sleeping"
+ connections the pool will allow is pool_size. `max_overflow`
+ can be set to -1 to indicate no overflow limit; no limit
+ will be placed on the total number of concurrent
+ connections. Defaults to 10.
+
+ :param timeout: The number of seconds to wait before giving up
+ on returning a connection. Defaults to 30.0. This can be a float
+ but is subject to the limitations of Python time functions which
+ may not be reliable in the tens of milliseconds.
+
+ :param use_lifo: use LIFO (last-in-first-out) when retrieving
+ connections instead of FIFO (first-in-first-out). Using LIFO, a
+ server-side timeout scheme can reduce the number of connections used
+ during non-peak periods of use. When planning for server-side
+ timeouts, ensure that a recycle or pre-ping strategy is in use to
+ gracefully handle stale connections.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`pool_use_lifo`
+
+ :ref:`pool_disconnects`
+
+ :param \**kw: Other keyword arguments including
+ :paramref:`_pool.Pool.recycle`, :paramref:`_pool.Pool.echo`,
+ :paramref:`_pool.Pool.reset_on_return` and others are passed to the
+ :class:`_pool.Pool` constructor.
+
+ """
+ Pool.__init__(self, creator, **kw)
+ self._pool = self._queue_class(pool_size, use_lifo=use_lifo)
+ self._overflow = 0 - pool_size
+ self._max_overflow = max_overflow
+ self._timeout = timeout
+ self._overflow_lock = threading.Lock()
+
+ def _do_return_conn(self, conn):
+ try:
+ self._pool.put(conn, False)
+ except sqla_queue.Full:
+ try:
+ conn.close()
+ finally:
+ self._dec_overflow()
+
+ def _do_get(self):
+ use_overflow = self._max_overflow > -1
+
+ try:
+ wait = use_overflow and self._overflow >= self._max_overflow
+ return self._pool.get(wait, self._timeout)
+ except sqla_queue.Empty:
+ # don't do things inside of "except Empty", because when we say
+ # we timed out or can't connect and raise, Python 3 tells
+ # people the real error is queue.Empty which it isn't.
+ pass
+ if use_overflow and self._overflow >= self._max_overflow:
+ if not wait:
+ return self._do_get()
+ else:
+ raise exc.TimeoutError(
+ "QueuePool limit of size %d overflow %d reached, "
+ "connection timed out, timeout %0.2f"
+ % (self.size(), self.overflow(), self._timeout),
+ code="3o7r",
+ )
+
+ if self._inc_overflow():
+ try:
+ return self._create_connection()
+ except:
+ with util.safe_reraise():
+ self._dec_overflow()
+ else:
+ return self._do_get()
+
+ def _inc_overflow(self):
+ if self._max_overflow == -1:
+ self._overflow += 1
+ return True
+ with self._overflow_lock:
+ if self._overflow < self._max_overflow:
+ self._overflow += 1
+ return True
+ else:
+ return False
+
+ def _dec_overflow(self):
+ if self._max_overflow == -1:
+ self._overflow -= 1
+ return True
+ with self._overflow_lock:
+ self._overflow -= 1
+ return True
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ pool_size=self._pool.maxsize,
+ max_overflow=self._max_overflow,
+ pre_ping=self._pre_ping,
+ use_lifo=self._pool.use_lifo,
+ timeout=self._timeout,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ while True:
+ try:
+ conn = self._pool.get(False)
+ conn.close()
+ except sqla_queue.Empty:
+ break
+
+ self._overflow = 0 - self.size()
+ self.logger.info("Pool disposed. %s", self.status())
+
+ def status(self):
+ return (
+ "Pool size: %d Connections in pool: %d "
+ "Current Overflow: %d Current Checked out "
+ "connections: %d"
+ % (
+ self.size(),
+ self.checkedin(),
+ self.overflow(),
+ self.checkedout(),
+ )
+ )
+
+ def size(self):
+ return self._pool.maxsize
+
+ def timeout(self):
+ return self._timeout
+
+ def checkedin(self):
+ return self._pool.qsize()
+
+ def overflow(self):
+ return self._overflow
+
+ def checkedout(self):
+ return self._pool.maxsize - self._pool.qsize() + self._overflow
+
+
+class AsyncAdaptedQueuePool(QueuePool):
+ _is_asyncio = True
+ _queue_class = sqla_queue.AsyncAdaptedQueue
+ _dialect = _AsyncConnDialect()
+
+
+class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool):
+ _queue_class = sqla_queue.FallbackAsyncAdaptedQueue
+
+
+class NullPool(Pool):
+
+ """A Pool which does not pool connections.
+
+ Instead it literally opens and closes the underlying DB-API connection
+ per each connection open/close.
+
+ Reconnect-related functions such as ``recycle`` and connection
+ invalidation are not supported by this Pool implementation, since
+ no connections are held persistently.
+
+ """
+
+ def status(self):
+ return "NullPool"
+
+ def _do_return_conn(self, conn):
+ conn.close()
+
+ def _do_get(self):
+ return self._create_connection()
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+
+ return self.__class__(
+ self._creator,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ pre_ping=self._pre_ping,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ pass
+
+
+class SingletonThreadPool(Pool):
+
+ """A Pool that maintains one connection per thread.
+
+ Maintains one connection per each thread, never moving a connection to a
+ thread other than the one which it was created in.
+
+ .. warning:: the :class:`.SingletonThreadPool` will call ``.close()``
+ on arbitrary connections that exist beyond the size setting of
+ ``pool_size``, e.g. if more unique **thread identities**
+ than what ``pool_size`` states are used. This cleanup is
+ non-deterministic and not sensitive to whether or not the connections
+ linked to those thread identities are currently in use.
+
+ :class:`.SingletonThreadPool` may be improved in a future release,
+ however in its current status it is generally used only for test
+ scenarios using a SQLite ``:memory:`` database and is not recommended
+ for production use.
+
+
+ Options are the same as those of :class:`_pool.Pool`, as well as:
+
+ :param pool_size: The number of threads in which to maintain connections
+ at once. Defaults to five.
+
+ :class:`.SingletonThreadPool` is used by the SQLite dialect
+ automatically when a memory-based database is used.
+ See :ref:`sqlite_toplevel`.
+
+ """
+
+ _is_asyncio = False
+
+ def __init__(self, creator, pool_size=5, **kw):
+ Pool.__init__(self, creator, **kw)
+ self._conn = threading.local()
+ self._fairy = threading.local()
+ self._all_conns = set()
+ self.size = pool_size
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ pool_size=self.size,
+ recycle=self._recycle,
+ echo=self.echo,
+ pre_ping=self._pre_ping,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ """Dispose of this pool."""
+
+ for conn in self._all_conns:
+ try:
+ conn.close()
+ except Exception:
+ # pysqlite won't even let you close a conn from a thread
+ # that didn't create it
+ pass
+
+ self._all_conns.clear()
+
+ def _cleanup(self):
+ while len(self._all_conns) >= self.size:
+ c = self._all_conns.pop()
+ c.close()
+
+ def status(self):
+ return "SingletonThreadPool id:%d size: %d" % (
+ id(self),
+ len(self._all_conns),
+ )
+
+ def _do_return_conn(self, conn):
+ pass
+
+ def _do_get(self):
+ try:
+ c = self._conn.current()
+ if c:
+ return c
+ except AttributeError:
+ pass
+ c = self._create_connection()
+ self._conn.current = weakref.ref(c)
+ if len(self._all_conns) >= self.size:
+ self._cleanup()
+ self._all_conns.add(c)
+ return c
+
+ def connect(self):
+ # vendored from Pool to include the now removed use_threadlocal
+ # behavior
+ try:
+ rec = self._fairy.current()
+ except AttributeError:
+ pass
+ else:
+ if rec is not None:
+ return rec._checkout_existing()
+
+ return _ConnectionFairy._checkout(self, self._fairy)
+
+ def _return_conn(self, record):
+ try:
+ del self._fairy.current
+ except AttributeError:
+ pass
+ self._do_return_conn(record)
+
+
+class StaticPool(Pool):
+
+ """A Pool of exactly one connection, used for all requests.
+
+ Reconnect-related functions such as ``recycle`` and connection
+ invalidation (which is also used to support auto-reconnect) are only
+ partially supported right now and may not yield good results.
+
+
+ """
+
+ @util.memoized_property
+ def connection(self):
+ return _ConnectionRecord(self)
+
+ def status(self):
+ return "StaticPool"
+
+ def dispose(self):
+ if (
+ "connection" in self.__dict__
+ and self.connection.dbapi_connection is not None
+ ):
+ self.connection.close()
+ del self.__dict__["connection"]
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ creator=self._creator,
+ recycle=self._recycle,
+ reset_on_return=self._reset_on_return,
+ pre_ping=self._pre_ping,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def _transfer_from(self, other_static_pool):
+ # used by the test suite to make a new engine / pool without
+ # losing the state of an existing SQLite :memory: connection
+ self._invoke_creator = (
+ lambda crec: other_static_pool.connection.dbapi_connection
+ )
+
+ def _create_connection(self):
+ raise NotImplementedError()
+
+ def _do_return_conn(self, conn):
+ pass
+
+ def _do_get(self):
+ rec = self.connection
+ if rec._is_hard_or_soft_invalidated():
+ del self.__dict__["connection"]
+ rec = self.connection
+
+ return rec
+
+
+class AssertionPool(Pool):
+
+ """A :class:`_pool.Pool` that allows at most one checked out connection at
+ any given time.
+
+ This will raise an exception if more than one connection is checked out
+ at a time. Useful for debugging code that is using more connections
+ than desired.
+
+ """
+
+ def __init__(self, *args, **kw):
+ self._conn = None
+ self._checked_out = False
+ self._store_traceback = kw.pop("store_traceback", True)
+ self._checkout_traceback = None
+ Pool.__init__(self, *args, **kw)
+
+ def status(self):
+ return "AssertionPool"
+
+ def _do_return_conn(self, conn):
+ if not self._checked_out:
+ raise AssertionError("connection is not checked out")
+ self._checked_out = False
+ assert conn is self._conn
+
+ def dispose(self):
+ self._checked_out = False
+ if self._conn:
+ self._conn.close()
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ echo=self.echo,
+ pre_ping=self._pre_ping,
+ recycle=self._recycle,
+ reset_on_return=self._reset_on_return,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def _do_get(self):
+ if self._checked_out:
+ if self._checkout_traceback:
+ suffix = " at:\n%s" % "".join(
+ chop_traceback(self._checkout_traceback)
+ )
+ else:
+ suffix = ""
+ raise AssertionError("connection is already checked out" + suffix)
+
+ if not self._conn:
+ self._conn = self._create_connection()
+
+ self._checked_out = True
+ if self._store_traceback:
+ self._checkout_traceback = traceback.format_stack()
+ return self._conn
diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py
new file mode 100644
index 0000000..e7f388f
--- /dev/null
+++ b/lib/sqlalchemy/processors.py
@@ -0,0 +1,176 @@
+# sqlalchemy/processors.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""defines generic type conversion functions, as used in bind and result
+processors.
+
+They all share one common characteristic: None is passed through unchanged.
+
+"""
+
+import codecs
+import datetime
+import re
+
+from . import util
+
+
+def str_to_datetime_processor_factory(regexp, type_):
+ rmatch = regexp.match
+ # Even on python2.6 datetime.strptime is both slower than this code
+ # and it does not support microseconds.
+ has_named_groups = bool(regexp.groupindex)
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ try:
+ m = rmatch(value)
+ except TypeError as err:
+ util.raise_(
+ ValueError(
+ "Couldn't parse %s string '%r' "
+ "- value is not a string." % (type_.__name__, value)
+ ),
+ from_=err,
+ )
+ if m is None:
+ raise ValueError(
+ "Couldn't parse %s string: "
+ "'%s'" % (type_.__name__, value)
+ )
+ if has_named_groups:
+ groups = m.groupdict(0)
+ return type_(
+ **dict(
+ list(
+ zip(
+ iter(groups.keys()),
+ list(map(int, iter(groups.values()))),
+ )
+ )
+ )
+ )
+ else:
+ return type_(*list(map(int, m.groups(0))))
+
+ return process
+
+
+def py_fallback():
+ def to_unicode_processor_factory(encoding, errors=None):
+ decoder = codecs.getdecoder(encoding)
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ # decoder returns a tuple: (value, len). Simply dropping the
+ # len part is safe: it is done that way in the normal
+ # 'xx'.decode(encoding) code path.
+ return decoder(value, errors)[0]
+
+ return process
+
+ def to_conditional_unicode_processor_factory(encoding, errors=None):
+ decoder = codecs.getdecoder(encoding)
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.text_type):
+ return value
+ else:
+ # decoder returns a tuple: (value, len). Simply dropping the
+ # len part is safe: it is done that way in the normal
+ # 'xx'.decode(encoding) code path.
+ return decoder(value, errors)[0]
+
+ return process
+
+ def to_decimal_processor_factory(target_class, scale):
+ fstring = "%%.%df" % scale
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ return target_class(fstring % value)
+
+ return process
+
+ def to_float(value): # noqa
+ if value is None:
+ return None
+ else:
+ return float(value)
+
+ def to_str(value): # noqa
+ if value is None:
+ return None
+ else:
+ return str(value)
+
+ def int_to_boolean(value): # noqa
+ if value is None:
+ return None
+ else:
+ return bool(value)
+
+ DATETIME_RE = re.compile(
+ r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?"
+ )
+ TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)")
+
+ str_to_datetime = str_to_datetime_processor_factory( # noqa
+ DATETIME_RE, datetime.datetime
+ )
+ str_to_time = str_to_datetime_processor_factory( # noqa
+ TIME_RE, datetime.time
+ ) # noqa
+ str_to_date = str_to_datetime_processor_factory( # noqa
+ DATE_RE, datetime.date
+ ) # noqa
+ return locals()
+
+
+try:
+ from sqlalchemy.cprocessors import DecimalResultProcessor # noqa
+ from sqlalchemy.cprocessors import int_to_boolean # noqa
+ from sqlalchemy.cprocessors import str_to_date # noqa
+ from sqlalchemy.cprocessors import str_to_datetime # noqa
+ from sqlalchemy.cprocessors import str_to_time # noqa
+ from sqlalchemy.cprocessors import to_float # noqa
+ from sqlalchemy.cprocessors import to_str # noqa
+ from sqlalchemy.cprocessors import UnicodeResultProcessor # noqa
+
+ def to_unicode_processor_factory(encoding, errors=None):
+ if errors is not None:
+ return UnicodeResultProcessor(encoding, errors).process
+ else:
+ return UnicodeResultProcessor(encoding).process
+
+ def to_conditional_unicode_processor_factory(encoding, errors=None):
+ if errors is not None:
+ return UnicodeResultProcessor(encoding, errors).conditional_process
+ else:
+ return UnicodeResultProcessor(encoding).conditional_process
+
+ def to_decimal_processor_factory(target_class, scale):
+ # Note that the scale argument is not taken into account for integer
+ # values in the C implementation while it is in the Python one.
+ # For example, the Python implementation might return
+ # Decimal('5.00000') whereas the C implementation will
+ # return Decimal('5'). These are equivalent of course.
+ return DecimalResultProcessor(target_class, "%%.%df" % scale).process
+
+
+except ImportError:
+ globals().update(py_fallback())
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
new file mode 100644
index 0000000..61f82bb
--- /dev/null
+++ b/lib/sqlalchemy/schema.py
@@ -0,0 +1,59 @@
+# schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Compatibility namespace for sqlalchemy.sql.schema and related.
+
+"""
+
+from .sql.base import SchemaVisitor # noqa
+from .sql.ddl import _CreateDropBase # noqa
+from .sql.ddl import _DDLCompiles # noqa
+from .sql.ddl import _DropView # noqa
+from .sql.ddl import AddConstraint # noqa
+from .sql.ddl import CreateColumn # noqa
+from .sql.ddl import CreateIndex # noqa
+from .sql.ddl import CreateSchema # noqa
+from .sql.ddl import CreateSequence # noqa
+from .sql.ddl import CreateTable # noqa
+from .sql.ddl import DDL # noqa
+from .sql.ddl import DDLBase # noqa
+from .sql.ddl import DDLElement # noqa
+from .sql.ddl import DropColumnComment # noqa
+from .sql.ddl import DropConstraint # noqa
+from .sql.ddl import DropIndex # noqa
+from .sql.ddl import DropSchema # noqa
+from .sql.ddl import DropSequence # noqa
+from .sql.ddl import DropTable # noqa
+from .sql.ddl import DropTableComment # noqa
+from .sql.ddl import SetColumnComment # noqa
+from .sql.ddl import SetTableComment # noqa
+from .sql.ddl import sort_tables # noqa
+from .sql.ddl import sort_tables_and_constraints # noqa
+from .sql.naming import conv # noqa
+from .sql.schema import _get_table_key # noqa
+from .sql.schema import BLANK_SCHEMA # noqa
+from .sql.schema import CheckConstraint # noqa
+from .sql.schema import Column # noqa
+from .sql.schema import ColumnCollectionConstraint # noqa
+from .sql.schema import ColumnCollectionMixin # noqa
+from .sql.schema import ColumnDefault # noqa
+from .sql.schema import Computed # noqa
+from .sql.schema import Constraint # noqa
+from .sql.schema import DefaultClause # noqa
+from .sql.schema import DefaultGenerator # noqa
+from .sql.schema import FetchedValue # noqa
+from .sql.schema import ForeignKey # noqa
+from .sql.schema import ForeignKeyConstraint # noqa
+from .sql.schema import Identity # noqa
+from .sql.schema import Index # noqa
+from .sql.schema import MetaData # noqa
+from .sql.schema import PrimaryKeyConstraint # noqa
+from .sql.schema import SchemaItem # noqa
+from .sql.schema import Sequence # noqa
+from .sql.schema import Table # noqa
+from .sql.schema import ThreadLocalMetaData # noqa
+from .sql.schema import UniqueConstraint # noqa
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
new file mode 100644
index 0000000..2677441
--- /dev/null
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -0,0 +1,150 @@
+# sql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import Executable
+from .compiler import COLLECT_CARTESIAN_PRODUCTS
+from .compiler import FROM_LINTING
+from .compiler import NO_LINTING
+from .compiler import WARN_LINTING
+from .expression import Alias
+from .expression import alias
+from .expression import all_
+from .expression import and_
+from .expression import any_
+from .expression import asc
+from .expression import between
+from .expression import bindparam
+from .expression import case
+from .expression import cast
+from .expression import ClauseElement
+from .expression import collate
+from .expression import column
+from .expression import ColumnCollection
+from .expression import ColumnElement
+from .expression import CompoundSelect
+from .expression import cte
+from .expression import Delete
+from .expression import delete
+from .expression import desc
+from .expression import distinct
+from .expression import except_
+from .expression import except_all
+from .expression import exists
+from .expression import extract
+from .expression import false
+from .expression import False_
+from .expression import FromClause
+from .expression import func
+from .expression import funcfilter
+from .expression import Insert
+from .expression import insert
+from .expression import intersect
+from .expression import intersect_all
+from .expression import Join
+from .expression import join
+from .expression import label
+from .expression import LABEL_STYLE_DEFAULT
+from .expression import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .expression import LABEL_STYLE_NONE
+from .expression import LABEL_STYLE_TABLENAME_PLUS_COL
+from .expression import lambda_stmt
+from .expression import LambdaElement
+from .expression import lateral
+from .expression import literal
+from .expression import literal_column
+from .expression import modifier
+from .expression import not_
+from .expression import null
+from .expression import nulls_first
+from .expression import nulls_last
+from .expression import nullsfirst
+from .expression import nullslast
+from .expression import or_
+from .expression import outerjoin
+from .expression import outparam
+from .expression import over
+from .expression import quoted_name
+from .expression import Select
+from .expression import select
+from .expression import Selectable
+from .expression import StatementLambdaElement
+from .expression import Subquery
+from .expression import subquery
+from .expression import table
+from .expression import TableClause
+from .expression import TableSample
+from .expression import tablesample
+from .expression import text
+from .expression import true
+from .expression import True_
+from .expression import tuple_
+from .expression import type_coerce
+from .expression import union
+from .expression import union_all
+from .expression import Update
+from .expression import update
+from .expression import Values
+from .expression import values
+from .expression import within_group
+from .visitors import ClauseVisitor
+
+
+def __go(lcls):
+ global __all__
+ from .. import util as _sa_util
+
+ import inspect as _inspect
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ from .annotation import _prepare_annotations
+ from .annotation import Annotated
+ from .elements import AnnotatedColumnElement
+ from .elements import ClauseList
+ from .selectable import AnnotatedFromClause
+
+ # from .traversals import _preconfigure_traversals
+
+ from . import base
+ from . import coercions
+ from . import elements
+ from . import events
+ from . import lambdas
+ from . import selectable
+ from . import schema
+ from . import sqltypes
+ from . import traversals
+ from . import type_api
+
+ base.coercions = elements.coercions = coercions
+ base.elements = elements
+ base.type_api = type_api
+ coercions.elements = elements
+ coercions.lambdas = lambdas
+ coercions.schema = schema
+ coercions.selectable = selectable
+ coercions.sqltypes = sqltypes
+ coercions.traversals = traversals
+
+ _prepare_annotations(ColumnElement, AnnotatedColumnElement)
+ _prepare_annotations(FromClause, AnnotatedFromClause)
+ _prepare_annotations(ClauseList, Annotated)
+
+ # this is expensive at import time; elements that are used can create
+ # their traversals on demand
+ # _preconfigure_traversals(ClauseElement)
+
+ _sa_util.preloaded.import_prefix("sqlalchemy.sql")
+
+ from . import naming
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
new file mode 100644
index 0000000..5c000ed
--- /dev/null
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -0,0 +1,364 @@
+# sql/annotation.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The :class:`.Annotated` class and related routines; creates hash-equivalent
+copies of SQL constructs which contain context-specific markers and
+associations.
+
+"""
+
+from . import operators
+from .base import HasCacheKey
+from .traversals import anon_map
+from .visitors import InternalTraversal
+from .. import util
+
+EMPTY_ANNOTATIONS = util.immutabledict()
+
+
+class SupportsAnnotations(object):
+ _annotations = EMPTY_ANNOTATIONS
+
+ @util.memoized_property
+ def _annotations_cache_key(self):
+ anon_map_ = anon_map()
+ return (
+ "_annotations",
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map_, [])
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [
+ (key, self._annotations[key])
+ for key in sorted(self._annotations)
+ ]
+ ),
+ )
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
+
+ _clone_annotations_traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
+ return new
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
+ return new
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`_expression.ClauseElement`
+ with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone or self._annotations:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ new = self._clone()
+ new._annotations = util.immutabledict()
+ new.__dict__.pop("_annotations_cache_key", None)
+ return new
+ else:
+ return self
+
+
+class SupportsWrappingAnnotations(SupportsAnnotations):
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`_expression.ClauseElement`
+ with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone:
+ s = self._clone()
+ return s
+ else:
+ return self
+
+
+class Annotated(object):
+ """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+
+ Unlike regular clones, this clone also mimics __hash__() and
+ __cmp__() of the original element so that it takes its place
+ in hashed collections.
+
+ A reference to the original element is maintained, for the important
+ reason of keeping its hash value current. When GC'ed, the
+ hash value may be reused, causing conflicts.
+
+ .. note:: The rationale for Annotated producing a brand new class,
+ rather than placing the functionality directly within ClauseElement,
+ is **performance**. The __hash__() method is absent on plain
+ ClauseElement which leads to significantly reduced function call
+ overhead, as the use of sets and dictionaries against ClauseElement
+ objects is prevalent, but most are not "annotated".
+
+ """
+
+ _is_column_operators = False
+
+ def __new__(cls, *args):
+ if not args:
+ # clone constructor
+ return object.__new__(cls)
+ else:
+ element, values = args
+ # pull appropriate subclass from registry of annotated
+ # classes
+ try:
+ cls = annotated_classes[element.__class__]
+ except KeyError:
+ cls = _new_annotation_type(element.__class__, cls)
+ return object.__new__(cls)
+
+ def __init__(self, element, values):
+ self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotations_cache_key", None)
+ self.__dict__.pop("_generate_cache_key", None)
+ self.__element = element
+ self._annotations = util.immutabledict(values)
+ self._hash = hash(element)
+
+ def _annotate(self, values):
+ _values = self._annotations.union(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
+ clone = self.__class__.__new__(self.__class__)
+ clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotations_cache_key", None)
+ clone.__dict__.pop("_generate_cache_key", None)
+ clone._annotations = values
+ return clone
+
+ def _deannotate(self, values=None, clone=True):
+ if values is None:
+ return self.__element
+ else:
+ return self._with_annotations(
+ util.immutabledict(
+ {
+ key: value
+ for key, value in self._annotations.items()
+ if key not in values
+ }
+ )
+ )
+
+ def _compiler_dispatch(self, visitor, **kw):
+ return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+
+ @property
+ def _constructor(self):
+ return self.__element._constructor
+
+ def _clone(self, **kw):
+ clone = self.__element._clone(**kw)
+ if clone is self.__element:
+ # detect immutable, don't change anything
+ return self
+ else:
+ # update the clone with any changes that have occurred
+ # to this object's __dict__.
+ clone.__dict__.update(self.__dict__)
+ return self.__class__(clone, self._annotations)
+
+ def __reduce__(self):
+ return self.__class__, (self.__element, self._annotations)
+
+ def __hash__(self):
+ return self._hash
+
+ def __eq__(self, other):
+ if self._is_column_operators:
+ return self.__element.__class__.__eq__(self, other)
+ else:
+ return hash(other) == hash(self)
+
+ @property
+ def entity_namespace(self):
+ if "entity_namespace" in self._annotations:
+ return self._annotations["entity_namespace"].entity_namespace
+ else:
+ return self.__element.entity_namespace
+
+
+# hard-generate Annotated subclasses. this technique
+# is used instead of on-the-fly types (i.e. type.__new__())
+# so that the resulting objects are pickleable; additionally, other
+# decisions can be made up front about the type of object being annotated
+# just once per class rather than per-instance.
+annotated_classes = {}
+
+
+def _deep_annotate(
+ element, annotations, exclude=None, detect_subquery_cols=False
+):
+ """Deep copy the given ClauseElement, annotating each element
+ with the given annotations dictionary.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
+
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ kw["detect_subquery_cols"] = detect_subquery_cols
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
+ if (
+ exclude
+ and hasattr(elem, "proxy_set")
+ and elem.proxy_set.intersection(exclude)
+ ):
+ newelem = elem._clone(clone=clone, **kw)
+ elif annotations != elem._annotations:
+ if detect_subquery_cols and elem._is_immutable:
+ newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+ else:
+ newelem = elem._annotate(annotations)
+ else:
+ newelem = elem
+ newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
+ return newelem
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
+
+ cloned = {}
+
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
+ newelem = elem._deannotate(values=values, clone=True)
+ newelem._copy_internals(clone=clone)
+ cloned[key] = newelem
+ return newelem
+ else:
+ return cloned[key]
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+def _shallow_annotate(element, annotations):
+ """Annotate the given ClauseElement and copy its internals so that
+ internal objects refer to the new annotated object.
+
+ Basically used to apply a "don't traverse" annotation to a
+ selectable, without digging throughout the whole
+ structure wasting time.
+ """
+ element = element._annotate(annotations)
+ element._copy_internals()
+ return element
+
+
+def _new_annotation_type(cls, base_cls):
+ if issubclass(cls, Annotated):
+ return cls
+ elif cls in annotated_classes:
+ return annotated_classes[cls]
+
+ for super_ in cls.__mro__:
+ # check if an Annotated subclass more specific than
+ # the given base_cls is already registered, such
+ # as AnnotatedColumnElement.
+ if super_ in annotated_classes:
+ base_cls = annotated_classes[super_]
+ break
+
+ annotated_classes[cls] = anno_cls = type(
+ "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ )
+ globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+ elif cls.__dict__.get("inherit_cache", False):
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ # some classes include this even if they have traverse_internals
+ # e.g. BindParameter, add it if present.
+ if cls.__dict__.get("inherit_cache", False):
+ anno_cls.inherit_cache = True
+
+ anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
+
+ return anno_cls
+
+
+def _prepare_annotations(target_hierarchy, base_cls):
+ for cls in util.walk_subclasses(target_hierarchy):
+ _new_annotation_type(cls, base_cls)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
new file mode 100644
index 0000000..ec685d1
--- /dev/null
+++ b/lib/sqlalchemy/sql/base.py
@@ -0,0 +1,1702 @@
+# sql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Foundational utilities common to many sql modules.
+
+"""
+
+
+import itertools
+import operator
+import re
+
+from . import roles
+from . import visitors
+from .traversals import HasCacheKey # noqa
+from .traversals import HasCopyInternals # noqa
+from .traversals import MemoizedHasCacheKey # noqa
+from .visitors import ClauseVisitor
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..util import HasMemoized
+from ..util import hybridmethod
+
+
+coercions = None
+elements = None
+type_api = None
+
+PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
+NO_ARG = util.symbol("NO_ARG")
+
+
+class Immutable(object):
+ """mark a ClauseElement as 'immutable' when expressions are cloned."""
+
+ _is_immutable = True
+
+ def unique_params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def _clone(self, **kw):
+ return self
+
+ def _copy_internals(self, **kw):
+ pass
+
+
+class SingletonConstant(Immutable):
+ """Represent SQL constants like NULL, TRUE, FALSE"""
+
+ _is_singleton_constant = True
+
+ def __new__(cls, *arg, **kw):
+ return cls._singleton
+
+ @classmethod
+ def _create_singleton(cls):
+ obj = object.__new__(cls)
+ obj.__init__()
+
+ # for a long time this was an empty frozenset, meaning
+ # a SingletonConstant would never be a "corresponding column" in
+ # a statement. This referred to #6259. However, in #7154 we see
+ # that we do in fact need "correspondence" to work when matching cols
+ # in result sets, so the non-correspondence was moved to a more
+ # specific level when we are actually adapting expressions for SQL
+ # render only.
+ obj.proxy_set = frozenset([obj])
+ cls._singleton = obj
+
+
+def _from_objects(*elements):
+ return itertools.chain.from_iterable(
+ [element._from_objects for element in elements]
+ )
+
+
+def _select_iterables(elements):
+ """expand tables into individual columns in the
+ given list of column expressions.
+
+ """
+ return itertools.chain.from_iterable(
+ [c._select_iterable for c in elements]
+ )
+
+
+def _generative(fn):
+ """non-caching _generative() decorator.
+
+ This is basically the legacy decorator that copies the object and
+ runs a method on the new copy.
+
+ """
+
+ @util.decorator
+ def _generative(fn, self, *args, **kw):
+ """Mark a method as generative."""
+
+ self = self._generate()
+ x = fn(self, *args, **kw)
+ assert x is None, "generative methods must have no return value"
+ return self
+
+ decorated = _generative(fn)
+ decorated.non_generative = fn
+ return decorated
+
+
+def _exclusive_against(*names, **kw):
+ msgs = kw.pop("msgs", {})
+
+ defaults = kw.pop("defaults", {})
+
+ getters = [
+ (name, operator.attrgetter(name), defaults.get(name, None))
+ for name in names
+ ]
+
+ @util.decorator
+ def check(fn, *args, **kw):
+ # make pylance happy by not including "self" in the argument
+ # list
+ self = args[0]
+ args = args[1:]
+ for name, getter, default_ in getters:
+ if getter(self) is not default_:
+ msg = msgs.get(
+ name,
+ "Method %s() has already been invoked on this %s construct"
+ % (fn.__name__, self.__class__),
+ )
+ raise exc.InvalidRequestError(msg)
+ return fn(self, *args, **kw)
+
+ return check
+
+
+def _clone(element, **kw):
+ return element._clone(**kw)
+
+
+def _expand_cloned(elements):
+ """expand the given set of ClauseElements to be the set of all 'cloned'
+ predecessors.
+
+ """
+ return itertools.chain(*[x._cloned_set for x in elements])
+
+
+def _cloned_intersection(a, b):
+ """return the intersection of sets a and b, counting
+ any overlap between 'cloned' predecessors.
+
+ The returned set is in terms of the entities present within 'a'.
+
+ """
+ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+ return set(
+ elem for elem in a if all_overlap.intersection(elem._cloned_set)
+ )
+
+
+def _cloned_difference(a, b):
+ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+ return set(
+ elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+ )
+
+
+class _DialectArgView(util.collections_abc.MutableMapping):
+ """A dictionary view of dialect-level arguments in the form
+ <dialectname>_<argument_name>.
+
+ """
+
+ def __init__(self, obj):
+ self.obj = obj
+
+ def _key(self, key):
+ try:
+ dialect, value_key = key.split("_", 1)
+ except ValueError as err:
+ util.raise_(KeyError(key), replace_context=err)
+ else:
+ return dialect, value_key
+
+ def __getitem__(self, key):
+ dialect, value_key = self._key(key)
+
+ try:
+ opt = self.obj.dialect_options[dialect]
+ except exc.NoSuchModuleError as err:
+ util.raise_(KeyError(key), replace_context=err)
+ else:
+ return opt[value_key]
+
+ def __setitem__(self, key, value):
+ try:
+ dialect, value_key = self._key(key)
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Keys must be of the form <dialectname>_<argname>"
+ ),
+ replace_context=err,
+ )
+ else:
+ self.obj.dialect_options[dialect][value_key] = value
+
+ def __delitem__(self, key):
+ dialect, value_key = self._key(key)
+ del self.obj.dialect_options[dialect][value_key]
+
+ def __len__(self):
+ return sum(
+ len(args._non_defaults)
+ for args in self.obj.dialect_options.values()
+ )
+
+ def __iter__(self):
+ return (
+ "%s_%s" % (dialect_name, value_name)
+ for dialect_name in self.obj.dialect_options
+ for value_name in self.obj.dialect_options[
+ dialect_name
+ ]._non_defaults
+ )
+
+
+class _DialectArgDict(util.collections_abc.MutableMapping):
+ """A dictionary view of dialect-level arguments for a specific
+ dialect.
+
+ Maintains a separate collection of user-specified arguments
+ and dialect-specified default arguments.
+
+ """
+
+ def __init__(self):
+ self._non_defaults = {}
+ self._defaults = {}
+
+ def __len__(self):
+ return len(set(self._non_defaults).union(self._defaults))
+
+ def __iter__(self):
+ return iter(set(self._non_defaults).union(self._defaults))
+
+ def __getitem__(self, key):
+ if key in self._non_defaults:
+ return self._non_defaults[key]
+ else:
+ return self._defaults[key]
+
+ def __setitem__(self, key, value):
+ self._non_defaults[key] = value
+
+ def __delitem__(self, key):
+ del self._non_defaults[key]
+
+
+@util.preload_module("sqlalchemy.dialects")
+def _kw_reg_for_dialect(dialect_name):
+ dialect_cls = util.preloaded.dialects.registry.load(dialect_name)
+ if dialect_cls.construct_arguments is None:
+ return None
+ return dict(dialect_cls.construct_arguments)
+
+
+class DialectKWArgs(object):
+ """Establish the ability for a class to have dialect-specific arguments
+ with defaults and constructor validation.
+
+ The :class:`.DialectKWArgs` interacts with the
+ :attr:`.DefaultDialect.construct_arguments` present on a dialect.
+
+ .. seealso::
+
+ :attr:`.DefaultDialect.construct_arguments`
+
+ """
+
+ _dialect_kwargs_traverse_internals = [
+ ("dialect_options", InternalTraversal.dp_dialect_options)
+ ]
+
+ @classmethod
+ def argument_for(cls, dialect_name, argument_name, default):
+ """Add a new kind of dialect-specific keyword argument for this class.
+
+ E.g.::
+
+ Index.argument_for("mydialect", "length", None)
+
+ some_index = Index('a', 'b', mydialect_length=5)
+
+ The :meth:`.DialectKWArgs.argument_for` method is a per-argument
+ way adding extra arguments to the
+ :attr:`.DefaultDialect.construct_arguments` dictionary. This
+ dictionary provides a list of argument names accepted by various
+ schema-level constructs on behalf of a dialect.
+
+ New dialects should typically specify this dictionary all at once as a
+ data member of the dialect class. The use case for ad-hoc addition of
+ argument names is typically for end-user code that is also using
+ a custom compilation scheme which consumes the additional arguments.
+
+ :param dialect_name: name of a dialect. The dialect must be
+ locatable, else a :class:`.NoSuchModuleError` is raised. The
+ dialect must also include an existing
+ :attr:`.DefaultDialect.construct_arguments` collection, indicating
+ that it participates in the keyword-argument validation and default
+ system, else :class:`.ArgumentError` is raised. If the dialect does
+ not include this collection, then any keyword argument can be
+ specified on behalf of this dialect already. All dialects packaged
+ within SQLAlchemy include this collection, however for third party
+ dialects, support may vary.
+
+ :param argument_name: name of the parameter.
+
+ :param default: default value of the parameter.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
+ if construct_arg_dictionary is None:
+ raise exc.ArgumentError(
+ "Dialect '%s' does have keyword-argument "
+ "validation and defaults enabled configured" % dialect_name
+ )
+ if cls not in construct_arg_dictionary:
+ construct_arg_dictionary[cls] = {}
+ construct_arg_dictionary[cls][argument_name] = default
+
+ @util.memoized_property
+ def dialect_kwargs(self):
+ """A collection of keyword arguments specified as dialect-specific
+ options to this construct.
+
+ The arguments are present here in their original ``<dialect>_<kwarg>``
+ format. Only arguments that were actually passed are included;
+ unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
+ contains all options known by this dialect including defaults.
+
+ The collection is also writable; keys are accepted of the
+ form ``<dialect>_<kwarg>`` where the value will be assembled
+ into the list of options.
+
+ .. versionadded:: 0.9.2
+
+ .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
+ collection is now writable.
+
+ .. seealso::
+
+ :attr:`.DialectKWArgs.dialect_options` - nested dictionary form
+
+ """
+ return _DialectArgView(self)
+
+ @property
+ def kwargs(self):
+ """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
+ return self.dialect_kwargs
+
+ _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
+
+ def _kw_reg_for_dialect_cls(self, dialect_name):
+ construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
+ d = _DialectArgDict()
+
+ if construct_arg_dictionary is None:
+ d._defaults.update({"*": None})
+ else:
+ for cls in reversed(self.__class__.__mro__):
+ if cls in construct_arg_dictionary:
+ d._defaults.update(construct_arg_dictionary[cls])
+ return d
+
+ @util.memoized_property
+ def dialect_options(self):
+ """A collection of keyword arguments specified as dialect-specific
+ options to this construct.
+
+ This is a two-level nested registry, keyed to ``<dialect_name>``
+ and ``<argument_name>``. For example, the ``postgresql_where``
+ argument would be locatable as::
+
+ arg = my_object.dialect_options['postgresql']['where']
+
+ .. versionadded:: 0.9.2
+
+ .. seealso::
+
+ :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
+
+ """
+
+ return util.PopulateDict(
+ util.portable_instancemethod(self._kw_reg_for_dialect_cls)
+ )
+
+ def _validate_dialect_kwargs(self, kwargs):
+ # validate remaining kwargs that they all specify DB prefixes
+
+ if not kwargs:
+ return
+
+ for k in kwargs:
+ m = re.match("^(.+?)_(.+)$", k)
+ if not m:
+ raise TypeError(
+ "Additional arguments should be "
+ "named <dialectname>_<argument>, got '%s'" % k
+ )
+ dialect_name, arg_name = m.group(1, 2)
+
+ try:
+ construct_arg_dictionary = self.dialect_options[dialect_name]
+ except exc.NoSuchModuleError:
+ util.warn(
+ "Can't validate argument %r; can't "
+ "locate any SQLAlchemy dialect named %r"
+ % (k, dialect_name)
+ )
+ self.dialect_options[dialect_name] = d = _DialectArgDict()
+ d._defaults.update({"*": None})
+ d._non_defaults[arg_name] = kwargs[k]
+ else:
+ if (
+ "*" not in construct_arg_dictionary
+ and arg_name not in construct_arg_dictionary
+ ):
+ raise exc.ArgumentError(
+ "Argument %r is not accepted by "
+ "dialect %r on behalf of %r"
+ % (k, dialect_name, self.__class__)
+ )
+ else:
+ construct_arg_dictionary[arg_name] = kwargs[k]
+
+
+class CompileState(object):
+ """Produces additional object state necessary for a statement to be
+ compiled.
+
+ the :class:`.CompileState` class is at the base of classes that assemble
+ state for a particular statement object that is then used by the
+ compiler. This process is essentially an extension of the process that
+ the SQLCompiler.visit_XYZ() method takes, however there is an emphasis
+ on converting raw user intent into more organized structures rather than
+ producing string output. The top-level :class:`.CompileState` for the
+ statement being executed is also accessible when the execution context
+ works with invoking the statement and collecting results.
+
+ The production of :class:`.CompileState` is specific to the compiler, such
+ as within the :meth:`.SQLCompiler.visit_insert`,
+ :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also
+ responsible for associating the :class:`.CompileState` with the
+ :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement,
+ i.e. the outermost SQL statement that's actually being executed.
+ There can be other :class:`.CompileState` objects that are not the
+ toplevel, such as when a SELECT subquery or CTE-nested
+ INSERT/UPDATE/DELETE is generated.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("statement",)
+
+ plugins = {}
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ # factory construction.
+
+ if statement._propagate_attrs:
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", "default"
+ )
+ klass = cls.plugins.get(
+ (plugin_name, statement._effective_plugin_target), None
+ )
+ if klass is None:
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
+
+ else:
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
+
+ if klass is cls:
+ return cls(statement, compiler, **kw)
+ else:
+ return klass.create_for_statement(statement, compiler, **kw)
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ @classmethod
+ def get_plugin_class(cls, statement):
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", None
+ )
+
+ if plugin_name:
+ key = (plugin_name, statement._effective_plugin_target)
+ if key in cls.plugins:
+ return cls.plugins[key]
+
+ # there's no case where we call upon get_plugin_class() and want
+ # to get None back, there should always be a default. return that
+ # if there was no plugin-specific class (e.g. Insert with "orm"
+ # plugin)
+ try:
+ return cls.plugins[("default", statement._effective_plugin_target)]
+ except KeyError:
+ return None
+
+ @classmethod
+ def _get_plugin_class_for_plugin(cls, statement, plugin_name):
+ try:
+ return cls.plugins[
+ (plugin_name, statement._effective_plugin_target)
+ ]
+ except KeyError:
+ return None
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name):
+ def decorate(cls_to_decorate):
+ cls.plugins[(plugin_name, visit_name)] = cls_to_decorate
+ return cls_to_decorate
+
+ return decorate
+
+
+class Generative(HasMemoized):
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator."""
+
+ def _generate(self):
+ skip = self._memoized_keys
+ cls = self.__class__
+ s = cls.__new__(cls)
+ if skip:
+ # ensure this iteration remains atomic
+ s.__dict__ = {
+ k: v for k, v in self.__dict__.copy().items() if k not in skip
+ }
+ else:
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+
+class InPlaceGenerative(HasMemoized):
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator that mutates in place."""
+
+ def _generate(self):
+ skip = self._memoized_keys
+ for k in skip:
+ self.__dict__.pop(k, None)
+ return self
+
+
+class HasCompileState(Generative):
+ """A class that has a :class:`.CompileState` associated with it."""
+
+ _compile_state_plugin = None
+
+ _attributes = util.immutabledict()
+
+ _compile_state_factory = CompileState.create_for_statement
+
+
+class _MetaOptions(type):
+ """metaclass for the Options class."""
+
+ def __init__(cls, classname, bases, dict_):
+ cls._cache_attrs = tuple(
+ sorted(
+ d
+ for d in dict_
+ if not d.startswith("__")
+ and d not in ("_cache_key_traversal",)
+ )
+ )
+ type.__init__(cls, classname, bases, dict_)
+
+ def __add__(self, other):
+ o1 = self()
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
+ o1.__dict__.update(other)
+ return o1
+
+
+class Options(util.with_metaclass(_MetaOptions)):
+ """A cacheable option dictionary with defaults."""
+
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __add__(self, other):
+ o1 = self.__class__.__new__(self.__class__)
+ o1.__dict__.update(self.__dict__)
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
+ o1.__dict__.update(other)
+ return o1
+
+ def __eq__(self, other):
+ # TODO: very inefficient. This is used only in test suites
+ # right now.
+ for a, b in util.zip_longest(self._cache_attrs, other._cache_attrs):
+ if getattr(self, a) != getattr(other, b):
+ return False
+ return True
+
+ def __repr__(self):
+ # TODO: fairly inefficient, used only in debugging right now.
+
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ ", ".join(
+ "%s=%r" % (k, self.__dict__[k])
+ for k in self._cache_attrs
+ if k in self.__dict__
+ ),
+ )
+
+ @classmethod
+ def isinstance(cls, klass):
+ return issubclass(cls, klass)
+
+ @hybridmethod
+ def add_to_element(self, name, value):
+ return self + {name: getattr(self, name) + value}
+
+ @hybridmethod
+ def _state_dict(self):
+ return self.__dict__
+
+ _state_dict_const = util.immutabledict()
+
+ @_state_dict.classlevel
+ def _state_dict(cls):
+ return cls._state_dict_const
+
+ @classmethod
+ def safe_merge(cls, other):
+ d = other._state_dict()
+
+ # only support a merge with another object of our class
+ # and which does not have attrs that we don't. otherwise
+ # we risk having state that might not be part of our cache
+ # key strategy
+
+ if (
+ cls is not other.__class__
+ and other._cache_attrs
+ and set(other._cache_attrs).difference(cls._cache_attrs)
+ ):
+ raise TypeError(
+ "other element %r is not empty, is not of type %s, "
+ "and contains attributes not covered here %r"
+ % (
+ other,
+ cls,
+ set(other._cache_attrs).difference(cls._cache_attrs),
+ )
+ )
+ return cls + d
+
+ @classmethod
+ def from_execution_options(
+ cls, key, attrs, exec_options, statement_exec_options
+ ):
+ """process Options argument in terms of execution options.
+
+
+ e.g.::
+
+ (
+ load_options,
+ execution_options,
+ ) = QueryContext.default_load_options.from_execution_options(
+ "_sa_orm_load_options",
+ {
+ "populate_existing",
+ "autoflush",
+ "yield_per"
+ },
+ execution_options,
+ statement._execution_options,
+ )
+
+ get back the Options and refresh "_sa_orm_load_options" in the
+ exec options dict w/ the Options as well
+
+ """
+
+ # common case is that no options we are looking for are
+ # in either dictionary, so cancel for that first
+ check_argnames = attrs.intersection(
+ set(exec_options).union(statement_exec_options)
+ )
+
+ existing_options = exec_options.get(key, cls)
+
+ if check_argnames:
+ result = {}
+ for argname in check_argnames:
+ local = "_" + argname
+ if argname in exec_options:
+ result[local] = exec_options[argname]
+ elif argname in statement_exec_options:
+ result[local] = statement_exec_options[argname]
+
+ new_options = existing_options + result
+ exec_options = util.immutabledict().merge_with(
+ exec_options, {key: new_options}
+ )
+ return new_options, exec_options
+
+ else:
+ return existing_options, exec_options
+
+
+class CacheableOptions(Options, HasCacheKey):
+ @hybridmethod
+ def _gen_cache_key(self, anon_map, bindparams):
+ return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
+
+ @_gen_cache_key.classlevel
+ def _gen_cache_key(cls, anon_map, bindparams):
+ return (cls, ())
+
+ @hybridmethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key_for_object(self)
+
+
+class ExecutableOption(HasCopyInternals):
+ _annotations = util.EMPTY_DICT
+
+ __visit_name__ = "executable_option"
+
+ _is_has_cache_key = False
+
+ def _clone(self, **kw):
+ """Create a shallow copy of this ExecutableOption."""
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = dict(self.__dict__)
+ return c
+
+
+class Executable(roles.StatementRole, Generative):
+ """Mark a :class:`_expression.ClauseElement` as supporting execution.
+
+ :class:`.Executable` is a superclass for all "statement" types
+ of objects, including :func:`select`, :func:`delete`, :func:`update`,
+ :func:`insert`, :func:`text`.
+
+ """
+
+ supports_execution = True
+ _execution_options = util.immutabledict()
+ _bind = None
+ _with_options = ()
+ _with_context_options = ()
+
+ _executable_traverse_internals = [
+ ("_with_options", InternalTraversal.dp_executable_options),
+ (
+ "_with_context_options",
+ ExtendedInternalTraversal.dp_with_context_options,
+ ),
+ ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
+ ]
+
+ is_select = False
+ is_update = False
+ is_insert = False
+ is_text = False
+ is_delete = False
+ is_dml = False
+
+ @property
+ def _effective_plugin_target(self):
+ return self.__visit_name__
+
+ @_generative
+ def options(self, *options):
+ """Apply options to this statement.
+
+ In the general sense, options are any kind of Python object
+ that can be interpreted by the SQL compiler for the statement.
+ These options can be consumed by specific dialects or specific kinds
+ of compilers.
+
+ The most commonly known kind of option are the ORM level options
+ that apply "eager load" and other loading behaviors to an ORM
+ query. However, options can theoretically be used for many other
+ purposes.
+
+ For background on specific kinds of options for specific kinds of
+ statements, refer to the documentation for those option objects.
+
+ .. versionchanged:: 1.4 - added :meth:`.Generative.options` to
+ Core statement objects towards the goal of allowing unified
+ Core / ORM querying capabilities.
+
+ .. seealso::
+
+ :ref:`deferred_options` - refers to options specific to the usage
+ of ORM queries
+
+ :ref:`relationship_loader_options` - refers to options specific
+ to the usage of ORM queries
+
+ """
+ self._with_options += tuple(
+ coercions.expect(roles.ExecutableOptionRole, opt)
+ for opt in options
+ )
+
+ @_generative
+ def _set_compile_options(self, compile_options):
+ """Assign the compile options to a new value.
+
+ :param compile_options: appropriate CacheableOptions structure
+
+ """
+
+ self._compile_options = compile_options
+
+ @_generative
+ def _update_compile_options(self, options):
+ """update the _compile_options with new keys."""
+
+ self._compile_options += options
+
+ @_generative
+ def _add_context_option(self, callable_, cache_args):
+ """Add a context option to this statement.
+
+ These are callable functions that will
+ be given the CompileState object upon compilation.
+
+ A second argument cache_args is required, which will be combined with
+ the ``__code__`` identity of the function itself in order to produce a
+ cache key.
+
+ """
+ self._with_context_options += ((callable_, cache_args),)
+
+ @_generative
+ def execution_options(self, **kw):
+ """Set non-SQL options for the statement which take effect during
+ execution.
+
+ Execution options can be set on a per-statement or
+ per :class:`_engine.Connection` basis. Additionally, the
+ :class:`_engine.Engine` and ORM :class:`~.orm.query.Query`
+ objects provide
+ access to execution options which they in turn configure upon
+ connections.
+
+ The :meth:`execution_options` method is generative. A new
+ instance of this statement is returned that contains the options::
+
+ statement = select(table.c.x, table.c.y)
+ statement = statement.execution_options(autocommit=True)
+
+ Note that only a subset of possible execution options can be applied
+ to a statement - these include "autocommit" and "stream_results",
+ but not "isolation_level" or "compiled_cache".
+ See :meth:`_engine.Connection.execution_options` for a full list of
+ possible options.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+
+ :meth:`_query.Query.execution_options`
+
+ :meth:`.Executable.get_execution_options`
+
+ """
+ if "isolation_level" in kw:
+ raise exc.ArgumentError(
+ "'isolation_level' execution option may only be specified "
+ "on Connection.execution_options(), or "
+ "per-engine using the isolation_level "
+ "argument to create_engine()."
+ )
+ if "compiled_cache" in kw:
+ raise exc.ArgumentError(
+ "'compiled_cache' execution option may only be specified "
+ "on Connection.execution_options(), not per statement."
+ )
+ self._execution_options = self._execution_options.union(kw)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`.Executable.execution_options`
+ """
+ return self._execution_options
+
+ @util.deprecated_20(
+ ":meth:`.Executable.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, *multiparams, **params):
+ """Compile and execute this :class:`.Executable`."""
+ e = self.bind
+ if e is None:
+ label = (
+ getattr(self, "description", None) or self.__class__.__name__
+ )
+ msg = (
+ "This %s is not directly bound to a Connection or Engine. "
+ "Use the .execute() method of a Connection or Engine "
+ "to execute this construct." % label
+ )
+ raise exc.UnboundExecutionError(msg)
+ return e._execute_clauseelement(
+ self, multiparams, params, util.immutabledict()
+ )
+
+ @util.deprecated_20(
+ ":meth:`.Executable.scalar`",
+ alternative="Scalar execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.scalar` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.scalar` method of "
+ ":class:`.Session`.",
+ )
+ def scalar(self, *multiparams, **params):
+ """Compile and execute this :class:`.Executable`, returning the
+ result's scalar representation.
+
+ """
+ return self.execute(*multiparams, **params).scalar()
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to
+ which this :class:`.Executable` is bound, or None if none found.
+
+ This is a traversal which checks locally, then
+ checks among the "from" clauses of associated objects
+ until a bound engine or connection is found.
+
+ """
+ if self._bind is not None:
+ return self._bind
+
+ for f in _from_objects(self):
+ if f is self:
+ continue
+ engine = f.bind
+ if engine is not None:
+ return engine
+ else:
+ return None
+
+
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = derived + "_" + str(anonymous_counter)
+ self[key] = value
+ return value
+
+
+class SchemaEventTarget(object):
+ """Base class for elements that are the targets of :class:`.DDLEvents`
+ events.
+
+ This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
+
+ """
+
+ def _set_parent(self, parent, **kw):
+ """Associate with this SchemaEvent's parent object."""
+
+ def _set_parent_with_dispatch(self, parent, **kw):
+ self.dispatch.before_parent_attach(self, parent)
+ self._set_parent(parent, **kw)
+ self.dispatch.after_parent_attach(self, parent)
+
+
+class SchemaVisitor(ClauseVisitor):
+ """Define the visiting for ``SchemaItem`` objects."""
+
+ __traverse_options__ = {"schema_visitor": True}
+
+
+class ColumnCollection(object):
+ """Collection of :class:`_expression.ColumnElement` instances,
+ typically for
+ :class:`_sql.FromClause` objects.
+
+ The :class:`_sql.ColumnCollection` object is most commonly available
+ as the :attr:`_schema.Table.c` or :attr:`_schema.Table.columns` collection
+ on the :class:`_schema.Table` object, introduced at
+ :ref:`metadata_tables_and_columns`.
+
+ The :class:`_expression.ColumnCollection` has both mapping- and sequence-
+ like behaviors. A :class:`_expression.ColumnCollection` usually stores
+ :class:`_schema.Column` objects, which are then accessible both via mapping
+ style access as well as attribute access style.
+
+ To access :class:`_schema.Column` objects using ordinary attribute-style
+ access, specify the name like any other object attribute, such as below
+ a column named ``employee_name`` is accessed::
+
+ >>> employee_table.c.employee_name
+
+ To access columns that have names with special characters or spaces,
+ index-style access is used, such as below which illustrates a column named
+ ``employee ' payment`` is accessed::
+
+ >>> employee_table.c["employee ' payment"]
+
+ As the :class:`_sql.ColumnCollection` object provides a Python dictionary
+ interface, common dictionary method names like
+ :meth:`_sql.ColumnCollection.keys`, :meth:`_sql.ColumnCollection.values`,
+ and :meth:`_sql.ColumnCollection.items` are available, which means that
+ database columns that are keyed under these names also need to use indexed
+ access::
+
+ >>> employee_table.c["values"]
+
+
+ The name for which a :class:`_schema.Column` would be present is normally
+ that of the :paramref:`_schema.Column.key` parameter. In some contexts,
+ such as a :class:`_sql.Select` object that uses a label style set
+ using the :meth:`_sql.Select.set_label_style` method, a column of a certain
+ key may instead be represented under a particular label name such
+ as ``tablename_columnname``::
+
+ >>> from sqlalchemy import select, column, table
+ >>> from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
+ >>> t = table("t", column("c"))
+ >>> stmt = select(t).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ >>> subq = stmt.subquery()
+ >>> subq.c.t_c
+ <sqlalchemy.sql.elements.ColumnClause at 0x7f59dcf04fa0; t_c>
+
+ :class:`.ColumnCollection` also indexes the columns in order and allows
+ them to be accessible by their integer position::
+
+ >>> cc[0]
+ Column('x', Integer(), table=None)
+ >>> cc[1]
+ Column('y', Integer(), table=None)
+
+ .. versionadded:: 1.4 :class:`_expression.ColumnCollection`
+ allows integer-based
+ index access to the collection.
+
+ Iterating the collection yields the column expressions in order::
+
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('y', Integer(), table=None)]
+
+ The base :class:`_expression.ColumnCollection` object can store
+ duplicates, which can
+ mean either two columns with the same key, in which case the column
+ returned by key access is **arbitrary**::
+
+ >>> x1, x2 = Column('x', Integer), Column('x', Integer)
+ >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)])
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('x', Integer(), table=None)]
+ >>> cc['x'] is x1
+ False
+ >>> cc['x'] is x2
+ True
+
+ Or it can also mean the same column multiple times. These cases are
+ supported as :class:`_expression.ColumnCollection`
+ is used to represent the columns in
+ a SELECT statement which may include duplicates.
+
+ A special subclass :class:`.DedupeColumnCollection` exists which instead
+ maintains SQLAlchemy's older behavior of not allowing duplicates; this
+ collection is used for schema level objects like :class:`_schema.Table`
+ and
+ :class:`.PrimaryKeyConstraint` where this deduping is helpful. The
+ :class:`.DedupeColumnCollection` class also has additional mutation methods
+ as the schema constructs have more use cases that require removal and
+ replacement of columns.
+
+ .. versionchanged:: 1.4 :class:`_expression.ColumnCollection`
+ now stores duplicate
+ column keys as well as the same column in multiple positions. The
+ :class:`.DedupeColumnCollection` class is added to maintain the
+ former behavior in those cases where deduplication as well as
+ additional replace/remove operations are needed.
+
+
+ """
+
+ __slots__ = "_collection", "_index", "_colset"
+
+ def __init__(self, columns=None):
+ object.__setattr__(self, "_colset", set())
+ object.__setattr__(self, "_index", {})
+ object.__setattr__(self, "_collection", [])
+ if columns:
+ self._initial_populate(columns)
+
+ def _initial_populate(self, iter_):
+ self._populate_separate_keys(iter_)
+
+ @property
+ def _all_columns(self):
+ return [col for (k, col) in self._collection]
+
+ def keys(self):
+ """Return a sequence of string key names for all columns in this
+ collection."""
+ return [k for (k, col) in self._collection]
+
+ def values(self):
+ """Return a sequence of :class:`_sql.ColumnClause` or
+ :class:`_schema.Column` objects for all columns in this
+ collection."""
+ return [col for (k, col) in self._collection]
+
+ def items(self):
+ """Return a sequence of (key, column) tuples for all columns in this
+ collection each consisting of a string key name and a
+ :class:`_sql.ColumnClause` or
+ :class:`_schema.Column` object.
+ """
+
+ return list(self._collection)
+
+ def __bool__(self):
+ return bool(self._collection)
+
+ def __len__(self):
+ return len(self._collection)
+
+ def __iter__(self):
+ # turn to a list first to maintain over a course of changes
+ return iter([col for k, col in self._collection])
+
+ def __getitem__(self, key):
+ try:
+ return self._index[key]
+ except KeyError as err:
+ if isinstance(key, util.int_types):
+ util.raise_(IndexError(key), replace_context=err)
+ else:
+ raise
+
+ def __getattr__(self, key):
+ try:
+ return self._index[key]
+ except KeyError as err:
+ util.raise_(AttributeError(key), replace_context=err)
+
+ def __contains__(self, key):
+ if key not in self._index:
+ if not isinstance(key, util.string_types):
+ raise exc.ArgumentError(
+ "__contains__ requires a string argument"
+ )
+ return False
+ else:
+ return True
+
+ def compare(self, other):
+ """Compare this :class:`_expression.ColumnCollection` to another
+ based on the names of the keys"""
+
+ for l, r in util.zip_longest(self, other):
+ if l is not r:
+ return False
+ else:
+ return True
+
+ def __eq__(self, other):
+ return self.compare(other)
+
+ def get(self, key, default=None):
+ """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
+ based on a string key name from this
+ :class:`_expression.ColumnCollection`."""
+
+ if key in self._index:
+ return self._index[key]
+ else:
+ return default
+
+ def __str__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ ", ".join(str(c) for c in self),
+ )
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError()
+
+ def __delitem__(self, key):
+ raise NotImplementedError()
+
+ def __setattr__(self, key, obj):
+ raise NotImplementedError()
+
+ def clear(self):
+ """Dictionary clear() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ def remove(self, column):
+ """Dictionary remove() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ def update(self, iter_):
+ """Dictionary update() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ __hash__ = None
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+ self._collection[:] = cols
+ self._colset.update(c for k, c in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ self._index.update({k: col for k, col in reversed(self._collection)})
+
+ def add(self, column, key=None):
+ """Add a column to this :class:`_sql.ColumnCollection`.
+
+ .. note::
+
+ This method is **not normally used by user-facing code**, as the
+ :class:`_sql.ColumnCollection` is usually part of an existing
+ object such as a :class:`_schema.Table`. To add a
+ :class:`_schema.Column` to an existing :class:`_schema.Table`
+ object, use the :meth:`_schema.Table.append_column` method.
+
+ """
+ if key is None:
+ key = column.key
+
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ if key not in self._index:
+ self._index[key] = column
+
+ def __getstate__(self):
+ return {"_collection": self._collection, "_index": self._index}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_index", state["_index"])
+ object.__setattr__(self, "_collection", state["_collection"])
+ object.__setattr__(
+ self, "_colset", {col for k, col in self._collection}
+ )
+
+ def contains_column(self, col):
+ """Checks if a column object exists in this collection"""
+ if col not in self._colset:
+ if isinstance(col, util.string_types):
+ raise exc.ArgumentError(
+ "contains_column cannot be used with string arguments. "
+ "Use ``col_name in table.c`` instead."
+ )
+ return False
+ else:
+ return True
+
+ def as_immutable(self):
+ """Return an "immutable" form of this
+ :class:`_sql.ColumnCollection`."""
+
+ return ImmutableColumnCollection(self)
+
+ def corresponding_column(self, column, require_embedded=False):
+ """Given a :class:`_expression.ColumnElement`, return the exported
+ :class:`_expression.ColumnElement` object from this
+ :class:`_expression.ColumnCollection`
+ which corresponds to that original :class:`_expression.ColumnElement`
+ via a common
+ ancestor column.
+
+ :param column: the target :class:`_expression.ColumnElement`
+ to be matched.
+
+ :param require_embedded: only return corresponding columns for
+ the given :class:`_expression.ColumnElement`, if the given
+ :class:`_expression.ColumnElement`
+ is actually present within a sub-element
+ of this :class:`_expression.Selectable`.
+ Normally the column will match if
+ it merely shares a common ancestor with one of the exported
+ columns of this :class:`_expression.Selectable`.
+
+ .. seealso::
+
+ :meth:`_expression.Selectable.corresponding_column`
+ - invokes this method
+ against the collection returned by
+ :attr:`_expression.Selectable.exported_columns`.
+
+ .. versionchanged:: 1.4 the implementation for ``corresponding_column``
+ was moved onto the :class:`_expression.ColumnCollection` itself.
+
+ """
+
+ def embedded(expanded_proxy_set, target_set):
+ for t in target_set.difference(expanded_proxy_set):
+ if not set(_expand_cloned([t])).intersection(
+ expanded_proxy_set
+ ):
+ return False
+ return True
+
+ # don't dig around if the column is locally present
+ if column in self._colset:
+ return column
+ col, intersect = None, None
+ target_set = column.proxy_set
+ cols = [c for (k, c) in self._collection]
+ for c in cols:
+ expanded_proxy_set = set(_expand_cloned(c.proxy_set))
+ i = target_set.intersection(expanded_proxy_set)
+ if i and (
+ not require_embedded
+ or embedded(expanded_proxy_set, target_set)
+ ):
+ if col is None:
+
+ # no corresponding column yet, pick this one.
+
+ col, intersect = c, i
+ elif len(i) > len(intersect):
+
+ # 'c' has a larger field of correspondence than
+ # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x
+ # matches a1.c.x->table.c.x better than
+ # selectable.c.x->table.c.x does.
+
+ col, intersect = c, i
+ elif i == intersect:
+ # they have the same field of correspondence. see
+ # which proxy_set has fewer columns in it, which
+ # indicates a closer relationship with the root
+ # column. Also take into account the "weight"
+ # attribute which CompoundSelect() uses to give
+ # higher precedence to columns based on vertical
+ # position in the compound statement, and discard
+ # columns that have no reference to the target
+ # column (also occurs with CompoundSelect)
+
+ col_distance = util.reduce(
+ operator.add,
+ [
+ sc._annotations.get("weight", 1)
+ for sc in col._uncached_proxy_set()
+ if sc.shares_lineage(column)
+ ],
+ )
+ c_distance = util.reduce(
+ operator.add,
+ [
+ sc._annotations.get("weight", 1)
+ for sc in c._uncached_proxy_set()
+ if sc.shares_lineage(column)
+ ],
+ )
+ if c_distance < col_distance:
+ col, intersect = c, i
+ return col
+
+
+class DedupeColumnCollection(ColumnCollection):
+ """A :class:`_expression.ColumnCollection`
+ that maintains deduplicating behavior.
+
+ This is useful by schema level objects such as :class:`_schema.Table` and
+ :class:`.PrimaryKeyConstraint`. The collection includes more
+ sophisticated mutator methods as well to suit schema objects which
+ require mutable column collections.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def add(self, column, key=None):
+
+ if key is not None and column.key != key:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ key = column.key
+
+ if key is None:
+ raise exc.ArgumentError(
+ "Can't add unnamed column to column collection"
+ )
+
+ if key in self._index:
+
+ existing = self._index[key]
+
+ if existing is column:
+ return
+
+ self.replace(column)
+
+ # pop out memoized proxy_set as this
+ # operation may very well be occurring
+ # in a _make_proxy operation
+ util.memoized_property.reset(column, "proxy_set")
+ else:
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ self._index[key] = column
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+
+ replace_col = []
+ for k, col in cols:
+ if col.key != k:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ if col.name in self._index and col.key != col.name:
+ replace_col.append(col)
+ elif col.key in self._index:
+ replace_col.append(col)
+ else:
+ self._index[k] = col
+ self._collection.append((k, col))
+ self._colset.update(c for (k, c) in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ for col in replace_col:
+ self.replace(col)
+
+ def extend(self, iter_):
+ self._populate_separate_keys((col.key, col) for col in iter_)
+
+ def remove(self, column):
+ if column not in self._colset:
+ raise ValueError(
+ "Can't remove column %r; column is not in this collection"
+ % column
+ )
+ del self._index[column.key]
+ self._colset.remove(column)
+ self._collection[:] = [
+ (k, c) for (k, c) in self._collection if c is not column
+ ]
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ # delete higher index
+ del self._index[len(self._collection)]
+
+ def replace(self, column):
+ """add the given column to this collection, removing unaliased
+ versions of this column as well as existing columns with the
+ same key.
+
+ e.g.::
+
+ t = Table('sometable', metadata, Column('col1', Integer))
+ t.columns.replace(Column('col1', Integer, key='columnone'))
+
+ will remove the original 'col1' from the collection, and add
+ the new column under the name 'columnname'.
+
+ Used by schema.Column to override columns during table reflection.
+
+ """
+
+ remove_col = set()
+ # remove up to two columns based on matches of name as well as key
+ if column.name in self._index and column.key != column.name:
+ other = self._index[column.name]
+ if other.name == other.key:
+ remove_col.add(other)
+
+ if column.key in self._index:
+ remove_col.add(self._index[column.key])
+
+ new_cols = []
+ replaced = False
+ for k, col in self._collection:
+ if col in remove_col:
+ if not replaced:
+ replaced = True
+ new_cols.append((column.key, column))
+ else:
+ new_cols.append((k, col))
+
+ if remove_col:
+ self._colset.difference_update(remove_col)
+
+ if not replaced:
+ new_cols.append((column.key, column))
+
+ self._colset.add(column)
+ self._collection[:] = new_cols
+
+ self._index.clear()
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ self._index.update(self._collection)
+
+
+class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+ __slots__ = ("_parent",)
+
+ def __init__(self, collection):
+ object.__setattr__(self, "_parent", collection)
+ object.__setattr__(self, "_colset", collection._colset)
+ object.__setattr__(self, "_index", collection._index)
+ object.__setattr__(self, "_collection", collection._collection)
+
+ def __getstate__(self):
+ return {"_parent": self._parent}
+
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ self.__init__(parent)
+
+ add = extend = remove = util.ImmutableContainer._immutable
+
+
+class ColumnSet(util.ordered_column_set):
+ def contains_column(self, col):
+ return col in self
+
+ def extend(self, cols):
+ for col in cols:
+ self.add(col)
+
+ def __add__(self, other):
+ return list(self) + list(other)
+
+ def __eq__(self, other):
+ l = []
+ for c in other:
+ for local in self:
+ if c.shares_lineage(local):
+ l.append(c == local)
+ return elements.and_(*l)
+
+ def __hash__(self):
+ return hash(tuple(x for x in self))
+
+
+def _bind_or_error(schemaitem, msg=None):
+
+ util.warn_deprecated_20(
+ "The ``bind`` argument for schema methods that invoke SQL "
+ "against an engine or connection will be required in SQLAlchemy 2.0."
+ )
+ bind = schemaitem.bind
+ if not bind:
+ name = schemaitem.__class__.__name__
+ label = getattr(
+ schemaitem, "fullname", getattr(schemaitem, "name", None)
+ )
+ if label:
+ item = "%s object %r" % (name, label)
+ else:
+ item = "%s object" % name
+ if msg is None:
+ msg = (
+ "%s is not bound to an Engine or Connection. "
+ "Execution can not proceed without a database to execute "
+ "against." % item
+ )
+ raise exc.UnboundExecutionError(msg)
+ return bind
+
+
+def _entity_namespace(entity):
+ """Return the nearest .entity_namespace for the given entity.
+
+ If not immediately available, does an iterate to find a sub-element
+ that has one, if any.
+
+ """
+ try:
+ return entity.entity_namespace
+ except AttributeError:
+ for elem in visitors.iterate(entity):
+ if hasattr(elem, "entity_namespace"):
+ return elem.entity_namespace
+ else:
+ raise
+
+
+def _entity_namespace_key(entity, key, default=NO_ARG):
+ """Return an entry from an entity_namespace.
+
+
+ Raises :class:`_exc.InvalidRequestError` rather than attribute error
+ on not found.
+
+ """
+
+ try:
+ ns = _entity_namespace(entity)
+ if default is not NO_ARG:
+ return getattr(ns, key, default)
+ else:
+ return getattr(ns, key)
+ except AttributeError as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'Entity namespace for "%s" has no property "%s"'
+ % (entity, key)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
new file mode 100644
index 0000000..8cc73cb
--- /dev/null
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -0,0 +1,1096 @@
+# sql/coercions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import numbers
+import re
+
+from . import operators
+from . import roles
+from . import visitors
+from .base import ExecutableOption
+from .base import Options
+from .traversals import HasCacheKey
+from .visitors import Visitable
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+
+
+elements = None
+lambdas = None
+schema = None
+selectable = None
+sqltypes = None
+traversals = None
+
+
+def _is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ """
+
+ return (
+ not isinstance(
+ element,
+ (Visitable, schema.SchemaEventTarget),
+ )
+ and not hasattr(element, "__clause_element__")
+ )
+
+
+def _deep_is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ does a deeper more esoteric check than _is_literal. is used
+ for lambda elements that have to distinguish values that would
+ be bound vs. not without any context.
+
+ """
+
+ if isinstance(element, collections_abc.Sequence) and not isinstance(
+ element, str
+ ):
+ for elem in element:
+ if not _deep_is_literal(elem):
+ return False
+ else:
+ return True
+
+ return (
+ not isinstance(
+ element,
+ (
+ Visitable,
+ schema.SchemaEventTarget,
+ HasCacheKey,
+ Options,
+ util.langhelpers._symbol,
+ ),
+ )
+ and not hasattr(element, "__clause_element__")
+ and (
+ not isinstance(element, type)
+ or not issubclass(element, HasCacheKey)
+ )
+ )
+
+
+def _document_text_coercion(paramname, meth_rst, param_rst):
+ return util.add_parameter_text(
+ paramname,
+ (
+ ".. warning:: "
+ "The %s argument to %s can be passed as a Python string argument, "
+ "which will be treated "
+ "as **trusted SQL text** and rendered as given. **DO NOT PASS "
+ "UNTRUSTED INPUT TO THIS PARAMETER**."
+ )
+ % (param_rst, meth_rst),
+ )
+
+
+def _expression_collection_was_a_list(attrname, fnname, args):
+ if args and isinstance(args[0], (list, set, dict)) and len(args) == 1:
+ if isinstance(args[0], list):
+ util.warn_deprecated_20(
+ 'The "%s" argument to %s(), when referring to a sequence '
+ "of items, is now passed as a series of positional "
+ "elements, rather than as a list. " % (attrname, fnname)
+ )
+ return args[0]
+ else:
+ return args
+
+
+def expect(
+ role,
+ element,
+ apply_propagate_attrs=None,
+ argname=None,
+ post_inspect=False,
+ **kw
+):
+ if (
+ role.allows_lambda
+ # note callable() will not invoke a __getattr__() method, whereas
+ # hasattr(obj, "__call__") will. by keeping the callable() check here
+ # we prevent most needless calls to hasattr() and therefore
+ # __getattr__(), which is present on ColumnElement.
+ and callable(element)
+ and hasattr(element, "__code__")
+ ):
+ return lambdas.LambdaElement(
+ element,
+ role,
+ lambdas.LambdaOptions(**kw),
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+
+ # major case is that we are given a ClauseElement already, skip more
+ # elaborate logic up front if possible
+ impl = _impl_lookup[role]
+
+ original_element = element
+
+ if not isinstance(
+ element,
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ ):
+ resolved = None
+
+ if impl._resolve_literal_only:
+ resolved = impl._literal_coercion(element, **kw)
+ else:
+
+ original_element = element
+
+ is_clause_element = False
+
+ # this is a special performance optimization for ORM
+ # joins used by JoinTargetImpl that we don't go through the
+ # work of creating __clause_element__() when we only need the
+ # original QueryableAttribute, as the former will do clause
+ # adaption and all that which is just thrown away here.
+ if (
+ impl._skip_clauseelement_for_target_match
+ and isinstance(element, role)
+ and hasattr(element, "__clause_element__")
+ ):
+ is_clause_element = True
+ else:
+ while hasattr(element, "__clause_element__"):
+ is_clause_element = True
+
+ if not getattr(element, "is_clause_element", False):
+ element = element.__clause_element__()
+ else:
+ break
+
+ if not is_clause_element:
+ if impl._use_inspection:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ if post_inspect:
+ insp._post_inspect
+ try:
+ resolved = insp.__clause_element__()
+ except AttributeError:
+ impl._raise_for_expected(original_element, argname)
+
+ if resolved is None:
+ resolved = impl._literal_coercion(
+ element, argname=argname, **kw
+ )
+ else:
+ resolved = element
+ else:
+ resolved = element
+ if (
+ apply_propagate_attrs is not None
+ and not apply_propagate_attrs._propagate_attrs
+ and resolved._propagate_attrs
+ ):
+ apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
+
+ if impl._role_class in resolved.__class__.__mro__:
+ if impl._post_coercion:
+ resolved = impl._post_coercion(
+ resolved,
+ argname=argname,
+ original_element=original_element,
+ **kw
+ )
+ return resolved
+ else:
+ return impl._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+
+
+def expect_as_key(role, element, **kw):
+ kw["as_key"] = True
+ return expect(role, element, **kw)
+
+
+def expect_col_expression_collection(role, expressions):
+ for expr in expressions:
+ strname = None
+ column = None
+
+ resolved = expect(role, expr)
+ if isinstance(resolved, util.string_types):
+ strname = resolved = expr
+ else:
+ cols = []
+ visitors.traverse(resolved, {}, {"column": cols.append})
+ if cols:
+ column = cols[0]
+ add_element = column if column is not None else strname
+ yield resolved, column, strname, add_element
+
+
+class RoleImpl(object):
+ __slots__ = ("_role_class", "name", "_use_inspection")
+
+ def _literal_coercion(self, element, **kw):
+ raise NotImplementedError()
+
+ _post_coercion = None
+ _resolve_literal_only = False
+ _skip_clauseelement_for_target_match = False
+
+ def __init__(self, role_class):
+ self._role_class = role_class
+ self.name = role_class._role_name
+ self._use_inspection = issubclass(role_class, roles.UsesInspection)
+
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ self._raise_for_expected(element, argname, resolved)
+
+ def _raise_for_expected(
+ self,
+ element,
+ argname=None,
+ resolved=None,
+ advice=None,
+ code=None,
+ err=None,
+ ):
+ if resolved is not None and resolved is not element:
+ got = "%r object resolved from %r object" % (resolved, element)
+ else:
+ got = repr(element)
+
+ if argname:
+ msg = "%s expected for argument %r; got %s." % (
+ self.name,
+ argname,
+ got,
+ )
+ else:
+ msg = "%s expected, got %s." % (self.name, got)
+
+ if advice:
+ msg += " " + advice
+
+ util.raise_(exc.ArgumentError(msg, code=code), replace_context=err)
+
+
+class _Deannotate(object):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, **kw):
+ from .util import _deep_deannotate
+
+ return _deep_deannotate(resolved)
+
+
+class _StringOnly(object):
+ __slots__ = ()
+
+ _resolve_literal_only = True
+
+
+class _ReturnsStringKey(object):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, util.string_types):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class _ColumnCoercions(object):
+ __slots__ = ()
+
+ def _warn_for_scalar_subquery_coercion(self):
+ util.warn(
+ "implicitly coercing SELECT object to scalar subquery; "
+ "please use the .scalar_subquery() method to produce a scalar "
+ "subquery.",
+ )
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if not getattr(resolved, "is_clause_element", False):
+ self._raise_for_expected(original_element, argname, resolved)
+ elif resolved._is_select_statement:
+ self._warn_for_scalar_subquery_coercion()
+ return resolved.scalar_subquery()
+ elif resolved._is_from_clause and isinstance(
+ resolved, selectable.Subquery
+ ):
+ self._warn_for_scalar_subquery_coercion()
+ return resolved.element.scalar_subquery()
+ elif self._role_class.allows_lambda and resolved._is_lambda_element:
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+def _no_text_coercion(
+ element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
+):
+ util.raise_(
+ exc_cls(
+ "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
+ "explicitly declared as text(%(expr)r)"
+ % {
+ "expr": util.ellipses_string(element),
+ "argname": "for argument %s" % (argname,) if argname else "",
+ "extra": "%s " % extra if extra else "",
+ }
+ ),
+ replace_context=err,
+ )
+
+
+class _NoTextCoercion(object):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if isinstance(element, util.string_types) and issubclass(
+ elements.TextClause, self._role_class
+ ):
+ _no_text_coercion(element, argname)
+ else:
+ self._raise_for_expected(element, argname)
+
+
+class _CoerceLiterals(object):
+ __slots__ = ()
+ _coerce_consts = False
+ _coerce_star = False
+ _coerce_numerics = False
+
+ def _text_coercion(self, element, argname=None):
+ return _no_text_coercion(element, argname)
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if isinstance(element, util.string_types):
+ if self._coerce_star and element == "*":
+ return elements.ColumnClause("*", is_literal=True)
+ else:
+ return self._text_coercion(element, argname, **kw)
+
+ if self._coerce_consts:
+ if element is None:
+ return elements.Null()
+ elif element is False:
+ return elements.False_()
+ elif element is True:
+ return elements.True_()
+
+ if self._coerce_numerics and isinstance(element, (numbers.Number)):
+ return elements.ColumnClause(str(element), is_literal=True)
+
+ self._raise_for_expected(element, argname)
+
+
+class LiteralValueImpl(RoleImpl):
+ _resolve_literal_only = True
+
+ def _implicit_coercions(
+ self, element, resolved, argname, type_=None, **kw
+ ):
+ if not _is_literal(resolved):
+ self._raise_for_expected(
+ element, resolved=resolved, argname=argname, **kw
+ )
+
+ return elements.BindParameter(None, element, type_=type_, unique=True)
+
+ def _literal_coercion(self, element, argname=None, type_=None, **kw):
+ return element
+
+
+class _SelectIsNotFrom(object):
+ __slots__ = ()
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.SelectStatementRole) or isinstance(
+ resolved, roles.SelectStatementRole
+ ):
+ advice = (
+ "To create a "
+ "FROM clause from a %s object, use the .subquery() method."
+ % (resolved.__class__ if resolved is not None else element,)
+ )
+ code = "89ve"
+ else:
+ advice = code = None
+
+ return super(_SelectIsNotFrom, self)._raise_for_expected(
+ element,
+ argname=argname,
+ resolved=resolved,
+ advice=advice,
+ code=code,
+ **kw
+ )
+
+
+class HasCacheKeyImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, traversals.HasCacheKey):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class ExecutableOptionImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, ExecutableOption):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(
+ self, element, name=None, type_=None, argname=None, is_crud=False, **kw
+ ):
+ if (
+ element is None
+ and not is_crud
+ and (type_ is None or not type_.should_evaluate_none)
+ ):
+ # TODO: there's no test coverage now for the
+ # "should_evaluate_none" part of this, as outside of "crud" this
+ # codepath is not normally used except in some special cases
+ return elements.Null()
+ else:
+ try:
+ return elements.BindParameter(
+ name, element, type_, unique=True, _is_crud=is_crud
+ )
+ except exc.ArgumentError as err:
+ self._raise_for_expected(element, err=err)
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.AnonymizedFromClauseRole):
+ advice = (
+ "To create a "
+ "column expression from a FROM clause row "
+ "as a whole, use the .table_valued() method."
+ )
+ else:
+ advice = None
+
+ return super(ExpressionElementImpl, self)._raise_for_expected(
+ element, argname=argname, resolved=resolved, advice=advice, **kw
+ )
+
+
+class BinaryElementImpl(ExpressionElementImpl, RoleImpl):
+
+ __slots__ = ()
+
+ def _literal_coercion(
+ self, element, expr, operator, bindparam_type=None, argname=None, **kw
+ ):
+ try:
+ return expr._bind_param(operator, element, type_=bindparam_type)
+ except exc.ArgumentError as err:
+ self._raise_for_expected(element, err=err)
+
+ def _post_coercion(self, resolved, expr, bindparam_type=None, **kw):
+ if resolved.type._isnull and not expr.type._isnull:
+ resolved = resolved._with_binary_element_type(
+ bindparam_type if bindparam_type is not None else expr.type
+ )
+ return resolved
+
+
+class InElementImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_from_clause:
+ if (
+ isinstance(resolved, selectable.Alias)
+ and resolved.element._is_select_statement
+ ):
+ self._warn_for_implicit_coercion(resolved)
+ return self._post_coercion(resolved.element, **kw)
+ else:
+ self._warn_for_implicit_coercion(resolved)
+ return self._post_coercion(resolved.select(), **kw)
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _warn_for_implicit_coercion(self, elem):
+ util.warn(
+ "Coercing %s object into a select() for use in IN(); "
+ "please pass a select() construct explicitly"
+ % (elem.__class__.__name__)
+ )
+
+ def _literal_coercion(self, element, expr, operator, **kw):
+ if isinstance(element, collections_abc.Iterable) and not isinstance(
+ element, util.string_types
+ ):
+ non_literal_expressions = {}
+ element = list(element)
+ for o in element:
+ if not _is_literal(o):
+ if not isinstance(o, operators.ColumnOperators):
+ self._raise_for_expected(element, **kw)
+ else:
+ non_literal_expressions[o] = o
+ elif o is None:
+ non_literal_expressions[o] = elements.Null()
+
+ if non_literal_expressions:
+ return elements.ClauseList(
+ *[
+ non_literal_expressions[o]
+ if o in non_literal_expressions
+ else expr._bind_param(operator, o)
+ for o in element
+ ]
+ )
+ else:
+ return expr._bind_param(operator, element, expanding=True)
+
+ else:
+ self._raise_for_expected(element, **kw)
+
+ def _post_coercion(self, element, expr, operator, **kw):
+ if element._is_select_statement:
+ # for IN, we are doing scalar_subquery() coercion without
+ # a warning
+ return element.scalar_subquery()
+ elif isinstance(element, elements.ClauseList):
+ assert not len(element.clauses) == 0
+ return element.self_group(against=operator)
+
+ elif isinstance(element, elements.BindParameter):
+ element = element._clone(maintain_key=True)
+ element.expanding = True
+ element.expand_op = operator
+
+ return element
+ else:
+ return element
+
+
+class OnClauseImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if legacy and isinstance(resolved, str):
+ return resolved
+ else:
+ return super(OnClauseImpl, self)._implicit_coercions(
+ original_element,
+ resolved,
+ argname=argname,
+ legacy=legacy,
+ **kw
+ )
+
+ def _text_coercion(self, element, argname=None, legacy=False):
+ if legacy and isinstance(element, str):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in "
+ "Query.join() is deprecated and will be removed in "
+ "SQLAlchemy 2.0. Please use the class-bound attribute "
+ "directly."
+ )
+ return element
+
+ return super(OnClauseImpl, self)._text_coercion(element, argname)
+
+ def _post_coercion(self, resolved, original_element=None, **kw):
+ # this is a hack right now as we want to use coercion on an
+ # ORM InstrumentedAttribute, but we want to return the object
+ # itself if it is one, not its clause element.
+ # ORM context _join and _legacy_join() would need to be improved
+ # to look for annotations in a clause element form.
+ if isinstance(original_element, roles.JoinTargetRole):
+ return original_element
+ return resolved
+
+
+class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return _no_text_coercion(element, argname)
+
+
+class StatementOptionImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return elements.TextClause(element)
+
+
+class ColumnArgumentImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+
+class ColumnArgumentOrKeyImpl(_ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+
+class StrAsPlainColumnImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ def _text_coercion(self, element, argname=None):
+ return elements.ColumnClause(element)
+
+
+class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole):
+
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return elements._textual_label_reference(element)
+
+
+class OrderByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, **kw):
+ if (
+ isinstance(resolved, self._role_class)
+ and resolved._order_by_label_element is not None
+ ):
+ return elements._label_reference(resolved)
+ else:
+ return resolved
+
+
+class GroupByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.StrictFromClauseRole):
+ return elements.ClauseList(*resolved.c)
+ else:
+ return resolved
+
+
+class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, as_key=False, **kw):
+ if as_key:
+ return element.key
+ else:
+ return element
+
+
+class ConstExprImpl(RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if element is None:
+ return elements.Null()
+ elif element is False:
+ return elements.False_()
+ elif element is True:
+ return elements.True_()
+ else:
+ self._raise_for_expected(element, argname)
+
+
+class TruncatedLabelImpl(_StringOnly, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, util.string_types):
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ """coerce the given value to :class:`._truncated_label`.
+
+ Existing :class:`._truncated_label` and
+ :class:`._anonymous_label` objects are passed
+ unchanged.
+ """
+
+ if isinstance(element, elements._truncated_label):
+ return element
+ else:
+ return elements._truncated_label(element)
+
+
+class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl):
+
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ # see #5754 for why we can't easily deprecate this coercion.
+ # essentially expressions like postgresql_where would have to be
+ # text() as they come back from reflection and we don't want to
+ # have text() elements wired into the inspection dictionaries.
+ return elements.TextClause(element)
+
+
+class DDLConstraintColumnImpl(_Deannotate, _ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+
+class DDLReferredColumnImpl(DDLConstraintColumnImpl):
+ __slots__ = ()
+
+
+class LimitOffsetImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ if resolved is None:
+ return None
+ else:
+ self._raise_for_expected(element, argname, resolved)
+
+ def _literal_coercion(self, element, name, type_, **kw):
+ if element is None:
+ return None
+ else:
+ value = util.asint(element)
+ return selectable._OffsetLimitParam(
+ name, value, type_=type_, unique=True
+ )
+
+
+class LabeledColumnExprImpl(ExpressionElementImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.ExpressionElementRole):
+ return resolved.label(None)
+ else:
+ new = super(LabeledColumnExprImpl, self)._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+ if isinstance(new, roles.ExpressionElementRole):
+ return new.label(None)
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+ _coerce_numerics = True
+ _coerce_star = True
+
+ _guess_straight_column = re.compile(r"^\w\S*$", re.I)
+
+ def _text_coercion(self, element, argname=None):
+ element = str(element)
+
+ guess_is_literal = not self._guess_straight_column.match(element)
+ raise exc.ArgumentError(
+ "Textual column expression %(column)r %(argname)sshould be "
+ "explicitly declared with text(%(column)r), "
+ "or use %(literal_column)s(%(column)r) "
+ "for more specificity"
+ % {
+ "column": util.ellipses_string(element),
+ "argname": "for argument %s" % (argname,) if argname else "",
+ "literal_column": "literal_column"
+ if guess_is_literal
+ else "column",
+ }
+ )
+
+
+class ReturnsRowsImpl(RoleImpl):
+ __slots__ = ()
+
+
+class StatementImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, original_element, argname=None, **kw):
+ if resolved is not original_element and not isinstance(
+ original_element, util.string_types
+ ):
+ # use same method as Connection uses; this will later raise
+ # ObjectNotExecutableError
+ try:
+ original_element._execute_on_connection
+ except AttributeError:
+ util.warn_deprecated(
+ "Object %r should not be used directly in a SQL statement "
+ "context, such as passing to methods such as "
+ "session.execute(). This usage will be disallowed in a "
+ "future release. "
+ "Please use Core select() / update() / delete() etc. "
+ "with Session.execute() and other statement execution "
+ "methods." % original_element,
+ "1.4",
+ )
+
+ return resolved
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_lambda_element:
+ return resolved
+ else:
+ return super(StatementImpl, self)._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+
+ def _text_coercion(self, element, argname=None):
+ util.warn_deprecated_20(
+ "Using plain strings to indicate SQL statements without using "
+ "the text() construct is "
+ "deprecated and will be removed in version 2.0. Ensure plain "
+ "SQL statements are passed using the text() construct."
+ )
+ return elements.TextClause(element)
+
+
+class SelectStatementImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_text_clause:
+ return resolved.columns()
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class HasCTEImpl(ReturnsRowsImpl):
+ __slots__ = ()
+
+
+class IsCTEImpl(RoleImpl):
+ __slots__ = ()
+
+
+class JoinTargetImpl(RoleImpl):
+ __slots__ = ()
+
+ _skip_clauseelement_for_target_match = True
+
+ def _literal_coercion(self, element, legacy=False, **kw):
+ if isinstance(element, str):
+ return element
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if isinstance(original_element, roles.JoinTargetRole):
+ # note that this codepath no longer occurs as of
+ # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
+ # were set to False.
+ return original_element
+ elif legacy and isinstance(resolved, str):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in "
+ "Query.join() is deprecated and will be removed in "
+ "SQLAlchemy 2.0. Please use the class-bound attribute "
+ "directly."
+ )
+ return resolved
+ elif legacy and isinstance(resolved, roles.WhereHavingRole):
+ return resolved
+ elif legacy and resolved._is_select_statement:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ # TODO: doing _implicit_subquery here causes tests to fail,
+ # how was this working before? probably that ORM
+ # join logic treated it as a select and subquery would happen
+ # in _ORMJoin->Join
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self,
+ original_element,
+ resolved,
+ argname=None,
+ explicit_subquery=False,
+ allow_select=True,
+ **kw
+ ):
+ if resolved._is_select_statement:
+ if explicit_subquery:
+ return resolved.subquery()
+ elif allow_select:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ return resolved._implicit_subquery
+ elif resolved._is_text_clause:
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _post_coercion(self, element, deannotate=False, **kw):
+ if deannotate:
+ return element._deannotate()
+ else:
+ return element
+
+
+class StrictFromClauseImpl(FromClauseImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self,
+ original_element,
+ resolved,
+ argname=None,
+ allow_select=False,
+ **kw
+ ):
+ if resolved._is_select_statement and allow_select:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT constructs "
+ "into FROM clauses is deprecated; please call .subquery() "
+ "on any Core select or ORM Query object in order to produce a "
+ "subquery object.",
+ version="1.4",
+ )
+ return resolved._implicit_subquery
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class AnonymizedFromClauseImpl(StrictFromClauseImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, flat=False, name=None, **kw):
+ assert name is None
+
+ return element._anonymous_fromclause(flat=flat)
+
+
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, **kw):
+ if "dml_table" in element._annotations:
+ return element._annotations["dml_table"]
+ else:
+ return element
+
+
+class DMLSelectImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_from_clause:
+ if (
+ isinstance(resolved, selectable.Alias)
+ and resolved.element._is_select_statement
+ ):
+ return resolved.element
+ else:
+ return resolved.select()
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class CompoundElementImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.FromClauseRole):
+ if element._is_subquery:
+ advice = (
+ "Use the plain select() object without "
+ "calling .subquery() or .alias()."
+ )
+ else:
+ advice = (
+ "To SELECT from any FROM clause, use the .select() method."
+ )
+ else:
+ advice = None
+ return super(CompoundElementImpl, self)._raise_for_expected(
+ element, argname=argname, resolved=resolved, advice=advice, **kw
+ )
+
+
+_impl_lookup = {}
+
+
+for name in dir(roles):
+ cls = getattr(roles, name)
+ if name.endswith("Role"):
+ name = name.replace("Role", "Impl")
+ if name in globals():
+ impl = globals()[name](cls)
+ _impl_lookup[cls] = impl
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
new file mode 100644
index 0000000..c9b6ba6
--- /dev/null
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -0,0 +1,5525 @@
+# sql/compiler.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base SQL and DDL compiler implementations.
+
+Classes provided include:
+
+:class:`.compiler.SQLCompiler` - renders SQL
+strings
+
+:class:`.compiler.DDLCompiler` - renders DDL
+(data definition language) strings
+
+:class:`.compiler.GenericTypeCompiler` - renders
+type specification strings.
+
+To generate user-defined SQL strings, see
+:doc:`/ext/compiler`.
+
+"""
+
+import collections
+import contextlib
+import itertools
+import operator
+import re
+
+from . import base
+from . import coercions
+from . import crud
+from . import elements
+from . import functions
+from . import operators
+from . import schema
+from . import selectable
+from . import sqltypes
+from .base import NO_ARG
+from .base import prefix_anon_map
+from .elements import quoted_name
+from .. import exc
+from .. import util
+
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+FK_ON_DELETE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_ON_UPDATE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
+
+BIND_TEMPLATES = {
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
+}
+
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+
+OPERATORS = {
+ # binary
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.not_match_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.not_in_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.is_not: " IS NOT ",
+ operators.collate: " COLLATE ",
+ # unary
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
+ # modifiers
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nulls_first_op: " NULLS FIRST",
+ operators.nulls_last_op: " NULLS LAST",
+}
+
+FUNCTIONS = {
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
+}
+
+EXTRACT_MAP = {
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
+}
+
+COMPOUND_KEYWORDS = {
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
+}
+
+
+RM_RENDERED_NAME = 0
+RM_NAME = 1
+RM_OBJECTS = 2
+RM_TYPE = 3
+
+
+ExpandedState = collections.namedtuple(
+ "ExpandedState",
+ [
+ "statement",
+ "additional_parameters",
+ "processors",
+ "positiontup",
+ "parameter_expansion",
+ ],
+)
+
+
+NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+
+COLLECT_CARTESIAN_PRODUCTS = util.symbol(
+ "COLLECT_CARTESIAN_PRODUCTS",
+ "Collect data on FROMs and cartesian products and gather "
+ "into 'self.from_linter'",
+ canonical=1,
+)
+
+WARN_LINTING = util.symbol(
+ "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
+)
+
+FROM_LINTING = util.symbol(
+ "FROM_LINTING",
+ "Warn for cartesian products; "
+ "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
+ canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+)
+
+
+class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+ def lint(self, start=None):
+ froms = self.froms
+ if not froms:
+ return None, None
+
+ edges = set(self.edges)
+ the_rest = set(froms)
+
+ if start is not None:
+ start_with = start
+ the_rest.remove(start_with)
+ else:
+ start_with = the_rest.pop()
+
+ stack = collections.deque([start_with])
+
+ while stack and the_rest:
+ node = stack.popleft()
+ the_rest.discard(node)
+
+ # comparison of nodes in edges here is based on hash equality, as
+ # there are "annotated" elements that match the non-annotated ones.
+ # to remove the need for in-python hash() calls, use native
+ # containment routines (e.g. "node in edge", "edge.index(node)")
+ to_remove = {edge for edge in edges if node in edge}
+
+ # appendleft the node in each edge that is not
+ # the one that matched.
+ stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
+ edges.difference_update(to_remove)
+
+ # FROMS left over? boom
+ if the_rest:
+ return the_rest, start_with
+ else:
+ return None, None
+
+ def warn(self):
+ the_rest, start_with = self.lint()
+
+ # FROMS left over? boom
+ if the_rest:
+
+ froms = the_rest
+ if froms:
+ template = (
+ "SELECT statement has a cartesian product between "
+ "FROM element(s) {froms} and "
+ 'FROM element "{start}". Apply join condition(s) '
+ "between each element to resolve."
+ )
+ froms_str = ", ".join(
+ '"{elem}"'.format(elem=self.froms[from_])
+ for from_ in froms
+ )
+ message = template.format(
+ froms=froms_str, start=self.froms[start_with]
+ )
+
+ util.warn(message)
+
+
+class Compiled(object):
+
+ """Represent a compiled SQL or DDL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ _cached_metadata = None
+
+ _result_columns = None
+
+ schema_translate_map = None
+
+ execution_options = util.EMPTY_DICT
+ """
+ Execution options propagated from the statement. In some cases,
+ sub-elements of the statement can modify these.
+ """
+
+ _annotations = util.EMPTY_DICT
+
+ compile_state = None
+ """Optional :class:`.CompileState` object that maintains additional
+ state used by the compiler.
+
+ Major executable objects such as :class:`_expression.Insert`,
+ :class:`_expression.Update`, :class:`_expression.Delete`,
+ :class:`_expression.Select` will generate this
+ state when compiled in order to calculate additional information about the
+ object. For the top level object that is to be executed, the state can be
+ stored here where it can also have applicability towards result set
+ processing.
+
+ .. versionadded:: 1.4
+
+ """
+
+ dml_compile_state = None
+ """Optional :class:`.CompileState` assigned at the same point that
+ .isinsert, .isupdate, or .isdelete is assigned.
+
+ This will normally be the same object as .compile_state, with the
+ exception of cases like the :class:`.ORMFromStatementCompileState`
+ object.
+
+ .. versionadded:: 1.4.40
+
+ """
+
+ cache_key = None
+ _gen_time = None
+
+ def __init__(
+ self,
+ dialect,
+ statement,
+ schema_translate_map=None,
+ render_schema_translate=False,
+ compile_kwargs=util.immutabledict(),
+ ):
+ """Construct a new :class:`.Compiled` object.
+
+ :param dialect: :class:`.Dialect` to compile against.
+
+ :param statement: :class:`_expression.ClauseElement` to be compiled.
+
+ :param schema_translate_map: dictionary of schema names to be
+ translated when forming the resultant SQL
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
+ :param compile_kwargs: additional kwargs that will be
+ passed to the initial call to :meth:`.Compiled.process`.
+
+
+ """
+
+ self.dialect = dialect
+ self.preparer = self.dialect.identifier_preparer
+ if schema_translate_map:
+ self.schema_translate_map = schema_translate_map
+ self.preparer = self.preparer._with_schema_translate(
+ schema_translate_map
+ )
+
+ if statement is not None:
+ self.statement = statement
+ self.can_execute = statement.supports_execution
+ self._annotations = statement._annotations
+ if self.can_execute:
+ self.execution_options = statement._execution_options
+ self.string = self.process(self.statement, **compile_kwargs)
+
+ if render_schema_translate:
+ self.string = self.preparer._render_schema_translates(
+ self.string, schema_translate_map
+ )
+ self._gen_time = util.perf_counter()
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self.can_execute:
+ return connection._execute_compiled(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self.statement)
+
+ def visit_unsupported_compilation(self, element, err):
+ util.raise_(
+ exc.UnsupportedCompilationError(self, type(element)),
+ replace_context=err,
+ )
+
+ @property
+ def sql_compiler(self):
+ """Return a Compiled that is capable of processing SQL expressions.
+
+ If this compiler is one, it would likely just return 'self'.
+
+ """
+
+ raise NotImplementedError()
+
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
+
+ def __str__(self):
+ """Return the string text of the generated SQL or DDL."""
+
+ return self.string or ""
+
+ def construct_params(
+ self, params=None, extracted_parameters=None, escape_names=True
+ ):
+ """Return the bind params for this compiled object.
+
+ :param params: a dict of string/object pairs whose values will
+ override bind values compiled in to the
+ statement.
+ """
+
+ raise NotImplementedError()
+
+ @property
+ def params(self):
+ """Return the bind params for this compiled object."""
+ return self.construct_params()
+
+
+class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
+ """Produces DDL specification for TypeEngine objects."""
+
+ ensure_kwarg = r"visit_\w+"
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_, **kw):
+ return type_._compiler_dispatch(self, **kw)
+
+ def visit_unsupported_compilation(self, element, err, **kw):
+ util.raise_(
+ exc.UnsupportedCompilationError(self, element),
+ replace_context=err,
+ )
+
+
+# this was a Visitable, but to allow accurate detection of
+# column elements this is actually a column element
+class _CompileLabel(elements.ColumnElement):
+
+ """lightweight label object which acts as an expression.Label."""
+
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
+
+ def __init__(self, col, name, alt_names=()):
+ self.element = col
+ self.name = name
+ self._alt_names = (col,) + alt_names
+
+ @property
+ def proxy_set(self):
+ return self.element.proxy_set
+
+ @property
+ def type(self):
+ return self.element.type
+
+ def self_group(self, **kw):
+ return self
+
+
+class SQLCompiler(Compiled):
+ """Default implementation of :class:`.Compiled`.
+
+ Compiles :class:`_expression.ClauseElement` objects into SQL strings.
+
+ """
+
+ extract_map = EXTRACT_MAP
+
+ compound_keywords = COMPOUND_KEYWORDS
+
+ isdelete = isinsert = isupdate = False
+ """class-level defaults which can be set at the instance
+ level to define if this Compiled instance represents
+ INSERT/UPDATE/DELETE
+ """
+
+ isplaintext = False
+
+ returning = None
+ """holds the "returning" collection of columns if
+ the statement is CRUD and defines returning columns
+ either implicitly or explicitly
+ """
+
+ returning_precedes_values = False
+ """set to True classwide to generate RETURNING
+ clauses before the VALUES or WHERE clause (i.e. MSSQL)
+ """
+
+ render_table_with_column_in_update_from = False
+ """set to True classwide to indicate the SET clause
+ in a multi-table UPDATE statement should qualify
+ columns with the table name (i.e. MySQL only)
+ """
+
+ ansi_bind_rules = False
+ """SQL 92 doesn't allow bind parameters to be used
+ in the columns clause of a SELECT, nor does it allow
+ ambiguous expressions like "? = ?". A compiler
+ subclass can set this flag to False if the target
+ driver/DB enforces this
+ """
+
+ _textual_ordered_columns = False
+ """tell the result object that the column names as rendered are important,
+ but they are also "ordered" vs. what is in the compiled object here.
+ """
+
+ _ordered_columns = True
+ """
+ if False, means we can't be sure the list of entries
+ in _result_columns is actually the rendered order. Usually
+ True unless using an unordered TextualSelect.
+ """
+
+ _loose_column_name_matching = False
+ """tell the result object that the SQL statement is textual, wants to match
+ up to Column objects, and may be using the ._tq_label in the SELECT rather
+ than the base name.
+
+ """
+
+ _numeric_binds = False
+ """
+ True if paramstyle is "numeric". This paramstyle is trickier than
+ all the others.
+
+ """
+
+ _render_postcompile = False
+ """
+ whether to render out POSTCOMPILE params during the compile phase.
+
+ """
+
+ insert_single_values_expr = None
+ """When an INSERT is compiled with a single set of parameters inside
+ a VALUES expression, the string is assigned here, where it can be
+ used for insert batching schemes to rewrite the VALUES expression.
+
+ .. versionadded:: 1.3.8
+
+ """
+
+ literal_execute_params = frozenset()
+ """bindparameter objects that are rendered as literal values at statement
+ execution time.
+
+ """
+
+ post_compile_params = frozenset()
+ """bindparameter objects that are rendered as bound parameter placeholders
+ at statement execution time.
+
+ """
+
+ escaped_bind_names = util.EMPTY_DICT
+ """Late escaping of bound parameter names that has to be converted
+ to the original name when looking in the parameter dictionary.
+
+ """
+
+ has_out_parameters = False
+ """if True, there are bindparam() objects that have the isoutparam
+ flag set."""
+
+ insert_prefetch = update_prefetch = ()
+
+ postfetch_lastrowid = False
+ """if True, and this in insert, use cursor.lastrowid to populate
+ result.inserted_primary_key. """
+
+ _cache_key_bind_match = None
+ """a mapping that will relate the BindParameter object we compile
+ to those that are part of the extracted collection of parameters
+ in the cache key, if we were given a cache key.
+
+ """
+
+ positiontup = None
+ """for a compiled construct that uses a positional paramstyle, will be
+ a sequence of strings, indicating the names of bound parameters in order.
+
+ This is used in order to render bound parameters in their correct order,
+ and is combined with the :attr:`_sql.Compiled.params` dictionary to
+ render parameters.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string` - includes a usage example for
+ debugging use cases.
+
+ """
+
+ inline = False
+
+ def __init__(
+ self,
+ dialect,
+ statement,
+ cache_key=None,
+ column_keys=None,
+ for_executemany=False,
+ linting=NO_LINTING,
+ **kwargs
+ ):
+ """Construct a new :class:`.SQLCompiler` object.
+
+ :param dialect: :class:`.Dialect` to be used
+
+ :param statement: :class:`_expression.ClauseElement` to be compiled
+
+ :param column_keys: a list of column names to be compiled into an
+ INSERT or UPDATE statement.
+
+ :param for_executemany: whether INSERT / UPDATE statements should
+ expect that they are to be invoked in an "executemany" style,
+ which may impact how the statement will be expected to return the
+ values of defaults and autoincrement / sequences and similar.
+ Depending on the backend and driver in use, support for retrieving
+ these values may be disabled which means SQL expressions may
+ be rendered inline, RETURNING may not be rendered, etc.
+
+ :param kwargs: additional keyword arguments to be consumed by the
+ superclass.
+
+ """
+ self.column_keys = column_keys
+
+ self.cache_key = cache_key
+
+ if cache_key:
+ self._cache_key_bind_match = ckbm = {
+ b.key: b for b in cache_key[1]
+ }
+ ckbm.update({b: [b] for b in cache_key[1]})
+
+ # compile INSERT/UPDATE defaults/sequences to expect executemany
+ # style execution, which may mean no pre-execute of defaults,
+ # or no RETURNING
+ self.for_executemany = for_executemany
+
+ self.linting = linting
+
+ # a dictionary of bind parameter keys to BindParameter
+ # instances.
+ self.binds = {}
+
+ # a dictionary of BindParameter instances to "compiled" names
+ # that are actually present in the generated SQL
+ self.bind_names = util.column_dict()
+
+ # stack which keeps track of nested SELECT statements
+ self.stack = []
+
+ # relates label names in the final SQL to a tuple of local
+ # column/label name, ColumnElement object (if any) and
+ # TypeEngine. CursorResult uses this for type processing and
+ # column targeting
+ self._result_columns = []
+
+ # true if the paramstyle is positional
+ self.positional = dialect.positional
+ if self.positional:
+ self.positiontup = []
+ self._numeric_binds = dialect.paramstyle == "numeric"
+ self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+
+ self.ctes = None
+
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
+
+ # a map which tracks "anonymous" identifiers that are created on
+ # the fly here
+ self.anon_map = prefix_anon_map()
+
+ # a map which tracks "truncated" names based on
+ # dialect.label_length or dialect.max_identifier_length
+ self.truncated_names = {}
+
+ Compiled.__init__(self, dialect, statement, **kwargs)
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ if statement._returning:
+ self.returning = statement._returning
+
+ if self.isinsert or self.isupdate:
+ if statement._inline:
+ self.inline = True
+ elif self.for_executemany and (
+ not self.isinsert
+ or (
+ self.dialect.insert_executemany_returning
+ and statement._return_defaults
+ )
+ ):
+ self.inline = True
+
+ if self.positional and self._numeric_binds:
+ self._apply_numbered_params()
+
+ if self._render_postcompile:
+ self._process_parameters_for_postcompile(_populate_self=True)
+
+ @property
+ def current_executable(self):
+ """Return the current 'executable' that is being compiled.
+
+ This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
+ :class:`_sql.Update`, :class:`_sql.Delete`,
+ :class:`_sql.CompoundSelect` object that is being compiled.
+ Specifically it's assigned to the ``self.stack`` list of elements.
+
+ When a statement like the above is being compiled, it normally
+ is also assigned to the ``.statement`` attribute of the
+ :class:`_sql.Compiler` object. However, all SQL constructs are
+ ultimately nestable, and this attribute should never be consulted
+ by a ``visit_`` method, as it is not guaranteed to be assigned
+ nor guaranteed to correspond to the current statement being compiled.
+
+ .. versionadded:: 1.3.21
+
+ For compatibility with previous versions, use the following
+ recipe::
+
+ statement = getattr(self, "current_executable", False)
+ if statement is False:
+ statement = self.stack[-1]["selectable"]
+
+ For versions 1.4 and above, ensure only .current_executable
+ is used; the format of "self.stack" may change.
+
+
+ """
+ try:
+ return self.stack[-1]["selectable"]
+ except IndexError as ie:
+ util.raise_(
+ IndexError("Compiler does not have a stack entry"),
+ replace_context=ie,
+ )
+
+ @property
+ def prefetch(self):
+ return list(self.insert_prefetch + self.update_prefetch)
+
+ @util.memoized_property
+ def _global_attributes(self):
+ return {}
+
+ @util.memoized_instancemethod
+ def _init_cte_state(self):
+ """Initialize collections related to CTEs only if
+ a CTE is located, to save on the overhead of
+ these collections otherwise.
+
+ """
+ # collect CTEs to tack on top of a SELECT
+ # To store the query to print - Dict[cte, text_query]
+ self.ctes = util.OrderedDict()
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ self.ctes_by_level_name = {}
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name)]
+ self.level_name_by_cte = {}
+
+ self.ctes_recursive = False
+ if self.positional:
+ self.cte_positional = {}
+
+ @contextlib.contextmanager
+ def _nested_result(self):
+ """special API to support the use case of 'nested result sets'"""
+ result_columns, ordered_columns = (
+ self._result_columns,
+ self._ordered_columns,
+ )
+ self._result_columns, self._ordered_columns = [], False
+
+ try:
+ if self.stack:
+ entry = self.stack[-1]
+ entry["need_result_map_for_nested"] = True
+ else:
+ entry = None
+ yield self._result_columns, self._ordered_columns
+ finally:
+ if entry:
+ entry.pop("need_result_map_for_nested")
+ self._result_columns, self._ordered_columns = (
+ result_columns,
+ ordered_columns,
+ )
+
+ def _apply_numbered_params(self):
+ poscount = itertools.count(1)
+ self.string = re.sub(
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
+
+ @util.memoized_property
+ def _bind_processors(self):
+
+ return dict(
+ (
+ key,
+ value,
+ )
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect)
+ if not bindparam.type._is_tuple_type
+ else tuple(
+ elem_type._cached_bind_processor(self.dialect)
+ for elem_type in bindparam.type.types
+ ),
+ )
+ for bindparam in self.bind_names
+ )
+ if value is not None
+ )
+
+ def is_subquery(self):
+ return len(self.stack) > 1
+
+ @property
+ def sql_compiler(self):
+ return self
+
+ def construct_params(
+ self,
+ params=None,
+ _group_number=None,
+ _check=True,
+ extracted_parameters=None,
+ escape_names=True,
+ ):
+ """return a dictionary of bind parameter keys and values"""
+
+ has_escaped_names = escape_names and bool(self.escaped_bind_names)
+
+ if extracted_parameters:
+ # related the bound parameters collected in the original cache key
+ # to those collected in the incoming cache key. They will not have
+ # matching names but they will line up positionally in the same
+ # way. The parameters present in self.bind_names may be clones of
+ # these original cache key params in the case of DML but the .key
+ # will be guaranteed to match.
+ try:
+ orig_extracted = self.cache_key[1]
+ except TypeError as err:
+ util.raise_(
+ exc.CompileError(
+ "This compiled object has no original cache key; "
+ "can't pass extracted_parameters to construct_params"
+ ),
+ replace_context=err,
+ )
+
+ ckbm = self._cache_key_bind_match
+ resolved_extracted = {
+ bind: extracted
+ for b, extracted in zip(orig_extracted, extracted_parameters)
+ for bind in ckbm[b]
+ }
+ else:
+ resolved_extracted = None
+
+ if params:
+ pd = {}
+ for bindparam, name in self.bind_names.items():
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if has_escaped_names
+ else name
+ )
+
+ if bindparam.key in params:
+ pd[escaped_name] = params[bindparam.key]
+ elif name in params:
+ pd[escaped_name] = params[name]
+
+ elif _check and bindparam.required:
+ if _group_number:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
+ else:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key,
+ code="cd3x",
+ )
+ else:
+ if resolved_extracted:
+ value_param = resolved_extracted.get(
+ bindparam, bindparam
+ )
+ else:
+ value_param = bindparam
+
+ if bindparam.callable:
+ pd[escaped_name] = value_param.effective_value
+ else:
+ pd[escaped_name] = value_param.value
+ return pd
+ else:
+ pd = {}
+ for bindparam, name in self.bind_names.items():
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if has_escaped_names
+ else name
+ )
+
+ if _check and bindparam.required:
+ if _group_number:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
+ else:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key,
+ code="cd3x",
+ )
+
+ if resolved_extracted:
+ value_param = resolved_extracted.get(bindparam, bindparam)
+ else:
+ value_param = bindparam
+
+ if bindparam.callable:
+ pd[escaped_name] = value_param.effective_value
+ else:
+ pd[escaped_name] = value_param.value
+ return pd
+
+ @util.memoized_instancemethod
+ def _get_set_input_sizes_lookup(
+ self, include_types=None, exclude_types=None
+ ):
+ if not hasattr(self, "bind_names"):
+ return None
+
+ dialect = self.dialect
+ dbapi = self.dialect.dbapi
+
+ # _unwrapped_dialect_impl() is necessary so that we get the
+ # correct dialect type for a custom TypeDecorator, or a Variant,
+ # which is also a TypeDecorator. Special types like Interval,
+ # that use TypeDecorator but also might be mapped directly
+ # for a dialect impl, also subclass Emulated first which overrides
+ # this behavior in those cases to behave like the default.
+
+ if include_types is None and exclude_types is None:
+
+ def _lookup_type(typ):
+ dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
+ return dbtype
+
+ else:
+
+ def _lookup_type(typ):
+ # note we get dbtype from the possibly TypeDecorator-wrapped
+ # dialect_impl, but the dialect_impl itself that we use for
+ # include/exclude is the unwrapped version.
+
+ dialect_impl = typ._unwrapped_dialect_impl(dialect)
+
+ dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
+
+ if (
+ dbtype is not None
+ and (
+ exclude_types is None
+ or dbtype not in exclude_types
+ and type(dialect_impl) not in exclude_types
+ )
+ and (
+ include_types is None
+ or dbtype in include_types
+ or type(dialect_impl) in include_types
+ )
+ ):
+ return dbtype
+ else:
+ return None
+
+ inputsizes = {}
+ literal_execute_params = self.literal_execute_params
+
+ for bindparam in self.bind_names:
+ if bindparam in literal_execute_params:
+ continue
+
+ if bindparam.type._is_tuple_type:
+ inputsizes[bindparam] = [
+ _lookup_type(typ) for typ in bindparam.type.types
+ ]
+ else:
+ inputsizes[bindparam] = _lookup_type(bindparam.type)
+
+ return inputsizes
+
+ @property
+ def params(self):
+ """Return the bind param dictionary embedded into this
+ compiled object, for those values that are present.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string` - includes a usage example for
+ debugging use cases.
+
+ """
+ return self.construct_params(_check=False)
+
+ def _process_parameters_for_postcompile(
+ self, parameters=None, _populate_self=False
+ ):
+ """handle special post compile parameters.
+
+ These include:
+
+ * "expanding" parameters -typically IN tuples that are rendered
+ on a per-parameter basis for an otherwise fixed SQL statement string.
+
+ * literal_binds compiled with the literal_execute flag. Used for
+ things like SQL Server "TOP N" where the driver does not accommodate
+ N as a bound parameter.
+
+ """
+
+ if parameters is None:
+ parameters = self.construct_params(escape_names=False)
+
+ expanded_parameters = {}
+ if self.positional:
+ positiontup = []
+ else:
+ positiontup = None
+
+ processors = self._bind_processors
+
+ new_processors = {}
+
+ if self.positional and self._numeric_binds:
+ # I'm not familiar with any DBAPI that uses 'numeric'.
+ # strategy would likely be to make use of numbers greater than
+ # the highest number present; then for expanding parameters,
+ # append them to the end of the parameter list. that way
+ # we avoid having to renumber all the existing parameters.
+ raise NotImplementedError(
+ "'post-compile' bind parameters are not supported with "
+ "the 'numeric' paramstyle at this time."
+ )
+
+ replacement_expressions = {}
+ to_update_sets = {}
+
+ # notes:
+ # *unescaped* parameter names in:
+ # self.bind_names, self.binds, self._bind_processors
+ #
+ # *escaped* parameter names in:
+ # construct_params(), replacement_expressions
+
+ for name in (
+ self.positiontup if self.positional else self.bind_names.values()
+ ):
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if self.escaped_bind_names
+ else name
+ )
+
+ parameter = self.binds[name]
+ if parameter in self.literal_execute_params:
+ if escaped_name not in replacement_expressions:
+ value = parameters.pop(name)
+
+ replacement_expressions[
+ escaped_name
+ ] = self.render_literal_bindparam(
+ parameter, render_literal_value=value
+ )
+ continue
+
+ if parameter in self.post_compile_params:
+ if escaped_name in replacement_expressions:
+ to_update = to_update_sets[escaped_name]
+ else:
+ # we are removing the parameter from parameters
+ # because it is a list value, which is not expected by
+ # TypeEngine objects that would otherwise be asked to
+ # process it. the single name is being replaced with
+ # individual numbered parameters for each value in the
+ # param.
+ #
+ # note we are also inserting *escaped* parameter names
+ # into the given dictionary. default dialect will
+ # use these param names directly as they will not be
+ # in the escaped_bind_names dictionary.
+ values = parameters.pop(name)
+
+ leep = self._literal_execute_expanding_parameter
+ to_update, replacement_expr = leep(
+ escaped_name, parameter, values
+ )
+
+ to_update_sets[escaped_name] = to_update
+ replacement_expressions[escaped_name] = replacement_expr
+
+ if not parameter.literal_execute:
+ parameters.update(to_update)
+ if parameter.type._is_tuple_type:
+ new_processors.update(
+ (
+ "%s_%s_%s" % (name, i, j),
+ processors[name][j - 1],
+ )
+ for i, tuple_element in enumerate(values, 1)
+ for j, value in enumerate(tuple_element, 1)
+ if name in processors
+ and processors[name][j - 1] is not None
+ )
+ else:
+ new_processors.update(
+ (key, processors[name])
+ for key, value in to_update
+ if name in processors
+ )
+ if self.positional:
+ positiontup.extend(name for name, value in to_update)
+ expanded_parameters[name] = [
+ expand_key for expand_key, value in to_update
+ ]
+ elif self.positional:
+ positiontup.append(name)
+
+ def process_expanding(m):
+ key = m.group(1)
+ expr = replacement_expressions[key]
+
+ # if POSTCOMPILE included a bind_expression, render that
+ # around each element
+ if m.group(2):
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ expr = ", ".join(
+ "%s%s%s" % (be_left, exp, be_right)
+ for exp in expr.split(", ")
+ )
+ return expr
+
+ statement = re.sub(
+ r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ process_expanding,
+ self.string,
+ )
+
+ expanded_state = ExpandedState(
+ statement,
+ parameters,
+ new_processors,
+ positiontup,
+ expanded_parameters,
+ )
+
+ if _populate_self:
+ # this is for the "render_postcompile" flag, which is not
+ # otherwise used internally and is for end-user debugging and
+ # special use cases.
+ self.string = expanded_state.statement
+ self._bind_processors.update(expanded_state.processors)
+ self.positiontup = expanded_state.positiontup
+ self.post_compile_params = frozenset()
+ for key in expanded_state.parameter_expansion:
+ bind = self.binds.pop(key)
+ self.bind_names.pop(bind)
+ for value, expanded_key in zip(
+ bind.value, expanded_state.parameter_expansion[key]
+ ):
+ self.binds[expanded_key] = new_param = bind._with_value(
+ value
+ )
+ self.bind_names[new_param] = expanded_key
+
+ return expanded_state
+
+ @util.preload_module("sqlalchemy.engine.cursor")
+ def _create_result_map(self):
+ """utility method used for unit tests only."""
+ cursor = util.preloaded.engine_cursor
+ return cursor.CursorResultMetaData._create_description_match_map(
+ self._result_columns
+ )
+
+ @util.memoized_property
+ def _within_exec_param_key_getter(self):
+ getter = self._key_getters_for_crud_column[2]
+ return getter
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.engine.result")
+ def _inserted_primary_key_from_lastrowid_getter(self):
+ result = util.preloaded.engine_result
+
+ param_key_getter = self._within_exec_param_key_getter
+ table = self.statement.table
+
+ getters = [
+ (operator.methodcaller("get", param_key_getter(col), None), col)
+ for col in table.primary_key
+ ]
+
+ autoinc_col = table._autoincrement_column
+ if autoinc_col is not None:
+ # apply type post processors to the lastrowid
+ proc = autoinc_col.type._cached_result_processor(
+ self.dialect, None
+ )
+ else:
+ proc = None
+
+ row_fn = result.result_tuple([col.key for col in table.primary_key])
+
+ def get(lastrowid, parameters):
+ """given cursor.lastrowid value and the parameters used for INSERT,
+ return a "row" that represents the primary key, either by
+ using the "lastrowid" or by extracting values from the parameters
+ that were sent along with the INSERT.
+
+ """
+ if proc is not None:
+ lastrowid = proc(lastrowid)
+
+ if lastrowid is None:
+ return row_fn(getter(parameters) for getter, col in getters)
+ else:
+ return row_fn(
+ lastrowid if col is autoinc_col else getter(parameters)
+ for getter, col in getters
+ )
+
+ return get
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.engine.result")
+ def _inserted_primary_key_from_returning_getter(self):
+ result = util.preloaded.engine_result
+
+ param_key_getter = self._within_exec_param_key_getter
+ table = self.statement.table
+
+ ret = {col: idx for idx, col in enumerate(self.returning)}
+
+ getters = [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ]
+
+ row_fn = result.result_tuple([col.key for col in table.primary_key])
+
+ def get(row, parameters):
+ return row_fn(
+ getter(row) if use_row else getter(parameters)
+ for getter, use_row in getters
+ )
+
+ return get
+
+ def default_from(self):
+ """Called when a SELECT statement has no froms, and no FROM clause is
+ to be appended.
+
+ Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
+
+ """
+ return ""
+
+ def visit_grouping(self, grouping, asfrom=False, **kwargs):
+ return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
+
+ def visit_select_statement_grouping(self, grouping, **kwargs):
+ return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
+
+ def visit_label_reference(
+ self, element, within_columns_clause=False, **kwargs
+ ):
+ if self.stack and self.dialect.supports_simple_order_by_label:
+ compile_state = self.stack[-1]["compile_state"]
+
+ (
+ with_cols,
+ only_froms,
+ only_cols,
+ ) = compile_state._label_resolve_dict
+ if within_columns_clause:
+ resolve_dict = only_froms
+ else:
+ resolve_dict = only_cols
+
+ # this can be None in the case that a _label_reference()
+ # were subject to a replacement operation, in which case
+ # the replacement of the Label element may have changed
+ # to something else like a ColumnClause expression.
+ order_by_elem = element.element._order_by_label_element
+
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
+ return self.process(
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
+
+ def visit_textual_label_reference(
+ self, element, within_columns_clause=False, **kwargs
+ ):
+ if not self.stack:
+ # compiling the element outside of the context of a SELECT
+ return self.process(element._text_clause)
+
+ compile_state = self.stack[-1]["compile_state"]
+ with_cols, only_froms, only_cols = compile_state._label_resolve_dict
+ try:
+ if within_columns_clause:
+ col = only_froms[element.element]
+ else:
+ col = with_cols[element.element]
+ except KeyError as err:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=err,
+ )
+ else:
+ kwargs["render_label_as_label"] = col
+ return self.process(
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ result_map_targets=(),
+ **kw
+ ):
+ # only render labels within the columns clause
+ # or ORDER BY clause of a select. dialect-specific compilers
+ # can modify this behavior.
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
+ render_label_only = render_label_as_label is label
+
+ if render_label_only or render_label_with_as:
+ if isinstance(label.name, elements._truncated_label):
+ labelname = self._truncated_identifier("colident", label.name)
+ else:
+ labelname = label.name
+
+ if render_label_with_as:
+ if add_to_result_map is not None:
+ add_to_result_map(
+ labelname,
+ label.name,
+ (label, labelname) + label._alt_names + result_map_targets,
+ label.type,
+ )
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
+ elif render_label_only:
+ return self.preparer.format_label(label, labelname)
+ else:
+ return label.element._compiler_dispatch(
+ self, within_columns_clause=False, **kw
+ )
+
+ def _fallback_column_name(self, column):
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
+
+ def visit_lambda_element(self, element, **kw):
+ sql_element = element._resolved
+ return self.process(sql_element, **kw)
+
+ def visit_column(
+ self,
+ column,
+ add_to_result_map=None,
+ include_table=True,
+ result_map_targets=(),
+ **kwargs
+ ):
+ name = orig_name = column.name
+ if name is None:
+ name = self._fallback_column_name(column)
+
+ is_literal = column.is_literal
+ if not is_literal and isinstance(name, elements._truncated_label):
+ name = self._truncated_identifier("colident", name)
+
+ if add_to_result_map is not None:
+ targets = (column, name, column.key) + result_map_targets
+ if column._tq_label:
+ targets += (column._tq_label,)
+
+ add_to_result_map(name, orig_name, targets, column.type)
+
+ if is_literal:
+ # note we are not currently accommodating for
+ # literal_column(quoted_name('ident', True)) here
+ name = self.escape_literal_column(name)
+ else:
+ name = self.preparer.quote(name)
+ table = column.table
+ if table is None or not include_table or not table.named_with_column:
+ return name
+ else:
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if effective_schema:
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
+ else:
+ schema_prefix = ""
+ tablename = table.name
+ if isinstance(tablename, elements._truncated_label):
+ tablename = self._truncated_identifier("alias", tablename)
+
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
+
+ def visit_collation(self, element, **kw):
+ return self.preparer.format_collation(element.collation)
+
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
+
+ def visit_index(self, index, **kwargs):
+ return index.name
+
+ def visit_typeclause(self, typeclause, **kw):
+ kw["type_expression"] = typeclause
+ kw["identifier_preparer"] = self.preparer
+ return self.dialect.type_compiler.process(typeclause.type, **kw)
+
+ def post_process_text(self, text):
+ if self.preparer._double_percents:
+ text = text.replace("%", "%%")
+ return text
+
+ def escape_literal_column(self, text):
+ if self.preparer._double_percents:
+ text = text.replace("%", "%%")
+ return text
+
+ def visit_textclause(self, textclause, add_to_result_map=None, **kw):
+ def do_bindparam(m):
+ name = m.group(1)
+ if name in textclause._bindparams:
+ return self.process(textclause._bindparams[name], **kw)
+ else:
+ return self.bindparam_string(name, **kw)
+
+ if not self.stack:
+ self.isplaintext = True
+
+ if add_to_result_map:
+ # text() object is present in the columns clause of a
+ # select(). Add a no-name entry to the result map so that
+ # row[text()] produces a result
+ add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
+ # un-escape any \:params
+ return BIND_PARAMS_ESC.sub(
+ lambda m: m.group(1),
+ BIND_PARAMS.sub(
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
+ )
+
+ def visit_textual_select(
+ self, taf, compound_index=None, asfrom=False, **kw
+ ):
+
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ new_entry = {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": taf,
+ }
+ self.stack.append(new_entry)
+
+ if taf._independent_ctes:
+ for cte in taf._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ if populate_result_map:
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
+
+ # enable looser result column matching when the SQL text links to
+ # Column objects by name only
+ self._loose_column_name_matching = not taf.positional and bool(
+ taf.column_args
+ )
+
+ for c in taf.column_args:
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
+
+ text = self.process(taf.element, **kw)
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def visit_null(self, expr, **kw):
+ return "NULL"
+
+ def visit_true(self, expr, **kw):
+ if self.dialect.supports_native_boolean:
+ return "true"
+ else:
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ if self.dialect.supports_native_boolean:
+ return "false"
+ else:
+ return "0"
+
+ def _generate_delimited_list(self, elements, separator, **kw):
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in elements)
+ if s
+ )
+
+ def _generate_delimited_and_list(self, clauses, **kw):
+
+ lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
+ operators.and_,
+ elements.True_._singleton,
+ elements.False_._singleton,
+ clauses,
+ )
+ if lcc == 1:
+ return clauses[0]._compiler_dispatch(self, **kw)
+ else:
+ separator = OPERATORS[operators.and_]
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in clauses)
+ if s
+ )
+
+ def visit_tuple(self, clauselist, **kw):
+ return "(%s)" % self.visit_clauselist(clauselist, **kw)
+
+ def visit_clauselist(self, clauselist, **kw):
+ sep = clauselist.operator
+ if sep is None:
+ sep = " "
+ else:
+ sep = OPERATORS[clauselist.operator]
+
+ return self._generate_delimited_list(clauselist.clauses, sep, **kw)
+
+ def visit_case(self, clause, **kwargs):
+ x = "CASE "
+ if clause.value is not None:
+ x += clause.value._compiler_dispatch(self, **kwargs) + " "
+ for cond, result in clause.whens:
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
+ if clause.else_ is not None:
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
+ x += "END"
+ return x
+
+ def visit_type_coerce(self, type_coerce, **kw):
+ return type_coerce.typed_expression._compiler_dispatch(self, **kw)
+
+ def visit_cast(self, cast, **kwargs):
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
+
+ def _format_frame_clause(self, range_, **kw):
+
+ return "%s AND %s" % (
+ "UNBOUNDED PRECEDING"
+ if range_[0] is elements.RANGE_UNBOUNDED
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
+ if range_[0] < 0
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
+ "UNBOUNDED FOLLOWING"
+ if range_[1] is elements.RANGE_UNBOUNDED
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
+ if range_[1] < 0
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
+ )
+
+ def visit_over(self, over, **kwargs):
+ if over.range_:
+ range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
+ over.range_, **kwargs
+ )
+ elif over.rows:
+ range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
+ over.rows, **kwargs
+ )
+ else:
+ range_ = None
+
+ return "%s OVER (%s)" % (
+ over.element._compiler_dispatch(self, **kwargs),
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
+ )
+
+ def visit_withingroup(self, withingroup, **kwargs):
+ return "%s WITHIN GROUP (ORDER BY %s)" % (
+ withingroup.element._compiler_dispatch(self, **kwargs),
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_extract(self, extract, **kwargs):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "EXTRACT(%s FROM %s)" % (
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_scalar_function_column(self, element, **kw):
+ compiled_fn = self.visit_function(element.fn, **kw)
+ compiled_col = self.visit_column(element, **kw)
+ return "(%s).%s" % (compiled_fn, compiled_col)
+
+ def visit_function(self, func, add_to_result_map=None, **kwargs):
+ if add_to_result_map is not None:
+ add_to_result_map(func.name, func.name, (), func.type)
+
+ disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+ if disp:
+ text = disp(func, **kwargs)
+ else:
+ name = FUNCTIONS.get(func._deannotate().__class__, None)
+ if name:
+ if func._has_args:
+ name += "%(expr)s"
+ else:
+ name = func.name
+ name = (
+ self.preparer.quote(name)
+ if self.preparer._requires_quotes_illegal_chars(name)
+ or isinstance(name, elements.quoted_name)
+ else name
+ )
+ name = name + "%(expr)s"
+ text = ".".join(
+ [
+ (
+ self.preparer.quote(tok)
+ if self.preparer._requires_quotes_illegal_chars(tok)
+ or isinstance(name, elements.quoted_name)
+ else tok
+ )
+ for tok in func.packagenames
+ ]
+ + [name]
+ ) % {"expr": self.function_argspec(func, **kwargs)}
+
+ if func._with_ordinality:
+ text += " WITH ORDINALITY"
+ return text
+
+ def visit_next_value_func(self, next_value, **kw):
+ return self.visit_sequence(next_value.sequence)
+
+ def visit_sequence(self, sequence, **kw):
+ raise NotImplementedError(
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
+ )
+
+ def function_argspec(self, func, **kwargs):
+ return func.clause_expr._compiler_dispatch(self, **kwargs)
+
+ def visit_compound_select(
+ self, cs, asfrom=False, compound_index=None, **kwargs
+ ):
+ toplevel = not self.stack
+
+ compile_state = cs._compile_state_factory(cs, self, **kwargs)
+
+ if toplevel and not self.compile_state:
+ self.compile_state = compile_state
+
+ compound_stmt = compile_state.statement
+
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+ need_result_map = toplevel or (
+ not compound_index
+ and entry.get("need_result_map_for_compound", False)
+ )
+
+ # indicates there is already a CompoundSelect in play
+ if compound_index == 0:
+ entry["select_0"] = cs
+
+ self.stack.append(
+ {
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "compile_state": compile_state,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
+
+ if compound_stmt._independent_ctes:
+ for cte in compound_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kwargs)
+
+ keyword = self.compound_keywords.get(cs.keyword)
+
+ text = (" " + keyword + " ").join(
+ (
+ c._compiler_dispatch(
+ self, asfrom=asfrom, compound_index=i, **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
+ )
+
+ kwargs["include_table"] = False
+ text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
+ text += self.order_by_clause(cs, **kwargs)
+ if cs._has_row_limiting_clause:
+ text += self._row_limit_clause(cs, **kwargs)
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
+
+ self.stack.pop(-1)
+ return text
+
+ def _row_limit_clause(self, cs, **kwargs):
+ if cs._fetch_clause is not None:
+ return self.fetch_clause(cs, **kwargs)
+ else:
+ return self.limit_clause(cs, **kwargs)
+
+ def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+ attrname = "visit_%s_%s%s" % (
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
+ return getattr(self, attrname, None)
+
+ def visit_unary(
+ self, unary, add_to_result_map=None, result_map_targets=(), **kw
+ ):
+
+ if add_to_result_map is not None:
+ result_map_targets += (unary,)
+ kw["add_to_result_map"] = add_to_result_map
+ kw["result_map_targets"] = result_map_targets
+
+ if unary.operator:
+ if unary.modifier:
+ raise exc.CompileError(
+ "Unary expression does not support operator "
+ "and modifier simultaneously"
+ )
+ disp = self._get_operator_dispatch(
+ unary.operator, "unary", "operator"
+ )
+ if disp:
+ return disp(unary, unary.operator, **kw)
+ else:
+ return self._generate_generic_unary_operator(
+ unary, OPERATORS[unary.operator], **kw
+ )
+ elif unary.modifier:
+ disp = self._get_operator_dispatch(
+ unary.modifier, "unary", "modifier"
+ )
+ if disp:
+ return disp(unary, unary.modifier, **kw)
+ else:
+ return self._generate_generic_unary_modifier(
+ unary, OPERATORS[unary.modifier], **kw
+ )
+ else:
+ raise exc.CompileError(
+ "Unary expression has no operator or modifier"
+ )
+
+ def visit_is_true_unary_operator(self, element, operator, **kw):
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
+ return self.process(element.element, **kw)
+ else:
+ return "%s = 1" % self.process(element.element, **kw)
+
+ def visit_is_false_unary_operator(self, element, operator, **kw):
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
+ return "NOT %s" % self.process(element.element, **kw)
+ else:
+ return "%s = 0" % self.process(element.element, **kw)
+
+ def visit_not_match_op_binary(self, binary, operator, **kw):
+ return "NOT %s" % self.visit_binary(
+ binary, override_operator=operators.match_op
+ )
+
+ def visit_not_in_op_binary(self, binary, operator, **kw):
+ # The brackets are required in the NOT IN operation because the empty
+ # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
+ # The presence of the OR makes the brackets required.
+ return "(%s)" % self._generate_generic_binary(
+ binary, OPERATORS[operator], **kw
+ )
+
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ if expand_op is operators.not_in_op:
+ if len(type_) > 1:
+ return "(%s)) OR (1 = 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) OR (1 = 1"
+ elif expand_op is operators.in_op:
+ if len(type_) > 1:
+ return "(%s)) AND (1 != 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) AND (1 != 1"
+ else:
+ return self.visit_empty_set_expr(type_)
+
+ def visit_empty_set_expr(self, element_types):
+ raise NotImplementedError(
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
+ )
+
+ def _literal_execute_expanding_parameter_literal_binds(
+ self, parameter, values
+ ):
+
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
+ if not values:
+ if typ_dialect_impl._is_tuple_type:
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + self.visit_empty_set_op_expr(
+ parameter.type.types, parameter.expand_op
+ )
+
+ else:
+ replacement_expression = self.visit_empty_set_op_expr(
+ [parameter.type], parameter.expand_op
+ )
+
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
+ ):
+
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + ", ".join(
+ "(%s)"
+ % (
+ ", ".join(
+ self.render_literal_value(value, param_type)
+ for value, param_type in zip(
+ tuple_element, parameter.type.types
+ )
+ )
+ )
+ for i, tuple_element in enumerate(values)
+ )
+ else:
+ replacement_expression = ", ".join(
+ self.render_literal_value(value, parameter.type)
+ for value in values
+ )
+
+ return (), replacement_expression
+
+ def _literal_execute_expanding_parameter(self, name, parameter, values):
+
+ if parameter.literal_execute:
+ return self._literal_execute_expanding_parameter_literal_binds(
+ parameter, values
+ )
+
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
+ if not values:
+ to_update = []
+ if typ_dialect_impl._is_tuple_type:
+
+ replacement_expression = self.visit_empty_set_op_expr(
+ parameter.type.types, parameter.expand_op
+ )
+ else:
+ replacement_expression = self.visit_empty_set_op_expr(
+ [parameter.type], parameter.expand_op
+ )
+
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
+ ):
+ assert not typ_dialect_impl._is_array
+ to_update = [
+ ("%s_%s_%s" % (name, i, j), value)
+ for i, tuple_element in enumerate(values, 1)
+ for j, value in enumerate(tuple_element, 1)
+ ]
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + ", ".join(
+ "(%s)"
+ % (
+ ", ".join(
+ self.bindtemplate
+ % {"name": to_update[i * len(tuple_element) + j][0]}
+ for j, value in enumerate(tuple_element)
+ )
+ )
+ for i, tuple_element in enumerate(values)
+ )
+ else:
+ to_update = [
+ ("%s_%s" % (name, i), value)
+ for i, value in enumerate(values, 1)
+ ]
+ replacement_expression = ", ".join(
+ self.bindtemplate % {"name": key} for key, value in to_update
+ )
+
+ return to_update, replacement_expression
+
+ def visit_binary(
+ self,
+ binary,
+ override_operator=None,
+ eager_grouping=False,
+ from_linter=None,
+ lateral_from_linter=None,
+ **kw
+ ):
+ if from_linter and operators.is_comparison(binary.operator):
+ if lateral_from_linter is not None:
+ enclosing_lateral = kw["enclosing_lateral"]
+ lateral_from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects + [enclosing_lateral],
+ binary.right._from_objects + [enclosing_lateral],
+ )
+ )
+ else:
+ from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects, binary.right._from_objects
+ )
+ )
+
+ # don't allow "? = ?" to render
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_execute"] = True
+
+ operator_ = override_operator or binary.operator
+ disp = self._get_operator_dispatch(operator_, "binary", None)
+ if disp:
+ return disp(binary, operator_, **kw)
+ else:
+ try:
+ opstring = OPERATORS[operator_]
+ except KeyError as err:
+ util.raise_(
+ exc.UnsupportedCompilationError(self, operator_),
+ replace_context=err,
+ )
+ else:
+ return self._generate_generic_binary(
+ binary,
+ opstring,
+ from_linter=from_linter,
+ lateral_from_linter=lateral_from_linter,
+ **kw
+ )
+
+ def visit_function_as_comparison_op_binary(self, element, operator, **kw):
+ return self.process(element.sql_function, **kw)
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ if self.preparer._double_percents:
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
+
+ def visit_custom_op_binary(self, element, operator, **kw):
+ kw["eager_grouping"] = operator.eager_grouping
+ return self._generate_generic_binary(
+ element,
+ " " + self.escape_literal_column(operator.opstring) + " ",
+ **kw
+ )
+
+ def visit_custom_op_unary_operator(self, element, operator, **kw):
+ return self._generate_generic_unary_operator(
+ element, self.escape_literal_column(operator.opstring) + " ", **kw
+ )
+
+ def visit_custom_op_unary_modifier(self, element, operator, **kw):
+ return self._generate_generic_unary_modifier(
+ element, " " + self.escape_literal_column(operator.opstring), **kw
+ )
+
+ def _generate_generic_binary(
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
+
+ _in_binary = kw.get("_in_binary", False)
+
+ kw["_in_binary"] = True
+ kw["_binary_op"] = binary.operator
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
+
+ if _in_binary and eager_grouping:
+ text = "(%s)" % text
+ return text
+
+ def _generate_generic_unary_operator(self, unary, opstring, **kw):
+ return opstring + unary.element._compiler_dispatch(self, **kw)
+
+ def _generate_generic_unary_modifier(self, unary, opstring, **kw):
+ return unary.element._compiler_dispatch(self, **kw) + opstring
+
+ @util.memoized_property
+ def _like_percent_literal(self):
+ return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
+
+ def visit_contains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right).concat(percent)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_contains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right).concat(percent)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_startswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent._rconcat(binary.right)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_startswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent._rconcat(binary.right)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_endswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_endswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_like_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+
+ # TODO: use ternary here, not "and"/ "or"
+ return "%s LIKE %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_like_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "%s NOT LIKE %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "lower(%s) LIKE lower(%s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "lower(%s) NOT LIKE lower(%s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_between_op_binary(self, binary, operator, **kw):
+ symmetric = binary.modifiers.get("symmetric", False)
+ return self._generate_generic_binary(
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
+
+ def visit_not_between_op_binary(self, binary, operator, **kw):
+ symmetric = binary.modifiers.get("symmetric", False)
+ return self._generate_generic_binary(
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expressions"
+ % self.dialect.name
+ )
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expressions"
+ % self.dialect.name
+ )
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expression replacements"
+ % self.dialect.name
+ )
+
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ literal_execute=False,
+ render_postcompile=False,
+ **kwargs
+ ):
+ if not skip_bind_expression:
+ impl = bindparam.type.dialect_impl(self.dialect)
+ if impl._has_bind_expression:
+ bind_expression = impl.bind_expression(bindparam)
+ wrapped = self.process(
+ bind_expression,
+ skip_bind_expression=True,
+ within_columns_clause=within_columns_clause,
+ literal_binds=literal_binds,
+ literal_execute=literal_execute,
+ render_postcompile=render_postcompile,
+ **kwargs
+ )
+ if bindparam.expanding:
+ # for postcompile w/ expanding, move the "wrapped" part
+ # of this into the inside
+ m = re.match(
+ r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+ )
+ wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+ m.group(2),
+ m.group(1),
+ m.group(3),
+ )
+ return wrapped
+
+ if not literal_binds:
+ literal_execute = (
+ literal_execute
+ or bindparam.literal_execute
+ or (within_columns_clause and self.ansi_bind_rules)
+ )
+ post_compile = literal_execute or bindparam.expanding
+ else:
+ post_compile = False
+
+ if literal_binds:
+ ret = self.render_literal_bindparam(
+ bindparam, within_columns_clause=True, **kwargs
+ )
+ if bindparam.expanding:
+ ret = "(%s)" % ret
+ return ret
+
+ name = self._truncate_bindparam(bindparam)
+
+ if name in self.binds:
+ existing = self.binds[name]
+ if existing is not bindparam:
+ if (
+ (existing.unique or bindparam.unique)
+ and not existing.proxy_set.intersection(
+ bindparam.proxy_set
+ )
+ and not existing._cloned_set.intersection(
+ bindparam._cloned_set
+ )
+ ):
+ raise exc.CompileError(
+ "Bind parameter '%s' conflicts with "
+ "unique bind parameter of the same name" % name
+ )
+ elif existing.expanding != bindparam.expanding:
+ raise exc.CompileError(
+ "Can't reuse bound parameter name '%s' in both "
+ "'expanding' (e.g. within an IN expression) and "
+ "non-expanding contexts. If this parameter is to "
+ "receive a list/array value, set 'expanding=True' on "
+ "it for expressions that aren't IN, otherwise use "
+ "a different parameter name." % (name,)
+ )
+ elif existing._is_crud or bindparam._is_crud:
+ raise exc.CompileError(
+ "bindparam() name '%s' is reserved "
+ "for automatic usage in the VALUES or SET "
+ "clause of this "
+ "insert/update statement. Please use a "
+ "name other than column name when using bindparam() "
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
+ )
+
+ self.binds[bindparam.key] = self.binds[name] = bindparam
+
+ # if we are given a cache key that we're going to match against,
+ # relate the bindparam here to one that is most likely present
+ # in the "extracted params" portion of the cache key. this is used
+ # to set up a positional mapping that is used to determine the
+ # correct parameters for a subsequent use of this compiled with
+ # a different set of parameter values. here, we accommodate for
+ # parameters that may have been cloned both before and after the cache
+ # key was been generated.
+ ckbm = self._cache_key_bind_match
+ if ckbm:
+ for bp in bindparam._cloned_set:
+ if bp.key in ckbm:
+ cb = ckbm[bp.key]
+ ckbm[cb].append(bindparam)
+
+ if bindparam.isoutparam:
+ self.has_out_parameters = True
+
+ if post_compile:
+ if render_postcompile:
+ self._render_postcompile = True
+
+ if literal_execute:
+ self.literal_execute_params |= {bindparam}
+ else:
+ self.post_compile_params |= {bindparam}
+
+ ret = self.bindparam_string(
+ name,
+ post_compile=post_compile,
+ expanding=bindparam.expanding,
+ **kwargs
+ )
+
+ if bindparam.expanding:
+ ret = "(%s)" % ret
+ return ret
+
+ def render_literal_bindparam(
+ self, bindparam, render_literal_value=NO_ARG, **kw
+ ):
+ if render_literal_value is not NO_ARG:
+ value = render_literal_value
+ else:
+ if bindparam.value is None and bindparam.callable is None:
+ op = kw.get("_binary_op", None)
+ if op and op not in (operators.is_, operators.is_not):
+ util.warn_limited(
+ "Bound parameter '%s' rendering literal NULL in a SQL "
+ "expression; comparisons to NULL should not use "
+ "operators outside of 'is' or 'is not'",
+ (bindparam.key,),
+ )
+ return self.process(sqltypes.NULLTYPE, **kw)
+ value = bindparam.effective_value
+
+ if bindparam.expanding:
+ leep = self._literal_execute_expanding_parameter_literal_binds
+ to_update, replacement_expr = leep(bindparam, value)
+ return replacement_expr
+ else:
+ return self.render_literal_value(value, bindparam.type)
+
+ def render_literal_value(self, value, type_):
+ """Render the value of a bind parameter as a quoted literal.
+
+ This is used for statement sections that do not accept bind parameters
+ on the target driver/database.
+
+ This should be implemented by subclasses using the quoting services
+ of the DBAPI.
+
+ """
+
+ processor = type_._cached_literal_processor(self.dialect)
+ if processor:
+ return processor(value)
+ else:
+ raise NotImplementedError(
+ "Don't know how to literal-quote value %r" % value
+ )
+
+ def _truncate_bindparam(self, bindparam):
+ if bindparam in self.bind_names:
+ return self.bind_names[bindparam]
+
+ bind_name = bindparam.key
+ if isinstance(bind_name, elements._truncated_label):
+ bind_name = self._truncated_identifier("bindparam", bind_name)
+
+ # add to bind_names for translation
+ self.bind_names[bindparam] = bind_name
+
+ return bind_name
+
+ def _truncated_identifier(self, ident_class, name):
+ if (ident_class, name) in self.truncated_names:
+ return self.truncated_names[(ident_class, name)]
+
+ anonname = name.apply_map(self.anon_map)
+
+ if len(anonname) > self.label_length - 6:
+ counter = self.truncated_names.get(ident_class, 1)
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
+ self.truncated_names[ident_class] = counter + 1
+ else:
+ truncname = anonname
+ self.truncated_names[(ident_class, name)] = truncname
+ return truncname
+
+ def _anonymize(self, name):
+ return name % self.anon_map
+
+ def bindparam_string(
+ self,
+ name,
+ positional_names=None,
+ post_compile=False,
+ expanding=False,
+ escaped_from=None,
+ **kw
+ ):
+
+ if self.positional:
+ if positional_names is not None:
+ positional_names.append(name)
+ else:
+ self.positiontup.append(name)
+ elif not escaped_from:
+
+ if _BIND_TRANSLATE_RE.search(name):
+ # not quite the translate use case as we want to
+ # also get a quick boolean if we even found
+ # unusual characters in the name
+ new_name = _BIND_TRANSLATE_RE.sub(
+ lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
+ name,
+ )
+ escaped_from = name
+ name = new_name
+
+ if escaped_from:
+ if not self.escaped_bind_names:
+ self.escaped_bind_names = {}
+ self.escaped_bind_names[escaped_from] = name
+ if post_compile:
+ return "__[POSTCOMPILE_%s]" % name
+ else:
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ from_linter=None,
+ **kwargs
+ ):
+ self._init_cte_state()
+
+ kwargs["visiting_cte"] = cte
+
+ cte_name = cte.name
+
+ if isinstance(cte_name, elements._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte_name)
+
+ is_new_cte = True
+ embedded_in_current_named_cte = False
+
+ _reference_cte = cte._get_reference_cte()
+
+ if _reference_cte in self.level_name_by_cte:
+ cte_level, _ = self.level_name_by_cte[_reference_cte]
+ assert _ == cte_name
+ else:
+ cte_level = len(self.stack) if cte.nesting else 1
+
+ cte_level_name = (cte_level, cte_name)
+ if cte_level_name in self.ctes_by_level_name:
+ existing_cte = self.ctes_by_level_name[cte_level_name]
+ embedded_in_current_named_cte = visiting_cte is existing_cte
+
+ # we've generated a same-named CTE that we are enclosed in,
+ # or this is the same CTE. just return the name.
+ if cte is existing_cte._restates or cte is existing_cte:
+ is_new_cte = False
+ elif existing_cte is cte._restates:
+ # we've generated a same-named CTE that is
+ # enclosed in us - we take precedence, so
+ # discard the text for the "inner".
+ del self.ctes[existing_cte]
+
+ existing_cte_reference_cte = existing_cte._get_reference_cte()
+
+ # TODO: determine if these assertions are correct. they
+ # pass for current test cases
+ # assert existing_cte_reference_cte is _reference_cte
+ # assert existing_cte_reference_cte is existing_cte
+
+ del self.level_name_by_cte[existing_cte_reference_cte]
+ else:
+ # if the two CTEs are deep-copy identical, consider them
+ # the same, **if** they are clones, that is, they came from
+ # the ORM or other visit method
+ if (
+ cte._is_clone_of is not None
+ or existing_cte._is_clone_of is not None
+ ) and cte.compare(existing_cte):
+ is_new_cte = False
+ else:
+ raise exc.CompileError(
+ "Multiple, unrelated CTEs found with "
+ "the same name: %r" % cte_name
+ )
+
+ if not asfrom and not is_new_cte:
+ return None
+
+ if cte._cte_alias is not None:
+ pre_alias_cte = cte._cte_alias
+ cte_pre_alias_name = cte._cte_alias.name
+ if isinstance(cte_pre_alias_name, elements._truncated_label):
+ cte_pre_alias_name = self._truncated_identifier(
+ "alias", cte_pre_alias_name
+ )
+ else:
+ pre_alias_cte = cte
+ cte_pre_alias_name = None
+
+ if is_new_cte:
+ self.ctes_by_level_name[cte_level_name] = cte
+ self.level_name_by_cte[_reference_cte] = cte_level_name
+
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
+ self.execution_options = self.execution_options.union(
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
+
+ if pre_alias_cte not in self.ctes:
+ self.visit_cte(pre_alias_cte, **kwargs)
+
+ if not cte_pre_alias_name and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.element, selectable.Select):
+ col_source = cte.element
+ elif isinstance(cte.element, selectable.CompoundSelect):
+ col_source = cte.element.selects[0]
+ else:
+ assert False, "cte should only be against SelectBase"
+
+ # TODO: can we get at the .columns_plus_names collection
+ # that is already (or will be?) generated for the SELECT
+ # rather than calling twice?
+ recur_cols = [
+ # TODO: proxy_name is not technically safe,
+ # see test_cte->
+ # test_with_recursive_no_name_currently_buggy. not
+ # clear what should be done with such a case
+ fallback_label_name or proxy_name
+ for (
+ _,
+ proxy_name,
+ fallback_label_name,
+ c,
+ repeated,
+ ) in (col_source._generate_columns_plus_names(True))
+ if not repeated
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_label_name(
+ ident, anon_map=self.anon_map
+ )
+ for ident in recur_cols
+ )
+ )
+
+ if self.positional:
+ kwargs["positional_names"] = self.cte_positional[cte] = []
+
+ assert kwargs.get("subquery", False) is False
+
+ if not self.stack:
+ # toplevel, this is a stringify of the
+ # cte directly. just compile the inner
+ # the way alias() does.
+ return cte.element._compiler_dispatch(
+ self, asfrom=asfrom, **kwargs
+ )
+ else:
+ prefixes = self._generate_prefixes(
+ cte, cte._prefixes, **kwargs
+ )
+ inner = cte.element._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+
+ text += " AS %s\n(%s)" % (prefixes, inner)
+
+ if cte._suffixes:
+ text += " " + self._generate_prefixes(
+ cte, cte._suffixes, **kwargs
+ )
+
+ self.ctes[cte] = text
+
+ if asfrom:
+ if from_linter:
+ from_linter.froms[cte] = cte_name
+
+ if not is_new_cte and embedded_in_current_named_cte:
+ return self.preparer.format_alias(cte, cte_name)
+
+ if cte_pre_alias_name:
+ text = self.preparer.format_alias(cte, cte_pre_alias_name)
+ if self.preparer._requires_quotes(cte_name):
+ cte_name = self.preparer.quote(cte_name)
+ text += self.get_render_as_alias_suffix(cte_name)
+ return text
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+
+ def visit_table_valued_alias(self, element, **kw):
+ if element.joins_implicitly:
+ kw["from_linter"] = None
+ if element._is_lateral:
+ return self.visit_lateral(element, **kw)
+ else:
+ return self.visit_alias(element, **kw)
+
+ def visit_table_valued_column(self, element, **kw):
+ return self.visit_column(element, **kw)
+
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ subquery=False,
+ lateral=False,
+ enclosing_alias=None,
+ from_linter=None,
+ **kwargs
+ ):
+
+ if lateral:
+ if "enclosing_lateral" not in kwargs:
+ # if lateral is set and enclosing_lateral is not
+ # present, we assume we are being called directly
+ # from visit_lateral() and we need to set enclosing_lateral.
+ assert alias._is_lateral
+ kwargs["enclosing_lateral"] = alias
+
+ # for lateral objects, we track a second from_linter that is...
+ # lateral! to the level above us.
+ if (
+ from_linter
+ and "lateral_from_linter" not in kwargs
+ and "enclosing_lateral" in kwargs
+ ):
+ kwargs["lateral_from_linter"] = from_linter
+
+ if enclosing_alias is not None and enclosing_alias.element is alias:
+ inner = alias.element._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ ashint=ashint,
+ iscrud=iscrud,
+ fromhints=fromhints,
+ lateral=lateral,
+ enclosing_alias=alias,
+ **kwargs
+ )
+ if subquery and (asfrom or lateral):
+ inner = "(%s)" % (inner,)
+ return inner
+ else:
+ enclosing_alias = kwargs["enclosing_alias"] = alias
+
+ if asfrom or ashint:
+ if isinstance(alias.name, elements._truncated_label):
+ alias_name = self._truncated_identifier("alias", alias.name)
+ else:
+ alias_name = alias.name
+
+ if ashint:
+ return self.preparer.format_alias(alias, alias_name)
+ elif asfrom:
+ if from_linter:
+ from_linter.froms[alias] = alias_name
+
+ inner = alias.element._compiler_dispatch(
+ self, asfrom=True, lateral=lateral, **kwargs
+ )
+ if subquery:
+ inner = "(%s)" % (inner,)
+
+ ret = inner + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
+
+ if alias._supports_derived_columns and alias._render_derived:
+ ret += "(%s)" % (
+ ", ".join(
+ "%s%s"
+ % (
+ self.preparer.quote(col.name),
+ " %s"
+ % self.dialect.type_compiler.process(
+ col.type, **kwargs
+ )
+ if alias._render_derived_w_types
+ else "",
+ )
+ for col in alias.c
+ )
+ )
+
+ if fromhints and alias in fromhints:
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
+
+ return ret
+ else:
+ # note we cancel the "subquery" flag here as well
+ return alias.element._compiler_dispatch(
+ self, lateral=lateral, **kwargs
+ )
+
+ def visit_subquery(self, subquery, **kw):
+ kw["subquery"] = True
+ return self.visit_alias(subquery, **kw)
+
+ def visit_lateral(self, lateral_, **kw):
+ kw["lateral"] = True
+ return "LATERAL %s" % self.visit_alias(lateral_, **kw)
+
+ def visit_tablesample(self, tablesample, asfrom=False, **kw):
+ text = "%s TABLESAMPLE %s" % (
+ self.visit_alias(tablesample, asfrom=True, **kw),
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
+
+ if tablesample.seed is not None:
+ text += " REPEATABLE (%s)" % (
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
+
+ return text
+
+ def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+ kw.setdefault("literal_binds", element.literal_binds)
+ v = "VALUES %s" % ", ".join(
+ self.process(
+ elements.Tuple(
+ types=element._column_types, *elem
+ ).self_group(),
+ **kw
+ )
+ for chunk in element._data
+ for elem in chunk
+ )
+
+ if isinstance(element.name, elements._truncated_label):
+ name = self._truncated_identifier("values", element.name)
+ else:
+ name = element.name
+
+ if element._is_lateral:
+ lateral = "LATERAL "
+ else:
+ lateral = ""
+
+ if asfrom:
+ if from_linter:
+ from_linter.froms[element] = (
+ name if name is not None else "(unnamed VALUES element)"
+ )
+
+ if name:
+ v = "%s(%s)%s (%s)" % (
+ lateral,
+ v,
+ self.get_render_as_alias_suffix(self.preparer.quote(name)),
+ (
+ ", ".join(
+ c._compiler_dispatch(
+ self, include_table=False, **kw
+ )
+ for c in element.columns
+ )
+ ),
+ )
+ else:
+ v = "%s(%s)" % (lateral, v)
+ return v
+
+ def get_render_as_alias_suffix(self, alias_name_text):
+ return " AS " + alias_name_text
+
+ def _add_to_result_map(self, keyname, name, objects, type_):
+ if keyname is None or keyname == "*":
+ self._ordered_columns = False
+ self._textual_ordered_columns = True
+ if type_._is_tuple_type:
+ raise exc.CompileError(
+ "Most backends don't support SELECTing "
+ "from a tuple() object. If this is an ORM query, "
+ "consider using the Bundle object."
+ )
+ self._result_columns.append((keyname, name, objects, type_))
+
+ def _label_returning_column(self, stmt, column, column_clause_args=None):
+ """Render a column with necessary labels inside of a RETURNING clause.
+
+ This method is provided for individual dialects in place of calling
+ the _label_select_column method directly, so that the two use cases
+ of RETURNING vs. SELECT can be disambiguated going forward.
+
+ .. versionadded:: 1.4.21
+
+ """
+ return self._label_select_column(
+ None,
+ column,
+ True,
+ False,
+ {} if column_clause_args is None else column_clause_args,
+ )
+
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ proxy_name=None,
+ fallback_label_name=None,
+ within_columns_clause=True,
+ column_is_repeated=False,
+ need_column_expressions=False,
+ ):
+ """produce labeled columns present in a select()."""
+ impl = column.type.dialect_impl(self.dialect)
+
+ if impl._has_column_expression and (
+ need_column_expressions or populate_result_map
+ ):
+ col_expr = impl.column_expression(column)
+ else:
+ col_expr = column
+
+ if populate_result_map:
+ # pass an "add_to_result_map" callable into the compilation
+ # of embedded columns. this collects information about the
+ # column as it will be fetched in the result and is coordinated
+ # with cursor.description when the query is executed.
+ add_to_result_map = self._add_to_result_map
+
+ # if the SELECT statement told us this column is a repeat,
+ # wrap the callable with one that prevents the addition of the
+ # targets
+ if column_is_repeated:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(keyname, name, (), type_)
+
+ # if we redefined col_expr for type expressions, wrap the
+ # callable with one that adds the original column to the targets
+ elif col_expr is not column:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(
+ keyname, name, (column,) + objects, type_
+ )
+
+ else:
+ add_to_result_map = None
+
+ # this method is used by some of the dialects for RETURNING,
+ # which has different inputs. _label_returning_column was added
+ # as the better target for this now however for 1.4 we will keep
+ # _label_select_column directly compatible with this use case.
+ # these assertions right now set up the current expected inputs
+ assert within_columns_clause, (
+ "_label_select_column is only relevant within "
+ "the columns clause of a SELECT or RETURNING"
+ )
+
+ if isinstance(column, elements.Label):
+ if col_expr is not column:
+ result_expr = _CompileLabel(
+ col_expr, column.name, alt_names=(column.element,)
+ )
+ else:
+ result_expr = col_expr
+
+ elif name:
+ # here, _columns_plus_names has determined there's an explicit
+ # label name we need to use. this is the default for
+ # tablenames_plus_columnnames as well as when columns are being
+ # deduplicated on name
+
+ assert (
+ proxy_name is not None
+ ), "proxy_name is required if 'name' is passed"
+
+ result_expr = _CompileLabel(
+ col_expr,
+ name,
+ alt_names=(
+ proxy_name,
+ # this is a hack to allow legacy result column lookups
+ # to work as they did before; this goes away in 2.0.
+ # TODO: this only seems to be tested indirectly
+ # via test/orm/test_deprecations.py. should be a
+ # resultset test for this
+ column._tq_label,
+ ),
+ )
+ else:
+ # determine here whether this column should be rendered in
+ # a labelled context or not, as we were given no required label
+ # name from the caller. Here we apply heuristics based on the kind
+ # of SQL expression involved.
+
+ if col_expr is not column:
+ # type-specific expression wrapping the given column,
+ # so we render a label
+ render_with_label = True
+ elif isinstance(column, elements.ColumnClause):
+ # table-bound column, we render its name as a label if we are
+ # inside of a subquery only
+ render_with_label = (
+ asfrom
+ and not column.is_literal
+ and column.table is not None
+ )
+ elif isinstance(column, elements.TextClause):
+ render_with_label = False
+ elif isinstance(column, elements.UnaryExpression):
+ render_with_label = column.wraps_column_expression or asfrom
+ elif (
+ # general class of expressions that don't have a SQL-column
+ # addressible name. includes scalar selects, bind parameters,
+ # SQL functions, others
+ not isinstance(column, elements.NamedColumn)
+ # deeper check that indicates there's no natural "name" to
+ # this element, which accommodates for custom SQL constructs
+ # that might have a ".name" attribute (but aren't SQL
+ # functions) but are not implementing this more recently added
+ # base class. in theory the "NamedColumn" check should be
+ # enough, however here we seek to maintain legacy behaviors
+ # as well.
+ and column._non_anon_label is None
+ ):
+ render_with_label = True
+ else:
+ render_with_label = False
+
+ if render_with_label:
+ if not fallback_label_name:
+ # used by the RETURNING case right now. we generate it
+ # here as 3rd party dialects may be referring to
+ # _label_select_column method directly instead of the
+ # just-added _label_returning_column method
+ assert not column_is_repeated
+ fallback_label_name = column._anon_name_label
+
+ fallback_label_name = (
+ elements._truncated_label(fallback_label_name)
+ if not isinstance(
+ fallback_label_name, elements._truncated_label
+ )
+ else fallback_label_name
+ )
+
+ result_expr = _CompileLabel(
+ col_expr, fallback_label_name, alt_names=(proxy_name,)
+ )
+ else:
+ result_expr = col_expr
+
+ column_clause_args.update(
+ within_columns_clause=within_columns_clause,
+ add_to_result_map=add_to_result_map,
+ )
+ return result_expr._compiler_dispatch(self, **column_clause_args)
+
+ def format_from_hint_text(self, sqltext, table, hint, iscrud):
+ hinttext = self.get_from_hint_text(table, hint)
+ if hinttext:
+ sqltext += " " + hinttext
+ return sqltext
+
+ def get_select_hint_text(self, byfroms):
+ return None
+
+ def get_from_hint_text(self, table, text):
+ return None
+
+ def get_crud_hint_text(self, table, text):
+ return None
+
+ def get_statement_hint_text(self, hint_texts):
+ return " ".join(hint_texts)
+
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
+
+ def _display_froms_for_select(
+ self, select_stmt, asfrom, lateral=False, **kw
+ ):
+ # utility method to help external dialects
+ # get the correct from list for a select.
+ # specifically the oracle dialect needs this feature
+ # right now.
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ compile_state = select_stmt._compile_state_factory(select_stmt, self)
+
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
+
+ if asfrom and not lateral:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
+ else:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms,
+ )
+ return froms
+
+ translate_select_structure = None
+ """if not ``None``, should be a callable which accepts ``(select_stmt,
+ **kw)`` and returns a select object. this is used for structural changes
+ mostly to accommodate for LIMIT/OFFSET schemes
+
+ """
+
+ def visit_select(
+ self,
+ select_stmt,
+ asfrom=False,
+ insert_into=False,
+ fromhints=None,
+ compound_index=None,
+ select_wraps_for=None,
+ lateral=False,
+ from_linter=None,
+ **kwargs
+ ):
+ assert select_wraps_for is None, (
+ "SQLAlchemy 1.4 requires use of "
+ "the translate_select_structure hook for structural "
+ "translations of SELECT objects"
+ )
+
+ # initial setup of SELECT. the compile_state_factory may now
+ # be creating a totally different SELECT from the one that was
+ # passed in. for ORM use this will convert from an ORM-state
+ # SELECT to a regular "Core" SELECT. other composed operations
+ # such as computation of joins will be performed.
+
+ kwargs["within_columns_clause"] = False
+
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
+ toplevel = not self.stack
+
+ if toplevel and not self.compile_state:
+ self.compile_state = compile_state
+
+ is_embedded_select = compound_index is not None or insert_into
+
+ # translate step for Oracle, SQL Server which often need to
+ # restructure the SELECT to allow for LIMIT/OFFSET and possibly
+ # other conditions
+ if self.translate_select_structure:
+ new_select_stmt = self.translate_select_structure(
+ select_stmt, asfrom=asfrom, **kwargs
+ )
+
+ # if SELECT was restructured, maintain a link to the originals
+ # and assemble a new compile state
+ if new_select_stmt is not select_stmt:
+ compile_state_wraps_for = compile_state
+ select_wraps_for = select_stmt
+ select_stmt = new_select_stmt
+
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ populate_result_map = need_column_expressions = (
+ toplevel
+ or entry.get("need_result_map_for_compound", False)
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ # indicates there is a CompoundSelect in play and we are not the
+ # first select
+ if compound_index:
+ populate_result_map = False
+
+ # this was first proposed as part of #3372; however, it is not
+ # reached in current tests and could possibly be an assertion
+ # instead.
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
+
+ froms = self._setup_select_stack(
+ select_stmt, compile_state, entry, asfrom, lateral, compound_index
+ )
+
+ column_clause_args = kwargs.copy()
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
+
+ text = "SELECT " # we're off to a good start !
+
+ if select_stmt._hints:
+ hint_text, byfrom = self._setup_select_hints(select_stmt)
+ if hint_text:
+ text += hint_text + " "
+ else:
+ byfrom = None
+
+ if select_stmt._independent_ctes:
+ for cte in select_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kwargs)
+
+ if select_stmt._prefixes:
+ text += self._generate_prefixes(
+ select_stmt, select_stmt._prefixes, **kwargs
+ )
+
+ text += self.get_select_precolumns(select_stmt, **kwargs)
+ # the actual list of columns to print in the SELECT column list.
+ inner_columns = [
+ c
+ for c in [
+ self._label_select_column(
+ select_stmt,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=name,
+ proxy_name=proxy_name,
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ need_column_expressions=need_column_expressions,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in compile_state.columns_plus_names
+ ]
+ if c is not None
+ ]
+
+ if populate_result_map and select_wraps_for is not None:
+ # if this select was generated from translate_select,
+ # rewrite the targeted columns in the result map
+
+ translate = dict(
+ zip(
+ [
+ name
+ for (
+ key,
+ proxy_name,
+ fallback_label_name,
+ name,
+ repeated,
+ ) in compile_state.columns_plus_names
+ ],
+ [
+ name
+ for (
+ key,
+ proxy_name,
+ fallback_label_name,
+ name,
+ repeated,
+ ) in compile_state_wraps_for.columns_plus_names
+ ],
+ )
+ )
+
+ self._result_columns = [
+ (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ for key, name, obj, type_ in self._result_columns
+ ]
+
+ text = self._compose_select_body(
+ text,
+ select_stmt,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
+ )
+
+ if select_stmt._statement_hints:
+ per_dialect = [
+ ht
+ for (dialect_name, ht) in select_stmt._statement_hints
+ if dialect_name in ("*", self.dialect.name)
+ ]
+ if per_dialect:
+ text += " " + self.get_statement_hint_text(per_dialect)
+
+ if self.ctes:
+ # In compound query, CTEs are shared at the compound level
+ if not is_embedded_select:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(nesting_level=nesting_level) + text
+ )
+
+ if select_stmt._suffixes:
+ text += " " + self._generate_prefixes(
+ select_stmt, select_stmt._suffixes, **kwargs
+ )
+
+ self.stack.pop(-1)
+
+ return text
+
+ def _setup_select_hints(self, select):
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
+ hint_text = self.get_select_hint_text(byfrom)
+ return hint_text, byfrom
+
+ def _setup_select_stack(
+ self, select, compile_state, entry, asfrom, lateral, compound_index
+ ):
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
+
+ if compound_index == 0:
+ entry["select_0"] = select
+ elif compound_index:
+ select_0 = entry["select_0"]
+ numcols = len(select_0._all_selected_columns)
+
+ if len(compile_state.columns_plus_names) != numcols:
+ raise exc.CompileError(
+ "All selectables passed to "
+ "CompoundSelect must have identical numbers of "
+ "columns; select #%d has %d columns, select "
+ "#%d has %d"
+ % (
+ 1,
+ numcols,
+ compound_index + 1,
+ len(select._all_selected_columns),
+ )
+ )
+
+ if asfrom and not lateral:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
+ else:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms,
+ )
+
+ new_correlate_froms = set(selectable._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
+
+ new_entry = {
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
+ "compile_state": compile_state,
+ }
+ self.stack.append(new_entry)
+
+ return froms
+
+ def _compose_select_body(
+ self,
+ text,
+ select,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
+ ):
+ text += ", ".join(inner_columns)
+
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
+ if froms:
+ text += " \nFROM "
+
+ if select._hints:
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ fromhints=byfrom,
+ from_linter=from_linter,
+ **kwargs
+ )
+ for f in froms
+ ]
+ )
+ else:
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ from_linter=from_linter,
+ **kwargs
+ )
+ for f in froms
+ ]
+ )
+ else:
+ text += self.default_from()
+
+ if select._where_criteria:
+ t = self._generate_delimited_and_list(
+ select._where_criteria, from_linter=from_linter, **kwargs
+ )
+ if t:
+ text += " \nWHERE " + t
+
+ if warn_linting:
+ from_linter.warn()
+
+ if select._group_by_clauses:
+ text += self.group_by_clause(select, **kwargs)
+
+ if select._having_criteria:
+ t = self._generate_delimited_and_list(
+ select._having_criteria, **kwargs
+ )
+ if t:
+ text += " \nHAVING " + t
+
+ if select._order_by_clauses:
+ text += self.order_by_clause(select, **kwargs)
+
+ if select._has_row_limiting_clause:
+ text += self._row_limit_clause(select, **kwargs)
+
+ if select._for_update_arg is not None:
+ text += self.for_update_clause(select, **kwargs)
+
+ return text
+
+ def _generate_prefixes(self, stmt, prefixes, **kw):
+ clause = " ".join(
+ prefix._compiler_dispatch(self, **kw)
+ for prefix, dialect_name in prefixes
+ if dialect_name is None or dialect_name == self.dialect.name
+ )
+ if clause:
+ clause += " "
+ return clause
+
+ def _render_cte_clause(
+ self,
+ nesting_level=None,
+ include_following_stack=False,
+ ):
+ """
+ include_following_stack
+ Also render the nesting CTEs on the next stack. Useful for
+ SQL structures like UNION or INSERT that can wrap SELECT
+ statements containing nesting CTEs.
+ """
+ if not self.ctes:
+ return ""
+
+ if nesting_level and nesting_level > 1:
+ ctes = util.OrderedDict()
+ for cte in list(self.ctes.keys()):
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
+ is_rendered_level = cte_level == nesting_level or (
+ include_following_stack and cte_level == nesting_level + 1
+ )
+ if not (cte.nesting and is_rendered_level):
+ continue
+
+ ctes[cte] = self.ctes[cte]
+
+ else:
+ ctes = self.ctes
+
+ if not ctes:
+ return ""
+
+ ctes_recursive = any([cte.recursive for cte in ctes])
+
+ if self.positional:
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in ctes], [])
+ + self.positiontup
+ )
+ cte_text = self.get_cte_preamble(ctes_recursive) + " "
+ cte_text += ", \n".join([txt for txt in ctes.values()])
+ cte_text += "\n "
+
+ if nesting_level and nesting_level > 1:
+ for cte in list(ctes.keys()):
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
+ del self.ctes[cte]
+ del self.ctes_by_level_name[(cte_level, cte_name)]
+ del self.level_name_by_cte[cte._get_reference_cte()]
+
+ return cte_text
+
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
+ def get_select_precolumns(self, select, **kw):
+ """Called when building a ``SELECT`` statement, position is just
+ before column list.
+
+ """
+ if select._distinct_on:
+ util.warn_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ "dialect. Use of DISTINCT ON for other backends is currently "
+ "silently ignored, however this usage is deprecated, and will "
+ "raise CompileError in a future release for all backends "
+ "that do not support this syntax.",
+ version="1.4",
+ )
+ return "DISTINCT " if select._distinct else ""
+
+ def group_by_clause(self, select, **kw):
+ """allow dialects to customize how GROUP BY is rendered."""
+
+ group_by = self._generate_delimited_list(
+ select._group_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+ if group_by:
+ return " GROUP BY " + group_by
+ else:
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ """allow dialects to customize how ORDER BY is rendered."""
+
+ order_by = self._generate_delimited_list(
+ select._order_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+
+ if order_by:
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def for_update_clause(self, select, **kw):
+ return " FOR UPDATE"
+
+ def returning_clause(self, stmt, returning_cols):
+ raise exc.CompileError(
+ "RETURNING is not supported by this "
+ "dialect's statement compiler."
+ )
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += "\n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT -1"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def fetch_clause(self, select, **kw):
+ text = ""
+ if select._offset_clause is not None:
+ text += "\n OFFSET %s ROWS" % self.process(
+ select._offset_clause, **kw
+ )
+ if select._fetch_clause is not None:
+ text += "\n FETCH FIRST %s%s ROWS %s" % (
+ self.process(select._fetch_clause, **kw),
+ " PERCENT" if select._fetch_clause_options["percent"] else "",
+ "WITH TIES"
+ if select._fetch_clause_options["with_ties"]
+ else "ONLY",
+ )
+ return text
+
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ from_linter=None,
+ **kwargs
+ ):
+ if from_linter:
+ from_linter.froms[table] = table.fullname
+
+ if asfrom or ashint:
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if use_schema and effective_schema:
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
+ else:
+ ret = self.preparer.quote(table.name)
+ if fromhints and table in fromhints:
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
+ return ret
+ else:
+ return ""
+
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.update(
+ itertools.product(
+ join.left._from_objects, join.right._from_objects
+ )
+ )
+
+ if join.full:
+ join_type = " FULL OUTER JOIN "
+ elif join.isouter:
+ join_type = " LEFT OUTER JOIN "
+ else:
+ join_type = " JOIN "
+ return (
+ join.left._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ + join_type
+ + join.right._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ + " ON "
+ # TODO: likely need asfrom=True here?
+ + join.onclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
+ )
+
+ def _setup_crud_hints(self, stmt, table_text):
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
+ if stmt.table in dialect_hints:
+ table_text = self.format_from_hint_text(
+ table_text, stmt.table, dialect_hints[stmt.table], True
+ )
+ return dialect_hints, table_text
+
+ def visit_insert(self, insert_stmt, **kw):
+
+ compile_state = insert_stmt._compile_state_factory(
+ insert_stmt, self, **kw
+ )
+ insert_stmt = compile_state.statement
+
+ toplevel = not self.stack
+
+ if toplevel:
+ self.isinsert = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ self.stack.append(
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
+
+ crud_params = crud._get_crud_params(
+ self, insert_stmt, compile_state, **kw
+ )
+
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_default_metavalue
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
+
+ if compile_state._has_multi_parameters:
+ if not self.dialect.supports_multivalues_insert:
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support "
+ "in-place multirow inserts." % self.dialect.name
+ )
+ crud_params_single = crud_params[0]
+ else:
+ crud_params_single = crud_params
+
+ preparer = self.preparer
+ supports_default_values = self.dialect.supports_default_values
+
+ text = "INSERT "
+
+ if insert_stmt._prefixes:
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
+
+ text += "INTO "
+ table_text = preparer.format_table(insert_stmt.table)
+
+ if insert_stmt._hints:
+ _, table_text = self._setup_crud_hints(insert_stmt, table_text)
+
+ if insert_stmt._independent_ctes:
+ for cte in insert_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ if crud_params_single or not supports_default_values:
+ text += " (%s)" % ", ".join(
+ [expr for c, expr, value in crud_params_single]
+ )
+
+ if self.returning or insert_stmt._returning:
+ returning_clause = self.returning_clause(
+ insert_stmt, self.returning or insert_stmt._returning
+ )
+
+ if self.returning_precedes_values:
+ text += " " + returning_clause
+ else:
+ returning_clause = None
+
+ if insert_stmt.select is not None:
+ # placed here by crud.py
+ select_text = self.process(
+ self.stack[-1]["insert_from_select"], insert_into=True, **kw
+ )
+
+ if self.ctes and self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text += " %s%s" % (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ ),
+ select_text,
+ )
+ else:
+ text += " %s" % select_text
+ elif not crud_params and supports_default_values:
+ text += " DEFAULT VALUES"
+ elif compile_state._has_multi_parameters:
+ text += " VALUES %s" % (
+ ", ".join(
+ "(%s)"
+ % (", ".join(value for c, expr, value in crud_param_set))
+ for crud_param_set in crud_params
+ )
+ )
+ else:
+ insert_single_values_expr = ", ".join(
+ [value for c, expr, value in crud_params]
+ )
+ text += " VALUES (%s)" % insert_single_values_expr
+ if toplevel:
+ self.insert_single_values_expr = insert_single_values_expr
+
+ if insert_stmt._post_values_clause is not None:
+ post_values_clause = self.process(
+ insert_stmt._post_values_clause, **kw
+ )
+ if post_values_clause:
+ text += " " + post_values_clause
+
+ if returning_clause and not self.returning_precedes_values:
+ text += " " + returning_clause
+
+ if self.ctes and not self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
+
+ self.stack.pop(-1)
+
+ return text
+
+ def update_limit_clause(self, update_stmt):
+ """Provide a hook for MySQL to add LIMIT to the UPDATE"""
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ """Provide a hook to override the initial table clause
+ in an UPDATE statement.
+
+ MySQL overrides this.
+
+ """
+ kw["asfrom"] = True
+ return from_table._compiler_dispatch(self, iscrud=True, **kw)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Provide a hook to override the generation of an
+ UPDATE..FROM clause.
+
+ MySQL and MSSQL override this.
+
+ """
+ raise NotImplementedError(
+ "This backend does not support multiple-table "
+ "criteria within UPDATE"
+ )
+
+ def visit_update(self, update_stmt, **kw):
+ compile_state = update_stmt._compile_state_factory(
+ update_stmt, self, **kw
+ )
+ update_stmt = compile_state.statement
+
+ toplevel = not self.stack
+ if toplevel:
+ self.isupdate = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ extra_froms = compile_state._extra_froms
+ is_multitable = bool(extra_froms)
+
+ if is_multitable:
+ # main table might be a JOIN
+ main_froms = set(selectable._from_objects(update_stmt.table))
+ render_extra_froms = [
+ f for f in extra_froms if f not in main_froms
+ ]
+ correlate_froms = main_froms.union(extra_froms)
+ else:
+ render_extra_froms = []
+ correlate_froms = {update_stmt.table}
+
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
+
+ text = "UPDATE "
+
+ if update_stmt._prefixes:
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
+
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
+ crud_params = crud._get_crud_params(
+ self, update_stmt, compile_state, **kw
+ )
+
+ if update_stmt._hints:
+ dialect_hints, table_text = self._setup_crud_hints(
+ update_stmt, table_text
+ )
+ else:
+ dialect_hints = None
+
+ if update_stmt._independent_ctes:
+ for cte in update_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ text += " SET "
+ text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+
+ if self.returning or update_stmt._returning:
+ if self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ update_stmt, self.returning or update_stmt._returning
+ )
+
+ if extra_froms:
+ extra_from_text = self.update_from_clause(
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ dialect_hints,
+ **kw
+ )
+ if extra_from_text:
+ text += " " + extra_from_text
+
+ if update_stmt._where_criteria:
+ t = self._generate_delimited_and_list(
+ update_stmt._where_criteria, **kw
+ )
+ if t:
+ text += " WHERE " + t
+
+ limit_clause = self.update_limit_clause(update_stmt)
+ if limit_clause:
+ text += " " + limit_clause
+
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ update_stmt, self.returning or update_stmt._returning
+ )
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Provide a hook to override the generation of an
+ DELETE..FROM clause.
+
+ This can be used to implement DELETE..USING for example.
+
+ MySQL and MSSQL override this.
+
+ """
+ raise NotImplementedError(
+ "This backend does not support multiple-table "
+ "criteria within DELETE"
+ )
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+
+ def visit_delete(self, delete_stmt, **kw):
+ compile_state = delete_stmt._compile_state_factory(
+ delete_stmt, self, **kw
+ )
+ delete_stmt = compile_state.statement
+
+ toplevel = not self.stack
+ if toplevel:
+ self.isdelete = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ extra_froms = compile_state._extra_froms
+
+ correlate_froms = {delete_stmt.table}.union(extra_froms)
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
+
+ text = "DELETE "
+
+ if delete_stmt._prefixes:
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
+
+ text += "FROM "
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+
+ if delete_stmt._hints:
+ dialect_hints, table_text = self._setup_crud_hints(
+ delete_stmt, table_text
+ )
+ else:
+ dialect_hints = None
+
+ if delete_stmt._independent_ctes:
+ for cte in delete_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ if delete_stmt._returning:
+ if self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt, delete_stmt._returning
+ )
+
+ if extra_froms:
+ extra_from_text = self.delete_extra_from_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ dialect_hints,
+ **kw
+ )
+ if extra_from_text:
+ text += " " + extra_from_text
+
+ if delete_stmt._where_criteria:
+ t = self._generate_delimited_and_list(
+ delete_stmt._where_criteria, **kw
+ )
+ if t:
+ text += " WHERE " + t
+
+ if delete_stmt._returning and not self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt, delete_stmt._returning
+ )
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_release_savepoint(self, savepoint_stmt):
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+
+class StrSQLCompiler(SQLCompiler):
+ """A :class:`.SQLCompiler` subclass which allows a small selection
+ of non-standard SQL features to render into a string value.
+
+ The :class:`.StrSQLCompiler` is invoked whenever a Core expression
+ element is directly stringified without calling upon the
+ :meth:`_expression.ClauseElement.compile` method.
+ It can render a limited set
+ of non-standard SQL constructs to assist in basic stringification,
+ however for more substantial custom or dialect-specific SQL constructs,
+ it will be necessary to make use of
+ :meth:`_expression.ClauseElement.compile`
+ directly.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ """
+
+ def _fallback_column_name(self, column):
+ return "<name unknown>"
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def visit_unsupported_compilation(self, element, err, **kw):
+ if element.stringify_dialect != "default":
+ url = util.preloaded.engine_url
+ dialect = url.URL.create(element.stringify_dialect).get_dialect()()
+
+ compiler = dialect.statement_compiler(dialect, None)
+ if not isinstance(compiler, StrSQLCompiler):
+ return compiler.process(element)
+
+ return super(StrSQLCompiler, self).visit_unsupported_compilation(
+ element, err
+ )
+
+ def visit_getitem_binary(self, binary, operator, **kw):
+ return "%s[%s]" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self.visit_getitem_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self.visit_getitem_binary(binary, operator, **kw)
+
+ def visit_sequence(self, seq, **kw):
+ return "<next sequence value: %s>" % self.preparer.format_sequence(seq)
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = [
+ self._label_select_column(None, c, True, False, {})
+ for c in base._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 WHERE 1!=1"
+
+ def get_from_hint_text(self, table, text):
+ return "[%s]" % text
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " <regexp> ", **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " <not regexp> ", **kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ replacement = binary.modifiers["replacement"]
+ return "<regexp replace>(%s, %s, %s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ replacement._compiler_dispatch(self, **kw),
+ )
+
+
+class DDLCompiler(Compiled):
+ @util.memoized_property
+ def sql_compiler(self):
+ return self.dialect.statement_compiler(
+ self.dialect, None, schema_translate_map=self.schema_translate_map
+ )
+
+ @util.memoized_property
+ def type_compiler(self):
+ return self.dialect.type_compiler
+
+ def construct_params(
+ self, params=None, extracted_parameters=None, escape_names=True
+ ):
+ return None
+
+ def visit_ddl(self, ddl, **kwargs):
+ # table events can substitute table and schema name
+ context = ddl.context
+ if isinstance(ddl.target, schema.Table):
+ context = context.copy()
+
+ preparer = self.preparer
+ path = preparer.format_table_seq(ddl.target)
+ if len(path) == 1:
+ table, sch = path[0], ""
+ else:
+ table, sch = path[-1], path[0]
+
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
+
+ return self.sql_compiler.post_process_text(ddl.statement % context)
+
+ def visit_create_schema(self, create, **kw):
+ schema = self.preparer.format_schema(create.element)
+ return "CREATE SCHEMA " + schema
+
+ def visit_drop_schema(self, drop, **kw):
+ schema = self.preparer.format_schema(drop.element)
+ text = "DROP SCHEMA " + schema
+ if drop.cascade:
+ text += " CASCADE"
+ return text
+
+ def visit_create_table(self, create, **kw):
+ table = create.element
+ preparer = self.preparer
+
+ text = "\nCREATE "
+ if table._prefixes:
+ text += " ".join(table._prefixes) + " "
+
+ text += "TABLE "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += preparer.format_table(table) + " "
+
+ create_table_suffix = self.create_table_suffix(table)
+ if create_table_suffix:
+ text += create_table_suffix + " "
+
+ text += "("
+
+ separator = "\n"
+
+ # if only one primary key, specify it along with the column
+ first_pk = False
+ for create_column in create.columns:
+ column = create_column.element
+ try:
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
+ if processed is not None:
+ text += separator
+ separator = ", \n"
+ text += "\t" + processed
+ if column.primary_key:
+ first_pk = True
+ except exc.CompileError as ce:
+ util.raise_(
+ exc.CompileError(
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ ),
+ from_=ce,
+ )
+
+ const = self.create_table_constraints(
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
+ if const:
+ text += separator + "\t" + const
+
+ text += "\n)%s\n\n" % self.post_create_table(table)
+ return text
+
+ def visit_create_column(self, create, first_pk=False, **kw):
+ column = create.element
+
+ if column.system:
+ return None
+
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
+ )
+ if const:
+ text += " " + const
+
+ return text
+
+ def create_table_constraints(
+ self, table, _include_foreign_key_constraints=None, **kw
+ ):
+
+ # On some DB order is significant: visit PK first, then the
+ # other constraints (engine.ReflectionTest.testbasic failed on FB2)
+ constraints = []
+ if table.primary_key:
+ constraints.append(table.primary_key)
+
+ all_fkcs = table.foreign_key_constraints
+ if _include_foreign_key_constraints is not None:
+ omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
+ else:
+ omit_fkcs = set()
+
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
+
+ return ", \n\t".join(
+ p
+ for p in (
+ self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
+ and (
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
+ )
+
+ def visit_drop_table(self, drop, **kw):
+ text = "\nDROP TABLE "
+ if drop.if_exists:
+ text += "IF EXISTS "
+ return text + self.preparer.format_table(drop.element)
+
+ def visit_drop_view(self, drop, **kw):
+ return "\nDROP VIEW " + self.preparer.format_table(drop.element)
+
+ def _verify_index_table(self, index):
+ if index.table is None:
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
+
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True, **kw
+ ):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ if index.name is None:
+ raise exc.CompileError(
+ "CREATE INDEX requires that the index have a name"
+ )
+
+ text += "INDEX "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+ return text
+
+ def visit_drop_index(self, drop, **kw):
+ index = drop.element
+
+ if index.name is None:
+ raise exc.CompileError(
+ "DROP INDEX requires that the index have a name"
+ )
+ text = "\nDROP INDEX "
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ return text + self._prepared_index_name(index, include_schema=True)
+
+ def _prepared_index_name(self, index, include_schema=False):
+ if index.table is not None:
+ effective_schema = self.preparer.schema_for_object(index.table)
+ else:
+ effective_schema = None
+ if include_schema and effective_schema:
+ schema_name = self.preparer.quote_schema(effective_schema)
+ else:
+ schema_name = None
+
+ index_name = self.preparer.format_index(index)
+
+ if schema_name:
+ index_name = schema_name + "." + index_name
+ return index_name
+
+ def visit_add_constraint(self, create, **kw):
+ return "ALTER TABLE %s ADD %s" % (
+ self.preparer.format_table(create.element.table),
+ self.process(create.element),
+ )
+
+ def visit_set_table_comment(self, create, **kw):
+ return "COMMENT ON TABLE %s IS %s" % (
+ self.preparer.format_table(create.element),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_table_comment(self, drop, **kw):
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_set_column_comment(self, create, **kw):
+ return "COMMENT ON COLUMN %s IS %s" % (
+ self.preparer.format_column(
+ create.element, use_table=True, use_schema=True
+ ),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_column_comment(self, drop, **kw):
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
+
+ def get_identity_options(self, identity_options):
+ text = []
+ if identity_options.increment is not None:
+ text.append("INCREMENT BY %d" % identity_options.increment)
+ if identity_options.start is not None:
+ text.append("START WITH %d" % identity_options.start)
+ if identity_options.minvalue is not None:
+ text.append("MINVALUE %d" % identity_options.minvalue)
+ if identity_options.maxvalue is not None:
+ text.append("MAXVALUE %d" % identity_options.maxvalue)
+ if identity_options.nominvalue is not None:
+ text.append("NO MINVALUE")
+ if identity_options.nomaxvalue is not None:
+ text.append("NO MAXVALUE")
+ if identity_options.cache is not None:
+ text.append("CACHE %d" % identity_options.cache)
+ if identity_options.order is not None:
+ text.append("ORDER" if identity_options.order else "NO ORDER")
+ if identity_options.cycle is not None:
+ text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
+ return " ".join(text)
+
+ def visit_create_sequence(self, create, prefix=None, **kw):
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
+ if prefix:
+ text += prefix
+ if create.element.start is None:
+ create.element.start = self.dialect.default_sequence_base
+ options = self.get_identity_options(create.element)
+ if options:
+ text += " " + options
+ return text
+
+ def visit_drop_sequence(self, drop, **kw):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def visit_drop_constraint(self, drop, **kw):
+ constraint = drop.element
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ else:
+ formatted_name = None
+
+ if formatted_name is None:
+ raise exc.CompileError(
+ "Can't emit DROP CONSTRAINT for constraint %r; "
+ "it has no name" % drop.element
+ )
+ return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
+ self.preparer.format_table(drop.element.table),
+ formatted_name,
+ drop.cascade and " CASCADE" or "",
+ )
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+
+ if (
+ column.identity is not None
+ and self.dialect.supports_identity_columns
+ ):
+ colspec += " " + self.process(column.identity)
+
+ if not column.nullable and (
+ not column.identity or not self.dialect.supports_identity_columns
+ ):
+ colspec += " NOT NULL"
+ return colspec
+
+ def create_table_suffix(self, table):
+ return ""
+
+ def post_create_table(self, table):
+ return ""
+
+ def get_column_default_string(self, column):
+ if isinstance(column.server_default, schema.DefaultClause):
+ if isinstance(column.server_default.arg, util.string_types):
+ return self.sql_compiler.render_literal_value(
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
+ else:
+ return self.sql_compiler.process(
+ column.server_default.arg, literal_binds=True
+ )
+ else:
+ return None
+
+ def visit_table_or_column_check_constraint(self, constraint, **kw):
+ if constraint.is_column_level:
+ return self.visit_column_check_constraint(constraint)
+ else:
+ return self.visit_check_constraint(constraint)
+
+ def visit_check_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_column_check_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_primary_key_constraint(self, constraint, **kw):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "PRIMARY KEY "
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_foreign_key_constraint(self, constraint, **kw):
+ preparer = self.preparer
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ remote_table = list(constraint.elements)[0].column.table
+ text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
+ self.define_constraint_remote_table(
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
+ )
+ text += self.define_constraint_match(constraint)
+ text += self.define_constraint_cascades(constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def define_constraint_remote_table(self, constraint, table, preparer):
+ """Format the remote table clause of a CREATE CONSTRAINT clause."""
+
+ return preparer.format_table(table)
+
+ def visit_unique_constraint(self, constraint, **kw):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "UNIQUE (%s)" % (
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
+ constraint.ondelete, FK_ON_DELETE
+ )
+ if constraint.onupdate is not None:
+ text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+ constraint.onupdate, FK_ON_UPDATE
+ )
+ return text
+
+ def define_constraint_deferrability(self, constraint):
+ text = ""
+ if constraint.deferrable is not None:
+ if constraint.deferrable:
+ text += " DEFERRABLE"
+ else:
+ text += " NOT DEFERRABLE"
+ if constraint.initially is not None:
+ text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
+ constraint.initially, FK_INITIALLY
+ )
+ return text
+
+ def define_constraint_match(self, constraint):
+ text = ""
+ if constraint.match is not None:
+ text += " MATCH %s" % constraint.match
+ return text
+
+ def visit_computed_column(self, generated, **kw):
+ text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ if generated.persisted is True:
+ text += " STORED"
+ elif generated.persisted is False:
+ text += " VIRTUAL"
+ return text
+
+ def visit_identity_column(self, identity, **kw):
+ text = "GENERATED %s AS IDENTITY" % (
+ "ALWAYS" if identity.always else "BY DEFAULT",
+ )
+ options = self.get_identity_options(identity)
+ if options:
+ text += " (%s)" % options
+ return text
+
+
+class GenericTypeCompiler(TypeCompiler):
+ def visit_FLOAT(self, type_, **kw):
+ return "FLOAT"
+
+ def visit_REAL(self, type_, **kw):
+ return "REAL"
+
+ def visit_NUMERIC(self, type_, **kw):
+ if type_.precision is None:
+ return "NUMERIC"
+ elif type_.scale is None:
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
+
+ def visit_DECIMAL(self, type_, **kw):
+ if type_.precision is None:
+ return "DECIMAL"
+ elif type_.scale is None:
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
+ else:
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
+
+ def visit_INTEGER(self, type_, **kw):
+ return "INTEGER"
+
+ def visit_SMALLINT(self, type_, **kw):
+ return "SMALLINT"
+
+ def visit_BIGINT(self, type_, **kw):
+ return "BIGINT"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP"
+
+ def visit_DATETIME(self, type_, **kw):
+ return "DATETIME"
+
+ def visit_DATE(self, type_, **kw):
+ return "DATE"
+
+ def visit_TIME(self, type_, **kw):
+ return "TIME"
+
+ def visit_CLOB(self, type_, **kw):
+ return "CLOB"
+
+ def visit_NCLOB(self, type_, **kw):
+ return "NCLOB"
+
+ def _render_string_type(self, type_, name):
+
+ text = name
+ if type_.length:
+ text += "(%d)" % type_.length
+ if type_.collation:
+ text += ' COLLATE "%s"' % type_.collation
+ return text
+
+ def visit_CHAR(self, type_, **kw):
+ return self._render_string_type(type_, "CHAR")
+
+ def visit_NCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "NCHAR")
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "VARCHAR")
+
+ def visit_NVARCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "NVARCHAR")
+
+ def visit_TEXT(self, type_, **kw):
+ return self._render_string_type(type_, "TEXT")
+
+ def visit_BLOB(self, type_, **kw):
+ return "BLOB"
+
+ def visit_BINARY(self, type_, **kw):
+ return "BINARY" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_VARBINARY(self, type_, **kw):
+ return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_BOOLEAN(self, type_, **kw):
+ return "BOOLEAN"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BOOLEAN(type_, **kw)
+
+ def visit_time(self, type_, **kw):
+ return self.visit_TIME(type_, **kw)
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATETIME(type_, **kw)
+
+ def visit_date(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
+
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_BIGINT(type_, **kw)
+
+ def visit_small_integer(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_integer(self, type_, **kw):
+ return self.visit_INTEGER(type_, **kw)
+
+ def visit_real(self, type_, **kw):
+ return self.visit_REAL(type_, **kw)
+
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
+
+ def visit_numeric(self, type_, **kw):
+ return self.visit_NUMERIC(type_, **kw)
+
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_enum(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_null(self, type_, **kw):
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
+
+ def visit_type_decorator(self, type_, **kw):
+ return self.process(type_.type_engine(self.dialect), **kw)
+
+ def visit_user_defined(self, type_, **kw):
+ return type_.get_col_spec(**kw)
+
+
+class StrSQLTypeCompiler(GenericTypeCompiler):
+ def process(self, type_, **kw):
+ try:
+ _compiler_dispatch = type_._compiler_dispatch
+ except AttributeError:
+ return self._visit_unknown(type_, **kw)
+ else:
+ return _compiler_dispatch(self, **kw)
+
+ def __getattr__(self, key):
+ if key.startswith("visit_"):
+ return self._visit_unknown
+ else:
+ raise AttributeError(key)
+
+ def _visit_unknown(self, type_, **kw):
+ if type_.__class__.__name__ == type_.__class__.__name__.upper():
+ return type_.__class__.__name__
+ else:
+ return repr(type_)
+
+ def visit_null(self, type_, **kw):
+ return "NULL"
+
+ def visit_user_defined(self, type_, **kw):
+ try:
+ get_col_spec = type_.get_col_spec
+ except AttributeError:
+ return repr(type_)
+ else:
+ return get_col_spec(**kw)
+
+
+class IdentifierPreparer(object):
+
+ """Handle quoting and case-folding of identifiers based on options."""
+
+ reserved_words = RESERVED_WORDS
+
+ legal_characters = LEGAL_CHARACTERS
+
+ illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+
+ schema_for_object = operator.attrgetter("schema")
+ """Return the .schema attribute for an object.
+
+ For the default IdentifierPreparer, the schema for an object is always
+ the value of the ".schema" attribute. if the preparer is replaced
+ with one that has a non-empty schema_translate_map, the value of the
+ ".schema" attribute is rendered a symbol that will be converted to a
+ real schema name from the mapping post-compile.
+
+ """
+
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
+ """Construct a new ``IdentifierPreparer`` object.
+
+ initial_quote
+ Character that begins a delimited identifier.
+
+ final_quote
+ Character that ends a delimited identifier. Defaults to
+ `initial_quote`.
+
+ omit_schema
+ Prevent prepending schema name. Useful for databases that do
+ not support schemae.
+ """
+
+ self.dialect = dialect
+ self.initial_quote = initial_quote
+ self.final_quote = final_quote or self.initial_quote
+ self.escape_quote = escape_quote
+ self.escape_to_quote = self.escape_quote * 2
+ self.omit_schema = omit_schema
+ self.quote_case_sensitive_collations = quote_case_sensitive_collations
+ self._strings = {}
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
+
+ def _with_schema_translate(self, schema_translate_map):
+ prep = self.__class__.__new__(self.__class__)
+ prep.__dict__.update(self.__dict__)
+
+ def symbol_getter(obj):
+ name = obj.schema
+ if name in schema_translate_map and obj._use_schema_map:
+ if name is not None and ("[" in name or "]" in name):
+ raise exc.CompileError(
+ "Square bracket characters ([]) not supported "
+ "in schema translate name '%s'" % name
+ )
+ return quoted_name(
+ "__[SCHEMA_%s]" % (name or "_none"), quote=False
+ )
+ else:
+ return obj.schema
+
+ prep.schema_for_object = symbol_getter
+ return prep
+
+ def _render_schema_translates(self, statement, schema_translate_map):
+ d = schema_translate_map
+ if None in d:
+ d["_none"] = d[None]
+
+ def replace(m):
+ name = m.group(2)
+ effective_schema = d[name]
+ if not effective_schema:
+ effective_schema = self.dialect.default_schema_name
+ if not effective_schema:
+ # TODO: no coverage here
+ raise exc.CompileError(
+ "Dialect has no default schema name; can't "
+ "use None as dynamic schema target."
+ )
+ return self.quote_schema(effective_schema)
+
+ return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
+
+ def _escape_identifier(self, value):
+ """Escape an identifier.
+
+ Subclasses should override this to provide database-dependent
+ escaping behavior.
+ """
+
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ if self._double_percents:
+ value = value.replace("%", "%%")
+ return value
+
+ def _unescape_identifier(self, value):
+ """Canonicalize an escaped identifier.
+
+ Subclasses should override this to provide database-dependent
+ unescaping behavior that reverses _escape_identifier.
+ """
+
+ return value.replace(self.escape_to_quote, self.escape_quote)
+
+ def validate_sql_phrase(self, element, reg):
+ """keyword sequence filter.
+
+ a filter for elements that are intended to represent keyword sequences,
+ such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
+ should be present.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if element is not None and not reg.match(element):
+ raise exc.CompileError(
+ "Unexpected SQL phrase: %r (matching against %r)"
+ % (element, reg.pattern)
+ )
+ return element
+
+ def quote_identifier(self, value):
+ """Quote an identifier.
+
+ Subclasses should override this to provide database-dependent
+ quoting behavior.
+ """
+
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
+
+ def _requires_quotes(self, value):
+ """Return True if the given identifier requires quoting."""
+ lc_value = value.lower()
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
+
+ def _requires_quotes_illegal_chars(self, value):
+ """Return True if the given identifier requires quoting, but
+ not taking case convention into account."""
+ return not self.legal_characters.match(util.text_type(value))
+
+ def quote_schema(self, schema, force=None):
+ """Conditionally quote a schema name.
+
+
+ The name is quoted if it is a reserved word, contains quote-necessary
+ characters, or is an instance of :class:`.quoted_name` which includes
+ ``quote`` set to ``True``.
+
+ Subclasses can override this to provide database-dependent
+ quoting behavior for schema names.
+
+ :param schema: string schema name
+ :param force: unused
+
+ .. deprecated:: 0.9
+
+ The :paramref:`.IdentifierPreparer.quote_schema.force`
+ parameter is deprecated and will be removed in a future
+ release. This flag has no effect on the behavior of the
+ :meth:`.IdentifierPreparer.quote` method; please refer to
+ :class:`.quoted_name`.
+
+ """
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote_schema.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ # deprecated 0.9. warning from 1.3
+ version="0.9",
+ )
+
+ return self.quote(schema)
+
+ def quote(self, ident, force=None):
+ """Conditionally quote an identifier.
+
+ The identifier is quoted if it is a reserved word, contains
+ quote-necessary characters, or is an instance of
+ :class:`.quoted_name` which includes ``quote`` set to ``True``.
+
+ Subclasses can override this to provide database-dependent
+ quoting behavior for identifier names.
+
+ :param ident: string identifier
+ :param force: unused
+
+ .. deprecated:: 0.9
+
+ The :paramref:`.IdentifierPreparer.quote.force`
+ parameter is deprecated and will be removed in a future
+ release. This flag has no effect on the behavior of the
+ :meth:`.IdentifierPreparer.quote` method; please refer to
+ :class:`.quoted_name`.
+
+ """
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ # deprecated 0.9. warning from 1.3
+ version="0.9",
+ )
+
+ force = getattr(ident, "quote", None)
+
+ if force is None:
+ if ident in self._strings:
+ return self._strings[ident]
+ else:
+ if self._requires_quotes(ident):
+ self._strings[ident] = self.quote_identifier(ident)
+ else:
+ self._strings[ident] = ident
+ return self._strings[ident]
+ elif force:
+ return self.quote_identifier(ident)
+ else:
+ return ident
+
+ def format_collation(self, collation_name):
+ if self.quote_case_sensitive_collations:
+ return self.quote(collation_name)
+ else:
+ return collation_name
+
+ def format_sequence(self, sequence, use_schema=True):
+ name = self.quote(sequence.name)
+
+ effective_schema = self.schema_for_object(sequence)
+
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
+ name = self.quote_schema(effective_schema) + "." + name
+ return name
+
+ def format_label(self, label, name=None):
+ return self.quote(name or label.name)
+
+ def format_alias(self, alias, name=None):
+ return self.quote(name or alias.name)
+
+ def format_savepoint(self, savepoint, name=None):
+ # Running the savepoint name through quoting is unnecessary
+ # for all known dialects. This is here to support potential
+ # third party use cases
+ ident = name or savepoint.ident
+ if self._requires_quotes(ident):
+ ident = self.quote_identifier(ident)
+ return ident
+
+ @util.preload_module("sqlalchemy.sql.naming")
+ def format_constraint(self, constraint, _alembic_quote=True):
+ naming = util.preloaded.sql_naming
+
+ if constraint.name is elements._NONE_NAME:
+ name = naming._constraint_name_for_table(
+ constraint, constraint.table
+ )
+
+ if name is None:
+ return None
+ else:
+ name = constraint.name
+
+ if constraint.__visit_name__ == "index":
+ return self.truncate_and_render_index_name(
+ name, _alembic_quote=_alembic_quote
+ )
+ else:
+ return self.truncate_and_render_constraint_name(
+ name, _alembic_quote=_alembic_quote
+ )
+
+ def truncate_and_render_index_name(self, name, _alembic_quote=True):
+ # calculate these at format time so that ad-hoc changes
+ # to dialect.max_identifier_length etc. can be reflected
+ # as IdentifierPreparer is long lived
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
+ return self._truncate_and_render_maxlen_name(
+ name, max_, _alembic_quote
+ )
+
+ def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+ # calculate these at format time so that ad-hoc changes
+ # to dialect.max_identifier_length etc. can be reflected
+ # as IdentifierPreparer is long lived
+ max_ = (
+ self.dialect.max_constraint_name_length
+ or self.dialect.max_identifier_length
+ )
+ return self._truncate_and_render_maxlen_name(
+ name, max_, _alembic_quote
+ )
+
+ def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+ if isinstance(name, elements._truncated_label):
+ if len(name) > max_:
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
+ else:
+ self.dialect.validate_identifier(name)
+
+ if not _alembic_quote:
+ return name
+ else:
+ return self.quote(name)
+
+ def format_index(self, index):
+ return self.format_constraint(index)
+
+ def format_table(self, table, use_schema=True, name=None):
+ """Prepare a quoted table and schema name."""
+
+ if name is None:
+ name = table.name
+
+ result = self.quote(name)
+
+ effective_schema = self.schema_for_object(table)
+
+ if not self.omit_schema and use_schema and effective_schema:
+ result = self.quote_schema(effective_schema) + "." + result
+ return result
+
+ def format_schema(self, name):
+ """Prepare a quoted schema name."""
+
+ return self.quote(name)
+
+ def format_label_name(
+ self,
+ name,
+ anon_map=None,
+ ):
+ """Prepare a quoted column name."""
+
+ if anon_map is not None and isinstance(
+ name, elements._truncated_label
+ ):
+ name = name.apply_map(anon_map)
+
+ return self.quote(name)
+
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ anon_map=None,
+ ):
+ """Prepare a quoted column name."""
+
+ if name is None:
+ name = column.name
+
+ if anon_map is not None and isinstance(
+ name, elements._truncated_label
+ ):
+ name = name.apply_map(anon_map)
+
+ if not getattr(column, "is_literal", False):
+ if use_table:
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
+ else:
+ return self.quote(name)
+ else:
+ # literal textual elements get stuck into ColumnClause a lot,
+ # which shouldn't get quoted
+
+ if use_table:
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
+ else:
+ return name
+
+ def format_table_seq(self, table, use_schema=True):
+ """Format table name and schema as a tuple."""
+
+ # Dialects with more levels in their fully qualified references
+ # ('database', 'owner', etc.) could override this and return
+ # a longer sequence.
+
+ effective_schema = self.schema_for_object(table)
+
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
+ else:
+ return (self.format_table(table, use_schema=False),)
+
+ @util.memoized_property
+ def _r_identifiers(self):
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
+ r = re.compile(
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
+ return r
+
+ def unformat_identifiers(self, identifiers):
+ """Unpack 'schema.table.column'-like strings into components."""
+
+ r = self._r_identifiers
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
new file mode 100644
index 0000000..920c8b3
--- /dev/null
+++ b/lib/sqlalchemy/sql/crud.py
@@ -0,0 +1,1091 @@
+# sql/crud.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Functions used by compiler.py to determine the parameters rendered
+within INSERT and UPDATE statements.
+
+"""
+import functools
+import operator
+
+from . import coercions
+from . import dml
+from . import elements
+from . import roles
+from .selectable import Select
+from .. import exc
+from .. import util
+
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
+Placeholder for the value within a :class:`.BindParameter`
+which is required to be present when the statement is passed
+to :meth:`_engine.Connection.execute`.
+
+This symbol is typically used when a :func:`_expression.insert`
+or :func:`_expression.update` statement is compiled without parameter
+values present.
+
+""",
+)
+
+
+def _get_crud_params(compiler, stmt, compile_state, **kw):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ Also generates the Compiled object's postfetch, prefetch, and
+ returning column collections, used for default handling and ultimately
+ populating the CursorResult's prefetch_cols() and postfetch_cols()
+ collections.
+
+ """
+
+ compiler.postfetch = []
+ compiler.insert_prefetch = []
+ compiler.update_prefetch = []
+ compiler.returning = []
+
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ (
+ _column_as_key,
+ _getattr_col_key,
+ _col_bind_name,
+ ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+
+ compiler._key_getters_for_crud_column = getters
+
+ # no parameters in the statement, no parameters in the
+ # compiled params - return binds for all columns
+ if compiler.column_keys is None and compile_state._no_parameters:
+ return [
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
+ for c in stmt.table.columns
+ ]
+
+ if compile_state._has_multi_parameters:
+ spd = compile_state._multi_parameters[0]
+ stmt_parameter_tuples = list(spd.items())
+ elif compile_state._ordered_values:
+ spd = compile_state._dict_parameters
+ stmt_parameter_tuples = compile_state._ordered_values
+ elif compile_state._dict_parameters:
+ spd = compile_state._dict_parameters
+ stmt_parameter_tuples = list(spd.items())
+ else:
+ stmt_parameter_tuples = spd = None
+
+ # if we have statement parameters - set defaults in the
+ # compiled params
+ if compiler.column_keys is None:
+ parameters = {}
+ elif stmt_parameter_tuples:
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if key not in spd
+ )
+ else:
+ parameters = dict(
+ (_column_as_key(key), REQUIRED) for key in compiler.column_keys
+ )
+
+ # create a list of column assignment clauses as tuples
+ values = []
+
+ if stmt_parameter_tuples is not None:
+ _get_stmt_parameter_tuples_params(
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameter_tuples,
+ _column_as_key,
+ values,
+ kw,
+ )
+
+ check_columns = {}
+
+ # special logic that only occurs for multi-table UPDATE
+ # statements
+ if compile_state.isupdate and compile_state.is_multitable:
+ _get_update_multitable_params(
+ compiler,
+ stmt,
+ compile_state,
+ stmt_parameter_tuples,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
+
+ if compile_state.isinsert and stmt._select_names:
+ _scan_insert_from_select_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
+ else:
+ _scan_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
+
+ if parameters and stmt_parameter_tuples:
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
+ .difference(check_columns)
+ )
+ if check:
+ raise exc.CompileError(
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % (c,) for c in check))
+ )
+
+ if compile_state._has_multi_parameters:
+ values = _extend_values_for_multiparams(
+ compiler,
+ stmt,
+ compile_state,
+ values,
+ _column_as_key,
+ kw,
+ )
+ elif (
+ not values
+ and compiler.for_executemany
+ and compiler.dialect.supports_default_metavalue
+ ):
+ # convert an "INSERT DEFAULT VALUES"
+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
+ # into an in-place multi values. This supports
+ # insert_executemany_returning mode :)
+ values = [
+ (
+ stmt.table.columns[0],
+ compiler.preparer.format_column(stmt.table.columns[0]),
+ "DEFAULT",
+ )
+ ]
+
+ return values
+
+
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
+ if name is None:
+ name = col.key
+ bindparam = elements.BindParameter(
+ name, value, type_=col.type, required=required
+ )
+ bindparam._is_crud = True
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler, **kw)
+ return bindparam
+
+
+def _handle_values_anonymous_param(compiler, col, value, name, **kw):
+ # the insert() and update() constructs as of 1.4 will now produce anonymous
+ # bindparam() objects in the values() collections up front when given plain
+ # literal values. This is so that cache key behaviors, which need to
+ # produce bound parameters in deterministic order without invoking any
+ # compilation here, can be applied to these constructs when they include
+ # values() (but not yet multi-values, which are not included in caching
+ # right now).
+ #
+ # in order to produce the desired "crud" style name for these parameters,
+ # which will also be targetable in engine/default.py through the usual
+ # conventions, apply our desired name to these unique parameters by
+ # populating the compiler truncated names cache with the desired name,
+ # rather than having
+ # compiler.visit_bindparam()->compiler._truncated_identifier make up a
+ # name. Saves on call counts also.
+
+ # for INSERT/UPDATE that's a CTE, we don't need names to match to
+ # external parameters and these would also conflict in the case where
+ # multiple insert/update are combined together using CTEs
+ is_cte = "visiting_cte" in kw
+
+ if (
+ not is_cte
+ and value.unique
+ and isinstance(value.key, elements._truncated_label)
+ ):
+ compiler.truncated_names[("bindparam", value.key)] = name
+
+ if value.type._isnull:
+ # either unique parameter, or other bound parameters that were
+ # passed in directly
+ # set type to that of the column unconditionally
+ value = value._with_binary_element_type(col.type)
+
+ return value._compiler_dispatch(compiler, **kw)
+
+
+def _key_getters_for_crud_column(compiler, stmt, compile_state):
+ if compile_state.isupdate and compile_state._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(compile_state._extra_froms)
+
+ c_key_role = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
+
+ def _column_as_key(key):
+ str_key = c_key_role(key)
+ if hasattr(key, "table") and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+
+def _scan_insert_from_select_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
+
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
+
+ assert compiler.stack[-1]["selectable"] is stmt
+
+ compiler.stack[-1]["insert_from_select"] = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, compiler.preparer.format_column(c), None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw
+ )
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ ins_from_select = compiler.stack[-1]["insert_from_select"]
+ if not isinstance(ins_from_select, Select):
+ raise exc.CompileError(
+ "Can't extend statement for INSERT..FROM SELECT to include "
+ "additional default-holding column(s) "
+ "%s. Convert the selectable to a subquery() first, or pass "
+ "include_defaults=False to Insert.from_select() to skip these "
+ "columns."
+ % (", ".join(repr(key) for _, key, _ in add_select_cols),)
+ )
+ ins_from_select = ins_from_select._generate()
+ # copy raw_columns
+ ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
+ expr for col, col_expr, expr in add_select_cols
+ ]
+ compiler.stack[-1]["insert_from_select"] = ins_from_select
+
+
+def _scan_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+ (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
+
+ if compile_state._parameter_ordering:
+ parameter_ordering = [
+ _column_as_key(key) for key in compile_state._parameter_ordering
+ ]
+ ordered_keys = set(parameter_ordering)
+ cols = [
+ stmt.table.c[key]
+ for key in parameter_ordering
+ if isinstance(key, util.string_types) and key in stmt.table.c
+ ] + [c for c in stmt.table.c if c.key not in ordered_keys]
+
+ else:
+ cols = stmt.table.columns
+
+ for c in cols:
+ # scan through every column in the target table
+
+ col_key = _getattr_col_key(c)
+
+ if col_key in parameters and col_key not in check_columns:
+ # parameter is present for the column. use that.
+
+ _append_param_parameter(
+ compiler,
+ stmt,
+ compile_state,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
+
+ elif compile_state.isinsert:
+ # no parameter is present and it's an insert.
+
+ if c.primary_key and need_pks:
+ # it's a primary key column, it will need to be generated by a
+ # default generator of some kind, and the statement expects
+ # inserted_primary_key to be available.
+
+ if implicit_returning:
+ # we can use RETURNING, find out how to invoke this
+ # column and get the value where RETURNING is an option.
+ # we can inline server-side functions in this case.
+
+ _append_param_insert_pk_returning(
+ compiler, stmt, c, values, kw
+ )
+ else:
+ # otherwise, find out how to invoke this column
+ # and get its value where RETURNING is not an option.
+ # if we have to invoke a server-side function, we need
+ # to pre-execute it. or if this is a straight
+ # autoincrement column and the dialect supports it
+ # we can use cursor.lastrowid.
+
+ _append_param_insert_pk_no_returning(
+ compiler, stmt, c, values, kw
+ )
+
+ elif c.default is not None:
+ # column has a default, but it's not a pk column, or it is but
+ # we don't need to get the pk back.
+ _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
+
+ elif c.server_default is not None:
+ # column has a DDL-level default, and is either not a pk
+ # column or we don't need the pk.
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
+ _warn_pk_with_no_anticipated_value(c)
+
+ elif compile_state.isupdate:
+ # no parameter is present and it's an insert.
+
+ _append_param_update(
+ compiler,
+ compile_state,
+ stmt,
+ c,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
+
+
+def _append_param_parameter(
+ compiler,
+ stmt,
+ compile_state,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
+ value = parameters.pop(col_key)
+
+ col_value = compiler.preparer.format_column(
+ c, use_table=compile_state.include_table_with_column_exprs
+ )
+
+ if coercions._is_literal(value):
+ value = _create_bind_param(
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler,
+ c,
+ value,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ else:
+ # value is a SQL expression
+ value = compiler.process(value.self_group(), **kw)
+
+ if compile_state.isupdate:
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ compiler.postfetch.append(c)
+ else:
+ if c.primary_key:
+
+ if implicit_returning:
+ compiler.returning.append(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ compiler.postfetch_lastrowid = True
+
+ elif implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ # postfetch specifically means, "we can SELECT the row we just
+ # inserted by primary key to get back the server generated
+ # defaults". so by definition this can't be used to get the
+ # primary key value back, because we need to have it ahead of
+ # time.
+
+ compiler.postfetch.append(c)
+
+ values.append((c, col_value, value))
+
+
+def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and RETURNING
+ is available.
+
+ """
+ if c.default is not None:
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
+ compiler.returning.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
+ )
+ compiler.returning.append(c)
+ else:
+ # client side default. OK we can't use RETURNING, need to
+ # do a "prefetch", which in fact fetches the default value
+ # on the Python side
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif c is stmt.table._autoincrement_column or c.server_default is not None:
+ compiler.returning.append(c)
+ elif not c.nullable:
+ # no .default, no .server_default, not autoincrement, we have
+ # no indication this primary key column will have any value
+ _warn_pk_with_no_anticipated_value(c)
+
+
+def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw):
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and we cannot use
+ RETURNING.
+
+ Depending on the kind of default here we may create a bound parameter
+ in the INSERT statement and pre-execute a default generation function,
+ or we may use cursor.lastrowid if supported by the dialect.
+
+
+ """
+
+ if (
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it either is not a sequence, or it is and we support
+ # sequences and want to invoke it
+ not c.default.is_sequence
+ or (
+ compiler.dialect.supports_sequences
+ and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ )
+ )
+ )
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # dialect can't use cursor.lastrowid
+ not compiler.dialect.postfetch_lastrowid
+ and (
+ # column has a Sequence and we support those
+ (
+ c.default is not None
+ and c.default.is_sequence
+ and compiler.dialect.supports_sequences
+ )
+ or
+ # column has no default on it, but dialect can run the
+ # "autoincrement" mechanism explicitly, e.g. PostgreSQL
+ # SERIAL we know the sequence name
+ (
+ c.default is None
+ and compiler.dialect.preexecute_autoincrement_sequences
+ )
+ )
+ )
+ ):
+ # do a pre-execute of the default
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif (
+ c.default is None
+ and c.server_default is None
+ and not c.nullable
+ and c is not stmt.table._autoincrement_column
+ ):
+ # no .default, no .server_default, not autoincrement, we have
+ # no indication this primary key column will have any value
+ _warn_pk_with_no_anticipated_value(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ # finally, where it seems like there will be a generated primary key
+ # value and we haven't set up any other way to fetch it, and the
+ # dialect supports cursor.lastrowid, switch on the lastrowid flag so
+ # that the DefaultExecutionContext calls upon cursor.lastrowid
+ compiler.postfetch_lastrowid = True
+
+
+def _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
+ )
+
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ # don't add primary key column to postfetch
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+
+
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.next_value())
+ )
+ elif c.default.is_clause_element:
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.arg.self_group())
+ )
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(
+ compiler, c, process=False, **kw
+ ),
+ )
+ )
+
+
+def _append_param_update(
+ compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
+):
+
+ include_table = compile_state.include_table_with_column_exprs
+ if c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(
+ c,
+ use_table=include_table,
+ ),
+ compiler.process(c.onupdate.arg.self_group(), **kw),
+ )
+ )
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(
+ c,
+ use_table=include_table,
+ ),
+ _create_update_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif c.server_onupdate is not None:
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ elif (
+ implicit_return_defaults
+ and (stmt._return_defaults_columns or not stmt._return_defaults)
+ and c in implicit_return_defaults
+ ):
+ compiler.returning.append(c)
+
+
+def _create_insert_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.insert_prefetch.append(c)
+ return param
+
+
+def _create_update_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.update_prefetch.append(c)
+ return param
+
+
+class _multiparam_column(elements.ColumnElement):
+ _is_multiparam_column = True
+
+ def __init__(self, original, index):
+ self.index = index
+ self.key = "%s_m%d" % (original.key, index + 1)
+ self.original = original
+ self.default = original.default
+ self.type = original.type
+
+ def compare(self, other, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, other, **kw):
+ raise NotImplementedError()
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
+
+
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+ if not c.default:
+ raise exc.CompileError(
+ "INSERT value for column %s is explicitly rendered as a bound"
+ "parameter in the VALUES clause; "
+ "a Python-side value or SQL expression is required" % c
+ )
+ elif c.default.is_clause_element:
+ return compiler.process(c.default.arg.self_group(), **kw)
+ elif c.default.is_sequence:
+ # these conditions would have been established
+ # by append_param_insert_(?:hasdefault|pk_returning|pk_no_returning)
+ # in order for us to be here, so these don't need to be
+ # checked
+ # assert compiler.dialect.supports_sequences and (
+ # not c.default.optional
+ # or not compiler.dialect.sequences_optional
+ # )
+ return compiler.process(c.default, **kw)
+ else:
+ col = _multiparam_column(c, index)
+ if isinstance(stmt, dml.Insert):
+ return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ else:
+ return _create_update_prefetch_bind_param(compiler, col, **kw)
+
+
+def _get_update_multitable_params(
+ compiler,
+ stmt,
+ compile_state,
+ stmt_parameter_tuples,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
+ normalized_params = dict(
+ (coercions.expect(roles.DMLColumnRole, c), param)
+ for c, param in stmt_parameter_tuples
+ )
+
+ include_table = compile_state.include_table_with_column_exprs
+
+ affected_tables = set()
+ for t in compile_state._extra_froms:
+ for c in t.c:
+ if c in normalized_params:
+ affected_tables.add(t)
+ check_columns[_getattr_col_key(c)] = c
+ value = normalized_params[c]
+
+ col_value = compiler.process(c, include_table=include_table)
+ if coercions._is_literal(value):
+ value = _create_bind_param(
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ **kw # TODO: no test coverage for literal binds here
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler, c, value, name=_col_bind_name(c), **kw
+ )
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, col_value, value))
+ # determine tables which are actually to be updated - process onupdate
+ # and server_onupdate for these
+ for t in affected_tables:
+ for c in t.c:
+ if c in normalized_params:
+ continue
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.process(c, include_table=include_table),
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
+ )
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.process(c, include_table=include_table),
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c), **kw
+ ),
+ )
+ )
+ elif c.server_onupdate is not None:
+ compiler.postfetch.append(c)
+
+
+def _extend_values_for_multiparams(
+ compiler,
+ stmt,
+ compile_state,
+ values,
+ _column_as_key,
+ kw,
+):
+ values_0 = values
+ values = [values]
+
+ for i, row in enumerate(compile_state._multi_parameters[1:]):
+ extension = []
+
+ row = {_column_as_key(key): v for key, v in row.items()}
+
+ for (col, col_expr, param) in values_0:
+ if col.key in row:
+ key = col.key
+
+ if coercions._is_literal(row[key]):
+ new_param = _create_bind_param(
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
+ )
+ else:
+ new_param = compiler.process(row[key].self_group(), **kw)
+ else:
+ new_param = _process_multiparam_default_bind(
+ compiler, stmt, col, i, kw
+ )
+
+ extension.append((col, col_expr, new_param))
+
+ values.append(extension)
+
+ return values
+
+
+def _get_stmt_parameter_tuples_params(
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameter_tuples,
+ _column_as_key,
+ values,
+ kw,
+):
+
+ for k, v in stmt_parameter_tuples:
+ colkey = _column_as_key(k)
+ if colkey is not None:
+ parameters.setdefault(colkey, v)
+ else:
+ # a non-Column expression on the left side;
+ # add it to values() in an "as-is" state,
+ # coercing right side to bound param
+
+ # note one of the main use cases for this is array slice
+ # updates on PostgreSQL, as the left side is also an expression.
+
+ col_expr = compiler.process(
+ k, include_table=compile_state.include_table_with_column_exprs
+ )
+
+ if coercions._is_literal(v):
+ v = compiler.process(
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
+ else:
+ if v._is_bind_parameter and v.type._isnull:
+ # either unique parameter, or other bound parameters that
+ # were passed in directly
+ # set type to that of the column unconditionally
+ v = v._with_binary_element_type(k.type)
+
+ v = compiler.process(v.self_group(), **kw)
+
+ values.append((k, col_expr, v))
+
+
+def _get_returning_modifiers(compiler, stmt, compile_state):
+
+ need_pks = (
+ compile_state.isinsert
+ and not stmt._inline
+ and (
+ not compiler.for_executemany
+ or (
+ compiler.dialect.insert_executemany_returning
+ and stmt._return_defaults
+ )
+ )
+ and not stmt._returning
+ and not compile_state._has_multi_parameters
+ )
+
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
+
+ if compile_state.isinsert:
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
+ elif compile_state.isupdate:
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
+ else:
+ # this line is unused, currently we are always
+ # isinsert or isupdate
+ implicit_return_defaults = False # pragma: no cover
+
+ if implicit_return_defaults:
+ if not stmt._return_defaults_columns:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults_columns)
+
+ postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
+
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
+
+
+def _warn_pk_with_no_anticipated_value(c):
+ msg = (
+ "Column '%s.%s' is marked as a member of the "
+ "primary key for table '%s', "
+ "but has no Python-side or server-side default generator indicated, "
+ "nor does it indicate 'autoincrement=True' or 'nullable=True', "
+ "and no explicit value is passed. "
+ "Primary key columns typically may not store NULL."
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
+ if len(c.table.primary_key) > 1:
+ msg += (
+ " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
+ "indicated explicitly for composite (e.g. multicolumn) primary "
+ "keys if AUTO_INCREMENT/SERIAL/IDENTITY "
+ "behavior is expected for one of the columns in the primary key. "
+ "CREATE TABLE statements are impacted by this change as well on "
+ "most backends."
+ )
+ util.warn(msg)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
new file mode 100644
index 0000000..e608052
--- /dev/null
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -0,0 +1,1341 @@
+# sql/ddl.py
+# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+Provides the hierarchy of DDL-defining schema items as well as routines
+to invoke them for a create/drop call.
+
+"""
+
+from . import roles
+from .base import _bind_or_error
+from .base import _generative
+from .base import Executable
+from .base import SchemaVisitor
+from .elements import ClauseElement
+from .. import exc
+from .. import util
+from ..util import topological
+
+
+class _DDLCompiles(ClauseElement):
+ _hierarchy_supports_caching = False
+ """disable cache warnings for all _DDLCompiles subclasses. """
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a
+ Dialect."""
+
+ return dialect.ddl_compiler(dialect, self, **kw)
+
+ def _compile_w_cache(self, *arg, **kw):
+ raise NotImplementedError()
+
+
+class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
+ """Base class for DDL expression constructs.
+
+ This class is the base for the general purpose :class:`.DDL` class,
+ as well as the various create/drop clause constructs such as
+ :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`,
+ etc.
+
+ :class:`.DDLElement` integrates closely with SQLAlchemy events,
+ introduced in :ref:`event_toplevel`. An instance of one is
+ itself an event receiving callable::
+
+ event.listen(
+ users,
+ 'after_create',
+ AddConstraint(constraint).execute_if(dialect='postgresql')
+ )
+
+ .. seealso::
+
+ :class:`.DDL`
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ :ref:`schema_ddl_sequences`
+
+ """
+
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
+
+ target = None
+ on = None
+ dialect = None
+ callable_ = None
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_ddl(
+ self, multiparams, params, execution_options
+ )
+
+ @util.deprecated_20(
+ ":meth:`.DDLElement.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, bind=None, target=None):
+ """Execute this DDL immediately.
+
+ Executes the DDL statement in isolation using the supplied
+ :class:`.Connectable` or
+ :class:`.Connectable` assigned to the ``.bind``
+ property, if not supplied. If the DDL has a conditional ``on``
+ criteria, it will be invoked with None as the event.
+
+ :param bind:
+ Optional, an ``Engine`` or ``Connection``. If not supplied, a valid
+ :class:`.Connectable` must be present in the
+ ``.bind`` property.
+
+ :param target:
+ Optional, defaults to None. The target :class:`_schema.SchemaItem`
+ for the execute call. This is equivalent to passing the
+ :class:`_schema.SchemaItem` to the :meth:`.DDLElement.against`
+ method and then invoking :meth:`_schema.DDLElement.execute`
+ upon the resulting :class:`_schema.DDLElement` object. See
+ :meth:`.DDLElement.against` for further detail.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ if self._should_execute(target, bind):
+ return bind.execute(self.against(target))
+ else:
+ bind.engine.logger.info("DDL execution skipped, criteria not met.")
+
+ @_generative
+ def against(self, target):
+ """Return a copy of this :class:`_schema.DDLElement` which will include
+ the given target.
+
+ This essentially applies the given item to the ``.target`` attribute
+ of the returned :class:`_schema.DDLElement` object. This target
+ is then usable by event handlers and compilation routines in order to
+ provide services such as tokenization of a DDL string in terms of a
+ particular :class:`_schema.Table`.
+
+ When a :class:`_schema.DDLElement` object is established as an event
+ handler for the :meth:`_events.DDLEvents.before_create` or
+ :meth:`_events.DDLEvents.after_create` events, and the event
+ then occurs for a given target such as a :class:`_schema.Constraint`
+ or :class:`_schema.Table`, that target is established with a copy
+ of the :class:`_schema.DDLElement` object using this method, which
+ then proceeds to the :meth:`_schema.DDLElement.execute` method
+ in order to invoke the actual DDL instruction.
+
+ :param target: a :class:`_schema.SchemaItem` that will be the subject
+ of a DDL operation.
+
+ :return: a copy of this :class:`_schema.DDLElement` with the
+ ``.target`` attribute assigned to the given
+ :class:`_schema.SchemaItem`.
+
+ .. seealso::
+
+ :class:`_schema.DDL` - uses tokenization against the "target" when
+ processing the DDL string.
+
+ """
+
+ self.target = target
+
+ @_generative
+ def execute_if(self, dialect=None, callable_=None, state=None):
+ r"""Return a callable that will execute this
+ :class:`_ddl.DDLElement` conditionally within an event handler.
+
+ Used to provide a wrapper for event listening::
+
+ event.listen(
+ metadata,
+ 'before_create',
+ DDL("my_ddl").execute_if(dialect='postgresql')
+ )
+
+ :param dialect: May be a string or tuple of strings.
+ If a string, it will be compared to the name of the
+ executing database dialect::
+
+ DDL('something').execute_if(dialect='postgresql')
+
+ If a tuple, specifies multiple dialect names::
+
+ DDL('something').execute_if(dialect=('postgresql', 'mysql'))
+
+ :param callable\_: A callable, which will be invoked with
+ four positional arguments as well as optional keyword
+ arguments:
+
+ :ddl:
+ This DDL element.
+
+ :target:
+ The :class:`_schema.Table` or :class:`_schema.MetaData`
+ object which is the
+ target of this event. May be None if the DDL is executed
+ explicitly.
+
+ :bind:
+ The :class:`_engine.Connection` being used for DDL execution
+
+ :tables:
+ Optional keyword argument - a list of Table objects which are to
+ be created/ dropped within a MetaData.create_all() or drop_all()
+ method call.
+
+ :state:
+ Optional keyword argument - will be the ``state`` argument
+ passed to this function.
+
+ :checkfirst:
+ Keyword argument, will be True if the 'checkfirst' flag was
+ set during the call to ``create()``, ``create_all()``,
+ ``drop()``, ``drop_all()``.
+
+ If the callable returns a True value, the DDL statement will be
+ executed.
+
+ :param state: any value which will be passed to the callable\_
+ as the ``state`` keyword argument.
+
+ .. seealso::
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ """
+ self.dialect = dialect
+ self.callable_ = callable_
+ self.state = state
+
+ def _should_execute(self, target, bind, **kw):
+ if isinstance(self.dialect, util.string_types):
+ if self.dialect != bind.engine.name:
+ return False
+ elif isinstance(self.dialect, (tuple, list, set)):
+ if bind.engine.name not in self.dialect:
+ return False
+ if self.callable_ is not None and not self.callable_(
+ self, target, bind, state=self.state, **kw
+ ):
+ return False
+
+ return True
+
+ def __call__(self, target, bind, **kw):
+ """Execute the DDL as a ddl_listener."""
+
+ if self._should_execute(target, bind, **kw):
+ return bind.execute(self.against(target))
+
+ def bind(self):
+ if self._bind:
+ return self._bind
+
+ def _set_bind(self, bind):
+ self._bind = bind
+
+ bind = property(bind, _set_bind)
+
+ def _generate(self):
+ s = self.__class__.__new__(self.__class__)
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+
+class DDL(DDLElement):
+ """A literal DDL statement.
+
+ Specifies literal SQL DDL to be executed by the database. DDL objects
+ function as DDL event listeners, and can be subscribed to those events
+ listed in :class:`.DDLEvents`, using either :class:`_schema.Table` or
+ :class:`_schema.MetaData` objects as targets.
+ Basic templating support allows
+ a single DDL instance to handle repetitive tasks for multiple tables.
+
+ Examples::
+
+ from sqlalchemy import event, DDL
+
+ tbl = Table('users', metadata, Column('uid', Integer))
+ event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger'))
+
+ spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE')
+ event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb'))
+
+ drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
+ connection.execute(drop_spow)
+
+ When operating on Table events, the following ``statement``
+ string substitutions are available::
+
+ %(table)s - the Table name, with any required quoting applied
+ %(schema)s - the schema name, with any required quoting applied
+ %(fullname)s - the Table name including schema, quoted if needed
+
+ The DDL's "context", if any, will be combined with the standard
+ substitutions noted above. Keys present in the context will override
+ the standard substitutions.
+
+ """
+
+ __visit_name__ = "ddl"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DDL.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, statement, context=None, bind=None):
+ """Create a DDL statement.
+
+ :param statement:
+ A string or unicode string to be executed. Statements will be
+ processed with Python's string formatting operator using
+ a fixed set of string substitutions, as well as additional
+ substitutions provided by the optional :paramref:`.DDL.context`
+ parameter.
+
+ A literal '%' in a statement must be escaped as '%%'.
+
+ SQL bind parameters are not available in DDL statements.
+
+ :param context:
+ Optional dictionary, defaults to None. These values will be
+ available for use in string substitutions on the DDL statement.
+
+ :param bind:
+ Optional. A :class:`.Connectable`, used by
+ default when ``execute()`` is invoked without a bind argument.
+
+
+ .. seealso::
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ """
+
+ if not isinstance(statement, util.string_types):
+ raise exc.ArgumentError(
+ "Expected a string or unicode SQL statement, got '%r'"
+ % statement
+ )
+
+ self.statement = statement
+ self.context = context or {}
+
+ self._bind = bind
+
+ def __repr__(self):
+ return "<%s@%s; %s>" % (
+ type(self).__name__,
+ id(self),
+ ", ".join(
+ [repr(self.statement)]
+ + [
+ "%s=%r" % (key, getattr(self, key))
+ for key in ("on", "context")
+ if getattr(self, key)
+ ]
+ ),
+ )
+
+
+class _CreateDropBase(DDLElement):
+ """Base class for DDL constructs that represent CREATE and DROP or
+ equivalents.
+
+ The common theme of _CreateDropBase is a single
+ ``element`` attribute which refers to the element
+ to be created or dropped.
+
+ """
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DDLElement.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ element,
+ bind=None,
+ if_exists=False,
+ if_not_exists=False,
+ _legacy_bind=None,
+ ):
+ self.element = element
+ if bind:
+ self.bind = bind
+ elif _legacy_bind:
+ self.bind = _legacy_bind
+ self.if_exists = if_exists
+ self.if_not_exists = if_not_exists
+
+ @property
+ def stringify_dialect(self):
+ return self.element.create_drop_stringify_dialect
+
+ def _create_rule_disable(self, compiler):
+ """Allow disable of _create_rule using a callable.
+
+ Pass to _create_rule using
+ util.portable_instancemethod(self._create_rule_disable)
+ to retain serializability.
+
+ """
+ return False
+
+
+class CreateSchema(_CreateDropBase):
+ """Represent a CREATE SCHEMA statement.
+
+ The argument here is the string name of the schema.
+
+ """
+
+ __visit_name__ = "create_schema"
+
+ def __init__(self, name, quote=None, **kw):
+ """Create a new :class:`.CreateSchema` construct."""
+
+ self.quote = quote
+ super(CreateSchema, self).__init__(name, **kw)
+
+
+class DropSchema(_CreateDropBase):
+ """Represent a DROP SCHEMA statement.
+
+ The argument here is the string name of the schema.
+
+ """
+
+ __visit_name__ = "drop_schema"
+
+ def __init__(self, name, quote=None, cascade=False, **kw):
+ """Create a new :class:`.DropSchema` construct."""
+
+ self.quote = quote
+ self.cascade = cascade
+ super(DropSchema, self).__init__(name, **kw)
+
+
+class CreateTable(_CreateDropBase):
+ """Represent a CREATE TABLE statement."""
+
+ __visit_name__ = "create_table"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.CreateTable.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ element,
+ bind=None,
+ include_foreign_key_constraints=None,
+ if_not_exists=False,
+ ):
+ """Create a :class:`.CreateTable` construct.
+
+ :param element: a :class:`_schema.Table` that's the subject
+ of the CREATE
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param include_foreign_key_constraints: optional sequence of
+ :class:`_schema.ForeignKeyConstraint` objects that will be included
+ inline within the CREATE construct; if omitted, all foreign key
+ constraints that do not specify use_alter=True are included.
+
+ .. versionadded:: 1.0.0
+
+ :param if_not_exists: if True, an IF NOT EXISTS operator will be
+ applied to the construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(CreateTable, self).__init__(
+ element, _legacy_bind=bind, if_not_exists=if_not_exists
+ )
+ self.columns = [CreateColumn(column) for column in element.columns]
+ self.include_foreign_key_constraints = include_foreign_key_constraints
+
+
+class _DropView(_CreateDropBase):
+ """Semi-public 'DROP VIEW' construct.
+
+ Used by the test suite for dialect-agnostic drops of views.
+ This object will eventually be part of a public "view" API.
+
+ """
+
+ __visit_name__ = "drop_view"
+
+
+class CreateColumn(_DDLCompiles):
+ """Represent a :class:`_schema.Column`
+ as rendered in a CREATE TABLE statement,
+ via the :class:`.CreateTable` construct.
+
+ This is provided to support custom column DDL within the generation
+ of CREATE TABLE statements, by using the
+ compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel`
+ to extend :class:`.CreateColumn`.
+
+ Typical integration is to examine the incoming :class:`_schema.Column`
+ object, and to redirect compilation if a particular flag or condition
+ is found::
+
+ from sqlalchemy import schema
+ from sqlalchemy.ext.compiler import compiles
+
+ @compiles(schema.CreateColumn)
+ def compile(element, compiler, **kw):
+ column = element.element
+
+ if "special" not in column.info:
+ return compiler.visit_create_column(element, **kw)
+
+ text = "%s SPECIAL DIRECTIVE %s" % (
+ column.name,
+ compiler.type_compiler.process(column.type)
+ )
+ default = compiler.get_column_default_string(column)
+ if default is not None:
+ text += " DEFAULT " + default
+
+ if not column.nullable:
+ text += " NOT NULL"
+
+ if column.constraints:
+ text += " ".join(
+ compiler.process(const)
+ for const in column.constraints)
+ return text
+
+ The above construct can be applied to a :class:`_schema.Table`
+ as follows::
+
+ from sqlalchemy import Table, Metadata, Column, Integer, String
+ from sqlalchemy import schema
+
+ metadata = MetaData()
+
+ table = Table('mytable', MetaData(),
+ Column('x', Integer, info={"special":True}, primary_key=True),
+ Column('y', String(50)),
+ Column('z', String(20), info={"special":True})
+ )
+
+ metadata.create_all(conn)
+
+ Above, the directives we've added to the :attr:`_schema.Column.info`
+ collection
+ will be detected by our custom compilation scheme::
+
+ CREATE TABLE mytable (
+ x SPECIAL DIRECTIVE INTEGER NOT NULL,
+ y VARCHAR(50),
+ z SPECIAL DIRECTIVE VARCHAR(20),
+ PRIMARY KEY (x)
+ )
+
+ The :class:`.CreateColumn` construct can also be used to skip certain
+ columns when producing a ``CREATE TABLE``. This is accomplished by
+ creating a compilation rule that conditionally returns ``None``.
+ This is essentially how to produce the same effect as using the
+ ``system=True`` argument on :class:`_schema.Column`, which marks a column
+ as an implicitly-present "system" column.
+
+ For example, suppose we wish to produce a :class:`_schema.Table`
+ which skips
+ rendering of the PostgreSQL ``xmin`` column against the PostgreSQL
+ backend, but on other backends does render it, in anticipation of a
+ triggered rule. A conditional compilation rule could skip this name only
+ on PostgreSQL::
+
+ from sqlalchemy.schema import CreateColumn
+
+ @compiles(CreateColumn, "postgresql")
+ def skip_xmin(element, compiler, **kw):
+ if element.element.name == 'xmin':
+ return None
+ else:
+ return compiler.visit_create_column(element, **kw)
+
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('xmin', Integer)
+ )
+
+ Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE``
+ which only includes the ``id`` column in the string; the ``xmin`` column
+ will be omitted, but only against the PostgreSQL backend.
+
+ """
+
+ __visit_name__ = "create_column"
+
+ def __init__(self, element):
+ self.element = element
+
+
+class DropTable(_CreateDropBase):
+ """Represent a DROP TABLE statement."""
+
+ __visit_name__ = "drop_table"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DropTable.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_exists=False):
+ """Create a :class:`.DropTable` construct.
+
+ :param element: a :class:`_schema.Table` that's the subject
+ of the DROP.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_exists: if True, an IF EXISTS operator will be applied to the
+ construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(DropTable, self).__init__(
+ element, _legacy_bind=bind, if_exists=if_exists
+ )
+
+
+class CreateSequence(_CreateDropBase):
+ """Represent a CREATE SEQUENCE statement."""
+
+ __visit_name__ = "create_sequence"
+
+
+class DropSequence(_CreateDropBase):
+ """Represent a DROP SEQUENCE statement."""
+
+ __visit_name__ = "drop_sequence"
+
+
+class CreateIndex(_CreateDropBase):
+ """Represent a CREATE INDEX statement."""
+
+ __visit_name__ = "create_index"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.CreateIndex.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_not_exists=False):
+ """Create a :class:`.Createindex` construct.
+
+ :param element: a :class:`_schema.Index` that's the subject
+ of the CREATE.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_not_exists: if True, an IF NOT EXISTS operator will be
+ applied to the construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(CreateIndex, self).__init__(
+ element, _legacy_bind=bind, if_not_exists=if_not_exists
+ )
+
+
+class DropIndex(_CreateDropBase):
+ """Represent a DROP INDEX statement."""
+
+ __visit_name__ = "drop_index"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DropIndex.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_exists=False):
+ """Create a :class:`.DropIndex` construct.
+
+ :param element: a :class:`_schema.Index` that's the subject
+ of the DROP.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_exists: if True, an IF EXISTS operator will be applied to the
+ construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(DropIndex, self).__init__(
+ element, _legacy_bind=bind, if_exists=if_exists
+ )
+
+
+class AddConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE ADD CONSTRAINT statement."""
+
+ __visit_name__ = "add_constraint"
+
+ def __init__(self, element, *args, **kw):
+ super(AddConstraint, self).__init__(element, *args, **kw)
+ element._create_rule = util.portable_instancemethod(
+ self._create_rule_disable
+ )
+
+
+class DropConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE DROP CONSTRAINT statement."""
+
+ __visit_name__ = "drop_constraint"
+
+ def __init__(self, element, cascade=False, **kw):
+ self.cascade = cascade
+ super(DropConstraint, self).__init__(element, **kw)
+ element._create_rule = util.portable_instancemethod(
+ self._create_rule_disable
+ )
+
+
+class SetTableComment(_CreateDropBase):
+ """Represent a COMMENT ON TABLE IS statement."""
+
+ __visit_name__ = "set_table_comment"
+
+
+class DropTableComment(_CreateDropBase):
+ """Represent a COMMENT ON TABLE '' statement.
+
+ Note this varies a lot across database backends.
+
+ """
+
+ __visit_name__ = "drop_table_comment"
+
+
+class SetColumnComment(_CreateDropBase):
+ """Represent a COMMENT ON COLUMN IS statement."""
+
+ __visit_name__ = "set_column_comment"
+
+
+class DropColumnComment(_CreateDropBase):
+ """Represent a COMMENT ON COLUMN IS NULL statement."""
+
+ __visit_name__ = "drop_column_comment"
+
+
+class DDLBase(SchemaVisitor):
+ def __init__(self, connection):
+ self.connection = connection
+
+
+class SchemaGenerator(DDLBase):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+ self.memo = {}
+
+ def _can_create_table(self, table):
+ self.dialect.validate_identifier(table.name)
+ effective_schema = self.connection.schema_for_object(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or not self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
+
+ def _can_create_index(self, index):
+ effective_schema = self.connection.schema_for_object(index.table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or not self.dialect.has_index(
+ self.connection,
+ index.table.name,
+ index.name,
+ schema=effective_schema,
+ )
+
+ def _can_create_sequence(self, sequence):
+ effective_schema = self.connection.schema_for_object(sequence)
+
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or not self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
+
+ def visit_metadata(self, metadata):
+ if self.tables is not None:
+ tables = self.tables
+ else:
+ tables = list(metadata.tables.values())
+
+ collection = sort_tables_and_constraints(
+ [t for t in tables if self._can_create_table(t)]
+ )
+
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if s.column is None and self._can_create_sequence(s)
+ ]
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+ metadata.dispatch.before_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ for seq in seq_coll:
+ self.traverse_single(seq, create_ok=True)
+
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ create_ok=True,
+ include_foreign_key_constraints=fkcs,
+ _is_metadata_operation=True,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
+
+ metadata.dispatch.after_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ def visit_table(
+ self,
+ table,
+ create_ok=False,
+ include_foreign_key_constraints=None,
+ _is_metadata_operation=False,
+ ):
+ if not create_ok and not self._can_create_table(table):
+ return
+
+ table.dispatch.before_create(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ if not self.dialect.supports_alter:
+ # e.g., don't omit any foreign key constraints
+ include_foreign_key_constraints = None
+
+ self.connection.execute(
+ # fmt: off
+ CreateTable(
+ table,
+ include_foreign_key_constraints= # noqa
+ include_foreign_key_constraints, # noqa
+ )
+ # fmt: on
+ )
+
+ if hasattr(table, "indexes"):
+ for index in table.indexes:
+ self.traverse_single(index, create_ok=True)
+
+ if self.dialect.supports_comments and not self.dialect.inline_comments:
+ if table.comment is not None:
+ self.connection.execute(SetTableComment(table))
+
+ for column in table.columns:
+ if column.comment is not None:
+ self.connection.execute(SetColumnComment(column))
+
+ table.dispatch.after_create(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ def visit_foreign_key_constraint(self, constraint):
+ if not self.dialect.supports_alter:
+ return
+ self.connection.execute(AddConstraint(constraint))
+
+ def visit_sequence(self, sequence, create_ok=False):
+ if not create_ok and not self._can_create_sequence(sequence):
+ return
+ self.connection.execute(CreateSequence(sequence))
+
+ def visit_index(self, index, create_ok=False):
+ if not create_ok and not self._can_create_index(index):
+ return
+ self.connection.execute(CreateIndex(index))
+
+
+class SchemaDropper(DDLBase):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
+ super(SchemaDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+ self.memo = {}
+
+ def visit_metadata(self, metadata):
+ if self.tables is not None:
+ tables = self.tables
+ else:
+ tables = list(metadata.tables.values())
+
+ try:
+ unsorted_tables = [t for t in tables if self._can_drop_table(t)]
+ collection = list(
+ reversed(
+ sort_tables_and_constraints(
+ unsorted_tables,
+ filter_fn=lambda constraint: False
+ if not self.dialect.supports_alter
+ or constraint.name is None
+ else None,
+ )
+ )
+ )
+ except exc.CircularDependencyError as err2:
+ if not self.dialect.supports_alter:
+ util.warn(
+ "Can't sort tables for DROP; an "
+ "unresolvable foreign key "
+ "dependency exists between tables: %s; and backend does "
+ "not support ALTER. To restore at least a partial sort, "
+ "apply use_alter=True to ForeignKey and "
+ "ForeignKeyConstraint "
+ "objects involved in the cycle to mark these as known "
+ "cycles that will be ignored."
+ % (", ".join(sorted([t.fullname for t in err2.cycles])))
+ )
+ collection = [(t, ()) for t in unsorted_tables]
+ else:
+ util.raise_(
+ exc.CircularDependencyError(
+ err2.args[0],
+ err2.cycles,
+ err2.edges,
+ msg="Can't sort tables for DROP; an "
+ "unresolvable foreign key "
+ "dependency exists between tables: %s. Please ensure "
+ "that the ForeignKey and ForeignKeyConstraint objects "
+ "involved in the cycle have "
+ "names so that they can be dropped using "
+ "DROP CONSTRAINT."
+ % (
+ ", ".join(
+ sorted([t.fullname for t in err2.cycles])
+ )
+ ),
+ ),
+ from_=err2,
+ )
+
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if self._can_drop_sequence(s)
+ ]
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+
+ metadata.dispatch.before_drop(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ drop_ok=True,
+ _is_metadata_operation=True,
+ _ignore_sequences=seq_coll,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
+
+ for seq in seq_coll:
+ self.traverse_single(seq, drop_ok=seq.column is None)
+
+ metadata.dispatch.after_drop(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ def _can_drop_table(self, table):
+ self.dialect.validate_identifier(table.name)
+ effective_schema = self.connection.schema_for_object(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
+
+ def _can_drop_index(self, index):
+ effective_schema = self.connection.schema_for_object(index.table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or self.dialect.has_index(
+ self.connection,
+ index.table.name,
+ index.name,
+ schema=effective_schema,
+ )
+
+ def _can_drop_sequence(self, sequence):
+ effective_schema = self.connection.schema_for_object(sequence)
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
+
+ def visit_index(self, index, drop_ok=False):
+ if not drop_ok and not self._can_drop_index(index):
+ return
+
+ self.connection.execute(DropIndex(index))
+
+ def visit_table(
+ self,
+ table,
+ drop_ok=False,
+ _is_metadata_operation=False,
+ _ignore_sequences=(),
+ ):
+ if not drop_ok and not self._can_drop_table(table):
+ return
+
+ table.dispatch.before_drop(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ self.connection.execute(DropTable(table))
+
+ # traverse client side defaults which may refer to server-side
+ # sequences. noting that some of these client side defaults may also be
+ # set up as server side defaults (see https://docs.sqlalchemy.org/en/
+ # latest/core/defaults.html#associating-a-sequence-as-the-server-side-
+ # default), so have to be dropped after the table is dropped.
+ for column in table.columns:
+ if (
+ column.default is not None
+ and column.default not in _ignore_sequences
+ ):
+ self.traverse_single(column.default)
+
+ table.dispatch.after_drop(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ def visit_foreign_key_constraint(self, constraint):
+ if not self.dialect.supports_alter:
+ return
+ self.connection.execute(DropConstraint(constraint))
+
+ def visit_sequence(self, sequence, drop_ok=False):
+
+ if not drop_ok and not self._can_drop_sequence(sequence):
+ return
+ self.connection.execute(DropSequence(sequence))
+
+
+def sort_tables(
+ tables,
+ skip_fn=None,
+ extra_dependencies=None,
+):
+ """Sort a collection of :class:`_schema.Table` objects based on
+ dependency.
+
+ This is a dependency-ordered sort which will emit :class:`_schema.Table`
+ objects such that they will follow their dependent :class:`_schema.Table`
+ objects.
+ Tables are dependent on another based on the presence of
+ :class:`_schema.ForeignKeyConstraint`
+ objects as well as explicit dependencies
+ added by :meth:`_schema.Table.add_is_dependent_on`.
+
+ .. warning::
+
+ The :func:`._schema.sort_tables` function cannot by itself
+ accommodate automatic resolution of dependency cycles between
+ tables, which are usually caused by mutually dependent foreign key
+ constraints. When these cycles are detected, the foreign keys
+ of these tables are omitted from consideration in the sort.
+ A warning is emitted when this condition occurs, which will be an
+ exception raise in a future release. Tables which are not part
+ of the cycle will still be returned in dependency order.
+
+ To resolve these cycles, the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be
+ applied to those constraints which create a cycle. Alternatively,
+ the :func:`_schema.sort_tables_and_constraints` function will
+ automatically return foreign key constraints in a separate
+ collection when cycles are detected so that they may be applied
+ to a schema separately.
+
+ .. versionchanged:: 1.3.17 - a warning is emitted when
+ :func:`_schema.sort_tables` cannot perform a proper sort due to
+ cyclical dependencies. This will be an exception in a future
+ release. Additionally, the sort will continue to return
+ other tables not involved in the cycle in dependency order
+ which was not the case previously.
+
+ :param tables: a sequence of :class:`_schema.Table` objects.
+
+ :param skip_fn: optional callable which will be passed a
+ :class:`_schema.ForeignKey` object; if it returns True, this
+ constraint will not be considered as a dependency. Note this is
+ **different** from the same parameter in
+ :func:`.sort_tables_and_constraints`, which is
+ instead passed the owning :class:`_schema.ForeignKeyConstraint` object.
+
+ :param extra_dependencies: a sequence of 2-tuples of tables which will
+ also be considered as dependent on each other.
+
+ .. seealso::
+
+ :func:`.sort_tables_and_constraints`
+
+ :attr:`_schema.MetaData.sorted_tables` - uses this function to sort
+
+
+ """
+
+ if skip_fn is not None:
+
+ def _skip_fn(fkc):
+ for fk in fkc.elements:
+ if skip_fn(fk):
+ return True
+ else:
+ return None
+
+ else:
+ _skip_fn = None
+
+ return [
+ t
+ for (t, fkcs) in sort_tables_and_constraints(
+ tables,
+ filter_fn=_skip_fn,
+ extra_dependencies=extra_dependencies,
+ _warn_for_cycles=True,
+ )
+ if t is not None
+ ]
+
+
+def sort_tables_and_constraints(
+ tables, filter_fn=None, extra_dependencies=None, _warn_for_cycles=False
+):
+ """Sort a collection of :class:`_schema.Table` /
+ :class:`_schema.ForeignKeyConstraint`
+ objects.
+
+ This is a dependency-ordered sort which will emit tuples of
+ ``(Table, [ForeignKeyConstraint, ...])`` such that each
+ :class:`_schema.Table` follows its dependent :class:`_schema.Table`
+ objects.
+ Remaining :class:`_schema.ForeignKeyConstraint`
+ objects that are separate due to
+ dependency rules not satisfied by the sort are emitted afterwards
+ as ``(None, [ForeignKeyConstraint ...])``.
+
+ Tables are dependent on another based on the presence of
+ :class:`_schema.ForeignKeyConstraint` objects, explicit dependencies
+ added by :meth:`_schema.Table.add_is_dependent_on`,
+ as well as dependencies
+ stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn`
+ and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies`
+ parameters.
+
+ :param tables: a sequence of :class:`_schema.Table` objects.
+
+ :param filter_fn: optional callable which will be passed a
+ :class:`_schema.ForeignKeyConstraint` object,
+ and returns a value based on
+ whether this constraint should definitely be included or excluded as
+ an inline constraint, or neither. If it returns False, the constraint
+ will definitely be included as a dependency that cannot be subject
+ to ALTER; if True, it will **only** be included as an ALTER result at
+ the end. Returning None means the constraint is included in the
+ table-based result unless it is detected as part of a dependency cycle.
+
+ :param extra_dependencies: a sequence of 2-tuples of tables which will
+ also be considered as dependent on each other.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :func:`.sort_tables`
+
+
+ """
+
+ fixed_dependencies = set()
+ mutable_dependencies = set()
+
+ if extra_dependencies is not None:
+ fixed_dependencies.update(extra_dependencies)
+
+ remaining_fkcs = set()
+ for table in tables:
+ for fkc in table.foreign_key_constraints:
+ if fkc.use_alter is True:
+ remaining_fkcs.add(fkc)
+ continue
+
+ if filter_fn:
+ filtered = filter_fn(fkc)
+
+ if filtered is True:
+ remaining_fkcs.add(fkc)
+ continue
+
+ dependent_on = fkc.referred_table
+ if dependent_on is not table:
+ mutable_dependencies.add((dependent_on, table))
+
+ fixed_dependencies.update(
+ (parent, table) for parent in table._extra_dependencies
+ )
+
+ try:
+ candidate_sort = list(
+ topological.sort(
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ )
+ )
+ except exc.CircularDependencyError as err:
+ if _warn_for_cycles:
+ util.warn(
+ "Cannot correctly sort tables; there are unresolvable cycles "
+ 'between tables "%s", which is usually caused by mutually '
+ "dependent foreign key constraints. Foreign key constraints "
+ "involving these tables will not be considered; this warning "
+ "may raise an error in a future release."
+ % (", ".join(sorted(t.fullname for t in err.cycles)),)
+ )
+ for edge in err.edges:
+ if edge in mutable_dependencies:
+ table = edge[1]
+ if table not in err.cycles:
+ continue
+ can_remove = [
+ fkc
+ for fkc in table.foreign_key_constraints
+ if filter_fn is None or filter_fn(fkc) is not False
+ ]
+ remaining_fkcs.update(can_remove)
+ for fkc in can_remove:
+ dependent_on = fkc.referred_table
+ if dependent_on is not table:
+ mutable_dependencies.discard((dependent_on, table))
+ candidate_sort = list(
+ topological.sort(
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ )
+ )
+
+ return [
+ (table, table.foreign_key_constraints.difference(remaining_fkcs))
+ for table in candidate_sort
+ ] + [(None, list(remaining_fkcs))]
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
new file mode 100644
index 0000000..70586c6
--- /dev/null
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -0,0 +1,360 @@
+# sql/default_comparator.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Default implementation of SQL comparison operations.
+"""
+
+
+from . import coercions
+from . import operators
+from . import roles
+from . import type_api
+from .elements import and_
+from .elements import BinaryExpression
+from .elements import ClauseList
+from .elements import collate
+from .elements import CollectionAggregate
+from .elements import False_
+from .elements import Null
+from .elements import or_
+from .elements import True_
+from .elements import UnaryExpression
+from .. import exc
+from .. import util
+
+
+def _boolean_compare(
+ expr,
+ op,
+ obj,
+ negate=None,
+ reverse=False,
+ _python_is_types=(util.NoneType, bool),
+ _any_all_expr=False,
+ result_type=None,
+ **kwargs
+):
+
+ if result_type is None:
+ result_type = type_api.BOOLEANTYPE
+
+ if isinstance(obj, _python_is_types + (Null, True_, False_)):
+ # allow x ==/!= True/False to be treated as a literal.
+ # this comes out to "== / != true/false" or "1/0" if those
+ # constants aren't supported and works on all platforms
+ if op in (operators.eq, operators.ne) and isinstance(
+ obj, (bool, True_, False_)
+ ):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
+ elif op in (
+ operators.is_distinct_from,
+ operators.is_not_distinct_from,
+ ):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
+ elif _any_all_expr:
+ obj = coercions.expect(
+ roles.ConstExprRole, element=obj, operator=op, expr=expr
+ )
+ else:
+ # all other None uses IS, IS NOT
+ if op in (operators.eq, operators.is_):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ operators.is_,
+ negate=operators.is_not,
+ type_=result_type,
+ )
+ elif op in (operators.ne, operators.is_not):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ operators.is_not,
+ negate=operators.is_,
+ type_=result_type,
+ )
+ else:
+ raise exc.ArgumentError(
+ "Only '=', '!=', 'is_()', 'is_not()', "
+ "'is_distinct_from()', 'is_not_distinct_from()' "
+ "operators can be used with None/True/False"
+ )
+ else:
+ obj = coercions.expect(
+ roles.BinaryElementRole, element=obj, operator=op, expr=expr
+ )
+
+ if reverse:
+ return BinaryExpression(
+ obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
+ else:
+ return BinaryExpression(
+ expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
+
+
+def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
+ if result_type is None:
+ if op.return_type:
+ result_type = op.return_type
+ elif op.is_comparison:
+ result_type = type_api.BOOLEANTYPE
+
+ return _binary_operate(
+ expr, op, obj, reverse=reverse, result_type=result_type, **kw
+ )
+
+
+def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
+ obj = coercions.expect(
+ roles.BinaryElementRole, obj, expr=expr, operator=op
+ )
+
+ if reverse:
+ left, right = obj, expr
+ else:
+ left, right = expr, obj
+
+ if result_type is None:
+ op, result_type = left.comparator._adapt_expression(
+ op, right.comparator
+ )
+
+ return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
+
+
+def _conjunction_operate(expr, op, other, **kw):
+ if op is operators.and_:
+ return and_(expr, other)
+ elif op is operators.or_:
+ return or_(expr, other)
+ else:
+ raise NotImplementedError()
+
+
+def _scalar(expr, op, fn, **kw):
+ return fn(expr)
+
+
+def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
+ seq_or_selectable = coercions.expect(
+ roles.InElementRole, seq_or_selectable, expr=expr, operator=op
+ )
+ if "in_ops" in seq_or_selectable._annotations:
+ op, negate_op = seq_or_selectable._annotations["in_ops"]
+
+ return _boolean_compare(
+ expr, op, seq_or_selectable, negate=negate_op, **kw
+ )
+
+
+def _getitem_impl(expr, op, other, **kw):
+ if (
+ isinstance(expr.type, type_api.INDEXABLE)
+ or isinstance(expr.type, type_api.TypeDecorator)
+ and isinstance(expr.type.impl, type_api.INDEXABLE)
+ ):
+ other = coercions.expect(
+ roles.BinaryElementRole, other, expr=expr, operator=op
+ )
+ return _binary_operate(expr, op, other, **kw)
+ else:
+ _unsupported_impl(expr, op, other, **kw)
+
+
+def _unsupported_impl(expr, op, *arg, **kw):
+ raise NotImplementedError(
+ "Operator '%s' is not supported on " "this expression" % op.__name__
+ )
+
+
+def _inv_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.__inv__`."""
+
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
+ if hasattr(expr, "negation_clause"):
+ return expr.negation_clause
+ else:
+ return expr._negate()
+
+
+def _neg_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.__neg__`."""
+ return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
+
+
+def _match_impl(expr, op, other, **kw):
+ """See :meth:`.ColumnOperators.match`."""
+
+ return _boolean_compare(
+ expr,
+ operators.match_op,
+ coercions.expect(
+ roles.BinaryElementRole,
+ other,
+ expr=expr,
+ operator=operators.match_op,
+ ),
+ result_type=type_api.MATCHTYPE,
+ negate=operators.not_match_op
+ if op is operators.match_op
+ else operators.match_op,
+ **kw
+ )
+
+
+def _distinct_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.distinct`."""
+ return UnaryExpression(
+ expr, operator=operators.distinct_op, type_=expr.type
+ )
+
+
+def _between_impl(expr, op, cleft, cright, **kw):
+ """See :meth:`.ColumnOperators.between`."""
+ return BinaryExpression(
+ expr,
+ ClauseList(
+ coercions.expect(
+ roles.BinaryElementRole,
+ cleft,
+ expr=expr,
+ operator=operators.and_,
+ ),
+ coercions.expect(
+ roles.BinaryElementRole,
+ cright,
+ expr=expr,
+ operator=operators.and_,
+ ),
+ operator=operators.and_,
+ group=False,
+ group_contents=False,
+ ),
+ op,
+ negate=operators.not_between_op
+ if op is operators.between_op
+ else operators.between_op,
+ modifiers=kw,
+ )
+
+
+def _collate_impl(expr, op, other, **kw):
+ return collate(expr, other)
+
+
+def _regexp_match_impl(expr, op, pattern, flags, **kw):
+ if flags is not None:
+ flags = coercions.expect(
+ roles.BinaryElementRole,
+ flags,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ return _boolean_compare(
+ expr,
+ op,
+ pattern,
+ flags=flags,
+ negate=operators.not_regexp_match_op
+ if op is operators.regexp_match_op
+ else operators.regexp_match_op,
+ **kw
+ )
+
+
+def _regexp_replace_impl(expr, op, pattern, replacement, flags, **kw):
+ replacement = coercions.expect(
+ roles.BinaryElementRole,
+ replacement,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ if flags is not None:
+ flags = coercions.expect(
+ roles.BinaryElementRole,
+ flags,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ return _binary_operate(
+ expr, op, pattern, replacement=replacement, flags=flags, **kw
+ )
+
+
+# a mapping of operators with the method they use, along with
+# their negated operator for comparison operators
+operator_lookup = {
+ "and_": (_conjunction_operate,),
+ "or_": (_conjunction_operate,),
+ "inv": (_inv_impl,),
+ "add": (_binary_operate,),
+ "mul": (_binary_operate,),
+ "sub": (_binary_operate,),
+ "div": (_binary_operate,),
+ "mod": (_binary_operate,),
+ "truediv": (_binary_operate,),
+ "custom_op": (_custom_op_operate,),
+ "json_path_getitem_op": (_binary_operate,),
+ "json_getitem_op": (_binary_operate,),
+ "concat_op": (_binary_operate,),
+ "any_op": (_scalar, CollectionAggregate._create_any),
+ "all_op": (_scalar, CollectionAggregate._create_all),
+ "lt": (_boolean_compare, operators.ge),
+ "le": (_boolean_compare, operators.gt),
+ "ne": (_boolean_compare, operators.eq),
+ "gt": (_boolean_compare, operators.le),
+ "ge": (_boolean_compare, operators.lt),
+ "eq": (_boolean_compare, operators.ne),
+ "is_distinct_from": (_boolean_compare, operators.is_not_distinct_from),
+ "is_not_distinct_from": (_boolean_compare, operators.is_distinct_from),
+ "like_op": (_boolean_compare, operators.not_like_op),
+ "ilike_op": (_boolean_compare, operators.not_ilike_op),
+ "not_like_op": (_boolean_compare, operators.like_op),
+ "not_ilike_op": (_boolean_compare, operators.ilike_op),
+ "contains_op": (_boolean_compare, operators.not_contains_op),
+ "startswith_op": (_boolean_compare, operators.not_startswith_op),
+ "endswith_op": (_boolean_compare, operators.not_endswith_op),
+ "desc_op": (_scalar, UnaryExpression._create_desc),
+ "asc_op": (_scalar, UnaryExpression._create_asc),
+ "nulls_first_op": (_scalar, UnaryExpression._create_nulls_first),
+ "nulls_last_op": (_scalar, UnaryExpression._create_nulls_last),
+ "in_op": (_in_impl, operators.not_in_op),
+ "not_in_op": (_in_impl, operators.in_op),
+ "is_": (_boolean_compare, operators.is_),
+ "is_not": (_boolean_compare, operators.is_not),
+ "collate": (_collate_impl,),
+ "match_op": (_match_impl,),
+ "not_match_op": (_match_impl,),
+ "distinct_op": (_distinct_impl,),
+ "between_op": (_between_impl,),
+ "not_between_op": (_between_impl,),
+ "neg": (_neg_impl,),
+ "getitem": (_getitem_impl,),
+ "lshift": (_unsupported_impl,),
+ "rshift": (_unsupported_impl,),
+ "contains": (_unsupported_impl,),
+ "regexp_match_op": (_regexp_match_impl,),
+ "not_regexp_match_op": (_regexp_match_impl,),
+ "regexp_replace_op": (_regexp_replace_impl,),
+}
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
new file mode 100644
index 0000000..07a4d7b
--- /dev/null
+++ b/lib/sqlalchemy/sql/dml.py
@@ -0,0 +1,1514 @@
+# sql/dml.py
+# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+Provide :class:`_expression.Insert`, :class:`_expression.Update` and
+:class:`_expression.Delete`.
+
+"""
+from sqlalchemy.types import NullType
+from . import coercions
+from . import roles
+from . import util as sql_util
+from .base import _entity_namespace_key
+from .base import _exclusive_against
+from .base import _from_objects
+from .base import _generative
+from .base import ColumnCollection
+from .base import CompileState
+from .base import DialectKWArgs
+from .base import Executable
+from .base import HasCompileState
+from .elements import BooleanClauseList
+from .elements import ClauseElement
+from .elements import Null
+from .selectable import HasCTE
+from .selectable import HasPrefixes
+from .selectable import ReturnsRows
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..util import collections_abc
+
+
+class DMLState(CompileState):
+ _no_parameters = True
+ _dict_parameters = None
+ _multi_parameters = None
+ _ordered_values = None
+ _parameter_ordering = None
+ _has_multi_parameters = False
+ isupdate = False
+ isdelete = False
+ isinsert = False
+
+ def __init__(self, statement, compiler, **kw):
+ raise NotImplementedError()
+
+ @classmethod
+ def get_entity_description(cls, statement):
+ return {"name": statement.table.name, "table": statement.table}
+
+ @classmethod
+ def get_returning_column_descriptions(cls, statement):
+ return [
+ {
+ "name": c.key,
+ "type": c.type,
+ "expr": c,
+ }
+ for c in statement._all_selected_columns
+ ]
+
+ @property
+ def dml_table(self):
+ return self.statement.table
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ return [
+ (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in kv_iterator
+ ]
+
+ def _make_extra_froms(self, statement):
+ froms = []
+
+ all_tables = list(sql_util.tables_from_leftmost(statement.table))
+ seen = {all_tables[0]}
+
+ for crit in statement._where_criteria:
+ for item in _from_objects(crit):
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ froms.extend(all_tables[1:])
+ return froms
+
+ def _process_multi_values(self, statement):
+ if not statement._supports_multi_parameters:
+ raise exc.InvalidRequestError(
+ "%s construct does not support "
+ "multiple parameter sets." % statement.__visit_name__.upper()
+ )
+
+ for parameters in statement._multi_values:
+ multi_parameters = [
+ {
+ c.key: value
+ for c, value in zip(statement.table.c, parameter_set)
+ }
+ if isinstance(parameter_set, collections_abc.Sequence)
+ else parameter_set
+ for parameter_set in parameters
+ ]
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._has_multi_parameters = True
+ self._multi_parameters = multi_parameters
+ self._dict_parameters = self._multi_parameters[0]
+ elif not self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ self._multi_parameters.extend(multi_parameters)
+
+ def _process_values(self, statement):
+ if self._no_parameters:
+ self._has_multi_parameters = False
+ self._dict_parameters = statement._values
+ self._no_parameters = False
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+
+ def _process_ordered_values(self, statement):
+ parameters = statement._ordered_values
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = dict(parameters)
+ self._ordered_values = parameters
+ self._parameter_ordering = [key for key, value in parameters]
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ raise exc.InvalidRequestError(
+ "Can only invoke ordered_values() once, and not mixed "
+ "with any other values() call"
+ )
+
+ def _process_select_values(self, statement):
+ parameters = {
+ coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
+ for name in statement._select_names
+ }
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = parameters
+ else:
+ # this condition normally not reachable as the Insert
+ # does not allow this construction to occur
+ assert False, "This statement already has parameters"
+
+ def _cant_mix_formats_error(self):
+ raise exc.InvalidRequestError(
+ "Can't mix single and multiple VALUES "
+ "formats in one INSERT statement; one style appends to a "
+ "list while the other replaces values, so the intent is "
+ "ambiguous."
+ )
+
+
+@CompileState.plugin_for("default", "insert")
+class InsertDMLState(DMLState):
+ isinsert = True
+
+ include_table_with_column_exprs = False
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ self.isinsert = True
+ if statement._select_names:
+ self._process_select_values(statement)
+ if statement._values is not None:
+ self._process_values(statement)
+ if statement._multi_values:
+ self._process_multi_values(statement)
+
+ @util.memoized_property
+ def _insert_col_keys(self):
+ # this is also done in crud.py -> _key_getters_for_crud_column
+ return [
+ coercions.expect_as_key(roles.DMLColumnRole, col)
+ for col in self._dict_parameters
+ ]
+
+
+@CompileState.plugin_for("default", "update")
+class UpdateDMLState(DMLState):
+ isupdate = True
+
+ include_table_with_column_exprs = False
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+ self.isupdate = True
+ self._preserve_parameter_order = statement._preserve_parameter_order
+ if statement._ordered_values is not None:
+ self._process_ordered_values(statement)
+ elif statement._values is not None:
+ self._process_values(statement)
+ elif statement._multi_values:
+ self._process_multi_values(statement)
+ self._extra_froms = ef = self._make_extra_froms(statement)
+ self.is_multitable = mt = ef and self._dict_parameters
+ self.include_table_with_column_exprs = (
+ mt and compiler.render_table_with_column_in_update_from
+ )
+
+
+@CompileState.plugin_for("default", "delete")
+class DeleteDMLState(DMLState):
+ isdelete = True
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ self.isdelete = True
+ self._extra_froms = self._make_extra_froms(statement)
+
+
+class UpdateBase(
+ roles.DMLRole,
+ HasCTE,
+ HasCompileState,
+ DialectKWArgs,
+ HasPrefixes,
+ ReturnsRows,
+ Executable,
+ ClauseElement,
+):
+ """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
+
+ __visit_name__ = "update_base"
+
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
+ _hints = util.immutabledict()
+ named_with_column = False
+
+ _return_defaults = False
+ _return_defaults_columns = None
+ _returning = ()
+
+ is_dml = True
+
+ @classmethod
+ def _constructor_20_deprecations(cls, fn_name, clsname, names):
+
+ param_to_method_lookup = dict(
+ whereclause=(
+ "The :paramref:`%(func)s.whereclause` parameter "
+ "will be removed "
+ "in SQLAlchemy 2.0. Please refer to the "
+ ":meth:`%(classname)s.where` method."
+ ),
+ values=(
+ "The :paramref:`%(func)s.values` parameter will be removed "
+ "in SQLAlchemy 2.0. Please refer to the "
+ ":meth:`%(classname)s.values` method."
+ ),
+ bind=(
+ "The :paramref:`%(func)s.bind` parameter will be removed in "
+ "SQLAlchemy 2.0. Please use explicit connection execution."
+ ),
+ inline=(
+ "The :paramref:`%(func)s.inline` parameter will be "
+ "removed in "
+ "SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.inline` method."
+ ),
+ prefixes=(
+ "The :paramref:`%(func)s.prefixes parameter will be "
+ "removed in "
+ "SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.prefix_with` "
+ "method."
+ ),
+ return_defaults=(
+ "The :paramref:`%(func)s.return_defaults` parameter will be "
+ "removed in SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.return_defaults` method."
+ ),
+ returning=(
+ "The :paramref:`%(func)s.returning` parameter will be "
+ "removed in SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.returning`` method."
+ ),
+ preserve_parameter_order=(
+ "The :paramref:`%(func)s.preserve_parameter_order` parameter "
+ "will be removed in SQLAlchemy 2.0. Use the "
+ ":meth:`%(classname)s.ordered_values` method with a list "
+ "of tuples. "
+ ),
+ )
+
+ return util.deprecated_params(
+ **{
+ name: (
+ "2.0",
+ param_to_method_lookup[name]
+ % {
+ "func": "_expression.%s" % fn_name,
+ "classname": "_expression.%s" % clsname,
+ },
+ )
+ for name in names
+ }
+ )
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self._returning
+ )
+
+ def params(self, *arg, **kw):
+ """Set the parameters for the statement.
+
+ This method raises ``NotImplementedError`` on the base class,
+ and is overridden by :class:`.ValuesBase` to provide the
+ SET/VALUES clause of UPDATE and INSERT.
+
+ """
+ raise NotImplementedError(
+ "params() is not supported for INSERT/UPDATE/DELETE statements."
+ " To set the values for an INSERT or UPDATE statement, use"
+ " stmt.values(**parameters)."
+ )
+
+ @_generative
+ def with_dialect_options(self, **opt):
+ """Add dialect options to this INSERT/UPDATE/DELETE object.
+
+ e.g.::
+
+ upd = table.update().dialect_options(mysql_limit=10)
+
+ .. versionadded: 1.4 - this method supersedes the dialect options
+ associated with the constructor.
+
+
+ """
+ self._validate_dialect_kwargs(opt)
+
+ def _validate_dialect_kwargs_deprecated(self, dialect_kw):
+ util.warn_deprecated_20(
+ "Passing dialect keyword arguments directly to the "
+ "%s constructor is deprecated and will be removed in SQLAlchemy "
+ "2.0. Please use the ``with_dialect_options()`` method."
+ % (self.__class__.__name__)
+ )
+ self._validate_dialect_kwargs(dialect_kw)
+
+ def bind(self):
+ """Return a 'bind' linked to this :class:`.UpdateBase`
+ or a :class:`_schema.Table` associated with it.
+
+ """
+ return self._bind or self.table.bind
+
+ def _set_bind(self, bind):
+ self._bind = bind
+
+ bind = property(bind, _set_bind)
+
+ @_generative
+ def returning(self, *cols):
+ r"""Add a :term:`RETURNING` or equivalent clause to this statement.
+
+ e.g.:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = (
+ ... table.update()
+ ... .where(table.c.data == "value")
+ ... .values(status="X")
+ ... .returning(table.c.server_flag, table.c.updated_timestamp)
+ ... )
+ >>> print(stmt)
+ UPDATE some_table SET status=:status
+ WHERE some_table.data = :data_1
+ RETURNING some_table.server_flag, some_table.updated_timestamp
+
+ The method may be invoked multiple times to add new entries to the
+ list of expressions to be returned.
+
+ .. versionadded:: 1.4.0b2 The method may be invoked multiple times to
+ add new entries to the list of expressions to be returned.
+
+ The given collection of column expressions should be derived from the
+ table that is the target of the INSERT, UPDATE, or DELETE. While
+ :class:`_schema.Column` objects are typical, the elements can also be
+ expressions:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = table.insert().returning(
+ ... (table.c.first_name + " " + table.c.last_name).label("fullname")
+ ... )
+ >>> print(stmt)
+ INSERT INTO some_table (first_name, last_name)
+ VALUES (:first_name, :last_name)
+ RETURNING some_table.first_name || :first_name_1 || some_table.last_name AS fullname
+
+ Upon compilation, a RETURNING clause, or database equivalent,
+ will be rendered within the statement. For INSERT and UPDATE,
+ the values are the newly inserted/updated values. For DELETE,
+ the values are those of the rows which were deleted.
+
+ Upon execution, the values of the columns to be returned are made
+ available via the result set and can be iterated using
+ :meth:`_engine.CursorResult.fetchone` and similar.
+ For DBAPIs which do not
+ natively support returning values (i.e. cx_oracle), SQLAlchemy will
+ approximate this behavior at the result level so that a reasonable
+ amount of behavioral neutrality is provided.
+
+ Note that not all databases/DBAPIs
+ support RETURNING. For those backends with no support,
+ an exception is raised upon compilation and/or execution.
+ For those who do support it, the functionality across backends
+ varies greatly, including restrictions on executemany()
+ and other statements which return multiple rows. Please
+ read the documentation notes for the database in use in
+ order to determine the availability of RETURNING.
+
+ .. seealso::
+
+ :meth:`.ValuesBase.return_defaults` - an alternative method tailored
+ towards efficient fetching of server-side defaults and triggers
+ for single-row INSERTs or UPDATEs.
+
+ :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+ if self._return_defaults:
+ raise exc.InvalidRequestError(
+ "return_defaults() is already configured on this statement"
+ )
+ self._returning += tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
+
+ @property
+ def _all_selected_columns(self):
+ return self._returning
+
+ @property
+ def exported_columns(self):
+ """Return the RETURNING columns as a column collection for this
+ statement.
+
+ .. versionadded:: 1.4
+
+ """
+ # TODO: no coverage here
+ return ColumnCollection(
+ (c.key, c) for c in self._all_selected_columns
+ ).as_immutable()
+
+ @_generative
+ def with_hint(self, text, selectable=None, dialect_name="*"):
+ """Add a table hint for a single table to this
+ INSERT/UPDATE/DELETE statement.
+
+ .. note::
+
+ :meth:`.UpdateBase.with_hint` currently applies only to
+ Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use
+ :meth:`.UpdateBase.prefix_with`.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the :class:`_schema.Table` that is the subject of this
+ statement, or optionally to that of the given
+ :class:`_schema.Table` passed as the ``selectable`` argument.
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add a hint
+ that only takes effect for SQL Server::
+
+ mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
+
+ :param text: Text of the hint.
+ :param selectable: optional :class:`_schema.Table` that specifies
+ an element of the FROM clause within an UPDATE or DELETE
+ to be the subject of the hint - applies only to certain backends.
+ :param dialect_name: defaults to ``*``, if specified as the name
+ of a particular dialect, will apply these hints only when
+ that dialect is in use.
+ """
+ if selectable is None:
+ selectable = self.table
+
+ self._hints = self._hints.union({(selectable, dialect_name): text})
+
+ @property
+ def entity_description(self):
+ """Return a :term:`plugin-enabled` description of the table and/or
+ entity which this DML construct is operating against.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core statement, the structure returned by this accessor
+ is derived from the :attr:`.UpdateBase.table` attribute, and
+ refers to the :class:`.Table` being inserted, updated, or deleted::
+
+ >>> stmt = insert(user_table)
+ >>> stmt.entity_description
+ {
+ "name": "user_table",
+ "table": Table("user_table", ...)
+ }
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :attr:`.UpdateBase.returning_column_descriptions`
+
+ :attr:`.Select.column_descriptions` - entity information for
+ a :func:`.select` construct
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """
+ meth = DMLState.get_plugin_class(self).get_entity_description
+ return meth(self)
+
+ @property
+ def returning_column_descriptions(self):
+ """Return a :term:`plugin-enabled` description of the columns
+ which this DML construct is RETURNING against, in other words
+ the expressions established as part of :meth:`.UpdateBase.returning`.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core statement, the structure returned by this accessor is
+ derived from the same objects that are returned by the
+ :attr:`.UpdateBase.exported_columns` accessor::
+
+ >>> stmt = insert(user_table).returning(user_table.c.id, user_table.c.name)
+ >>> stmt.entity_description
+ [
+ {
+ "name": "id",
+ "type": Integer,
+ "expr": Column("id", Integer(), table=<user>, ...)
+ },
+ {
+ "name": "name",
+ "type": String(),
+ "expr": Column("name", String(), table=<user>, ...)
+ },
+ ]
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :attr:`.UpdateBase.entity_description`
+
+ :attr:`.Select.column_descriptions` - entity information for
+ a :func:`.select` construct
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """ # noqa: E501
+ meth = DMLState.get_plugin_class(
+ self
+ ).get_returning_column_descriptions
+ return meth(self)
+
+
+class ValuesBase(UpdateBase):
+ """Supplies support for :meth:`.ValuesBase.values` to
+ INSERT and UPDATE constructs."""
+
+ __visit_name__ = "values_base"
+
+ _supports_multi_parameters = False
+ _preserve_parameter_order = False
+ select = None
+ _post_values_clause = None
+
+ _values = None
+ _multi_values = ()
+ _ordered_values = None
+ _select_names = None
+
+ _returning = ()
+
+ def __init__(self, table, values, prefixes):
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
+ if values is not None:
+ self.values.non_generative(self, values)
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ @_generative
+ @_exclusive_against(
+ "_select_names",
+ "_ordered_values",
+ msgs={
+ "_select_names": "This construct already inserts from a SELECT",
+ "_ordered_values": "This statement already has ordered "
+ "values present",
+ },
+ )
+ def values(self, *args, **kwargs):
+ r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
+ clause for an UPDATE.
+
+ Note that the :class:`_expression.Insert` and
+ :class:`_expression.Update`
+ constructs support
+ per-execution time formatting of the VALUES and/or SET clauses,
+ based on the arguments passed to :meth:`_engine.Connection.execute`.
+ However, the :meth:`.ValuesBase.values` method can be used to "fix" a
+ particular set of parameters into the statement.
+
+ Multiple calls to :meth:`.ValuesBase.values` will produce a new
+ construct, each one with the parameter list modified to include
+ the new parameters sent. In the typical case of a single
+ dictionary of parameters, the newly passed keys will replace
+ the same keys in the previous construct. In the case of a list-based
+ "multiple values" construct, each new list of values is extended
+ onto the existing list of values.
+
+ :param \**kwargs: key value pairs representing the string key
+ of a :class:`_schema.Column`
+ mapped to the value to be rendered into the
+ VALUES or SET clause::
+
+ users.insert().values(name="some name")
+
+ users.update().where(users.c.id==5).values(name="some name")
+
+ :param \*args: As an alternative to passing key/value parameters,
+ a dictionary, tuple, or list of dictionaries or tuples can be passed
+ as a single positional argument in order to form the VALUES or
+ SET clause of the statement. The forms that are accepted vary
+ based on whether this is an :class:`_expression.Insert` or an
+ :class:`_expression.Update` construct.
+
+ For either an :class:`_expression.Insert` or
+ :class:`_expression.Update`
+ construct, a single dictionary can be passed, which works the same as
+ that of the kwargs form::
+
+ users.insert().values({"name": "some name"})
+
+ users.update().values({"name": "some new name"})
+
+ Also for either form but more typically for the
+ :class:`_expression.Insert` construct, a tuple that contains an
+ entry for every column in the table is also accepted::
+
+ users.insert().values((5, "some name"))
+
+ The :class:`_expression.Insert` construct also supports being
+ passed a list of dictionaries or full-table-tuples, which on the
+ server will render the less common SQL syntax of "multiple values" -
+ this syntax is supported on backends such as SQLite, PostgreSQL,
+ MySQL, but not necessarily others::
+
+ users.insert().values([
+ {"name": "some name"},
+ {"name": "some other name"},
+ {"name": "yet another name"},
+ ])
+
+ The above form would render a multiple VALUES statement similar to::
+
+ INSERT INTO users (name) VALUES
+ (:name_1),
+ (:name_2),
+ (:name_3)
+
+ It is essential to note that **passing multiple values is
+ NOT the same as using traditional executemany() form**. The above
+ syntax is a **special** syntax not typically used. To emit an
+ INSERT statement against multiple rows, the normal method is
+ to pass a multiple values list to the
+ :meth:`_engine.Connection.execute`
+ method, which is supported by all database backends and is generally
+ more efficient for a very large number of parameters.
+
+ .. seealso::
+
+ :ref:`tutorial_multiple_parameters` - an introduction to
+ the traditional Core method of multiple parameter set
+ invocation for INSERTs and other statements.
+
+ .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES
+ clause, even a list of length one,
+ implies that the :paramref:`_expression.Insert.inline`
+ flag is set to
+ True, indicating that the statement will not attempt to fetch
+ the "last inserted primary key" or other defaults. The
+ statement deals with an arbitrary number of rows, so the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ accessor does not
+ apply.
+
+ .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports
+ columns with Python side default values and callables in the
+ same way as that of an "executemany" style of invocation; the
+ callable is invoked for each row. See :ref:`bug_3288`
+ for other details.
+
+ The UPDATE construct also supports rendering the SET parameters
+ in a specific order. For this feature refer to the
+ :meth:`_expression.Update.ordered_values` method.
+
+ .. seealso::
+
+ :meth:`_expression.Update.ordered_values`
+
+
+ """
+ if args:
+ # positional case. this is currently expensive. we don't
+ # yet have positional-only args so we have to check the length.
+ # then we need to check multiparams vs. single dictionary.
+ # since the parameter format is needed in order to determine
+ # a cache key, we need to determine this up front.
+ arg = args[0]
+
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't pass positional and kwargs to values() "
+ "simultaneously"
+ )
+ elif len(args) > 1:
+ raise exc.ArgumentError(
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
+ )
+
+ elif not self._preserve_parameter_order and isinstance(
+ arg, collections_abc.Sequence
+ ):
+
+ if arg and isinstance(arg[0], (list, dict, tuple)):
+ self._multi_values += (arg,)
+ return
+
+ # tuple values
+ arg = {c.key: value for c, value in zip(self.table.c, arg)}
+ elif self._preserve_parameter_order and not isinstance(
+ arg, collections_abc.Sequence
+ ):
+ raise ValueError(
+ "When preserve_parameter_order is True, "
+ "values() only accepts a list of 2-tuples"
+ )
+
+ else:
+ # kwarg path. this is the most common path for non-multi-params
+ # so this is fairly quick.
+ arg = kwargs
+ if args:
+ raise exc.ArgumentError(
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
+ )
+
+ # for top level values(), convert literals to anonymous bound
+ # parameters at statement construction time, so that these values can
+ # participate in the cache key process like any other ClauseElement.
+ # crud.py now intercepts bound parameters with unique=True from here
+ # and ensures they get the "crud"-style name when rendered.
+
+ kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+
+ if self._preserve_parameter_order:
+ self._ordered_values = kv_generator(self, arg)
+ else:
+ arg = {k: v for k, v in kv_generator(self, arg.items())}
+ if self._values:
+ self._values = self._values.union(arg)
+ else:
+ self._values = util.immutabledict(arg)
+
+ @_generative
+ @_exclusive_against(
+ "_returning",
+ msgs={
+ "_returning": "RETURNING is already configured on this statement"
+ },
+ defaults={"_returning": _returning},
+ )
+ def return_defaults(self, *cols):
+ """Make use of a :term:`RETURNING` clause for the purpose
+ of fetching server-side expressions and defaults.
+
+ E.g.::
+
+ stmt = table.insert().values(data='newdata').return_defaults()
+
+ result = connection.execute(stmt)
+
+ server_created_at = result.returned_defaults['created_at']
+
+ When used against a backend that supports RETURNING, all column
+ values generated by SQL expression or server-side-default will be
+ added to any existing RETURNING clause, provided that
+ :meth:`.UpdateBase.returning` is not used simultaneously. The column
+ values will then be available on the result using the
+ :attr:`_engine.CursorResult.returned_defaults` accessor as
+ a dictionary,
+ referring to values keyed to the :class:`_schema.Column`
+ object as well as
+ its ``.key``.
+
+ This method differs from :meth:`.UpdateBase.returning` in these ways:
+
+ 1. :meth:`.ValuesBase.return_defaults` is only intended for use with an
+ INSERT or an UPDATE statement that matches exactly one row per
+ parameter set. While the RETURNING construct in the general sense
+ supports multiple rows for a multi-row UPDATE or DELETE statement,
+ or for special cases of INSERT that return multiple rows (e.g.
+ INSERT from SELECT, multi-valued VALUES clause),
+ :meth:`.ValuesBase.return_defaults` is intended only for an
+ "ORM-style" single-row INSERT/UPDATE statement. The row
+ returned by the statement is also consumed implicitly when
+ :meth:`.ValuesBase.return_defaults` is used. By contrast,
+ :meth:`.UpdateBase.returning` leaves the RETURNING result-set intact
+ with a collection of any number of rows.
+
+ 2. It is compatible with the existing logic to fetch auto-generated
+ primary key values, also known as "implicit returning". Backends
+ that support RETURNING will automatically make use of RETURNING in
+ order to fetch the value of newly generated primary keys; while the
+ :meth:`.UpdateBase.returning` method circumvents this behavior,
+ :meth:`.ValuesBase.return_defaults` leaves it intact.
+
+ 3. It can be called against any backend. Backends that don't support
+ RETURNING will skip the usage of the feature, rather than raising
+ an exception. The return value of
+ :attr:`_engine.CursorResult.returned_defaults` will be ``None``
+
+ 4. An INSERT statement invoked with executemany() is supported if the
+ backend database driver supports the
+ ``insert_executemany_returning`` feature, currently this includes
+ PostgreSQL with psycopg2. When executemany is used, the
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
+ will return the inserted defaults and primary keys.
+
+ .. versionadded:: 1.4
+
+ :meth:`.ValuesBase.return_defaults` is used by the ORM to provide
+ an efficient implementation for the ``eager_defaults`` feature of
+ :func:`.mapper`.
+
+ :param cols: optional list of column key names or
+ :class:`_schema.Column`
+ objects. If omitted, all column expressions evaluated on the server
+ are added to the returning list.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`.UpdateBase.returning`
+
+ :attr:`_engine.CursorResult.returned_defaults`
+
+ :attr:`_engine.CursorResult.returned_defaults_rows`
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ :attr:`_engine.CursorResult.inserted_primary_key_rows`
+
+ """
+ self._return_defaults = True
+ self._return_defaults_columns = cols
+
+
+class Insert(ValuesBase):
+ """Represent an INSERT construct.
+
+ The :class:`_expression.Insert` object is created using the
+ :func:`_expression.insert()` function.
+
+ """
+
+ __visit_name__ = "insert"
+
+ _supports_multi_parameters = True
+
+ select = None
+ include_insert_from_select_defaults = False
+
+ is_insert = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_select_names", InternalTraversal.dp_string_list),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_multi_values", InternalTraversal.dp_dml_multi_values),
+ ("select", InternalTraversal.dp_clauseelement),
+ ("_post_values_clause", InternalTraversal.dp_clauseelement),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "insert",
+ "Insert",
+ [
+ "values",
+ "inline",
+ "bind",
+ "prefixes",
+ "returning",
+ "return_defaults",
+ ],
+ )
+ def __init__(
+ self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ **dialect_kw
+ ):
+ """Construct an :class:`_expression.Insert` object.
+
+ E.g.::
+
+ from sqlalchemy import insert
+
+ stmt = (
+ insert(user_table).
+ values(name='username', fullname='Full Username')
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.insert` method on
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial`
+
+
+ :param table: :class:`_expression.TableClause`
+ which is the subject of the
+ insert.
+
+ :param values: collection of values to be inserted; see
+ :meth:`_expression.Insert.values`
+ for a description of allowed formats here.
+ Can be omitted entirely; a :class:`_expression.Insert` construct
+ will also dynamically render the VALUES clause at execution time
+ based on the parameters passed to :meth:`_engine.Connection.execute`.
+
+ :param inline: if True, no attempt will be made to retrieve the
+ SQL-generated default values to be provided within the statement;
+ in particular,
+ this allows SQL expressions to be rendered 'inline' within the
+ statement without the need to pre-execute them beforehand; for
+ backends that support "returning", this turns off the "implicit
+ returning" feature for the statement.
+
+ If both :paramref:`_expression.Insert.values` and compile-time bind
+ parameters are present, the compile-time bind parameters override the
+ information specified within :paramref:`_expression.Insert.values` on a
+ per-key basis.
+
+ The keys within :paramref:`_expression.Insert.values` can be either
+ :class:`~sqlalchemy.schema.Column` objects or their string
+ identifiers. Each key may reference one of:
+
+ * a literal data value (i.e. string, number, etc.);
+ * a Column object;
+ * a SELECT statement.
+
+ If a ``SELECT`` statement is specified which references this
+ ``INSERT`` statement's table, the statement will be correlated
+ against the ``INSERT`` statement.
+
+ .. seealso::
+
+ :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial`
+
+ """
+ super(Insert, self).__init__(table, values, prefixes)
+ self._bind = bind
+ self._inline = inline
+ if returning:
+ self._returning = returning
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
+
+ if return_defaults:
+ self._return_defaults = True
+ if not isinstance(return_defaults, bool):
+ self._return_defaults_columns = return_defaults
+
+ @_generative
+ def inline(self):
+ """Make this :class:`_expression.Insert` construct "inline" .
+
+ When set, no attempt will be made to retrieve the
+ SQL-generated default values to be provided within the statement;
+ in particular,
+ this allows SQL expressions to be rendered 'inline' within the
+ statement without the need to pre-execute them beforehand; for
+ backends that support "returning", this turns off the "implicit
+ returning" feature for the statement.
+
+
+ .. versionchanged:: 1.4 the :paramref:`_expression.Insert.inline`
+ parameter
+ is now superseded by the :meth:`_expression.Insert.inline` method.
+
+ """
+ self._inline = True
+
+ @_generative
+ def from_select(self, names, select, include_defaults=True):
+ """Return a new :class:`_expression.Insert` construct which represents
+ an ``INSERT...FROM SELECT`` statement.
+
+ e.g.::
+
+ sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5)
+ ins = table2.insert().from_select(['a', 'b'], sel)
+
+ :param names: a sequence of string column names or
+ :class:`_schema.Column`
+ objects representing the target columns.
+ :param select: a :func:`_expression.select` construct,
+ :class:`_expression.FromClause`
+ or other construct which resolves into a
+ :class:`_expression.FromClause`,
+ such as an ORM :class:`_query.Query` object, etc. The order of
+ columns returned from this FROM clause should correspond to the
+ order of columns sent as the ``names`` parameter; while this
+ is not checked before passing along to the database, the database
+ would normally raise an exception if these column lists don't
+ correspond.
+ :param include_defaults: if True, non-server default values and
+ SQL expressions as specified on :class:`_schema.Column` objects
+ (as documented in :ref:`metadata_defaults_toplevel`) not
+ otherwise specified in the list of names will be rendered
+ into the INSERT and SELECT statements, so that these values are also
+ included in the data to be inserted.
+
+ .. note:: A Python-side default that uses a Python callable function
+ will only be invoked **once** for the whole statement, and **not
+ per row**.
+
+ .. versionadded:: 1.0.0 - :meth:`_expression.Insert.from_select`
+ now renders
+ Python-side and SQL expression column defaults into the
+ SELECT statement for columns otherwise not included in the
+ list of column names.
+
+ .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
+ implies that the :paramref:`_expression.insert.inline`
+ flag is set to
+ True, indicating that the statement will not attempt to fetch
+ the "last inserted primary key" or other defaults. The statement
+ deals with an arbitrary number of rows, so the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ accessor does not apply.
+
+ """
+
+ if self._values:
+ raise exc.InvalidRequestError(
+ "This construct already inserts value expressions"
+ )
+
+ self._select_names = names
+ self._inline = True
+ self.include_insert_from_select_defaults = include_defaults
+ self.select = coercions.expect(roles.DMLSelectRole, select)
+
+
+class DMLWhereBase(object):
+ _where_criteria = ()
+
+ @_generative
+ def where(self, *whereclause):
+ """Return a new construct with the given expression(s) added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ Both :meth:`_dml.Update.where` and :meth:`_dml.Delete.where`
+ support multiple-table forms, including database-specific
+ ``UPDATE...FROM`` as well as ``DELETE..USING``. For backends that
+ don't have multiple-table support, a backend agnostic approach
+ to using multiple tables is to make use of correlated subqueries.
+ See the linked tutorial sections below for examples.
+
+ .. seealso::
+
+ :ref:`tutorial_correlated_updates`
+
+ :ref:`tutorial_update_from`
+
+ :ref:`tutorial_multi_table_deletes`
+
+ """
+
+ for criterion in whereclause:
+ where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ self._where_criteria += (where_criteria,)
+
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return self.where(*criteria)
+
+ def _filter_by_zero(self):
+ return self.table
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+ statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+
+class Update(DMLWhereBase, ValuesBase):
+ """Represent an Update construct.
+
+ The :class:`_expression.Update` object is created using the
+ :func:`_expression.update()` function.
+
+ """
+
+ __visit_name__ = "update"
+
+ is_update = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "update",
+ "Update",
+ [
+ "whereclause",
+ "values",
+ "inline",
+ "bind",
+ "prefixes",
+ "returning",
+ "return_defaults",
+ "preserve_parameter_order",
+ ],
+ )
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ preserve_parameter_order=False,
+ **dialect_kw
+ ):
+ r"""Construct an :class:`_expression.Update` object.
+
+ E.g.::
+
+ from sqlalchemy import update
+
+ stmt = (
+ update(user_table).
+ where(user_table.c.id == 5).
+ values(name='user #5')
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.update` method on
+ :class:`_schema.Table`.
+
+ :param table: A :class:`_schema.Table`
+ object representing the database
+ table to be updated.
+
+ :param whereclause: Optional SQL expression describing the ``WHERE``
+ condition of the ``UPDATE`` statement; is equivalent to using the
+ more modern :meth:`~Update.where()` method to specify the ``WHERE``
+ clause.
+
+ :param values:
+ Optional dictionary which specifies the ``SET`` conditions of the
+ ``UPDATE``. If left as ``None``, the ``SET``
+ conditions are determined from those parameters passed to the
+ statement during the execution and/or compilation of the
+ statement. When compiled standalone without any parameters,
+ the ``SET`` clause generates for all columns.
+
+ Modern applications may prefer to use the generative
+ :meth:`_expression.Update.values` method to set the values of the
+ UPDATE statement.
+
+ :param inline:
+ if True, SQL defaults present on :class:`_schema.Column` objects via
+ the ``default`` keyword will be compiled 'inline' into the statement
+ and not pre-executed. This means that their values will not
+ be available in the dictionary returned from
+ :meth:`_engine.CursorResult.last_updated_params`.
+
+ :param preserve_parameter_order: if True, the update statement is
+ expected to receive parameters **only** via the
+ :meth:`_expression.Update.values` method,
+ and they must be passed as a Python
+ ``list`` of 2-tuples. The rendered UPDATE statement will emit the SET
+ clause for each referenced column maintaining this order.
+
+ .. versionadded:: 1.0.10
+
+ .. seealso::
+
+ :ref:`updates_order_parameters` - illustrates the
+ :meth:`_expression.Update.ordered_values` method.
+
+ If both ``values`` and compile-time bind parameters are present, the
+ compile-time bind parameters override the information specified
+ within ``values`` on a per-key basis.
+
+ The keys within ``values`` can be either :class:`_schema.Column`
+ objects or their string identifiers (specifically the "key" of the
+ :class:`_schema.Column`, normally but not necessarily equivalent to
+ its "name"). Normally, the
+ :class:`_schema.Column` objects used here are expected to be
+ part of the target :class:`_schema.Table` that is the table
+ to be updated. However when using MySQL, a multiple-table
+ UPDATE statement can refer to columns from any of
+ the tables referred to in the WHERE clause.
+
+ The values referred to in ``values`` are typically:
+
+ * a literal data value (i.e. string, number, etc.)
+ * a SQL expression, such as a related :class:`_schema.Column`,
+ a scalar-returning :func:`_expression.select` construct,
+ etc.
+
+ When combining :func:`_expression.select` constructs within the
+ values clause of an :func:`_expression.update`
+ construct, the subquery represented
+ by the :func:`_expression.select` should be *correlated* to the
+ parent table, that is, providing criterion which links the table inside
+ the subquery to the outer table being updated::
+
+ users.update().values(
+ name=select(addresses.c.email_address).\
+ where(addresses.c.user_id==users.c.id).\
+ scalar_subquery()
+ )
+
+ .. seealso::
+
+ :ref:`inserts_and_updates` - SQL Expression
+ Language Tutorial
+
+
+ """
+ self._preserve_parameter_order = preserve_parameter_order
+ super(Update, self).__init__(table, values, prefixes)
+ self._bind = bind
+ if returning:
+ self._returning = returning
+ if whereclause is not None:
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+ self._inline = inline
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
+ self._return_defaults = return_defaults
+
+ @_generative
+ def ordered_values(self, *args):
+ """Specify the VALUES clause of this UPDATE statement with an explicit
+ parameter ordering that will be maintained in the SET clause of the
+ resulting UPDATE statement.
+
+ E.g.::
+
+ stmt = table.update().ordered_values(
+ ("name", "ed"), ("ident": "foo")
+ )
+
+ .. seealso::
+
+ :ref:`tutorial_parameter_ordered_updates` - full example of the
+ :meth:`_expression.Update.ordered_values` method.
+
+ .. versionchanged:: 1.4 The :meth:`_expression.Update.ordered_values`
+ method
+ supersedes the
+ :paramref:`_expression.update.preserve_parameter_order`
+ parameter, which will be removed in SQLAlchemy 2.0.
+
+ """
+ if self._values:
+ raise exc.ArgumentError(
+ "This statement already has values present"
+ )
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
+ )
+
+ kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+ self._ordered_values = kv_generator(self, args)
+
+ @_generative
+ def inline(self):
+ """Make this :class:`_expression.Update` construct "inline" .
+
+ When set, SQL defaults present on :class:`_schema.Column`
+ objects via the
+ ``default`` keyword will be compiled 'inline' into the statement and
+ not pre-executed. This means that their values will not be available
+ in the dictionary returned from
+ :meth:`_engine.CursorResult.last_updated_params`.
+
+ .. versionchanged:: 1.4 the :paramref:`_expression.update.inline`
+ parameter
+ is now superseded by the :meth:`_expression.Update.inline` method.
+
+ """
+ self._inline = True
+
+
+class Delete(DMLWhereBase, UpdateBase):
+ """Represent a DELETE construct.
+
+ The :class:`_expression.Delete` object is created using the
+ :func:`_expression.delete()` function.
+
+ """
+
+ __visit_name__ = "delete"
+
+ is_delete = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "delete",
+ "Delete",
+ ["whereclause", "values", "bind", "prefixes", "returning"],
+ )
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ bind=None,
+ returning=None,
+ prefixes=None,
+ **dialect_kw
+ ):
+ r"""Construct :class:`_expression.Delete` object.
+
+ E.g.::
+
+ from sqlalchemy import delete
+
+ stmt = (
+ delete(user_table).
+ where(user_table.c.id == 5)
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.delete` method on
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :ref:`inserts_and_updates` - in the
+ :ref:`1.x tutorial <sqlexpression_toplevel>`
+
+ :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
+
+
+ :param table: The table to delete rows from.
+
+ :param whereclause: Optional SQL expression describing the ``WHERE``
+ condition of the ``DELETE`` statement; is equivalent to using the
+ more modern :meth:`~Delete.where()` method to specify the ``WHERE``
+ clause.
+
+ .. seealso::
+
+ :ref:`deletes` - SQL Expression Tutorial
+
+ """
+ self._bind = bind
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
+ if returning:
+ self._returning = returning
+
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ if whereclause is not None:
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
new file mode 100644
index 0000000..268c0d6
--- /dev/null
+++ b/lib/sqlalchemy/sql/elements.py
@@ -0,0 +1,5415 @@
+# sql/elements.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Core SQL expression elements, including :class:`_expression.ClauseElement`,
+:class:`_expression.ColumnElement`, and derived classes.
+
+"""
+
+from __future__ import unicode_literals
+
+import itertools
+import operator
+import re
+
+from . import coercions
+from . import operators
+from . import roles
+from . import traversals
+from . import type_api
+from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
+from .base import _clone
+from .base import _generative
+from .base import Executable
+from .base import HasMemoized
+from .base import Immutable
+from .base import NO_ARG
+from .base import PARSE_AUTOCOMMIT
+from .base import SingletonConstant
+from .coercions import _document_text_coercion
+from .traversals import HasCopyInternals
+from .traversals import MemoizedHasCacheKey
+from .traversals import NO_CACHE
+from .visitors import cloned_traverse
+from .visitors import InternalTraversal
+from .visitors import traverse
+from .visitors import Traversible
+from .. import exc
+from .. import inspection
+from .. import util
+
+
+def collate(expression, collation):
+ """Return the clause ``expression COLLATE collation``.
+
+ e.g.::
+
+ collate(mycolumn, 'utf8_bin')
+
+ produces::
+
+ mycolumn COLLATE utf8_bin
+
+ The collation expression is also quoted if it is a case sensitive
+ identifier, e.g. contains uppercase characters.
+
+ .. versionchanged:: 1.2 quoting is automatically applied to COLLATE
+ expressions if they are case sensitive.
+
+ """
+
+ expr = coercions.expect(roles.ExpressionElementRole, expression)
+ return BinaryExpression(
+ expr, CollationClause(collation), operators.collate, type_=expr.type
+ )
+
+
+def between(expr, lower_bound, upper_bound, symmetric=False):
+ """Produce a ``BETWEEN`` predicate clause.
+
+ E.g.::
+
+ from sqlalchemy import between
+ stmt = select(users_table).where(between(users_table.c.id, 5, 7))
+
+ Would produce SQL resembling::
+
+ SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2
+
+ The :func:`.between` function is a standalone version of the
+ :meth:`_expression.ColumnElement.between` method available on all
+ SQL expressions, as in::
+
+ stmt = select(users_table).where(users_table.c.id.between(5, 7))
+
+ All arguments passed to :func:`.between`, including the left side
+ column expression, are coerced from Python scalar values if a
+ the value is not a :class:`_expression.ColumnElement` subclass.
+ For example,
+ three fixed values can be compared as in::
+
+ print(between(5, 3, 7))
+
+ Which would produce::
+
+ :param_1 BETWEEN :param_2 AND :param_3
+
+ :param expr: a column expression, typically a
+ :class:`_expression.ColumnElement`
+ instance or alternatively a Python scalar expression to be coerced
+ into a column expression, serving as the left side of the ``BETWEEN``
+ expression.
+
+ :param lower_bound: a column or Python scalar expression serving as the
+ lower bound of the right side of the ``BETWEEN`` expression.
+
+ :param upper_bound: a column or Python scalar expression serving as the
+ upper bound of the right side of the ``BETWEEN`` expression.
+
+ :param symmetric: if True, will render " BETWEEN SYMMETRIC ". Note
+ that not all databases support this syntax.
+
+ .. versionadded:: 0.9.5
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.between`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return expr.between(lower_bound, upper_bound, symmetric=symmetric)
+
+
+def literal(value, type_=None):
+ r"""Return a literal clause, bound to a bind parameter.
+
+ Literal clauses are created automatically when non-
+ :class:`_expression.ClauseElement` objects (such as strings, ints, dates,
+ etc.) are
+ used in a comparison operation with a :class:`_expression.ColumnElement`
+ subclass,
+ such as a :class:`~sqlalchemy.schema.Column` object. Use this function
+ to force the generation of a literal clause, which will be created as a
+ :class:`BindParameter` with a bound value.
+
+ :param value: the value to be bound. Can be any Python object supported by
+ the underlying DB-API, or is translatable via the given type argument.
+
+ :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which
+ will provide bind-parameter translation for this literal.
+
+ """
+ return coercions.expect(roles.LiteralValueRole, value, type_=type_)
+
+
+def outparam(key, type_=None):
+ r"""Create an 'OUT' parameter for usage in functions (stored procedures),
+ for databases which support them.
+
+ The ``outparam`` can be used like a regular function parameter.
+ The "output" value will be available from the
+ :class:`~sqlalchemy.engine.CursorResult` object via its ``out_parameters``
+ attribute, which returns a dictionary containing the values.
+
+ """
+ return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
+
+
+def not_(clause):
+ """Return a negation of the given clause, i.e. ``NOT(clause)``.
+
+ The ``~`` operator is also overloaded on all
+ :class:`_expression.ColumnElement` subclasses to produce the
+ same result.
+
+ """
+ return operators.inv(coercions.expect(roles.ExpressionElementRole, clause))
+
+
+@inspection._self_inspects
+class ClauseElement(
+ roles.SQLRole,
+ SupportsWrappingAnnotations,
+ MemoizedHasCacheKey,
+ HasCopyInternals,
+ Traversible,
+):
+ """Base class for elements of a programmatically constructed SQL
+ expression.
+
+ """
+
+ __visit_name__ = "clause"
+
+ _propagate_attrs = util.immutabledict()
+ """like annotations, however these propagate outwards liberally
+ as SQL constructs are built, and are set up at construction time.
+
+ """
+
+ supports_execution = False
+
+ stringify_dialect = "default"
+
+ _from_objects = []
+ bind = None
+ description = None
+ _is_clone_of = None
+
+ is_clause_element = True
+ is_selectable = False
+
+ _is_textual = False
+ _is_from_clause = False
+ _is_returns_rows = False
+ _is_text_clause = False
+ _is_from_container = False
+ _is_select_container = False
+ _is_select_statement = False
+ _is_bind_parameter = False
+ _is_clause_list = False
+ _is_lambda_element = False
+ _is_singleton_constant = False
+ _is_immutable = False
+ _is_star = False
+
+ _order_by_label_element = None
+
+ _cache_key_traversal = None
+
+ def _set_propagate_attrs(self, values):
+ # usually, self._propagate_attrs is empty here. one case where it's
+ # not is a subquery against ORM select, that is then pulled as a
+ # property of an aliased class. should all be good
+
+ # assert not self._propagate_attrs
+
+ self._propagate_attrs = util.immutabledict(values)
+ return self
+
+ def _clone(self, **kw):
+ """Create a shallow copy of this ClauseElement.
+
+ This method may be used by a generative API. Its also used as
+ part of the "deep" copy afforded by a traversal that combines
+ the _copy_internals() method.
+
+ """
+ skip = self._memoized_keys
+ c = self.__class__.__new__(self.__class__)
+
+ if skip:
+ # ensure this iteration remains atomic
+ c.__dict__ = {
+ k: v for k, v in self.__dict__.copy().items() if k not in skip
+ }
+ else:
+ c.__dict__ = self.__dict__.copy()
+
+ # this is a marker that helps to "equate" clauses to each other
+ # when a Select returns its list of FROM clauses. the cloning
+ # process leaves around a lot of remnants of the previous clause
+ # typically in the form of column expressions still attached to the
+ # old table.
+ cc = self._is_clone_of
+ c._is_clone_of = cc if cc is not None else self
+ return c
+
+ def _negate_in_binary(self, negated_op, original_op):
+ """a hook to allow the right side of a binary expression to respond
+ to a negation of the binary expression.
+
+ Used for the special case of expanding bind parameter with IN.
+
+ """
+ return self
+
+ def _with_binary_element_type(self, type_):
+ """in the context of binary expression, convert the type of this
+ object to the one given.
+
+ applies only to :class:`_expression.ColumnElement` classes.
+
+ """
+ return self
+
+ @property
+ def _constructor(self):
+ """return the 'constructor' for this ClauseElement.
+
+ This is for the purposes for creating a new object of
+ this type. Usually, its just the element's __class__.
+ However, the "Annotated" version of the object overrides
+ to return the class of its proxied element.
+
+ """
+ return self.__class__
+
+ @HasMemoized.memoized_attribute
+ def _cloned_set(self):
+ """Return the set consisting all cloned ancestors of this
+ ClauseElement.
+
+ Includes this ClauseElement. This accessor tends to be used for
+ FromClause objects to identify 'equivalent' FROM clauses, regardless
+ of transformative operations.
+
+ """
+ s = util.column_set()
+ f = self
+
+ # note this creates a cycle, asserted in test_memusage. however,
+ # turning this into a plain @property adds tends of thousands of method
+ # calls to Core / ORM performance tests, so the small overhead
+ # introduced by the relatively small amount of short term cycles
+ # produced here is preferable
+ while f is not None:
+ s.add(f)
+ f = f._is_clone_of
+ return s
+
+ @property
+ def entity_namespace(self):
+ raise AttributeError(
+ "This SQL expression has no entity namespace "
+ "with which to filter from."
+ )
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop("_is_clone_of", None)
+ d.pop("_generate_cache_key", None)
+ return d
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options, _force=False
+ ):
+ if _force or self.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+ def unique_params(self, *optionaldict, **kwargs):
+ """Return a copy with :func:`_expression.bindparam` elements
+ replaced.
+
+ Same functionality as :meth:`_expression.ClauseElement.params`,
+ except adds `unique=True`
+ to affected bind parameters so that multiple statements can be
+ used.
+
+ """
+ return self._replace_params(True, optionaldict, kwargs)
+
+ def params(self, *optionaldict, **kwargs):
+ """Return a copy with :func:`_expression.bindparam` elements
+ replaced.
+
+ Returns a copy of this ClauseElement with
+ :func:`_expression.bindparam`
+ elements replaced with values taken from the given dictionary::
+
+ >>> clause = column('x') + bindparam('foo')
+ >>> print(clause.compile().params)
+ {'foo':None}
+ >>> print(clause.params({'foo':7}).compile().params)
+ {'foo':7}
+
+ """
+ return self._replace_params(False, optionaldict, kwargs)
+
+ def _replace_params(self, unique, optionaldict, kwargs):
+
+ if len(optionaldict) == 1:
+ kwargs.update(optionaldict[0])
+ elif len(optionaldict) > 1:
+ raise exc.ArgumentError(
+ "params() takes zero or one positional dictionary argument"
+ )
+
+ def visit_bindparam(bind):
+ if bind.key in kwargs:
+ bind.value = kwargs[bind.key]
+ bind.required = False
+ if unique:
+ bind._convert_to_unique()
+
+ return cloned_traverse(
+ self,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
+ )
+
+ def compare(self, other, **kw):
+ r"""Compare this :class:`_expression.ClauseElement` to
+ the given :class:`_expression.ClauseElement`.
+
+ Subclasses should override the default behavior, which is a
+ straight identity comparison.
+
+ \**kw are arguments consumed by subclass ``compare()`` methods and
+ may be used to modify the criteria for comparison
+ (see :class:`_expression.ColumnElement`).
+
+ """
+ return traversals.compare(self, other, **kw)
+
+ def self_group(self, against=None):
+ """Apply a 'grouping' to this :class:`_expression.ClauseElement`.
+
+ This method is overridden by subclasses to return a "grouping"
+ construct, i.e. parenthesis. In particular it's used by "binary"
+ expressions to provide a grouping around themselves when placed into a
+ larger expression, as well as by :func:`_expression.select`
+ constructs when placed into the FROM clause of another
+ :func:`_expression.select`. (Note that subqueries should be
+ normally created using the :meth:`_expression.Select.alias` method,
+ as many
+ platforms require nested SELECT statements to be named).
+
+ As expressions are composed together, the application of
+ :meth:`self_group` is automatic - end-user code should never
+ need to use this method directly. Note that SQLAlchemy's
+ clause constructs take operator precedence into account -
+ so parenthesis might not be needed, for example, in
+ an expression like ``x OR (y AND z)`` - AND takes precedence
+ over OR.
+
+ The base :meth:`self_group` method of
+ :class:`_expression.ClauseElement`
+ just returns self.
+ """
+ return self
+
+ def _ungroup(self):
+ """Return this :class:`_expression.ClauseElement`
+ without any groupings.
+ """
+
+ return self
+
+ @util.preload_module("sqlalchemy.engine.default")
+ @util.preload_module("sqlalchemy.engine.url")
+ def compile(self, bind=None, dialect=None, **kw):
+ """Compile this SQL expression.
+
+ The return value is a :class:`~.Compiled` object.
+ Calling ``str()`` or ``unicode()`` on the returned value will yield a
+ string representation of the result. The
+ :class:`~.Compiled` object also can return a
+ dictionary of bind parameter names and values
+ using the ``params`` accessor.
+
+ :param bind: An ``Engine`` or ``Connection`` from which a
+ ``Compiled`` will be acquired. This argument takes precedence over
+ this :class:`_expression.ClauseElement`'s bound engine, if any.
+
+ :param column_keys: Used for INSERT and UPDATE statements, a list of
+ column names which should be present in the VALUES clause of the
+ compiled statement. If ``None``, all columns from the target table
+ object are rendered.
+
+ :param dialect: A ``Dialect`` instance from which a ``Compiled``
+ will be acquired. This argument takes precedence over the `bind`
+ argument as well as this :class:`_expression.ClauseElement`
+ 's bound engine,
+ if any.
+
+ :param compile_kwargs: optional dictionary of additional parameters
+ that will be passed through to the compiler within all "visit"
+ methods. This allows any custom flag to be passed through to
+ a custom compilation construct, for example. It is also used
+ for the case of passing the ``literal_binds`` flag through::
+
+ from sqlalchemy.sql import table, column, select
+
+ t = table('t', column('x'))
+
+ s = select(t).where(t.c.x == 5)
+
+ print(s.compile(compile_kwargs={"literal_binds": True}))
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ """
+
+ if not dialect:
+ if bind:
+ dialect = bind.dialect
+ elif self.bind:
+ dialect = self.bind.dialect
+ else:
+ if self.stringify_dialect == "default":
+ default = util.preloaded.engine_default
+ dialect = default.StrCompileDialect()
+ else:
+ url = util.preloaded.engine_url
+ dialect = url.URL.create(
+ self.stringify_dialect
+ ).get_dialect()()
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ for_executemany=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None and dialect._supports_statement_cache:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ for_executemany,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ cache_hit = dialect.CACHE_MISS
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = dialect.CACHE_HIT
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ if not dialect._supports_statement_cache:
+ cache_hit = dialect.NO_DIALECT_SUPPORT
+ elif compiled_cache is None:
+ cache_hit = dialect.CACHING_DISABLED
+ else:
+ cache_hit = dialect.NO_CACHE_KEY
+
+ return compiled_sql, extracted_params, cache_hit
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a
+ Dialect."""
+
+ return dialect.statement_compiler(dialect, self, **kw)
+
+ def __str__(self):
+ if util.py3k:
+ return str(self.compile())
+ else:
+ return unicode(self.compile()).encode( # noqa
+ "ascii", "backslashreplace"
+ ) # noqa
+
+ def __invert__(self):
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
+ if hasattr(self, "negation_clause"):
+ return self.negation_clause
+ else:
+ return self._negate()
+
+ def _negate(self):
+ return UnaryExpression(
+ self.self_group(against=operators.inv), operator=operators.inv
+ )
+
+ def __bool__(self):
+ raise TypeError("Boolean value of this clause is not defined")
+
+ __nonzero__ = __bool__
+
+ def __repr__(self):
+ friendly = self.description
+ if friendly is None:
+ return object.__repr__(self)
+ else:
+ return "<%s.%s at 0x%x; %s>" % (
+ self.__module__,
+ self.__class__.__name__,
+ id(self),
+ friendly,
+ )
+
+
+class ColumnElement(
+ roles.ColumnArgumentOrKeyRole,
+ roles.StatementOptionRole,
+ roles.WhereHavingRole,
+ roles.BinaryElementRole,
+ roles.OrderByRole,
+ roles.ColumnsClauseRole,
+ roles.LimitOffsetRole,
+ roles.DMLColumnRole,
+ roles.DDLConstraintColumnRole,
+ roles.DDLExpressionRole,
+ operators.ColumnOperators,
+ ClauseElement,
+):
+ """Represent a column-oriented SQL expression suitable for usage in the
+ "columns" clause, WHERE clause etc. of a statement.
+
+ While the most familiar kind of :class:`_expression.ColumnElement` is the
+ :class:`_schema.Column` object, :class:`_expression.ColumnElement`
+ serves as the basis
+ for any unit that may be present in a SQL expression, including
+ the expressions themselves, SQL functions, bound parameters,
+ literal expressions, keywords such as ``NULL``, etc.
+ :class:`_expression.ColumnElement`
+ is the ultimate base class for all such elements.
+
+ A wide variety of SQLAlchemy Core functions work at the SQL expression
+ level, and are intended to accept instances of
+ :class:`_expression.ColumnElement` as
+ arguments. These functions will typically document that they accept a
+ "SQL expression" as an argument. What this means in terms of SQLAlchemy
+ usually refers to an input which is either already in the form of a
+ :class:`_expression.ColumnElement` object,
+ or a value which can be **coerced** into
+ one. The coercion rules followed by most, but not all, SQLAlchemy Core
+ functions with regards to SQL expressions are as follows:
+
+ * a literal Python value, such as a string, integer or floating
+ point value, boolean, datetime, ``Decimal`` object, or virtually
+ any other Python object, will be coerced into a "literal bound
+ value". This generally means that a :func:`.bindparam` will be
+ produced featuring the given value embedded into the construct; the
+ resulting :class:`.BindParameter` object is an instance of
+ :class:`_expression.ColumnElement`.
+ The Python value will ultimately be sent
+ to the DBAPI at execution time as a parameterized argument to the
+ ``execute()`` or ``executemany()`` methods, after SQLAlchemy
+ type-specific converters (e.g. those provided by any associated
+ :class:`.TypeEngine` objects) are applied to the value.
+
+ * any special object value, typically ORM-level constructs, which
+ feature an accessor called ``__clause_element__()``. The Core
+ expression system looks for this method when an object of otherwise
+ unknown type is passed to a function that is looking to coerce the
+ argument into a :class:`_expression.ColumnElement` and sometimes a
+ :class:`_expression.SelectBase` expression.
+ It is used within the ORM to
+ convert from ORM-specific objects like mapped classes and
+ mapped attributes into Core expression objects.
+
+ * The Python ``None`` value is typically interpreted as ``NULL``,
+ which in SQLAlchemy Core produces an instance of :func:`.null`.
+
+ A :class:`_expression.ColumnElement` provides the ability to generate new
+ :class:`_expression.ColumnElement`
+ objects using Python expressions. This means that Python operators
+ such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations,
+ and allow the instantiation of further :class:`_expression.ColumnElement`
+ instances
+ which are composed from other, more fundamental
+ :class:`_expression.ColumnElement`
+ objects. For example, two :class:`.ColumnClause` objects can be added
+ together with the addition operator ``+`` to produce
+ a :class:`.BinaryExpression`.
+ Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses
+ of :class:`_expression.ColumnElement`::
+
+ >>> from sqlalchemy.sql import column
+ >>> column('a') + column('b')
+ <sqlalchemy.sql.expression.BinaryExpression object at 0x101029dd0>
+ >>> print(column('a') + column('b'))
+ a + b
+
+ .. seealso::
+
+ :class:`_schema.Column`
+
+ :func:`_expression.column`
+
+ """
+
+ __visit_name__ = "column_element"
+ primary_key = False
+ foreign_keys = []
+ _proxies = ()
+
+ _tq_label = None
+ """The named label that can be used to target
+ this column in a result set in a "table qualified" context.
+
+ This label is almost always the label used when
+ rendering <expr> AS <label> in a SELECT statement when using
+ the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the legacy
+ ORM ``Query`` object uses as well.
+
+ For a regular Column bound to a Table, this is typically the label
+ <tablename>_<columnname>. For other constructs, different rules
+ may apply, such as anonymized labels and others.
+
+ .. versionchanged:: 1.4.21 renamed from ``._label``
+
+ """
+
+ key = None
+ """The 'key' that in some circumstances refers to this object in a
+ Python namespace.
+
+ This typically refers to the "key" of the column as present in the
+ ``.c`` collection of a selectable, e.g. ``sometable.c["somekey"]`` would
+ return a :class:`_schema.Column` with a ``.key`` of "somekey".
+
+ """
+
+ @HasMemoized.memoized_attribute
+ def _tq_key_label(self):
+ """A label-based version of 'key' that in some circumstances refers
+ to this object in a Python namespace.
+
+
+ _tq_key_label comes into play when a select() statement is constructed
+ with apply_labels(); in this case, all Column objects in the ``.c``
+ collection are rendered as <tablename>_<columnname> in SQL; this is
+ essentially the value of ._label. But to locate those columns in the
+ ``.c`` collection, the name is along the lines of <tablename>_<key>;
+ that's the typical value of .key_label.
+
+ .. versionchanged:: 1.4.21 renamed from ``._key_label``
+
+ """
+ return self._proxy_key
+
+ @property
+ def _key_label(self):
+ """legacy; renamed to _tq_key_label"""
+ return self._tq_key_label
+
+ @property
+ def _label(self):
+ """legacy; renamed to _tq_label"""
+ return self._tq_label
+
+ @property
+ def _non_anon_label(self):
+ """the 'name' that naturally applies this element when rendered in
+ SQL.
+
+ Concretely, this is the "name" of a column or a label in a
+ SELECT statement; ``<columnname>`` and ``<labelname>`` below::
+
+ SELECT <columnmame> FROM table
+
+ SELECT column AS <labelname> FROM table
+
+ Above, the two names noted will be what's present in the DBAPI
+ ``cursor.description`` as the names.
+
+ If this attribute returns ``None``, it means that the SQL element as
+ written does not have a 100% fully predictable "name" that would appear
+ in the ``cursor.description``. Examples include SQL functions, CAST
+ functions, etc. While such things do return names in
+ ``cursor.description``, they are only predictable on a
+ database-specific basis; e.g. an expression like ``MAX(table.col)`` may
+ appear as the string ``max`` on one database (like PostgreSQL) or may
+ appear as the whole expression ``max(table.col)`` on SQLite.
+
+ The default implementation looks for a ``.name`` attribute on the
+ object, as has been the precedent established in SQLAlchemy for many
+ years. An exception is made on the ``FunctionElement`` subclass
+ so that the return value is always ``None``.
+
+ .. versionadded:: 1.4.21
+
+
+
+ """
+ return getattr(self, "name", None)
+
+ _render_label_in_columns_clause = True
+ """A flag used by select._columns_plus_names that helps to determine
+ we are actually going to render in terms of "SELECT <col> AS <label>".
+ This flag can be returned as False for some Column objects that want
+ to be rendered as simple "SELECT <col>"; typically columns that don't have
+ any parent table and are named the same as what the label would be
+ in any case.
+
+ """
+
+ _allow_label_resolve = True
+ """A flag that can be flipped to prevent a column from being resolvable
+ by string label name.
+
+ The joined eager loader strategy in the ORM uses this, for example.
+
+ """
+
+ _is_implicitly_boolean = False
+
+ _alt_names = ()
+
+ def self_group(self, against=None):
+ if (
+ against in (operators.and_, operators.or_, operators._asbool)
+ and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
+ ):
+ return AsBoolean(self, operators.is_true, operators.is_false)
+ elif against in (operators.any_op, operators.all_op):
+ return Grouping(self)
+ else:
+ return self
+
+ def _negate(self):
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ return AsBoolean(self, operators.is_false, operators.is_true)
+ else:
+ return super(ColumnElement, self)._negate()
+
+ @util.memoized_property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @HasMemoized.memoized_attribute
+ def comparator(self):
+ try:
+ comparator_factory = self.type.comparator_factory
+ except AttributeError as err:
+ util.raise_(
+ TypeError(
+ "Object %r associated with '.type' attribute "
+ "is not a TypeEngine class or object" % self.type
+ ),
+ replace_context=err,
+ )
+ else:
+ return comparator_factory(self)
+
+ def __getattr__(self, key):
+ try:
+ return getattr(self.comparator, key)
+ except AttributeError as err:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object has an attribute %r"
+ % (
+ type(self).__name__,
+ type(self.comparator).__name__,
+ key,
+ )
+ ),
+ replace_context=err,
+ )
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.comparator, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ expanding=expanding,
+ )
+
+ @property
+ def expression(self):
+ """Return a column expression.
+
+ Part of the inspection interface; returns self.
+
+ """
+ return self
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ @util.memoized_property
+ def base_columns(self):
+ return util.column_set(c for c in self.proxy_set if not c._proxies)
+
+ @util.memoized_property
+ def proxy_set(self):
+ s = util.column_set([self])
+ for c in self._proxies:
+ s.update(c.proxy_set)
+ return s
+
+ def _uncached_proxy_set(self):
+ """An 'uncached' version of proxy set.
+
+ This is so that we can read annotations from the list of columns
+ without breaking the caching of the above proxy_set.
+
+ """
+ s = util.column_set([self])
+ for c in self._proxies:
+ s.update(c._uncached_proxy_set())
+ return s
+
+ def shares_lineage(self, othercolumn):
+ """Return True if the given :class:`_expression.ColumnElement`
+ has a common ancestor to this :class:`_expression.ColumnElement`."""
+
+ return bool(self.proxy_set.intersection(othercolumn.proxy_set))
+
+ def _compare_name_for_result(self, other):
+ """Return True if the given column element compares to this one
+ when targeting within a result row."""
+
+ return (
+ hasattr(other, "name")
+ and hasattr(self, "name")
+ and other.name == self.name
+ )
+
+ @HasMemoized.memoized_attribute
+ def _proxy_key(self):
+ if self._annotations and "proxy_key" in self._annotations:
+ return self._annotations["proxy_key"]
+
+ name = self.key
+ if not name:
+ # there's a bit of a seeming contradiction which is that the
+ # "_non_anon_label" of a column can in fact be an
+ # "_anonymous_label"; this is when it's on a column that is
+ # proxying for an anonymous expression in a subquery.
+ name = self._non_anon_label
+
+ if isinstance(name, _anonymous_label):
+ return None
+ else:
+ return name
+
+ @HasMemoized.memoized_attribute
+ def _expression_label(self):
+ """a suggested label to use in the case that the column has no name,
+ which should be used if possible as the explicit 'AS <label>'
+ where this expression would normally have an anon label.
+
+ this is essentially mostly what _proxy_key does except it returns
+ None if the column has a normal name that can be used.
+
+ """
+
+ if getattr(self, "name", None) is not None:
+ return None
+ elif self._annotations and "proxy_key" in self._annotations:
+ return self._annotations["proxy_key"]
+ else:
+ return None
+
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
+ """Create a new :class:`_expression.ColumnElement` representing this
+ :class:`_expression.ColumnElement` as it appears in the select list of
+ a descending selectable.
+
+ """
+ if name is None:
+ name = self._anon_name_label
+ if key is None:
+ key = self._proxy_key
+ else:
+ key = name
+
+ co = ColumnClause(
+ coercions.expect(roles.TruncatedLabelRole, name)
+ if name_is_truncatable
+ else name,
+ type_=getattr(self, "type", None),
+ _selectable=selectable,
+ )
+
+ co._propagate_attrs = selectable._propagate_attrs
+ co._proxies = [self]
+ if selectable._is_clone_of is not None:
+ co._is_clone_of = selectable._is_clone_of.columns.get(key)
+ return key, co
+
+ def cast(self, type_):
+ """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
+
+ This is a shortcut to the :func:`_expression.cast` function.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`_expression.cast`
+
+ :func:`_expression.type_coerce`
+
+ .. versionadded:: 1.0.7
+
+ """
+ return Cast(self, type_)
+
+ def label(self, name):
+ """Produce a column label, i.e. ``<columnname> AS <name>``.
+
+ This is a shortcut to the :func:`_expression.label` function.
+
+ If 'name' is ``None``, an anonymous label name will be generated.
+
+ """
+ return Label(name, self, self.type)
+
+ def _anon_label(self, seed, add_hash=None):
+ while self._is_clone_of is not None:
+ self = self._is_clone_of
+
+ # as of 1.4 anonymous label for ColumnElement uses hash(), not id(),
+ # as the identifier, because a column and its annotated version are
+ # the same thing in a SQL statement
+ hash_value = hash(self)
+
+ if add_hash:
+ # this path is used for disambiguating anon labels that would
+ # otherwise be the same name for the same element repeated.
+ # an additional numeric value is factored in for each label.
+
+ # shift hash(self) (which is id(self), typically 8 byte integer)
+ # 16 bits leftward. fill extra add_hash on right
+ assert add_hash < (2 << 15)
+ assert seed
+ hash_value = (hash_value << 16) | add_hash
+
+ # extra underscore is added for labels with extra hash
+ # values, to isolate the "deduped anon" namespace from the
+ # regular namespace. eliminates chance of these
+ # manufactured hash values overlapping with regular ones for some
+ # undefined python interpreter
+ seed = seed + "_"
+
+ if isinstance(seed, _anonymous_label):
+ return _anonymous_label.safe_construct(
+ hash_value, "", enclosing_label=seed
+ )
+
+ return _anonymous_label.safe_construct(hash_value, seed or "anon")
+
+ @util.memoized_property
+ def _anon_name_label(self):
+ """Provides a constant 'anonymous label' for this ColumnElement.
+
+ This is a label() expression which will be named at compile time.
+ The same label() is returned each time ``anon_label`` is called so
+ that expressions can reference ``anon_label`` multiple times,
+ producing the same label name at compile time.
+
+ The compiler uses this function automatically at compile time
+ for expressions that are known to be 'unnamed' like binary
+ expressions and function calls.
+
+ .. versionchanged:: 1.4.9 - this attribute was not intended to be
+ public and is renamed to _anon_name_label. anon_name exists
+ for backwards compat
+
+ """
+ name = getattr(self, "name", None)
+ return self._anon_label(name)
+
+ @util.memoized_property
+ def _anon_key_label(self):
+ """Provides a constant 'anonymous key label' for this ColumnElement.
+
+ Compare to ``anon_label``, except that the "key" of the column,
+ if available, is used to generate the label.
+
+ This is used when a deduplicating key is placed into the columns
+ collection of a selectable.
+
+ .. versionchanged:: 1.4.9 - this attribute was not intended to be
+ public and is renamed to _anon_key_label. anon_key_label exists
+ for backwards compat
+
+ """
+ return self._anon_label(self._proxy_key)
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.ColumnElement.anon_label` attribute is now "
+ "private, and the public accessor is deprecated.",
+ )
+ def anon_label(self):
+ return self._anon_name_label
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.ColumnElement.anon_key_label` attribute is "
+ "now private, and the public accessor is deprecated.",
+ )
+ def anon_key_label(self):
+ return self._anon_key_label
+
+ def _dedupe_anon_label_idx(self, idx):
+ """label to apply to a column that is anon labeled, but repeated
+ in the SELECT, so that we have to make an "extra anon" label that
+ disambiguates it from the previous appearance.
+
+ these labels come out like "foo_bar_id__1" and have double underscores
+ in them.
+
+ """
+ label = getattr(self, "name", None)
+
+ # current convention is that if the element doesn't have a
+ # ".name" (usually because it is not NamedColumn), we try to
+ # use a "table qualified" form for the "dedupe anon" label,
+ # based on the notion that a label like
+ # "CAST(casttest.v1 AS DECIMAL) AS casttest_v1__1" looks better than
+ # "CAST(casttest.v1 AS DECIMAL) AS anon__1"
+
+ if label is None:
+ return self._dedupe_anon_tq_label_idx(idx)
+ else:
+ return self._anon_label(label, add_hash=idx)
+
+ @util.memoized_property
+ def _anon_tq_label(self):
+ return self._anon_label(getattr(self, "_tq_label", None))
+
+ @util.memoized_property
+ def _anon_tq_key_label(self):
+ return self._anon_label(getattr(self, "_tq_key_label", None))
+
+ def _dedupe_anon_tq_label_idx(self, idx):
+ label = getattr(self, "_tq_label", None) or "anon"
+
+ return self._anon_label(label, add_hash=idx)
+
+
+class WrapsColumnExpression(object):
+ """Mixin that defines a :class:`_expression.ColumnElement`
+ as a wrapper with special
+ labeling behavior for an expression that already has a name.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`change_4449`
+
+
+ """
+
+ @property
+ def wrapped_column_expression(self):
+ raise NotImplementedError()
+
+ @property
+ def _tq_label(self):
+ wce = self.wrapped_column_expression
+ if hasattr(wce, "_tq_label"):
+ return wce._tq_label
+ else:
+ return None
+
+ _label = _tq_label
+
+ @property
+ def _non_anon_label(self):
+ return None
+
+ @property
+ def _anon_name_label(self):
+ wce = self.wrapped_column_expression
+
+ # this logic tries to get the WrappedColumnExpression to render
+ # with "<expr> AS <name>", where "<name>" is the natural name
+ # within the expression itself. e.g. "CAST(table.foo) AS foo".
+ if not wce._is_text_clause:
+ nal = wce._non_anon_label
+ if nal:
+ return nal
+ elif hasattr(wce, "_anon_name_label"):
+ return wce._anon_name_label
+ return super(WrapsColumnExpression, self)._anon_name_label
+
+ def _dedupe_anon_label_idx(self, idx):
+ wce = self.wrapped_column_expression
+ nal = wce._non_anon_label
+ if nal:
+ return self._anon_label(nal + "_")
+ else:
+ return self._dedupe_anon_tq_label_idx(idx)
+
+ @property
+ def _proxy_key(self):
+ wce = self.wrapped_column_expression
+
+ if not wce._is_text_clause:
+ return wce._proxy_key
+ return super(WrapsColumnExpression, self)._proxy_key
+
+
+class BindParameter(roles.InElementRole, ColumnElement):
+ r"""Represent a "bound expression".
+
+ :class:`.BindParameter` is invoked explicitly using the
+ :func:`.bindparam` function, as in::
+
+ from sqlalchemy import bindparam
+
+ stmt = select(users_table).\
+ where(users_table.c.name == bindparam('username'))
+
+ Detailed discussion of how :class:`.BindParameter` is used is
+ at :func:`.bindparam`.
+
+ .. seealso::
+
+ :func:`.bindparam`
+
+ """
+
+ __visit_name__ = "bindparam"
+
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("callable", InternalTraversal.dp_plain_dict),
+ ("value", InternalTraversal.dp_plain_obj),
+ ("literal_execute", InternalTraversal.dp_boolean),
+ ]
+
+ _is_crud = False
+ _is_bind_parameter = True
+ _key_is_anon = False
+
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
+ def __init__(
+ self,
+ key,
+ value=NO_ARG,
+ type_=None,
+ unique=False,
+ required=NO_ARG,
+ quote=None,
+ callable_=None,
+ expanding=False,
+ isoutparam=False,
+ literal_execute=False,
+ _compared_to_operator=None,
+ _compared_to_type=None,
+ _is_crud=False,
+ ):
+ r"""Produce a "bound expression".
+
+ The return value is an instance of :class:`.BindParameter`; this
+ is a :class:`_expression.ColumnElement`
+ subclass which represents a so-called
+ "placeholder" value in a SQL expression, the value of which is
+ supplied at the point at which the statement in executed against a
+ database connection.
+
+ In SQLAlchemy, the :func:`.bindparam` construct has
+ the ability to carry along the actual value that will be ultimately
+ used at expression time. In this way, it serves not just as
+ a "placeholder" for eventual population, but also as a means of
+ representing so-called "unsafe" values which should not be rendered
+ directly in a SQL statement, but rather should be passed along
+ to the :term:`DBAPI` as values which need to be correctly escaped
+ and potentially handled for type-safety.
+
+ When using :func:`.bindparam` explicitly, the use case is typically
+ one of traditional deferment of parameters; the :func:`.bindparam`
+ construct accepts a name which can then be referred to at execution
+ time::
+
+ from sqlalchemy import bindparam
+
+ stmt = select(users_table).\
+ where(users_table.c.name == bindparam('username'))
+
+ The above statement, when rendered, will produce SQL similar to::
+
+ SELECT id, name FROM user WHERE name = :username
+
+ In order to populate the value of ``:username`` above, the value
+ would typically be applied at execution time to a method
+ like :meth:`_engine.Connection.execute`::
+
+ result = connection.execute(stmt, username='wendy')
+
+ Explicit use of :func:`.bindparam` is also common when producing
+ UPDATE or DELETE statements that are to be invoked multiple times,
+ where the WHERE criterion of the statement is to change on each
+ invocation, such as::
+
+ stmt = (users_table.update().
+ where(user_table.c.name == bindparam('username')).
+ values(fullname=bindparam('fullname'))
+ )
+
+ connection.execute(
+ stmt, [{"username": "wendy", "fullname": "Wendy Smith"},
+ {"username": "jack", "fullname": "Jack Jones"},
+ ]
+ )
+
+ SQLAlchemy's Core expression system makes wide use of
+ :func:`.bindparam` in an implicit sense. It is typical that Python
+ literal values passed to virtually all SQL expression functions are
+ coerced into fixed :func:`.bindparam` constructs. For example, given
+ a comparison operation such as::
+
+ expr = users_table.c.name == 'Wendy'
+
+ The above expression will produce a :class:`.BinaryExpression`
+ construct, where the left side is the :class:`_schema.Column` object
+ representing the ``name`` column, and the right side is a
+ :class:`.BindParameter` representing the literal value::
+
+ print(repr(expr.right))
+ BindParameter('%(4327771088 name)s', 'Wendy', type_=String())
+
+ The expression above will render SQL such as::
+
+ user.name = :name_1
+
+ Where the ``:name_1`` parameter name is an anonymous name. The
+ actual string ``Wendy`` is not in the rendered string, but is carried
+ along where it is later used within statement execution. If we
+ invoke a statement like the following::
+
+ stmt = select(users_table).where(users_table.c.name == 'Wendy')
+ result = connection.execute(stmt)
+
+ We would see SQL logging output as::
+
+ SELECT "user".id, "user".name
+ FROM "user"
+ WHERE "user".name = %(name_1)s
+ {'name_1': 'Wendy'}
+
+ Above, we see that ``Wendy`` is passed as a parameter to the database,
+ while the placeholder ``:name_1`` is rendered in the appropriate form
+ for the target database, in this case the PostgreSQL database.
+
+ Similarly, :func:`.bindparam` is invoked automatically when working
+ with :term:`CRUD` statements as far as the "VALUES" portion is
+ concerned. The :func:`_expression.insert` construct produces an
+ ``INSERT`` expression which will, at statement execution time, generate
+ bound placeholders based on the arguments passed, as in::
+
+ stmt = users_table.insert()
+ result = connection.execute(stmt, name='Wendy')
+
+ The above will produce SQL output as::
+
+ INSERT INTO "user" (name) VALUES (%(name)s)
+ {'name': 'Wendy'}
+
+ The :class:`_expression.Insert` construct, at
+ compilation/execution time, rendered a single :func:`.bindparam`
+ mirroring the column name ``name`` as a result of the single ``name``
+ parameter we passed to the :meth:`_engine.Connection.execute` method.
+
+ :param key:
+ the key (e.g. the name) for this bind param.
+ Will be used in the generated
+ SQL statement for dialects that use named parameters. This
+ value may be modified when part of a compilation operation,
+ if other :class:`BindParameter` objects exist with the same
+ key, or if its length is too long and truncation is
+ required.
+
+ :param value:
+ Initial value for this bind param. Will be used at statement
+ execution time as the value for this parameter passed to the
+ DBAPI, if no other value is indicated to the statement execution
+ method for this particular parameter name. Defaults to ``None``.
+
+ :param callable\_:
+ A callable function that takes the place of "value". The function
+ will be called at statement execution time to determine the
+ ultimate value. Used for scenarios where the actual bind
+ value cannot be determined at the point at which the clause
+ construct is created, but embedded bind values are still desirable.
+
+ :param type\_:
+ A :class:`.TypeEngine` class or instance representing an optional
+ datatype for this :func:`.bindparam`. If not passed, a type
+ may be determined automatically for the bind, based on the given
+ value; for example, trivial Python types such as ``str``,
+ ``int``, ``bool``
+ may result in the :class:`.String`, :class:`.Integer` or
+ :class:`.Boolean` types being automatically selected.
+
+ The type of a :func:`.bindparam` is significant especially in that
+ the type will apply pre-processing to the value before it is
+ passed to the database. For example, a :func:`.bindparam` which
+ refers to a datetime value, and is specified as holding the
+ :class:`.DateTime` type, may apply conversion needed to the
+ value (such as stringification on SQLite) before passing the value
+ to the database.
+
+ :param unique:
+ if True, the key name of this :class:`.BindParameter` will be
+ modified if another :class:`.BindParameter` of the same name
+ already has been located within the containing
+ expression. This flag is used generally by the internals
+ when producing so-called "anonymous" bound expressions, it
+ isn't generally applicable to explicitly-named :func:`.bindparam`
+ constructs.
+
+ :param required:
+ If ``True``, a value is required at execution time. If not passed,
+ it defaults to ``True`` if neither :paramref:`.bindparam.value`
+ or :paramref:`.bindparam.callable` were passed. If either of these
+ parameters are present, then :paramref:`.bindparam.required`
+ defaults to ``False``.
+
+ :param quote:
+ True if this parameter name requires quoting and is not
+ currently known as a SQLAlchemy reserved word; this currently
+ only applies to the Oracle backend, where bound names must
+ sometimes be quoted.
+
+ :param isoutparam:
+ if True, the parameter should be treated like a stored procedure
+ "OUT" parameter. This applies to backends such as Oracle which
+ support OUT parameters.
+
+ :param expanding:
+ if True, this parameter will be treated as an "expanding" parameter
+ at execution time; the parameter value is expected to be a sequence,
+ rather than a scalar value, and the string SQL statement will
+ be transformed on a per-execution basis to accommodate the sequence
+ with a variable number of parameter slots passed to the DBAPI.
+ This is to allow statement caching to be used in conjunction with
+ an IN clause.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.in_`
+
+ :ref:`baked_in` - with baked queries
+
+ .. note:: The "expanding" feature does not support "executemany"-
+ style parameter sets.
+
+ .. versionadded:: 1.2
+
+ .. versionchanged:: 1.3 the "expanding" bound parameter feature now
+ supports empty lists.
+
+ :param literal_execute:
+ if True, the bound parameter will be rendered in the compile phase
+ with a special "POSTCOMPILE" token, and the SQLAlchemy compiler will
+ render the final value of the parameter into the SQL statement at
+ statement execution time, omitting the value from the parameter
+ dictionary / list passed to DBAPI ``cursor.execute()``. This
+ produces a similar effect as that of using the ``literal_binds``,
+ compilation flag, however takes place as the statement is sent to
+ the DBAPI ``cursor.execute()`` method, rather than when the statement
+ is compiled. The primary use of this
+ capability is for rendering LIMIT / OFFSET clauses for database
+ drivers that can't accommodate for bound parameters in these
+ contexts, while allowing SQL constructs to be cacheable at the
+ compilation level.
+
+ .. versionadded:: 1.4 Added "post compile" bound parameters
+
+ .. seealso::
+
+ :ref:`change_4808`.
+
+ .. seealso::
+
+ :ref:`tutorial_sending_parameters` - in the
+ :ref:`unified_tutorial`
+
+ """
+ if required is NO_ARG:
+ required = value is NO_ARG and callable_ is None
+ if value is NO_ARG:
+ value = None
+
+ if quote is not None:
+ key = quoted_name(key, quote)
+
+ if unique:
+ self.key = _anonymous_label.safe_construct(
+ id(self),
+ key
+ if key is not None and not isinstance(key, _anonymous_label)
+ else "param",
+ sanitize_key=True,
+ )
+ self._key_is_anon = True
+ elif key:
+ self.key = key
+ else:
+ self.key = _anonymous_label.safe_construct(id(self), "param")
+ self._key_is_anon = True
+
+ # identifying key that won't change across
+ # clones, used to identify the bind's logical
+ # identity
+ self._identifying_key = self.key
+
+ # key that was passed in the first place, used to
+ # generate new keys
+ self._orig_key = key or "param"
+
+ self.unique = unique
+ self.value = value
+ self.callable = callable_
+ self.isoutparam = isoutparam
+ self.required = required
+
+ # indicate an "expanding" parameter; the compiler sets this
+ # automatically in the compiler _render_in_expr_w_bindparam method
+ # for an IN expression
+ self.expanding = expanding
+
+ # this is another hint to help w/ expanding and is typically
+ # set in the compiler _render_in_expr_w_bindparam method for an
+ # IN expression
+ self.expand_op = None
+
+ self.literal_execute = literal_execute
+ if _is_crud:
+ self._is_crud = True
+
+ if type_ is None:
+ if expanding and value:
+ check_value = value[0]
+ else:
+ check_value = value
+ if _compared_to_type is not None:
+ self.type = _compared_to_type.coerce_compared_value(
+ _compared_to_operator, check_value
+ )
+ else:
+ self.type = type_api._resolve_value_to_type(check_value)
+ elif isinstance(type_, type):
+ self.type = type_()
+ elif type_._is_tuple_type and value:
+ if expanding:
+ check_value = value[0]
+ else:
+ check_value = value
+ self.type = type_._resolve_values_to_types(check_value)
+ else:
+ self.type = type_
+
+ def _with_value(self, value, maintain_key=False, required=NO_ARG):
+ """Return a copy of this :class:`.BindParameter` with the given value
+ set.
+ """
+ cloned = self._clone(maintain_key=maintain_key)
+ cloned.value = value
+ cloned.callable = None
+ cloned.required = required if required is not NO_ARG else self.required
+ if cloned.type is type_api.NULLTYPE:
+ cloned.type = type_api._resolve_value_to_type(value)
+ return cloned
+
+ @property
+ def effective_value(self):
+ """Return the value of this bound parameter,
+ taking into account if the ``callable`` parameter
+ was set.
+
+ The ``callable`` value will be evaluated
+ and returned if present, else ``value``.
+
+ """
+ if self.callable:
+ return self.callable()
+ else:
+ return self.value
+
+ def render_literal_execute(self):
+ """Produce a copy of this bound parameter that will enable the
+ :paramref:`_sql.BindParameter.literal_execute` flag.
+
+ The :paramref:`_sql.BindParameter.literal_execute` flag will
+ have the effect of the parameter rendered in the compiled SQL
+ string using ``[POSTCOMPILE]`` form, which is a special form that
+ is converted to be a rendering of the literal value of the parameter
+ at SQL execution time. The rationale is to support caching
+ of SQL statement strings that can embed per-statement literal values,
+ such as LIMIT and OFFSET parameters, in the final SQL string that
+ is passed to the DBAPI. Dialects in particular may want to use
+ this method within custom compilation schemes.
+
+ .. versionadded:: 1.4.5
+
+ .. seealso::
+
+ :ref:`engine_thirdparty_caching`
+
+ """
+ return self.__class__(
+ self.key,
+ self.value,
+ type_=self.type,
+ literal_execute=True,
+ )
+
+ def _negate_in_binary(self, negated_op, original_op):
+ if self.expand_op is original_op:
+ bind = self._clone()
+ bind.expand_op = negated_op
+ return bind
+ else:
+ return self
+
+ def _with_binary_element_type(self, type_):
+ c = ClauseElement._clone(self)
+ c.type = type_
+ return c
+
+ def _clone(self, maintain_key=False, **kw):
+ c = ClauseElement._clone(self, **kw)
+ # ensure all the BindParameter objects stay in cloned set.
+ # in #7823, we changed "clone" so that a clone only keeps a reference
+ # to the "original" element, since for column correspondence, that's
+ # all we need. However, for BindParam, _cloned_set is used by
+ # the "cache key bind match" lookup, which means if any of those
+ # interim BindParameter objects became part of a cache key in the
+ # cache, we need it. So here, make sure all clones keep carrying
+ # forward.
+ c._cloned_set.update(self._cloned_set)
+ if not maintain_key and self.unique:
+ c.key = _anonymous_label.safe_construct(
+ id(c), c._orig_key or "param", sanitize_key=True
+ )
+ return c
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
+ idself = id(self)
+ if idself in anon_map:
+ return (anon_map[idself], self.__class__)
+ else:
+ # inline of
+ # id_ = anon_map[idself]
+ anon_map[idself] = id_ = str(anon_map.index)
+ anon_map.index += 1
+
+ if bindparams is not None:
+ bindparams.append(self)
+
+ return (
+ id_,
+ self.__class__,
+ self.type._static_cache_key,
+ self.key % anon_map if self._key_is_anon else self.key,
+ self.literal_execute,
+ )
+
+ def _convert_to_unique(self):
+ if not self.unique:
+ self.unique = True
+ self.key = _anonymous_label.safe_construct(
+ id(self), self._orig_key or "param", sanitize_key=True
+ )
+
+ def __getstate__(self):
+ """execute a deferred value for serialization purposes."""
+
+ d = self.__dict__.copy()
+ v = self.value
+ if self.callable:
+ v = self.callable()
+ d["callable"] = None
+ d["value"] = v
+ return d
+
+ def __setstate__(self, state):
+ if state.get("unique", False):
+ state["key"] = _anonymous_label.safe_construct(
+ id(self), state.get("_orig_key", "param"), sanitize_key=True
+ )
+ self.__dict__.update(state)
+
+ def __repr__(self):
+ return "%s(%r, %r, type_=%r)" % (
+ self.__class__.__name__,
+ self.key,
+ self.value,
+ self.type,
+ )
+
+
+class TypeClause(ClauseElement):
+ """Handle a type keyword in a SQL statement.
+
+ Used by the ``Case`` statement.
+
+ """
+
+ __visit_name__ = "typeclause"
+
+ _traverse_internals = [("type", InternalTraversal.dp_type)]
+
+ def __init__(self, type_):
+ self.type = type_
+
+
+class TextClause(
+ roles.DDLConstraintColumnRole,
+ roles.DDLExpressionRole,
+ roles.StatementOptionRole,
+ roles.WhereHavingRole,
+ roles.OrderByRole,
+ roles.FromClauseRole,
+ roles.SelectStatementRole,
+ roles.BinaryElementRole,
+ roles.InElementRole,
+ Executable,
+ ClauseElement,
+):
+ """Represent a literal SQL text fragment.
+
+ E.g.::
+
+ from sqlalchemy import text
+
+ t = text("SELECT * FROM users")
+ result = connection.execute(t)
+
+
+ The :class:`_expression.TextClause` construct is produced using the
+ :func:`_expression.text`
+ function; see that function for full documentation.
+
+ .. seealso::
+
+ :func:`_expression.text`
+
+ """
+
+ __visit_name__ = "textclause"
+
+ _traverse_internals = [
+ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+ ("text", InternalTraversal.dp_string),
+ ]
+
+ _is_text_clause = True
+
+ _is_textual = True
+
+ _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": PARSE_AUTOCOMMIT}
+ )
+ _is_implicitly_boolean = False
+
+ _render_label_in_columns_clause = False
+
+ _hide_froms = ()
+
+ def __and__(self, other):
+ # support use in select.where(), query.filter()
+ return and_(self, other)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ # help in those cases where text() is
+ # interpreted in a column expression situation
+ key = _label = None
+
+ _allow_label_resolve = False
+
+ @property
+ def _is_star(self):
+ return self.text == "*"
+
+ def __init__(self, text, bind=None):
+ self._bind = bind
+ self._bindparams = {}
+
+ def repl(m):
+ self._bindparams[m.group(1)] = BindParameter(m.group(1))
+ return ":%s" % m.group(1)
+
+ # scan the string and search for bind parameter names, add them
+ # to the list of bindparams
+ self.text = self._bind_params_regex.sub(repl, text)
+
+ @classmethod
+ @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.text.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def _create_text(cls, text, bind=None):
+ r"""Construct a new :class:`_expression.TextClause` clause,
+ representing
+ a textual SQL string directly.
+
+ E.g.::
+
+ from sqlalchemy import text
+
+ t = text("SELECT * FROM users")
+ result = connection.execute(t)
+
+ The advantages :func:`_expression.text`
+ provides over a plain string are
+ backend-neutral support for bind parameters, per-statement
+ execution options, as well as
+ bind parameter and result-column typing behavior, allowing
+ SQLAlchemy type constructs to play a role when executing
+ a statement that is specified literally. The construct can also
+ be provided with a ``.c`` collection of column elements, allowing
+ it to be embedded in other SQL expression constructs as a subquery.
+
+ Bind parameters are specified by name, using the format ``:name``.
+ E.g.::
+
+ t = text("SELECT * FROM users WHERE id=:user_id")
+ result = connection.execute(t, user_id=12)
+
+ For SQL statements where a colon is required verbatim, as within
+ an inline string, use a backslash to escape::
+
+ t = text("SELECT * FROM users WHERE name='\:username'")
+
+ The :class:`_expression.TextClause`
+ construct includes methods which can
+ provide information about the bound parameters as well as the column
+ values which would be returned from the textual statement, assuming
+ it's an executable SELECT type of statement. The
+ :meth:`_expression.TextClause.bindparams`
+ method is used to provide bound
+ parameter detail, and :meth:`_expression.TextClause.columns`
+ method allows
+ specification of return columns including names and types::
+
+ t = text("SELECT * FROM users WHERE id=:user_id").\
+ bindparams(user_id=7).\
+ columns(id=Integer, name=String)
+
+ for id, name in connection.execute(t):
+ print(id, name)
+
+ The :func:`_expression.text` construct is used in cases when
+ a literal string SQL fragment is specified as part of a larger query,
+ such as for the WHERE clause of a SELECT statement::
+
+ s = select(users.c.id, users.c.name).where(text("id=:user_id"))
+ result = connection.execute(s, user_id=12)
+
+ :func:`_expression.text` is also used for the construction
+ of a full, standalone statement using plain text.
+ As such, SQLAlchemy refers
+ to it as an :class:`.Executable` object, and it supports
+ the :meth:`Executable.execution_options` method. For example,
+ a :func:`_expression.text`
+ construct that should be subject to "autocommit"
+ can be set explicitly so using the
+ :paramref:`.Connection.execution_options.autocommit` option::
+
+ t = text("EXEC my_procedural_thing()").\
+ execution_options(autocommit=True)
+
+ .. deprecated:: 1.4 The "autocommit" execution option is deprecated
+ and will be removed in SQLAlchemy 2.0. See
+ :ref:`migration_20_autocommit` for discussion.
+
+ :param text:
+ the text of the SQL statement to be created. Use ``:<param>``
+ to specify bind parameters; they will be compiled to their
+ engine-specific format.
+
+ :param bind:
+ an optional connection or engine to be used for this text query.
+
+ .. seealso::
+
+ :ref:`tutorial_select_arbitrary_text`
+
+
+ """
+ return TextClause(text, bind=bind)
+
+ @_generative
+ def bindparams(self, *binds, **names_to_values):
+ """Establish the values and/or types of bound parameters within
+ this :class:`_expression.TextClause` construct.
+
+ Given a text construct such as::
+
+ from sqlalchemy import text
+ stmt = text("SELECT id, name FROM user WHERE name=:name "
+ "AND timestamp=:timestamp")
+
+ the :meth:`_expression.TextClause.bindparams`
+ method can be used to establish
+ the initial value of ``:name`` and ``:timestamp``,
+ using simple keyword arguments::
+
+ stmt = stmt.bindparams(name='jack',
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5))
+
+ Where above, new :class:`.BindParameter` objects
+ will be generated with the names ``name`` and ``timestamp``, and
+ values of ``jack`` and ``datetime.datetime(2012, 10, 8, 15, 12, 5)``,
+ respectively. The types will be
+ inferred from the values given, in this case :class:`.String` and
+ :class:`.DateTime`.
+
+ When specific typing behavior is needed, the positional ``*binds``
+ argument can be used in which to specify :func:`.bindparam` constructs
+ directly. These constructs must include at least the ``key``
+ argument, then an optional value and type::
+
+ from sqlalchemy import bindparam
+ stmt = stmt.bindparams(
+ bindparam('name', value='jack', type_=String),
+ bindparam('timestamp', type_=DateTime)
+ )
+
+ Above, we specified the type of :class:`.DateTime` for the
+ ``timestamp`` bind, and the type of :class:`.String` for the ``name``
+ bind. In the case of ``name`` we also set the default value of
+ ``"jack"``.
+
+ Additional bound parameters can be supplied at statement execution
+ time, e.g.::
+
+ result = connection.execute(stmt,
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5))
+
+ The :meth:`_expression.TextClause.bindparams`
+ method can be called repeatedly,
+ where it will re-use existing :class:`.BindParameter` objects to add
+ new information. For example, we can call
+ :meth:`_expression.TextClause.bindparams`
+ first with typing information, and a
+ second time with value information, and it will be combined::
+
+ stmt = text("SELECT id, name FROM user WHERE name=:name "
+ "AND timestamp=:timestamp")
+ stmt = stmt.bindparams(
+ bindparam('name', type_=String),
+ bindparam('timestamp', type_=DateTime)
+ )
+ stmt = stmt.bindparams(
+ name='jack',
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)
+ )
+
+ The :meth:`_expression.TextClause.bindparams`
+ method also supports the concept of
+ **unique** bound parameters. These are parameters that are
+ "uniquified" on name at statement compilation time, so that multiple
+ :func:`_expression.text`
+ constructs may be combined together without the names
+ conflicting. To use this feature, specify the
+ :paramref:`.BindParameter.unique` flag on each :func:`.bindparam`
+ object::
+
+ stmt1 = text("select id from table where name=:name").bindparams(
+ bindparam("name", value='name1', unique=True)
+ )
+ stmt2 = text("select id from table where name=:name").bindparams(
+ bindparam("name", value='name2', unique=True)
+ )
+
+ union = union_all(
+ stmt1.columns(column("id")),
+ stmt2.columns(column("id"))
+ )
+
+ The above statement will render as::
+
+ select id from table where name=:name_1
+ UNION ALL select id from table where name=:name_2
+
+ .. versionadded:: 1.3.11 Added support for the
+ :paramref:`.BindParameter.unique` flag to work with
+ :func:`_expression.text`
+ constructs.
+
+ """
+ self._bindparams = new_params = self._bindparams.copy()
+
+ for bind in binds:
+ try:
+ # the regex used for text() currently will not match
+ # a unique/anonymous key in any case, so use the _orig_key
+ # so that a text() construct can support unique parameters
+ existing = new_params[bind._orig_key]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "This text() construct doesn't define a "
+ "bound parameter named %r" % bind._orig_key
+ ),
+ replace_context=err,
+ )
+ else:
+ new_params[existing._orig_key] = bind
+
+ for key, value in names_to_values.items():
+ try:
+ existing = new_params[key]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "This text() construct doesn't define a "
+ "bound parameter named %r" % key
+ ),
+ replace_context=err,
+ )
+ else:
+ new_params[key] = existing._with_value(value, required=False)
+
+ @util.preload_module("sqlalchemy.sql.selectable")
+ def columns(self, *cols, **types):
+ r"""Turn this :class:`_expression.TextClause` object into a
+ :class:`_expression.TextualSelect`
+ object that serves the same role as a SELECT
+ statement.
+
+ The :class:`_expression.TextualSelect` is part of the
+ :class:`_expression.SelectBase`
+ hierarchy and can be embedded into another statement by using the
+ :meth:`_expression.TextualSelect.subquery` method to produce a
+ :class:`.Subquery`
+ object, which can then be SELECTed from.
+
+ This function essentially bridges the gap between an entirely
+ textual SELECT statement and the SQL expression language concept
+ of a "selectable"::
+
+ from sqlalchemy.sql import column, text
+
+ stmt = text("SELECT id, name FROM some_table")
+ stmt = stmt.columns(column('id'), column('name')).subquery('st')
+
+ stmt = select(mytable).\
+ select_from(
+ mytable.join(stmt, mytable.c.name == stmt.c.name)
+ ).where(stmt.c.id > 5)
+
+ Above, we pass a series of :func:`_expression.column` elements to the
+ :meth:`_expression.TextClause.columns` method positionally. These
+ :func:`_expression.column`
+ elements now become first class elements upon the
+ :attr:`_expression.TextualSelect.selected_columns` column collection,
+ which then
+ become part of the :attr:`.Subquery.c` collection after
+ :meth:`_expression.TextualSelect.subquery` is invoked.
+
+ The column expressions we pass to
+ :meth:`_expression.TextClause.columns` may
+ also be typed; when we do so, these :class:`.TypeEngine` objects become
+ the effective return type of the column, so that SQLAlchemy's
+ result-set-processing systems may be used on the return values.
+ This is often needed for types such as date or boolean types, as well
+ as for unicode processing on some dialect configurations::
+
+ stmt = text("SELECT id, name, timestamp FROM some_table")
+ stmt = stmt.columns(
+ column('id', Integer),
+ column('name', Unicode),
+ column('timestamp', DateTime)
+ )
+
+ for id, name, timestamp in connection.execute(stmt):
+ print(id, name, timestamp)
+
+ As a shortcut to the above syntax, keyword arguments referring to
+ types alone may be used, if only type conversion is needed::
+
+ stmt = text("SELECT id, name, timestamp FROM some_table")
+ stmt = stmt.columns(
+ id=Integer,
+ name=Unicode,
+ timestamp=DateTime
+ )
+
+ for id, name, timestamp in connection.execute(stmt):
+ print(id, name, timestamp)
+
+ The positional form of :meth:`_expression.TextClause.columns`
+ also provides the
+ unique feature of **positional column targeting**, which is
+ particularly useful when using the ORM with complex textual queries. If
+ we specify the columns from our model to
+ :meth:`_expression.TextClause.columns`,
+ the result set will match to those columns positionally, meaning the
+ name or origin of the column in the textual SQL doesn't matter::
+
+ stmt = text("SELECT users.id, addresses.id, users.id, "
+ "users.name, addresses.email_address AS email "
+ "FROM users JOIN addresses ON users.id=addresses.user_id "
+ "WHERE users.id = 1").columns(
+ User.id,
+ Address.id,
+ Address.user_id,
+ User.name,
+ Address.email_address
+ )
+
+ query = session.query(User).from_statement(stmt).options(
+ contains_eager(User.addresses))
+
+ .. versionadded:: 1.1 the :meth:`_expression.TextClause.columns`
+ method now
+ offers positional column targeting in the result set when
+ the column expressions are passed purely positionally.
+
+ The :meth:`_expression.TextClause.columns` method provides a direct
+ route to calling :meth:`_expression.FromClause.subquery` as well as
+ :meth:`_expression.SelectBase.cte`
+ against a textual SELECT statement::
+
+ stmt = stmt.columns(id=Integer, name=String).cte('st')
+
+ stmt = select(sometable).where(sometable.c.id == stmt.c.id)
+
+ :param \*cols: A series of :class:`_expression.ColumnElement` objects,
+ typically
+ :class:`_schema.Column` objects from a :class:`_schema.Table`
+ or ORM level
+ column-mapped attributes, representing a set of columns that this
+ textual string will SELECT from.
+
+ :param \**types: A mapping of string names to :class:`.TypeEngine`
+ type objects indicating the datatypes to use for names that are
+ SELECTed from the textual string. Prefer to use the ``*cols``
+ argument as it also indicates positional ordering.
+
+ """
+ selectable = util.preloaded.sql_selectable
+ positional_input_cols = [
+ ColumnClause(col.key, types.pop(col.key))
+ if col.key in types
+ else col
+ for col in cols
+ ]
+ keyed_input_cols = [
+ ColumnClause(key, type_) for key, type_ in types.items()
+ ]
+
+ return selectable.TextualSelect(
+ self,
+ positional_input_cols + keyed_input_cols,
+ positional=bool(positional_input_cols) and not keyed_input_cols,
+ )
+
+ @property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @property
+ def comparator(self):
+ return self.type.comparator_factory(self)
+
+ def self_group(self, against=None):
+ if against is operators.in_op:
+ return Grouping(self)
+ else:
+ return self
+
+
+class Null(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the NULL keyword in a SQL statement.
+
+ :class:`.Null` is accessed as a constant via the
+ :func:`.null` function.
+
+ """
+
+ __visit_name__ = "null"
+
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @classmethod
+ def _instance(cls):
+ """Return a constant :class:`.Null` construct."""
+
+ return Null()
+
+
+Null._create_singleton()
+
+
+class False_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the ``false`` keyword, or equivalent, in a SQL statement.
+
+ :class:`.False_` is accessed as a constant via the
+ :func:`.false` function.
+
+ """
+
+ __visit_name__ = "false"
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.BOOLEANTYPE
+
+ def _negate(self):
+ return True_()
+
+ @classmethod
+ def _instance(cls):
+ """Return a :class:`.False_` construct.
+
+ E.g.::
+
+ >>> from sqlalchemy import false
+ >>> print(select(t.c.x).where(false()))
+ SELECT x FROM t WHERE false
+
+ A backend which does not support true/false constants will render as
+ an expression against 1 or 0::
+
+ >>> print(select(t.c.x).where(false()))
+ SELECT x FROM t WHERE 0 = 1
+
+ The :func:`.true` and :func:`.false` constants also feature
+ "short circuit" operation within an :func:`.and_` or :func:`.or_`
+ conjunction::
+
+ >>> print(select(t.c.x).where(or_(t.c.x > 5, true())))
+ SELECT x FROM t WHERE true
+
+ >>> print(select(t.c.x).where(and_(t.c.x > 5, false())))
+ SELECT x FROM t WHERE false
+
+ .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature
+ better integrated behavior within conjunctions and on dialects
+ that don't support true/false constants.
+
+ .. seealso::
+
+ :func:`.true`
+
+ """
+
+ return False_()
+
+
+False_._create_singleton()
+
+
+class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the ``true`` keyword, or equivalent, in a SQL statement.
+
+ :class:`.True_` is accessed as a constant via the
+ :func:`.true` function.
+
+ """
+
+ __visit_name__ = "true"
+
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.BOOLEANTYPE
+
+ def _negate(self):
+ return False_()
+
+ @classmethod
+ def _ifnone(cls, other):
+ if other is None:
+ return cls._instance()
+ else:
+ return other
+
+ @classmethod
+ def _instance(cls):
+ """Return a constant :class:`.True_` construct.
+
+ E.g.::
+
+ >>> from sqlalchemy import true
+ >>> print(select(t.c.x).where(true()))
+ SELECT x FROM t WHERE true
+
+ A backend which does not support true/false constants will render as
+ an expression against 1 or 0::
+
+ >>> print(select(t.c.x).where(true()))
+ SELECT x FROM t WHERE 1 = 1
+
+ The :func:`.true` and :func:`.false` constants also feature
+ "short circuit" operation within an :func:`.and_` or :func:`.or_`
+ conjunction::
+
+ >>> print(select(t.c.x).where(or_(t.c.x > 5, true())))
+ SELECT x FROM t WHERE true
+
+ >>> print(select(t.c.x).where(and_(t.c.x > 5, false())))
+ SELECT x FROM t WHERE false
+
+ .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature
+ better integrated behavior within conjunctions and on dialects
+ that don't support true/false constants.
+
+ .. seealso::
+
+ :func:`.false`
+
+ """
+
+ return True_()
+
+
+True_._create_singleton()
+
+
+class ClauseList(
+ roles.InElementRole,
+ roles.OrderByRole,
+ roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
+ ClauseElement,
+):
+ """Describe a list of clauses, separated by an operator.
+
+ By default, is comma-separated, such as a column listing.
+
+ """
+
+ __visit_name__ = "clauselist"
+
+ _is_clause_list = True
+
+ _traverse_internals = [
+ ("clauses", InternalTraversal.dp_clauseelement_list),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
+ def __init__(self, *clauses, **kwargs):
+ self.operator = kwargs.pop("operator", operators.comma_op)
+ self.group = kwargs.pop("group", True)
+ self.group_contents = kwargs.pop("group_contents", True)
+ if kwargs.pop("_flatten_sub_clauses", False):
+ clauses = util.flatten_iterator(clauses)
+ self._text_converter_role = text_converter_role = kwargs.pop(
+ "_literal_as_text_role", roles.WhereHavingRole
+ )
+ if self.group_contents:
+ self.clauses = [
+ coercions.expect(
+ text_converter_role, clause, apply_propagate_attrs=self
+ ).self_group(against=self.operator)
+ for clause in clauses
+ ]
+ else:
+ self.clauses = [
+ coercions.expect(
+ text_converter_role, clause, apply_propagate_attrs=self
+ )
+ for clause in clauses
+ ]
+ self._is_implicitly_boolean = operators.is_boolean(self.operator)
+
+ @classmethod
+ def _construct_raw(cls, operator, clauses=None):
+ self = cls.__new__(cls)
+ self.clauses = clauses if clauses else []
+ self.group = True
+ self.operator = operator
+ self.group_contents = True
+ self._is_implicitly_boolean = False
+ return self
+
+ def __iter__(self):
+ return iter(self.clauses)
+
+ def __len__(self):
+ return len(self.clauses)
+
+ @property
+ def _select_iterable(self):
+ return itertools.chain.from_iterable(
+ [elem._select_iterable for elem in self.clauses]
+ )
+
+ def append(self, clause):
+ if self.group_contents:
+ self.clauses.append(
+ coercions.expect(self._text_converter_role, clause).self_group(
+ against=self.operator
+ )
+ )
+ else:
+ self.clauses.append(
+ coercions.expect(self._text_converter_role, clause)
+ )
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(*[c._from_objects for c in self.clauses]))
+
+ def self_group(self, against=None):
+ if self.group and operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+
+class BooleanClauseList(ClauseList, ColumnElement):
+ __visit_name__ = "clauselist"
+ inherit_cache = True
+
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError(
+ "BooleanClauseList has a private constructor"
+ )
+
+ @classmethod
+ def _process_clauses_for_boolean(
+ cls, operator, continue_on, skip_on, clauses
+ ):
+ has_continue_on = None
+
+ convert_clauses = []
+
+ against = operators._asbool
+ lcc = 0
+
+ for clause in clauses:
+ if clause is continue_on:
+ # instance of continue_on, like and_(x, y, True, z), store it
+ # if we didn't find one already, we will use it if there
+ # are no other expressions here.
+ has_continue_on = clause
+ elif clause is skip_on:
+ # instance of skip_on, e.g. and_(x, y, False, z), cancels
+ # the rest out
+ convert_clauses = [clause]
+ lcc = 1
+ break
+ else:
+ if not lcc:
+ lcc = 1
+ else:
+ against = operator
+ # technically this would be len(convert_clauses) + 1
+ # however this only needs to indicate "greater than one"
+ lcc = 2
+ convert_clauses.append(clause)
+
+ if not convert_clauses and has_continue_on is not None:
+ convert_clauses = [has_continue_on]
+ lcc = 1
+
+ return lcc, [c.self_group(against=against) for c in convert_clauses]
+
+ @classmethod
+ def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ [
+ coercions.expect(roles.WhereHavingRole, clause)
+ for clause in util.coerce_generator_arg(clauses)
+ ],
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ # no elements period. deprecated use case. return an empty
+ # ClauseList construct that generates nothing unless it has
+ # elements added to it.
+ util.warn_deprecated(
+ "Invoking %(name)s() without arguments is deprecated, and "
+ "will be disallowed in a future release. For an empty "
+ "%(name)s() construct, use %(name)s(%(continue_on)s, *args)."
+ % {
+ "name": operator.__name__,
+ "continue_on": "True"
+ if continue_on is True_._singleton
+ else "False",
+ },
+ version="1.4",
+ )
+ return cls._construct_raw(operator)
+
+ @classmethod
+ def _construct_for_whereclause(cls, clauses):
+ operator, continue_on, skip_on = (
+ operators.and_,
+ True_._singleton,
+ False_._singleton,
+ )
+
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ clauses, # these are assumed to be coerced already
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ return None
+
+ @classmethod
+ def _construct_raw(cls, operator, clauses=None):
+ self = cls.__new__(cls)
+ self.clauses = clauses if clauses else []
+ self.group = True
+ self.operator = operator
+ self.group_contents = True
+ self.type = type_api.BOOLEANTYPE
+ self._is_implicitly_boolean = True
+ return self
+
+ @classmethod
+ def and_(cls, *clauses):
+ r"""Produce a conjunction of expressions joined by ``AND``.
+
+ E.g.::
+
+ from sqlalchemy import and_
+
+ stmt = select(users_table).where(
+ and_(
+ users_table.c.name == 'wendy',
+ users_table.c.enrolled == True
+ )
+ )
+
+ The :func:`.and_` conjunction is also available using the
+ Python ``&`` operator (though note that compound expressions
+ need to be parenthesized in order to function with Python
+ operator precedence behavior)::
+
+ stmt = select(users_table).where(
+ (users_table.c.name == 'wendy') &
+ (users_table.c.enrolled == True)
+ )
+
+ The :func:`.and_` operation is also implicit in some cases;
+ the :meth:`_expression.Select.where`
+ method for example can be invoked multiple
+ times against a statement, which will have the effect of each
+ clause being combined using :func:`.and_`::
+
+ stmt = select(users_table).\
+ where(users_table.c.name == 'wendy').\
+ where(users_table.c.enrolled == True)
+
+ The :func:`.and_` construct must be given at least one positional
+ argument in order to be valid; a :func:`.and_` construct with no
+ arguments is ambiguous. To produce an "empty" or dynamically
+ generated :func:`.and_` expression, from a given list of expressions,
+ a "default" element of ``True`` should be specified::
+
+ criteria = and_(True, *expressions)
+
+ The above expression will compile to SQL as the expression ``true``
+ or ``1 = 1``, depending on backend, if no other expressions are
+ present. If expressions are present, then the ``True`` value is
+ ignored as it does not affect the outcome of an AND expression that
+ has other elements.
+
+ .. deprecated:: 1.4 The :func:`.and_` element now requires that at
+ least one argument is passed; creating the :func:`.and_` construct
+ with no arguments is deprecated, and will emit a deprecation warning
+ while continuing to produce a blank SQL string.
+
+ .. seealso::
+
+ :func:`.or_`
+
+ """
+ return cls._construct(
+ operators.and_, True_._singleton, False_._singleton, *clauses
+ )
+
+ @classmethod
+ def or_(cls, *clauses):
+ """Produce a conjunction of expressions joined by ``OR``.
+
+ E.g.::
+
+ from sqlalchemy import or_
+
+ stmt = select(users_table).where(
+ or_(
+ users_table.c.name == 'wendy',
+ users_table.c.name == 'jack'
+ )
+ )
+
+ The :func:`.or_` conjunction is also available using the
+ Python ``|`` operator (though note that compound expressions
+ need to be parenthesized in order to function with Python
+ operator precedence behavior)::
+
+ stmt = select(users_table).where(
+ (users_table.c.name == 'wendy') |
+ (users_table.c.name == 'jack')
+ )
+
+ The :func:`.or_` construct must be given at least one positional
+ argument in order to be valid; a :func:`.or_` construct with no
+ arguments is ambiguous. To produce an "empty" or dynamically
+ generated :func:`.or_` expression, from a given list of expressions,
+ a "default" element of ``False`` should be specified::
+
+ or_criteria = or_(False, *expressions)
+
+ The above expression will compile to SQL as the expression ``false``
+ or ``0 = 1``, depending on backend, if no other expressions are
+ present. If expressions are present, then the ``False`` value is
+ ignored as it does not affect the outcome of an OR expression which
+ has other elements.
+
+ .. deprecated:: 1.4 The :func:`.or_` element now requires that at
+ least one argument is passed; creating the :func:`.or_` construct
+ with no arguments is deprecated, and will emit a deprecation warning
+ while continuing to produce a blank SQL string.
+
+ .. seealso::
+
+ :func:`.and_`
+
+ """
+ return cls._construct(
+ operators.or_, False_._singleton, True_._singleton, *clauses
+ )
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def self_group(self, against=None):
+ if not self.clauses:
+ return self
+ else:
+ return super(BooleanClauseList, self).self_group(against=against)
+
+ def _negate(self):
+ return ClauseList._negate(self)
+
+
+and_ = BooleanClauseList.and_
+or_ = BooleanClauseList.or_
+
+
+class Tuple(ClauseList, ColumnElement):
+ """Represent a SQL tuple."""
+
+ __visit_name__ = "tuple"
+
+ _traverse_internals = ClauseList._traverse_internals + []
+
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def __init__(self, *clauses, **kw):
+ """Return a :class:`.Tuple`.
+
+ Main usage is to produce a composite IN construct using
+ :meth:`.ColumnOperators.in_` ::
+
+ from sqlalchemy import tuple_
+
+ tuple_(table.c.col1, table.c.col2).in_(
+ [(1, 2), (5, 12), (10, 19)]
+ )
+
+ .. versionchanged:: 1.3.6 Added support for SQLite IN tuples.
+
+ .. warning::
+
+ The composite IN construct is not supported by all backends, and is
+ currently known to work on PostgreSQL, MySQL, and SQLite.
+ Unsupported backends will raise a subclass of
+ :class:`~sqlalchemy.exc.DBAPIError` when such an expression is
+ invoked.
+
+ """
+ sqltypes = util.preloaded.sql_sqltypes
+
+ types = kw.pop("types", None)
+ if types is None:
+ clauses = [
+ coercions.expect(roles.ExpressionElementRole, c)
+ for c in clauses
+ ]
+ else:
+ if len(types) != len(clauses):
+ raise exc.ArgumentError(
+ "Wrong number of elements for %d-tuple: %r "
+ % (len(types), clauses)
+ )
+ clauses = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ type_=typ if not typ._isnull else None,
+ )
+ for typ, c in zip(types, clauses)
+ ]
+
+ self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
+ super(Tuple, self).__init__(*clauses, **kw)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ if expanding:
+ return BindParameter(
+ None,
+ value=obj,
+ _compared_to_operator=operator,
+ unique=True,
+ expanding=True,
+ type_=self.type,
+ )
+ else:
+ return Tuple(
+ *[
+ BindParameter(
+ None,
+ o,
+ _compared_to_operator=operator,
+ _compared_to_type=compared_to_type,
+ unique=True,
+ type_=type_,
+ )
+ for o, compared_to_type in zip(obj, self.type.types)
+ ]
+ )
+
+ def self_group(self, against=None):
+ # Tuple is parenthesized by definition.
+ return self
+
+
+class Case(ColumnElement):
+ """Represent a ``CASE`` expression.
+
+ :class:`.Case` is produced using the :func:`.case` factory function,
+ as in::
+
+ from sqlalchemy import case
+
+ stmt = select(users_table).\
+ where(
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J'),
+ else_='E'
+ )
+ )
+
+ Details on :class:`.Case` usage is at :func:`.case`.
+
+ .. seealso::
+
+ :func:`.case`
+
+ """
+
+ __visit_name__ = "case"
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ # TODO: for Py2k removal, this will be:
+ # def __init__(self, *whens, value=None, else_=None):
+
+ def __init__(self, *whens, **kw):
+ r"""Produce a ``CASE`` expression.
+
+ The ``CASE`` construct in SQL is a conditional object that
+ acts somewhat analogously to an "if/then" construct in other
+ languages. It returns an instance of :class:`.Case`.
+
+ :func:`.case` in its usual form is passed a series of "when"
+ constructs, that is, a list of conditions and results as tuples::
+
+ from sqlalchemy import case
+
+ stmt = select(users_table).\
+ where(
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J'),
+ else_='E'
+ )
+ )
+
+ The above statement will produce SQL resembling::
+
+ SELECT id, name FROM user
+ WHERE CASE
+ WHEN (name = :name_1) THEN :param_1
+ WHEN (name = :name_2) THEN :param_2
+ ELSE :param_3
+ END
+
+ When simple equality expressions of several values against a single
+ parent column are needed, :func:`.case` also has a "shorthand" format
+ used via the
+ :paramref:`.case.value` parameter, which is passed a column
+ expression to be compared. In this form, the :paramref:`.case.whens`
+ parameter is passed as a dictionary containing expressions to be
+ compared against keyed to result expressions. The statement below is
+ equivalent to the preceding statement::
+
+ stmt = select(users_table).\
+ where(
+ case(
+ {"wendy": "W", "jack": "J"},
+ value=users_table.c.name,
+ else_='E'
+ )
+ )
+
+ The values which are accepted as result values in
+ :paramref:`.case.whens` as well as with :paramref:`.case.else_` are
+ coerced from Python literals into :func:`.bindparam` constructs.
+ SQL expressions, e.g. :class:`_expression.ColumnElement` constructs,
+ are accepted
+ as well. To coerce a literal string expression into a constant
+ expression rendered inline, use the :func:`_expression.literal_column`
+ construct,
+ as in::
+
+ from sqlalchemy import case, literal_column
+
+ case(
+ (
+ orderline.c.qty > 100,
+ literal_column("'greaterthan100'")
+ ),
+ (
+ orderline.c.qty > 10,
+ literal_column("'greaterthan10'")
+ ),
+ else_=literal_column("'lessthan10'")
+ )
+
+ The above will render the given constants without using bound
+ parameters for the result values (but still for the comparison
+ values), as in::
+
+ CASE
+ WHEN (orderline.qty > :qty_1) THEN 'greaterthan100'
+ WHEN (orderline.qty > :qty_2) THEN 'greaterthan10'
+ ELSE 'lessthan10'
+ END
+
+ :param \*whens: The criteria to be compared against,
+ :paramref:`.case.whens` accepts two different forms, based on
+ whether or not :paramref:`.case.value` is used.
+
+ .. versionchanged:: 1.4 the :func:`_sql.case`
+ function now accepts the series of WHEN conditions positionally;
+ passing the expressions within a list is deprecated.
+
+ In the first form, it accepts a list of 2-tuples; each 2-tuple
+ consists of ``(<sql expression>, <value>)``, where the SQL
+ expression is a boolean expression and "value" is a resulting value,
+ e.g.::
+
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J')
+ )
+
+ In the second form, it accepts a Python dictionary of comparison
+ values mapped to a resulting value; this form requires
+ :paramref:`.case.value` to be present, and values will be compared
+ using the ``==`` operator, e.g.::
+
+ case(
+ {"wendy": "W", "jack": "J"},
+ value=users_table.c.name
+ )
+
+ :param value: An optional SQL expression which will be used as a
+ fixed "comparison point" for candidate values within a dictionary
+ passed to :paramref:`.case.whens`.
+
+ :param else\_: An optional SQL expression which will be the evaluated
+ result of the ``CASE`` construct if all expressions within
+ :paramref:`.case.whens` evaluate to false. When omitted, most
+ databases will produce a result of NULL if none of the "when"
+ expressions evaluate to true.
+
+
+ """
+
+ if "whens" in kw:
+ util.warn_deprecated_20(
+ 'The "whens" argument to case() is now passed using '
+ "positional style only, not as a keyword argument."
+ )
+ whens = (kw.pop("whens"),)
+
+ whens = coercions._expression_collection_was_a_list(
+ "whens", "case", whens
+ )
+
+ try:
+ whens = util.dictlike_iteritems(whens)
+ except TypeError:
+ pass
+
+ value = kw.pop("value", None)
+
+ whenlist = [
+ (
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ apply_propagate_attrs=self,
+ ).self_group(),
+ coercions.expect(roles.ExpressionElementRole, r),
+ )
+ for (c, r) in whens
+ ]
+
+ if whenlist:
+ type_ = list(whenlist[-1])[-1].type
+ else:
+ type_ = None
+
+ if value is None:
+ self.value = None
+ else:
+ self.value = coercions.expect(roles.ExpressionElementRole, value)
+
+ self.type = type_
+ self.whens = whenlist
+
+ else_ = kw.pop("else_", None)
+ if else_ is not None:
+ self.else_ = coercions.expect(roles.ExpressionElementRole, else_)
+ else:
+ self.else_ = None
+
+ if kw:
+ raise TypeError("unknown arguments: %s" % (", ".join(sorted(kw))))
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(*[x._from_objects for x in self.get_children()])
+ )
+
+
+def literal_column(text, type_=None):
+ r"""Produce a :class:`.ColumnClause` object that has the
+ :paramref:`_expression.column.is_literal` flag set to True.
+
+ :func:`_expression.literal_column` is similar to
+ :func:`_expression.column`, except that
+ it is more often used as a "standalone" column expression that renders
+ exactly as stated; while :func:`_expression.column`
+ stores a string name that
+ will be assumed to be part of a table and may be quoted as such,
+ :func:`_expression.literal_column` can be that,
+ or any other arbitrary column-oriented
+ expression.
+
+ :param text: the text of the expression; can be any SQL expression.
+ Quoting rules will not be applied. To specify a column-name expression
+ which should be subject to quoting rules, use the :func:`column`
+ function.
+
+ :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine`
+ object which will
+ provide result-set translation and additional expression semantics for
+ this column. If left as ``None`` the type will be :class:`.NullType`.
+
+ .. seealso::
+
+ :func:`_expression.column`
+
+ :func:`_expression.text`
+
+ :ref:`sqlexpression_literal_column`
+
+ """
+ return ColumnClause(text, type_=type_, is_literal=True)
+
+
+class Cast(WrapsColumnExpression, ColumnElement):
+ """Represent a ``CAST`` expression.
+
+ :class:`.Cast` is produced using the :func:`.cast` factory function,
+ as in::
+
+ from sqlalchemy import cast, Numeric
+
+ stmt = select(cast(product_table.c.unit_price, Numeric(10, 4)))
+
+ Details on :class:`.Cast` usage is at :func:`.cast`.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.cast`
+
+ :func:`.type_coerce` - an alternative to CAST that coerces the type
+ on the Python side only, which is often sufficient to generate the
+ correct SQL and data coercion.
+
+ """
+
+ __visit_name__ = "cast"
+
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("typeclause", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, expression, type_):
+ r"""Produce a ``CAST`` expression.
+
+ :func:`.cast` returns an instance of :class:`.Cast`.
+
+ E.g.::
+
+ from sqlalchemy import cast, Numeric
+
+ stmt = select(cast(product_table.c.unit_price, Numeric(10, 4)))
+
+ The above statement will produce SQL resembling::
+
+ SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product
+
+ The :func:`.cast` function performs two distinct functions when
+ used. The first is that it renders the ``CAST`` expression within
+ the resulting SQL string. The second is that it associates the given
+ type (e.g. :class:`.TypeEngine` class or instance) with the column
+ expression on the Python side, which means the expression will take
+ on the expression operator behavior associated with that type,
+ as well as the bound-value handling and result-row-handling behavior
+ of the type.
+
+ .. versionchanged:: 0.9.0 :func:`.cast` now applies the given type
+ to the expression such that it takes effect on the bound-value,
+ e.g. the Python-to-database direction, in addition to the
+ result handling, e.g. database-to-Python, direction.
+
+ An alternative to :func:`.cast` is the :func:`.type_coerce` function.
+ This function performs the second task of associating an expression
+ with a specific type, but does not render the ``CAST`` expression
+ in SQL.
+
+ :param expression: A SQL expression, such as a
+ :class:`_expression.ColumnElement`
+ expression or a Python string which will be coerced into a bound
+ literal value.
+
+ :param type\_: A :class:`.TypeEngine` class or instance indicating
+ the type to which the ``CAST`` should apply.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.type_coerce` - an alternative to CAST that coerces the type
+ on the Python side only, which is often sufficient to generate the
+ correct SQL and data coercion.
+
+
+ """
+ self.type = type_api.to_instance(type_)
+ self.clause = coercions.expect(
+ roles.ExpressionElementRole,
+ expression,
+ type_=self.type,
+ apply_propagate_attrs=self,
+ )
+ self.typeclause = TypeClause(self.type)
+
+ @property
+ def _from_objects(self):
+ return self.clause._from_objects
+
+ @property
+ def wrapped_column_expression(self):
+ return self.clause
+
+
+class TypeCoerce(WrapsColumnExpression, ColumnElement):
+ """Represent a Python-side type-coercion wrapper.
+
+ :class:`.TypeCoerce` supplies the :func:`_expression.type_coerce`
+ function; see that function for usage details.
+
+ .. versionchanged:: 1.1 The :func:`.type_coerce` function now produces
+ a persistent :class:`.TypeCoerce` wrapper object rather than
+ translating the given object in place.
+
+ .. seealso::
+
+ :func:`_expression.type_coerce`
+
+ :func:`.cast`
+
+ """
+
+ __visit_name__ = "type_coerce"
+
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ def __init__(self, expression, type_):
+ r"""Associate a SQL expression with a particular type, without rendering
+ ``CAST``.
+
+ E.g.::
+
+ from sqlalchemy import type_coerce
+
+ stmt = select(type_coerce(log_table.date_string, StringDateTime()))
+
+ The above construct will produce a :class:`.TypeCoerce` object, which
+ does not modify the rendering in any way on the SQL side, with the
+ possible exception of a generated label if used in a columns clause
+ context::
+
+ SELECT date_string AS date_string FROM log
+
+ When result rows are fetched, the ``StringDateTime`` type processor
+ will be applied to result rows on behalf of the ``date_string`` column.
+
+ .. note:: the :func:`.type_coerce` construct does not render any
+ SQL syntax of its own, including that it does not imply
+ parenthesization. Please use :meth:`.TypeCoerce.self_group`
+ if explicit parenthesization is required.
+
+ In order to provide a named label for the expression, use
+ :meth:`_expression.ColumnElement.label`::
+
+ stmt = select(
+ type_coerce(log_table.date_string, StringDateTime()).label('date')
+ )
+
+
+ A type that features bound-value handling will also have that behavior
+ take effect when literal values or :func:`.bindparam` constructs are
+ passed to :func:`.type_coerce` as targets.
+ For example, if a type implements the
+ :meth:`.TypeEngine.bind_expression`
+ method or :meth:`.TypeEngine.bind_processor` method or equivalent,
+ these functions will take effect at statement compilation/execution
+ time when a literal value is passed, as in::
+
+ # bound-value handling of MyStringType will be applied to the
+ # literal value "some string"
+ stmt = select(type_coerce("some string", MyStringType))
+
+ When using :func:`.type_coerce` with composed expressions, note that
+ **parenthesis are not applied**. If :func:`.type_coerce` is being
+ used in an operator context where the parenthesis normally present from
+ CAST are necessary, use the :meth:`.TypeCoerce.self_group` method::
+
+ >>> some_integer = column("someint", Integer)
+ >>> some_string = column("somestr", String)
+ >>> expr = type_coerce(some_integer + 5, String) + some_string
+ >>> print(expr)
+ someint + :someint_1 || somestr
+ >>> expr = type_coerce(some_integer + 5, String).self_group() + some_string
+ >>> print(expr)
+ (someint + :someint_1) || somestr
+
+ :param expression: A SQL expression, such as a
+ :class:`_expression.ColumnElement`
+ expression or a Python string which will be coerced into a bound
+ literal value.
+
+ :param type\_: A :class:`.TypeEngine` class or instance indicating
+ the type to which the expression is coerced.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.cast`
+
+ """ # noqa
+ self.type = type_api.to_instance(type_)
+ self.clause = coercions.expect(
+ roles.ExpressionElementRole,
+ expression,
+ type_=self.type,
+ apply_propagate_attrs=self,
+ )
+
+ @property
+ def _from_objects(self):
+ return self.clause._from_objects
+
+ @HasMemoized.memoized_attribute
+ def typed_expression(self):
+ if isinstance(self.clause, BindParameter):
+ bp = self.clause._clone()
+ bp.type = self.type
+ return bp
+ else:
+ return self.clause
+
+ @property
+ def wrapped_column_expression(self):
+ return self.clause
+
+ def self_group(self, against=None):
+ grouped = self.clause.self_group(against=against)
+ if grouped is not self.clause:
+ return TypeCoerce(grouped, self.type)
+ else:
+ return self
+
+
+class Extract(ColumnElement):
+ """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
+
+ __visit_name__ = "extract"
+
+ _traverse_internals = [
+ ("expr", InternalTraversal.dp_clauseelement),
+ ("field", InternalTraversal.dp_string),
+ ]
+
+ def __init__(self, field, expr, **kwargs):
+ """Return a :class:`.Extract` construct.
+
+ This is typically available as :func:`.extract`
+ as well as ``func.extract`` from the
+ :data:`.func` namespace.
+
+ :param field: The field to extract.
+
+ :param expr: A column or Python scalar expression serving as the
+ right side of the ``EXTRACT`` expression.
+
+ E.g.::
+
+ from sqlalchemy import extract
+ from sqlalchemy import table, column
+
+ logged_table = table("user",
+ column("id"),
+ column("date_created"),
+ )
+
+ stmt = select(logged_table.c.id).where(
+ extract("YEAR", logged_table.c.date_created) == 2021
+ )
+
+ In the above example, the statement is used to select ids from the
+ database where the ``YEAR`` component matches a specific value.
+
+ Similarly, one can also select an extracted component::
+
+ stmt = select(
+ extract("YEAR", logged_table.c.date_created)
+ ).where(logged_table.c.id == 1)
+
+ The implementation of ``EXTRACT`` may vary across database backends.
+ Users are reminded to consult their database documentation.
+ """
+ self.type = type_api.INTEGERTYPE
+ self.field = field
+ self.expr = coercions.expect(roles.ExpressionElementRole, expr)
+
+ @property
+ def _from_objects(self):
+ return self.expr._from_objects
+
+
+class _label_reference(ColumnElement):
+ """Wrap a column expression as it appears in a 'reference' context.
+
+ This expression is any that includes an _order_by_label_element,
+ which is a Label, or a DESC / ASC construct wrapping a Label.
+
+ The production of _label_reference() should occur when an expression
+ is added to this context; this includes the ORDER BY or GROUP BY of a
+ SELECT statement, as well as a few other places, such as the ORDER BY
+ within an OVER clause.
+
+ """
+
+ __visit_name__ = "label_reference"
+
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ def __init__(self, element):
+ self.element = element
+
+ @property
+ def _from_objects(self):
+ return ()
+
+
+class _textual_label_reference(ColumnElement):
+ __visit_name__ = "textual_label_reference"
+
+ _traverse_internals = [("element", InternalTraversal.dp_string)]
+
+ def __init__(self, element):
+ self.element = element
+
+ @util.memoized_property
+ def _text_clause(self):
+ return TextClause._create_text(self.element)
+
+
+class UnaryExpression(ColumnElement):
+ """Define a 'unary' expression.
+
+ A unary expression has a single column expression
+ and an operator. The operator can be placed on the left
+ (where it is called the 'operator') or right (where it is called the
+ 'modifier') of the column expression.
+
+ :class:`.UnaryExpression` is the basis for several unary operators
+ including those used by :func:`.desc`, :func:`.asc`, :func:`.distinct`,
+ :func:`.nulls_first` and :func:`.nulls_last`.
+
+ """
+
+ __visit_name__ = "unary"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("modifier", InternalTraversal.dp_operator),
+ ]
+
+ def __init__(
+ self,
+ element,
+ operator=None,
+ modifier=None,
+ type_=None,
+ wraps_column_expression=False,
+ ):
+ self.operator = operator
+ self.modifier = modifier
+ self._propagate_attrs = element._propagate_attrs
+ self.element = element.self_group(
+ against=self.operator or self.modifier
+ )
+ self.type = type_api.to_instance(type_)
+ self.wraps_column_expression = wraps_column_expression
+
+ @classmethod
+ def _create_nulls_first(cls, column):
+ """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression.
+
+ :func:`.nulls_first` is intended to modify the expression produced
+ by :func:`.asc` or :func:`.desc`, and indicates how NULL values
+ should be handled when they are encountered during ordering::
+
+
+ from sqlalchemy import desc, nulls_first
+
+ stmt = select(users_table).order_by(
+ nulls_first(desc(users_table.c.name)))
+
+ The SQL expression from the above would resemble::
+
+ SELECT id, name FROM user ORDER BY name DESC NULLS FIRST
+
+ Like :func:`.asc` and :func:`.desc`, :func:`.nulls_first` is typically
+ invoked from the column expression itself using
+ :meth:`_expression.ColumnElement.nulls_first`,
+ rather than as its standalone
+ function version, as in::
+
+ stmt = select(users_table).order_by(
+ users_table.c.name.desc().nulls_first())
+
+ .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from
+ :func:`.nullsfirst` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.desc`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.nulls_first_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_nulls_last(cls, column):
+ """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression.
+
+ :func:`.nulls_last` is intended to modify the expression produced
+ by :func:`.asc` or :func:`.desc`, and indicates how NULL values
+ should be handled when they are encountered during ordering::
+
+
+ from sqlalchemy import desc, nulls_last
+
+ stmt = select(users_table).order_by(
+ nulls_last(desc(users_table.c.name)))
+
+ The SQL expression from the above would resemble::
+
+ SELECT id, name FROM user ORDER BY name DESC NULLS LAST
+
+ Like :func:`.asc` and :func:`.desc`, :func:`.nulls_last` is typically
+ invoked from the column expression itself using
+ :meth:`_expression.ColumnElement.nulls_last`,
+ rather than as its standalone
+ function version, as in::
+
+ stmt = select(users_table).order_by(
+ users_table.c.name.desc().nulls_last())
+
+ .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from
+ :func:`.nullslast` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.desc`
+
+ :func:`.nulls_first`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.nulls_last_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_desc(cls, column):
+ """Produce a descending ``ORDER BY`` clause element.
+
+ e.g.::
+
+ from sqlalchemy import desc
+
+ stmt = select(users_table).order_by(desc(users_table.c.name))
+
+ will produce SQL as::
+
+ SELECT id, name FROM user ORDER BY name DESC
+
+ The :func:`.desc` function is a standalone version of the
+ :meth:`_expression.ColumnElement.desc`
+ method available on all SQL expressions,
+ e.g.::
+
+
+ stmt = select(users_table).order_by(users_table.c.name.desc())
+
+ :param column: A :class:`_expression.ColumnElement` (e.g.
+ scalar SQL expression)
+ with which to apply the :func:`.desc` operation.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.nulls_first`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.desc_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_asc(cls, column):
+ """Produce an ascending ``ORDER BY`` clause element.
+
+ e.g.::
+
+ from sqlalchemy import asc
+ stmt = select(users_table).order_by(asc(users_table.c.name))
+
+ will produce SQL as::
+
+ SELECT id, name FROM user ORDER BY name ASC
+
+ The :func:`.asc` function is a standalone version of the
+ :meth:`_expression.ColumnElement.asc`
+ method available on all SQL expressions,
+ e.g.::
+
+
+ stmt = select(users_table).order_by(users_table.c.name.asc())
+
+ :param column: A :class:`_expression.ColumnElement` (e.g.
+ scalar SQL expression)
+ with which to apply the :func:`.asc` operation.
+
+ .. seealso::
+
+ :func:`.desc`
+
+ :func:`.nulls_first`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.asc_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_distinct(cls, expr):
+ """Produce an column-expression-level unary ``DISTINCT`` clause.
+
+ This applies the ``DISTINCT`` keyword to an individual column
+ expression, and is typically contained within an aggregate function,
+ as in::
+
+ from sqlalchemy import distinct, func
+ stmt = select(func.count(distinct(users_table.c.name)))
+
+ The above would produce an expression resembling::
+
+ SELECT COUNT(DISTINCT name) FROM user
+
+ The :func:`.distinct` function is also available as a column-level
+ method, e.g. :meth:`_expression.ColumnElement.distinct`, as in::
+
+ stmt = select(func.count(users_table.c.name.distinct()))
+
+ The :func:`.distinct` operator is different from the
+ :meth:`_expression.Select.distinct` method of
+ :class:`_expression.Select`,
+ which produces a ``SELECT`` statement
+ with ``DISTINCT`` applied to the result set as a whole,
+ e.g. a ``SELECT DISTINCT`` expression. See that method for further
+ information.
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.distinct`
+
+ :meth:`_expression.Select.distinct`
+
+ :data:`.func`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return UnaryExpression(
+ expr,
+ operator=operators.distinct_op,
+ type_=expr.type,
+ wraps_column_expression=False,
+ )
+
+ @property
+ def _order_by_label_element(self):
+ if self.modifier in (operators.desc_op, operators.asc_op):
+ return self.element._order_by_label_element
+ else:
+ return None
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def _negate(self):
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ return UnaryExpression(
+ self.self_group(against=operators.inv),
+ operator=operators.inv,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=self.wraps_column_expression,
+ )
+ else:
+ return ClauseElement._negate(self)
+
+ def self_group(self, against=None):
+ if self.operator and operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+
+class CollectionAggregate(UnaryExpression):
+ """Forms the basis for right-hand collection operator modifiers
+ ANY and ALL.
+
+ The ANY and ALL keywords are available in different ways on different
+ backends. On PostgreSQL, they only work for an ARRAY type. On
+ MySQL, they only work for subqueries.
+
+ """
+
+ inherit_cache = True
+
+ @classmethod
+ def _create_any(cls, expr):
+ """Produce an ANY expression.
+
+ For dialects such as that of PostgreSQL, this operator applies
+ to usage of the :class:`_types.ARRAY` datatype, for that of
+ MySQL, it may apply to a subquery. e.g.::
+
+ # renders on PostgreSQL:
+ # '5 = ANY (somearray)'
+ expr = 5 == any_(mytable.c.somearray)
+
+ # renders on MySQL:
+ # '5 = ANY (SELECT value FROM table)'
+ expr = 5 == any_(select(table.c.value))
+
+ Comparison to NULL may work using ``None`` or :func:`_sql.null`::
+
+ None == any_(mytable.c.somearray)
+
+ The any_() / all_() operators also feature a special "operand flipping"
+ behavior such that if any_() / all_() are used on the left side of a
+ comparison using a standalone operator such as ``==``, ``!=``, etc.
+ (not including operator methods such as
+ :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped::
+
+ # would render '5 = ANY (column)`
+ any_(mytable.c.column) == 5
+
+ Or with ``None``, which note will not perform
+ the usual step of rendering "IS" as is normally the case for NULL::
+
+ # would render 'NULL = ANY(somearray)'
+ any_(mytable.c.somearray) == None
+
+ .. versionchanged:: 1.4.26 repaired the use of any_() / all_()
+ comparing to NULL on the right side to be flipped to the left.
+
+ The column-level :meth:`_sql.ColumnElement.any_` method (not to be
+ confused with :class:`_types.ARRAY` level
+ :meth:`_types.ARRAY.Comparator.any`) is shorthand for
+ ``any_(col)``::
+
+ 5 = mytable.c.somearray.any_()
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.any_`
+
+ :func:`_expression.all_`
+
+ """
+
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+
+ expr = expr.self_group()
+ return CollectionAggregate(
+ expr,
+ operator=operators.any_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_all(cls, expr):
+ """Produce an ALL expression.
+
+ For dialects such as that of PostgreSQL, this operator applies
+ to usage of the :class:`_types.ARRAY` datatype, for that of
+ MySQL, it may apply to a subquery. e.g.::
+
+ # renders on PostgreSQL:
+ # '5 = ALL (somearray)'
+ expr = 5 == all_(mytable.c.somearray)
+
+ # renders on MySQL:
+ # '5 = ALL (SELECT value FROM table)'
+ expr = 5 == all_(select(table.c.value))
+
+ Comparison to NULL may work using ``None``::
+
+ None == all_(mytable.c.somearray)
+
+ The any_() / all_() operators also feature a special "operand flipping"
+ behavior such that if any_() / all_() are used on the left side of a
+ comparison using a standalone operator such as ``==``, ``!=``, etc.
+ (not including operator methods such as
+ :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped::
+
+ # would render '5 = ALL (column)`
+ all_(mytable.c.column) == 5
+
+ Or with ``None``, which note will not perform
+ the usual step of rendering "IS" as is normally the case for NULL::
+
+ # would render 'NULL = ALL(somearray)'
+ all_(mytable.c.somearray) == None
+
+ .. versionchanged:: 1.4.26 repaired the use of any_() / all_()
+ comparing to NULL on the right side to be flipped to the left.
+
+ The column-level :meth:`_sql.ColumnElement.all_` method (not to be
+ confused with :class:`_types.ARRAY` level
+ :meth:`_types.ARRAY.Comparator.all`) is shorthand for
+ ``all_(col)``::
+
+ 5 == mytable.c.somearray.all_()
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.all_`
+
+ :func:`_expression.any_`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ expr = expr.self_group()
+ return CollectionAggregate(
+ expr,
+ operator=operators.all_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
+
+ # operate and reverse_operate are hardwired to
+ # dispatch onto the type comparator directly, so that we can
+ # ensure "reversed" behavior.
+ def operate(self, op, *other, **kwargs):
+ if not operators.is_comparison(op):
+ raise exc.ArgumentError(
+ "Only comparison operators may be used with ANY/ALL"
+ )
+ kwargs["reverse"] = kwargs["_any_all_expr"] = True
+ return self.comparator.operate(operators.mirror(op), *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ # comparison operators should never call reverse_operate
+ assert not operators.is_comparison(op)
+ raise exc.ArgumentError(
+ "Only comparison operators may be used with ANY/ALL"
+ )
+
+
+class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
+ def __init__(self, element, operator, negate):
+ self.element = element
+ self.type = type_api.BOOLEANTYPE
+ self.operator = operator
+ self.negate = negate
+ self.modifier = None
+ self.wraps_column_expression = True
+ self._is_implicitly_boolean = element._is_implicitly_boolean
+
+ @property
+ def wrapped_column_expression(self):
+ return self.element
+
+ def self_group(self, against=None):
+ return self
+
+ def _negate(self):
+ if isinstance(self.element, (True_, False_)):
+ return self.element._negate()
+ else:
+ return AsBoolean(self.element, self.negate, self.operator)
+
+
+class BinaryExpression(ColumnElement):
+ """Represent an expression that is ``LEFT <operator> RIGHT``.
+
+ A :class:`.BinaryExpression` is generated automatically
+ whenever two column expressions are used in a Python binary expression::
+
+ >>> from sqlalchemy.sql import column
+ >>> column('a') + column('b')
+ <sqlalchemy.sql.expression.BinaryExpression object at 0x101029dd0>
+ >>> print(column('a') + column('b'))
+ a + b
+
+ """
+
+ __visit_name__ = "binary"
+
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("negate", InternalTraversal.dp_operator),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ (
+ "type",
+ InternalTraversal.dp_type,
+ ), # affects JSON CAST operators
+ ]
+
+ _is_implicitly_boolean = True
+ """Indicates that any database will know this is a boolean expression
+ even if the database does not have an explicit boolean datatype.
+
+ """
+
+ def __init__(
+ self, left, right, operator, type_=None, negate=None, modifiers=None
+ ):
+ # allow compatibility with libraries that
+ # refer to BinaryExpression directly and pass strings
+ if isinstance(operator, util.string_types):
+ operator = operators.custom_op(operator)
+ self._orig = (left.__hash__(), right.__hash__())
+ self._propagate_attrs = left._propagate_attrs or right._propagate_attrs
+ self.left = left.self_group(against=operator)
+ self.right = right.self_group(against=operator)
+ self.operator = operator
+ self.type = type_api.to_instance(type_)
+ self.negate = negate
+ self._is_implicitly_boolean = operators.is_boolean(operator)
+
+ if modifiers is None:
+ self.modifiers = {}
+ else:
+ self.modifiers = modifiers
+
+ def __bool__(self):
+ if self.operator in (operator.eq, operator.ne):
+ return self.operator(*self._orig)
+ else:
+ raise TypeError("Boolean value of this clause is not defined")
+
+ __nonzero__ = __bool__
+
+ @property
+ def is_comparison(self):
+ return operators.is_comparison(self.operator)
+
+ @property
+ def _from_objects(self):
+ return self.left._from_objects + self.right._from_objects
+
+ def self_group(self, against=None):
+
+ if operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+ def _negate(self):
+ if self.negate is not None:
+ return BinaryExpression(
+ self.left,
+ self.right._negate_in_binary(self.negate, self.operator),
+ self.negate,
+ negate=self.operator,
+ type_=self.type,
+ modifiers=self.modifiers,
+ )
+ else:
+ return super(BinaryExpression, self)._negate()
+
+
+class Slice(ColumnElement):
+ """Represent SQL for a Python array-slice object.
+
+ This is not a specific SQL construct at this level, but
+ may be interpreted by specific dialects, e.g. PostgreSQL.
+
+ """
+
+ __visit_name__ = "slice"
+
+ _traverse_internals = [
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.type = type_api.NULLTYPE
+
+ def self_group(self, against=None):
+ assert against is operator.getitem
+ return self
+
+
+class IndexExpression(BinaryExpression):
+ """Represent the class of expressions that are like an "index"
+ operation."""
+
+ inherit_cache = True
+
+
+class GroupedElement(ClauseElement):
+ """Represent any parenthesized expression"""
+
+ __visit_name__ = "grouping"
+
+ def self_group(self, against=None):
+ return self
+
+ def _ungroup(self):
+ return self.element._ungroup()
+
+
+class Grouping(GroupedElement, ColumnElement):
+ """Represent a grouping within a column expression"""
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ def __init__(self, element):
+ self.element = element
+ self.type = getattr(element, "type", type_api.NULLTYPE)
+
+ def _with_binary_element_type(self, type_):
+ return self.__class__(self.element._with_binary_element_type(type_))
+
+ @util.memoized_property
+ def _is_implicitly_boolean(self):
+ return self.element._is_implicitly_boolean
+
+ @property
+ def _tq_label(self):
+ return (
+ getattr(self.element, "_tq_label", None) or self._anon_name_label
+ )
+
+ @property
+ def _proxies(self):
+ if isinstance(self.element, ColumnElement):
+ return [self.element]
+ else:
+ return []
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def __getattr__(self, attr):
+ return getattr(self.element, attr)
+
+ def __getstate__(self):
+ return {"element": self.element, "type": self.type}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+ self.type = state["type"]
+
+
+RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
+RANGE_CURRENT = util.symbol("RANGE_CURRENT")
+
+
+class Over(ColumnElement):
+ """Represent an OVER clause.
+
+ This is a special operator against a so-called
+ "window" function, as well as any aggregate function,
+ which produces results relative to the result set
+ itself. Most modern SQL backends now support window functions.
+
+ """
+
+ __visit_name__ = "over"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ("partition_by", InternalTraversal.dp_clauseelement),
+ ("range_", InternalTraversal.dp_plain_obj),
+ ("rows", InternalTraversal.dp_plain_obj),
+ ]
+
+ order_by = None
+ partition_by = None
+
+ element = None
+ """The underlying expression object to which this :class:`.Over`
+ object refers towards."""
+
+ def __init__(
+ self, element, partition_by=None, order_by=None, range_=None, rows=None
+ ):
+ r"""Produce an :class:`.Over` object against a function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ :func:`_expression.over` is usually called using
+ the :meth:`.FunctionElement.over` method, e.g.::
+
+ func.row_number().over(order_by=mytable.c.some_column)
+
+ Would produce::
+
+ ROW_NUMBER() OVER(ORDER BY some_column)
+
+ Ranges are also possible using the :paramref:`.expression.over.range_`
+ and :paramref:`.expression.over.rows` parameters. These
+ mutually-exclusive parameters each accept a 2-tuple, which contains
+ a combination of integers and None::
+
+ func.row_number().over(
+ order_by=my_table.c.some_column, range_=(None, 0))
+
+ The above would produce::
+
+ ROW_NUMBER() OVER(ORDER BY some_column
+ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+
+ A value of ``None`` indicates "unbounded", a
+ value of zero indicates "current row", and negative / positive
+ integers indicate "preceding" and "following":
+
+ * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(-5, 10))
+
+ * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW::
+
+ func.row_number().over(order_by='x', rows=(None, 0))
+
+ * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(-2, None))
+
+ * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(1, 3))
+
+ .. versionadded:: 1.1 support for RANGE / ROWS within a window
+
+
+ :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`,
+ or other compatible construct.
+ :param partition_by: a column element or string, or a list
+ of such, that will be used as the PARTITION BY clause
+ of the OVER construct.
+ :param order_by: a column element or string, or a list
+ of such, that will be used as the ORDER BY clause
+ of the OVER construct.
+ :param range\_: optional range clause for the window. This is a
+ tuple value which can contain integer values or ``None``,
+ and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause.
+
+ .. versionadded:: 1.1
+
+ :param rows: optional rows clause for the window. This is a tuple
+ value which can contain integer values or None, and will render
+ a ROWS BETWEEN PRECEDING / FOLLOWING clause.
+
+ .. versionadded:: 1.1
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.over` method.
+
+ .. seealso::
+
+ :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial`
+
+ :data:`.expression.func`
+
+ :func:`_expression.within_group`
+
+ """
+ self.element = element
+ if order_by is not None:
+ self.order_by = ClauseList(
+ *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
+ )
+ if partition_by is not None:
+ self.partition_by = ClauseList(
+ *util.to_list(partition_by),
+ _literal_as_text_role=roles.ByOfRole
+ )
+
+ if range_:
+ self.range_ = self._interpret_range(range_)
+ if rows:
+ raise exc.ArgumentError(
+ "'range_' and 'rows' are mutually exclusive"
+ )
+ else:
+ self.rows = None
+ elif rows:
+ self.rows = self._interpret_range(rows)
+ self.range_ = None
+ else:
+ self.rows = self.range_ = None
+
+ def __reduce__(self):
+ return self.__class__, (
+ self.element,
+ self.partition_by,
+ self.order_by,
+ self.range_,
+ self.rows,
+ )
+
+ def _interpret_range(self, range_):
+ if not isinstance(range_, tuple) or len(range_) != 2:
+ raise exc.ArgumentError("2-tuple expected for range/rows")
+
+ if range_[0] is None:
+ lower = RANGE_UNBOUNDED
+ else:
+ try:
+ lower = int(range_[0])
+ except ValueError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Integer or None expected for range value"
+ ),
+ replace_context=err,
+ )
+ else:
+ if lower == 0:
+ lower = RANGE_CURRENT
+
+ if range_[1] is None:
+ upper = RANGE_UNBOUNDED
+ else:
+ try:
+ upper = int(range_[1])
+ except ValueError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Integer or None expected for range value"
+ ),
+ replace_context=err,
+ )
+ else:
+ if upper == 0:
+ upper = RANGE_CURRENT
+
+ return lower, upper
+
+ @util.memoized_property
+ def type(self):
+ return self.element.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
+ )
+ )
+
+
+class WithinGroup(ColumnElement):
+ """Represent a WITHIN GROUP (ORDER BY) clause.
+
+ This is a special operator against so-called
+ "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including ``percentile_cont()``,
+ ``rank()``, ``dense_rank()``, etc.
+
+ It's supported only by certain database backends, such as PostgreSQL,
+ Oracle and MS SQL Server.
+
+ The :class:`.WithinGroup` construct extracts its type from the
+ method :meth:`.FunctionElement.within_group_type`. If this returns
+ ``None``, the function's ``.type`` is used.
+
+ """
+
+ __visit_name__ = "withingroup"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
+
+ order_by = None
+
+ def __init__(self, element, *order_by):
+ r"""Produce a :class:`.WithinGroup` object against a function.
+
+ Used against so-called "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including :class:`.percentile_cont`,
+ :class:`.rank`, :class:`.dense_rank`, etc.
+
+ :func:`_expression.within_group` is usually called using
+ the :meth:`.FunctionElement.within_group` method, e.g.::
+
+ from sqlalchemy import within_group
+ stmt = select(
+ department.c.id,
+ func.percentile_cont(0.5).within_group(
+ department.c.salary.desc()
+ )
+ )
+
+ The above statement would produce SQL similar to
+ ``SELECT department.id, percentile_cont(0.5)
+ WITHIN GROUP (ORDER BY department.salary DESC)``.
+
+ :param element: a :class:`.FunctionElement` construct, typically
+ generated by :data:`~.expression.func`.
+ :param \*order_by: one or more column elements that will be used
+ as the ORDER BY clause of the WITHIN GROUP construct.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` - in the
+ :ref:`unified_tutorial`
+
+ :data:`.expression.func`
+
+ :func:`_expression.over`
+
+ """
+ self.element = element
+ if order_by is not None:
+ self.order_by = ClauseList(
+ *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
+ )
+
+ def __reduce__(self):
+ return self.__class__, (self.element,) + tuple(self.order_by)
+
+ def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+ """Produce an OVER clause against this :class:`.WithinGroup`
+ construct.
+
+ This function has the same signature as that of
+ :meth:`.FunctionElement.over`.
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
+
+ @util.memoized_property
+ def type(self):
+ wgt = self.element.within_group_type(self)
+ if wgt is not None:
+ return wgt
+ else:
+ return self.element.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.order_by)
+ if c is not None
+ ]
+ )
+ )
+
+
+class FunctionFilter(ColumnElement):
+ """Represent a function FILTER clause.
+
+ This is a special operator against aggregate and window functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ Invocation of :class:`.FunctionFilter` is via
+ :meth:`.FunctionElement.filter`::
+
+ func.count(1).filter(True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+ """
+
+ __visit_name__ = "funcfilter"
+
+ _traverse_internals = [
+ ("func", InternalTraversal.dp_clauseelement),
+ ("criterion", InternalTraversal.dp_clauseelement),
+ ]
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce a :class:`.FunctionFilter` object against a function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` - in the
+ :ref:`unified_tutorial`
+
+ :meth:`.FunctionElement.filter`
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ """Produce an additional FILTER against the function.
+
+ This method adds additional criteria to the initial criteria
+ set up by :meth:`.FunctionElement.filter`.
+
+ Multiple criteria are joined together at SQL render time
+ via ``AND``.
+
+
+ """
+
+ for criterion in list(criterion):
+ criterion = coercions.expect(roles.WhereHavingRole, criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+ """Produce an OVER clause against this filtered function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.rank().filter(MyClass.y > 5).over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over, funcfilter
+ over(funcfilter(func.rank(), MyClass.y > 5), order_by='x')
+
+ See :func:`_expression.over` for a full description.
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
+
+ def self_group(self, against=None):
+ if operators.is_precedent(operators.filter_op, against):
+ return Grouping(self)
+ else:
+ return self
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.func, self.criterion)
+ if c is not None
+ ]
+ )
+ )
+
+
+class Label(roles.LabeledColumnExprRole, ColumnElement):
+ """Represents a column label (AS).
+
+ Represent a label, as typically applied to any column-level
+ element using the ``AS`` sql keyword.
+
+ """
+
+ __visit_name__ = "label"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("_type", InternalTraversal.dp_type),
+ ("_element", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, name, element, type_=None):
+ """Return a :class:`Label` object for the
+ given :class:`_expression.ColumnElement`.
+
+ A label changes the name of an element in the columns clause of a
+ ``SELECT`` statement, typically via the ``AS`` SQL keyword.
+
+ This functionality is more conveniently available via the
+ :meth:`_expression.ColumnElement.label` method on
+ :class:`_expression.ColumnElement`.
+
+ :param name: label name
+
+ :param obj: a :class:`_expression.ColumnElement`.
+
+ """
+
+ orig_element = element
+ element = coercions.expect(
+ roles.ExpressionElementRole,
+ element,
+ apply_propagate_attrs=self,
+ )
+ while isinstance(element, Label):
+ # TODO: this is only covered in test_text.py, but nothing
+ # fails if it's removed. determine rationale
+ element = element.element
+
+ if name:
+ self.name = name
+ else:
+ self.name = _anonymous_label.safe_construct(
+ id(self), getattr(element, "name", "anon")
+ )
+ if isinstance(orig_element, Label):
+ # TODO: no coverage for this block, again would be in
+ # test_text.py where the resolve_label concept is important
+ self._resolve_label = orig_element._label
+
+ self.key = self._tq_label = self._tq_key_label = self.name
+ self._element = element
+ self._type = type_
+ self._proxies = [element]
+
+ def __reduce__(self):
+ return self.__class__, (self.name, self._element, self._type)
+
+ @util.memoized_property
+ def _is_implicitly_boolean(self):
+ return self.element._is_implicitly_boolean
+
+ @HasMemoized.memoized_attribute
+ def _allow_label_resolve(self):
+ return self.element._allow_label_resolve
+
+ @property
+ def _order_by_label_element(self):
+ return self
+
+ @util.memoized_property
+ def type(self):
+ return type_api.to_instance(
+ self._type or getattr(self._element, "type", None)
+ )
+
+ @HasMemoized.memoized_attribute
+ def element(self):
+ return self._element.self_group(against=operators.as_)
+
+ def self_group(self, against=None):
+ return self._apply_to_inner(self._element.self_group, against=against)
+
+ def _negate(self):
+ return self._apply_to_inner(self._element._negate)
+
+ def _apply_to_inner(self, fn, *arg, **kw):
+ sub_element = fn(*arg, **kw)
+ if sub_element is not self._element:
+ return Label(self.name, sub_element, type_=self._type)
+ else:
+ return self
+
+ @property
+ def primary_key(self):
+ return self.element.primary_key
+
+ @property
+ def foreign_keys(self):
+ return self.element.foreign_keys
+
+ def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ self._reset_memoizations()
+ self._element = clone(self._element, **kw)
+ if anonymize_labels:
+ self.name = _anonymous_label.safe_construct(
+ id(self), getattr(self.element, "name", "anon")
+ )
+ self.key = self._tq_label = self._tq_key_label = self.name
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def _make_proxy(self, selectable, name=None, **kw):
+ name = self.name if not name else name
+
+ key, e = self.element._make_proxy(
+ selectable,
+ name=name,
+ disallow_is_literal=True,
+ name_is_truncatable=isinstance(name, _truncated_label),
+ )
+
+ # there was a note here to remove this assertion, which was here
+ # to determine if we later could support a use case where
+ # the key and name of a label are separate. But I don't know what
+ # that case was. For now, this is an unexpected case that occurs
+ # when a label name conflicts with other columns and select()
+ # is attempting to disambiguate an explicit label, which is not what
+ # the user would want. See issue #6090.
+ if key != self.name:
+ raise exc.InvalidRequestError(
+ "Label name %s is being renamed to an anonymous label due "
+ "to disambiguation "
+ "which is not supported right now. Please use unique names "
+ "for explicit labels." % (self.name)
+ )
+
+ e._propagate_attrs = selectable._propagate_attrs
+ e._proxies.append(self)
+ if self._type is not None:
+ e.type = self._type
+
+ return self.key, e
+
+
+class NamedColumn(ColumnElement):
+ is_literal = False
+ table = None
+
+ def _compare_name_for_result(self, other):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
+
+ @util.memoized_property
+ def description(self):
+ if util.py3k:
+ return self.name
+ else:
+ return self.name.encode("ascii", "backslashreplace")
+
+ @HasMemoized.memoized_attribute
+ def _tq_key_label(self):
+ """table qualified label based on column key.
+
+ for table-bound columns this is <tablename>_<column key/proxy key>;
+
+ all other expressions it resolves to key/proxy key.
+
+ """
+ proxy_key = self._proxy_key
+ if proxy_key and proxy_key != self.name:
+ return self._gen_tq_label(proxy_key)
+ else:
+ return self._tq_label
+
+ @HasMemoized.memoized_attribute
+ def _tq_label(self):
+ """table qualified label based on column name.
+
+ for table-bound columns this is <tablename>_<columnname>; all other
+ expressions it resolves to .name.
+
+ """
+ return self._gen_tq_label(self.name)
+
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return True
+
+ @HasMemoized.memoized_attribute
+ def _non_anon_label(self):
+ return self.name
+
+ def _gen_tq_label(self, name, dedupe_on_key=True):
+ return name
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ expanding=expanding,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ name_is_truncatable=False,
+ disallow_is_literal=False,
+ **kw
+ ):
+ c = ColumnClause(
+ coercions.expect(roles.TruncatedLabelRole, name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
+ type_=self.type,
+ _selectable=selectable,
+ is_literal=False,
+ )
+ c._propagate_attrs = selectable._propagate_attrs
+ if name is None:
+ c.key = self.key
+ c._proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ return c.key, c
+
+
+class ColumnClause(
+ roles.DDLReferredColumnRole,
+ roles.LabeledColumnExprRole,
+ roles.StrAsPlainColumnRole,
+ Immutable,
+ NamedColumn,
+):
+ """Represents a column expression from any textual string.
+
+ The :class:`.ColumnClause`, a lightweight analogue to the
+ :class:`_schema.Column` class, is typically invoked using the
+ :func:`_expression.column` function, as in::
+
+ from sqlalchemy import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The above statement would produce SQL like::
+
+ SELECT id, name FROM user
+
+ :class:`.ColumnClause` is the immediate superclass of the schema-specific
+ :class:`_schema.Column` object. While the :class:`_schema.Column`
+ class has all the
+ same capabilities as :class:`.ColumnClause`, the :class:`.ColumnClause`
+ class is usable by itself in those cases where behavioral requirements
+ are limited to simple SQL expression generation. The object has none of
+ the associations with schema-level metadata or with execution-time
+ behavior that :class:`_schema.Column` does,
+ so in that sense is a "lightweight"
+ version of :class:`_schema.Column`.
+
+ Full details on :class:`.ColumnClause` usage is at
+ :func:`_expression.column`.
+
+ .. seealso::
+
+ :func:`_expression.column`
+
+ :class:`_schema.Column`
+
+ """
+
+ table = None
+ is_literal = False
+
+ __visit_name__ = "column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("table", InternalTraversal.dp_clauseelement),
+ ("is_literal", InternalTraversal.dp_boolean),
+ ]
+
+ onupdate = default = server_default = server_onupdate = None
+
+ _is_multiparam_column = False
+
+ @property
+ def _is_star(self):
+ return self.is_literal and self.name == "*"
+
+ def __init__(self, text, type_=None, is_literal=False, _selectable=None):
+ """Produce a :class:`.ColumnClause` object.
+
+ The :class:`.ColumnClause` is a lightweight analogue to the
+ :class:`_schema.Column` class. The :func:`_expression.column`
+ function can
+ be invoked with just a name alone, as in::
+
+ from sqlalchemy import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The above statement would produce SQL like::
+
+ SELECT id, name FROM user
+
+ Once constructed, :func:`_expression.column`
+ may be used like any other SQL
+ expression element such as within :func:`_expression.select`
+ constructs::
+
+ from sqlalchemy.sql import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The text handled by :func:`_expression.column`
+ is assumed to be handled
+ like the name of a database column; if the string contains mixed case,
+ special characters, or matches a known reserved word on the target
+ backend, the column expression will render using the quoting
+ behavior determined by the backend. To produce a textual SQL
+ expression that is rendered exactly without any quoting,
+ use :func:`_expression.literal_column` instead,
+ or pass ``True`` as the
+ value of :paramref:`_expression.column.is_literal`. Additionally,
+ full SQL
+ statements are best handled using the :func:`_expression.text`
+ construct.
+
+ :func:`_expression.column` can be used in a table-like
+ fashion by combining it with the :func:`.table` function
+ (which is the lightweight analogue to :class:`_schema.Table`
+ ) to produce
+ a working table construct with minimal boilerplate::
+
+ from sqlalchemy import table, column, select
+
+ user = table("user",
+ column("id"),
+ column("name"),
+ column("description"),
+ )
+
+ stmt = select(user.c.description).where(user.c.name == 'wendy')
+
+ A :func:`_expression.column` / :func:`.table`
+ construct like that illustrated
+ above can be created in an
+ ad-hoc fashion and is not associated with any
+ :class:`_schema.MetaData`, DDL, or events, unlike its
+ :class:`_schema.Table` counterpart.
+
+ .. versionchanged:: 1.0.0 :func:`_expression.column` can now
+ be imported from the plain ``sqlalchemy`` namespace like any
+ other SQL element.
+
+ :param text: the text of the element.
+
+ :param type: :class:`_types.TypeEngine` object which can associate
+ this :class:`.ColumnClause` with a type.
+
+ :param is_literal: if True, the :class:`.ColumnClause` is assumed to
+ be an exact expression that will be delivered to the output with no
+ quoting rules applied regardless of case sensitive settings. the
+ :func:`_expression.literal_column()` function essentially invokes
+ :func:`_expression.column` while passing ``is_literal=True``.
+
+ .. seealso::
+
+ :class:`_schema.Column`
+
+ :func:`_expression.literal_column`
+
+ :func:`.table`
+
+ :func:`_expression.text`
+
+ :ref:`tutorial_select_arbitrary_text`
+
+ """
+ self.key = self.name = text
+ self.table = _selectable
+ self.type = type_api.to_instance(type_)
+ self.is_literal = is_literal
+
+ def get_children(self, column_tables=False, **kw):
+ # override base get_children() to not return the Table
+ # or selectable that is parent to this column. Traversals
+ # expect the columns of tables and subqueries to be leaf nodes.
+ return []
+
+ @property
+ def entity_namespace(self):
+ if self.table is not None:
+ return self.table.entity_namespace
+ else:
+ return super(ColumnClause, self).entity_namespace
+
+ def _clone(self, detect_subquery_cols=False, **kw):
+ if (
+ detect_subquery_cols
+ and self.table is not None
+ and self.table._is_subquery
+ ):
+ clone = kw.pop("clone")
+ table = clone(self.table, **kw)
+ new = table.c.corresponding_column(self)
+ return new
+
+ return super(ColumnClause, self)._clone(**kw)
+
+ @HasMemoized.memoized_attribute
+ def _from_objects(self):
+ t = self.table
+ if t is not None:
+ return [t]
+ else:
+ return []
+
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return self.table is not None
+
+ @property
+ def _ddl_label(self):
+ return self._gen_tq_label(self.name, dedupe_on_key=False)
+
+ def _compare_name_for_result(self, other):
+ if (
+ self.is_literal
+ or self.table is None
+ or self.table._is_textual
+ or not hasattr(other, "proxy_set")
+ or (
+ isinstance(other, ColumnClause)
+ and (
+ other.is_literal
+ or other.table is None
+ or other.table._is_textual
+ )
+ )
+ ):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_tq_label")
+ and self._tq_label == other._tq_label
+ )
+ else:
+ return other.proxy_set.intersection(self.proxy_set)
+
+ def _gen_tq_label(self, name, dedupe_on_key=True):
+ """generate table-qualified label
+
+ for a table-bound column this is <tablename>_<columnname>.
+
+ used primarily for LABEL_STYLE_TABLENAME_PLUS_COL
+ as well as the .columns collection on a Join object.
+
+ """
+ t = self.table
+ if self.is_literal:
+ return None
+ elif t is not None and t.named_with_column:
+ if getattr(t, "schema", None):
+ label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
+ else:
+ label = t.name + "_" + name
+
+ # propagate name quoting rules for labels.
+ if getattr(name, "quote", None) is not None:
+ if isinstance(label, quoted_name):
+ label.quote = name.quote
+ else:
+ label = quoted_name(label, name.quote)
+ elif getattr(t.name, "quote", None) is not None:
+ # can't get this situation to occur, so let's
+ # assert false on it for now
+ assert not isinstance(label, quoted_name)
+ label = quoted_name(label, t.name.quote)
+
+ if dedupe_on_key:
+ # ensure the label name doesn't conflict with that of an
+ # existing column. note that this implies that any Column
+ # must **not** set up its _label before its parent table has
+ # all of its other Column objects set up. There are several
+ # tables in the test suite which will fail otherwise; example:
+ # table "owner" has columns "name" and "owner_name". Therefore
+ # column owner.name cannot use the label "owner_name", it has
+ # to be "owner_name_1".
+ if label in t.c:
+ _label = label
+ counter = 1
+ while _label in t.c:
+ _label = label + "_" + str(counter)
+ counter += 1
+ label = _label
+
+ return coercions.expect(roles.TruncatedLabelRole, label)
+
+ else:
+ return name
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ name_is_truncatable=False,
+ disallow_is_literal=False,
+ **kw
+ ):
+ # the "is_literal" flag normally should never be propagated; a proxied
+ # column is always a SQL identifier and never the actual expression
+ # being evaluated. however, there is a case where the "is_literal" flag
+ # might be used to allow the given identifier to have a fixed quoting
+ # pattern already, so maintain the flag for the proxy unless a
+ # :class:`.Label` object is creating the proxy. See [ticket:4730].
+ is_literal = (
+ not disallow_is_literal
+ and self.is_literal
+ and (
+ # note this does not accommodate for quoted_name differences
+ # right now
+ name is None
+ or name == self.name
+ )
+ )
+ c = self._constructor(
+ coercions.expect(roles.TruncatedLabelRole, name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
+ type_=self.type,
+ _selectable=selectable,
+ is_literal=is_literal,
+ )
+ c._propagate_attrs = selectable._propagate_attrs
+ if name is None:
+ c.key = self.key
+ c._proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ return c.key, c
+
+
+class TableValuedColumn(NamedColumn):
+ __visit_name__ = "table_valued_column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("scalar_alias", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, scalar_alias, type_):
+ self.scalar_alias = scalar_alias
+ self.key = self.name = scalar_alias.name
+ self.type = type_
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.scalar_alias = clone(self.scalar_alias, **kw)
+ self.key = self.name = self.scalar_alias.name
+
+ @property
+ def _from_objects(self):
+ return [self.scalar_alias]
+
+
+class CollationClause(ColumnElement):
+ __visit_name__ = "collation"
+
+ _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
+ def __init__(self, collation):
+ self.collation = collation
+
+
+class _IdentifiedClause(Executable, ClauseElement):
+
+ __visit_name__ = "identified"
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": False}
+ )
+
+ def __init__(self, ident):
+ self.ident = ident
+
+
+class SavepointClause(_IdentifiedClause):
+ __visit_name__ = "savepoint"
+ inherit_cache = False
+
+
+class RollbackToSavepointClause(_IdentifiedClause):
+ __visit_name__ = "rollback_to_savepoint"
+ inherit_cache = False
+
+
+class ReleaseSavepointClause(_IdentifiedClause):
+ __visit_name__ = "release_savepoint"
+ inherit_cache = False
+
+
+class quoted_name(util.MemoizedSlots, util.text_type):
+ """Represent a SQL identifier combined with quoting preferences.
+
+ :class:`.quoted_name` is a Python unicode/str subclass which
+ represents a particular identifier name along with a
+ ``quote`` flag. This ``quote`` flag, when set to
+ ``True`` or ``False``, overrides automatic quoting behavior
+ for this identifier in order to either unconditionally quote
+ or to not quote the name. If left at its default of ``None``,
+ quoting behavior is applied to the identifier on a per-backend basis
+ based on an examination of the token itself.
+
+ A :class:`.quoted_name` object with ``quote=True`` is also
+ prevented from being modified in the case of a so-called
+ "name normalize" option. Certain database backends, such as
+ Oracle, Firebird, and DB2 "normalize" case-insensitive names
+ as uppercase. The SQLAlchemy dialects for these backends
+ convert from SQLAlchemy's lower-case-means-insensitive convention
+ to the upper-case-means-insensitive conventions of those backends.
+ The ``quote=True`` flag here will prevent this conversion from occurring
+ to support an identifier that's quoted as all lower case against
+ such a backend.
+
+ The :class:`.quoted_name` object is normally created automatically
+ when specifying the name for key schema constructs such as
+ :class:`_schema.Table`, :class:`_schema.Column`, and others.
+ The class can also be
+ passed explicitly as the name to any function that receives a name which
+ can be quoted. Such as to use the :meth:`_engine.Engine.has_table`
+ method with
+ an unconditionally quoted name::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy import inspect
+ from sqlalchemy.sql import quoted_name
+
+ engine = create_engine("oracle+cx_oracle://some_dsn")
+ print(inspect(engine).has_table(quoted_name("some_table", True)))
+
+ The above logic will run the "has table" logic against the Oracle backend,
+ passing the name exactly as ``"some_table"`` without converting to
+ upper case.
+
+ .. versionadded:: 0.9.0
+
+ .. versionchanged:: 1.2 The :class:`.quoted_name` construct is now
+ importable from ``sqlalchemy.sql``, in addition to the previous
+ location of ``sqlalchemy.sql.elements``.
+
+ """
+
+ __slots__ = "quote", "lower", "upper"
+
+ def __new__(cls, value, quote):
+ if value is None:
+ return None
+ # experimental - don't bother with quoted_name
+ # if quote flag is None. doesn't seem to make any dent
+ # in performance however
+ # elif not sprcls and quote is None:
+ # return value
+ elif isinstance(value, cls) and (
+ quote is None or value.quote == quote
+ ):
+ return value
+ self = super(quoted_name, cls).__new__(cls, value)
+
+ self.quote = quote
+ return self
+
+ def __reduce__(self):
+ return quoted_name, (util.text_type(self), self.quote)
+
+ def _memoized_method_lower(self):
+ if self.quote:
+ return self
+ else:
+ return util.text_type(self).lower()
+
+ def _memoized_method_upper(self):
+ if self.quote:
+ return self
+ else:
+ return util.text_type(self).upper()
+
+ def __repr__(self):
+ if util.py2k:
+ backslashed = self.encode("ascii", "backslashreplace")
+ if not util.py2k:
+ backslashed = backslashed.decode("ascii")
+ return "'%s'" % backslashed
+ else:
+ return str.__repr__(self)
+
+
+def _find_columns(clause):
+ """locate Column objects within the given expression."""
+
+ cols = util.column_set()
+ traverse(clause, {}, {"column": cols.add})
+ return cols
+
+
+def _type_from_args(args):
+ for a in args:
+ if not a.type._isnull:
+ return a.type
+ else:
+ return type_api.NULLTYPE
+
+
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+ c = fromclause.corresponding_column(
+ column, require_embedded=require_embedded
+ )
+ if c is None:
+ raise exc.InvalidRequestError(
+ "Given column '%s', attached to table '%s', "
+ "failed to locate a corresponding column from table '%s'"
+ % (column, getattr(column, "table", None), fromclause.description)
+ )
+ return c
+
+
+class AnnotatedColumnElement(Annotated):
+ def __init__(self, element, values):
+ Annotated.__init__(self, element, values)
+ for attr in (
+ "comparator",
+ "_proxy_key",
+ "_tq_key_label",
+ "_tq_label",
+ "_non_anon_label",
+ ):
+ self.__dict__.pop(attr, None)
+ for attr in ("name", "key", "table"):
+ if self.__dict__.get(attr, False) is None:
+ self.__dict__.pop(attr)
+
+ def _with_annotations(self, values):
+ clone = super(AnnotatedColumnElement, self)._with_annotations(values)
+ clone.__dict__.pop("comparator", None)
+ return clone
+
+ @util.memoized_property
+ def name(self):
+ """pull 'name' from parent, if not present"""
+ return self._Annotated__element.name
+
+ @util.memoized_property
+ def table(self):
+ """pull 'table' from parent, if not present"""
+ return self._Annotated__element.table
+
+ @util.memoized_property
+ def key(self):
+ """pull 'key' from parent, if not present"""
+ return self._Annotated__element.key
+
+ @util.memoized_property
+ def info(self):
+ return self._Annotated__element.info
+
+ @util.memoized_property
+ def _anon_name_label(self):
+ return self._Annotated__element._anon_name_label
+
+
+class _truncated_label(quoted_name):
+ """A unicode subclass used to identify symbolic "
+ "names that may require truncation."""
+
+ __slots__ = ()
+
+ def __new__(cls, value, quote=None):
+ quote = getattr(value, "quote", quote)
+ # return super(_truncated_label, cls).__new__(cls, value, quote, True)
+ return super(_truncated_label, cls).__new__(cls, value, quote)
+
+ def __reduce__(self):
+ return self.__class__, (util.text_type(self), self.quote)
+
+ def apply_map(self, map_):
+ return self
+
+
+class conv(_truncated_label):
+ """Mark a string indicating that a name has already been converted
+ by a naming convention.
+
+ This is a string subclass that indicates a name that should not be
+ subject to any further naming conventions.
+
+ E.g. when we create a :class:`.Constraint` using a naming convention
+ as follows::
+
+ m = MetaData(naming_convention={
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
+ t = Table('t', m, Column('x', Integer),
+ CheckConstraint('x > 5', name='x5'))
+
+ The name of the above constraint will be rendered as ``"ck_t_x5"``.
+ That is, the existing name ``x5`` is used in the naming convention as the
+ ``constraint_name`` token.
+
+ In some situations, such as in migration scripts, we may be rendering
+ the above :class:`.CheckConstraint` with a name that's already been
+ converted. In order to make sure the name isn't double-modified, the
+ new name is applied using the :func:`_schema.conv` marker. We can
+ use this explicitly as follows::
+
+
+ m = MetaData(naming_convention={
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
+ t = Table('t', m, Column('x', Integer),
+ CheckConstraint('x > 5', name=conv('ck_t_x5')))
+
+ Where above, the :func:`_schema.conv` marker indicates that the constraint
+ name here is final, and the name will render as ``"ck_t_x5"`` and not
+ ``"ck_t_ck_t_x5"``
+
+ .. versionadded:: 0.9.4
+
+ .. seealso::
+
+ :ref:`constraint_naming_conventions`
+
+ """
+
+ __slots__ = ()
+
+
+_NONE_NAME = util.symbol("NONE_NAME")
+"""indicate a 'deferred' name that was ultimately the value None."""
+
+# for backwards compatibility in case
+# someone is re-implementing the
+# _truncated_identifier() sequence in a custom
+# compiler
+_generated_label = _truncated_label
+
+
+class _anonymous_label(_truncated_label):
+ """A unicode subclass used to identify anonymously
+ generated names."""
+
+ __slots__ = ()
+
+ @classmethod
+ def safe_construct(
+ cls, seed, body, enclosing_label=None, sanitize_key=False
+ ):
+
+ if sanitize_key:
+ body = re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+
+ label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
+ if enclosing_label:
+ label = "%s%s" % (enclosing_label, label)
+
+ return _anonymous_label(label)
+
+ def __add__(self, other):
+ if "%" in other and not isinstance(other, _anonymous_label):
+ other = util.text_type(other).replace("%", "%%")
+ else:
+ other = util.text_type(other)
+
+ return _anonymous_label(
+ quoted_name(
+ util.text_type.__add__(self, other),
+ self.quote,
+ )
+ )
+
+ def __radd__(self, other):
+ if "%" in other and not isinstance(other, _anonymous_label):
+ other = util.text_type(other).replace("%", "%%")
+ else:
+ other = util.text_type(other)
+
+ return _anonymous_label(
+ quoted_name(
+ util.text_type.__add__(other, self),
+ self.quote,
+ )
+ )
+
+ def apply_map(self, map_):
+ if self.quote is not None:
+ # preserve quoting only if necessary
+ return quoted_name(self % map_, self.quote)
+ else:
+ # else skip the constructor call
+ return self % map_
diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py
new file mode 100644
index 0000000..c425789
--- /dev/null
+++ b/lib/sqlalchemy/sql/events.py
@@ -0,0 +1,331 @@
+# sqlalchemy/sql/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import SchemaEventTarget
+from .. import event
+
+
+class DDLEvents(event.Events):
+ """
+ Define event listeners for schema objects,
+ that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
+ subclasses, including :class:`_schema.MetaData`, :class:`_schema.Table`,
+ :class:`_schema.Column`.
+
+ :class:`_schema.MetaData` and :class:`_schema.Table` support events
+ specifically regarding when CREATE and DROP
+ DDL is emitted to the database.
+
+ Attachment events are also provided to customize
+ behavior whenever a child schema element is associated
+ with a parent, such as, when a :class:`_schema.Column` is associated
+ with its :class:`_schema.Table`, when a
+ :class:`_schema.ForeignKeyConstraint`
+ is associated with a :class:`_schema.Table`, etc.
+
+ Example using the ``after_create`` event::
+
+ from sqlalchemy import event
+ from sqlalchemy import Table, Column, Metadata, Integer
+
+ m = MetaData()
+ some_table = Table('some_table', m, Column('data', Integer))
+
+ def after_create(target, connection, **kw):
+ connection.execute(text(
+ "ALTER TABLE %s SET name=foo_%s" % (target.name, target.name)
+ ))
+
+ event.listen(some_table, "after_create", after_create)
+
+ DDL events integrate closely with the
+ :class:`.DDL` class and the :class:`.DDLElement` hierarchy
+ of DDL clause constructs, which are themselves appropriate
+ as listener callables::
+
+ from sqlalchemy import DDL
+ event.listen(
+ some_table,
+ "after_create",
+ DDL("ALTER TABLE %(table)s SET name=foo_%(table)s")
+ )
+
+ The methods here define the name of an event as well
+ as the names of members that are passed to listener
+ functions.
+
+ For all :class:`.DDLEvent` events, the ``propagate=True`` keyword argument
+ will ensure that a given event handler is propagated to copies of the
+ object, which are made when using the :meth:`_schema.Table.to_metadata`
+ method::
+
+ from sqlalchemy import DDL
+ event.listen(
+ some_table,
+ "after_create",
+ DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"),
+ propagate=True
+ )
+
+ new_table = some_table.to_metadata(new_metadata)
+
+ The above :class:`.DDL` object will also be associated with the
+ :class:`_schema.Table` object represented by ``new_table``.
+
+ .. seealso::
+
+ :ref:`event_toplevel`
+
+ :class:`.DDLElement`
+
+ :class:`.DDL`
+
+ :ref:`schema_ddl_sequences`
+
+ """
+
+ _target_class_doc = "SomeSchemaClassOrObject"
+ _dispatch_target = SchemaEventTarget
+
+ def before_create(self, target, connection, **kw):
+ r"""Called before CREATE statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ CREATE statement or statements will be emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ :func:`.event.listen` accepts the ``insert=True``
+ modifier for this event; when True, the listener function will
+ be prepended to the internal list of events upon discovery, and execute
+ before registered listener functions that do not pass this argument.
+
+ """
+
+ def after_create(self, target, connection, **kw):
+ r"""Called after CREATE statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ CREATE statement or statements have been emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def before_drop(self, target, connection, **kw):
+ r"""Called before DROP statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ DROP statement or statements will be emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def after_drop(self, target, connection, **kw):
+ r"""Called after DROP statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ DROP statement or statements have been emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def before_parent_attach(self, target, parent):
+ """Called before a :class:`.SchemaItem` is associated with
+ a parent :class:`.SchemaItem`.
+
+ :param target: the target object
+ :param parent: the parent to which the target is being attached.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def after_parent_attach(self, target, parent):
+ """Called after a :class:`.SchemaItem` is associated with
+ a parent :class:`.SchemaItem`.
+
+ :param target: the target object
+ :param parent: the parent to which the target is being attached.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def _sa_event_column_added_to_pk_constraint(self, const, col):
+ """internal event hook used for primary key naming convention
+ updates.
+
+ """
+
+ def column_reflect(self, inspector, table, column_info):
+ """Called for each unit of 'column info' retrieved when
+ a :class:`_schema.Table` is being reflected.
+
+ This event is most easily used by applying it to a specific
+ :class:`_schema.MetaData` instance, where it will take effect for
+ all :class:`_schema.Table` objects within that
+ :class:`_schema.MetaData` that undergo reflection::
+
+ metadata = MetaData()
+
+ @event.listens_for(metadata, 'column_reflect')
+ def receive_column_reflect(inspector, table, column_info):
+ # receives for all Table objects that are reflected
+ # under this MetaData
+
+
+ # will use the above event hook
+ my_table = Table("my_table", metadata, autoload_with=some_engine)
+
+
+ .. versionadded:: 1.4.0b2 The :meth:`_events.DDLEvents.column_reflect`
+ hook may now be applied to a :class:`_schema.MetaData` object as
+ well as the :class:`_schema.MetaData` class itself where it will
+ take place for all :class:`_schema.Table` objects associated with
+ the targeted :class:`_schema.MetaData`.
+
+ It may also be applied to the :class:`_schema.Table` class across
+ the board::
+
+ from sqlalchemy import Table
+
+ @event.listens_for(Table, 'column_reflect')
+ def receive_column_reflect(inspector, table, column_info):
+ # receives for all Table objects that are reflected
+
+ It can also be applied to a specific :class:`_schema.Table` at the
+ point that one is being reflected using the
+ :paramref:`_schema.Table.listeners` parameter::
+
+ t1 = Table(
+ "my_table",
+ autoload_with=some_engine,
+ listeners=[
+ ('column_reflect', receive_column_reflect)
+ ]
+ )
+
+ A future release will allow it to be associated with a specific
+ :class:`_schema.MetaData` object as well.
+
+ The dictionary of column information as returned by the
+ dialect is passed, and can be modified. The dictionary
+ is that returned in each element of the list returned
+ by :meth:`.reflection.Inspector.get_columns`:
+
+ * ``name`` - the column's name, is applied to the
+ :paramref:`_schema.Column.name` parameter
+
+ * ``type`` - the type of this column, which should be an instance
+ of :class:`~sqlalchemy.types.TypeEngine`, is applied to the
+ :paramref:`_schema.Column.type` parameter
+
+ * ``nullable`` - boolean flag if the column is NULL or NOT NULL,
+ is applied to the :paramref:`_schema.Column.nullable` parameter
+
+ * ``default`` - the column's server default value. This is
+ normally specified as a plain string SQL expression, however the
+ event can pass a :class:`.FetchedValue`, :class:`.DefaultClause`,
+ or :func:`_expression.text` object as well. Is applied to the
+ :paramref:`_schema.Column.server_default` parameter
+
+ The event is called before any action is taken against
+ this dictionary, and the contents can be modified; the following
+ additional keys may be added to the dictionary to further modify
+ how the :class:`_schema.Column` is constructed:
+
+
+ * ``key`` - the string key that will be used to access this
+ :class:`_schema.Column` in the ``.c`` collection; will be applied
+ to the :paramref:`_schema.Column.key` parameter. Is also used
+ for ORM mapping. See the section
+ :ref:`mapper_automated_reflection_schemes` for an example.
+
+ * ``quote`` - force or un-force quoting on the column name;
+ is applied to the :paramref:`_schema.Column.quote` parameter.
+
+ * ``info`` - a dictionary of arbitrary data to follow along with
+ the :class:`_schema.Column`, is applied to the
+ :paramref:`_schema.Column.info` parameter.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ .. seealso::
+
+ :ref:`mapper_automated_reflection_schemes` -
+ in the ORM mapping documentation
+
+ :ref:`automap_intercepting_columns` -
+ in the :ref:`automap_toplevel` documentation
+
+ :ref:`metadata_reflection_dbagnostic_types` - in
+ the :ref:`metadata_reflection_toplevel` documentation
+
+ """
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
new file mode 100644
index 0000000..b4aa14e
--- /dev/null
+++ b/lib/sqlalchemy/sql/expression.py
@@ -0,0 +1,278 @@
+# sql/expression.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines the public namespace for SQL expression constructs.
+
+Prior to version 0.9, this module contained all of "elements", "dml",
+"default_comparator" and "selectable". The module was broken up
+and most "factory" functions were moved to be grouped with their associated
+class.
+
+"""
+
+__all__ = [
+ "Alias",
+ "AliasedReturnsRows",
+ "any_",
+ "all_",
+ "CacheKey",
+ "ClauseElement",
+ "ColumnCollection",
+ "ColumnElement",
+ "CompoundSelect",
+ "Delete",
+ "FromClause",
+ "Insert",
+ "Join",
+ "Lateral",
+ "LambdaElement",
+ "StatementLambdaElement",
+ "Select",
+ "Selectable",
+ "TableClause",
+ "TableValuedAlias",
+ "Update",
+ "Values",
+ "alias",
+ "and_",
+ "asc",
+ "between",
+ "bindparam",
+ "case",
+ "cast",
+ "column",
+ "custom_op",
+ "cte",
+ "delete",
+ "desc",
+ "distinct",
+ "except_",
+ "except_all",
+ "exists",
+ "extract",
+ "func",
+ "modifier",
+ "collate",
+ "insert",
+ "intersect",
+ "intersect_all",
+ "join",
+ "label",
+ "lateral",
+ "lambda_stmt",
+ "literal",
+ "literal_column",
+ "not_",
+ "null",
+ "nulls_first",
+ "nulls_last",
+ "or_",
+ "outparam",
+ "outerjoin",
+ "over",
+ "select",
+ "table",
+ "text",
+ "tuple_",
+ "type_coerce",
+ "quoted_name",
+ "union",
+ "union_all",
+ "update",
+ "quoted_name",
+ "within_group",
+ "Subquery",
+ "TableSample",
+ "tablesample",
+ "values",
+]
+
+
+from .base import _from_objects
+from .base import _select_iterables
+from .base import ColumnCollection
+from .base import Executable
+from .base import PARSE_AUTOCOMMIT
+from .dml import Delete
+from .dml import Insert
+from .dml import Update
+from .dml import UpdateBase
+from .dml import ValuesBase
+from .elements import _truncated_label
+from .elements import between
+from .elements import BinaryExpression
+from .elements import BindParameter
+from .elements import BooleanClauseList
+from .elements import Case
+from .elements import Cast
+from .elements import ClauseElement
+from .elements import ClauseList
+from .elements import collate
+from .elements import CollectionAggregate
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import Extract
+from .elements import False_
+from .elements import FunctionFilter
+from .elements import Grouping
+from .elements import Label
+from .elements import literal
+from .elements import literal_column
+from .elements import not_
+from .elements import Null
+from .elements import outparam
+from .elements import Over
+from .elements import quoted_name
+from .elements import ReleaseSavepointClause
+from .elements import RollbackToSavepointClause
+from .elements import SavepointClause
+from .elements import TextClause
+from .elements import True_
+from .elements import Tuple
+from .elements import TypeClause
+from .elements import TypeCoerce
+from .elements import UnaryExpression
+from .elements import WithinGroup
+from .functions import func
+from .functions import Function
+from .functions import FunctionElement
+from .functions import modifier
+from .lambdas import lambda_stmt
+from .lambdas import LambdaElement
+from .lambdas import StatementLambdaElement
+from .operators import ColumnOperators
+from .operators import custom_op
+from .operators import Operators
+from .selectable import Alias
+from .selectable import AliasedReturnsRows
+from .selectable import CompoundSelect
+from .selectable import CTE
+from .selectable import Exists
+from .selectable import FromClause
+from .selectable import FromGrouping
+from .selectable import GenerativeSelect
+from .selectable import HasCTE
+from .selectable import HasPrefixes
+from .selectable import HasSuffixes
+from .selectable import Join
+from .selectable import LABEL_STYLE_DEFAULT
+from .selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .selectable import LABEL_STYLE_NONE
+from .selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from .selectable import Lateral
+from .selectable import ReturnsRows
+from .selectable import ScalarSelect
+from .selectable import Select
+from .selectable import Selectable
+from .selectable import SelectBase
+from .selectable import Subquery
+from .selectable import subquery
+from .selectable import TableClause
+from .selectable import TableSample
+from .selectable import TableValuedAlias
+from .selectable import TextAsFrom
+from .selectable import TextualSelect
+from .selectable import Values
+from .traversals import CacheKey
+from .visitors import Visitable
+from ..util.langhelpers import public_factory
+
+# factory functions - these pull class-bound constructors and classmethods
+# from SQL elements and selectables into public functions. This allows
+# the functions to be available in the sqlalchemy.sql.* namespace and
+# to be auto-cross-documenting from the function to the class itself.
+
+all_ = public_factory(CollectionAggregate._create_all, ".sql.expression.all_")
+any_ = public_factory(CollectionAggregate._create_any, ".sql.expression.any_")
+and_ = public_factory(BooleanClauseList.and_, ".sql.expression.and_")
+alias = public_factory(Alias._factory, ".sql.expression.alias")
+tablesample = public_factory(
+ TableSample._factory, ".sql.expression.tablesample"
+)
+lateral = public_factory(Lateral._factory, ".sql.expression.lateral")
+or_ = public_factory(BooleanClauseList.or_, ".sql.expression.or_")
+bindparam = public_factory(BindParameter, ".sql.expression.bindparam")
+select = public_factory(Select._create, ".sql.expression.select")
+text = public_factory(TextClause._create_text, ".sql.expression.text")
+table = public_factory(TableClause, ".sql.expression.table")
+column = public_factory(ColumnClause, ".sql.expression.column")
+over = public_factory(Over, ".sql.expression.over")
+within_group = public_factory(WithinGroup, ".sql.expression.within_group")
+label = public_factory(Label, ".sql.expression.label")
+case = public_factory(Case, ".sql.expression.case")
+cast = public_factory(Cast, ".sql.expression.cast")
+cte = public_factory(CTE._factory, ".sql.expression.cte")
+values = public_factory(Values, ".sql.expression.values")
+extract = public_factory(Extract, ".sql.expression.extract")
+tuple_ = public_factory(Tuple, ".sql.expression.tuple_")
+except_ = public_factory(
+ CompoundSelect._create_except, ".sql.expression.except_"
+)
+except_all = public_factory(
+ CompoundSelect._create_except_all, ".sql.expression.except_all"
+)
+intersect = public_factory(
+ CompoundSelect._create_intersect, ".sql.expression.intersect"
+)
+intersect_all = public_factory(
+ CompoundSelect._create_intersect_all, ".sql.expression.intersect_all"
+)
+union = public_factory(CompoundSelect._create_union, ".sql.expression.union")
+union_all = public_factory(
+ CompoundSelect._create_union_all, ".sql.expression.union_all"
+)
+exists = public_factory(Exists, ".sql.expression.exists")
+nulls_first = public_factory(
+ UnaryExpression._create_nulls_first, ".sql.expression.nulls_first"
+)
+nullsfirst = nulls_first # deprecated 1.4; see #5435
+nulls_last = public_factory(
+ UnaryExpression._create_nulls_last, ".sql.expression.nulls_last"
+)
+nullslast = nulls_last # deprecated 1.4; see #5435
+asc = public_factory(UnaryExpression._create_asc, ".sql.expression.asc")
+desc = public_factory(UnaryExpression._create_desc, ".sql.expression.desc")
+distinct = public_factory(
+ UnaryExpression._create_distinct, ".sql.expression.distinct"
+)
+type_coerce = public_factory(TypeCoerce, ".sql.expression.type_coerce")
+true = public_factory(True_._instance, ".sql.expression.true")
+false = public_factory(False_._instance, ".sql.expression.false")
+null = public_factory(Null._instance, ".sql.expression.null")
+join = public_factory(Join._create_join, ".sql.expression.join")
+outerjoin = public_factory(Join._create_outerjoin, ".sql.expression.outerjoin")
+insert = public_factory(Insert, ".sql.expression.insert")
+update = public_factory(Update, ".sql.expression.update")
+delete = public_factory(Delete, ".sql.expression.delete")
+funcfilter = public_factory(FunctionFilter, ".sql.expression.funcfilter")
+
+
+# internal functions still being called from tests and the ORM,
+# these might be better off in some other namespace
+
+
+# old names for compatibility
+_Executable = Executable
+_BindParamClause = BindParameter
+_Label = Label
+_SelectBase = SelectBase
+_BinaryExpression = BinaryExpression
+_Cast = Cast
+_Null = Null
+_False = False_
+_True = True_
+_TextClause = TextClause
+_UnaryExpression = UnaryExpression
+_Case = Case
+_Tuple = Tuple
+_Over = Over
+_TypeClause = TypeClause
+_Extract = Extract
+_Exists = Exists
+_Grouping = Grouping
+_FromGrouping = FromGrouping
+_ScalarSelect = ScalarSelect
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
new file mode 100644
index 0000000..29f4122
--- /dev/null
+++ b/lib/sqlalchemy/sql/functions.py
@@ -0,0 +1,1575 @@
+# sql/functions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL function API, factories, and built-in functions.
+
+"""
+from . import annotation
+from . import coercions
+from . import operators
+from . import roles
+from . import schema
+from . import sqltypes
+from . import util as sqlutil
+from .base import _entity_namespace
+from .base import ColumnCollection
+from .base import Executable
+from .base import Generative
+from .base import HasMemoized
+from .elements import _type_from_args
+from .elements import BinaryExpression
+from .elements import BindParameter
+from .elements import Cast
+from .elements import ClauseList
+from .elements import ColumnElement
+from .elements import Extract
+from .elements import FunctionFilter
+from .elements import Grouping
+from .elements import literal_column
+from .elements import NamedColumn
+from .elements import Over
+from .elements import WithinGroup
+from .selectable import FromClause
+from .selectable import Select
+from .selectable import TableValuedAlias
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
+from .. import util
+
+
+_registry = util.defaultdict(dict)
+
+
+def register_function(identifier, fn, package="_default"):
+ """Associate a callable with a particular func. name.
+
+ This is normally called by _GenericMeta, but is also
+ available by itself so that a non-Function construct
+ can be associated with the :data:`.func` accessor (i.e.
+ CAST, EXTRACT).
+
+ """
+ reg = _registry[package]
+
+ identifier = util.text_type(identifier).lower()
+
+ # Check if a function with the same identifier is registered.
+ if identifier in reg:
+ util.warn(
+ "The GenericFunction '{}' is already registered and "
+ "is going to be overridden.".format(identifier)
+ )
+ reg[identifier] = fn
+
+
+class FunctionElement(Executable, ColumnElement, FromClause, Generative):
+ """Base for SQL function-oriented constructs.
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :class:`.Function` - named SQL function.
+
+ :data:`.func` - namespace which produces registered or ad-hoc
+ :class:`.Function` instances.
+
+ :class:`.GenericFunction` - allows creation of registered function
+ types.
+
+ """
+
+ _traverse_internals = [
+ ("clause_expr", InternalTraversal.dp_clauseelement),
+ ("_with_ordinality", InternalTraversal.dp_boolean),
+ ("_table_value_type", InternalTraversal.dp_has_cache_key),
+ ]
+
+ packagenames = ()
+
+ _has_args = False
+ _with_ordinality = False
+ _table_value_type = None
+
+ def __init__(self, *clauses, **kwargs):
+ r"""Construct a :class:`.FunctionElement`.
+
+ :param \*clauses: list of column expressions that form the arguments
+ of the SQL function call.
+
+ :param \**kwargs: additional kwargs are typically consumed by
+ subclasses.
+
+ .. seealso::
+
+ :data:`.func`
+
+ :class:`.Function`
+
+ """
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=getattr(self, "name", None),
+ apply_propagate_attrs=self,
+ )
+ for c in clauses
+ ]
+ self._has_args = self._has_args or bool(args)
+ self.clause_expr = ClauseList(
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
+
+ _non_anon_label = None
+
+ @property
+ def _proxy_key(self):
+ return super(FunctionElement, self)._proxy_key or getattr(
+ self, "name", None
+ )
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_function(
+ self, multiparams, params, execution_options
+ )
+
+ def scalar_table_valued(self, name, type_=None):
+ """Return a column expression that's against this
+ :class:`_functions.FunctionElement` as a scalar
+ table-valued expression.
+
+ The returned expression is similar to that returned by a single column
+ accessed off of a :meth:`_functions.FunctionElement.table_valued`
+ construct, except no FROM clause is generated; the function is rendered
+ in the similar way as a scalar subquery.
+
+ E.g.::
+
+ >>> from sqlalchemy import func, select
+ >>> fn = func.jsonb_each("{'k', 'v'}").scalar_table_valued("key")
+ >>> print(select(fn))
+ SELECT (jsonb_each(:jsonb_each_1)).key
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ :meth:`_functions.FunctionElement.alias`
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ """ # noqa: E501
+
+ return ScalarFunctionColumn(self, name, type_)
+
+ def table_valued(self, *expr, **kw):
+ r"""Return a :class:`_sql.TableValuedAlias` representation of this
+ :class:`_functions.FunctionElement` with table-valued expressions added.
+
+ e.g.::
+
+ >>> fn = (
+ ... func.generate_series(1, 5).
+ ... table_valued("value", "start", "stop", "step")
+ ... )
+
+ >>> print(select(fn))
+ SELECT anon_1.value, anon_1.start, anon_1.stop, anon_1.step
+ FROM generate_series(:generate_series_1, :generate_series_2) AS anon_1
+
+ >>> print(select(fn.c.value, fn.c.stop).where(fn.c.value > 2))
+ SELECT anon_1.value, anon_1.stop
+ FROM generate_series(:generate_series_1, :generate_series_2) AS anon_1
+ WHERE anon_1.value > :value_1
+
+ A WITH ORDINALITY expression may be generated by passing the keyword
+ argument "with_ordinality"::
+
+ >>> fn = func.generate_series(4, 1, -1).table_valued("gen", with_ordinality="ordinality")
+ >>> print(select(fn))
+ SELECT anon_1.gen, anon_1.ordinality
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) WITH ORDINALITY AS anon_1
+
+ :param \*expr: A series of string column names that will be added to the
+ ``.c`` collection of the resulting :class:`_sql.TableValuedAlias`
+ construct as columns. :func:`_sql.column` objects with or without
+ datatypes may also be used.
+
+ :param name: optional name to assign to the alias name that's generated.
+ If omitted, a unique anonymizing name is used.
+
+ :param with_ordinality: string name that when present results in the
+ ``WITH ORDINALITY`` clause being added to the alias, and the given
+ string name will be added as a column to the .c collection
+ of the resulting :class:`_sql.TableValuedAlias`.
+
+ :param joins_implicitly: when True, the table valued function may be
+ used in the FROM clause without any explicit JOIN to other tables
+ in the SQL query, and no "cartesian product" warning will be generated.
+ May be useful for SQL functions such as ``func.json_each()``.
+
+ .. versionadded:: 1.4.33
+
+ .. versionadded:: 1.4.0b2
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+ :ref:`postgresql_table_valued` - in the :ref:`postgresql_toplevel` documentation
+
+ :meth:`_functions.FunctionElement.scalar_table_valued` - variant of
+ :meth:`_functions.FunctionElement.table_valued` which delivers the
+ complete table valued expression as a scalar column expression
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ :meth:`_sql.TableValuedAlias.render_derived` - renders the alias
+ using a derived column clause, e.g. ``AS name(col1, col2, ...)``
+
+ """ # noqa: 501
+
+ new_func = self._generate()
+
+ with_ordinality = kw.pop("with_ordinality", None)
+ joins_implicitly = kw.pop("joins_implicitly", None)
+ name = kw.pop("name", None)
+
+ if with_ordinality:
+ expr += (with_ordinality,)
+ new_func._with_ordinality = True
+
+ new_func.type = new_func._table_value_type = sqltypes.TableValueType(
+ *expr
+ )
+
+ return new_func.alias(name=name, joins_implicitly=joins_implicitly)
+
+ def column_valued(self, name=None):
+ """Return this :class:`_functions.FunctionElement` as a column expression that
+ selects from itself as a FROM clause.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, func
+ >>> gs = func.generate_series(1, 5, -1).column_valued()
+ >>> print(select(gs))
+ SELECT anon_1
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) AS anon_1
+
+ This is shorthand for::
+
+ gs = func.generate_series(1, 5, -1).alias().column
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial`
+
+ :ref:`postgresql_column_valued` - in the :ref:`postgresql_toplevel` documentation
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ """ # noqa: 501
+
+ return self.alias(name=name).column
+
+ @property
+ def columns(self):
+ r"""The set of columns exported by this :class:`.FunctionElement`.
+
+ This is a placeholder collection that allows the function to be
+ placed in the FROM clause of a statement::
+
+ >>> from sqlalchemy import column, select, func
+ >>> stmt = select(column('x'), column('y')).select_from(func.myfunction())
+ >>> print(stmt)
+ SELECT x, y FROM myfunction()
+
+ The above form is a legacy feature that is now superseded by the
+ fully capable :meth:`_functions.FunctionElement.table_valued`
+ method; see that method for details.
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.table_valued` - generates table-valued
+ SQL function expressions.
+
+ """ # noqa: E501
+
+ return ColumnCollection(
+ columns=[(col.key, col) for col in self._all_selected_columns]
+ )
+
+ @property
+ def _all_selected_columns(self):
+ if self.type._is_table_value:
+ cols = self.type._elements
+ else:
+ cols = [self.label(None)]
+
+ return cols
+
+ @property
+ def exported_columns(self):
+ return self.columns
+
+ @HasMemoized.memoized_attribute
+ def clauses(self):
+ """Return the underlying :class:`.ClauseList` which contains
+ the arguments for this :class:`.FunctionElement`.
+
+ """
+ return self.clause_expr.element
+
+ def over(self, partition_by=None, order_by=None, rows=None, range_=None):
+ """Produce an OVER clause against this function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.row_number().over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over
+ over(func.row_number(), order_by='x')
+
+ See :func:`_expression.over` for a full description.
+
+ .. seealso::
+
+ :func:`_expression.over`
+
+ :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ rows=rows,
+ range_=range_,
+ )
+
+ def within_group(self, *order_by):
+ """Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
+
+ Used against so-called "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including :class:`.percentile_cont`,
+ :class:`.rank`, :class:`.dense_rank`, etc.
+
+ See :func:`_expression.within_group` for a full description.
+
+ .. versionadded:: 1.1
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` -
+ in the :ref:`unified_tutorial`
+
+
+ """
+ return WithinGroup(self, *order_by)
+
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` -
+ in the :ref:`unified_tutorial`
+
+ :class:`.FunctionFilter`
+
+ :func:`.funcfilter`
+
+
+ """
+ if not criterion:
+ return self
+ return FunctionFilter(self, *criterion)
+
+ def as_comparison(self, left_index, right_index):
+ """Interpret this expression as a boolean comparison between two
+ values.
+
+ This method is used for an ORM use case described at
+ :ref:`relationship_custom_operator_sql_function`.
+
+ A hypothetical SQL function "is_equal()" which compares to values
+ for equality would be written in the Core expression language as::
+
+ expr = func.is_equal("a", "b")
+
+ If "is_equal()" above is comparing "a" and "b" for equality, the
+ :meth:`.FunctionElement.as_comparison` method would be invoked as::
+
+ expr = func.is_equal("a", "b").as_comparison(1, 2)
+
+ Where above, the integer value "1" refers to the first argument of the
+ "is_equal()" function and the integer value "2" refers to the second.
+
+ This would create a :class:`.BinaryExpression` that is equivalent to::
+
+ BinaryExpression("a", "b", operator=op.eq)
+
+ However, at the SQL level it would still render as
+ "is_equal('a', 'b')".
+
+ The ORM, when it loads a related object or collection, needs to be able
+ to manipulate the "left" and "right" sides of the ON clause of a JOIN
+ expression. The purpose of this method is to provide a SQL function
+ construct that can also supply this information to the ORM, when used
+ with the :paramref:`_orm.relationship.primaryjoin` parameter. The
+ return value is a containment object called :class:`.FunctionAsBinary`.
+
+ An ORM example is as follows::
+
+ class Venue(Base):
+ __tablename__ = 'venue'
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ descendants = relationship(
+ "Venue",
+ primaryjoin=func.instr(
+ remote(foreign(name)), name + "/"
+ ).as_comparison(1, 2) == 1,
+ viewonly=True,
+ order_by=name
+ )
+
+ Above, the "Venue" class can load descendant "Venue" objects by
+ determining if the name of the parent Venue is contained within the
+ start of the hypothetical descendant value's name, e.g. "parent1" would
+ match up to "parent1/child1", but not to "parent2/child1".
+
+ Possible use cases include the "materialized path" example given above,
+ as well as making use of special SQL functions such as geometric
+ functions to create join conditions.
+
+ :param left_index: the integer 1-based index of the function argument
+ that serves as the "left" side of the expression.
+ :param right_index: the integer 1-based index of the function argument
+ that serves as the "right" side of the expression.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`relationship_custom_operator_sql_function` -
+ example use within the ORM
+
+ """
+ return FunctionAsBinary(self, left_index, right_index)
+
+ @property
+ def _from_objects(self):
+ return self.clauses._from_objects
+
+ def within_group_type(self, within_group):
+ """For types that define their return type as based on the criteria
+ within a WITHIN GROUP (ORDER BY) expression, called by the
+ :class:`.WithinGroup` construct.
+
+ Returns None by default, in which case the function's normal ``.type``
+ is used.
+
+ """
+
+ return None
+
+ def alias(self, name=None, joins_implicitly=False):
+ r"""Produce a :class:`_expression.Alias` construct against this
+ :class:`.FunctionElement`.
+
+ .. tip::
+
+ The :meth:`_functions.FunctionElement.alias` method is part of the
+ mechanism by which "table valued" SQL functions are created.
+ However, most use cases are covered by higher level methods on
+ :class:`_functions.FunctionElement` including
+ :meth:`_functions.FunctionElement.table_valued`, and
+ :meth:`_functions.FunctionElement.column_valued`.
+
+ This construct wraps the function in a named alias which
+ is suitable for the FROM clause, in the style accepted for example
+ by PostgreSQL. A column expression is also provided using the
+ special ``.column`` attribute, which may
+ be used to refer to the output of the function as a scalar value
+ in the columns or where clause, for a backend such as PostgreSQL.
+
+ For a full table-valued expression, use the
+ :meth:`_function.FunctionElement.table_valued` method first to
+ establish named columns.
+
+ e.g.::
+
+ >>> from sqlalchemy import func, select, column
+ >>> data_view = func.unnest([1, 2, 3]).alias("data_view")
+ >>> print(select(data_view.column))
+ SELECT data_view
+ FROM unnest(:unnest_1) AS data_view
+
+ The :meth:`_functions.FunctionElement.column_valued` method provides
+ a shortcut for the above pattern::
+
+ >>> data_view = func.unnest([1, 2, 3]).column_valued("data_view")
+ >>> print(select(data_view))
+ SELECT data_view
+ FROM unnest(:unnest_1) AS data_view
+
+ .. versionadded:: 1.4.0b2 Added the ``.column`` accessor
+
+ :param name: alias name, will be rendered as ``AS <name>`` in the
+ FROM clause
+
+ :param joins_implicitly: when True, the table valued function may be
+ used in the FROM clause without any explicit JOIN to other tables
+ in the SQL query, and no "cartesian product" warning will be
+ generated. May be useful for SQL functions such as
+ ``func.json_each()``.
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` -
+ in the :ref:`unified_tutorial`
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ :meth:`_functions.FunctionElement.scalar_table_valued`
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+
+ """
+
+ return TableValuedAlias._construct(
+ self,
+ name,
+ table_value_type=self.type,
+ joins_implicitly=joins_implicitly,
+ )
+
+ def select(self):
+ """Produce a :func:`_expression.select` construct
+ against this :class:`.FunctionElement`.
+
+ This is shorthand for::
+
+ s = select(function_element)
+
+ """
+ s = Select._create_select(self)
+ if self._execution_options:
+ s = s.execution_options(**self._execution_options)
+ return s
+
+ @util.deprecated_20(
+ ":meth:`.FunctionElement.scalar`",
+ alternative="Scalar execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.scalar` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.scalar` method of "
+ ":class:`.Session`.",
+ )
+ def scalar(self):
+ """Execute this :class:`.FunctionElement` against an embedded
+ 'bind' and return a scalar value.
+
+ This first calls :meth:`~.FunctionElement.select` to
+ produce a SELECT construct.
+
+ Note that :class:`.FunctionElement` can be passed to
+ the :meth:`.Connectable.scalar` method of :class:`_engine.Connection`
+ or :class:`_engine.Engine`.
+
+ """
+ return self.select().execute().scalar()
+
+ @util.deprecated_20(
+ ":meth:`.FunctionElement.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self):
+ """Execute this :class:`.FunctionElement` against an embedded
+ 'bind'.
+
+ This first calls :meth:`~.FunctionElement.select` to
+ produce a SELECT construct.
+
+ Note that :class:`.FunctionElement` can be passed to
+ the :meth:`.Connectable.execute` method of :class:`_engine.Connection`
+ or :class:`_engine.Engine`.
+
+ """
+ return self.select().execute()
+
+ def _bind_param(self, operator, obj, type_=None, **kw):
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ **kw
+ )
+
+ def self_group(self, against=None):
+ # for the moment, we are parenthesizing all array-returning
+ # expressions against getitem. This may need to be made
+ # more portable if in the future we support other DBs
+ # besides postgresql.
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
+ return Grouping(self)
+ else:
+ return super(FunctionElement, self).self_group(against=against)
+
+ @property
+ def entity_namespace(self):
+ """overrides FromClause.entity_namespace as functions are generally
+ column expressions and not FromClauses.
+
+ """
+ # ideally functions would not be fromclauses but we failed to make
+ # this adjustment in 1.4
+ return _entity_namespace(self.clause_expr)
+
+
+class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return ColumnElement._gen_cache_key(self, anon_map, bindparams)
+
+ def __init__(self, fn, left_index, right_index):
+ self.sql_function = fn
+ self.left_index = left_index
+ self.right_index = right_index
+
+ self.operator = operators.function_as_comparison_op
+ self.type = sqltypes.BOOLEANTYPE
+ self.negate = None
+ self._is_implicitly_boolean = True
+ self.modifiers = {}
+
+ @property
+ def left(self):
+ return self.sql_function.clauses.clauses[self.left_index - 1]
+
+ @left.setter
+ def left(self, value):
+ self.sql_function.clauses.clauses[self.left_index - 1] = value
+
+ @property
+ def right(self):
+ return self.sql_function.clauses.clauses[self.right_index - 1]
+
+ @right.setter
+ def right(self, value):
+ self.sql_function.clauses.clauses[self.right_index - 1] = value
+
+
+class ScalarFunctionColumn(NamedColumn):
+ __visit_name__ = "scalar_function_column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("fn", InternalTraversal.dp_clauseelement),
+ ]
+
+ is_literal = False
+ table = None
+
+ def __init__(self, fn, name, type_=None):
+ self.fn = fn
+ self.name = name
+ self.type = sqltypes.to_instance(type_)
+
+
+class _FunctionGenerator(object):
+ """Generate SQL function expressions.
+
+ :data:`.func` is a special object instance which generates SQL
+ functions based on name-based attributes, e.g.::
+
+ >>> print(func.count(1))
+ count(:param_1)
+
+ The returned object is an instance of :class:`.Function`, and is a
+ column-oriented SQL element like any other, and is used in that way::
+
+ >>> print(select(func.count(table.c.id)))
+ SELECT count(sometable.id) FROM sometable
+
+ Any name can be given to :data:`.func`. If the function name is unknown to
+ SQLAlchemy, it will be rendered exactly as is. For common SQL functions
+ which SQLAlchemy is aware of, the name may be interpreted as a *generic
+ function* which will be compiled appropriately to the target database::
+
+ >>> print(func.current_timestamp())
+ CURRENT_TIMESTAMP
+
+ To call functions which are present in dot-separated packages,
+ specify them in the same manner::
+
+ >>> print(func.stats.yield_curve(5, 10))
+ stats.yield_curve(:yield_curve_1, :yield_curve_2)
+
+ SQLAlchemy can be made aware of the return type of functions to enable
+ type-specific lexical and result-based behavior. For example, to ensure
+ that a string-based function returns a Unicode value and is similarly
+ treated as a string in expressions, specify
+ :class:`~sqlalchemy.types.Unicode` as the type:
+
+ >>> print(func.my_string(u'hi', type_=Unicode) + ' ' +
+ ... func.my_string(u'there', type_=Unicode))
+ my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3)
+
+ The object returned by a :data:`.func` call is usually an instance of
+ :class:`.Function`.
+ This object meets the "column" interface, including comparison and labeling
+ functions. The object can also be passed the :meth:`~.Connectable.execute`
+ method of a :class:`_engine.Connection` or :class:`_engine.Engine`,
+ where it will be
+ wrapped inside of a SELECT statement first::
+
+ print(connection.execute(func.current_timestamp()).scalar())
+
+ In a few exception cases, the :data:`.func` accessor
+ will redirect a name to a built-in expression such as :func:`.cast`
+ or :func:`.extract`, as these names have well-known meaning
+ but are not exactly the same as "functions" from a SQLAlchemy
+ perspective.
+
+ Functions which are interpreted as "generic" functions know how to
+ calculate their return type automatically. For a listing of known generic
+ functions, see :ref:`generic_functions`.
+
+ .. note::
+
+ The :data:`.func` construct has only limited support for calling
+ standalone "stored procedures", especially those with special
+ parameterization concerns.
+
+ See the section :ref:`stored_procedures` for details on how to use
+ the DBAPI-level ``callproc()`` method for fully traditional stored
+ procedures.
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :class:`.Function`
+
+ """
+
+ def __init__(self, **opts):
+ self.__names = []
+ self.opts = opts
+
+ def __getattr__(self, name):
+ # passthru __ attributes; fixes pydoc
+ if name.startswith("__"):
+ try:
+ return self.__dict__[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ elif name.endswith("_"):
+ name = name[0:-1]
+ f = _FunctionGenerator(**self.opts)
+ f.__names = list(self.__names) + [name]
+ return f
+
+ def __call__(self, *c, **kwargs):
+ o = self.opts.copy()
+ o.update(kwargs)
+
+ tokens = len(self.__names)
+
+ if tokens == 2:
+ package, fname = self.__names
+ elif tokens == 1:
+ package, fname = "_default", self.__names[0]
+ else:
+ package = None
+
+ if package is not None:
+ func = _registry[package].get(fname.lower())
+ if func is not None:
+ return func(*c, **o)
+
+ return Function(
+ self.__names[-1], packagenames=tuple(self.__names[0:-1]), *c, **o
+ )
+
+
+func = _FunctionGenerator()
+func.__doc__ = _FunctionGenerator.__doc__
+
+modifier = _FunctionGenerator(group=False)
+
+
+class Function(FunctionElement):
+ r"""Describe a named SQL function.
+
+ The :class:`.Function` object is typically generated from the
+ :data:`.func` generation object.
+
+
+ :param \*clauses: list of column expressions that form the arguments
+ of the SQL function call.
+
+ :param type\_: optional :class:`.TypeEngine` datatype object that will be
+ used as the return value of the column expression generated by this
+ function call.
+
+ :param packagenames: a string which indicates package prefix names
+ to be prepended to the function name when the SQL is generated.
+ The :data:`.func` generator creates these when it is called using
+ dotted format, e.g.::
+
+ func.mypackage.some_function(col1, col2)
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :data:`.func` - namespace which produces registered or ad-hoc
+ :class:`.Function` instances.
+
+ :class:`.GenericFunction` - allows creation of registered function
+ types.
+
+ """
+
+ __visit_name__ = "function"
+
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ type = sqltypes.NULLTYPE
+ """A :class:`_types.TypeEngine` object which refers to the SQL return
+ type represented by this SQL function.
+
+ This datatype may be configured when generating a
+ :class:`_functions.Function` object by passing the
+ :paramref:`_functions.Function.type_` parameter, e.g.::
+
+ >>> select(func.lower("some VALUE", type_=String))
+
+ The small number of built-in classes of :class:`_functions.Function` come
+ with a built-in datatype that's appropriate to the class of function and
+ its arguments. For functions that aren't known, the type defaults to the
+ "null type".
+
+ """
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.text.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, name, *clauses, **kw):
+ """Construct a :class:`.Function`.
+
+ The :data:`.func` construct is normally used to construct
+ new :class:`.Function` instances.
+
+ """
+ self.packagenames = kw.pop("packagenames", None) or ()
+ self.name = name
+
+ self._bind = self._get_bind(kw)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
+
+ FunctionElement.__init__(self, *clauses, **kw)
+
+ def _get_bind(self, kw):
+ if "bind" in kw:
+ util.warn_deprecated_20(
+ "The Function.bind argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ )
+ return kw["bind"]
+
+ def _bind_param(self, operator, obj, type_=None, **kw):
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ **kw
+ )
+
+
+class _GenericMeta(TraversibleType):
+ def __init__(cls, clsname, bases, clsdict):
+ if annotation.Annotated not in cls.__mro__:
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
+ # legacy
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
+
+ # Check _register attribute status
+ cls._register = getattr(cls, "_register", True)
+
+ # Register the function if required
+ if cls._register:
+ register_function(identifier, cls, package)
+ else:
+ # Set _register to True to register child classes by default
+ cls._register = True
+
+ super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
+
+
+class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
+ """Define a 'generic' function.
+
+ A generic function is a pre-established :class:`.Function`
+ class that is instantiated automatically when called
+ by name from the :data:`.func` attribute. Note that
+ calling any name from :data:`.func` has the effect that
+ a new :class:`.Function` instance is created automatically,
+ given that name. The primary use case for defining
+ a :class:`.GenericFunction` class is so that a function
+ of a particular name may be given a fixed return type.
+ It can also include custom argument parsing schemes as well
+ as additional methods.
+
+ Subclasses of :class:`.GenericFunction` are automatically
+ registered under the name of the class. For
+ example, a user-defined function ``as_utc()`` would
+ be available immediately::
+
+ from sqlalchemy.sql.functions import GenericFunction
+ from sqlalchemy.types import DateTime
+
+ class as_utc(GenericFunction):
+ type = DateTime
+ inherit_cache = True
+
+ print(select(func.as_utc()))
+
+ User-defined generic functions can be organized into
+ packages by specifying the "package" attribute when defining
+ :class:`.GenericFunction`. Third party libraries
+ containing many functions may want to use this in order
+ to avoid name conflicts with other systems. For example,
+ if our ``as_utc()`` function were part of a package
+ "time"::
+
+ class as_utc(GenericFunction):
+ type = DateTime
+ package = "time"
+ inherit_cache = True
+
+ The above function would be available from :data:`.func`
+ using the package name ``time``::
+
+ print(select(func.time.as_utc()))
+
+ A final option is to allow the function to be accessed
+ from one name in :data:`.func` but to render as a different name.
+ The ``identifier`` attribute will override the name used to
+ access the function as loaded from :data:`.func`, but will retain
+ the usage of ``name`` as the rendered name::
+
+ class GeoBuffer(GenericFunction):
+ type = Geometry
+ package = "geo"
+ name = "ST_Buffer"
+ identifier = "buffer"
+ inherit_cache = True
+
+ The above function will render as follows::
+
+ >>> print(func.geo.buffer())
+ ST_Buffer()
+
+ The name will be rendered as is, however without quoting unless the name
+ contains special characters that require quoting. To force quoting
+ on or off for the name, use the :class:`.sqlalchemy.sql.quoted_name`
+ construct::
+
+ from sqlalchemy.sql import quoted_name
+
+ class GeoBuffer(GenericFunction):
+ type = Geometry
+ package = "geo"
+ name = quoted_name("ST_Buffer", True)
+ identifier = "buffer"
+ inherit_cache = True
+
+ The above function will render as::
+
+ >>> print(func.geo.buffer())
+ "ST_Buffer"()
+
+ .. versionadded:: 1.3.13 The :class:`.quoted_name` construct is now
+ recognized for quoting when used with the "name" attribute of the
+ object, so that quoting can be forced on or off for the function
+ name.
+
+
+ """
+
+ coerce_arguments = True
+ _register = False
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ parsed_args = kwargs.pop("_parsed_args", None)
+ if parsed_args is None:
+ parsed_args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=self.name,
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ self._has_args = self._has_args or bool(parsed_args)
+ self.packagenames = ()
+ self._bind = self._get_bind(kwargs)
+ self.clause_expr = ClauseList(
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
+ self.type = sqltypes.to_instance(
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
+
+register_function("cast", Cast)
+register_function("extract", Extract)
+
+
+class next_value(GenericFunction):
+ """Represent the 'next value', given a :class:`.Sequence`
+ as its single argument.
+
+ Compiles into the appropriate function on each backend,
+ or will raise NotImplementedError if used on a backend
+ that does not provide support for sequences.
+
+ """
+
+ type = sqltypes.Integer()
+ name = "next_value"
+
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
+ def __init__(self, seq, **kw):
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = self._get_bind(kw)
+ self.sequence = seq
+ self.type = sqltypes.to_instance(
+ seq.data_type or getattr(self, "type", None)
+ )
+
+ def compare(self, other, **kw):
+ return (
+ isinstance(other, next_value)
+ and self.sequence.name == other.sequence.name
+ )
+
+ @property
+ def _from_objects(self):
+ return []
+
+
+class AnsiFunction(GenericFunction):
+ """Define a function in "ansi" format, which doesn't render parenthesis."""
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ GenericFunction.__init__(self, *args, **kwargs)
+
+
+class ReturnTypeFromArgs(GenericFunction):
+ """Define a function whose return type is the same as its arguments."""
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=self.name,
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
+ super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
+
+
+class coalesce(ReturnTypeFromArgs):
+ _has_args = True
+ inherit_cache = True
+
+
+class max(ReturnTypeFromArgs): # noqa: A001
+ """The SQL MAX() aggregate function."""
+
+ inherit_cache = True
+
+
+class min(ReturnTypeFromArgs): # noqa: A001
+ """The SQL MIN() aggregate function."""
+
+ inherit_cache = True
+
+
+class sum(ReturnTypeFromArgs): # noqa: A001
+ """The SQL SUM() aggregate function."""
+
+ inherit_cache = True
+
+
+class now(GenericFunction):
+ """The SQL now() datetime function.
+
+ SQLAlchemy dialects will usually render this particular function
+ in a backend-specific way, such as rendering it as ``CURRENT_TIMESTAMP``.
+
+ """
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class concat(GenericFunction):
+ """The SQL CONCAT() function, which concatenates strings.
+
+ E.g.::
+
+ >>> print(select(func.concat('a', 'b')))
+ SELECT concat(:concat_2, :concat_3) AS concat_1
+
+ String concatenation in SQLAlchemy is more commonly available using the
+ Python ``+`` operator with string datatypes, which will render a
+ backend-specific concatenation operator, such as ::
+
+ >>> print(select(literal("a") + "b"))
+ SELECT :param_1 || :param_2 AS anon_1
+
+
+ """
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class char_length(GenericFunction):
+ """The CHAR_LENGTH() SQL function."""
+
+ type = sqltypes.Integer
+ inherit_cache = True
+
+ def __init__(self, arg, **kwargs):
+ GenericFunction.__init__(self, arg, **kwargs)
+
+
+class random(GenericFunction):
+ """The RANDOM() SQL function."""
+
+ _has_args = True
+ inherit_cache = True
+
+
+class count(GenericFunction):
+ r"""The ANSI COUNT aggregate function. With no arguments,
+ emits COUNT \*.
+
+ E.g.::
+
+ from sqlalchemy import func
+ from sqlalchemy import select
+ from sqlalchemy import table, column
+
+ my_table = table('some_table', column('id'))
+
+ stmt = select(func.count()).select_from(my_table)
+
+ Executing ``stmt`` would emit::
+
+ SELECT count(*) AS count_1
+ FROM some_table
+
+
+ """
+ type = sqltypes.Integer
+ inherit_cache = True
+
+ def __init__(self, expression=None, **kwargs):
+ if expression is None:
+ expression = literal_column("*")
+ super(count, self).__init__(expression, **kwargs)
+
+
+class current_date(AnsiFunction):
+ """The CURRENT_DATE() SQL function."""
+
+ type = sqltypes.Date
+ inherit_cache = True
+
+
+class current_time(AnsiFunction):
+ """The CURRENT_TIME() SQL function."""
+
+ type = sqltypes.Time
+ inherit_cache = True
+
+
+class current_timestamp(AnsiFunction):
+ """The CURRENT_TIMESTAMP() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class current_user(AnsiFunction):
+ """The CURRENT_USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class localtime(AnsiFunction):
+ """The localtime() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class localtimestamp(AnsiFunction):
+ """The localtimestamp() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class session_user(AnsiFunction):
+ """The SESSION_USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class sysdate(AnsiFunction):
+ """The SYSDATE() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class user(AnsiFunction):
+ """The USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class array_agg(GenericFunction):
+ """Support for the ARRAY_AGG function.
+
+ The ``func.array_agg(expr)`` construct returns an expression of
+ type :class:`_types.ARRAY`.
+
+ e.g.::
+
+ stmt = select(func.array_agg(table.c.values)[2:5])
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_postgresql.array_agg` - PostgreSQL-specific version that
+ returns :class:`_postgresql.ARRAY`, which has PG-specific operators
+ added.
+
+ """
+
+ type = sqltypes.ARRAY
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole, c, apply_propagate_attrs=self
+ )
+ for c in args
+ ]
+
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
+
+ type_from_args = _type_from_args(args)
+ if isinstance(type_from_args, sqltypes.ARRAY):
+ kwargs["type_"] = type_from_args
+ else:
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
+ super(array_agg, self).__init__(*args, **kwargs)
+
+
+class OrderedSetAgg(GenericFunction):
+ """Define a function where the return type is based on the sort
+ expression type as defined by the expression passed to the
+ :meth:`.FunctionElement.within_group` method."""
+
+ array_for_multi_clause = False
+ inherit_cache = True
+
+ def within_group_type(self, within_group):
+ func_clauses = self.clause_expr.element
+ order_by = sqlutil.unwrap_order_by(within_group.order_by)
+ if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
+ return sqltypes.ARRAY(order_by[0].type)
+ else:
+ return order_by[0].type
+
+
+class mode(OrderedSetAgg):
+ """Implement the ``mode`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression.
+
+ .. versionadded:: 1.1
+
+ """
+
+ inherit_cache = True
+
+
+class percentile_cont(OrderedSetAgg):
+ """Implement the ``percentile_cont`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression,
+ or if the arguments are an array, an :class:`_types.ARRAY` of the sort
+ expression's type.
+
+ .. versionadded:: 1.1
+
+ """
+
+ array_for_multi_clause = True
+ inherit_cache = True
+
+
+class percentile_disc(OrderedSetAgg):
+ """Implement the ``percentile_disc`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression,
+ or if the arguments are an array, an :class:`_types.ARRAY` of the sort
+ expression's type.
+
+ .. versionadded:: 1.1
+
+ """
+
+ array_for_multi_clause = True
+ inherit_cache = True
+
+
+class rank(GenericFunction):
+ """Implement the ``rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Integer`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Integer()
+ inherit_cache = True
+
+
+class dense_rank(GenericFunction):
+ """Implement the ``dense_rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Integer`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Integer()
+ inherit_cache = True
+
+
+class percent_rank(GenericFunction):
+ """Implement the ``percent_rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Numeric`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Numeric()
+ inherit_cache = True
+
+
+class cume_dist(GenericFunction):
+ """Implement the ``cume_dist`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Numeric`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Numeric()
+ inherit_cache = True
+
+
+class cube(GenericFunction):
+ r"""Implement the ``CUBE`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.cube(table.c.col_1, table.c.col_2))
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
+
+
+class rollup(GenericFunction):
+ r"""Implement the ``ROLLUP`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.rollup(table.c.col_1, table.c.col_2))
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
+
+
+class grouping_sets(GenericFunction):
+ r"""Implement the ``GROUPING SETS`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.grouping_sets(table.c.col_1, table.c.col_2))
+
+ In order to group by multiple sets, use the :func:`.tuple_` construct::
+
+ from sqlalchemy import tuple_
+
+ stmt = select(
+ func.sum(table.c.value),
+ table.c.col_1, table.c.col_2,
+ table.c.col_3
+ ).group_by(
+ func.grouping_sets(
+ tuple_(table.c.col_1, table.c.col_2),
+ tuple_(table.c.value, table.c.col_3),
+ )
+ )
+
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
new file mode 100644
index 0000000..584efe4
--- /dev/null
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -0,0 +1,1314 @@
+# sql/lambdas.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import inspect
+import itertools
+import operator
+import sys
+import threading
+import types
+import weakref
+
+from . import coercions
+from . import elements
+from . import roles
+from . import schema
+from . import traversals
+from . import type_api
+from . import visitors
+from .base import _clone
+from .base import Options
+from .operators import ColumnOperators
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+from ..util import compat
+
+_closure_per_cache_key = util.LRUCache(1000)
+
+
+class LambdaOptions(Options):
+ enable_tracking = True
+ track_closure_variables = True
+ track_on = None
+ global_track_bound_values = True
+ track_bound_values = True
+ lambda_cache = None
+
+
+def lambda_stmt(
+ lmb,
+ enable_tracking=True,
+ track_closure_variables=True,
+ track_on=None,
+ global_track_bound_values=True,
+ track_bound_values=True,
+ lambda_cache=None,
+):
+ """Produce a SQL statement that is cached as a lambda.
+
+ The Python code object within the lambda is scanned for both Python
+ literals that will become bound parameters as well as closure variables
+ that refer to Core or ORM constructs that may vary. The lambda itself
+ will be invoked only once per particular set of constructs detected.
+
+ E.g.::
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: table.select())
+ stmt += lambda s: s.where(table.c.id == 5)
+
+ result = connection.execute(stmt)
+
+ The object returned is an instance of :class:`_sql.StatementLambdaElement`.
+
+ .. versionadded:: 1.4
+
+ :param lmb: a Python function, typically a lambda, which takes no arguments
+ and returns a SQL expression construct
+ :param enable_tracking: when False, all scanning of the given lambda for
+ changes in closure variables or bound parameters is disabled. Use for
+ a lambda that produces the identical results in all cases with no
+ parameterization.
+ :param track_closure_variables: when False, changes in closure variables
+ within the lambda will not be scanned. Use for a lambda where the
+ state of its closure variables will never change the SQL structure
+ returned by the lambda.
+ :param track_bound_values: when False, bound parameter tracking will
+ be disabled for the given lambda. Use for a lambda that either does
+ not produce any bound values, or where the initial bound values never
+ change.
+ :param global_track_bound_values: when False, bound parameter tracking
+ will be disabled for the entire statement including additional links
+ added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
+ :param lambda_cache: a dictionary or other mapping-like object where
+ information about the lambda's Python code as well as the tracked closure
+ variables in the lambda itself will be stored. Defaults
+ to a global LRU cache. This cache is independent of the "compiled_cache"
+ used by the :class:`_engine.Connection` object.
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+
+ """
+
+ return StatementLambdaElement(
+ lmb,
+ roles.StatementRole,
+ LambdaOptions(
+ enable_tracking=enable_tracking,
+ track_on=track_on,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=global_track_bound_values,
+ track_bound_values=track_bound_values,
+ lambda_cache=lambda_cache,
+ ),
+ )
+
+
+class LambdaElement(elements.ClauseElement):
+ """A SQL construct where the state is stored as an un-invoked lambda.
+
+ The :class:`_sql.LambdaElement` is produced transparently whenever
+ passing lambda expressions into SQL constructs, such as::
+
+ stmt = select(table).where(lambda: table.c.col == parameter)
+
+ The :class:`_sql.LambdaElement` is the base of the
+ :class:`_sql.StatementLambdaElement` which represents a full statement
+ within a lambda.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ _transforms = ()
+
+ parent_lambda = None
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
+
+ def __init__(
+ self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+ ):
+ self.fn = fn
+ self.role = role
+ self.tracker_key = (fn.__code__,)
+ self.opts = opts
+
+ if apply_propagate_attrs is None and (role is roles.StatementRole):
+ apply_propagate_attrs = self
+
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec.propagate_attrs
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
+ lambda_cache = opts.lambda_cache
+ if lambda_cache is None:
+ lambda_cache = _closure_per_cache_key
+
+ tracker_key = self.tracker_key
+
+ fn = self.fn
+ closure = fn.__closure__
+ tracker = AnalyzedCode.get(
+ fn,
+ self,
+ opts,
+ )
+
+ self._resolved_bindparams = bindparams = []
+
+ if self.parent_lambda is not None:
+ parent_closure_cache_key = self.parent_lambda.closure_cache_key
+ else:
+ parent_closure_cache_key = ()
+
+ if parent_closure_cache_key is not traversals.NO_CACHE:
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, opts, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
+
+ if traversals.NO_CACHE not in anon_map:
+ cache_key = parent_closure_cache_key + cache_key
+
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+ else:
+ cache_key = traversals.NO_CACHE
+ rec = None
+
+ else:
+ cache_key = traversals.NO_CACHE
+ rec = None
+
+ self.closure_cache_key = cache_key
+
+ if rec is None:
+ if cache_key is not traversals.NO_CACHE:
+
+ with AnalyzedCode._generation_mutex:
+ key = tracker_key + cache_key
+ if key not in lambda_cache:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, fn
+ )
+ rec.closure_bindparams = bindparams
+ lambda_cache[key] = rec
+ else:
+ rec = lambda_cache[key]
+ else:
+ rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
+
+ else:
+ bindparams[:] = [
+ orig_bind._with_value(new_bind.value, maintain_key=True)
+ for orig_bind, new_bind in zip(
+ rec.closure_bindparams, bindparams
+ )
+ ]
+
+ self._rec = rec
+
+ if cache_key is not traversals.NO_CACHE:
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
+
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn,
+ tracker_instrumented_fn,
+ bindparams,
+ )
+ lambda_element = lambda_element.parent_lambda
+
+ return rec
+
+ def __getattr__(self, key):
+ return getattr(self._rec.expected_expr, key)
+
+ @property
+ def _is_sequence(self):
+ return self._rec.is_sequence
+
+ @property
+ def _select_iterable(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._select_iterable for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._select_iterable
+
+ @property
+ def _from_objects(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._from_objects for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._from_objects
+
+ def _param_dict(self):
+ return {b.key: b.value for b in self._resolved_bindparams}
+
+ def _setup_binds_for_tracked_expr(self, expr):
+ bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
+
+ def replace(thing):
+ if isinstance(thing, elements.BindParameter):
+
+ if thing.key in bindparam_lookup:
+ bind = bindparam_lookup[thing.key]
+ if thing.expanding:
+ bind.expanding = True
+ bind.expand_op = thing.expand_op
+ bind.type = thing.type
+ return bind
+
+ if self._rec.is_sequence:
+ expr = [
+ visitors.replacement_traverse(sub_expr, {}, replace)
+ for sub_expr in expr
+ ]
+ elif getattr(expr, "is_clause_element", False):
+ expr = visitors.replacement_traverse(expr, {}, replace)
+
+ return expr
+
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ # TODO: this needs A LOT of tests
+ self._resolved = clone(
+ self._resolved,
+ deferred_copy_internals=deferred_copy_internals,
+ **kw
+ )
+
+ @util.memoized_property
+ def _resolved(self):
+ expr = self._rec.expected_expr
+
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ return expr
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self.closure_cache_key is traversals.NO_CACHE:
+ anon_map[traversals.NO_CACHE] = True
+ return None
+
+ cache_key = (
+ self.fn.__code__,
+ self.__class__,
+ ) + self.closure_cache_key
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (
+ (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ )
+
+ parent = parent.parent_lambda
+
+ if self._resolved_bindparams:
+ bindparams.extend(self._resolved_bindparams)
+ return cache_key
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn()
+
+
+class DeferredLambdaElement(LambdaElement):
+ """A LambdaElement where the lambda accepts arguments and is
+ invoked within the compile phase with special context.
+
+ This lambda doesn't normally produce its real SQL expression outside of the
+ compile phase. It is passed a fixed set of initial arguments
+ so that it can generate a sample expression.
+
+ """
+
+ def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
+ self.lambda_args = lambda_args
+ super(DeferredLambdaElement, self).__init__(fn, role, opts)
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(*self.lambda_args)
+
+ def _resolve_with_args(self, *lambda_args):
+ tracker_fn = self._rec.tracker_instrumented_fn
+ expr = tracker_fn(*lambda_args)
+
+ expr = coercions.expect(self.role, expr)
+
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ # this validation is getting very close, but not quite, to achieving
+ # #5767. The problem is if the base lambda uses an unnamed column
+ # as is very common with mixins, the parameter name is different
+ # and it produces a false positive; that is, for the documented case
+ # that is exactly what people will be doing, it doesn't work, so
+ # I'm not really sure how to handle this right now.
+ # expected_binds = [
+ # b._orig_key
+ # for b in self._rec.expr._generate_cache_key()[1]
+ # if b.required
+ # ]
+ # got_binds = [
+ # b._orig_key for b in expr._generate_cache_key()[1] if b.required
+ # ]
+ # if expected_binds != got_binds:
+ # raise exc.InvalidRequestError(
+ # "Lambda callable at %s produced a different set of bound "
+ # "parameters than its original run: %s"
+ # % (self.fn.__code__, ", ".join(got_binds))
+ # )
+
+ # TODO: TEST TEST TEST, this is very out there
+ for deferred_copy_internals in self._transforms:
+ expr = deferred_copy_internals(expr)
+
+ return expr
+
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ super(DeferredLambdaElement, self)._copy_internals(
+ clone=clone,
+ deferred_copy_internals=deferred_copy_internals, # **kw
+ opts=kw,
+ )
+
+ # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
+ # our expression yet. so hold onto the replacement
+ if deferred_copy_internals:
+ self._transforms += (deferred_copy_internals,)
+
+
+class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+ """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
+
+ The :class:`_sql.StatementLambdaElement` is constructed using the
+ :func:`_sql.lambda_stmt` function::
+
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: select(table))
+
+ Once constructed, additional criteria can be built onto the statement
+ by adding subsequent lambdas, which accept the existing statement
+ object as a single parameter::
+
+ stmt += lambda s: s.where(table.c.col == parameter)
+
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ def __add__(self, other):
+ return self.add_criteria(other)
+
+ def add_criteria(
+ self,
+ other,
+ enable_tracking=True,
+ track_on=None,
+ track_closure_variables=True,
+ track_bound_values=True,
+ ):
+ """Add new criteria to this :class:`_sql.StatementLambdaElement`.
+
+ E.g.::
+
+ >>> def my_stmt(parameter):
+ ... stmt = lambda_stmt(
+ ... lambda: select(table.c.x, table.c.y),
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: table.c.x > parameter
+ ... )
+ ... return stmt
+
+ The :meth:`_sql.StatementLambdaElement.add_criteria` method is
+ equivalent to using the Python addition operator to add a new
+ lambda, except that additional arguments may be added including
+ ``track_closure_values`` and ``track_on``::
+
+ >>> def my_stmt(self, foo):
+ ... stmt = lambda_stmt(
+ ... lambda: select(func.max(foo.x, foo.y)),
+ ... track_closure_variables=False
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: self.where_criteria,
+ ... track_on=[self]
+ ... )
+ ... return stmt
+
+ See :func:`_sql.lambda_stmt` for a description of the parameters
+ accepted.
+
+ """
+
+ opts = self.opts + dict(
+ enable_tracking=enable_tracking,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=self.opts.global_track_bound_values,
+ track_on=track_on,
+ track_bound_values=track_bound_values,
+ )
+
+ return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._rec.expected_expr.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+ @property
+ def _with_options(self):
+ return self._rec.expected_expr._with_options
+
+ @property
+ def _effective_plugin_target(self):
+ return self._rec.expected_expr._effective_plugin_target
+
+ @property
+ def _execution_options(self):
+ return self._rec.expected_expr._execution_options
+
+ def spoil(self):
+ """Return a new :class:`.StatementLambdaElement` that will run
+ all lambdas unconditionally each time.
+
+ """
+ return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+ """Provides the :class:`.StatementLambdaElement` API but does not
+ cache or analyze lambdas.
+
+ the lambdas are instead invoked immediately.
+
+ The intended use is to isolate issues that may arise when using
+ lambda statements.
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __init__(self, statement):
+ self._resolved = statement
+ self._propagate_attrs = statement._propagate_attrs
+
+ def __getattr__(self, key):
+ return getattr(self._resolved, key)
+
+ def __add__(self, other):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def add_criteria(self, other, **kw):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._resolved.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+
+class LinkedLambdaElement(StatementLambdaElement):
+ """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+ role = None
+
+ def __init__(self, fn, parent_lambda, opts):
+ self.opts = opts
+ self.fn = fn
+ self.parent_lambda = parent_lambda
+
+ self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+ self._retrieve_tracker_rec(fn, self, opts)
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+ __slots__ = (
+ "track_closure_variables",
+ "track_bound_values",
+ "bindparam_trackers",
+ "closure_trackers",
+ "build_py_wrappers",
+ )
+ _fns = weakref.WeakKeyDictionary()
+
+ _generation_mutex = threading.RLock()
+
+ @classmethod
+ def get(cls, fn, lambda_element, lambda_kw, **kw):
+ try:
+ # TODO: validate kw haven't changed?
+ return cls._fns[fn.__code__]
+ except KeyError:
+ pass
+
+ with cls._generation_mutex:
+ # check for other thread already created object
+ if fn.__code__ in cls._fns:
+ return cls._fns[fn.__code__]
+
+ cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+ fn, lambda_element, lambda_kw, **kw
+ )
+ return analyzed
+
+ def __init__(self, fn, lambda_element, opts):
+ if inspect.ismethod(fn):
+ raise exc.ArgumentError(
+ "Method %s may not be passed as a SQL expression" % fn
+ )
+ closure = fn.__closure__
+
+ self.track_bound_values = (
+ opts.track_bound_values and opts.global_track_bound_values
+ )
+ enable_tracking = opts.enable_tracking
+ track_on = opts.track_on
+ track_closure_variables = opts.track_closure_variables
+
+ self.track_closure_variables = track_closure_variables and not track_on
+
+ # a list of callables generated from _bound_parameter_getter_*
+ # functions. Each of these uses a PyWrapper object to retrieve
+ # a parameter value
+ self.bindparam_trackers = []
+
+ # a list of callables generated from _cache_key_getter_* functions
+ # these callables work to generate a cache key for the lambda
+ # based on what's inside its closure variables.
+ self.closure_trackers = []
+
+ self.build_py_wrappers = []
+
+ if enable_tracking:
+ if track_on:
+ self._init_track_on(track_on)
+
+ self._init_globals(fn)
+
+ if closure:
+ self._init_closure(fn)
+
+ self._setup_additional_closure_trackers(fn, lambda_element, opts)
+
+ def _init_track_on(self, track_on):
+ self.closure_trackers.extend(
+ self._cache_key_getter_track_on(idx, elem)
+ for idx, elem in enumerate(track_on)
+ )
+
+ def _init_globals(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ bindparam_trackers = self.bindparam_trackers
+ track_bound_values = self.track_bound_values
+
+ for name in fn.__code__.co_names:
+ if name not in fn.__globals__:
+ continue
+
+ _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((name, None))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_globals(name)
+ )
+
+ def _init_closure(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ closure = fn.__closure__
+
+ track_bound_values = self.track_bound_values
+ track_closure_variables = self.track_closure_variables
+ bindparam_trackers = self.bindparam_trackers
+ closure_trackers = self.closure_trackers
+
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, closure)
+ ):
+ _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((fv, closure_index))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_closure(
+ fv, closure_index
+ )
+ )
+ else:
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will
+ # recalculate the whole lambda and run it again.
+
+ if track_closure_variables:
+ closure_trackers.append(
+ self._cache_key_getter_closure_variable(
+ fn, fv, closure_index, cell.cell_contents
+ )
+ )
+
+ def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
+ # an additional step is to actually run the function, then
+ # go through the PyWrapper objects that were set up to catch a bound
+ # parameter. then if they *didn't* make a param, oh they're another
+ # object in the closure we have to track for our cache key. so
+ # create trackers to catch those.
+
+ analyzed_function = AnalyzedFunction(
+ self,
+ lambda_element,
+ None,
+ fn,
+ )
+
+ closure_trackers = self.closure_trackers
+
+ for pywrapper in analyzed_function.closure_pywrappers:
+ if not pywrapper._sa__has_param:
+ closure_trackers.append(
+ self._cache_key_getter_tracked_literal(fn, pywrapper)
+ )
+
+ @classmethod
+ def _roll_down_to_literal(cls, element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem, type)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+ def _bound_parameter_getter_func_globals(self, name):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__globals__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__globals__[name]
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+ def _bound_parameter_getter_func_closure(self, name, closure_index):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__closure__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__closure__[
+ closure_index
+ ].cell_contents
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+ def _cache_key_getter_track_on(self, idx, elem):
+ """Return a getter that will extend a cache key with new entries
+ from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+ """
+
+ if isinstance(elem, tuple):
+ # tuple must contain hascachekey elements
+ def get(closure, opts, anon_map, bindparams):
+ return tuple(
+ tup_elem._gen_cache_key(anon_map, bindparams)
+ for tup_elem in opts.track_on[idx]
+ )
+
+ elif isinstance(elem, traversals.HasCacheKey):
+
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
+
+ else:
+
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]
+
+ return get
+
+ def _cache_key_getter_closure_variable(
+ self,
+ fn,
+ variable_name,
+ idx,
+ cell_contents,
+ use_clause_element=False,
+ use_inspect=False,
+ ):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ """
+
+ if isinstance(cell_contents, traversals.HasCacheKey):
+
+ def get(closure, opts, anon_map, bindparams):
+
+ obj = closure[idx].cell_contents
+ if use_inspect:
+ obj = inspection.inspect(obj)
+ elif use_clause_element:
+ while hasattr(obj, "__clause_element__"):
+ if not getattr(obj, "is_clause_element", False):
+ obj = obj.__clause_element__()
+
+ return obj._gen_cache_key(anon_map, bindparams)
+
+ elif isinstance(cell_contents, types.FunctionType):
+
+ def get(closure, opts, anon_map, bindparams):
+ return closure[idx].cell_contents.__code__
+
+ elif isinstance(cell_contents, collections_abc.Sequence):
+
+ def get(closure, opts, anon_map, bindparams):
+ contents = closure[idx].cell_contents
+
+ try:
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in contents
+ )
+ except AttributeError as ae:
+ self._raise_for_uncacheable_closure_variable(
+ variable_name, fn, from_=ae
+ )
+
+ else:
+ # if the object is a mapped class or aliased class, or some
+ # other object in the ORM realm of things like that, imitate
+ # the logic used in coercions.expect() to roll it down to the
+ # SQL element
+ element = cell_contents
+ is_clause_element = False
+ while hasattr(element, "__clause_element__"):
+ is_clause_element = True
+ if not getattr(element, "is_clause_element", False):
+ element = element.__clause_element__()
+ else:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, insp, use_inspect=True
+ )
+ else:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, element, use_clause_element=True
+ )
+
+ self._raise_for_uncacheable_closure_variable(variable_name, fn)
+
+ return get
+
+ def _raise_for_uncacheable_closure_variable(
+ self, variable_name, fn, from_=None
+ ):
+ util.raise_(
+ exc.InvalidRequestError(
+ "Closure variable named '%s' inside of lambda callable %s "
+ "does not refer to a cacheable SQL element, and also does not "
+ "appear to be serving as a SQL literal bound value based on "
+ "the default "
+ "SQL expression returned by the function. This variable "
+ "needs to remain outside the scope of a SQL-generating lambda "
+ "so that a proper cache key may be generated from the "
+ "lambda's state. Evaluate this variable outside of the "
+ "lambda, set track_on=[<elements>] to explicitly select "
+ "closure elements to track, or set "
+ "track_closure_variables=False to exclude "
+ "closure variables from being part of the cache key."
+ % (variable_name, fn.__code__),
+ ),
+ from_=from_,
+ )
+
+ def _cache_key_getter_tracked_literal(self, fn, pytracker):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ this getter differs from _cache_key_getter_closure_variable
+ in that these are detected after the function is run, and PyWrapper
+ objects have recorded that a particular literal value is in fact
+ not being interpreted as a bound parameter.
+
+ """
+
+ elem = pytracker._sa__to_evaluate
+ closure_index = pytracker._sa__closure_index
+ variable_name = pytracker._sa__name
+
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, closure_index, elem
+ )
+
+
+class NonAnalyzedFunction(object):
+ __slots__ = ("expr",)
+
+ closure_bindparams = None
+ bindparam_trackers = None
+
+ def __init__(self, expr):
+ self.expr = expr
+
+ @property
+ def expected_expr(self):
+ return self.expr
+
+
+class AnalyzedFunction(object):
+ __slots__ = (
+ "analyzed_code",
+ "fn",
+ "closure_pywrappers",
+ "tracker_instrumented_fn",
+ "expr",
+ "bindparam_trackers",
+ "expected_expr",
+ "is_sequence",
+ "propagate_attrs",
+ "closure_bindparams",
+ )
+
+ def __init__(
+ self,
+ analyzed_code,
+ lambda_element,
+ apply_propagate_attrs,
+ fn,
+ ):
+ self.analyzed_code = analyzed_code
+ self.fn = fn
+
+ self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+ self._instrument_and_run_function(lambda_element)
+
+ self._coerce_expression(lambda_element, apply_propagate_attrs)
+
+ def _instrument_and_run_function(self, lambda_element):
+ analyzed_code = self.analyzed_code
+
+ fn = self.fn
+ self.closure_pywrappers = closure_pywrappers = []
+
+ build_py_wrappers = analyzed_code.build_py_wrappers
+
+ if not build_py_wrappers:
+ self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+ else:
+ track_closure_variables = analyzed_code.track_closure_variables
+ closure = fn.__closure__
+
+ # will form the __closure__ of the function when we rebuild it
+ if closure:
+ new_closure = {
+ fv: cell.cell_contents
+ for fv, cell in zip(fn.__code__.co_freevars, closure)
+ }
+ else:
+ new_closure = {}
+
+ # will form the __globals__ of the function when we rebuild it
+ new_globals = fn.__globals__.copy()
+
+ for name, closure_index in build_py_wrappers:
+ if closure_index is not None:
+ value = closure[closure_index].cell_contents
+ new_closure[name] = bind = PyWrapper(
+ fn,
+ name,
+ value,
+ closure_index=closure_index,
+ track_bound_values=(
+ self.analyzed_code.track_bound_values
+ ),
+ )
+ if track_closure_variables:
+ closure_pywrappers.append(bind)
+ else:
+ value = fn.__globals__[name]
+ new_globals[name] = bind = PyWrapper(fn, name, value)
+
+ # rewrite the original fn. things that look like they will
+ # become bound parameters are wrapped in a PyWrapper.
+ self.tracker_instrumented_fn = (
+ tracker_instrumented_fn
+ ) = self._rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
+
+ # now invoke the function. This will give us a new SQL
+ # expression, but all the places that there would be a bound
+ # parameter, the PyWrapper in its place will give us a bind
+ # with a predictable name we can match up later.
+
+ # additionally, each PyWrapper will log that it did in fact
+ # create a parameter, otherwise, it's some kind of Python
+ # object in the closure and we want to track that, to make
+ # sure it doesn't change to something else, or if it does,
+ # that we create a different tracked function with that
+ # variable.
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs):
+ """Run the tracker-generated expression through coercion rules.
+
+ After the user-defined lambda has been invoked to produce a statement
+ for re-use, run it through coercion rules to both check that it's the
+ correct type of object and also to coerce it to its useful form.
+
+ """
+
+ parent_lambda = lambda_element.parent_lambda
+ expr = self.expr
+
+ if parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ self.expected_expr = [
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+ for sub_expr in expr
+ ]
+ self.is_sequence = True
+ else:
+ self.expected_expr = coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+ self.is_sequence = False
+ else:
+ self.expected_expr = expr
+ self.is_sequence = False
+
+ if apply_propagate_attrs is not None:
+ self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ self.propagate_attrs = util.EMPTY_DICT
+
+ def _rewrite_code_obj(self, f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
+
+ yes it works in pypy :P
+
+ """
+
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ compat.exec_(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(
+ f.__code__, globals_, f.__name__, f.__defaults__, closure
+ )
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
+
+
+class PyWrapper(ColumnOperators):
+ """A wrapper object that is injected into the ``__globals__`` and
+ ``__closure__`` of a Python function.
+
+ When the function is instrumented with :class:`.PyWrapper` objects, it is
+ then invoked just once in order to set up the wrappers. We look through
+ all the :class:`.PyWrapper` objects we made to find the ones that generated
+ a :class:`.BindParameter` object, e.g. the expression system interpreted
+ something as a literal. Those positions in the globals/closure are then
+ ones that we will look at, each time a new lambda comes in that refers to
+ the same ``__code__`` object. In this way, we keep a single version of
+ the SQL expression that this lambda produced, without calling upon the
+ Python function that created it more than once, unless its other closure
+ variables have changed. The expression is then transformed to have the
+ new bound values embedded into it.
+
+ """
+
+ def __init__(
+ self,
+ fn,
+ name,
+ to_evaluate,
+ closure_index=None,
+ getter=None,
+ track_bound_values=True,
+ ):
+ self.fn = fn
+ self._name = name
+ self._to_evaluate = to_evaluate
+ self._param = None
+ self._has_param = False
+ self._bind_paths = {}
+ self._getter = getter
+ self._closure_index = closure_index
+ self.track_bound_values = track_bound_values
+
+ def __call__(self, *arg, **kw):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = elem(*arg, **kw)
+ if (
+ self._sa_track_bound_values
+ and coercions._deep_is_literal(value)
+ and not isinstance(
+ # TODO: coverage where an ORM option or similar is here
+ value,
+ traversals.HasCacheKey,
+ )
+ ):
+ name = object.__getattribute__(self, "_name")
+ raise exc.InvalidRequestError(
+ "Can't invoke Python callable %s() inside of lambda "
+ "expression argument at %s; lambda SQL constructs should "
+ "not invoke functions from closure variables to produce "
+ "literal values since the "
+ "lambda SQL system normally extracts bound values without "
+ "actually "
+ "invoking the lambda or any functions within it. Call the "
+ "function outside of the "
+ "lambda and assign to a local variable that is used in the "
+ "lambda as a closure variable, or set "
+ "track_bound_values=False if the return value of this "
+ "function is used in some other way other than a SQL bound "
+ "value." % (name, self._sa_fn.__code__)
+ )
+ else:
+ return value
+
+ def operate(self, op, *other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(elem, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(other, elem, **kwargs)
+
+ def _extract_bound_parameters(self, starting_point, result_list):
+ param = object.__getattribute__(self, "_param")
+ if param is not None:
+ param = param._with_value(starting_point, maintain_key=True)
+ result_list.append(param)
+ for pywrapper in object.__getattribute__(self, "_bind_paths").values():
+ getter = object.__getattribute__(pywrapper, "_getter")
+ element = getter(starting_point)
+ pywrapper._sa__extract_bound_parameters(element, result_list)
+
+ def __clause_element__(self):
+ param = object.__getattribute__(self, "_param")
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ if param is None:
+ name = object.__getattribute__(self, "_name")
+ self._param = param = elements.BindParameter(
+ name, required=False, unique=True
+ )
+ self._has_param = True
+ param.type = type_api._resolve_value_to_type(to_evaluate)
+ return param._with_value(to_evaluate, maintain_key=True)
+
+ def __bool__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
+ def __nonzero__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
+ def __getattribute__(self, key):
+ if key.startswith("_sa_"):
+ return object.__getattribute__(self, key[4:])
+ elif key in (
+ "__clause_element__",
+ "operate",
+ "reverse_operate",
+ "__class__",
+ "__dict__",
+ ):
+ return object.__getattribute__(self, key)
+
+ if key.startswith("__"):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return getattr(elem, key)
+ else:
+ return self._sa__add_getter(key, operator.attrgetter)
+
+ def __iter__(self):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return iter(elem)
+
+ def __getitem__(self, key):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ if not hasattr(elem, "__getitem__"):
+ raise AttributeError("__getitem__")
+
+ if isinstance(key, PyWrapper):
+ # TODO: coverage
+ raise exc.InvalidRequestError(
+ "Dictionary keys / list indexes inside of a cached "
+ "lambda must be Python literals only"
+ )
+ return self._sa__add_getter(key, operator.itemgetter)
+
+ def _add_getter(self, key, getter_fn):
+
+ bind_paths = object.__getattribute__(self, "_bind_paths")
+
+ bind_path_key = (key, getter_fn)
+ if bind_path_key in bind_paths:
+ return bind_paths[bind_path_key]
+
+ getter = getter_fn(key)
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = getter(elem)
+
+ rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
+
+ if coercions._deep_is_literal(rolled_down_value):
+ wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
+ bind_paths[bind_path_key] = wrapper
+ return wrapper
+ else:
+ return value
+
+
+@inspection._inspects(LambdaElement)
+def insp(lmb):
+ return inspection.inspect(lmb._resolved)
diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py
new file mode 100644
index 0000000..b7ad221
--- /dev/null
+++ b/lib/sqlalchemy/sql/naming.py
@@ -0,0 +1,210 @@
+# sqlalchemy/naming.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Establish constraint and index naming conventions.
+
+
+"""
+
+import re
+
+from . import events # noqa
+from .elements import _NONE_NAME
+from .elements import conv
+from .schema import CheckConstraint
+from .schema import Column
+from .schema import Constraint
+from .schema import ForeignKeyConstraint
+from .schema import Index
+from .schema import PrimaryKeyConstraint
+from .schema import Table
+from .schema import UniqueConstraint
+from .. import event
+from .. import exc
+
+
+class ConventionDict(object):
+ def __init__(self, const, table, convention):
+ self.const = const
+ self._is_fk = isinstance(const, ForeignKeyConstraint)
+ self.table = table
+ self.convention = convention
+ self._const_name = const.name
+
+ def _key_table_name(self):
+ return self.table.name
+
+ def _column_X(self, idx, attrname):
+ if self._is_fk:
+ try:
+ fk = self.const.elements[idx]
+ except IndexError:
+ return ""
+ else:
+ return getattr(fk.parent, attrname)
+ else:
+ cols = list(self.const.columns)
+ try:
+ col = cols[idx]
+ except IndexError:
+ return ""
+ else:
+ return getattr(col, attrname)
+
+ def _key_constraint_name(self):
+ if self._const_name in (None, _NONE_NAME):
+ raise exc.InvalidRequestError(
+ "Naming convention including "
+ "%(constraint_name)s token requires that "
+ "constraint is explicitly named."
+ )
+ if not isinstance(self._const_name, conv):
+ self.const.name = None
+ return self._const_name
+
+ def _key_column_X_key(self, idx):
+ # note this method was missing before
+ # [ticket:3989], meaning tokens like ``%(column_0_key)s`` weren't
+ # working even though documented.
+ return self._column_X(idx, "key")
+
+ def _key_column_X_name(self, idx):
+ return self._column_X(idx, "name")
+
+ def _key_column_X_label(self, idx):
+ return self._column_X(idx, "_ddl_label")
+
+ def _key_referred_table_name(self):
+ fk = self.const.elements[0]
+ refs = fk.target_fullname.split(".")
+ if len(refs) == 3:
+ refschema, reftable, refcol = refs
+ else:
+ reftable, refcol = refs
+ return reftable
+
+ def _key_referred_column_X_name(self, idx):
+ fk = self.const.elements[idx]
+ # note that before [ticket:3989], this method was returning
+ # the specification for the :class:`.ForeignKey` itself, which normally
+ # would be using the ``.key`` of the column, not the name.
+ return fk.column.name
+
+ def __getitem__(self, key):
+ if key in self.convention:
+ return self.convention[key](self.const, self.table)
+ elif hasattr(self, "_key_%s" % key):
+ return getattr(self, "_key_%s" % key)()
+ else:
+ col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
+ if col_template:
+ idx = col_template.group(1)
+ multiples = col_template.group(2)
+
+ if multiples:
+ if self._is_fk:
+ elems = self.const.elements
+ else:
+ elems = list(self.const.columns)
+ tokens = []
+ for idx, elem in enumerate(elems):
+ attr = "_key_" + key.replace("0" + multiples, "X")
+ try:
+ tokens.append(getattr(self, attr)(idx))
+ except AttributeError:
+ raise KeyError(key)
+ sep = "_" if multiples.startswith("_") else ""
+ return sep.join(tokens)
+ else:
+ attr = "_key_" + key.replace(idx, "X")
+ idx = int(idx)
+ if hasattr(self, attr):
+ return getattr(self, attr)(idx)
+ raise KeyError(key)
+
+
+_prefix_dict = {
+ Index: "ix",
+ PrimaryKeyConstraint: "pk",
+ CheckConstraint: "ck",
+ UniqueConstraint: "uq",
+ ForeignKeyConstraint: "fk",
+}
+
+
+def _get_convention(dict_, key):
+
+ for super_ in key.__mro__:
+ if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
+ return dict_[_prefix_dict[super_]]
+ elif super_ in dict_:
+ return dict_[super_]
+ else:
+ return None
+
+
+def _constraint_name_for_table(const, table):
+ metadata = table.metadata
+ convention = _get_convention(metadata.naming_convention, type(const))
+
+ if isinstance(const.name, conv):
+ return const.name
+ elif (
+ convention is not None
+ and not isinstance(const.name, conv)
+ and (
+ const.name is None
+ or "constraint_name" in convention
+ or const.name is _NONE_NAME
+ )
+ ):
+ return conv(
+ convention
+ % ConventionDict(const, table, metadata.naming_convention)
+ )
+ elif convention is _NONE_NAME:
+ return None
+
+
+@event.listens_for(
+ PrimaryKeyConstraint, "_sa_event_column_added_to_pk_constraint"
+)
+def _column_added_to_pk_constraint(pk_constraint, col):
+ if pk_constraint._implicit_generated:
+ # only operate upon the "implicit" pk constraint for now,
+ # as we have to force the name to None to reset it. the
+ # "implicit" constraint will only have a naming convention name
+ # if at all.
+ table = pk_constraint.table
+ pk_constraint.name = None
+ newname = _constraint_name_for_table(pk_constraint, table)
+ if newname:
+ pk_constraint.name = newname
+
+
+@event.listens_for(Constraint, "after_parent_attach")
+@event.listens_for(Index, "after_parent_attach")
+def _constraint_name(const, table):
+ if isinstance(table, Column):
+ # this path occurs for a CheckConstraint linked to a Column
+
+ # for column-attached constraint, set another event
+ # to link the column attached to the table as this constraint
+ # associated with the table.
+ event.listen(
+ table,
+ "after_parent_attach",
+ lambda col, table: _constraint_name(const, table),
+ )
+
+ elif isinstance(table, Table):
+ if isinstance(const.name, conv) or const.name is _NONE_NAME:
+ return
+
+ newname = _constraint_name_for_table(const, table)
+ if newname:
+ const.name = newname
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
new file mode 100644
index 0000000..1da5032
--- /dev/null
+++ b/lib/sqlalchemy/sql/operators.py
@@ -0,0 +1,1688 @@
+# sql/operators.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines operators used in SQL expressions."""
+
+from operator import add
+from operator import and_
+from operator import contains
+from operator import eq
+from operator import ge
+from operator import getitem
+from operator import gt
+from operator import inv
+from operator import le
+from operator import lshift
+from operator import lt
+from operator import mod
+from operator import mul
+from operator import ne
+from operator import neg
+from operator import or_
+from operator import rshift
+from operator import sub
+from operator import truediv
+
+from .. import util
+
+
+if util.py2k:
+ from operator import div
+else:
+ div = truediv
+
+
+class Operators(object):
+ """Base of comparison and logical operators.
+
+ Implements base methods
+ :meth:`~sqlalchemy.sql.operators.Operators.operate` and
+ :meth:`~sqlalchemy.sql.operators.Operators.reverse_operate`, as well as
+ :meth:`~sqlalchemy.sql.operators.Operators.__and__`,
+ :meth:`~sqlalchemy.sql.operators.Operators.__or__`,
+ :meth:`~sqlalchemy.sql.operators.Operators.__invert__`.
+
+ Usually is used via its most common subclass
+ :class:`.ColumnOperators`.
+
+ """
+
+ __slots__ = ()
+
+ def __and__(self, other):
+ """Implement the ``&`` operator.
+
+ When used with SQL expressions, results in an
+ AND operation, equivalent to
+ :func:`_expression.and_`, that is::
+
+ a & b
+
+ is equivalent to::
+
+ from sqlalchemy import and_
+ and_(a, b)
+
+ Care should be taken when using ``&`` regarding
+ operator precedence; the ``&`` operator has the highest precedence.
+ The operands should be enclosed in parenthesis if they contain
+ further sub expressions::
+
+ (a == 2) & (b == 4)
+
+ """
+ return self.operate(and_, other)
+
+ def __or__(self, other):
+ """Implement the ``|`` operator.
+
+ When used with SQL expressions, results in an
+ OR operation, equivalent to
+ :func:`_expression.or_`, that is::
+
+ a | b
+
+ is equivalent to::
+
+ from sqlalchemy import or_
+ or_(a, b)
+
+ Care should be taken when using ``|`` regarding
+ operator precedence; the ``|`` operator has the highest precedence.
+ The operands should be enclosed in parenthesis if they contain
+ further sub expressions::
+
+ (a == 2) | (b == 4)
+
+ """
+ return self.operate(or_, other)
+
+ def __invert__(self):
+ """Implement the ``~`` operator.
+
+ When used with SQL expressions, results in a
+ NOT operation, equivalent to
+ :func:`_expression.not_`, that is::
+
+ ~a
+
+ is equivalent to::
+
+ from sqlalchemy import not_
+ not_(a)
+
+ """
+ return self.operate(inv)
+
+ def op(
+ self, opstring, precedence=0, is_comparison=False, return_type=None
+ ):
+ """Produce a generic operator function.
+
+ e.g.::
+
+ somecolumn.op("*")(5)
+
+ produces::
+
+ somecolumn * 5
+
+ This function can also be used to make bitwise operators explicit. For
+ example::
+
+ somecolumn.op('&')(0xff)
+
+ is a bitwise AND of the value in ``somecolumn``.
+
+ :param operator: a string which will be output as the infix operator
+ between this element and the expression passed to the
+ generated function.
+
+ :param precedence: precedence to apply to the operator, when
+ parenthesizing expressions. A lower number will cause the expression
+ to be parenthesized when applied against another operator with
+ higher precedence. The default value of ``0`` is lower than all
+ operators except for the comma (``,``) and ``AS`` operators.
+ A value of 100 will be higher or equal to all operators, and -100
+ will be lower than or equal to all operators.
+
+ :param is_comparison: legacy; if True, the operator will be considered
+ as a "comparison" operator, that is which evaluates to a boolean
+ true/false value, like ``==``, ``>``, etc. This flag is provided
+ so that ORM relationships can establish that the operator is a
+ comparison operator when used in a custom join condition.
+
+ Using the ``is_comparison`` parameter is superseded by using the
+ :meth:`.Operators.bool_op` method instead; this more succinct
+ operator sets this parameter automatically. In SQLAlchemy 2.0 it
+ will also provide for improved typing support.
+
+ :param return_type: a :class:`.TypeEngine` class or object that will
+ force the return type of an expression produced by this operator
+ to be of that type. By default, operators that specify
+ :paramref:`.Operators.op.is_comparison` will resolve to
+ :class:`.Boolean`, and those that do not will be of the same
+ type as the left-hand operand.
+
+ .. seealso::
+
+ :meth:`.Operators.bool_op`
+
+ :ref:`types_operators`
+
+ :ref:`relationship_custom_operator`
+
+ """
+ operator = custom_op(opstring, precedence, is_comparison, return_type)
+
+ def against(other):
+ return operator(self, other)
+
+ return against
+
+ def bool_op(self, opstring, precedence=0):
+ """Return a custom boolean operator.
+
+ This method is shorthand for calling
+ :meth:`.Operators.op` and passing the
+ :paramref:`.Operators.op.is_comparison`
+ flag with True. A key advantage to using :meth:`.Operators.bool_op`
+ is that when using column constructs, the "boolean" nature of the
+ returned expression will be present for :pep:`484` purposes.
+
+ .. seealso::
+
+ :meth:`.Operators.op`
+
+ """
+ return self.op(opstring, precedence=precedence, is_comparison=True)
+
+ def operate(self, op, *other, **kwargs):
+ r"""Operate on an argument.
+
+ This is the lowest level of operation, raises
+ :class:`NotImplementedError` by default.
+
+ Overriding this on a subclass can allow common
+ behavior to be applied to all operations.
+ For example, overriding :class:`.ColumnOperators`
+ to apply ``func.lower()`` to the left and right
+ side::
+
+ class MyComparator(ColumnOperators):
+ def operate(self, op, other, **kwargs):
+ return op(func.lower(self), func.lower(other), **kwargs)
+
+ :param op: Operator callable.
+ :param \*other: the 'other' side of the operation. Will
+ be a single scalar for most operations.
+ :param \**kwargs: modifiers. These may be passed by special
+ operators such as :meth:`ColumnOperators.contains`.
+
+
+ """
+ raise NotImplementedError(str(op))
+
+ def reverse_operate(self, op, other, **kwargs):
+ """Reverse operate on an argument.
+
+ Usage is the same as :meth:`operate`.
+
+ """
+ raise NotImplementedError(str(op))
+
+
+class custom_op(object):
+ """Represent a 'custom' operator.
+
+ :class:`.custom_op` is normally instantiated when the
+ :meth:`.Operators.op` or :meth:`.Operators.bool_op` methods
+ are used to create a custom operator callable. The class can also be
+ used directly when programmatically constructing expressions. E.g.
+ to represent the "factorial" operation::
+
+ from sqlalchemy.sql import UnaryExpression
+ from sqlalchemy.sql import operators
+ from sqlalchemy import Numeric
+
+ unary = UnaryExpression(table.c.somecolumn,
+ modifier=operators.custom_op("!"),
+ type_=Numeric)
+
+
+ .. seealso::
+
+ :meth:`.Operators.op`
+
+ :meth:`.Operators.bool_op`
+
+ """
+
+ __name__ = "custom_op"
+
+ def __init__(
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ natural_self_precedent=False,
+ eager_grouping=False,
+ ):
+ self.opstring = opstring
+ self.precedence = precedence
+ self.is_comparison = is_comparison
+ self.natural_self_precedent = natural_self_precedent
+ self.eager_grouping = eager_grouping
+ self.return_type = (
+ return_type._to_instance(return_type) if return_type else None
+ )
+
+ def __eq__(self, other):
+ return isinstance(other, custom_op) and other.opstring == self.opstring
+
+ def __hash__(self):
+ return id(self)
+
+ def __call__(self, left, right, **kw):
+ return left.operate(self, right, **kw)
+
+
+class ColumnOperators(Operators):
+ """Defines boolean, comparison, and other operators for
+ :class:`_expression.ColumnElement` expressions.
+
+ By default, all methods call down to
+ :meth:`.operate` or :meth:`.reverse_operate`,
+ passing in the appropriate operator function from the
+ Python builtin ``operator`` module or
+ a SQLAlchemy-specific operator function from
+ :mod:`sqlalchemy.expression.operators`. For example
+ the ``__eq__`` function::
+
+ def __eq__(self, other):
+ return self.operate(operators.eq, other)
+
+ Where ``operators.eq`` is essentially::
+
+ def eq(a, b):
+ return a == b
+
+ The core column expression unit :class:`_expression.ColumnElement`
+ overrides :meth:`.Operators.operate` and others
+ to return further :class:`_expression.ColumnElement` constructs,
+ so that the ``==`` operation above is replaced by a clause
+ construct.
+
+ .. seealso::
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ :class:`.ColumnOperators`
+
+ :class:`.PropComparator`
+
+ """
+
+ __slots__ = ()
+
+ timetuple = None
+ """Hack, allows datetime objects to be compared on the LHS."""
+
+ def __lt__(self, other):
+ """Implement the ``<`` operator.
+
+ In a column context, produces the clause ``a < b``.
+
+ """
+ return self.operate(lt, other)
+
+ def __le__(self, other):
+ """Implement the ``<=`` operator.
+
+ In a column context, produces the clause ``a <= b``.
+
+ """
+ return self.operate(le, other)
+
+ __hash__ = Operators.__hash__
+
+ def __eq__(self, other):
+ """Implement the ``==`` operator.
+
+ In a column context, produces the clause ``a = b``.
+ If the target is ``None``, produces ``a IS NULL``.
+
+ """
+ return self.operate(eq, other)
+
+ def __ne__(self, other):
+ """Implement the ``!=`` operator.
+
+ In a column context, produces the clause ``a != b``.
+ If the target is ``None``, produces ``a IS NOT NULL``.
+
+ """
+ return self.operate(ne, other)
+
+ def is_distinct_from(self, other):
+ """Implement the ``IS DISTINCT FROM`` operator.
+
+ Renders "a IS DISTINCT FROM b" on most platforms;
+ on some such as SQLite may render "a IS NOT b".
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(is_distinct_from, other)
+
+ def is_not_distinct_from(self, other):
+ """Implement the ``IS NOT DISTINCT FROM`` operator.
+
+ Renders "a IS NOT DISTINCT FROM b" on most platforms;
+ on some such as SQLite may render "a IS b".
+
+ .. versionchanged:: 1.4 The ``is_not_distinct_from()`` operator is
+ renamed from ``isnot_distinct_from()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(is_not_distinct_from, other)
+
+ # deprecated 1.4; see #5435
+ isnot_distinct_from = is_not_distinct_from
+
+ def __gt__(self, other):
+ """Implement the ``>`` operator.
+
+ In a column context, produces the clause ``a > b``.
+
+ """
+ return self.operate(gt, other)
+
+ def __ge__(self, other):
+ """Implement the ``>=`` operator.
+
+ In a column context, produces the clause ``a >= b``.
+
+ """
+ return self.operate(ge, other)
+
+ def __neg__(self):
+ """Implement the ``-`` operator.
+
+ In a column context, produces the clause ``-a``.
+
+ """
+ return self.operate(neg)
+
+ def __contains__(self, other):
+ return self.operate(contains, other)
+
+ def __getitem__(self, index):
+ """Implement the [] operator.
+
+ This can be used by some database-specific types
+ such as PostgreSQL ARRAY and HSTORE.
+
+ """
+ return self.operate(getitem, index)
+
+ def __lshift__(self, other):
+ """implement the << operator.
+
+ Not used by SQLAlchemy core, this is provided
+ for custom operator systems which want to use
+ << as an extension point.
+ """
+ return self.operate(lshift, other)
+
+ def __rshift__(self, other):
+ """implement the >> operator.
+
+ Not used by SQLAlchemy core, this is provided
+ for custom operator systems which want to use
+ >> as an extension point.
+ """
+ return self.operate(rshift, other)
+
+ def concat(self, other):
+ """Implement the 'concat' operator.
+
+ In a column context, produces the clause ``a || b``,
+ or uses the ``concat()`` operator on MySQL.
+
+ """
+ return self.operate(concat_op, other)
+
+ def _rconcat(self, other):
+ """Implement an 'rconcat' operator.
+
+ this is for internal use at the moment
+
+ .. versionadded:: 1.4.40
+
+ """
+ return self.reverse_operate(concat_op, other)
+
+ def like(self, other, escape=None):
+ r"""Implement the ``like`` operator.
+
+ In a column context, produces the expression::
+
+ a LIKE other
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.like("%foobar%"))
+
+ :param other: expression to be compared
+ :param escape: optional escape character, renders the ``ESCAPE``
+ keyword, e.g.::
+
+ somecolumn.like("foo/%bar", escape="/")
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.ilike`
+
+ """
+ return self.operate(like_op, other, escape=escape)
+
+ def ilike(self, other, escape=None):
+ r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE.
+
+ In a column context, produces an expression either of the form::
+
+ lower(a) LIKE lower(other)
+
+ Or on backends that support the ILIKE operator::
+
+ a ILIKE other
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.ilike("%foobar%"))
+
+ :param other: expression to be compared
+ :param escape: optional escape character, renders the ``ESCAPE``
+ keyword, e.g.::
+
+ somecolumn.ilike("foo/%bar", escape="/")
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(ilike_op, other, escape=escape)
+
+ def in_(self, other):
+ """Implement the ``in`` operator.
+
+ In a column context, produces the clause ``column IN <other>``.
+
+ The given parameter ``other`` may be:
+
+ * A list of literal values, e.g.::
+
+ stmt.where(column.in_([1, 2, 3]))
+
+ In this calling form, the list of items is converted to a set of
+ bound parameters the same length as the list given::
+
+ WHERE COL IN (?, ?, ?)
+
+ * A list of tuples may be provided if the comparison is against a
+ :func:`.tuple_` containing multiple expressions::
+
+ from sqlalchemy import tuple_
+ stmt.where(tuple_(col1, col2).in_([(1, 10), (2, 20), (3, 30)]))
+
+ * An empty list, e.g.::
+
+ stmt.where(column.in_([]))
+
+ In this calling form, the expression renders an "empty set"
+ expression. These expressions are tailored to individual backends
+ and are generally trying to get an empty SELECT statement as a
+ subquery. Such as on SQLite, the expression is::
+
+ WHERE col IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
+
+ .. versionchanged:: 1.4 empty IN expressions now use an
+ execution-time generated SELECT subquery in all cases.
+
+ * A bound parameter, e.g. :func:`.bindparam`, may be used if it
+ includes the :paramref:`.bindparam.expanding` flag::
+
+ stmt.where(column.in_(bindparam('value', expanding=True)))
+
+ In this calling form, the expression renders a special non-SQL
+ placeholder expression that looks like::
+
+ WHERE COL IN ([EXPANDING_value])
+
+ This placeholder expression is intercepted at statement execution
+ time to be converted into the variable number of bound parameter
+ form illustrated earlier. If the statement were executed as::
+
+ connection.execute(stmt, {"value": [1, 2, 3]})
+
+ The database would be passed a bound parameter for each value::
+
+ WHERE COL IN (?, ?, ?)
+
+ .. versionadded:: 1.2 added "expanding" bound parameters
+
+ If an empty list is passed, a special "empty list" expression,
+ which is specific to the database in use, is rendered. On
+ SQLite this would be::
+
+ WHERE COL IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
+
+ .. versionadded:: 1.3 "expanding" bound parameters now support
+ empty lists
+
+ * a :func:`_expression.select` construct, which is usually a
+ correlated scalar select::
+
+ stmt.where(
+ column.in_(
+ select(othertable.c.y).
+ where(table.c.x == othertable.c.x)
+ )
+ )
+
+ In this calling form, :meth:`.ColumnOperators.in_` renders as given::
+
+ WHERE COL IN (SELECT othertable.y
+ FROM othertable WHERE othertable.x = table.x)
+
+ :param other: a list of literals, a :func:`_expression.select`
+ construct, or a :func:`.bindparam` construct that includes the
+ :paramref:`.bindparam.expanding` flag set to True.
+
+ """
+ return self.operate(in_op, other)
+
+ def not_in(self, other):
+ """implement the ``NOT IN`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.in_`, i.e. ``~x.in_(y)``.
+
+ In the case that ``other`` is an empty sequence, the compiler
+ produces an "empty not in" expression. This defaults to the
+ expression "1 = 1" to produce true in all cases. The
+ :paramref:`_sa.create_engine.empty_in_strategy` may be used to
+ alter this behavior.
+
+ .. versionchanged:: 1.4 The ``not_in()`` operator is renamed from
+ ``notin_()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. versionchanged:: 1.2 The :meth:`.ColumnOperators.in_` and
+ :meth:`.ColumnOperators.not_in` operators
+ now produce a "static" expression for an empty IN sequence
+ by default.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.in_`
+
+ """
+ return self.operate(not_in_op, other)
+
+ # deprecated 1.4; see #5429
+ notin_ = not_in
+
+ def not_like(self, other, escape=None):
+ """implement the ``NOT LIKE`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.like`, i.e. ``~x.like(y)``.
+
+ .. versionchanged:: 1.4 The ``not_like()`` operator is renamed from
+ ``notlike()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(notlike_op, other, escape=escape)
+
+ # deprecated 1.4; see #5435
+ notlike = not_like
+
+ def not_ilike(self, other, escape=None):
+ """implement the ``NOT ILIKE`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.ilike`, i.e. ``~x.ilike(y)``.
+
+ .. versionchanged:: 1.4 The ``not_ilike()`` operator is renamed from
+ ``notilike()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.ilike`
+
+ """
+ return self.operate(notilike_op, other, escape=escape)
+
+ # deprecated 1.4; see #5435
+ notilike = not_ilike
+
+ def is_(self, other):
+ """Implement the ``IS`` operator.
+
+ Normally, ``IS`` is generated automatically when comparing to a
+ value of ``None``, which resolves to ``NULL``. However, explicit
+ usage of ``IS`` may be desirable if comparing to boolean values
+ on certain platforms.
+
+ .. seealso:: :meth:`.ColumnOperators.is_not`
+
+ """
+ return self.operate(is_, other)
+
+ def is_not(self, other):
+ """Implement the ``IS NOT`` operator.
+
+ Normally, ``IS NOT`` is generated automatically when comparing to a
+ value of ``None``, which resolves to ``NULL``. However, explicit
+ usage of ``IS NOT`` may be desirable if comparing to boolean values
+ on certain platforms.
+
+ .. versionchanged:: 1.4 The ``is_not()`` operator is renamed from
+ ``isnot()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso:: :meth:`.ColumnOperators.is_`
+
+ """
+ return self.operate(is_not, other)
+
+ # deprecated 1.4; see #5429
+ isnot = is_not
+
+ def startswith(self, other, **kwargs):
+ r"""Implement the ``startswith`` operator.
+
+ Produces a LIKE expression that tests against a match for the start
+ of a string value::
+
+ column LIKE <other> || '%'
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.startswith("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.startswith.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.startswith.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.startswith.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.startswith("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE :param || '%' ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.startswith("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE :param || '%' ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.startswith.autoescape`::
+
+ somecolumn.startswith("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.endswith`
+
+ :meth:`.ColumnOperators.contains`
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(startswith_op, other, **kwargs)
+
+ def endswith(self, other, **kwargs):
+ r"""Implement the 'endswith' operator.
+
+ Produces a LIKE expression that tests against a match for the end
+ of a string value::
+
+ column LIKE '%' || <other>
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.endswith("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.endswith.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.endswith.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.endswith.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.endswith("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.endswith("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.endswith.autoescape`::
+
+ somecolumn.endswith("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.startswith`
+
+ :meth:`.ColumnOperators.contains`
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(endswith_op, other, **kwargs)
+
+ def contains(self, other, **kwargs):
+ r"""Implement the 'contains' operator.
+
+ Produces a LIKE expression that tests against a match for the middle
+ of a string value::
+
+ column LIKE '%' || <other> || '%'
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.contains("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.contains.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.contains.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.contains.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.contains("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param || '%' ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.contains("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param || '%' ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.contains.autoescape`::
+
+ somecolumn.contains("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.startswith`
+
+ :meth:`.ColumnOperators.endswith`
+
+ :meth:`.ColumnOperators.like`
+
+
+ """
+ return self.operate(contains_op, other, **kwargs)
+
+ def match(self, other, **kwargs):
+ """Implements a database-specific 'match' operator.
+
+ :meth:`_sql.ColumnOperators.match` attempts to resolve to
+ a MATCH-like function or operator provided by the backend.
+ Examples include:
+
+ * PostgreSQL - renders ``x @@ to_tsquery(y)``
+ * MySQL - renders ``MATCH (x) AGAINST (y IN BOOLEAN MODE)``
+
+ .. seealso::
+
+ :class:`_mysql.match` - MySQL specific construct with
+ additional features.
+
+ * Oracle - renders ``CONTAINS(x, y)``
+ * other backends may provide special implementations.
+ * Backends without any special implementation will emit
+ the operator as "MATCH". This is compatible with SQLite, for
+ example.
+
+ """
+ return self.operate(match_op, other, **kwargs)
+
+ def regexp_match(self, pattern, flags=None):
+ """Implements a database-specific 'regexp match' operator.
+
+ E.g.::
+
+ stmt = select(table.c.some_column).where(
+ table.c.some_column.regexp_match('^(b|c)')
+ )
+
+ :meth:`_sql.ColumnOperators.regexp_match` attempts to resolve to
+ a REGEXP-like function or operator provided by the backend, however
+ the specific regular expression syntax and flags available are
+ **not backend agnostic**.
+
+ Examples include:
+
+ * PostgreSQL - renders ``x ~ y`` or ``x !~ y`` when negated.
+ * Oracle - renders ``REGEXP_LIKE(x, y)``
+ * SQLite - uses SQLite's ``REGEXP`` placeholder operator and calls into
+ the Python ``re.match()`` builtin.
+ * other backends may provide special implementations.
+ * Backends without any special implementation will emit
+ the operator as "REGEXP" or "NOT REGEXP". This is compatible with
+ SQLite and MySQL, for example.
+
+ Regular expression support is currently implemented for Oracle,
+ PostgreSQL, MySQL and MariaDB. Partial support is available for
+ SQLite. Support among third-party dialects may vary.
+
+ :param pattern: The regular expression pattern string or column
+ clause.
+ :param flags: Any regular expression string flags to apply. Flags
+ tend to be backend specific. It can be a string or a column clause.
+ Some backends, like PostgreSQL and MariaDB, may alternatively
+ specify the flags as part of the pattern.
+ When using the ignore case flag 'i' in PostgreSQL, the ignore case
+ regexp match operator ``~*`` or ``!~*`` will be used.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.regexp_replace`
+
+
+ """
+ return self.operate(regexp_match_op, pattern, flags=flags)
+
+ def regexp_replace(self, pattern, replacement, flags=None):
+ """Implements a database-specific 'regexp replace' operator.
+
+ E.g.::
+
+ stmt = select(
+ table.c.some_column.regexp_replace(
+ 'b(..)',
+ 'X\1Y',
+ flags='g'
+ )
+ )
+
+ :meth:`_sql.ColumnOperators.regexp_replace` attempts to resolve to
+ a REGEXP_REPLACE-like function provided by the backend, that
+ usually emit the function ``REGEXP_REPLACE()``. However,
+ the specific regular expression syntax and flags available are
+ **not backend agnostic**.
+
+ Regular expression replacement support is currently implemented for
+ Oracle, PostgreSQL, MySQL 8 or greater and MariaDB. Support among
+ third-party dialects may vary.
+
+ :param pattern: The regular expression pattern string or column
+ clause.
+ :param pattern: The replacement string or column clause.
+ :param flags: Any regular expression string flags to apply. Flags
+ tend to be backend specific. It can be a string or a column clause.
+ Some backends, like PostgreSQL and MariaDB, may alternatively
+ specify the flags as part of the pattern.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.regexp_match`
+
+ """
+ return self.operate(
+ regexp_replace_op, pattern, replacement=replacement, flags=flags
+ )
+
+ def desc(self):
+ """Produce a :func:`_expression.desc` clause against the
+ parent object."""
+ return self.operate(desc_op)
+
+ def asc(self):
+ """Produce a :func:`_expression.asc` clause against the
+ parent object."""
+ return self.operate(asc_op)
+
+ def nulls_first(self):
+ """Produce a :func:`_expression.nulls_first` clause against the
+ parent object.
+
+ .. versionchanged:: 1.4 The ``nulls_first()`` operator is
+ renamed from ``nullsfirst()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+ """
+ return self.operate(nulls_first_op)
+
+ # deprecated 1.4; see #5435
+ nullsfirst = nulls_first
+
+ def nulls_last(self):
+ """Produce a :func:`_expression.nulls_last` clause against the
+ parent object.
+
+ .. versionchanged:: 1.4 The ``nulls_last()`` operator is
+ renamed from ``nullslast()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+ """
+ return self.operate(nulls_last_op)
+
+ # deprecated 1.4; see #5429
+ nullslast = nulls_last
+
+ def collate(self, collation):
+ """Produce a :func:`_expression.collate` clause against
+ the parent object, given the collation string.
+
+ .. seealso::
+
+ :func:`_expression.collate`
+
+ """
+ return self.operate(collate, collation)
+
+ def __radd__(self, other):
+ """Implement the ``+`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__add__`.
+
+ """
+ return self.reverse_operate(add, other)
+
+ def __rsub__(self, other):
+ """Implement the ``-`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__sub__`.
+
+ """
+ return self.reverse_operate(sub, other)
+
+ def __rmul__(self, other):
+ """Implement the ``*`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__mul__`.
+
+ """
+ return self.reverse_operate(mul, other)
+
+ def __rdiv__(self, other):
+ """Implement the ``/`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__div__`.
+
+ """
+ return self.reverse_operate(div, other)
+
+ def __rmod__(self, other):
+ """Implement the ``%`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__mod__`.
+
+ """
+ return self.reverse_operate(mod, other)
+
+ def between(self, cleft, cright, symmetric=False):
+ """Produce a :func:`_expression.between` clause against
+ the parent object, given the lower and upper range.
+
+ """
+ return self.operate(between_op, cleft, cright, symmetric=symmetric)
+
+ def distinct(self):
+ """Produce a :func:`_expression.distinct` clause against the
+ parent object.
+
+ """
+ return self.operate(distinct_op)
+
+ def any_(self):
+ """Produce an :func:`_expression.any_` clause against the
+ parent object.
+
+ See the documentation for :func:`_sql.any_` for examples.
+
+ .. note:: be sure to not confuse the newer
+ :meth:`_sql.ColumnOperators.any_` method with its older
+ :class:`_types.ARRAY`-specific counterpart, the
+ :meth:`_types.ARRAY.Comparator.any` method, which a different
+ calling syntax and usage pattern.
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(any_op)
+
+ def all_(self):
+ """Produce an :func:`_expression.all_` clause against the
+ parent object.
+
+ See the documentation for :func:`_sql.all_` for examples.
+
+ .. note:: be sure to not confuse the newer
+ :meth:`_sql.ColumnOperators.all_` method with its older
+ :class:`_types.ARRAY`-specific counterpart, the
+ :meth:`_types.ARRAY.Comparator.all` method, which a different
+ calling syntax and usage pattern.
+
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(all_op)
+
+ def __add__(self, other):
+ """Implement the ``+`` operator.
+
+ In a column context, produces the clause ``a + b``
+ if the parent object has non-string affinity.
+ If the parent object has a string affinity,
+ produces the concatenation operator, ``a || b`` -
+ see :meth:`.ColumnOperators.concat`.
+
+ """
+ return self.operate(add, other)
+
+ def __sub__(self, other):
+ """Implement the ``-`` operator.
+
+ In a column context, produces the clause ``a - b``.
+
+ """
+ return self.operate(sub, other)
+
+ def __mul__(self, other):
+ """Implement the ``*`` operator.
+
+ In a column context, produces the clause ``a * b``.
+
+ """
+ return self.operate(mul, other)
+
+ def __div__(self, other):
+ """Implement the ``/`` operator.
+
+ In a column context, produces the clause ``a / b``.
+
+ """
+ return self.operate(div, other)
+
+ def __mod__(self, other):
+ """Implement the ``%`` operator.
+
+ In a column context, produces the clause ``a % b``.
+
+ """
+ return self.operate(mod, other)
+
+ def __truediv__(self, other):
+ """Implement the ``//`` operator.
+
+ In a column context, produces the clause ``a / b``.
+
+ """
+ return self.operate(truediv, other)
+
+ def __rtruediv__(self, other):
+ """Implement the ``//`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__truediv__`.
+
+ """
+ return self.reverse_operate(truediv, other)
+
+
+_commutative = {eq, ne, add, mul}
+_comparison = {eq, ne, lt, gt, ge, le}
+
+
+def commutative_op(fn):
+ _commutative.add(fn)
+ return fn
+
+
+def comparison_op(fn):
+ _comparison.add(fn)
+ return fn
+
+
+def from_():
+ raise NotImplementedError()
+
+
+@comparison_op
+def function_as_comparison_op():
+ raise NotImplementedError()
+
+
+def as_():
+ raise NotImplementedError()
+
+
+def exists():
+ raise NotImplementedError()
+
+
+def is_true(a):
+ raise NotImplementedError()
+
+
+# 1.4 deprecated; see #5435
+istrue = is_true
+
+
+def is_false(a):
+ raise NotImplementedError()
+
+
+# 1.4 deprecated; see #5435
+isfalse = is_false
+
+
+@comparison_op
+def is_distinct_from(a, b):
+ return a.is_distinct_from(b)
+
+
+@comparison_op
+def is_not_distinct_from(a, b):
+ return a.is_not_distinct_from(b)
+
+
+# deprecated 1.4; see #5435
+isnot_distinct_from = is_not_distinct_from
+
+
+@comparison_op
+def is_(a, b):
+ return a.is_(b)
+
+
+@comparison_op
+def is_not(a, b):
+ return a.is_not(b)
+
+
+# 1.4 deprecated; see #5429
+isnot = is_not
+
+
+def collate(a, b):
+ return a.collate(b)
+
+
+def op(a, opstring, b):
+ return a.op(opstring)(b)
+
+
+@comparison_op
+def like_op(a, b, escape=None):
+ return a.like(b, escape=escape)
+
+
+@comparison_op
+def not_like_op(a, b, escape=None):
+ return a.notlike(b, escape=escape)
+
+
+# 1.4 deprecated; see #5435
+notlike_op = not_like_op
+
+
+@comparison_op
+def ilike_op(a, b, escape=None):
+ return a.ilike(b, escape=escape)
+
+
+@comparison_op
+def not_ilike_op(a, b, escape=None):
+ return a.not_ilike(b, escape=escape)
+
+
+# 1.4 deprecated; see #5435
+notilike_op = not_ilike_op
+
+
+@comparison_op
+def between_op(a, b, c, symmetric=False):
+ return a.between(b, c, symmetric=symmetric)
+
+
+@comparison_op
+def not_between_op(a, b, c, symmetric=False):
+ return ~a.between(b, c, symmetric=symmetric)
+
+
+# 1.4 deprecated; see #5435
+notbetween_op = not_between_op
+
+
+@comparison_op
+def in_op(a, b):
+ return a.in_(b)
+
+
+@comparison_op
+def not_in_op(a, b):
+ return a.not_in(b)
+
+
+# 1.4 deprecated; see #5429
+notin_op = not_in_op
+
+
+def distinct_op(a):
+ return a.distinct()
+
+
+def any_op(a):
+ return a.any_()
+
+
+def all_op(a):
+ return a.all_()
+
+
+def _escaped_like_impl(fn, other, escape, autoescape):
+ if autoescape:
+ if autoescape is not True:
+ util.warn(
+ "The autoescape parameter is now a simple boolean True/False"
+ )
+ if escape is None:
+ escape = "/"
+
+ if not isinstance(other, util.compat.string_types):
+ raise TypeError("String value expected when autoescape=True")
+
+ if escape not in ("%", "_"):
+ other = other.replace(escape, escape + escape)
+
+ other = other.replace("%", escape + "%").replace("_", escape + "_")
+
+ return fn(other, escape=escape)
+
+
+@comparison_op
+def startswith_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.startswith, b, escape, autoescape)
+
+
+@comparison_op
+def not_startswith_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.startswith, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notstartswith_op = not_startswith_op
+
+
+@comparison_op
+def endswith_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.endswith, b, escape, autoescape)
+
+
+@comparison_op
+def not_endswith_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.endswith, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notendswith_op = not_endswith_op
+
+
+@comparison_op
+def contains_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.contains, b, escape, autoescape)
+
+
+@comparison_op
+def not_contains_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.contains, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notcontains_op = not_contains_op
+
+
+@comparison_op
+def match_op(a, b, **kw):
+ return a.match(b, **kw)
+
+
+@comparison_op
+def regexp_match_op(a, b, flags=None):
+ return a.regexp_match(b, flags=flags)
+
+
+@comparison_op
+def not_regexp_match_op(a, b, flags=None):
+ return ~a.regexp_match(b, flags=flags)
+
+
+def regexp_replace_op(a, b, replacement, flags=None):
+ return a.regexp_replace(b, replacement=replacement, flags=flags)
+
+
+@comparison_op
+def not_match_op(a, b, **kw):
+ return ~a.match(b, **kw)
+
+
+# 1.4 deprecated; see #5429
+notmatch_op = not_match_op
+
+
+def comma_op(a, b):
+ raise NotImplementedError()
+
+
+def filter_op(a, b):
+ raise NotImplementedError()
+
+
+def concat_op(a, b):
+ try:
+ concat = a.concat
+ except AttributeError:
+ return b._rconcat(a)
+ else:
+ return concat(b)
+
+
+def desc_op(a):
+ return a.desc()
+
+
+def asc_op(a):
+ return a.asc()
+
+
+def nulls_first_op(a):
+ return a.nulls_first()
+
+
+# 1.4 deprecated; see #5435
+nullsfirst_op = nulls_first_op
+
+
+def nulls_last_op(a):
+ return a.nulls_last()
+
+
+# 1.4 deprecated; see #5435
+nullslast_op = nulls_last_op
+
+
+def json_getitem_op(a, b):
+ raise NotImplementedError()
+
+
+def json_path_getitem_op(a, b):
+ raise NotImplementedError()
+
+
+def is_comparison(op):
+ return op in _comparison or isinstance(op, custom_op) and op.is_comparison
+
+
+def is_commutative(op):
+ return op in _commutative
+
+
+def is_ordering_modifier(op):
+ return op in (asc_op, desc_op, nulls_first_op, nulls_last_op)
+
+
+def is_natural_self_precedent(op):
+ return (
+ op in _natural_self_precedent
+ or isinstance(op, custom_op)
+ and op.natural_self_precedent
+ )
+
+
+_booleans = (inv, is_true, is_false, and_, or_)
+
+
+def is_boolean(op):
+ return is_comparison(op) or op in _booleans
+
+
+_mirror = {gt: lt, ge: le, lt: gt, le: ge}
+
+
+def mirror(op):
+ """rotate a comparison operator 180 degrees.
+
+ Note this is not the same as negation.
+
+ """
+ return _mirror.get(op, op)
+
+
+_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
+
+
+def is_associative(op):
+ return op in _associative
+
+
+_natural_self_precedent = _associative.union(
+ [getitem, json_getitem_op, json_path_getitem_op]
+)
+"""Operators where if we have (a op b) op c, we don't want to
+parenthesize (a op b).
+
+"""
+
+
+_asbool = util.symbol("_asbool", canonical=-10)
+_smallest = util.symbol("_smallest", canonical=-100)
+_largest = util.symbol("_largest", canonical=100)
+
+_PRECEDENCE = {
+ from_: 15,
+ function_as_comparison_op: 15,
+ any_op: 15,
+ all_op: 15,
+ getitem: 15,
+ json_getitem_op: 15,
+ json_path_getitem_op: 15,
+ mul: 8,
+ truediv: 8,
+ div: 8,
+ mod: 8,
+ neg: 8,
+ add: 7,
+ sub: 7,
+ concat_op: 6,
+ filter_op: 6,
+ match_op: 5,
+ not_match_op: 5,
+ regexp_match_op: 5,
+ not_regexp_match_op: 5,
+ regexp_replace_op: 5,
+ ilike_op: 5,
+ not_ilike_op: 5,
+ like_op: 5,
+ not_like_op: 5,
+ in_op: 5,
+ not_in_op: 5,
+ is_: 5,
+ is_not: 5,
+ eq: 5,
+ ne: 5,
+ is_distinct_from: 5,
+ is_not_distinct_from: 5,
+ gt: 5,
+ lt: 5,
+ ge: 5,
+ le: 5,
+ between_op: 5,
+ not_between_op: 5,
+ distinct_op: 5,
+ inv: 5,
+ is_true: 5,
+ is_false: 5,
+ and_: 3,
+ or_: 2,
+ comma_op: -1,
+ desc_op: 3,
+ asc_op: 3,
+ collate: 4,
+ as_: -1,
+ exists: 0,
+ _asbool: -10,
+ _smallest: _smallest,
+ _largest: _largest,
+}
+
+
+def is_precedent(operator, against):
+ if operator is against and is_natural_self_precedent(operator):
+ return False
+ else:
+ return _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _smallest)
+ ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
new file mode 100644
index 0000000..9e146f7
--- /dev/null
+++ b/lib/sqlalchemy/sql/roles.py
@@ -0,0 +1,239 @@
+# sql/roles.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import util
+
+
+class SQLRole(object):
+ """Define a "role" within a SQL statement structure.
+
+ Classes within SQL Core participate within SQLRole hierarchies in order
+ to more accurately indicate where they may be used within SQL statements
+ of all types.
+
+ .. versionadded:: 1.4
+
+ """
+
+ allows_lambda = False
+ uses_inspection = False
+
+
+class UsesInspection(object):
+ _post_inspect = None
+ uses_inspection = True
+
+
+class AllowsLambdaRole(object):
+ allows_lambda = True
+
+
+class HasCacheKeyRole(SQLRole):
+ _role_name = "Cacheable Core or ORM object"
+
+
+class ExecutableOptionRole(SQLRole):
+ __slots__ = ()
+ _role_name = "ExecutionOption Core or ORM object"
+
+
+class LiteralValueRole(SQLRole):
+ _role_name = "Literal Python value"
+
+
+class ColumnArgumentRole(SQLRole):
+ _role_name = "Column expression"
+
+
+class ColumnArgumentOrKeyRole(ColumnArgumentRole):
+ _role_name = "Column expression or string key"
+
+
+class StrAsPlainColumnRole(ColumnArgumentRole):
+ _role_name = "Column expression or string key"
+
+
+class ColumnListRole(SQLRole):
+ """Elements suitable for forming comma separated lists of expressions."""
+
+
+class TruncatedLabelRole(SQLRole):
+ _role_name = "String SQL identifier"
+
+
+class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
+ _role_name = "Column expression or FROM clause"
+
+ @property
+ def _select_iterable(self):
+ raise NotImplementedError()
+
+
+class LimitOffsetRole(SQLRole):
+ _role_name = "LIMIT / OFFSET expression"
+
+
+class ByOfRole(ColumnListRole):
+ _role_name = "GROUP BY / OF / etc. expression"
+
+
+class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole):
+ # note there's a special case right now where you can pass a whole
+ # ORM entity to group_by() and it splits out. we may not want to keep
+ # this around
+
+ _role_name = "GROUP BY expression"
+
+
+class OrderByRole(AllowsLambdaRole, ByOfRole):
+ _role_name = "ORDER BY expression"
+
+
+class StructuralRole(SQLRole):
+ pass
+
+
+class StatementOptionRole(StructuralRole):
+ _role_name = "statement sub-expression element"
+
+
+class OnClauseRole(AllowsLambdaRole, StructuralRole):
+ _role_name = "SQL expression for ON clause"
+
+
+class WhereHavingRole(OnClauseRole):
+ _role_name = "SQL expression for WHERE/HAVING role"
+
+
+class ExpressionElementRole(SQLRole):
+ _role_name = "SQL expression element"
+
+
+class ConstExprRole(ExpressionElementRole):
+ _role_name = "Constant True/False/None expression"
+
+
+class LabeledColumnExprRole(ExpressionElementRole):
+ pass
+
+
+class BinaryElementRole(ExpressionElementRole):
+ _role_name = "SQL expression element or literal value"
+
+
+class InElementRole(SQLRole):
+ _role_name = (
+ "IN expression list, SELECT construct, or bound parameter object"
+ )
+
+
+class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole):
+ _role_name = (
+ "Join target, typically a FROM expression, or ORM "
+ "relationship attribute"
+ )
+
+
+class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
+ _role_name = "FROM expression, such as a Table or alias() object"
+
+ _is_subquery = False
+
+ @property
+ def _hide_froms(self):
+ raise NotImplementedError()
+
+
+class StrictFromClauseRole(FromClauseRole):
+ # does not allow text() or select() objects
+
+ @property
+ def description(self):
+ raise NotImplementedError()
+
+
+class AnonymizedFromClauseRole(StrictFromClauseRole):
+ # calls .alias() as a post processor
+
+ def _anonymous_fromclause(self, name=None, flat=False):
+ raise NotImplementedError()
+
+
+class ReturnsRowsRole(SQLRole):
+ _role_name = (
+ "Row returning expression such as a SELECT, a FROM clause, or an "
+ "INSERT/UPDATE/DELETE with RETURNING"
+ )
+
+
+class StatementRole(SQLRole):
+ _role_name = "Executable SQL or text() construct"
+
+ _propagate_attrs = util.immutabledict()
+
+
+class SelectStatementRole(StatementRole, ReturnsRowsRole):
+ _role_name = "SELECT construct or equivalent text() construct"
+
+ def subquery(self):
+ raise NotImplementedError(
+ "All SelectStatementRole objects should implement a "
+ ".subquery() method."
+ )
+
+
+class HasCTERole(ReturnsRowsRole):
+ pass
+
+
+class IsCTERole(SQLRole):
+ _role_name = "CTE object"
+
+
+class CompoundElementRole(AllowsLambdaRole, SQLRole):
+ """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
+
+ _role_name = (
+ "SELECT construct for inclusion in a UNION or other set construct"
+ )
+
+
+# TODO: are we using this?
+class DMLRole(StatementRole):
+ pass
+
+
+class DMLTableRole(FromClauseRole):
+ _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
+class DMLColumnRole(SQLRole):
+ _role_name = "SET/VALUES column expression or string key"
+
+
+class DMLSelectRole(SQLRole):
+ """A SELECT statement embedded in DML, typically INSERT from SELECT"""
+
+ _role_name = "SELECT statement or equivalent textual object"
+
+
+class DDLRole(StatementRole):
+ pass
+
+
+class DDLExpressionRole(StructuralRole):
+ _role_name = "SQL expression element for DDL constraint"
+
+
+class DDLConstraintColumnRole(SQLRole):
+ _role_name = "String column name or column expression for DDL constraint"
+
+
+class DDLReferredColumnRole(DDLConstraintColumnRole):
+ _role_name = (
+ "String column name or Column object for DDL foreign key constraint"
+ )
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
new file mode 100644
index 0000000..dde665c
--- /dev/null
+++ b/lib/sqlalchemy/sql/schema.py
@@ -0,0 +1,5268 @@
+# sql/schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The schema module provides the building blocks for database metadata.
+
+Each element within this module describes a database entity which can be
+created and dropped, or is otherwise part of such an entity. Examples include
+tables, columns, sequences, and indexes.
+
+All entities are subclasses of :class:`~sqlalchemy.schema.SchemaItem`, and as
+defined in this module they are intended to be agnostic of any vendor-specific
+constructs.
+
+A collection of entities are grouped into a unit called
+:class:`~sqlalchemy.schema.MetaData`. MetaData serves as a logical grouping of
+schema elements, and can also be associated with an actual database connection
+such that operations involving the contained elements can contact the database
+as needed.
+
+Two of the elements here also build upon their "syntactic" counterparts, which
+are defined in :class:`~sqlalchemy.sql.expression.`, specifically
+:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.Column`.
+Since these objects are part of the SQL expression language, they are usable
+as components in SQL expressions.
+
+"""
+from __future__ import absolute_import
+
+import collections
+
+import sqlalchemy
+from . import coercions
+from . import ddl
+from . import roles
+from . import type_api
+from . import visitors
+from .base import _bind_or_error
+from .base import DedupeColumnCollection
+from .base import DialectKWArgs
+from .base import Executable
+from .base import SchemaEventTarget
+from .coercions import _document_text_coercion
+from .elements import ClauseElement
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import quoted_name
+from .elements import TextClause
+from .selectable import TableClause
+from .type_api import to_instance
+from .visitors import InternalTraversal
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import util
+
+
+RETAIN_SCHEMA = util.symbol(
+ "retain_schema"
+ """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence`
+ or in some cases a :class:`_schema.ForeignKey` object, in situations
+ where the object is being copied for a :meth:`.Table.to_metadata`
+ operation, should retain the schema name that it already has.
+
+ """
+)
+
+BLANK_SCHEMA = util.symbol(
+ "blank_schema",
+ """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence`
+ or in some cases a :class:`_schema.ForeignKey` object
+ should have 'None' for its schema, even if the parent
+ :class:`_schema.MetaData` has specified a schema.
+
+ .. versionadded:: 1.0.14
+
+ """,
+)
+
+NULL_UNSPECIFIED = util.symbol(
+ "NULL_UNSPECIFIED",
+ """Symbol indicating the "nullable" keyword was not passed to a Column.
+
+ Normally we would expect None to be acceptable for this but some backends
+ such as that of SQL Server place special signficance on a "nullability"
+ value of None.
+
+ """,
+)
+
+
+def _get_table_key(name, schema):
+ if schema is None:
+ return name
+ else:
+ return schema + "." + name
+
+
+# this should really be in sql/util.py but we'd have to
+# break an import cycle
+def _copy_expression(expression, source_table, target_table):
+ if source_table is None or target_table is None:
+ return expression
+
+ def replace(col):
+ if (
+ isinstance(col, Column)
+ and col.table is source_table
+ and col.key in source_table.c
+ ):
+ return target_table.c[col.key]
+ else:
+ return None
+
+ return visitors.replacement_traverse(expression, {}, replace)
+
+
+@inspection._self_inspects
+class SchemaItem(SchemaEventTarget, visitors.Visitable):
+ """Base class for items that define a database schema."""
+
+ __visit_name__ = "schema_item"
+
+ create_drop_stringify_dialect = "default"
+
+ def _init_items(self, *args, **kw):
+ """Initialize the list of child items for this SchemaItem."""
+ for item in args:
+ if item is not None:
+ try:
+ spwd = item._set_parent_with_dispatch
+ except AttributeError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "'SchemaItem' object, such as a 'Column' or a "
+ "'Constraint' expected, got %r" % item
+ ),
+ replace_context=err,
+ )
+ else:
+ spwd(self, **kw)
+
+ def __repr__(self):
+ return util.generic_repr(self, omit_kwarg=["info"])
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.SchemaItem`.
+
+ The dictionary is automatically generated when first accessed.
+ It can also be specified in the constructor of some objects,
+ such as :class:`_schema.Table` and :class:`_schema.Column`.
+
+ """
+ return {}
+
+ def _schema_item_copy(self, schema_item):
+ if "info" in self.__dict__:
+ schema_item.info = self.info.copy()
+ schema_item.dispatch._update(self.dispatch)
+ return schema_item
+
+ _use_schema_map = True
+
+
+class Table(DialectKWArgs, SchemaItem, TableClause):
+ r"""Represent a table in a database.
+
+ e.g.::
+
+ mytable = Table(
+ "mytable", metadata,
+ Column('mytable_id', Integer, primary_key=True),
+ Column('value', String(50))
+ )
+
+ The :class:`_schema.Table`
+ object constructs a unique instance of itself based
+ on its name and optional schema name within the given
+ :class:`_schema.MetaData` object. Calling the :class:`_schema.Table`
+ constructor with the same name and same :class:`_schema.MetaData` argument
+ a second time will return the *same* :class:`_schema.Table`
+ object - in this way
+ the :class:`_schema.Table` constructor acts as a registry function.
+
+ .. seealso::
+
+ :ref:`metadata_describing` - Introduction to database metadata
+
+ Constructor arguments are as follows:
+
+ :param name: The name of this table as represented in the database.
+
+ The table name, along with the value of the ``schema`` parameter,
+ forms a key which uniquely identifies this :class:`_schema.Table`
+ within
+ the owning :class:`_schema.MetaData` collection.
+ Additional calls to :class:`_schema.Table` with the same name,
+ metadata,
+ and schema name will return the same :class:`_schema.Table` object.
+
+ Names which contain no upper case characters
+ will be treated as case insensitive names, and will not be quoted
+ unless they are a reserved word or contain special characters.
+ A name with any number of upper case characters is considered
+ to be case sensitive, and will be sent as quoted.
+
+ To enable unconditional quoting for the table name, specify the flag
+ ``quote=True`` to the constructor, or use the :class:`.quoted_name`
+ construct to specify the name.
+
+ :param metadata: a :class:`_schema.MetaData`
+ object which will contain this
+ table. The metadata is used as a point of association of this table
+ with other tables which are referenced via foreign key. It also
+ may be used to associate this table with a particular
+ :class:`.Connectable`.
+
+ :param \*args: Additional positional arguments are used primarily
+ to add the list of :class:`_schema.Column`
+ objects contained within this
+ table. Similar to the style of a CREATE TABLE statement, other
+ :class:`.SchemaItem` constructs may be added here, including
+ :class:`.PrimaryKeyConstraint`, and
+ :class:`_schema.ForeignKeyConstraint`.
+
+ :param autoload: Defaults to ``False``, unless
+ :paramref:`_schema.Table.autoload_with`
+ is set in which case it defaults to ``True``;
+ :class:`_schema.Column` objects
+ for this table should be reflected from the database, possibly
+ augmenting objects that were explicitly specified.
+ :class:`_schema.Column` and other objects explicitly set on the
+ table will replace corresponding reflected objects.
+
+ .. deprecated:: 1.4
+
+ The autoload parameter is deprecated and will be removed in
+ version 2.0. Please use the
+ :paramref:`_schema.Table.autoload_with` parameter, passing an
+ engine or connection.
+
+ .. seealso::
+
+ :ref:`metadata_reflection_toplevel`
+
+ :param autoload_replace: Defaults to ``True``; when using
+ :paramref:`_schema.Table.autoload`
+ in conjunction with :paramref:`_schema.Table.extend_existing`,
+ indicates
+ that :class:`_schema.Column` objects present in the already-existing
+ :class:`_schema.Table`
+ object should be replaced with columns of the same
+ name retrieved from the autoload process. When ``False``, columns
+ already present under existing names will be omitted from the
+ reflection process.
+
+ Note that this setting does not impact :class:`_schema.Column` objects
+ specified programmatically within the call to :class:`_schema.Table`
+ that
+ also is autoloading; those :class:`_schema.Column` objects will always
+ replace existing columns of the same name when
+ :paramref:`_schema.Table.extend_existing` is ``True``.
+
+ .. seealso::
+
+ :paramref:`_schema.Table.autoload`
+
+ :paramref:`_schema.Table.extend_existing`
+
+ :param autoload_with: An :class:`_engine.Engine` or
+ :class:`_engine.Connection` object,
+ or a :class:`_reflection.Inspector` object as returned by
+ :func:`_sa.inspect`
+ against one, with which this :class:`_schema.Table`
+ object will be reflected.
+ When set to a non-None value, the autoload process will take place
+ for this table against the given engine or connection.
+
+ :param extend_existing: When ``True``, indicates that if this
+ :class:`_schema.Table` is already present in the given
+ :class:`_schema.MetaData`,
+ apply further arguments within the constructor to the existing
+ :class:`_schema.Table`.
+
+ If :paramref:`_schema.Table.extend_existing` or
+ :paramref:`_schema.Table.keep_existing` are not set,
+ and the given name
+ of the new :class:`_schema.Table` refers to a :class:`_schema.Table`
+ that is
+ already present in the target :class:`_schema.MetaData` collection,
+ and
+ this :class:`_schema.Table`
+ specifies additional columns or other constructs
+ or flags that modify the table's state, an
+ error is raised. The purpose of these two mutually-exclusive flags
+ is to specify what action should be taken when a
+ :class:`_schema.Table`
+ is specified that matches an existing :class:`_schema.Table`,
+ yet specifies
+ additional constructs.
+
+ :paramref:`_schema.Table.extend_existing`
+ will also work in conjunction
+ with :paramref:`_schema.Table.autoload` to run a new reflection
+ operation against the database, even if a :class:`_schema.Table`
+ of the same name is already present in the target
+ :class:`_schema.MetaData`; newly reflected :class:`_schema.Column`
+ objects
+ and other options will be added into the state of the
+ :class:`_schema.Table`, potentially overwriting existing columns
+ and options of the same name.
+
+ As is always the case with :paramref:`_schema.Table.autoload`,
+ :class:`_schema.Column` objects can be specified in the same
+ :class:`_schema.Table`
+ constructor, which will take precedence. Below, the existing
+ table ``mytable`` will be augmented with :class:`_schema.Column`
+ objects
+ both reflected from the database, as well as the given
+ :class:`_schema.Column`
+ named "y"::
+
+ Table("mytable", metadata,
+ Column('y', Integer),
+ extend_existing=True,
+ autoload_with=engine
+ )
+
+ .. seealso::
+
+ :paramref:`_schema.Table.autoload`
+
+ :paramref:`_schema.Table.autoload_replace`
+
+ :paramref:`_schema.Table.keep_existing`
+
+
+ :param implicit_returning: True by default - indicates that
+ RETURNING can be used by default to fetch newly inserted primary key
+ values, for backends which support this. Note that
+ :func:`_sa.create_engine` also provides an ``implicit_returning``
+ flag.
+
+ :param include_columns: A list of strings indicating a subset of
+ columns to be loaded via the ``autoload`` operation; table columns who
+ aren't present in this list will not be represented on the resulting
+ ``Table`` object. Defaults to ``None`` which indicates all columns
+ should be reflected.
+
+ :param resolve_fks: Whether or not to reflect :class:`_schema.Table`
+ objects
+ related to this one via :class:`_schema.ForeignKey` objects, when
+ :paramref:`_schema.Table.autoload` or
+ :paramref:`_schema.Table.autoload_with` is
+ specified. Defaults to True. Set to False to disable reflection of
+ related tables as :class:`_schema.ForeignKey`
+ objects are encountered; may be
+ used either to save on SQL calls or to avoid issues with related tables
+ that can't be accessed. Note that if a related table is already present
+ in the :class:`_schema.MetaData` collection, or becomes present later,
+ a
+ :class:`_schema.ForeignKey` object associated with this
+ :class:`_schema.Table` will
+ resolve to that table normally.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :paramref:`.MetaData.reflect.resolve_fks`
+
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ :param keep_existing: When ``True``, indicates that if this Table
+ is already present in the given :class:`_schema.MetaData`, ignore
+ further arguments within the constructor to the existing
+ :class:`_schema.Table`, and return the :class:`_schema.Table`
+ object as
+ originally created. This is to allow a function that wishes
+ to define a new :class:`_schema.Table` on first call, but on
+ subsequent calls will return the same :class:`_schema.Table`,
+ without any of the declarations (particularly constraints)
+ being applied a second time.
+
+ If :paramref:`_schema.Table.extend_existing` or
+ :paramref:`_schema.Table.keep_existing` are not set,
+ and the given name
+ of the new :class:`_schema.Table` refers to a :class:`_schema.Table`
+ that is
+ already present in the target :class:`_schema.MetaData` collection,
+ and
+ this :class:`_schema.Table`
+ specifies additional columns or other constructs
+ or flags that modify the table's state, an
+ error is raised. The purpose of these two mutually-exclusive flags
+ is to specify what action should be taken when a
+ :class:`_schema.Table`
+ is specified that matches an existing :class:`_schema.Table`,
+ yet specifies
+ additional constructs.
+
+ .. seealso::
+
+ :paramref:`_schema.Table.extend_existing`
+
+ :param listeners: A list of tuples of the form ``(<eventname>, <fn>)``
+ which will be passed to :func:`.event.listen` upon construction.
+ This alternate hook to :func:`.event.listen` allows the establishment
+ of a listener function specific to this :class:`_schema.Table` before
+ the "autoload" process begins. Historically this has been intended
+ for use with the :meth:`.DDLEvents.column_reflect` event, however
+ note that this event hook may now be associated with the
+ :class:`_schema.MetaData` object directly::
+
+ def listen_for_reflect(table, column_info):
+ "handle the column reflection event"
+ # ...
+
+ t = Table(
+ 'sometable',
+ autoload_with=engine,
+ listeners=[
+ ('column_reflect', listen_for_reflect)
+ ])
+
+ .. seealso::
+
+ :meth:`_events.DDLEvents.column_reflect`
+
+ :param must_exist: When ``True``, indicates that this Table must already
+ be present in the given :class:`_schema.MetaData` collection, else
+ an exception is raised.
+
+ :param prefixes:
+ A list of strings to insert after CREATE in the CREATE TABLE
+ statement. They will be separated by spaces.
+
+ :param quote: Force quoting of this table's name on or off, corresponding
+ to ``True`` or ``False``. When left at its default of ``None``,
+ the column identifier will be quoted according to whether the name is
+ case sensitive (identifiers with at least one upper case character are
+ treated as case sensitive), or if it's a reserved word. This flag
+ is only needed to force quoting of a reserved word which is not known
+ by the SQLAlchemy dialect.
+
+ .. note:: setting this flag to ``False`` will not provide
+ case-insensitive behavior for table reflection; table reflection
+ will always search for a mixed-case name in a case sensitive
+ fashion. Case insensitive names are specified in SQLAlchemy only
+ by stating the name with all lower case characters.
+
+ :param quote_schema: same as 'quote' but applies to the schema identifier.
+
+ :param schema: The schema name for this table, which is required if
+ the table resides in a schema other than the default selected schema
+ for the engine's database connection. Defaults to ``None``.
+
+ If the owning :class:`_schema.MetaData` of this :class:`_schema.Table`
+ specifies its
+ own :paramref:`_schema.MetaData.schema` parameter,
+ then that schema name will
+ be applied to this :class:`_schema.Table`
+ if the schema parameter here is set
+ to ``None``. To set a blank schema name on a :class:`_schema.Table`
+ that
+ would otherwise use the schema set on the owning
+ :class:`_schema.MetaData`,
+ specify the special symbol :attr:`.BLANK_SCHEMA`.
+
+ .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to
+ allow a :class:`_schema.Table`
+ to have a blank schema name even when the
+ parent :class:`_schema.MetaData` specifies
+ :paramref:`_schema.MetaData.schema`.
+
+ The quoting rules for the schema name are the same as those for the
+ ``name`` parameter, in that quoting is applied for reserved words or
+ case-sensitive names; to enable unconditional quoting for the schema
+ name, specify the flag ``quote_schema=True`` to the constructor, or use
+ the :class:`.quoted_name` construct to specify the name.
+
+ :param comment: Optional string that will render an SQL comment on table
+ creation.
+
+ .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment`
+ parameter
+ to :class:`_schema.Table`.
+
+ :param \**kw: Additional keyword arguments not mentioned above are
+ dialect specific, and passed in the form ``<dialectname>_<argname>``.
+ See the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ """
+
+ __visit_name__ = "table"
+
+ constraints = None
+ """A collection of all :class:`_schema.Constraint` objects associated with
+ this :class:`_schema.Table`.
+
+ Includes :class:`_schema.PrimaryKeyConstraint`,
+ :class:`_schema.ForeignKeyConstraint`, :class:`_schema.UniqueConstraint`,
+ :class:`_schema.CheckConstraint`. A separate collection
+ :attr:`_schema.Table.foreign_key_constraints` refers to the collection
+ of all :class:`_schema.ForeignKeyConstraint` objects, and the
+ :attr:`_schema.Table.primary_key` attribute refers to the single
+ :class:`_schema.PrimaryKeyConstraint` associated with the
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.constraints`
+
+ :attr:`_schema.Table.primary_key`
+
+ :attr:`_schema.Table.foreign_key_constraints`
+
+ :attr:`_schema.Table.indexes`
+
+ :class:`_reflection.Inspector`
+
+
+ """
+
+ indexes = None
+ """A collection of all :class:`_schema.Index` objects associated with this
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_indexes`
+
+ """
+
+ _traverse_internals = TableClause._traverse_internals + [
+ ("schema", InternalTraversal.dp_string)
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self._annotations:
+ return (self,) + self._annotations_cache_key
+ else:
+ return (self,)
+
+ @util.deprecated_params(
+ mustexist=(
+ "1.4",
+ "Deprecated alias of :paramref:`_schema.Table.must_exist`",
+ ),
+ autoload=(
+ "2.0",
+ "The autoload parameter is deprecated and will be removed in "
+ "version 2.0. Please use the "
+ "autoload_with parameter, passing an engine or connection.",
+ ),
+ )
+ def __new__(cls, *args, **kw):
+ if not args and not kw:
+ # python3k pickle seems to call this
+ return object.__new__(cls)
+
+ try:
+ name, metadata, args = args[0], args[1], args[2:]
+ except IndexError:
+ raise TypeError(
+ "Table() takes at least two positional-only "
+ "arguments 'name' and 'metadata'"
+ )
+
+ schema = kw.get("schema", None)
+ if schema is None:
+ schema = metadata.schema
+ elif schema is BLANK_SCHEMA:
+ schema = None
+ keep_existing = kw.get("keep_existing", False)
+ extend_existing = kw.get("extend_existing", False)
+
+ if keep_existing and extend_existing:
+ msg = "keep_existing and extend_existing are mutually exclusive."
+ raise exc.ArgumentError(msg)
+
+ must_exist = kw.pop("must_exist", kw.pop("mustexist", False))
+ key = _get_table_key(name, schema)
+ if key in metadata.tables:
+ if not keep_existing and not extend_existing and bool(args):
+ raise exc.InvalidRequestError(
+ "Table '%s' is already defined for this MetaData "
+ "instance. Specify 'extend_existing=True' "
+ "to redefine "
+ "options and columns on an "
+ "existing Table object." % key
+ )
+ table = metadata.tables[key]
+ if extend_existing:
+ table._init_existing(*args, **kw)
+ return table
+ else:
+ if must_exist:
+ raise exc.InvalidRequestError("Table '%s' not defined" % (key))
+ table = object.__new__(cls)
+ table.dispatch.before_parent_attach(table, metadata)
+ metadata._add_table(name, schema, table)
+ try:
+ table._init(name, metadata, *args, **kw)
+ table.dispatch.after_parent_attach(table, metadata)
+ return table
+ except Exception:
+ with util.safe_reraise():
+ metadata._remove_table(name, schema)
+
+ def __init__(self, *args, **kw):
+ """Constructor for :class:`_schema.Table`.
+
+ This method is a no-op. See the top-level
+ documentation for :class:`_schema.Table`
+ for constructor arguments.
+
+ """
+ # __init__ is overridden to prevent __new__ from
+ # calling the superclass constructor.
+
+ def _init(self, name, metadata, *args, **kwargs):
+ super(Table, self).__init__(
+ quoted_name(name, kwargs.pop("quote", None))
+ )
+ self.metadata = metadata
+
+ self.schema = kwargs.pop("schema", None)
+ if self.schema is None:
+ self.schema = metadata.schema
+ elif self.schema is BLANK_SCHEMA:
+ self.schema = None
+ else:
+ quote_schema = kwargs.pop("quote_schema", None)
+ self.schema = quoted_name(self.schema, quote_schema)
+
+ self.indexes = set()
+ self.constraints = set()
+ PrimaryKeyConstraint(
+ _implicit_generated=True
+ )._set_parent_with_dispatch(self)
+ self.foreign_keys = set()
+ self._extra_dependencies = set()
+ if self.schema is not None:
+ self.fullname = "%s.%s" % (self.schema, self.name)
+ else:
+ self.fullname = self.name
+
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ # this argument is only used with _init_existing()
+ kwargs.pop("autoload_replace", True)
+ keep_existing = kwargs.pop("keep_existing", False)
+ extend_existing = kwargs.pop("extend_existing", False)
+ _extend_on = kwargs.pop("_extend_on", None)
+
+ resolve_fks = kwargs.pop("resolve_fks", True)
+ include_columns = kwargs.pop("include_columns", None)
+
+ self.implicit_returning = kwargs.pop("implicit_returning", True)
+
+ self.comment = kwargs.pop("comment", None)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+ if "listeners" in kwargs:
+ listeners = kwargs.pop("listeners")
+ for evt, fn in listeners:
+ event.listen(self, evt, fn)
+
+ self._prefixes = kwargs.pop("prefixes", None) or []
+
+ self._extra_kwargs(**kwargs)
+
+ # load column definitions from the database if 'autoload' is defined
+ # we do it after the table is in the singleton dictionary to support
+ # circular foreign keys
+ if autoload:
+ self._autoload(
+ metadata,
+ autoload_with,
+ include_columns,
+ _extend_on=_extend_on,
+ resolve_fks=resolve_fks,
+ )
+
+ # initialize all the column, etc. objects. done after reflection to
+ # allow user-overrides
+
+ self._init_items(
+ *args,
+ allow_replacements=extend_existing or keep_existing or autoload
+ )
+
+ def _autoload(
+ self,
+ metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns=(),
+ resolve_fks=True,
+ _extend_on=None,
+ ):
+ if autoload_with is None:
+ autoload_with = _bind_or_error(
+ metadata,
+ msg="No engine is bound to this Table's MetaData. "
+ "Pass an engine to the Table via "
+ "autoload_with=<someengine_or_connection>",
+ )
+
+ insp = inspection.inspect(autoload_with)
+ with insp._inspection_context() as conn_insp:
+ conn_insp.reflect_table(
+ self,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on=_extend_on,
+ )
+
+ @property
+ def _sorted_constraints(self):
+ """Return the set of constraints as a list, sorted by creation
+ order.
+
+ """
+ return sorted(self.constraints, key=lambda c: c._creation_order)
+
+ @property
+ def foreign_key_constraints(self):
+ """:class:`_schema.ForeignKeyConstraint` objects referred to by this
+ :class:`_schema.Table`.
+
+ This list is produced from the collection of
+ :class:`_schema.ForeignKey`
+ objects currently associated.
+
+
+ .. seealso::
+
+ :attr:`_schema.Table.constraints`
+
+ :attr:`_schema.Table.foreign_keys`
+
+ :attr:`_schema.Table.indexes`
+
+ """
+ return set(fkc.constraint for fkc in self.foreign_keys)
+
+ def _init_existing(self, *args, **kwargs):
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ autoload_replace = kwargs.pop("autoload_replace", True)
+ schema = kwargs.pop("schema", None)
+ _extend_on = kwargs.pop("_extend_on", None)
+ # these arguments are only used with _init()
+ kwargs.pop("extend_existing", False)
+ kwargs.pop("keep_existing", False)
+
+ if schema and schema != self.schema:
+ raise exc.ArgumentError(
+ "Can't change schema of existing table from '%s' to '%s'",
+ (self.schema, schema),
+ )
+
+ include_columns = kwargs.pop("include_columns", None)
+ if include_columns is not None:
+ for c in self.c:
+ if c.name not in include_columns:
+ self._columns.remove(c)
+
+ resolve_fks = kwargs.pop("resolve_fks", True)
+
+ for key in ("quote", "quote_schema"):
+ if key in kwargs:
+ raise exc.ArgumentError(
+ "Can't redefine 'quote' or 'quote_schema' arguments"
+ )
+
+ # update `self` with these kwargs, if provided
+ self.comment = kwargs.pop("comment", self.comment)
+ self.implicit_returning = kwargs.pop(
+ "implicit_returning", self.implicit_returning
+ )
+ self.info = kwargs.pop("info", self.info)
+
+ if autoload:
+ if not autoload_replace:
+ # don't replace columns already present.
+ # we'd like to do this for constraints also however we don't
+ # have simple de-duping for unnamed constraints.
+ exclude_columns = [c.name for c in self.c]
+ else:
+ exclude_columns = ()
+ self._autoload(
+ self.metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on=_extend_on,
+ )
+
+ self._extra_kwargs(**kwargs)
+ self._init_items(*args)
+
+ def _extra_kwargs(self, **kwargs):
+ self._validate_dialect_kwargs(kwargs)
+
+ def _init_collections(self):
+ pass
+
+ def _reset_exported(self):
+ pass
+
+ @property
+ def _autoincrement_column(self):
+ return self.primary_key._autoincrement_column
+
+ @property
+ def key(self):
+ """Return the 'key' for this :class:`_schema.Table`.
+
+ This value is used as the dictionary key within the
+ :attr:`_schema.MetaData.tables` collection. It is typically the same
+ as that of :attr:`_schema.Table.name` for a table with no
+ :attr:`_schema.Table.schema`
+ set; otherwise it is typically of the form
+ ``schemaname.tablename``.
+
+ """
+ return _get_table_key(self.name, self.schema)
+
+ def __repr__(self):
+ return "Table(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.metadata)]
+ + [repr(x) for x in self.columns]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]]
+ )
+
+ def __str__(self):
+ return _get_table_key(self.description, self.schema)
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this Table."""
+
+ return self.metadata and self.metadata.bind or None
+
+ def add_is_dependent_on(self, table):
+ """Add a 'dependency' for this Table.
+
+ This is another Table object which must be created
+ first before this one can, or dropped after this one.
+
+ Usually, dependencies between tables are determined via
+ ForeignKey objects. However, for other situations that
+ create dependencies outside of foreign keys (rules, inheriting),
+ this method can manually establish such a link.
+
+ """
+ self._extra_dependencies.add(table)
+
+ def append_column(self, column, replace_existing=False):
+ """Append a :class:`_schema.Column` to this :class:`_schema.Table`.
+
+ The "key" of the newly added :class:`_schema.Column`, i.e. the
+ value of its ``.key`` attribute, will then be available
+ in the ``.c`` collection of this :class:`_schema.Table`, and the
+ column definition will be included in any CREATE TABLE, SELECT,
+ UPDATE, etc. statements generated from this :class:`_schema.Table`
+ construct.
+
+ Note that this does **not** change the definition of the table
+ as it exists within any underlying database, assuming that
+ table has already been created in the database. Relational
+ databases support the addition of columns to existing tables
+ using the SQL ALTER command, which would need to be
+ emitted for an already-existing table that doesn't contain
+ the newly added column.
+
+ :param replace_existing: When ``True``, allows replacing existing
+ columns. When ``False``, the default, an warning will be raised
+ if a column with the same ``.key`` already exists. A future
+ version of sqlalchemy will instead rise a warning.
+
+ .. versionadded:: 1.4.0
+ """
+
+ column._set_parent_with_dispatch(
+ self, allow_replacements=replace_existing
+ )
+
+ def append_constraint(self, constraint):
+ """Append a :class:`_schema.Constraint` to this
+ :class:`_schema.Table`.
+
+ This has the effect of the constraint being included in any
+ future CREATE TABLE statement, assuming specific DDL creation
+ events have not been associated with the given
+ :class:`_schema.Constraint` object.
+
+ Note that this does **not** produce the constraint within the
+ relational database automatically, for a table that already exists
+ in the database. To add a constraint to an
+ existing relational database table, the SQL ALTER command must
+ be used. SQLAlchemy also provides the
+ :class:`.AddConstraint` construct which can produce this SQL when
+ invoked as an executable clause.
+
+ """
+
+ constraint._set_parent_with_dispatch(self)
+
+ def _set_parent(self, metadata, **kw):
+ metadata._add_table(self.name, self.schema, self)
+ self.metadata = metadata
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Table.exists` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.has_table`.",
+ )
+ def exists(self, bind=None):
+ """Return True if this table exists."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ insp = inspection.inspect(bind)
+ return insp.has_table(self.name, schema=self.schema)
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue a ``CREATE`` statement for this
+ :class:`_schema.Table`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.create_all`.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue a ``DROP`` statement for this
+ :class:`_schema.Table`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.drop_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_schema.Table.tometadata` is renamed to "
+ ":meth:`_schema.Table.to_metadata`",
+ )
+ def tometadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
+ """Return a copy of this :class:`_schema.Table`
+ associated with a different
+ :class:`_schema.MetaData`.
+
+ See :meth:`_schema.Table.to_metadata` for a full description.
+
+ """
+ return self.to_metadata(
+ metadata,
+ schema=schema,
+ referred_schema_fn=referred_schema_fn,
+ name=name,
+ )
+
+ def to_metadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
+ """Return a copy of this :class:`_schema.Table` associated with a
+ different :class:`_schema.MetaData`.
+
+ E.g.::
+
+ m1 = MetaData()
+
+ user = Table('user', m1, Column('id', Integer, primary_key=True))
+
+ m2 = MetaData()
+ user_copy = user.to_metadata(m2)
+
+ .. versionchanged:: 1.4 The :meth:`_schema.Table.to_metadata` function
+ was renamed from :meth:`_schema.Table.tometadata`.
+
+
+ :param metadata: Target :class:`_schema.MetaData` object,
+ into which the
+ new :class:`_schema.Table` object will be created.
+
+ :param schema: optional string name indicating the target schema.
+ Defaults to the special symbol :attr:`.RETAIN_SCHEMA` which indicates
+ that no change to the schema name should be made in the new
+ :class:`_schema.Table`. If set to a string name, the new
+ :class:`_schema.Table`
+ will have this new name as the ``.schema``. If set to ``None``, the
+ schema will be set to that of the schema set on the target
+ :class:`_schema.MetaData`, which is typically ``None`` as well,
+ unless
+ set explicitly::
+
+ m2 = MetaData(schema='newschema')
+
+ # user_copy_one will have "newschema" as the schema name
+ user_copy_one = user.to_metadata(m2, schema=None)
+
+ m3 = MetaData() # schema defaults to None
+
+ # user_copy_two will have None as the schema name
+ user_copy_two = user.to_metadata(m3, schema=None)
+
+ :param referred_schema_fn: optional callable which can be supplied
+ in order to provide for the schema name that should be assigned
+ to the referenced table of a :class:`_schema.ForeignKeyConstraint`.
+ The callable accepts this parent :class:`_schema.Table`, the
+ target schema that we are changing to, the
+ :class:`_schema.ForeignKeyConstraint` object, and the existing
+ "target schema" of that constraint. The function should return the
+ string schema name that should be applied. To reset the schema
+ to "none", return the symbol :data:`.BLANK_SCHEMA`. To effect no
+ change, return ``None`` or :data:`.RETAIN_SCHEMA`.
+
+ .. versionchanged:: 1.4.33 The ``referred_schema_fn`` function
+ may return the :data:`.BLANK_SCHEMA` or :data:`.RETAIN_SCHEMA`
+ symbols.
+
+ E.g.::
+
+ def referred_schema_fn(table, to_schema,
+ constraint, referred_schema):
+ if referred_schema == 'base_tables':
+ return referred_schema
+ else:
+ return to_schema
+
+ new_table = table.to_metadata(m2, schema="alt_schema",
+ referred_schema_fn=referred_schema_fn)
+
+ .. versionadded:: 0.9.2
+
+ :param name: optional string name indicating the target table name.
+ If not specified or None, the table name is retained. This allows
+ a :class:`_schema.Table` to be copied to the same
+ :class:`_schema.MetaData` target
+ with a new name.
+
+ .. versionadded:: 1.0.0
+
+ """
+ if name is None:
+ name = self.name
+ if schema is RETAIN_SCHEMA:
+ schema = self.schema
+ elif schema is None:
+ schema = metadata.schema
+ key = _get_table_key(name, schema)
+ if key in metadata.tables:
+ util.warn(
+ "Table '%s' already exists within the given "
+ "MetaData - not copying." % self.description
+ )
+ return metadata.tables[key]
+
+ args = []
+ for c in self.columns:
+ args.append(c._copy(schema=schema))
+ table = Table(
+ name,
+ metadata,
+ schema=schema,
+ comment=self.comment,
+ *args,
+ **self.kwargs
+ )
+ for c in self.constraints:
+ if isinstance(c, ForeignKeyConstraint):
+ referred_schema = c._referred_schema
+ if referred_schema_fn:
+ fk_constraint_schema = referred_schema_fn(
+ self, schema, c, referred_schema
+ )
+ else:
+ fk_constraint_schema = (
+ schema if referred_schema == self.schema else None
+ )
+ table.append_constraint(
+ c._copy(schema=fk_constraint_schema, target_table=table)
+ )
+ elif not c._type_bound:
+ # skip unique constraints that would be generated
+ # by the 'unique' flag on Column
+ if c._column_flag:
+ continue
+
+ table.append_constraint(
+ c._copy(schema=schema, target_table=table)
+ )
+ for index in self.indexes:
+ # skip indexes that would be generated
+ # by the 'index' flag on Column
+ if index._column_flag:
+ continue
+ Index(
+ index.name,
+ unique=index.unique,
+ *[
+ _copy_expression(expr, self, table)
+ for expr in index.expressions
+ ],
+ _table=table,
+ **index.kwargs
+ )
+ return self._schema_item_copy(table)
+
+
+class Column(DialectKWArgs, SchemaItem, ColumnClause):
+ """Represents a column in a database table."""
+
+ __visit_name__ = "column"
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ r"""
+ Construct a new ``Column`` object.
+
+ :param name: The name of this column as represented in the database.
+ This argument may be the first positional argument, or specified
+ via keyword.
+
+ Names which contain no upper case characters
+ will be treated as case insensitive names, and will not be quoted
+ unless they are a reserved word. Names with any number of upper
+ case characters will be quoted and sent exactly. Note that this
+ behavior applies even for databases which standardize upper
+ case names as case insensitive such as Oracle.
+
+ The name field may be omitted at construction time and applied
+ later, at any time before the Column is associated with a
+ :class:`_schema.Table`. This is to support convenient
+ usage within the :mod:`~sqlalchemy.ext.declarative` extension.
+
+ :param type\_: The column's type, indicated using an instance which
+ subclasses :class:`~sqlalchemy.types.TypeEngine`. If no arguments
+ are required for the type, the class of the type can be sent
+ as well, e.g.::
+
+ # use a type with arguments
+ Column('data', String(50))
+
+ # use no arguments
+ Column('level', Integer)
+
+ The ``type`` argument may be the second positional argument
+ or specified by keyword.
+
+ If the ``type`` is ``None`` or is omitted, it will first default to
+ the special type :class:`.NullType`. If and when this
+ :class:`_schema.Column` is made to refer to another column using
+ :class:`_schema.ForeignKey` and/or
+ :class:`_schema.ForeignKeyConstraint`, the type
+ of the remote-referenced column will be copied to this column as
+ well, at the moment that the foreign key is resolved against that
+ remote :class:`_schema.Column` object.
+
+ .. versionchanged:: 0.9.0
+ Support for propagation of type to a :class:`_schema.Column`
+ from its
+ :class:`_schema.ForeignKey` object has been improved and should be
+ more reliable and timely.
+
+ :param \*args: Additional positional arguments include various
+ :class:`.SchemaItem` derived constructs which will be applied
+ as options to the column. These include instances of
+ :class:`.Constraint`, :class:`_schema.ForeignKey`,
+ :class:`.ColumnDefault`, :class:`.Sequence`, :class:`.Computed`
+ :class:`.Identity`. In some cases an
+ equivalent keyword argument is available such as ``server_default``,
+ ``default`` and ``unique``.
+
+ :param autoincrement: Set up "auto increment" semantics for an
+ **integer primary key column with no foreign key dependencies**
+ (see later in this docstring for a more specific definition).
+ This may influence the :term:`DDL` that will be emitted for
+ this column during a table create, as well as how the column
+ will be considered when INSERT statements are compiled and
+ executed.
+
+ The default value is the string ``"auto"``,
+ which indicates that a single-column (i.e. non-composite) primary key
+ that is of an INTEGER type with no other client-side or server-side
+ default constructs indicated should receive auto increment semantics
+ automatically. Other values include ``True`` (force this column to
+ have auto-increment semantics for a :term:`composite primary key` as
+ well), ``False`` (this column should never have auto-increment
+ semantics), and the string ``"ignore_fk"`` (special-case for foreign
+ key columns, see below).
+
+ The term "auto increment semantics" refers both to the kind of DDL
+ that will be emitted for the column within a CREATE TABLE statement,
+ when methods such as :meth:`.MetaData.create_all` and
+ :meth:`.Table.create` are invoked, as well as how the column will be
+ considered when an INSERT statement is compiled and emitted to the
+ database:
+
+ * **DDL rendering** (i.e. :meth:`.MetaData.create_all`,
+ :meth:`.Table.create`): When used on a :class:`.Column` that has
+ no other
+ default-generating construct associated with it (such as a
+ :class:`.Sequence` or :class:`.Identity` construct), the parameter
+ will imply that database-specific keywords such as PostgreSQL
+ ``SERIAL``, MySQL ``AUTO_INCREMENT``, or ``IDENTITY`` on SQL Server
+ should also be rendered. Not every database backend has an
+ "implied" default generator available; for example the Oracle
+ backend always needs an explicit construct such as
+ :class:`.Identity` to be included with a :class:`.Column` in order
+ for the DDL rendered to include auto-generating constructs to also
+ be produced in the database.
+
+ * **INSERT semantics** (i.e. when a :func:`_sql.insert` construct is
+ compiled into a SQL string and is then executed on a database using
+ :meth:`_engine.Connection.execute` or equivalent): A single-row
+ INSERT statement will be known to produce a new integer primary key
+ value automatically for this column, which will be accessible
+ after the statement is invoked via the
+ :attr:`.CursorResult.inserted_primary_key` attribute upon the
+ :class:`_result.Result` object. This also applies towards use of the
+ ORM when ORM-mapped objects are persisted to the database,
+ indicating that a new integer primary key will be available to
+ become part of the :term:`identity key` for that object. This
+ behavior takes place regardless of what DDL constructs are
+ associated with the :class:`_schema.Column` and is independent
+ of the "DDL Rendering" behavior discussed in the previous note
+ above.
+
+ The parameter may be set to ``True`` to indicate that a column which
+ is part of a composite (i.e. multi-column) primary key should
+ have autoincrement semantics, though note that only one column
+ within a primary key may have this setting. It can also
+ be set to ``True`` to indicate autoincrement semantics on a
+ column that has a client-side or server-side default configured,
+ however note that not all dialects can accommodate all styles
+ of default as an "autoincrement". It can also be
+ set to ``False`` on a single-column primary key that has a
+ datatype of INTEGER in order to disable auto increment semantics
+ for that column.
+
+ .. versionchanged:: 1.1 The autoincrement flag now defaults to
+ ``"auto"`` which indicates autoincrement semantics by default
+ for single-column integer primary keys only; for composite
+ (multi-column) primary keys, autoincrement is never implicitly
+ enabled; as always, ``autoincrement=True`` will allow for
+ at most one of those columns to be an "autoincrement" column.
+ ``autoincrement=True`` may also be set on a
+ :class:`_schema.Column`
+ that has an explicit client-side or server-side default,
+ subject to limitations of the backend database and dialect.
+
+ The setting *only* has an effect for columns which are:
+
+ * Integer derived (i.e. INT, SMALLINT, BIGINT).
+
+ * Part of the primary key
+
+ * Not referring to another column via :class:`_schema.ForeignKey`,
+ unless
+ the value is specified as ``'ignore_fk'``::
+
+ # turn on autoincrement for this column despite
+ # the ForeignKey()
+ Column('id', ForeignKey('other.id'),
+ primary_key=True, autoincrement='ignore_fk')
+
+ It is typically not desirable to have "autoincrement" enabled on a
+ column that refers to another via foreign key, as such a column is
+ required to refer to a value that originates from elsewhere.
+
+ The setting has these effects on columns that meet the
+ above criteria:
+
+ * DDL issued for the column, if the column does not already include
+ a default generating construct supported by the backend such as
+ :class:`.Identity`, will include database-specific
+ keywords intended to signify this column as an
+ "autoincrement" column for specific backends. Behavior for
+ primary SQLAlchemy dialects includes:
+
+ * AUTO INCREMENT on MySQL and MariaDB
+ * SERIAL on PostgreSQL
+ * IDENTITY on MS-SQL - this occurs even without the
+ :class:`.Identity` construct as the
+ :paramref:`.Column.autoincrement` parameter pre-dates this
+ construct.
+ * SQLite - SQLite integer primary key columns are implicitly
+ "auto incrementing" and no additional keywords are rendered;
+ to render the special SQLite keyword ``AUTOINCREMENT``
+ is not included as this is unnecessary and not recommended
+ by the database vendor. See the section
+ :ref:`sqlite_autoincrement` for more background.
+ * Oracle - The Oracle dialect has no default "autoincrement"
+ feature available at this time, instead the :class:`.Identity`
+ construct is recommended to achieve this (the :class:`.Sequence`
+ construct may also be used).
+ * Third-party dialects - consult those dialects' documentation
+ for details on their specific behaviors.
+
+ * When a single-row :func:`_sql.insert` construct is compiled and
+ executed, which does not set the :meth:`_sql.Insert.inline`
+ modifier, newly generated primary key values for this column
+ will be automatically retrieved upon statement execution
+ using a method specific to the database driver in use:
+
+ * MySQL, SQLite - calling upon ``cursor.lastrowid()``
+ (see
+ `https://www.python.org/dev/peps/pep-0249/#lastrowid
+ <https://www.python.org/dev/peps/pep-0249/#lastrowid>`_)
+ * PostgreSQL, SQL Server, Oracle - use RETURNING or an equivalent
+ construct when rendering an INSERT statement, and then retrieving
+ the newly generated primary key values after execution
+ * PostgreSQL, Oracle for :class:`_schema.Table` objects that
+ set :paramref:`_schema.Table.implicit_returning` to False -
+ for a :class:`.Sequence` only, the :class:`.Sequence` is invoked
+ explicitly before the INSERT statement takes place so that the
+ newly generated primary key value is available to the client
+ * SQL Server for :class:`_schema.Table` objects that
+ set :paramref:`_schema.Table.implicit_returning` to False -
+ the ``SELECT scope_identity()`` construct is used after the
+ INSERT statement is invoked to retrieve the newly generated
+ primary key value.
+ * Third-party dialects - consult those dialects' documentation
+ for details on their specific behaviors.
+
+ * For multiple-row :func:`_sql.insert` constructs invoked with
+ a list of parameters (i.e. "executemany" semantics), primary-key
+ retrieving behaviors are generally disabled, however there may
+ be special APIs that may be used to retrieve lists of new
+ primary key values for an "executemany", such as the psycopg2
+ "fast insertmany" feature. Such features are very new and
+ may not yet be well covered in documentation.
+
+ :param default: A scalar, Python callable, or
+ :class:`_expression.ColumnElement` expression representing the
+ *default value* for this column, which will be invoked upon insert
+ if this column is otherwise not specified in the VALUES clause of
+ the insert. This is a shortcut to using :class:`.ColumnDefault` as
+ a positional argument; see that class for full detail on the
+ structure of the argument.
+
+ Contrast this argument to
+ :paramref:`_schema.Column.server_default`
+ which creates a default generator on the database side.
+
+ .. seealso::
+
+ :ref:`metadata_defaults_toplevel`
+
+ :param doc: optional String that can be used by the ORM or similar
+ to document attributes on the Python side. This attribute does
+ **not** render SQL comments; use the
+ :paramref:`_schema.Column.comment`
+ parameter for this purpose.
+
+ :param key: An optional string identifier which will identify this
+ ``Column`` object on the :class:`_schema.Table`.
+ When a key is provided,
+ this is the only identifier referencing the ``Column`` within the
+ application, including ORM attribute mapping; the ``name`` field
+ is used only when rendering SQL.
+
+ :param index: When ``True``, indicates that a :class:`_schema.Index`
+ construct will be automatically generated for this
+ :class:`_schema.Column`, which will result in a "CREATE INDEX"
+ statement being emitted for the :class:`_schema.Table` when the DDL
+ create operation is invoked.
+
+ Using this flag is equivalent to making use of the
+ :class:`_schema.Index` construct explicitly at the level of the
+ :class:`_schema.Table` construct itself::
+
+ Table(
+ "some_table",
+ metadata,
+ Column("x", Integer),
+ Index("ix_some_table_x", "x")
+ )
+
+ To add the :paramref:`_schema.Index.unique` flag to the
+ :class:`_schema.Index`, set both the
+ :paramref:`_schema.Column.unique` and
+ :paramref:`_schema.Column.index` flags to True simultaneously,
+ which will have the effect of rendering the "CREATE UNIQUE INDEX"
+ DDL instruction instead of "CREATE INDEX".
+
+ The name of the index is generated using the
+ :ref:`default naming convention <constraint_default_naming_convention>`
+ which for the :class:`_schema.Index` construct is of the form
+ ``ix_<tablename>_<columnname>``.
+
+ As this flag is intended only as a convenience for the common case
+ of adding a single-column, default configured index to a table
+ definition, explicit use of the :class:`_schema.Index` construct
+ should be preferred for most use cases, including composite indexes
+ that encompass more than one column, indexes with SQL expressions
+ or ordering, backend-specific index configuration options, and
+ indexes that use a specific name.
+
+ .. note:: the :attr:`_schema.Column.index` attribute on
+ :class:`_schema.Column`
+ **does not indicate** if this column is indexed or not, only
+ if this flag was explicitly set here. To view indexes on
+ a column, view the :attr:`_schema.Table.indexes` collection
+ or use :meth:`_reflection.Inspector.get_indexes`.
+
+ .. seealso::
+
+ :ref:`schema_indexes`
+
+ :ref:`constraint_naming_conventions`
+
+ :paramref:`_schema.Column.unique`
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ :param nullable: When set to ``False``, will cause the "NOT NULL"
+ phrase to be added when generating DDL for the column. When
+ ``True``, will normally generate nothing (in SQL this defaults to
+ "NULL"), except in some very specific backend-specific edge cases
+ where "NULL" may render explicitly.
+ Defaults to ``True`` unless :paramref:`_schema.Column.primary_key`
+ is also ``True`` or the column specifies a :class:`_sql.Identity`,
+ in which case it defaults to ``False``.
+ This parameter is only used when issuing CREATE TABLE statements.
+
+ .. note::
+
+ When the column specifies a :class:`_sql.Identity` this
+ parameter is in general ignored by the DDL compiler. The
+ PostgreSQL database allows nullable identity column by
+ setting this parameter to ``True`` explicitly.
+
+ :param onupdate: A scalar, Python callable, or
+ :class:`~sqlalchemy.sql.expression.ClauseElement` representing a
+ default value to be applied to the column within UPDATE
+ statements, which will be invoked upon update if this column is not
+ present in the SET clause of the update. This is a shortcut to
+ using :class:`.ColumnDefault` as a positional argument with
+ ``for_update=True``.
+
+ .. seealso::
+
+ :ref:`metadata_defaults` - complete discussion of onupdate
+
+ :param primary_key: If ``True``, marks this column as a primary key
+ column. Multiple columns can have this flag set to specify
+ composite primary keys. As an alternative, the primary key of a
+ :class:`_schema.Table` can be specified via an explicit
+ :class:`.PrimaryKeyConstraint` object.
+
+ :param server_default: A :class:`.FetchedValue` instance, str, Unicode
+ or :func:`~sqlalchemy.sql.expression.text` construct representing
+ the DDL DEFAULT value for the column.
+
+ String types will be emitted as-is, surrounded by single quotes::
+
+ Column('x', Text, server_default="val")
+
+ x TEXT DEFAULT 'val'
+
+ A :func:`~sqlalchemy.sql.expression.text` expression will be
+ rendered as-is, without quotes::
+
+ Column('y', DateTime, server_default=text('NOW()'))
+
+ y DATETIME DEFAULT NOW()
+
+ Strings and text() will be converted into a
+ :class:`.DefaultClause` object upon initialization.
+
+ This parameter can also accept complex combinations of contextually
+ valid SQLAlchemy expressions or constructs::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy import Table, Column, MetaData, ARRAY, Text
+ from sqlalchemy.dialects.postgresql import array
+
+ engine = create_engine(
+ 'postgresql://scott:tiger@localhost/mydatabase'
+ )
+ metadata_obj = MetaData()
+ tbl = Table(
+ "foo",
+ metadata_obj,
+ Column("bar",
+ ARRAY(Text),
+ server_default=array(["biz", "bang", "bash"])
+ )
+ )
+ metadata_obj.create_all(engine)
+
+ The above results in a table created with the following SQL::
+
+ CREATE TABLE foo (
+ bar TEXT[] DEFAULT ARRAY['biz', 'bang', 'bash']
+ )
+
+ Use :class:`.FetchedValue` to indicate that an already-existing
+ column will generate a default value on the database side which
+ will be available to SQLAlchemy for post-fetch after inserts. This
+ construct does not specify any DDL and the implementation is left
+ to the database, such as via a trigger.
+
+ .. seealso::
+
+ :ref:`server_defaults` - complete discussion of server side
+ defaults
+
+ :param server_onupdate: A :class:`.FetchedValue` instance
+ representing a database-side default generation function,
+ such as a trigger. This
+ indicates to SQLAlchemy that a newly generated value will be
+ available after updates. This construct does not actually
+ implement any kind of generation function within the database,
+ which instead must be specified separately.
+
+
+ .. warning:: This directive **does not** currently produce MySQL's
+ "ON UPDATE CURRENT_TIMESTAMP()" clause. See
+ :ref:`mysql_timestamp_onupdate` for background on how to
+ produce this clause.
+
+ .. seealso::
+
+ :ref:`triggered_columns`
+
+ :param quote: Force quoting of this column's name on or off,
+ corresponding to ``True`` or ``False``. When left at its default
+ of ``None``, the column identifier will be quoted according to
+ whether the name is case sensitive (identifiers with at least one
+ upper case character are treated as case sensitive), or if it's a
+ reserved word. This flag is only needed to force quoting of a
+ reserved word which is not known by the SQLAlchemy dialect.
+
+ :param unique: When ``True``, and the :paramref:`_schema.Column.index`
+ parameter is left at its default value of ``False``,
+ indicates that a :class:`_schema.UniqueConstraint`
+ construct will be automatically generated for this
+ :class:`_schema.Column`,
+ which will result in a "UNIQUE CONSTRAINT" clause referring
+ to this column being included
+ in the ``CREATE TABLE`` statement emitted, when the DDL create
+ operation for the :class:`_schema.Table` object is invoked.
+
+ When this flag is ``True`` while the
+ :paramref:`_schema.Column.index` parameter is simultaneously
+ set to ``True``, the effect instead is that a
+ :class:`_schema.Index` construct which includes the
+ :paramref:`_schema.Index.unique` parameter set to ``True``
+ is generated. See the documentation for
+ :paramref:`_schema.Column.index` for additional detail.
+
+ Using this flag is equivalent to making use of the
+ :class:`_schema.UniqueConstraint` construct explicitly at the
+ level of the :class:`_schema.Table` construct itself::
+
+ Table(
+ "some_table",
+ metadata,
+ Column("x", Integer),
+ UniqueConstraint("x")
+ )
+
+ The :paramref:`_schema.UniqueConstraint.name` parameter
+ of the unique constraint object is left at its default value
+ of ``None``; in the absence of a :ref:`naming convention <constraint_naming_conventions>`
+ for the enclosing :class:`_schema.MetaData`, the UNIQUE CONSTRAINT
+ construct will be emitted as unnamed, which typically invokes
+ a database-specific naming convention to take place.
+
+ As this flag is intended only as a convenience for the common case
+ of adding a single-column, default configured unique constraint to a table
+ definition, explicit use of the :class:`_schema.UniqueConstraint` construct
+ should be preferred for most use cases, including composite constraints
+ that encompass more than one column, backend-specific index configuration options, and
+ constraints that use a specific name.
+
+ .. note:: the :attr:`_schema.Column.unique` attribute on
+ :class:`_schema.Column`
+ **does not indicate** if this column has a unique constraint or
+ not, only if this flag was explicitly set here. To view
+ indexes and unique constraints that may involve this column,
+ view the
+ :attr:`_schema.Table.indexes` and/or
+ :attr:`_schema.Table.constraints` collections or use
+ :meth:`_reflection.Inspector.get_indexes` and/or
+ :meth:`_reflection.Inspector.get_unique_constraints`
+
+ .. seealso::
+
+ :ref:`schema_unique_constraint`
+
+ :ref:`constraint_naming_conventions`
+
+ :paramref:`_schema.Column.index`
+
+ :param system: When ``True``, indicates this is a "system" column,
+ that is a column which is automatically made available by the
+ database, and should not be included in the columns list for a
+ ``CREATE TABLE`` statement.
+
+ For more elaborate scenarios where columns should be
+ conditionally rendered differently on different backends,
+ consider custom compilation rules for :class:`.CreateColumn`.
+
+ :param comment: Optional string that will render an SQL comment on
+ table creation.
+
+ .. versionadded:: 1.2 Added the
+ :paramref:`_schema.Column.comment`
+ parameter to :class:`_schema.Column`.
+
+
+ """ # noqa: E501, RST201, RST202
+
+ name = kwargs.pop("name", None)
+ type_ = kwargs.pop("type_", None)
+ args = list(args)
+ if args:
+ if isinstance(args[0], util.string_types):
+ if name is not None:
+ raise exc.ArgumentError(
+ "May not pass name positionally and as a keyword."
+ )
+ name = args.pop(0)
+ if args:
+ coltype = args[0]
+
+ if hasattr(coltype, "_sqla_type"):
+ if type_ is not None:
+ raise exc.ArgumentError(
+ "May not pass type_ positionally and as a keyword."
+ )
+ type_ = args.pop(0)
+
+ if name is not None:
+ name = quoted_name(name, kwargs.pop("quote", None))
+ elif "quote" in kwargs:
+ raise exc.ArgumentError(
+ "Explicit 'name' is required when " "sending 'quote' argument"
+ )
+
+ super(Column, self).__init__(name, type_)
+ self.key = kwargs.pop("key", name)
+ self.primary_key = primary_key = kwargs.pop("primary_key", False)
+
+ self._user_defined_nullable = udn = kwargs.pop(
+ "nullable", NULL_UNSPECIFIED
+ )
+
+ if udn is not NULL_UNSPECIFIED:
+ self.nullable = udn
+ else:
+ self.nullable = not primary_key
+
+ self.default = kwargs.pop("default", None)
+ self.server_default = kwargs.pop("server_default", None)
+ self.server_onupdate = kwargs.pop("server_onupdate", None)
+
+ # these default to None because .index and .unique is *not*
+ # an informational flag about Column - there can still be an
+ # Index or UniqueConstraint referring to this Column.
+ self.index = kwargs.pop("index", None)
+ self.unique = kwargs.pop("unique", None)
+
+ self.system = kwargs.pop("system", False)
+ self.doc = kwargs.pop("doc", None)
+ self.onupdate = kwargs.pop("onupdate", None)
+ self.autoincrement = kwargs.pop("autoincrement", "auto")
+ self.constraints = set()
+ self.foreign_keys = set()
+ self.comment = kwargs.pop("comment", None)
+ self.computed = None
+ self.identity = None
+
+ # check if this Column is proxying another column
+ if "_proxies" in kwargs:
+ self._proxies = kwargs.pop("_proxies")
+ # otherwise, add DDL-related events
+ elif isinstance(self.type, SchemaEventTarget):
+ self.type._set_parent_with_dispatch(self)
+
+ if self.default is not None:
+ if isinstance(self.default, (ColumnDefault, Sequence)):
+ args.append(self.default)
+ else:
+ if getattr(self.type, "_warn_on_bytestring", False):
+ if isinstance(self.default, util.binary_type):
+ util.warn(
+ "Unicode column '%s' has non-unicode "
+ "default value %r specified."
+ % (self.key, self.default)
+ )
+ args.append(ColumnDefault(self.default))
+
+ if self.server_default is not None:
+ if isinstance(self.server_default, FetchedValue):
+ args.append(self.server_default._as_for_update(False))
+ else:
+ args.append(DefaultClause(self.server_default))
+
+ if self.onupdate is not None:
+ if isinstance(self.onupdate, (ColumnDefault, Sequence)):
+ args.append(self.onupdate)
+ else:
+ args.append(ColumnDefault(self.onupdate, for_update=True))
+
+ if self.server_onupdate is not None:
+ if isinstance(self.server_onupdate, FetchedValue):
+ args.append(self.server_onupdate._as_for_update(True))
+ else:
+ args.append(
+ DefaultClause(self.server_onupdate, for_update=True)
+ )
+ self._init_items(*args)
+
+ util.set_creation_order(self)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ self._extra_kwargs(**kwargs)
+
+ foreign_keys = None
+ """A collection of all :class:`_schema.ForeignKey` marker objects
+ associated with this :class:`_schema.Column`.
+
+ Each object is a member of a :class:`_schema.Table`-wide
+ :class:`_schema.ForeignKeyConstraint`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.foreign_keys`
+
+ """
+
+ index = None
+ """The value of the :paramref:`_schema.Column.index` parameter.
+
+ Does not indicate if this :class:`_schema.Column` is actually indexed
+ or not; use :attr:`_schema.Table.indexes`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.indexes`
+ """
+
+ unique = None
+ """The value of the :paramref:`_schema.Column.unique` parameter.
+
+ Does not indicate if this :class:`_schema.Column` is actually subject to
+ a unique constraint or not; use :attr:`_schema.Table.indexes` and
+ :attr:`_schema.Table.constraints`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.indexes`
+
+ :attr:`_schema.Table.constraints`.
+
+ """
+
+ def _extra_kwargs(self, **kwargs):
+ self._validate_dialect_kwargs(kwargs)
+
+ def __str__(self):
+ if self.name is None:
+ return "(no name)"
+ elif self.table is not None:
+ if self.table.named_with_column:
+ return self.table.description + "." + self.description
+ else:
+ return self.description
+ else:
+ return self.description
+
+ def references(self, column):
+ """Return True if this Column references the given column via foreign
+ key."""
+
+ for fk in self.foreign_keys:
+ if fk.column.proxy_set.intersection(column.proxy_set):
+ return True
+ else:
+ return False
+
+ def append_foreign_key(self, fk):
+ fk._set_parent_with_dispatch(self)
+
+ def __repr__(self):
+ kwarg = []
+ if self.key != self.name:
+ kwarg.append("key")
+ if self.primary_key:
+ kwarg.append("primary_key")
+ if not self.nullable:
+ kwarg.append("nullable")
+ if self.onupdate:
+ kwarg.append("onupdate")
+ if self.default:
+ kwarg.append("default")
+ if self.server_default:
+ kwarg.append("server_default")
+ if self.comment:
+ kwarg.append("comment")
+ return "Column(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.type)]
+ + [repr(x) for x in self.foreign_keys if x is not None]
+ + [repr(x) for x in self.constraints]
+ + [
+ (
+ self.table is not None
+ and "table=<%s>" % self.table.description
+ or "table=None"
+ )
+ ]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ )
+
+ def _set_parent(self, table, allow_replacements=True):
+ if not self.name:
+ raise exc.ArgumentError(
+ "Column must be constructed with a non-blank name or "
+ "assign a non-blank .name before adding to a Table."
+ )
+
+ self._reset_memoizations()
+
+ if self.key is None:
+ self.key = self.name
+
+ existing = getattr(self, "table", None)
+ if existing is not None and existing is not table:
+ raise exc.ArgumentError(
+ "Column object '%s' already assigned to Table '%s'"
+ % (self.key, existing.description)
+ )
+
+ if self.key in table._columns:
+ col = table._columns.get(self.key)
+ if col is not self:
+ if not allow_replacements:
+ util.warn_deprecated(
+ "A column with name '%s' is already present "
+ "in table '%s'. Please use method "
+ ":meth:`_schema.Table.append_column` with the "
+ "parameter ``replace_existing=True`` to replace an "
+ "existing column." % (self.key, table.name),
+ "1.4",
+ )
+ for fk in col.foreign_keys:
+ table.foreign_keys.remove(fk)
+ if fk.constraint in table.constraints:
+ # this might have been removed
+ # already, if it's a composite constraint
+ # and more than one col being replaced
+ table.constraints.remove(fk.constraint)
+
+ table._columns.replace(self)
+
+ self.table = table
+
+ if self.primary_key:
+ table.primary_key._replace(self)
+ elif self.key in table.primary_key:
+ raise exc.ArgumentError(
+ "Trying to redefine primary-key column '%s' as a "
+ "non-primary-key column on table '%s'"
+ % (self.key, table.fullname)
+ )
+
+ if self.index:
+ if isinstance(self.index, util.string_types):
+ raise exc.ArgumentError(
+ "The 'index' keyword argument on Column is boolean only. "
+ "To create indexes with a specific name, create an "
+ "explicit Index object external to the Table."
+ )
+ table.append_constraint(
+ Index(
+ None, self.key, unique=bool(self.unique), _column_flag=True
+ )
+ )
+
+ elif self.unique:
+ if isinstance(self.unique, util.string_types):
+ raise exc.ArgumentError(
+ "The 'unique' keyword argument on Column is boolean "
+ "only. To create unique constraints or indexes with a "
+ "specific name, append an explicit UniqueConstraint to "
+ "the Table's list of elements, or create an explicit "
+ "Index object external to the Table."
+ )
+ table.append_constraint(
+ UniqueConstraint(self.key, _column_flag=True)
+ )
+
+ self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
+
+ if self.identity and (
+ isinstance(self.default, Sequence)
+ or isinstance(self.onupdate, Sequence)
+ ):
+ raise exc.ArgumentError(
+ "An column cannot specify both Identity and Sequence."
+ )
+
+ def _setup_on_memoized_fks(self, fn):
+ fk_keys = [
+ ((self.table.key, self.key), False),
+ ((self.table.key, self.name), True),
+ ]
+ for fk_key, link_to_name in fk_keys:
+ if fk_key in self.table.metadata._fk_memos:
+ for fk in self.table.metadata._fk_memos[fk_key]:
+ if fk.link_to_name is link_to_name:
+ fn(fk)
+
+ def _on_table_attach(self, fn):
+ if self.table is not None:
+ fn(self, self.table)
+ else:
+ event.listen(self, "after_parent_attach", fn)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Column.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ """Create a copy of this ``Column``, uninitialized.
+
+ This is used in :meth:`_schema.Table.to_metadata`.
+
+ """
+
+ # Constraint objects plus non-constraint-bound ForeignKey objects
+ args = [
+ c._copy(**kw) for c in self.constraints if not c._type_bound
+ ] + [c._copy(**kw) for c in self.foreign_keys if not c.constraint]
+
+ # ticket #5276
+ column_kwargs = {}
+ for dialect_name in self.dialect_options:
+ dialect_options = self.dialect_options[dialect_name]._non_defaults
+ for (
+ dialect_option_key,
+ dialect_option_value,
+ ) in dialect_options.items():
+ column_kwargs[
+ dialect_name + "_" + dialect_option_key
+ ] = dialect_option_value
+
+ server_default = self.server_default
+ server_onupdate = self.server_onupdate
+ if isinstance(server_default, (Computed, Identity)):
+ server_default = server_onupdate = None
+ args.append(self.server_default._copy(**kw))
+
+ type_ = self.type
+ if isinstance(type_, SchemaEventTarget):
+ type_ = type_.copy(**kw)
+
+ if self._user_defined_nullable is not NULL_UNSPECIFIED:
+ column_kwargs["nullable"] = self._user_defined_nullable
+
+ c = self._constructor(
+ name=self.name,
+ type_=type_,
+ key=self.key,
+ primary_key=self.primary_key,
+ unique=self.unique,
+ system=self.system,
+ # quote=self.quote, # disabled 2013-08-27 (commit 031ef080)
+ index=self.index,
+ autoincrement=self.autoincrement,
+ default=self.default,
+ server_default=server_default,
+ onupdate=self.onupdate,
+ server_onupdate=server_onupdate,
+ doc=self.doc,
+ comment=self.comment,
+ *args,
+ **column_kwargs
+ )
+ return self._schema_item_copy(c)
+
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
+ """Create a *proxy* for this column.
+
+ This is a copy of this ``Column`` referenced by a different parent
+ (such as an alias or select statement). The column should
+ be used only in select scenarios, as its full DDL/default
+ information is not transferred.
+
+ """
+
+ fk = [
+ ForeignKey(
+ col if col is not None else f._colspec,
+ _unresolvable=col is None,
+ _constraint=f.constraint,
+ )
+ for f, col in [
+ (fk, fk._resolve_column(raiseerr=False))
+ for fk in self.foreign_keys
+ ]
+ ]
+
+ if name is None and self.name is None:
+ raise exc.InvalidRequestError(
+ "Cannot initialize a sub-selectable"
+ " with this Column object until its 'name' has "
+ "been assigned."
+ )
+ try:
+ c = self._constructor(
+ coercions.expect(
+ roles.TruncatedLabelRole, name if name else self.name
+ )
+ if name_is_truncatable
+ else (name or self.name),
+ self.type,
+ # this may actually be ._proxy_key when the key is incoming
+ key=key if key else name if name else self.key,
+ primary_key=self.primary_key,
+ nullable=self.nullable,
+ _proxies=[self],
+ *fk
+ )
+ except TypeError as err:
+ util.raise_(
+ TypeError(
+ "Could not create a copy of this %r object. "
+ "Ensure the class includes a _constructor() "
+ "attribute or method which accepts the "
+ "standard Column constructor arguments, or "
+ "references the Column class itself." % self.__class__
+ ),
+ from_=err,
+ )
+
+ c.table = selectable
+ c._propagate_attrs = selectable._propagate_attrs
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ if self.primary_key:
+ selectable.primary_key.add(c)
+ if fk:
+ selectable.foreign_keys.update(fk)
+ return c.key, c
+
+
+class ForeignKey(DialectKWArgs, SchemaItem):
+ """Defines a dependency between two columns.
+
+ ``ForeignKey`` is specified as an argument to a :class:`_schema.Column`
+ object,
+ e.g.::
+
+ t = Table("remote_table", metadata,
+ Column("remote_id", ForeignKey("main_table.id"))
+ )
+
+ Note that ``ForeignKey`` is only a marker object that defines
+ a dependency between two columns. The actual constraint
+ is in all cases represented by the :class:`_schema.ForeignKeyConstraint`
+ object. This object will be generated automatically when
+ a ``ForeignKey`` is associated with a :class:`_schema.Column` which
+ in turn is associated with a :class:`_schema.Table`. Conversely,
+ when :class:`_schema.ForeignKeyConstraint` is applied to a
+ :class:`_schema.Table`,
+ ``ForeignKey`` markers are automatically generated to be
+ present on each associated :class:`_schema.Column`, which are also
+ associated with the constraint object.
+
+ Note that you cannot define a "composite" foreign key constraint,
+ that is a constraint between a grouping of multiple parent/child
+ columns, using ``ForeignKey`` objects. To define this grouping,
+ the :class:`_schema.ForeignKeyConstraint` object must be used, and applied
+ to the :class:`_schema.Table`. The associated ``ForeignKey`` objects
+ are created automatically.
+
+ The ``ForeignKey`` objects associated with an individual
+ :class:`_schema.Column`
+ object are available in the `foreign_keys` collection
+ of that column.
+
+ Further examples of foreign key configuration are in
+ :ref:`metadata_foreignkeys`.
+
+ """
+
+ __visit_name__ = "foreign_key"
+
+ def __init__(
+ self,
+ column,
+ _constraint=None,
+ use_alter=False,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ link_to_name=False,
+ match=None,
+ info=None,
+ _unresolvable=False,
+ **dialect_kw
+ ):
+ r"""
+ Construct a column-level FOREIGN KEY.
+
+ The :class:`_schema.ForeignKey` object when constructed generates a
+ :class:`_schema.ForeignKeyConstraint`
+ which is associated with the parent
+ :class:`_schema.Table` object's collection of constraints.
+
+ :param column: A single target column for the key relationship. A
+ :class:`_schema.Column` object or a column name as a string:
+ ``tablename.columnkey`` or ``schema.tablename.columnkey``.
+ ``columnkey`` is the ``key`` which has been assigned to the column
+ (defaults to the column name itself), unless ``link_to_name`` is
+ ``True`` in which case the rendered name of the column is used.
+
+ :param name: Optional string. An in-database name for the key if
+ `constraint` is not provided.
+
+ :param onupdate: Optional string. If set, emit ON UPDATE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param ondelete: Optional string. If set, emit ON DELETE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT
+ DEFERRABLE when issuing DDL for this constraint.
+
+ :param initially: Optional string. If set, emit INITIALLY <value> when
+ issuing DDL for this constraint.
+
+ :param link_to_name: if True, the string name given in ``column`` is
+ the rendered name of the referenced column, not its locally
+ assigned ``key``.
+
+ :param use_alter: passed to the underlying
+ :class:`_schema.ForeignKeyConstraint`
+ to indicate the constraint should
+ be generated/dropped externally from the CREATE TABLE/ DROP TABLE
+ statement. See :paramref:`_schema.ForeignKeyConstraint.use_alter`
+ for further description.
+
+ .. seealso::
+
+ :paramref:`_schema.ForeignKeyConstraint.use_alter`
+
+ :ref:`use_alter`
+
+ :param match: Optional string. If set, emit MATCH <value> when issuing
+ DDL for this constraint. Typical values include SIMPLE, PARTIAL
+ and FULL.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. The
+ arguments are ultimately handled by a corresponding
+ :class:`_schema.ForeignKeyConstraint`.
+ See the documentation regarding
+ an individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ .. versionadded:: 0.9.2
+
+ """
+
+ self._colspec = coercions.expect(roles.DDLReferredColumnRole, column)
+ self._unresolvable = _unresolvable
+
+ if isinstance(self._colspec, util.string_types):
+ self._table_column = None
+ else:
+ self._table_column = self._colspec
+
+ if not isinstance(
+ self._table_column.table, (util.NoneType, TableClause)
+ ):
+ raise exc.ArgumentError(
+ "ForeignKey received Column not bound "
+ "to a Table, got: %r" % self._table_column.table
+ )
+
+ # the linked ForeignKeyConstraint.
+ # ForeignKey will create this when parent Column
+ # is attached to a Table, *or* ForeignKeyConstraint
+ # object passes itself in when creating ForeignKey
+ # markers.
+ self.constraint = _constraint
+ self.parent = None
+ self.use_alter = use_alter
+ self.name = name
+ self.onupdate = onupdate
+ self.ondelete = ondelete
+ self.deferrable = deferrable
+ self.initially = initially
+ self.link_to_name = link_to_name
+ self.match = match
+ if info:
+ self.info = info
+ self._unvalidated_dialect_kw = dialect_kw
+
+ def __repr__(self):
+ return "ForeignKey(%r)" % self._get_colspec()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ForeignKey.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, schema=None, **kw):
+ return self._copy(schema=schema, **kw)
+
+ def _copy(self, schema=None, **kw):
+ """Produce a copy of this :class:`_schema.ForeignKey` object.
+
+ The new :class:`_schema.ForeignKey` will not be bound
+ to any :class:`_schema.Column`.
+
+ This method is usually used by the internal
+ copy procedures of :class:`_schema.Column`, :class:`_schema.Table`,
+ and :class:`_schema.MetaData`.
+
+ :param schema: The returned :class:`_schema.ForeignKey` will
+ reference the original table and column name, qualified
+ by the given string schema name.
+
+ """
+
+ fk = ForeignKey(
+ self._get_colspec(schema=schema),
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ **self._unvalidated_dialect_kw
+ )
+ return self._schema_item_copy(fk)
+
+ def _get_colspec(self, schema=None, table_name=None):
+ """Return a string based 'column specification' for this
+ :class:`_schema.ForeignKey`.
+
+ This is usually the equivalent of the string-based "tablename.colname"
+ argument first passed to the object's constructor.
+
+ """
+ if schema not in (None, RETAIN_SCHEMA):
+ _schema, tname, colname = self._column_tokens
+ if table_name is not None:
+ tname = table_name
+ if schema is BLANK_SCHEMA:
+ return "%s.%s" % (tname, colname)
+ else:
+ return "%s.%s.%s" % (schema, tname, colname)
+ elif table_name:
+ schema, tname, colname = self._column_tokens
+ if schema:
+ return "%s.%s.%s" % (schema, table_name, colname)
+ else:
+ return "%s.%s" % (table_name, colname)
+ elif self._table_column is not None:
+ return "%s.%s" % (
+ self._table_column.table.fullname,
+ self._table_column.key,
+ )
+ else:
+ return self._colspec
+
+ @property
+ def _referred_schema(self):
+ return self._column_tokens[0]
+
+ def _table_key(self):
+ if self._table_column is not None:
+ if self._table_column.table is None:
+ return None
+ else:
+ return self._table_column.table.key
+ else:
+ schema, tname, colname = self._column_tokens
+ return _get_table_key(tname, schema)
+
+ target_fullname = property(_get_colspec)
+
+ def references(self, table):
+ """Return True if the given :class:`_schema.Table`
+ is referenced by this
+ :class:`_schema.ForeignKey`."""
+
+ return table.corresponding_column(self.column) is not None
+
+ def get_referent(self, table):
+ """Return the :class:`_schema.Column` in the given
+ :class:`_schema.Table`
+ referenced by this :class:`_schema.ForeignKey`.
+
+ Returns None if this :class:`_schema.ForeignKey`
+ does not reference the given
+ :class:`_schema.Table`.
+
+ """
+
+ return table.corresponding_column(self.column)
+
+ @util.memoized_property
+ def _column_tokens(self):
+ """parse a string-based _colspec into its component parts."""
+
+ m = self._get_colspec().split(".")
+ if m is None:
+ raise exc.ArgumentError(
+ "Invalid foreign key column specification: %s" % self._colspec
+ )
+ if len(m) == 1:
+ tname = m.pop()
+ colname = None
+ else:
+ colname = m.pop()
+ tname = m.pop()
+
+ # A FK between column 'bar' and table 'foo' can be
+ # specified as 'foo', 'foo.bar', 'dbo.foo.bar',
+ # 'otherdb.dbo.foo.bar'. Once we have the column name and
+ # the table name, treat everything else as the schema
+ # name. Some databases (e.g. Sybase) support
+ # inter-database foreign keys. See tickets#1341 and --
+ # indirectly related -- Ticket #594. This assumes that '.'
+ # will never appear *within* any component of the FK.
+
+ if len(m) > 0:
+ schema = ".".join(m)
+ else:
+ schema = None
+ return schema, tname, colname
+
+ def _resolve_col_tokens(self):
+ if self.parent is None:
+ raise exc.InvalidRequestError(
+ "this ForeignKey object does not yet have a "
+ "parent Column associated with it."
+ )
+
+ elif self.parent.table is None:
+ raise exc.InvalidRequestError(
+ "this ForeignKey's parent column is not yet associated "
+ "with a Table."
+ )
+
+ parenttable = self.parent.table
+
+ if self._unresolvable:
+ schema, tname, colname = self._column_tokens
+ tablekey = _get_table_key(tname, schema)
+ return parenttable, tablekey, colname
+
+ # assertion
+ # basically Column._make_proxy() sends the actual
+ # target Column to the ForeignKey object, so the
+ # string resolution here is never called.
+ for c in self.parent.base_columns:
+ if isinstance(c, Column):
+ assert c.table is parenttable
+ break
+ else:
+ assert False
+ ######################
+
+ schema, tname, colname = self._column_tokens
+
+ if schema is None and parenttable.metadata.schema is not None:
+ schema = parenttable.metadata.schema
+
+ tablekey = _get_table_key(tname, schema)
+ return parenttable, tablekey, colname
+
+ def _link_to_col_by_colstring(self, parenttable, table, colname):
+
+ _column = None
+ if colname is None:
+ # colname is None in the case that ForeignKey argument
+ # was specified as table name only, in which case we
+ # match the column name to the same column on the
+ # parent.
+ # this use case wasn't working in later 1.x series
+ # as it had no test coverage; fixed in 2.0
+ parent = self.parent
+ assert parent is not None
+ key = parent.key
+ _column = table.c.get(key, None)
+ elif self.link_to_name:
+ key = colname
+ for c in table.c:
+ if c.name == colname:
+ _column = c
+ else:
+ key = colname
+ _column = table.c.get(colname, None)
+
+ if _column is None:
+ raise exc.NoReferencedColumnError(
+ "Could not initialize target column "
+ "for ForeignKey '%s' on table '%s': "
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, table.name, key),
+ table.name,
+ key,
+ )
+
+ return _column
+
+ def _set_target_column(self, column):
+ assert self.parent is not None
+
+ # propagate TypeEngine to parent if it didn't have one
+ if self.parent.type._isnull:
+ self.parent.type = column.type
+
+ # super-edgy case, if other FKs point to our column,
+ # they'd get the type propagated out also.
+
+ def set_type(fk):
+ if fk.parent.type._isnull:
+ fk.parent.type = column.type
+
+ self.parent._setup_on_memoized_fks(set_type)
+
+ self.column = column
+
+ @util.memoized_property
+ def column(self):
+ """Return the target :class:`_schema.Column` referenced by this
+ :class:`_schema.ForeignKey`.
+
+ If no target column has been established, an exception
+ is raised.
+
+ .. versionchanged:: 0.9.0
+ Foreign key target column resolution now occurs as soon as both
+ the ForeignKey object and the remote Column to which it refers
+ are both associated with the same MetaData object.
+
+ """
+
+ return self._resolve_column()
+
+ def _resolve_column(self, raiseerr=True):
+
+ if isinstance(self._colspec, util.string_types):
+
+ parenttable, tablekey, colname = self._resolve_col_tokens()
+
+ if self._unresolvable or tablekey not in parenttable.metadata:
+ if not raiseerr:
+ return None
+ raise exc.NoReferencedTableError(
+ "Foreign key associated with column '%s' could not find "
+ "table '%s' with which to generate a "
+ "foreign key to target column '%s'"
+ % (self.parent, tablekey, colname),
+ tablekey,
+ )
+ elif parenttable.key not in parenttable.metadata:
+ if not raiseerr:
+ return None
+ raise exc.InvalidRequestError(
+ "Table %s is no longer associated with its "
+ "parent MetaData" % parenttable
+ )
+ else:
+ table = parenttable.metadata.tables[tablekey]
+ return self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
+
+ elif hasattr(self._colspec, "__clause_element__"):
+ _column = self._colspec.__clause_element__()
+ return _column
+ else:
+ _column = self._colspec
+ return _column
+
+ def _set_parent(self, column, **kw):
+ if self.parent is not None and self.parent is not column:
+ raise exc.InvalidRequestError(
+ "This ForeignKey already has a parent !"
+ )
+ self.parent = column
+ self.parent.foreign_keys.add(self)
+ self.parent._on_table_attach(self._set_table)
+
+ def _set_remote_table(self, table):
+ parenttable, tablekey, colname = self._resolve_col_tokens()
+ self._link_to_col_by_colstring(parenttable, table, colname)
+
+ _column = self._link_to_col_by_colstring(parenttable, table, colname)
+ self._set_target_column(_column)
+ assert self.constraint is not None
+
+ self.constraint._validate_dest_table(table)
+
+ def _remove_from_metadata(self, metadata):
+ parenttable, table_key, colname = self._resolve_col_tokens()
+ fk_key = (table_key, colname)
+
+ if self in metadata._fk_memos[fk_key]:
+ # TODO: no test coverage for self not in memos
+ metadata._fk_memos[fk_key].remove(self)
+
+ def _set_table(self, column, table):
+ # standalone ForeignKey - create ForeignKeyConstraint
+ # on the hosting Table when attached to the Table.
+ assert isinstance(table, Table)
+ if self.constraint is None:
+ self.constraint = ForeignKeyConstraint(
+ [],
+ [],
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ match=self.match,
+ **self._unvalidated_dialect_kw
+ )
+ self.constraint._append_element(column, self)
+ self.constraint._set_parent_with_dispatch(table)
+ table.foreign_keys.add(self)
+ # set up remote ".column" attribute, or a note to pick it
+ # up when the other Table/Column shows up
+ if isinstance(self._colspec, util.string_types):
+ parenttable, table_key, colname = self._resolve_col_tokens()
+ fk_key = (table_key, colname)
+ if table_key in parenttable.metadata.tables:
+ table = parenttable.metadata.tables[table_key]
+ try:
+ _column = self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
+ except exc.NoReferencedColumnError:
+ # this is OK, we'll try later
+ pass
+ else:
+ self._set_target_column(_column)
+ parenttable.metadata._fk_memos[fk_key].append(self)
+ elif hasattr(self._colspec, "__clause_element__"):
+ _column = self._colspec.__clause_element__()
+ self._set_target_column(_column)
+ else:
+ _column = self._colspec
+ self._set_target_column(_column)
+
+
+class DefaultGenerator(Executable, SchemaItem):
+ """Base class for column *default* values."""
+
+ __visit_name__ = "default_generator"
+
+ is_sequence = False
+ is_server_default = False
+ column = None
+
+ def __init__(self, for_update=False):
+ self.for_update = for_update
+
+ def _set_parent(self, column, **kw):
+ self.column = column
+ if self.for_update:
+ self.column.onupdate = self
+ else:
+ self.column.default = self
+
+ @util.deprecated_20(
+ ":meth:`.DefaultGenerator.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, bind=None):
+ if bind is None:
+ bind = _bind_or_error(self)
+ return bind._execute_default(self, (), util.EMPTY_DICT)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_default(
+ self, multiparams, params, execution_options
+ )
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this default."""
+ if getattr(self, "column", None) is not None:
+ return self.column.table.bind
+ else:
+ return None
+
+
+class ColumnDefault(DefaultGenerator):
+ """A plain default value on a column.
+
+ This could correspond to a constant, a callable function,
+ or a SQL clause.
+
+ :class:`.ColumnDefault` is generated automatically
+ whenever the ``default``, ``onupdate`` arguments of
+ :class:`_schema.Column` are used. A :class:`.ColumnDefault`
+ can be passed positionally as well.
+
+ For example, the following::
+
+ Column('foo', Integer, default=50)
+
+ Is equivalent to::
+
+ Column('foo', Integer, ColumnDefault(50))
+
+
+ """
+
+ def __init__(self, arg, **kwargs):
+ """Construct a new :class:`.ColumnDefault`.
+
+
+ :param arg: argument representing the default value.
+ May be one of the following:
+
+ * a plain non-callable Python value, such as a
+ string, integer, boolean, or other simple type.
+ The default value will be used as is each time.
+ * a SQL expression, that is one which derives from
+ :class:`_expression.ColumnElement`. The SQL expression will
+ be rendered into the INSERT or UPDATE statement,
+ or in the case of a primary key column when
+ RETURNING is not used may be
+ pre-executed before an INSERT within a SELECT.
+ * A Python callable. The function will be invoked for each
+ new row subject to an INSERT or UPDATE.
+ The callable must accept exactly
+ zero or one positional arguments. The one-argument form
+ will receive an instance of the :class:`.ExecutionContext`,
+ which provides contextual information as to the current
+ :class:`_engine.Connection` in use as well as the current
+ statement and parameters.
+
+ """
+ super(ColumnDefault, self).__init__(**kwargs)
+ if isinstance(arg, FetchedValue):
+ raise exc.ArgumentError(
+ "ColumnDefault may not be a server-side default type."
+ )
+ if callable(arg):
+ arg = self._maybe_wrap_callable(arg)
+ self.arg = arg
+
+ @util.memoized_property
+ def is_callable(self):
+ return callable(self.arg)
+
+ @util.memoized_property
+ def is_clause_element(self):
+ return isinstance(self.arg, ClauseElement)
+
+ @util.memoized_property
+ def is_scalar(self):
+ return (
+ not self.is_callable
+ and not self.is_clause_element
+ and not self.is_sequence
+ )
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def _arg_is_typed(self):
+ sqltypes = util.preloaded.sql_sqltypes
+
+ if self.is_clause_element:
+ return not isinstance(self.arg.type, sqltypes.NullType)
+ else:
+ return False
+
+ def _maybe_wrap_callable(self, fn):
+ """Wrap callables that don't accept a context.
+
+ This is to allow easy compatibility with default callables
+ that aren't specific to accepting of a context.
+
+ """
+ try:
+ argspec = util.get_callable_argspec(fn, no_self=True)
+ except TypeError:
+ return util.wrap_callable(lambda ctx: fn(), fn)
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ positionals = len(argspec[0]) - defaulted
+
+ if positionals == 0:
+ return util.wrap_callable(lambda ctx: fn(), fn)
+
+ elif positionals == 1:
+ return fn
+ else:
+ raise exc.ArgumentError(
+ "ColumnDefault Python function takes zero or one "
+ "positional arguments"
+ )
+
+ def __repr__(self):
+ return "ColumnDefault(%r)" % (self.arg,)
+
+
+class IdentityOptions(object):
+ """Defines options for a named database sequence or an identity column.
+
+ .. versionadded:: 1.3.18
+
+ .. seealso::
+
+ :class:`.Sequence`
+
+ """
+
+ def __init__(
+ self,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ cache=None,
+ order=None,
+ ):
+ """Construct a :class:`.IdentityOptions` object.
+
+ See the :class:`.Sequence` documentation for a complete description
+ of the parameters.
+
+ :param start: the starting index of the sequence.
+ :param increment: the increment value of the sequence.
+ :param minvalue: the minimum value of the sequence.
+ :param maxvalue: the maximum value of the sequence.
+ :param nominvalue: no minimum value of the sequence.
+ :param nomaxvalue: no maximum value of the sequence.
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached.
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance.
+ :param order: optional boolean value; if ``True``, renders the
+ ORDER keyword.
+
+ """
+ self.start = start
+ self.increment = increment
+ self.minvalue = minvalue
+ self.maxvalue = maxvalue
+ self.nominvalue = nominvalue
+ self.nomaxvalue = nomaxvalue
+ self.cycle = cycle
+ self.cache = cache
+ self.order = order
+
+
+class Sequence(IdentityOptions, DefaultGenerator):
+ """Represents a named database sequence.
+
+ The :class:`.Sequence` object represents the name and configurational
+ parameters of a database sequence. It also represents
+ a construct that can be "executed" by a SQLAlchemy :class:`_engine.Engine`
+ or :class:`_engine.Connection`,
+ rendering the appropriate "next value" function
+ for the target database and returning a result.
+
+ The :class:`.Sequence` is typically associated with a primary key column::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, Sequence('some_table_seq'),
+ primary_key=True)
+ )
+
+ When CREATE TABLE is emitted for the above :class:`_schema.Table`, if the
+ target platform supports sequences, a CREATE SEQUENCE statement will
+ be emitted as well. For platforms that don't support sequences,
+ the :class:`.Sequence` construct is ignored.
+
+ .. seealso::
+
+ :ref:`defaults_sequences`
+
+ :class:`.CreateSequence`
+
+ :class:`.DropSequence`
+
+ """
+
+ __visit_name__ = "sequence"
+
+ is_sequence = True
+
+ def __init__(
+ self,
+ name,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ schema=None,
+ cache=None,
+ order=None,
+ data_type=None,
+ optional=False,
+ quote=None,
+ metadata=None,
+ quote_schema=None,
+ for_update=False,
+ ):
+ """Construct a :class:`.Sequence` object.
+
+ :param name: the name of the sequence.
+
+ :param start: the starting index of the sequence. This value is
+ used when the CREATE SEQUENCE command is emitted to the database
+ as the value of the "START WITH" clause. If ``None``, the
+ clause is omitted, which on most platforms indicates a starting
+ value of 1.
+ :param increment: the increment value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "INCREMENT BY" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates an
+ increment of 1.
+ :param minvalue: the minimum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "MINVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ minvalue of 1 and -2^63-1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param maxvalue: the maximum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "MAXVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ maxvalue of 2^63-1 and -1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param nominvalue: no minimum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "NO MINVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ minvalue of 1 and -2^63-1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param nomaxvalue: no maximum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "NO MAXVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ maxvalue of 2^63-1 and -1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached by an ascending or descending sequence
+ respectively. This value is used when the CREATE SEQUENCE command
+ is emitted to the database as the "CYCLE" clause. If the limit is
+ reached, the next number generated will be the minvalue or maxvalue,
+ respectively. If cycle=False (the default) any calls to nextval
+ after the sequence has reached its maximum value will return an
+ error.
+
+ .. versionadded:: 1.0.7
+
+ :param schema: optional schema name for the sequence, if located
+ in a schema other than the default. The rules for selecting the
+ schema name when a :class:`_schema.MetaData`
+ is also present are the same
+ as that of :paramref:`_schema.Table.schema`.
+
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance. Renders the CACHE keyword
+ understood by Oracle and PostgreSQL.
+
+ .. versionadded:: 1.1.12
+
+ :param order: optional boolean value; if ``True``, renders the
+ ORDER keyword, understood by Oracle, indicating the sequence is
+ definitively ordered. May be necessary to provide deterministic
+ ordering using Oracle RAC.
+
+ .. versionadded:: 1.1.12
+
+ :param data_type: The type to be returned by the sequence, for
+ dialects that allow us to choose between INTEGER, BIGINT, etc.
+ (e.g., mssql).
+
+ .. versionadded:: 1.4.0
+
+ :param optional: boolean value, when ``True``, indicates that this
+ :class:`.Sequence` object only needs to be explicitly generated
+ on backends that don't provide another way to generate primary
+ key identifiers. Currently, it essentially means, "don't create
+ this sequence on the PostgreSQL backend, where the SERIAL keyword
+ creates a sequence for us automatically".
+ :param quote: boolean value, when ``True`` or ``False``, explicitly
+ forces quoting of the :paramref:`_schema.Sequence.name` on or off.
+ When left at its default of ``None``, normal quoting rules based
+ on casing and reserved words take place.
+ :param quote_schema: Set the quoting preferences for the ``schema``
+ name.
+
+ :param metadata: optional :class:`_schema.MetaData` object which this
+ :class:`.Sequence` will be associated with. A :class:`.Sequence`
+ that is associated with a :class:`_schema.MetaData`
+ gains the following
+ capabilities:
+
+ * The :class:`.Sequence` will inherit the
+ :paramref:`_schema.MetaData.schema`
+ parameter specified to the target :class:`_schema.MetaData`, which
+ affects the production of CREATE / DROP DDL, if any.
+
+ * The :meth:`.Sequence.create` and :meth:`.Sequence.drop` methods
+ automatically use the engine bound to the :class:`_schema.MetaData`
+ object, if any.
+
+ * The :meth:`_schema.MetaData.create_all` and
+ :meth:`_schema.MetaData.drop_all`
+ methods will emit CREATE / DROP for this :class:`.Sequence`,
+ even if the :class:`.Sequence` is not associated with any
+ :class:`_schema.Table` / :class:`_schema.Column`
+ that's a member of this
+ :class:`_schema.MetaData`.
+
+ The above behaviors can only occur if the :class:`.Sequence` is
+ explicitly associated with the :class:`_schema.MetaData`
+ via this parameter.
+
+ .. seealso::
+
+ :ref:`sequence_metadata` - full discussion of the
+ :paramref:`.Sequence.metadata` parameter.
+
+ :param for_update: Indicates this :class:`.Sequence`, when associated
+ with a :class:`_schema.Column`,
+ should be invoked for UPDATE statements
+ on that column's table, rather than for INSERT statements, when
+ no value is otherwise present for that column in the statement.
+
+ """
+ DefaultGenerator.__init__(self, for_update=for_update)
+ IdentityOptions.__init__(
+ self,
+ start=start,
+ increment=increment,
+ minvalue=minvalue,
+ maxvalue=maxvalue,
+ nominvalue=nominvalue,
+ nomaxvalue=nomaxvalue,
+ cycle=cycle,
+ cache=cache,
+ order=order,
+ )
+ self.name = quoted_name(name, quote)
+ self.optional = optional
+ if schema is BLANK_SCHEMA:
+ self.schema = schema = None
+ elif metadata is not None and schema is None and metadata.schema:
+ self.schema = schema = metadata.schema
+ else:
+ self.schema = quoted_name(schema, quote_schema)
+ self.metadata = metadata
+ self._key = _get_table_key(name, schema)
+ if metadata:
+ self._set_metadata(metadata)
+ if data_type is not None:
+ self.data_type = to_instance(data_type)
+ else:
+ self.data_type = None
+
+ @util.memoized_property
+ def is_callable(self):
+ return False
+
+ @util.memoized_property
+ def is_clause_element(self):
+ return False
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def next_value(self):
+ """Return a :class:`.next_value` function element
+ which will render the appropriate increment function
+ for this :class:`.Sequence` within any SQL expression.
+
+ """
+ if self.bind:
+ return util.preloaded.sql_functions.func.next_value(
+ self, bind=self.bind
+ )
+ else:
+ return util.preloaded.sql_functions.func.next_value(self)
+
+ def _set_parent(self, column, **kw):
+ super(Sequence, self)._set_parent(column)
+ column._on_table_attach(self._set_table)
+
+ def _set_table(self, column, table):
+ self._set_metadata(table.metadata)
+
+ def _set_metadata(self, metadata):
+ self.metadata = metadata
+ self.metadata._sequences[self._key] = self
+
+ @property
+ def bind(self):
+ if self.metadata:
+ return self.metadata.bind
+ else:
+ return None
+
+ def create(self, bind=None, checkfirst=True):
+ """Creates this sequence in the database.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Drops this sequence from the database.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def _not_a_column_expr(self):
+ raise exc.InvalidRequestError(
+ "This %s cannot be used directly "
+ "as a column expression. Use func.next_value(sequence) "
+ "to produce a 'next value' function that's usable "
+ "as a column element." % self.__class__.__name__
+ )
+
+
+@inspection._self_inspects
+class FetchedValue(SchemaEventTarget):
+ """A marker for a transparent database-side default.
+
+ Use :class:`.FetchedValue` when the database is configured
+ to provide some automatic default for a column.
+
+ E.g.::
+
+ Column('foo', Integer, FetchedValue())
+
+ Would indicate that some trigger or default generator
+ will create a new value for the ``foo`` column during an
+ INSERT.
+
+ .. seealso::
+
+ :ref:`triggered_columns`
+
+ """
+
+ is_server_default = True
+ reflected = False
+ has_argument = False
+ is_clause_element = False
+
+ def __init__(self, for_update=False):
+ self.for_update = for_update
+
+ def _as_for_update(self, for_update):
+ if for_update == self.for_update:
+ return self
+ else:
+ return self._clone(for_update)
+
+ def _clone(self, for_update):
+ n = self.__class__.__new__(self.__class__)
+ n.__dict__.update(self.__dict__)
+ n.__dict__.pop("column", None)
+ n.for_update = for_update
+ return n
+
+ def _set_parent(self, column, **kw):
+ self.column = column
+ if self.for_update:
+ self.column.server_onupdate = self
+ else:
+ self.column.server_default = self
+
+ def __repr__(self):
+ return util.generic_repr(self)
+
+
+class DefaultClause(FetchedValue):
+ """A DDL-specified DEFAULT column value.
+
+ :class:`.DefaultClause` is a :class:`.FetchedValue`
+ that also generates a "DEFAULT" clause when
+ "CREATE TABLE" is emitted.
+
+ :class:`.DefaultClause` is generated automatically
+ whenever the ``server_default``, ``server_onupdate`` arguments of
+ :class:`_schema.Column` are used. A :class:`.DefaultClause`
+ can be passed positionally as well.
+
+ For example, the following::
+
+ Column('foo', Integer, server_default="50")
+
+ Is equivalent to::
+
+ Column('foo', Integer, DefaultClause("50"))
+
+ """
+
+ has_argument = True
+
+ def __init__(self, arg, for_update=False, _reflected=False):
+ util.assert_arg_type(
+ arg, (util.string_types[0], ClauseElement, TextClause), "arg"
+ )
+ super(DefaultClause, self).__init__(for_update)
+ self.arg = arg
+ self.reflected = _reflected
+
+ def __repr__(self):
+ return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
+
+
+class Constraint(DialectKWArgs, SchemaItem):
+ """A table-level SQL constraint.
+
+ :class:`_schema.Constraint` serves as the base class for the series of
+ constraint objects that can be associated with :class:`_schema.Table`
+ objects, including :class:`_schema.PrimaryKeyConstraint`,
+ :class:`_schema.ForeignKeyConstraint`
+ :class:`_schema.UniqueConstraint`, and
+ :class:`_schema.CheckConstraint`.
+
+ """
+
+ __visit_name__ = "constraint"
+
+ def __init__(
+ self,
+ name=None,
+ deferrable=None,
+ initially=None,
+ _create_rule=None,
+ info=None,
+ _type_bound=False,
+ **dialect_kw
+ ):
+ r"""Create a SQL constraint.
+
+ :param name:
+ Optional, the in-database name of this ``Constraint``.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. See
+ the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ :param _create_rule:
+ used internally by some datatypes that also create constraints.
+
+ :param _type_bound:
+ used internally to indicate that this constraint is associated with
+ a specific datatype.
+
+ """
+
+ self.name = name
+ self.deferrable = deferrable
+ self.initially = initially
+ if info:
+ self.info = info
+ self._create_rule = _create_rule
+ self._type_bound = _type_bound
+ util.set_creation_order(self)
+ self._validate_dialect_kwargs(dialect_kw)
+
+ @property
+ def table(self):
+ try:
+ if isinstance(self.parent, Table):
+ return self.parent
+ except AttributeError:
+ pass
+ raise exc.InvalidRequestError(
+ "This constraint is not bound to a table. Did you "
+ "mean to call table.append_constraint(constraint) ?"
+ )
+
+ def _set_parent(self, parent, **kw):
+ self.parent = parent
+ parent.constraints.add(self)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Constraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ raise NotImplementedError()
+
+
+class ColumnCollectionMixin(object):
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
+ objects.
+
+ This collection represents the columns which are referred to by
+ this object.
+
+ """
+
+ _allow_multiple_tables = False
+
+ def __init__(self, *columns, **kw):
+ _autoattach = kw.pop("_autoattach", True)
+ self._column_flag = kw.pop("_column_flag", False)
+ self.columns = DedupeColumnCollection()
+
+ processed_expressions = kw.pop("_gather_expressions", None)
+ if processed_expressions is not None:
+ self._pending_colargs = []
+ for (
+ expr,
+ column,
+ strname,
+ add_element,
+ ) in coercions.expect_col_expression_collection(
+ roles.DDLConstraintColumnRole, columns
+ ):
+ self._pending_colargs.append(add_element)
+ processed_expressions.append(expr)
+ else:
+ self._pending_colargs = [
+ coercions.expect(roles.DDLConstraintColumnRole, column)
+ for column in columns
+ ]
+
+ if _autoattach and self._pending_colargs:
+ self._check_attach()
+
+ def _check_attach(self, evt=False):
+ col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
+
+ cols_w_table = [c for c in col_objs if isinstance(c.table, Table)]
+
+ cols_wo_table = set(col_objs).difference(cols_w_table)
+ if cols_wo_table:
+ # feature #3341 - place event listeners for Column objects
+ # such that when all those cols are attached, we autoattach.
+ assert not evt, "Should not reach here on event call"
+
+ # issue #3411 - don't do the per-column auto-attach if some of the
+ # columns are specified as strings.
+ has_string_cols = set(
+ c for c in self._pending_colargs if c is not None
+ ).difference(col_objs)
+ if not has_string_cols:
+
+ def _col_attached(column, table):
+ # this isinstance() corresponds with the
+ # isinstance() above; only want to count Table-bound
+ # columns
+ if isinstance(table, Table):
+ cols_wo_table.discard(column)
+ if not cols_wo_table:
+ self._check_attach(evt=True)
+
+ self._cols_wo_table = cols_wo_table
+ for col in cols_wo_table:
+ col._on_table_attach(_col_attached)
+ return
+
+ columns = cols_w_table
+
+ tables = {c.table for c in columns}
+ if len(tables) == 1:
+ self._set_parent_with_dispatch(tables.pop())
+ elif len(tables) > 1 and not self._allow_multiple_tables:
+ table = columns[0].table
+ others = [c for c in columns[1:] if c.table is not table]
+ if others:
+ raise exc.ArgumentError(
+ "Column(s) %s are not part of table '%s'."
+ % (
+ ", ".join("'%s'" % c for c in others),
+ table.description,
+ )
+ )
+
+ def _col_expressions(self, table):
+ return [
+ table.c[col] if isinstance(col, util.string_types) else col
+ for col in self._pending_colargs
+ ]
+
+ def _set_parent(self, table, **kw):
+ for col in self._col_expressions(table):
+ if col is not None:
+ self.columns.add(col)
+
+
+class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
+ """A constraint that proxies a ColumnCollection."""
+
+ def __init__(self, *columns, **kw):
+ r"""
+ :param \*columns:
+ A sequence of column names or Column objects.
+
+ :param name:
+ Optional, the in-database name of this constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param \**kw: other keyword arguments including dialect-specific
+ arguments are propagated to the :class:`.Constraint` superclass.
+
+ """
+ _autoattach = kw.pop("_autoattach", True)
+ _column_flag = kw.pop("_column_flag", False)
+ Constraint.__init__(self, **kw)
+ ColumnCollectionMixin.__init__(
+ self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
+ )
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` representing the set of columns
+ for this constraint.
+
+ """
+
+ def _set_parent(self, table, **kw):
+ Constraint._set_parent(self, table)
+ ColumnCollectionMixin._set_parent(self, table)
+
+ def __contains__(self, x):
+ return x in self.columns
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ColumnCollectionConstraint.copy` method "
+ "is deprecated and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table=target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ # ticket #5276
+ constraint_kwargs = {}
+ for dialect_name in self.dialect_options:
+ dialect_options = self.dialect_options[dialect_name]._non_defaults
+ for (
+ dialect_option_key,
+ dialect_option_value,
+ ) in dialect_options.items():
+ constraint_kwargs[
+ dialect_name + "_" + dialect_option_key
+ ] = dialect_option_value
+
+ c = self.__class__(
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ *[
+ _copy_expression(expr, self.parent, target_table)
+ for expr in self.columns
+ ],
+ **constraint_kwargs
+ )
+ return self._schema_item_copy(c)
+
+ def contains_column(self, col):
+ """Return True if this constraint contains the given column.
+
+ Note that this object also contains an attribute ``.columns``
+ which is a :class:`_expression.ColumnCollection` of
+ :class:`_schema.Column` objects.
+
+ """
+
+ return self.columns.contains_column(col)
+
+ def __iter__(self):
+ return iter(self.columns)
+
+ def __len__(self):
+ return len(self.columns)
+
+
+class CheckConstraint(ColumnCollectionConstraint):
+ """A table- or column-level CHECK constraint.
+
+ Can be included in the definition of a Table or Column.
+ """
+
+ _allow_multiple_tables = True
+
+ __visit_name__ = "table_or_column_check_constraint"
+
+ @_document_text_coercion(
+ "sqltext",
+ ":class:`.CheckConstraint`",
+ ":paramref:`.CheckConstraint.sqltext`",
+ )
+ def __init__(
+ self,
+ sqltext,
+ name=None,
+ deferrable=None,
+ initially=None,
+ table=None,
+ info=None,
+ _create_rule=None,
+ _autoattach=True,
+ _type_bound=False,
+ **kw
+ ):
+ r"""Construct a CHECK constraint.
+
+ :param sqltext:
+ A string containing the constraint definition, which will be used
+ verbatim, or a SQL expression construct. If given as a string,
+ the object is converted to a :func:`_expression.text` object.
+ If the textual
+ string includes a colon character, escape this using a backslash::
+
+ CheckConstraint(r"foo ~ E'a(?\:b|c)d")
+
+ :param name:
+ Optional, the in-database name of the constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
+ columns = []
+ visitors.traverse(self.sqltext, {}, {"column": columns.append})
+
+ super(CheckConstraint, self).__init__(
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ _create_rule=_create_rule,
+ info=info,
+ _type_bound=_type_bound,
+ _autoattach=_autoattach,
+ *columns,
+ **kw
+ )
+ if table is not None:
+ self._set_parent_with_dispatch(table)
+
+ @property
+ def is_column_level(self):
+ return not isinstance(self.parent, Table)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.CheckConstraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table=target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ if target_table is not None:
+ # note that target_table is None for the copy process of
+ # a column-bound CheckConstraint, so this path is not reached
+ # in that case.
+ sqltext = _copy_expression(self.sqltext, self.table, target_table)
+ else:
+ sqltext = self.sqltext
+ c = CheckConstraint(
+ sqltext,
+ name=self.name,
+ initially=self.initially,
+ deferrable=self.deferrable,
+ _create_rule=self._create_rule,
+ table=target_table,
+ _autoattach=False,
+ _type_bound=self._type_bound,
+ )
+ return self._schema_item_copy(c)
+
+
+class ForeignKeyConstraint(ColumnCollectionConstraint):
+ """A table-level FOREIGN KEY constraint.
+
+ Defines a single column or composite FOREIGN KEY ... REFERENCES
+ constraint. For a no-frills, single column foreign key, adding a
+ :class:`_schema.ForeignKey` to the definition of a :class:`_schema.Column`
+ is a
+ shorthand equivalent for an unnamed, single column
+ :class:`_schema.ForeignKeyConstraint`.
+
+ Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
+
+ """
+
+ __visit_name__ = "foreign_key_constraint"
+
+ def __init__(
+ self,
+ columns,
+ refcolumns,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ use_alter=False,
+ link_to_name=False,
+ match=None,
+ table=None,
+ info=None,
+ **dialect_kw
+ ):
+ r"""Construct a composite-capable FOREIGN KEY.
+
+ :param columns: A sequence of local column names. The named columns
+ must be defined and present in the parent Table. The names should
+ match the ``key`` given to each column (defaults to the name) unless
+ ``link_to_name`` is True.
+
+ :param refcolumns: A sequence of foreign column names or Column
+ objects. The columns must all be located within the same Table.
+
+ :param name: Optional, the in-database name of the key.
+
+ :param onupdate: Optional string. If set, emit ON UPDATE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param ondelete: Optional string. If set, emit ON DELETE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT
+ DEFERRABLE when issuing DDL for this constraint.
+
+ :param initially: Optional string. If set, emit INITIALLY <value> when
+ issuing DDL for this constraint.
+
+ :param link_to_name: if True, the string name given in ``column`` is
+ the rendered name of the referenced column, not its locally assigned
+ ``key``.
+
+ :param use_alter: If True, do not emit the DDL for this constraint as
+ part of the CREATE TABLE definition. Instead, generate it via an
+ ALTER TABLE statement issued after the full collection of tables
+ have been created, and drop it via an ALTER TABLE statement before
+ the full collection of tables are dropped.
+
+ The use of :paramref:`_schema.ForeignKeyConstraint.use_alter` is
+ particularly geared towards the case where two or more tables
+ are established within a mutually-dependent foreign key constraint
+ relationship; however, the :meth:`_schema.MetaData.create_all` and
+ :meth:`_schema.MetaData.drop_all`
+ methods will perform this resolution
+ automatically, so the flag is normally not needed.
+
+ .. versionchanged:: 1.0.0 Automatic resolution of foreign key
+ cycles has been added, removing the need to use the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` in typical use
+ cases.
+
+ .. seealso::
+
+ :ref:`use_alter`
+
+ :param match: Optional string. If set, emit MATCH <value> when issuing
+ DDL for this constraint. Typical values include SIMPLE, PARTIAL
+ and FULL.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. See
+ the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ .. versionadded:: 0.9.2
+
+ """
+
+ Constraint.__init__(
+ self,
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ info=info,
+ **dialect_kw
+ )
+ self.onupdate = onupdate
+ self.ondelete = ondelete
+ self.link_to_name = link_to_name
+ self.use_alter = use_alter
+ self.match = match
+
+ if len(set(columns)) != len(refcolumns):
+ if len(set(columns)) != len(columns):
+ # e.g. FOREIGN KEY (a, a) REFERENCES r (b, c)
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint with duplicate source column "
+ "references are not supported."
+ )
+ else:
+ # e.g. FOREIGN KEY (a) REFERENCES r (b, c)
+ # paraphrasing
+ # https://www.postgresql.org/docs/current/static/ddl-constraints.html
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint number "
+ "of constrained columns must match the number of "
+ "referenced columns."
+ )
+
+ # standalone ForeignKeyConstraint - create
+ # associated ForeignKey objects which will be applied to hosted
+ # Column objects (in col.foreign_keys), either now or when attached
+ # to the Table for string-specified names
+ self.elements = [
+ ForeignKey(
+ refcol,
+ _constraint=self,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ **self.dialect_kwargs
+ )
+ for refcol in refcolumns
+ ]
+
+ ColumnCollectionMixin.__init__(self, *columns)
+ if table is not None:
+ if hasattr(self, "parent"):
+ assert table is self.parent
+ self._set_parent_with_dispatch(table)
+
+ def _append_element(self, column, fk):
+ self.columns.add(column)
+ self.elements.append(fk)
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` representing the set of columns
+ for this constraint.
+
+ """
+
+ elements = None
+ """A sequence of :class:`_schema.ForeignKey` objects.
+
+ Each :class:`_schema.ForeignKey`
+ represents a single referring column/referred
+ column pair.
+
+ This collection is intended to be read-only.
+
+ """
+
+ @property
+ def _elements(self):
+ # legacy - provide a dictionary view of (column_key, fk)
+ return util.OrderedDict(zip(self.column_keys, self.elements))
+
+ @property
+ def _referred_schema(self):
+ for elem in self.elements:
+ return elem._referred_schema
+ else:
+ return None
+
+ @property
+ def referred_table(self):
+ """The :class:`_schema.Table` object to which this
+ :class:`_schema.ForeignKeyConstraint` references.
+
+ This is a dynamically calculated attribute which may not be available
+ if the constraint and/or parent table is not yet associated with
+ a metadata collection that contains the referred table.
+
+ .. versionadded:: 1.0.0
+
+ """
+ return self.elements[0].column.table
+
+ def _validate_dest_table(self, table):
+ table_keys = set([elem._table_key() for elem in self.elements])
+ if None not in table_keys and len(table_keys) > 1:
+ elem0, elem1 = sorted(table_keys)[0:2]
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint on %s(%s) refers to "
+ "multiple remote tables: %s and %s"
+ % (table.fullname, self._col_description, elem0, elem1)
+ )
+
+ @property
+ def column_keys(self):
+ """Return a list of string keys representing the local
+ columns in this :class:`_schema.ForeignKeyConstraint`.
+
+ This list is either the original string arguments sent
+ to the constructor of the :class:`_schema.ForeignKeyConstraint`,
+ or if the constraint has been initialized with :class:`_schema.Column`
+ objects, is the string ``.key`` of each element.
+
+ .. versionadded:: 1.0.0
+
+ """
+ if hasattr(self, "parent"):
+ return self.columns.keys()
+ else:
+ return [
+ col.key if isinstance(col, ColumnElement) else str(col)
+ for col in self._pending_colargs
+ ]
+
+ @property
+ def _col_description(self):
+ return ", ".join(self.column_keys)
+
+ def _set_parent(self, table, **kw):
+ Constraint._set_parent(self, table)
+
+ try:
+ ColumnCollectionConstraint._set_parent(self, table)
+ except KeyError as ke:
+ util.raise_(
+ exc.ArgumentError(
+ "Can't create ForeignKeyConstraint "
+ "on table '%s': no column "
+ "named '%s' is present." % (table.description, ke.args[0])
+ ),
+ from_=ke,
+ )
+
+ for col, fk in zip(self.columns, self.elements):
+ if not hasattr(fk, "parent") or fk.parent is not col:
+ fk._set_parent_with_dispatch(col)
+
+ self._validate_dest_table(table)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ForeignKeyConstraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, schema=None, target_table=None, **kw):
+ return self._copy(schema=schema, target_table=target_table, **kw)
+
+ def _copy(self, schema=None, target_table=None, **kw):
+ fkc = ForeignKeyConstraint(
+ [x.parent.key for x in self.elements],
+ [
+ x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None,
+ )
+ for x in self.elements
+ ],
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ )
+ for self_fk, other_fk in zip(self.elements, fkc.elements):
+ self_fk._schema_item_copy(other_fk)
+ return self._schema_item_copy(fkc)
+
+
+class PrimaryKeyConstraint(ColumnCollectionConstraint):
+ """A table-level PRIMARY KEY constraint.
+
+ The :class:`.PrimaryKeyConstraint` object is present automatically
+ on any :class:`_schema.Table` object; it is assigned a set of
+ :class:`_schema.Column` objects corresponding to those marked with
+ the :paramref:`_schema.Column.primary_key` flag::
+
+ >>> my_table = Table('mytable', metadata,
+ ... Column('id', Integer, primary_key=True),
+ ... Column('version_id', Integer, primary_key=True),
+ ... Column('data', String(50))
+ ... )
+ >>> my_table.primary_key
+ PrimaryKeyConstraint(
+ Column('id', Integer(), table=<mytable>,
+ primary_key=True, nullable=False),
+ Column('version_id', Integer(), table=<mytable>,
+ primary_key=True, nullable=False)
+ )
+
+ The primary key of a :class:`_schema.Table` can also be specified by using
+ a :class:`.PrimaryKeyConstraint` object explicitly; in this mode of usage,
+ the "name" of the constraint can also be specified, as well as other
+ options which may be recognized by dialects::
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer),
+ Column('version_id', Integer),
+ Column('data', String(50)),
+ PrimaryKeyConstraint('id', 'version_id',
+ name='mytable_pk')
+ )
+
+ The two styles of column-specification should generally not be mixed.
+ An warning is emitted if the columns present in the
+ :class:`.PrimaryKeyConstraint`
+ don't match the columns that were marked as ``primary_key=True``, if both
+ are present; in this case, the columns are taken strictly from the
+ :class:`.PrimaryKeyConstraint` declaration, and those columns otherwise
+ marked as ``primary_key=True`` are ignored. This behavior is intended to
+ be backwards compatible with previous behavior.
+
+ .. versionchanged:: 0.9.2 Using a mixture of columns within a
+ :class:`.PrimaryKeyConstraint` in addition to columns marked as
+ ``primary_key=True`` now emits a warning if the lists don't match.
+ The ultimate behavior of ignoring those columns marked with the flag
+ only is currently maintained for backwards compatibility; this warning
+ may raise an exception in a future release.
+
+ For the use case where specific options are to be specified on the
+ :class:`.PrimaryKeyConstraint`, but the usual style of using
+ ``primary_key=True`` flags is still desirable, an empty
+ :class:`.PrimaryKeyConstraint` may be specified, which will take on the
+ primary key column collection from the :class:`_schema.Table` based on the
+ flags::
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('version_id', Integer, primary_key=True),
+ Column('data', String(50)),
+ PrimaryKeyConstraint(name='mytable_pk',
+ mssql_clustered=True)
+ )
+
+ .. versionadded:: 0.9.2 an empty :class:`.PrimaryKeyConstraint` may now
+ be specified for the purposes of establishing keyword arguments with
+ the constraint, independently of the specification of "primary key"
+ columns within the :class:`_schema.Table` itself; columns marked as
+ ``primary_key=True`` will be gathered into the empty constraint's
+ column collection.
+
+ """
+
+ __visit_name__ = "primary_key_constraint"
+
+ def __init__(self, *columns, **kw):
+ self._implicit_generated = kw.pop("_implicit_generated", False)
+ super(PrimaryKeyConstraint, self).__init__(*columns, **kw)
+
+ def _set_parent(self, table, **kw):
+ super(PrimaryKeyConstraint, self)._set_parent(table)
+
+ if table.primary_key is not self:
+ table.constraints.discard(table.primary_key)
+ table.primary_key = self
+ table.constraints.add(self)
+
+ table_pks = [c for c in table.c if c.primary_key]
+ if self.columns and table_pks and set(table_pks) != set(self.columns):
+ util.warn(
+ "Table '%s' specifies columns %s as primary_key=True, "
+ "not matching locally specified columns %s; setting the "
+ "current primary key columns to %s. This warning "
+ "may become an exception in a future release"
+ % (
+ table.name,
+ ", ".join("'%s'" % c.name for c in table_pks),
+ ", ".join("'%s'" % c.name for c in self.columns),
+ ", ".join("'%s'" % c.name for c in self.columns),
+ )
+ )
+ table_pks[:] = []
+
+ for c in self.columns:
+ c.primary_key = True
+ if c._user_defined_nullable is NULL_UNSPECIFIED:
+ c.nullable = False
+ if table_pks:
+ self.columns.extend(table_pks)
+
+ def _reload(self, columns):
+ """repopulate this :class:`.PrimaryKeyConstraint` given
+ a set of columns.
+
+ Existing columns in the table that are marked as primary_key=True
+ are maintained.
+
+ Also fires a new event.
+
+ This is basically like putting a whole new
+ :class:`.PrimaryKeyConstraint` object on the parent
+ :class:`_schema.Table` object without actually replacing the object.
+
+ The ordering of the given list of columns is also maintained; these
+ columns will be appended to the list of columns after any which
+ are already present.
+
+ """
+ # set the primary key flag on new columns.
+ # note any existing PK cols on the table also have their
+ # flag still set.
+ for col in columns:
+ col.primary_key = True
+
+ self.columns.extend(columns)
+
+ PrimaryKeyConstraint._autoincrement_column._reset(self)
+ self._set_parent_with_dispatch(self.table)
+
+ def _replace(self, col):
+ PrimaryKeyConstraint._autoincrement_column._reset(self)
+ self.columns.replace(col)
+
+ self.dispatch._sa_event_column_added_to_pk_constraint(self, col)
+
+ @property
+ def columns_autoinc_first(self):
+ autoinc = self._autoincrement_column
+
+ if autoinc is not None:
+ return [autoinc] + [c for c in self.columns if c is not autoinc]
+ else:
+ return list(self.columns)
+
+ @util.memoized_property
+ def _autoincrement_column(self):
+ def _validate_autoinc(col, autoinc_true):
+ if col.type._type_affinity is None or not issubclass(
+ col.type._type_affinity,
+ (
+ type_api.INTEGERTYPE._type_affinity,
+ type_api.NUMERICTYPE._type_affinity,
+ ),
+ ):
+ if autoinc_true:
+ raise exc.ArgumentError(
+ "Column type %s on column '%s' is not "
+ "compatible with autoincrement=True" % (col.type, col)
+ )
+ else:
+ return False
+ elif (
+ not isinstance(col.default, (type(None), Sequence))
+ and not autoinc_true
+ ):
+ return False
+ elif (
+ col.server_default is not None
+ and not isinstance(col.server_default, Identity)
+ and not autoinc_true
+ ):
+ return False
+ elif col.foreign_keys and col.autoincrement not in (
+ True,
+ "ignore_fk",
+ ):
+ return False
+ return True
+
+ if len(self.columns) == 1:
+ col = list(self.columns)[0]
+
+ if col.autoincrement is True:
+ _validate_autoinc(col, True)
+ return col
+ elif (
+ col.autoincrement
+ in (
+ "auto",
+ "ignore_fk",
+ )
+ and _validate_autoinc(col, False)
+ ):
+ return col
+
+ else:
+ autoinc = None
+ for col in self.columns:
+ if col.autoincrement is True:
+ _validate_autoinc(col, True)
+ if autoinc is not None:
+ raise exc.ArgumentError(
+ "Only one Column may be marked "
+ "autoincrement=True, found both %s and %s."
+ % (col.name, autoinc.name)
+ )
+ else:
+ autoinc = col
+
+ return autoinc
+
+
+class UniqueConstraint(ColumnCollectionConstraint):
+ """A table-level UNIQUE constraint.
+
+ Defines a single column or composite UNIQUE constraint. For a no-frills,
+ single column constraint, adding ``unique=True`` to the ``Column``
+ definition is a shorthand equivalent for an unnamed, single column
+ UniqueConstraint.
+ """
+
+ __visit_name__ = "unique_constraint"
+
+
+class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
+ """A table-level INDEX.
+
+ Defines a composite (one or more column) INDEX.
+
+ E.g.::
+
+ sometable = Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100))
+ )
+
+ Index("some_index", sometable.c.name)
+
+ For a no-frills, single column index, adding
+ :class:`_schema.Column` also supports ``index=True``::
+
+ sometable = Table("sometable", metadata,
+ Column("name", String(50), index=True)
+ )
+
+ For a composite index, multiple columns can be specified::
+
+ Index("some_index", sometable.c.name, sometable.c.address)
+
+ Functional indexes are supported as well, typically by using the
+ :data:`.func` construct in conjunction with table-bound
+ :class:`_schema.Column` objects::
+
+ Index("some_index", func.lower(sometable.c.name))
+
+ An :class:`.Index` can also be manually associated with a
+ :class:`_schema.Table`,
+ either through inline declaration or using
+ :meth:`_schema.Table.append_constraint`. When this approach is used,
+ the names
+ of the indexed columns can be specified as strings::
+
+ Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100)),
+ Index("some_index", "name", "address")
+ )
+
+ To support functional or expression-based indexes in this form, the
+ :func:`_expression.text` construct may be used::
+
+ from sqlalchemy import text
+
+ Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100)),
+ Index("some_index", text("lower(name)"))
+ )
+
+ .. versionadded:: 0.9.5 the :func:`_expression.text`
+ construct may be used to
+ specify :class:`.Index` expressions, provided the :class:`.Index`
+ is explicitly associated with the :class:`_schema.Table`.
+
+
+ .. seealso::
+
+ :ref:`schema_indexes` - General information on :class:`.Index`.
+
+ :ref:`postgresql_indexes` - PostgreSQL-specific options available for
+ the :class:`.Index` construct.
+
+ :ref:`mysql_indexes` - MySQL-specific options available for the
+ :class:`.Index` construct.
+
+ :ref:`mssql_indexes` - MSSQL-specific options available for the
+ :class:`.Index` construct.
+
+ """
+
+ __visit_name__ = "index"
+
+ def __init__(self, name, *expressions, **kw):
+ r"""Construct an index object.
+
+ :param name:
+ The name of the index
+
+ :param \*expressions:
+ Column expressions to include in the index. The expressions
+ are normally instances of :class:`_schema.Column`, but may also
+ be arbitrary SQL expressions which ultimately refer to a
+ :class:`_schema.Column`.
+
+ :param unique=False:
+ Keyword only argument; if True, create a unique index.
+
+ :param quote=None:
+ Keyword only argument; whether to apply quoting to the name of
+ the index. Works in the same manner as that of
+ :paramref:`_schema.Column.quote`.
+
+ :param info=None: Optional data dictionary which will be populated
+ into the :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**kw: Additional keyword arguments not mentioned above are
+ dialect specific, and passed in the form
+ ``<dialectname>_<argname>``. See the documentation regarding an
+ individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ """
+ self.table = table = None
+
+ self.name = quoted_name(name, kw.pop("quote", None))
+ self.unique = kw.pop("unique", False)
+ _column_flag = kw.pop("_column_flag", False)
+ if "info" in kw:
+ self.info = kw.pop("info")
+
+ # TODO: consider "table" argument being public, but for
+ # the purpose of the fix here, it starts as private.
+ if "_table" in kw:
+ table = kw.pop("_table")
+
+ self._validate_dialect_kwargs(kw)
+
+ self.expressions = []
+ # will call _set_parent() if table-bound column
+ # objects are present
+ ColumnCollectionMixin.__init__(
+ self,
+ *expressions,
+ _column_flag=_column_flag,
+ _gather_expressions=self.expressions
+ )
+
+ if table is not None:
+ self._set_parent(table)
+
+ def _set_parent(self, table, **kw):
+ ColumnCollectionMixin._set_parent(self, table)
+
+ if self.table is not None and table is not self.table:
+ raise exc.ArgumentError(
+ "Index '%s' is against table '%s', and "
+ "cannot be associated with table '%s'."
+ % (self.name, self.table.description, table.description)
+ )
+ self.table = table
+ table.indexes.add(self)
+
+ expressions = self.expressions
+ col_expressions = self._col_expressions(table)
+ assert len(expressions) == len(col_expressions)
+ self.expressions = [
+ expr if isinstance(expr, ClauseElement) else colexpr
+ for expr, colexpr in zip(expressions, col_expressions)
+ ]
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this Index."""
+
+ return self.table.bind
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue a ``CREATE`` statement for this
+ :class:`.Index`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.create_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+ return self
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue a ``DROP`` statement for this
+ :class:`.Index`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.drop_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def __repr__(self):
+ return "Index(%s)" % (
+ ", ".join(
+ [repr(self.name)]
+ + [repr(e) for e in self.expressions]
+ + (self.unique and ["unique=True"] or [])
+ )
+ )
+
+
+DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
+
+
+class MetaData(SchemaItem):
+ """A collection of :class:`_schema.Table`
+ objects and their associated schema
+ constructs.
+
+ Holds a collection of :class:`_schema.Table` objects as well as
+ an optional binding to an :class:`_engine.Engine` or
+ :class:`_engine.Connection`. If bound, the :class:`_schema.Table` objects
+ in the collection and their columns may participate in implicit SQL
+ execution.
+
+ The :class:`_schema.Table` objects themselves are stored in the
+ :attr:`_schema.MetaData.tables` dictionary.
+
+ :class:`_schema.MetaData` is a thread-safe object for read operations.
+ Construction of new tables within a single :class:`_schema.MetaData`
+ object,
+ either explicitly or via reflection, may not be completely thread-safe.
+
+ .. seealso::
+
+ :ref:`metadata_describing` - Introduction to database metadata
+
+ """
+
+ __visit_name__ = "metadata"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_schema.MetaData.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ bind=None,
+ schema=None,
+ quote_schema=None,
+ naming_convention=None,
+ info=None,
+ ):
+ """Create a new MetaData object.
+
+ :param bind:
+ An Engine or Connection to bind to. May also be a string or URL
+ instance, these are passed to :func:`_sa.create_engine` and
+ this :class:`_schema.MetaData` will
+ be bound to the resulting engine.
+
+ :param schema:
+ The default schema to use for the :class:`_schema.Table`,
+ :class:`.Sequence`, and potentially other objects associated with
+ this :class:`_schema.MetaData`. Defaults to ``None``.
+
+ .. seealso::
+
+ :ref:`schema_metadata_schema_name` - details on how the
+ :paramref:`_schema.MetaData.schema` parameter is used.
+
+ :paramref:`_schema.Table.schema`
+
+ :paramref:`.Sequence.schema`
+
+ :param quote_schema:
+ Sets the ``quote_schema`` flag for those :class:`_schema.Table`,
+ :class:`.Sequence`, and other objects which make usage of the
+ local ``schema`` name.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param naming_convention: a dictionary referring to values which
+ will establish default naming conventions for :class:`.Constraint`
+ and :class:`.Index` objects, for those objects which are not given
+ a name explicitly.
+
+ The keys of this dictionary may be:
+
+ * a constraint or Index class, e.g. the :class:`.UniqueConstraint`,
+ :class:`_schema.ForeignKeyConstraint` class, the :class:`.Index`
+ class
+
+ * a string mnemonic for one of the known constraint classes;
+ ``"fk"``, ``"pk"``, ``"ix"``, ``"ck"``, ``"uq"`` for foreign key,
+ primary key, index, check, and unique constraint, respectively.
+
+ * the string name of a user-defined "token" that can be used
+ to define new naming tokens.
+
+ The values associated with each "constraint class" or "constraint
+ mnemonic" key are string naming templates, such as
+ ``"uq_%(table_name)s_%(column_0_name)s"``,
+ which describe how the name should be composed. The values
+ associated with user-defined "token" keys should be callables of the
+ form ``fn(constraint, table)``, which accepts the constraint/index
+ object and :class:`_schema.Table` as arguments, returning a string
+ result.
+
+ The built-in names are as follows, some of which may only be
+ available for certain types of constraint:
+
+ * ``%(table_name)s`` - the name of the :class:`_schema.Table`
+ object
+ associated with the constraint.
+
+ * ``%(referred_table_name)s`` - the name of the
+ :class:`_schema.Table`
+ object associated with the referencing target of a
+ :class:`_schema.ForeignKeyConstraint`.
+
+ * ``%(column_0_name)s`` - the name of the :class:`_schema.Column`
+ at
+ index position "0" within the constraint.
+
+ * ``%(column_0N_name)s`` - the name of all :class:`_schema.Column`
+ objects in order within the constraint, joined without a
+ separator.
+
+ * ``%(column_0_N_name)s`` - the name of all
+ :class:`_schema.Column`
+ objects in order within the constraint, joined with an
+ underscore as a separator.
+
+ * ``%(column_0_label)s``, ``%(column_0N_label)s``,
+ ``%(column_0_N_label)s`` - the label of either the zeroth
+ :class:`_schema.Column` or all :class:`.Columns`, separated with
+ or without an underscore
+
+ * ``%(column_0_key)s``, ``%(column_0N_key)s``,
+ ``%(column_0_N_key)s`` - the key of either the zeroth
+ :class:`_schema.Column` or all :class:`.Columns`, separated with
+ or without an underscore
+
+ * ``%(referred_column_0_name)s``, ``%(referred_column_0N_name)s``
+ ``%(referred_column_0_N_name)s``, ``%(referred_column_0_key)s``,
+ ``%(referred_column_0N_key)s``, ... column tokens which
+ render the names/keys/labels of columns that are referenced
+ by a :class:`_schema.ForeignKeyConstraint`.
+
+ * ``%(constraint_name)s`` - a special key that refers to the
+ existing name given to the constraint. When this key is
+ present, the :class:`.Constraint` object's existing name will be
+ replaced with one that is composed from template string that
+ uses this token. When this token is present, it is required that
+ the :class:`.Constraint` is given an explicit name ahead of time.
+
+ * user-defined: any additional token may be implemented by passing
+ it along with a ``fn(constraint, table)`` callable to the
+ naming_convention dictionary.
+
+ .. versionadded:: 1.3.0 - added new ``%(column_0N_name)s``,
+ ``%(column_0_N_name)s``, and related tokens that produce
+ concatenations of names, keys, or labels for all columns referred
+ to by a given constraint.
+
+ .. seealso::
+
+ :ref:`constraint_naming_conventions` - for detailed usage
+ examples.
+
+ """
+ self.tables = util.FacadeDict()
+ self.schema = quoted_name(schema, quote_schema)
+ self.naming_convention = (
+ naming_convention
+ if naming_convention
+ else DEFAULT_NAMING_CONVENTION
+ )
+ if info:
+ self.info = info
+ self._schemas = set()
+ self._sequences = {}
+ self._fk_memos = collections.defaultdict(list)
+
+ self.bind = bind
+
+ tables = None
+ """A dictionary of :class:`_schema.Table`
+ objects keyed to their name or "table key".
+
+ The exact key is that determined by the :attr:`_schema.Table.key`
+ attribute;
+ for a table with no :attr:`_schema.Table.schema` attribute,
+ this is the same
+ as :attr:`_schema.Table.name`. For a table with a schema,
+ it is typically of the
+ form ``schemaname.tablename``.
+
+ .. seealso::
+
+ :attr:`_schema.MetaData.sorted_tables`
+
+ """
+
+ def __repr__(self):
+ if self.bind:
+ return "MetaData(bind=%r)" % self.bind
+ else:
+ return "MetaData()"
+
+ def __contains__(self, table_or_key):
+ if not isinstance(table_or_key, util.string_types):
+ table_or_key = table_or_key.key
+ return table_or_key in self.tables
+
+ def _add_table(self, name, schema, table):
+ key = _get_table_key(name, schema)
+ self.tables._insert_item(key, table)
+ if schema:
+ self._schemas.add(schema)
+
+ def _remove_table(self, name, schema):
+ key = _get_table_key(name, schema)
+ removed = dict.pop(self.tables, key, None)
+ if removed is not None:
+ for fk in removed.foreign_keys:
+ fk._remove_from_metadata(self)
+ if self._schemas:
+ self._schemas = set(
+ [
+ t.schema
+ for t in self.tables.values()
+ if t.schema is not None
+ ]
+ )
+
+ def __getstate__(self):
+ return {
+ "tables": self.tables,
+ "schema": self.schema,
+ "schemas": self._schemas,
+ "sequences": self._sequences,
+ "fk_memos": self._fk_memos,
+ "naming_convention": self.naming_convention,
+ }
+
+ def __setstate__(self, state):
+ self.tables = state["tables"]
+ self.schema = state["schema"]
+ self.naming_convention = state["naming_convention"]
+ self._bind = None
+ self._sequences = state["sequences"]
+ self._schemas = state["schemas"]
+ self._fk_memos = state["fk_memos"]
+
+ def is_bound(self):
+ """True if this MetaData is bound to an Engine or Connection."""
+
+ return self._bind is not None
+
+ def bind(self):
+ """An :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this
+ :class:`_schema.MetaData` is bound.
+
+ Typically, a :class:`_engine.Engine` is assigned to this attribute
+ so that "implicit execution" may be used, or alternatively
+ as a means of providing engine binding information to an
+ ORM :class:`.Session` object::
+
+ engine = create_engine("someurl://")
+ metadata.bind = engine
+
+ .. deprecated :: 1.4
+
+ The metadata.bind attribute, as part of the deprecated system
+ of "implicit execution", is itself deprecated and will be
+ removed in SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`dbengine_implicit` - background on "bound metadata"
+
+ """
+ return self._bind
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def _bind_to(self, bind):
+ """Bind this MetaData to an Engine, Connection, string or URL."""
+ url = util.preloaded.engine_url
+ if isinstance(bind, util.string_types + (url.URL,)):
+ self._bind = sqlalchemy.create_engine(bind)
+ else:
+ self._bind = bind
+
+ bind = property(bind, _bind_to)
+
+ def clear(self):
+ """Clear all Table objects from this MetaData."""
+
+ dict.clear(self.tables)
+ self._schemas.clear()
+ self._fk_memos.clear()
+
+ def remove(self, table):
+ """Remove the given Table object from this MetaData."""
+
+ self._remove_table(table.name, table.schema)
+
+ @property
+ def sorted_tables(self):
+ """Returns a list of :class:`_schema.Table` objects sorted in order of
+ foreign key dependency.
+
+ The sorting will place :class:`_schema.Table`
+ objects that have dependencies
+ first, before the dependencies themselves, representing the
+ order in which they can be created. To get the order in which
+ the tables would be dropped, use the ``reversed()`` Python built-in.
+
+ .. warning::
+
+ The :attr:`.MetaData.sorted_tables` attribute cannot by itself
+ accommodate automatic resolution of dependency cycles between
+ tables, which are usually caused by mutually dependent foreign key
+ constraints. When these cycles are detected, the foreign keys
+ of these tables are omitted from consideration in the sort.
+ A warning is emitted when this condition occurs, which will be an
+ exception raise in a future release. Tables which are not part
+ of the cycle will still be returned in dependency order.
+
+ To resolve these cycles, the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be
+ applied to those constraints which create a cycle. Alternatively,
+ the :func:`_schema.sort_tables_and_constraints` function will
+ automatically return foreign key constraints in a separate
+ collection when cycles are detected so that they may be applied
+ to a schema separately.
+
+ .. versionchanged:: 1.3.17 - a warning is emitted when
+ :attr:`.MetaData.sorted_tables` cannot perform a proper sort
+ due to cyclical dependencies. This will be an exception in a
+ future release. Additionally, the sort will continue to return
+ other tables not involved in the cycle in dependency order which
+ was not the case previously.
+
+ .. seealso::
+
+ :func:`_schema.sort_tables`
+
+ :func:`_schema.sort_tables_and_constraints`
+
+ :attr:`_schema.MetaData.tables`
+
+ :meth:`_reflection.Inspector.get_table_names`
+
+ :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names`
+
+
+ """
+ return ddl.sort_tables(
+ sorted(self.tables.values(), key=lambda t: t.key)
+ )
+
+ def reflect(
+ self,
+ bind=None,
+ schema=None,
+ views=False,
+ only=None,
+ extend_existing=False,
+ autoload_replace=True,
+ resolve_fks=True,
+ **dialect_kwargs
+ ):
+ r"""Load all available table definitions from the database.
+
+ Automatically creates ``Table`` entries in this ``MetaData`` for any
+ table available in the database but not yet present in the
+ ``MetaData``. May be called multiple times to pick up tables recently
+ added to the database, however no special action is taken if a table
+ in this ``MetaData`` no longer exists in the database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the database; if None, uses
+ the existing bind on this ``MetaData``, if any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param schema:
+ Optional, query and reflect tables from an alternate schema.
+ If None, the schema associated with this :class:`_schema.MetaData`
+ is used, if any.
+
+ :param views:
+ If True, also reflect views.
+
+ :param only:
+ Optional. Load only a sub-set of available named tables. May be
+ specified as a sequence of names or a callable.
+
+ If a sequence of names is provided, only those tables will be
+ reflected. An error is raised if a table is requested but not
+ available. Named tables already present in this ``MetaData`` are
+ ignored.
+
+ If a callable is provided, it will be used as a boolean predicate to
+ filter the list of potential table names. The callable is called
+ with a table name and this ``MetaData`` instance as positional
+ arguments and should return a true value for any table to reflect.
+
+ :param extend_existing: Passed along to each :class:`_schema.Table` as
+ :paramref:`_schema.Table.extend_existing`.
+
+ .. versionadded:: 0.9.1
+
+ :param autoload_replace: Passed along to each :class:`_schema.Table`
+ as
+ :paramref:`_schema.Table.autoload_replace`.
+
+ .. versionadded:: 0.9.1
+
+ :param resolve_fks: if True, reflect :class:`_schema.Table`
+ objects linked
+ to :class:`_schema.ForeignKey` objects located in each
+ :class:`_schema.Table`.
+ For :meth:`_schema.MetaData.reflect`,
+ this has the effect of reflecting
+ related tables that might otherwise not be in the list of tables
+ being reflected, for example if the referenced table is in a
+ different schema or is omitted via the
+ :paramref:`.MetaData.reflect.only` parameter. When False,
+ :class:`_schema.ForeignKey` objects are not followed to the
+ :class:`_schema.Table`
+ in which they link, however if the related table is also part of the
+ list of tables that would be reflected in any case, the
+ :class:`_schema.ForeignKey` object will still resolve to its related
+ :class:`_schema.Table` after the :meth:`_schema.MetaData.reflect`
+ operation is
+ complete. Defaults to True.
+
+ .. versionadded:: 1.3.0
+
+ .. seealso::
+
+ :paramref:`_schema.Table.resolve_fks`
+
+ :param \**dialect_kwargs: Additional keyword arguments not mentioned
+ above are dialect specific, and passed in the form
+ ``<dialectname>_<argname>``. See the documentation regarding an
+ individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ .. versionadded:: 0.9.2 - Added
+ :paramref:`.MetaData.reflect.**dialect_kwargs` to support
+ dialect-level reflection options for all :class:`_schema.Table`
+ objects reflected.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ with inspection.inspect(bind)._inspection_context() as insp:
+ reflect_opts = {
+ "autoload_with": insp,
+ "extend_existing": extend_existing,
+ "autoload_replace": autoload_replace,
+ "resolve_fks": resolve_fks,
+ "_extend_on": set(),
+ }
+
+ reflect_opts.update(dialect_kwargs)
+
+ if schema is None:
+ schema = self.schema
+
+ if schema is not None:
+ reflect_opts["schema"] = schema
+
+ available = util.OrderedSet(insp.get_table_names(schema))
+ if views:
+ available.update(insp.get_view_names(schema))
+
+ if schema is not None:
+ available_w_schema = util.OrderedSet(
+ ["%s.%s" % (schema, name) for name in available]
+ )
+ else:
+ available_w_schema = available
+
+ current = set(self.tables)
+
+ if only is None:
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if extend_existing or schname not in current
+ ]
+ elif callable(only):
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if (extend_existing or schname not in current)
+ and only(name, self)
+ ]
+ else:
+ missing = [name for name in only if name not in available]
+ if missing:
+ s = schema and (" schema '%s'" % schema) or ""
+ raise exc.InvalidRequestError(
+ "Could not reflect: requested table(s) not available "
+ "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing))
+ )
+ load = [
+ name
+ for name in only
+ if extend_existing or name not in current
+ ]
+
+ for name in load:
+ try:
+ Table(name, self, **reflect_opts)
+ except exc.UnreflectableTableError as uerr:
+ util.warn("Skipping table %s: %s" % (name, uerr))
+
+ def create_all(self, bind=None, tables=None, checkfirst=True):
+ """Create all tables stored in this metadata.
+
+ Conditional by default, will not attempt to recreate tables already
+ present in the target database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the
+ database; if None, uses the existing bind on this ``MetaData``, if
+ any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param tables:
+ Optional list of ``Table`` objects, which is a subset of the total
+ tables in the ``MetaData`` (others are ignored).
+
+ :param checkfirst:
+ Defaults to True, don't issue CREATEs for tables already present
+ in the target database.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(
+ ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables
+ )
+
+ def drop_all(self, bind=None, tables=None, checkfirst=True):
+ """Drop all tables stored in this metadata.
+
+ Conditional by default, will not attempt to drop tables not present in
+ the target database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the
+ database; if None, uses the existing bind on this ``MetaData``, if
+ any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param tables:
+ Optional list of ``Table`` objects, which is a subset of the
+ total tables in the ``MetaData`` (others are ignored).
+
+ :param checkfirst:
+ Defaults to True, only issue DROPs for tables confirmed to be
+ present in the target database.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(
+ ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables
+ )
+
+
+@util.deprecated_cls(
+ "1.4",
+ ":class:`.ThreadLocalMetaData` is deprecated and will be removed "
+ "in a future release.",
+ constructor="__init__",
+)
+class ThreadLocalMetaData(MetaData):
+ """A MetaData variant that presents a different ``bind`` in every thread.
+
+ Makes the ``bind`` property of the MetaData a thread-local value, allowing
+ this collection of tables to be bound to different ``Engine``
+ implementations or connections in each thread.
+
+ The ThreadLocalMetaData starts off bound to None in each thread. Binds
+ must be made explicitly by assigning to the ``bind`` property or using
+ ``connect()``. You can also re-bind dynamically multiple times per
+ thread, just like a regular ``MetaData``.
+
+ """
+
+ __visit_name__ = "metadata"
+
+ def __init__(self):
+ """Construct a ThreadLocalMetaData."""
+
+ self.context = util.threading.local()
+ self.__engines = {}
+ super(ThreadLocalMetaData, self).__init__()
+
+ def bind(self):
+ """The bound Engine or Connection for this thread.
+
+ This property may be assigned an Engine or Connection, or assigned a
+ string or URL to automatically create a basic Engine for this bind
+ with ``create_engine()``."""
+
+ return getattr(self.context, "_engine", None)
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def _bind_to(self, bind):
+ """Bind to a Connectable in the caller's thread."""
+ url = util.preloaded.engine_url
+ if isinstance(bind, util.string_types + (url.URL,)):
+ try:
+ self.context._engine = self.__engines[bind]
+ except KeyError:
+ e = sqlalchemy.create_engine(bind)
+ self.__engines[bind] = e
+ self.context._engine = e
+ else:
+ # TODO: this is squirrely. we shouldn't have to hold onto engines
+ # in a case like this
+ if bind not in self.__engines:
+ self.__engines[bind] = bind
+ self.context._engine = bind
+
+ bind = property(bind, _bind_to)
+
+ def is_bound(self):
+ """True if there is a bind for this thread."""
+ return (
+ hasattr(self.context, "_engine")
+ and self.context._engine is not None
+ )
+
+ def dispose(self):
+ """Dispose all bound engines, in all thread contexts."""
+
+ for e in self.__engines.values():
+ if hasattr(e, "dispose"):
+ e.dispose()
+
+
+class Computed(FetchedValue, SchemaItem):
+ """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax.
+
+ The :class:`.Computed` construct is an inline construct added to the
+ argument list of a :class:`_schema.Column` object::
+
+ from sqlalchemy import Computed
+
+ Table('square', metadata_obj,
+ Column('side', Float, nullable=False),
+ Column('area', Float, Computed('side * side'))
+ )
+
+ See the linked documentation below for complete details.
+
+ .. versionadded:: 1.3.11
+
+ .. seealso::
+
+ :ref:`computed_ddl`
+
+ """
+
+ __visit_name__ = "computed_column"
+
+ @_document_text_coercion(
+ "sqltext", ":class:`.Computed`", ":paramref:`.Computed.sqltext`"
+ )
+ def __init__(self, sqltext, persisted=None):
+ """Construct a GENERATED ALWAYS AS DDL construct to accompany a
+ :class:`_schema.Column`.
+
+ :param sqltext:
+ A string containing the column generation expression, which will be
+ used verbatim, or a SQL expression construct, such as a
+ :func:`_expression.text`
+ object. If given as a string, the object is converted to a
+ :func:`_expression.text` object.
+
+ :param persisted:
+ Optional, controls how this column should be persisted by the
+ database. Possible values are:
+
+ * ``None``, the default, it will use the default persistence
+ defined by the database.
+ * ``True``, will render ``GENERATED ALWAYS AS ... STORED``, or the
+ equivalent for the target database if supported.
+ * ``False``, will render ``GENERATED ALWAYS AS ... VIRTUAL``, or
+ the equivalent for the target database if supported.
+
+ Specifying ``True`` or ``False`` may raise an error when the DDL
+ is emitted to the target database if the database does not support
+ that persistence option. Leaving this parameter at its default
+ of ``None`` is guaranteed to succeed for all databases that support
+ ``GENERATED ALWAYS AS``.
+
+ """
+ self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
+ self.persisted = persisted
+ self.column = None
+
+ def _set_parent(self, parent, **kw):
+ if not isinstance(
+ parent.server_default, (type(None), Computed)
+ ) or not isinstance(parent.server_onupdate, (type(None), Computed)):
+ raise exc.ArgumentError(
+ "A generated column cannot specify a server_default or a "
+ "server_onupdate argument"
+ )
+ self.column = parent
+ parent.computed = self
+ self.column.server_onupdate = self
+ self.column.server_default = self
+
+ def _as_for_update(self, for_update):
+ return self
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Computed.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ sqltext = _copy_expression(
+ self.sqltext,
+ self.column.table if self.column is not None else None,
+ target_table,
+ )
+ g = Computed(sqltext, persisted=self.persisted)
+
+ return self._schema_item_copy(g)
+
+
+class Identity(IdentityOptions, FetchedValue, SchemaItem):
+ """Defines an identity column, i.e. "GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY" syntax.
+
+ The :class:`.Identity` construct is an inline construct added to the
+ argument list of a :class:`_schema.Column` object::
+
+ from sqlalchemy import Identity
+
+ Table('foo', metadata_obj,
+ Column('id', Integer, Identity())
+ Column('description', Text),
+ )
+
+ See the linked documentation below for complete details.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`identity_ddl`
+
+ """
+
+ __visit_name__ = "identity_column"
+
+ def __init__(
+ self,
+ always=False,
+ on_null=None,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ cache=None,
+ order=None,
+ ):
+ """Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL
+ construct to accompany a :class:`_schema.Column`.
+
+ See the :class:`.Sequence` documentation for a complete description
+ of most parameters.
+
+ .. note::
+ MSSQL supports this construct as the preferred alternative to
+ generate an IDENTITY on a column, but it uses non standard
+ syntax that only support :paramref:`_schema.Identity.start`
+ and :paramref:`_schema.Identity.increment`.
+ All other parameters are ignored.
+
+ :param always:
+ A boolean, that indicates the type of identity column.
+ If ``False`` is specified, the default, then the user-specified
+ value takes precedence.
+ If ``True`` is specified, a user-specified value is not accepted (
+ on some backends, like PostgreSQL, OVERRIDING SYSTEM VALUE, or
+ similar, may be specified in an INSERT to override the sequence
+ value).
+ Some backends also have a default value for this parameter,
+ ``None`` can be used to omit rendering this part in the DDL. It
+ will be treated as ``False`` if a backend does not have a default
+ value.
+
+ :param on_null:
+ Set to ``True`` to specify ON NULL in conjunction with a
+ ``always=False`` identity column. This option is only supported on
+ some backends, like Oracle.
+
+ :param start: the starting index of the sequence.
+ :param increment: the increment value of the sequence.
+ :param minvalue: the minimum value of the sequence.
+ :param maxvalue: the maximum value of the sequence.
+ :param nominvalue: no minimum value of the sequence.
+ :param nomaxvalue: no maximum value of the sequence.
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached.
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance.
+ :param order: optional boolean value; if true, renders the
+ ORDER keyword.
+
+ """
+ IdentityOptions.__init__(
+ self,
+ start=start,
+ increment=increment,
+ minvalue=minvalue,
+ maxvalue=maxvalue,
+ nominvalue=nominvalue,
+ nomaxvalue=nomaxvalue,
+ cycle=cycle,
+ cache=cache,
+ order=order,
+ )
+ self.always = always
+ self.on_null = on_null
+ self.column = None
+
+ def _set_parent(self, parent, **kw):
+ if not isinstance(
+ parent.server_default, (type(None), Identity)
+ ) or not isinstance(parent.server_onupdate, type(None)):
+ raise exc.ArgumentError(
+ "A column with an Identity object cannot specify a "
+ "server_default or a server_onupdate argument"
+ )
+ if parent.autoincrement is False:
+ raise exc.ArgumentError(
+ "A column with an Identity object cannot specify "
+ "autoincrement=False"
+ )
+ self.column = parent
+
+ parent.identity = self
+ if parent._user_defined_nullable is NULL_UNSPECIFIED:
+ parent.nullable = False
+
+ parent.server_default = self
+
+ def _as_for_update(self, for_update):
+ return self
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Identity.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ i = Identity(
+ always=self.always,
+ on_null=self.on_null,
+ start=self.start,
+ increment=self.increment,
+ minvalue=self.minvalue,
+ maxvalue=self.maxvalue,
+ nominvalue=self.nominvalue,
+ nomaxvalue=self.nomaxvalue,
+ cycle=self.cycle,
+ cache=self.cache,
+ order=self.order,
+ )
+
+ return self._schema_item_copy(i)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
new file mode 100644
index 0000000..8379e1c
--- /dev/null
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -0,0 +1,6946 @@
+# sql/selectable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The :class:`_expression.FromClause` class of SQL expression elements,
+representing
+SQL tables and derived rowsets.
+
+"""
+
+import collections
+import itertools
+from operator import attrgetter
+
+from . import coercions
+from . import operators
+from . import roles
+from . import traversals
+from . import type_api
+from . import visitors
+from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
+from .base import _clone
+from .base import _cloned_difference
+from .base import _cloned_intersection
+from .base import _entity_namespace_key
+from .base import _expand_cloned
+from .base import _from_objects
+from .base import _generative
+from .base import _select_iterables
+from .base import CacheableOptions
+from .base import ColumnCollection
+from .base import ColumnSet
+from .base import CompileState
+from .base import DedupeColumnCollection
+from .base import Executable
+from .base import Generative
+from .base import HasCompileState
+from .base import HasMemoized
+from .base import Immutable
+from .base import prefix_anon_map
+from .coercions import _document_text_coercion
+from .elements import _anonymous_label
+from .elements import and_
+from .elements import BindParameter
+from .elements import BooleanClauseList
+from .elements import ClauseElement
+from .elements import ClauseList
+from .elements import ColumnClause
+from .elements import GroupedElement
+from .elements import Grouping
+from .elements import literal_column
+from .elements import TableValuedColumn
+from .elements import UnaryExpression
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..inspection import inspect
+
+
+class _OffsetLimitParam(BindParameter):
+ inherit_cache = True
+
+ @property
+ def _limit_offset_value(self):
+ return self.effective_value
+
+
+@util.deprecated(
+ "1.4",
+ "The standalone :func:`.subquery` function is deprecated "
+ "and will be removed in a future release. Use select().subquery().",
+)
+def subquery(alias, *args, **kwargs):
+ r"""Return an :class:`.Subquery` object derived
+ from a :class:`_expression.Select`.
+
+ :param alias: the alias name for the subquery
+
+ :param \*args, \**kwargs: all other arguments are passed through to the
+ :func:`_expression.select` function.
+
+ """
+ return Select.create_legacy_select(*args, **kwargs).subquery(alias)
+
+
+class ReturnsRows(roles.ReturnsRowsRole, ClauseElement):
+ """The base-most class for Core constructs that have some concept of
+ columns that can represent rows.
+
+ While the SELECT statement and TABLE are the primary things we think
+ of in this category, DML like INSERT, UPDATE and DELETE can also specify
+ RETURNING which means they can be used in CTEs and other forms, and
+ PostgreSQL has functions that return rows also.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_returns_rows = True
+
+ # sub-elements of returns_rows
+ _is_from_clause = False
+ _is_select_statement = False
+ _is_lateral = False
+
+ @property
+ def selectable(self):
+ return self
+
+ @property
+ def _all_selected_columns(self):
+ """A sequence of column expression objects that represents the
+ "selected" columns of this :class:`_expression.ReturnsRows`.
+
+ This is typically equivalent to .exported_columns except it is
+ delivered in the form of a straight sequence and not keyed
+ :class:`_expression.ColumnCollection`.
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.ReturnsRows`.
+
+ The "exported" columns represent the collection of
+ :class:`_expression.ColumnElement`
+ expressions that are rendered by this SQL
+ construct. There are primary varieties which are the
+ "FROM clause columns" of a FROM clause, such as a table, join,
+ or subquery, the "SELECTed columns", which are the columns in
+ the "columns clause" of a SELECT statement, and the RETURNING
+ columns in a DML statement..
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.FromClause.exported_columns`
+
+ :attr:`_expression.SelectBase.exported_columns`
+ """
+
+ raise NotImplementedError()
+
+
+class Selectable(ReturnsRows):
+ """Mark a class as being selectable."""
+
+ __visit_name__ = "selectable"
+
+ is_selectable = True
+
+ def _refresh_for_new_column(self, column):
+ raise NotImplementedError()
+
+ def lateral(self, name=None):
+ """Return a LATERAL alias of this :class:`_expression.Selectable`.
+
+ The return value is the :class:`_expression.Lateral` construct also
+ provided by the top-level :func:`_expression.lateral` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+ return Lateral._construct(self, name)
+
+ @util.deprecated(
+ "1.4",
+ message="The :meth:`.Selectable.replace_selectable` method is "
+ "deprecated, and will be removed in a future release. Similar "
+ "functionality is available via the sqlalchemy.sql.visitors module.",
+ )
+ @util.preload_module("sqlalchemy.sql.util")
+ def replace_selectable(self, old, alias):
+ """Replace all occurrences of :class:`_expression.FromClause`
+ 'old' with the given :class:`_expression.Alias`
+ object, returning a copy of this :class:`_expression.FromClause`.
+
+ """
+ return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self)
+
+ def corresponding_column(self, column, require_embedded=False):
+ """Given a :class:`_expression.ColumnElement`, return the exported
+ :class:`_expression.ColumnElement` object from the
+ :attr:`_expression.Selectable.exported_columns`
+ collection of this :class:`_expression.Selectable`
+ which corresponds to that
+ original :class:`_expression.ColumnElement` via a common ancestor
+ column.
+
+ :param column: the target :class:`_expression.ColumnElement`
+ to be matched.
+
+ :param require_embedded: only return corresponding columns for
+ the given :class:`_expression.ColumnElement`, if the given
+ :class:`_expression.ColumnElement`
+ is actually present within a sub-element
+ of this :class:`_expression.Selectable`.
+ Normally the column will match if
+ it merely shares a common ancestor with one of the exported
+ columns of this :class:`_expression.Selectable`.
+
+ .. seealso::
+
+ :attr:`_expression.Selectable.exported_columns` - the
+ :class:`_expression.ColumnCollection`
+ that is used for the operation.
+
+ :meth:`_expression.ColumnCollection.corresponding_column`
+ - implementation
+ method.
+
+ """
+
+ return self.exported_columns.corresponding_column(
+ column, require_embedded
+ )
+
+
+class HasPrefixes(object):
+ _prefixes = ()
+
+ _has_prefixes_traverse_internals = [
+ ("_prefixes", InternalTraversal.dp_prefix_sequence)
+ ]
+
+ @_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`_expression.HasPrefixes.prefix_with`",
+ ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ )
+ def prefix_with(self, *expr, **kw):
+ r"""Add one or more expressions following the statement keyword, i.e.
+ SELECT, INSERT, UPDATE, or DELETE. Generative.
+
+ This is used to support backend-specific prefix keywords such as those
+ provided by MySQL.
+
+ E.g.::
+
+ stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql")
+
+ # MySQL 5.7 optimizer hints
+ stmt = select(table).prefix_with(
+ "/*+ BKA(t1) */", dialect="mysql")
+
+ Multiple prefixes can be specified by multiple calls
+ to :meth:`_expression.HasPrefixes.prefix_with`.
+
+ :param \*expr: textual or :class:`_expression.ClauseElement`
+ construct which
+ will be rendered following the INSERT, UPDATE, or DELETE
+ keyword.
+ :param \**kw: A single keyword 'dialect' is accepted. This is an
+ optional string dialect name which will
+ limit rendering of this prefix to only that dialect.
+
+ """
+ dialect = kw.pop("dialect", None)
+ if kw:
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
+ self._setup_prefixes(expr, dialect)
+
+ def _setup_prefixes(self, prefixes, dialect=None):
+ self._prefixes = self._prefixes + tuple(
+ [
+ (coercions.expect(roles.StatementOptionRole, p), dialect)
+ for p in prefixes
+ ]
+ )
+
+
+class HasSuffixes(object):
+ _suffixes = ()
+
+ _has_suffixes_traverse_internals = [
+ ("_suffixes", InternalTraversal.dp_prefix_sequence)
+ ]
+
+ @_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`_expression.HasSuffixes.suffix_with`",
+ ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ )
+ def suffix_with(self, *expr, **kw):
+ r"""Add one or more expressions following the statement as a whole.
+
+ This is used to support backend-specific suffix keywords on
+ certain constructs.
+
+ E.g.::
+
+ stmt = select(col1, col2).cte().suffix_with(
+ "cycle empno set y_cycle to 1 default 0", dialect="oracle")
+
+ Multiple suffixes can be specified by multiple calls
+ to :meth:`_expression.HasSuffixes.suffix_with`.
+
+ :param \*expr: textual or :class:`_expression.ClauseElement`
+ construct which
+ will be rendered following the target clause.
+ :param \**kw: A single keyword 'dialect' is accepted. This is an
+ optional string dialect name which will
+ limit rendering of this suffix to only that dialect.
+
+ """
+ dialect = kw.pop("dialect", None)
+ if kw:
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
+ self._setup_suffixes(expr, dialect)
+
+ def _setup_suffixes(self, suffixes, dialect=None):
+ self._suffixes = self._suffixes + tuple(
+ [
+ (coercions.expect(roles.StatementOptionRole, p), dialect)
+ for p in suffixes
+ ]
+ )
+
+
+class HasHints(object):
+ _hints = util.immutabledict()
+ _statement_hints = ()
+
+ _has_hints_traverse_internals = [
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+
+ def with_statement_hint(self, text, dialect_name="*"):
+ """Add a statement hint to this :class:`_expression.Select` or
+ other selectable object.
+
+ This method is similar to :meth:`_expression.Select.with_hint`
+ except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_hint`
+
+ :meth:`_expression.Select.prefix_with` - generic SELECT prefixing
+ which also can suit some database-specific HINT syntaxes such as
+ MySQL optimizer hints
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
+ @_generative
+ def with_hint(self, selectable, text, dialect_name="*"):
+ r"""Add an indexing or other executional context hint for the given
+ selectable to this :class:`_expression.Select` or other selectable
+ object.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the given :class:`_schema.Table` or :class:`_expression.Alias`
+ passed as the
+ ``selectable`` argument. The dialect implementation
+ typically uses Python string substitution syntax
+ with the token ``%(name)s`` to render the name of
+ the table or alias. E.g. when using Oracle, the
+ following::
+
+ select(mytable).\
+ with_hint(mytable, "index(%(name)s ix_mytable)")
+
+ Would render SQL as::
+
+ select /*+ index(mytable ix_mytable) */ ... from mytable
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add hints for both Oracle
+ and Sybase simultaneously::
+
+ select(mytable).\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
+ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_statement_hint`
+
+ """
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text),)
+ else:
+ self._hints = self._hints.union(
+ {
+ (
+ coercions.expect(roles.FromClauseRole, selectable),
+ dialect_name,
+ ): text
+ }
+ )
+
+
+class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+ """Represent an element that can be used within the ``FROM``
+ clause of a ``SELECT`` statement.
+
+ The most common forms of :class:`_expression.FromClause` are the
+ :class:`_schema.Table` and the :func:`_expression.select` constructs. Key
+ features common to all :class:`_expression.FromClause` objects include:
+
+ * a :attr:`.c` collection, which provides per-name access to a collection
+ of :class:`_expression.ColumnElement` objects.
+ * a :attr:`.primary_key` attribute, which is a collection of all those
+ :class:`_expression.ColumnElement`
+ objects that indicate the ``primary_key`` flag.
+ * Methods to generate various derivations of a "from" clause, including
+ :meth:`_expression.FromClause.alias`,
+ :meth:`_expression.FromClause.join`,
+ :meth:`_expression.FromClause.select`.
+
+
+ """
+
+ __visit_name__ = "fromclause"
+ named_with_column = False
+ _hide_froms = []
+
+ schema = None
+ """Define the 'schema' attribute for this :class:`_expression.FromClause`.
+
+ This is typically ``None`` for most objects except that of
+ :class:`_schema.Table`, where it is taken as the value of the
+ :paramref:`_schema.Table.schema` argument.
+
+ """
+
+ is_selectable = True
+ _is_from_clause = True
+ _is_join = False
+
+ _use_schema_map = False
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.FromClause.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use of "
+ "the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.FromClause.select` method will no longer accept "
+ "keyword arguments in version 2.0. Please use generative methods "
+ "from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Return a SELECT of this :class:`_expression.FromClause`.
+
+
+ e.g.::
+
+ stmt = some_table.select().where(some_table.c.id == 5)
+
+ :param whereclause: a WHERE clause, equivalent to calling the
+ :meth:`_sql.Select.where` method.
+
+ :param \**kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ .. seealso::
+
+ :func:`_expression.select` - general purpose
+ method which allows for arbitrary column lists.
+
+ """
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(self, [self], **kwargs)
+
+ def join(self, right, onclause=None, isouter=False, full=False):
+ """Return a :class:`_expression.Join` from this
+ :class:`_expression.FromClause`
+ to another :class:`FromClause`.
+
+ E.g.::
+
+ from sqlalchemy import join
+
+ j = user_table.join(address_table,
+ user_table.c.id == address_table.c.user_id)
+ stmt = select(user_table).select_from(j)
+
+ would emit SQL along the lines of::
+
+ SELECT user.id, user.name FROM user
+ JOIN address ON user.id = address.user_id
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of LEFT OUTER
+ JOIN. Implies :paramref:`.FromClause.join.isouter`.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.join` - standalone function
+
+ :class:`_expression.Join` - the type of object produced
+
+ """
+
+ return Join(self, right, onclause, isouter, full)
+
+ def outerjoin(self, right, onclause=None, full=False):
+ """Return a :class:`_expression.Join` from this
+ :class:`_expression.FromClause`
+ to another :class:`FromClause`, with the "isouter" flag set to
+ True.
+
+ E.g.::
+
+ from sqlalchemy import outerjoin
+
+ j = user_table.outerjoin(address_table,
+ user_table.c.id == address_table.c.user_id)
+
+ The above is equivalent to::
+
+ j = user_table.join(
+ address_table,
+ user_table.c.id == address_table.c.user_id,
+ isouter=True)
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of
+ LEFT OUTER JOIN.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.join`
+
+ :class:`_expression.Join`
+
+ """
+
+ return Join(self, right, onclause, True, full)
+
+ def alias(self, name=None, flat=False):
+ """Return an alias of this :class:`_expression.FromClause`.
+
+ E.g.::
+
+ a2 = some_table.alias('a2')
+
+ The above code creates an :class:`_expression.Alias`
+ object which can be used
+ as a FROM clause in any SELECT statement.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+
+ return Alias._construct(self, name)
+
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def table_valued(self):
+ """Return a :class:`_sql.TableValuedColumn` object for this
+ :class:`_expression.FromClause`.
+
+ A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
+ represents a complete row in a table. Support for this construct is
+ backend dependent, and is supported in various forms by backends
+ such as PostgreSQL, Oracle and SQL Server.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, column, func, table
+ >>> a = table("a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return TableValuedColumn(self, type_api.TABLEVALUE)
+
+ def tablesample(self, sampling, name=None, seed=None):
+ """Return a TABLESAMPLE alias of this :class:`_expression.FromClause`.
+
+ The return value is the :class:`_expression.TableSample`
+ construct also
+ provided by the top-level :func:`_expression.tablesample` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.tablesample` - usage guidelines and parameters
+
+ """
+ return TableSample._construct(self, sampling, name, seed)
+
+ def is_derived_from(self, fromclause):
+ """Return ``True`` if this :class:`_expression.FromClause` is
+ 'derived' from the given ``FromClause``.
+
+ An example would be an Alias of a Table is derived from that Table.
+
+ """
+ # this is essentially an "identity" check in the base class.
+ # Other constructs override this to traverse through
+ # contained elements.
+ return fromclause in self._cloned_set
+
+ def _is_lexical_equivalent(self, other):
+ """Return ``True`` if this :class:`_expression.FromClause` and
+ the other represent the same lexical identity.
+
+ This tests if either one is a copy of the other, or
+ if they are the same via annotation identity.
+
+ """
+ return self._cloned_set.intersection(other._cloned_set)
+
+ @property
+ def description(self):
+ """A brief description of this :class:`_expression.FromClause`.
+
+ Used primarily for error message formatting.
+
+ """
+ return getattr(self, "name", self.__class__.__name__ + " object")
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self.c
+ )
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.Selectable`.
+
+ The "exported" columns for a :class:`_expression.FromClause`
+ object are synonymous
+ with the :attr:`_expression.FromClause.columns` collection.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.Selectable.exported_columns`
+
+ :attr:`_expression.SelectBase.exported_columns`
+
+
+ """
+ return self.columns
+
+ @util.memoized_property
+ def columns(self):
+ """A named-based collection of :class:`_expression.ColumnElement`
+ objects maintained by this :class:`_expression.FromClause`.
+
+ The :attr:`.columns`, or :attr:`.c` collection, is the gateway
+ to the construction of SQL expressions using table-bound or
+ other selectable-bound columns::
+
+ select(mytable).where(mytable.c.somecolumn == 5)
+
+ :return: a :class:`.ColumnCollection` object.
+
+ """
+
+ if "_columns" not in self.__dict__:
+ self._init_collections()
+ self._populate_column_collection()
+ return self._columns.as_immutable()
+
+ @property
+ def entity_namespace(self):
+ """Return a namespace used for name-based access in SQL expressions.
+
+ This is the namespace that is used to resolve "filter_by()" type
+ expressions, such as::
+
+ stmt.filter_by(address='some address')
+
+ It defaults to the ``.c`` collection, however internally it can
+ be overridden using the "entity_namespace" annotation to deliver
+ alternative results.
+
+ """
+ return self.columns
+
+ @util.memoized_property
+ def primary_key(self):
+ """Return the iterable collection of :class:`_schema.Column` objects
+ which comprise the primary key of this :class:`_selectable.FromClause`.
+
+ For a :class:`_schema.Table` object, this collection is represented
+ by the :class:`_schema.PrimaryKeyConstraint` which itself is an
+ iterable collection of :class:`_schema.Column` objects.
+
+ """
+ self._init_collections()
+ self._populate_column_collection()
+ return self.primary_key
+
+ @util.memoized_property
+ def foreign_keys(self):
+ """Return the collection of :class:`_schema.ForeignKey` marker objects
+ which this FromClause references.
+
+ Each :class:`_schema.ForeignKey` is a member of a
+ :class:`_schema.Table`-wide
+ :class:`_schema.ForeignKeyConstraint`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.foreign_key_constraints`
+
+ """
+ self._init_collections()
+ self._populate_column_collection()
+ return self.foreign_keys
+
+ def _reset_column_collection(self):
+ """Reset the attributes linked to the ``FromClause.c`` attribute.
+
+ This collection is separate from all the other memoized things
+ as it has shown to be sensitive to being cleared out in situations
+ where enclosing code, typically in a replacement traversal scenario,
+ has already established strong relationships
+ with the exported columns.
+
+ The collection is cleared for the case where a table is having a
+ column added to it as well as within a Join during copy internals.
+
+ """
+
+ for key in ["_columns", "columns", "primary_key", "foreign_keys"]:
+ self.__dict__.pop(key, None)
+
+ c = property(
+ attrgetter("columns"),
+ doc="""
+ A named-based collection of :class:`_expression.ColumnElement`
+ objects maintained by this :class:`_expression.FromClause`.
+
+ The :attr:`_sql.FromClause.c` attribute is an alias for the
+ :attr:`_sql.FromClause.columns` attribute.
+
+ :return: a :class:`.ColumnCollection`
+
+ """,
+ )
+ _select_iterable = property(attrgetter("columns"))
+
+ def _init_collections(self):
+ assert "_columns" not in self.__dict__
+ assert "primary_key" not in self.__dict__
+ assert "foreign_keys" not in self.__dict__
+
+ self._columns = ColumnCollection()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
+
+ @property
+ def _cols_populated(self):
+ return "_columns" in self.__dict__
+
+ def _populate_column_collection(self):
+ """Called on subclasses to establish the .c collection.
+
+ Each implementation has a different way of establishing
+ this collection.
+
+ """
+
+ def _refresh_for_new_column(self, column):
+ """Given a column added to the .c collection of an underlying
+ selectable, produce the local version of that column, assuming this
+ selectable ultimately should proxy this column.
+
+ this is used to "ping" a derived selectable to add a new column
+ to its .c. collection when a Column has been added to one of the
+ Table objects it ultimately derives from.
+
+ If the given selectable hasn't populated its .c. collection yet,
+ it should at least pass on the message to the contained selectables,
+ but it will return None.
+
+ This method is currently used by Declarative to allow Table
+ columns to be added to a partially constructed inheritance
+ mapping that may have already produced joins. The method
+ isn't public right now, as the full span of implications
+ and/or caveats aren't yet clear.
+
+ It's also possible that this functionality could be invoked by
+ default via an event, which would require that
+ selectables maintain a weak referencing collection of all
+ derivations.
+
+ """
+ self._reset_column_collection()
+
+ def _anonymous_fromclause(self, name=None, flat=False):
+ return self.alias(name=name)
+
+
+LABEL_STYLE_NONE = util.symbol(
+ "LABEL_STYLE_NONE",
+ """Label style indicating no automatic labeling should be applied to the
+ columns clause of a SELECT statement.
+
+ Below, the columns named ``columna`` are both rendered as is, meaning that
+ the name ``columna`` can only refer to the first occurrence of this name
+ within a result set, as well as if the statement were used as a subquery::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_NONE
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_NONE))
+ SELECT table1.columna, table1.columnb, table2.columna, table2.columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.Select.set_label_style` method.
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501
+)
+
+LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol(
+ "LABEL_STYLE_TABLENAME_PLUS_COL",
+ """Label style indicating all columns should be labeled as
+ ``<tablename>_<columnname>`` when generating the columns clause of a SELECT
+ statement, to disambiguate same-named columns referenced from different
+ tables, aliases, or subqueries.
+
+ Below, all column names are given a label so that the two same-named
+ columns ``columna`` are disambiguated as ``table1_columna`` and
+ ``table2_columna``::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_TABLENAME_PLUS_COL
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL))
+ SELECT table1.columna AS table1_columna, table1.columnb AS table1_columnb, table2.columna AS table2_columna, table2.columnc AS table2_columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.GenerativeSelect.set_label_style` method.
+ Equivalent to the legacy method ``Select.apply_labels()``;
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL` is SQLAlchemy's legacy
+ auto-labeling style. :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` provides a
+ less intrusive approach to disambiguation of same-named column expressions.
+
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501
+)
+
+
+LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol(
+ "LABEL_STYLE_DISAMBIGUATE_ONLY",
+ """Label style indicating that columns with a name that conflicts with
+ an existing name should be labeled with a semi-anonymizing label
+ when generating the columns clause of a SELECT statement.
+
+ Below, most column names are left unaffected, except for the second
+ occurrence of the name ``columna``, which is labeled using the
+ label ``columna_1`` to disambiguate it from that of ``tablea.columna``::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_DISAMBIGUATE_ONLY
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY))
+ SELECT table1.columna, table1.columnb, table2.columna AS columna_1, table2.columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.GenerativeSelect.set_label_style` method,
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` is the default labeling style
+ for all SELECT statements outside of :term:`1.x style` ORM queries.
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501,
+)
+
+
+LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
+"""The default label style, refers to
+:data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`.
+
+.. versionadded:: 1.4
+
+"""
+
+
+class Join(roles.DMLTableRole, FromClause):
+ """Represent a ``JOIN`` construct between two
+ :class:`_expression.FromClause`
+ elements.
+
+ The public constructor function for :class:`_expression.Join`
+ is the module-level
+ :func:`_expression.join()` function, as well as the
+ :meth:`_expression.FromClause.join` method
+ of any :class:`_expression.FromClause` (e.g. such as
+ :class:`_schema.Table`).
+
+ .. seealso::
+
+ :func:`_expression.join`
+
+ :meth:`_expression.FromClause.join`
+
+ """
+
+ __visit_name__ = "join"
+
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("onclause", InternalTraversal.dp_clauseelement),
+ ("isouter", InternalTraversal.dp_boolean),
+ ("full", InternalTraversal.dp_boolean),
+ ]
+
+ _is_join = True
+
+ def __init__(self, left, right, onclause=None, isouter=False, full=False):
+ """Construct a new :class:`_expression.Join`.
+
+ The usual entrypoint here is the :func:`_expression.join`
+ function or the :meth:`_expression.FromClause.join` method of any
+ :class:`_expression.FromClause` object.
+
+ """
+ self.left = coercions.expect(
+ roles.FromClauseRole, left, deannotate=True
+ )
+ self.right = coercions.expect(
+ roles.FromClauseRole, right, deannotate=True
+ ).self_group()
+
+ if onclause is None:
+ self.onclause = self._match_primaries(self.left, self.right)
+ else:
+ # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
+ # not merged yet
+ self.onclause = coercions.expect(
+ roles.OnClauseRole, onclause
+ ).self_group(against=operators._asbool)
+
+ self.isouter = isouter
+ self.full = full
+
+ @classmethod
+ def _create_outerjoin(cls, left, right, onclause=None, full=False):
+ """Return an ``OUTER JOIN`` clause element.
+
+ The returned object is an instance of :class:`_expression.Join`.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.outerjoin` method on any
+ :class:`_expression.FromClause`.
+
+ :param left: The left side of the join.
+
+ :param right: The right side of the join.
+
+ :param onclause: Optional criterion for the ``ON`` clause, is
+ derived from foreign key relationships established between
+ left and right otherwise.
+
+ To chain joins together, use the :meth:`_expression.FromClause.join`
+ or
+ :meth:`_expression.FromClause.outerjoin` methods on the resulting
+ :class:`_expression.Join` object.
+
+ """
+ return cls(left, right, onclause, isouter=True, full=full)
+
+ @classmethod
+ def _create_join(
+ cls, left, right, onclause=None, isouter=False, full=False
+ ):
+ """Produce a :class:`_expression.Join` object, given two
+ :class:`_expression.FromClause`
+ expressions.
+
+ E.g.::
+
+ j = join(user_table, address_table,
+ user_table.c.id == address_table.c.user_id)
+ stmt = select(user_table).select_from(j)
+
+ would emit SQL along the lines of::
+
+ SELECT user.id, user.name FROM user
+ JOIN address ON user.id = address.user_id
+
+ Similar functionality is available given any
+ :class:`_expression.FromClause` object (e.g. such as a
+ :class:`_schema.Table`) using
+ the :meth:`_expression.FromClause.join` method.
+
+ :param left: The left side of the join.
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of JOIN.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.join` - method form,
+ based on a given left side.
+
+ :class:`_expression.Join` - the type of object produced.
+
+ """
+
+ return cls(left, right, onclause, isouter, full)
+
+ @property
+ def description(self):
+ return "Join object on %s(%d) and %s(%d)" % (
+ self.left.description,
+ id(self.left),
+ self.right.description,
+ id(self.right),
+ )
+
+ def is_derived_from(self, fromclause):
+ return (
+ # use hash() to ensure direct comparison to annotated works
+ # as well
+ hash(fromclause) == hash(self)
+ or self.left.is_derived_from(fromclause)
+ or self.right.is_derived_from(fromclause)
+ )
+
+ def self_group(self, against=None):
+ return FromGrouping(self)
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _populate_column_collection(self):
+ sqlutil = util.preloaded.sql_util
+ columns = [c for c in self.left.columns] + [
+ c for c in self.right.columns
+ ]
+
+ self.primary_key.extend(
+ sqlutil.reduce_columns(
+ (c for c in columns if c.primary_key), self.onclause
+ )
+ )
+ self._columns._populate_separate_keys(
+ (col._tq_key_label, col) for col in columns
+ )
+ self.foreign_keys.update(
+ itertools.chain(*[col.foreign_keys for col in columns])
+ )
+
+ def _copy_internals(self, clone=_clone, **kw):
+ # see Select._copy_internals() for similar concept
+
+ # here we pre-clone "left" and "right" so that we can
+ # determine the new FROM clauses
+ all_the_froms = set(
+ itertools.chain(
+ _from_objects(self.left),
+ _from_objects(self.right),
+ )
+ )
+
+ # run the clone on those. these will be placed in the
+ # cache used by the clone function
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+
+ # set up a special replace function that will replace for
+ # ColumnClause with parent table referring to those
+ # replaced FromClause objects
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # run normal _copy_internals. the clones for
+ # left and right will come from the clone function's
+ # cache
+ super(Join, self)._copy_internals(clone=clone, **kw)
+
+ self._reset_memoizations()
+
+ def _refresh_for_new_column(self, column):
+ super(Join, self)._refresh_for_new_column(column)
+ self.left._refresh_for_new_column(column)
+ self.right._refresh_for_new_column(column)
+
+ def _match_primaries(self, left, right):
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+ return self._join_condition(left, right, a_subset=left_right)
+
+ @classmethod
+ def _join_condition(
+ cls, a, b, a_subset=None, consider_as_foreign_keys=None
+ ):
+ """Create a join condition between two tables or selectables.
+
+ e.g.::
+
+ join_condition(tablea, tableb)
+
+ would produce an expression along the lines of::
+
+ tablea.c.id==tableb.c.tablea_id
+
+ The join is determined based on the foreign key relationships
+ between the two selectables. If there are multiple ways
+ to join, or no way to join, an error is raised.
+
+ :param a_subset: An optional expression that is a sub-component
+ of ``a``. An attempt will be made to join to just this sub-component
+ first before looking at the full ``a`` construct, and if found
+ will be successful even if there are other ways to join to ``a``.
+ This allows the "right side" of a join to be passed thereby
+ providing a "natural join".
+
+ """
+ constraints = cls._joincond_scan_left_right(
+ a, a_subset, b, consider_as_foreign_keys
+ )
+
+ if len(constraints) > 1:
+ cls._joincond_trim_constraints(
+ a, b, constraints, consider_as_foreign_keys
+ )
+
+ if len(constraints) == 0:
+ if isinstance(b, FromGrouping):
+ hint = (
+ " Perhaps you meant to convert the right side to a "
+ "subquery using alias()?"
+ )
+ else:
+ hint = ""
+ raise exc.NoForeignKeysError(
+ "Can't find any foreign key relationships "
+ "between '%s' and '%s'.%s"
+ % (a.description, b.description, hint)
+ )
+
+ crit = [(x == y) for x, y in list(constraints.values())[0]]
+ if len(crit) == 1:
+ return crit[0]
+ else:
+ return and_(*crit)
+
+ @classmethod
+ def _can_join(cls, left, right, consider_as_foreign_keys=None):
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+
+ constraints = cls._joincond_scan_left_right(
+ a=left,
+ b=right,
+ a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+
+ return bool(constraints)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.sql.util")
+ def _joincond_scan_left_right(
+ cls, a, a_subset, b, consider_as_foreign_keys
+ ):
+ sql_util = util.preloaded.sql_util
+
+ a = coercions.expect(roles.FromClauseRole, a)
+ b = coercions.expect(roles.FromClauseRole, b)
+
+ constraints = collections.defaultdict(list)
+
+ for left in (a_subset, a):
+ if left is None:
+ continue
+ for fk in sorted(
+ b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
+ continue
+ try:
+ col = fk.get_referent(left)
+ except exc.NoReferenceError as nrte:
+ table_names = {t.name for t in sql_util.find_tables(left)}
+ if nrte.table_name in table_names:
+ raise
+ else:
+ continue
+
+ if col is not None:
+ constraints[fk.constraint].append((col, fk.parent))
+ if left is not b:
+ for fk in sorted(
+ left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
+ continue
+ try:
+ col = fk.get_referent(b)
+ except exc.NoReferenceError as nrte:
+ table_names = {t.name for t in sql_util.find_tables(b)}
+ if nrte.table_name in table_names:
+ raise
+ else:
+ continue
+
+ if col is not None:
+ constraints[fk.constraint].append((col, fk.parent))
+ if constraints:
+ break
+ return constraints
+
+ @classmethod
+ def _joincond_trim_constraints(
+ cls, a, b, constraints, consider_as_foreign_keys
+ ):
+ # more than one constraint matched. narrow down the list
+ # to include just those FKCs that match exactly to
+ # "consider_as_foreign_keys".
+ if consider_as_foreign_keys:
+ for const in list(constraints):
+ if set(f.parent for f in const.elements) != set(
+ consider_as_foreign_keys
+ ):
+ del constraints[const]
+
+ # if still multiple constraints, but
+ # they all refer to the exact same end result, use it.
+ if len(constraints) > 1:
+ dedupe = set(tuple(crit) for crit in constraints.values())
+ if len(dedupe) == 1:
+ key = list(constraints)[0]
+ constraints = {key: constraints[key]}
+
+ if len(constraints) != 1:
+ raise exc.AmbiguousForeignKeysError(
+ "Can't determine join between '%s' and '%s'; "
+ "tables have more than one foreign key "
+ "constraint relationship between them. "
+ "Please specify the 'onclause' of this "
+ "join explicitly." % (a.description, b.description)
+ )
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.Join.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use of "
+ "the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.Join.select` method will no longer accept "
+ "keyword arguments in version 2.0. Please use generative "
+ "methods from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Create a :class:`_expression.Select` from this
+ :class:`_expression.Join`.
+
+ E.g.::
+
+ stmt = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+
+ stmt = stmt.select()
+
+ The above will produce a SQL string resembling::
+
+ SELECT table_a.id, table_a.col, table_b.id, table_b.a_id
+ FROM table_a JOIN table_b ON table_a.id = table_b.a_id
+
+ :param whereclause: WHERE criteria, same as calling
+ :meth:`_sql.Select.where` on the resulting statement
+
+ :param \**kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ """
+ collist = [self.left, self.right]
+
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(
+ self, collist, **kwargs
+ ).select_from(self)
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Return the bound engine associated with either the left or right
+ side of this :class:`_sql.Join`.
+
+ """
+
+ return self.left.bind or self.right.bind
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _anonymous_fromclause(self, name=None, flat=False):
+ sqlutil = util.preloaded.sql_util
+ if flat:
+ if name is not None:
+ raise exc.ArgumentError("Can't send name argument with flat")
+ left_a, right_a = (
+ self.left._anonymous_fromclause(flat=True),
+ self.right._anonymous_fromclause(flat=True),
+ )
+ adapter = sqlutil.ClauseAdapter(left_a).chain(
+ sqlutil.ClauseAdapter(right_a)
+ )
+
+ return left_a.join(
+ right_a,
+ adapter.traverse(self.onclause),
+ isouter=self.isouter,
+ full=self.full,
+ )
+ else:
+ return (
+ self.select()
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .correlate(None)
+ .alias(name)
+ )
+
+ @util.deprecated_20(
+ ":meth:`_sql.Join.alias`",
+ alternative="Create a select + subquery, or alias the "
+ "individual tables inside the join, instead.",
+ )
+ def alias(self, name=None, flat=False):
+ r"""Return an alias of this :class:`_expression.Join`.
+
+ The default behavior here is to first produce a SELECT
+ construct from this :class:`_expression.Join`, then to produce an
+ :class:`_expression.Alias` from that. So given a join of the form::
+
+ j = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+
+ The JOIN by itself would look like::
+
+ table_a JOIN table_b ON table_a.id = table_b.a_id
+
+ Whereas the alias of the above, ``j.alias()``, would in a
+ SELECT context look like::
+
+ (SELECT table_a.id AS table_a_id, table_b.id AS table_b_id,
+ table_b.a_id AS table_b_a_id
+ FROM table_a
+ JOIN table_b ON table_a.id = table_b.a_id) AS anon_1
+
+ The equivalent long-hand form, given a :class:`_expression.Join`
+ object ``j``, is::
+
+ from sqlalchemy import select, alias
+ j = alias(
+ select(j.left, j.right).\
+ select_from(j).\
+ set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).\
+ correlate(False),
+ name=name
+ )
+
+ The selectable produced by :meth:`_expression.Join.alias`
+ features the same
+ columns as that of the two individual selectables presented under
+ a single name - the individual columns are "auto-labeled", meaning
+ the ``.c.`` collection of the resulting :class:`_expression.Alias`
+ represents
+ the names of the individual columns using a
+ ``<tablename>_<columname>`` scheme::
+
+ j.c.table_a_id
+ j.c.table_b_a_id
+
+ :meth:`_expression.Join.alias` also features an alternate
+ option for aliasing joins which produces no enclosing SELECT and
+ does not normally apply labels to the column names. The
+ ``flat=True`` option will call :meth:`_expression.FromClause.alias`
+ against the left and right sides individually.
+ Using this option, no new ``SELECT`` is produced;
+ we instead, from a construct as below::
+
+ j = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+ j = j.alias(flat=True)
+
+ we get a result like this::
+
+ table_a AS table_a_1 JOIN table_b AS table_b_1 ON
+ table_a_1.id = table_b_1.a_id
+
+ The ``flat=True`` argument is also propagated to the contained
+ selectables, so that a composite join such as::
+
+ j = table_a.join(
+ table_b.join(table_c,
+ table_b.c.id == table_c.c.b_id),
+ table_b.c.a_id == table_a.c.id
+ ).alias(flat=True)
+
+ Will produce an expression like::
+
+ table_a AS table_a_1 JOIN (
+ table_b AS table_b_1 JOIN table_c AS table_c_1
+ ON table_b_1.id = table_c_1.b_id
+ ) ON table_a_1.id = table_b_1.a_id
+
+ The standalone :func:`_expression.alias` function as well as the
+ base :meth:`_expression.FromClause.alias`
+ method also support the ``flat=True``
+ argument as a no-op, so that the argument can be passed to the
+ ``alias()`` method of any selectable.
+
+ :param name: name given to the alias.
+
+ :param flat: if True, produce an alias of the left and right
+ sides of this :class:`_expression.Join` and return the join of those
+ two selectables. This produces join expression that does not
+ include an enclosing SELECT.
+
+ .. seealso::
+
+ :ref:`core_tutorial_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ return self._anonymous_fromclause(flat=flat, name=name)
+
+ @property
+ def _hide_froms(self):
+ return itertools.chain(
+ *[_from_objects(x.left, x.right) for x in self._cloned_set]
+ )
+
+ @property
+ def _from_objects(self):
+ return [self] + self.left._from_objects + self.right._from_objects
+
+
+class NoInit(object):
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError(
+ "The %s class is not intended to be constructed "
+ "directly. Please use the %s() standalone "
+ "function or the %s() method available from appropriate "
+ "selectable objects."
+ % (
+ self.__class__.__name__,
+ self.__class__.__name__.lower(),
+ self.__class__.__name__.lower(),
+ )
+ )
+
+
+# FromClause ->
+# AliasedReturnsRows
+# -> Alias only for FromClause
+# -> Subquery only for SelectBase
+# -> CTE only for HasCTE -> SelectBase, DML
+# -> Lateral -> FromClause, but we accept SelectBase
+# w/ non-deprecated coercion
+# -> TableSample -> only for FromClause
+class AliasedReturnsRows(NoInit, FromClause):
+ """Base class of aliases against tables, subqueries, and other
+ selectables."""
+
+ _is_from_container = True
+ named_with_column = True
+
+ _supports_derived_columns = False
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ]
+
+ @classmethod
+ def _construct(cls, *arg, **kw):
+ obj = cls.__new__(cls)
+ obj._init(*arg, **kw)
+ return obj
+
+ @classmethod
+ def _factory(cls, returnsrows, name=None):
+ """Base factory method. Subclasses need to provide this."""
+ raise NotImplementedError()
+
+ def _init(self, selectable, name=None):
+ self.element = coercions.expect(
+ roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
+ )
+ self.element = selectable
+ self._orig_name = name
+ if name is None:
+ if (
+ isinstance(selectable, FromClause)
+ and selectable.named_with_column
+ ):
+ name = getattr(selectable, "name", None)
+ if isinstance(name, _anonymous_label):
+ name = None
+ name = _anonymous_label.safe_construct(id(self), name or "anon")
+ self.name = name
+
+ def _refresh_for_new_column(self, column):
+ super(AliasedReturnsRows, self)._refresh_for_new_column(column)
+ self.element._refresh_for_new_column(column)
+
+ @property
+ def description(self):
+ name = self.name
+ if isinstance(name, _anonymous_label):
+ name = "anon_1"
+
+ if util.py3k:
+ return name
+ else:
+ return name.encode("ascii", "backslashreplace")
+
+ @property
+ def original(self):
+ """Legacy for dialects that are referring to Alias.original."""
+ return self.element
+
+ def is_derived_from(self, fromclause):
+ if fromclause in self._cloned_set:
+ return True
+ return self.element.is_derived_from(fromclause)
+
+ def _populate_column_collection(self):
+ self.element._generate_fromclause_column_proxies(self)
+
+ def _copy_internals(self, clone=_clone, **kw):
+ existing_element = self.element
+
+ super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
+
+ # the element clone is usually against a Table that returns the
+ # same object. don't reset exported .c. collections and other
+ # memoized details if it was not changed. this saves a lot on
+ # performance.
+ if existing_element is not self.element:
+ self._reset_column_collection()
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+ @property
+ def bind(self):
+ return self.element.bind
+
+
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
+ """Represents an table or selectable alias (AS).
+
+ Represents an alias, as typically applied to any table or
+ sub-select within a SQL statement using the ``AS`` keyword (or
+ without the keyword on certain databases such as Oracle).
+
+ This object is constructed from the :func:`_expression.alias` module
+ level function as well as the :meth:`_expression.FromClause.alias`
+ method available
+ on all :class:`_expression.FromClause` subclasses.
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.alias`
+
+ """
+
+ __visit_name__ = "alias"
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None, flat=False):
+ """Return an :class:`_expression.Alias` object.
+
+ An :class:`_expression.Alias` represents any
+ :class:`_expression.FromClause`
+ with an alternate name assigned within SQL, typically using the ``AS``
+ clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
+
+ Similar functionality is available via the
+ :meth:`_expression.FromClause.alias`
+ method available on all :class:`_expression.FromClause` subclasses.
+ In terms of
+ a SELECT object as generated from the :func:`_expression.select`
+ function, the :meth:`_expression.SelectBase.alias` method returns an
+ :class:`_expression.Alias` or similar object which represents a named,
+ parenthesized subquery.
+
+ When an :class:`_expression.Alias` is created from a
+ :class:`_schema.Table` object,
+ this has the effect of the table being rendered
+ as ``tablename AS aliasname`` in a SELECT statement.
+
+ For :func:`_expression.select` objects, the effect is that of
+ creating a named subquery, i.e. ``(select ...) AS aliasname``.
+
+ The ``name`` parameter is optional, and provides the name
+ to use in the rendered SQL. If blank, an "anonymous" name
+ will be deterministically generated at compile time.
+ Deterministic means the name is guaranteed to be unique against
+ other constructs used in the same statement, and will also be the
+ same name for each successive compilation of the same statement
+ object.
+
+ :param selectable: any :class:`_expression.FromClause` subclass,
+ such as a table, select statement, etc.
+
+ :param name: string name to be assigned as the alias.
+ If ``None``, a name will be deterministically generated
+ at compile time.
+
+ :param flat: Will be passed through to if the given selectable
+ is an instance of :class:`_expression.Join` - see
+ :meth:`_expression.Join.alias`
+ for details.
+
+ """
+ return coercions.expect(
+ roles.FromClauseRole, selectable, allow_select=True
+ ).alias(name=name, flat=flat)
+
+
+class TableValuedAlias(Alias):
+ """An alias against a "table valued" SQL function.
+
+ This construct provides for a SQL function that returns columns
+ to be used in the FROM clause of a SELECT statement. The
+ object is generated using the :meth:`_functions.FunctionElement.table_valued`
+ method, e.g.::
+
+ >>> from sqlalchemy import select, func
+ >>> fn = func.json_array_elements_text('["one", "two", "three"]').table_valued("value")
+ >>> print(select(fn.c.value))
+ SELECT anon_1.value
+ FROM json_array_elements_text(:json_array_elements_text_1) AS anon_1
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+
+ __visit_name__ = "table_valued_alias"
+
+ _supports_derived_columns = True
+ _render_derived = False
+ _render_derived_w_types = False
+ joins_implicitly = False
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ("_tableval_type", InternalTraversal.dp_type),
+ ("_render_derived", InternalTraversal.dp_boolean),
+ ("_render_derived_w_types", InternalTraversal.dp_boolean),
+ ]
+
+ def _init(
+ self,
+ selectable,
+ name=None,
+ table_value_type=None,
+ joins_implicitly=False,
+ ):
+ super(TableValuedAlias, self)._init(selectable, name=name)
+
+ self.joins_implicitly = joins_implicitly
+ self._tableval_type = (
+ type_api.TABLEVALUE
+ if table_value_type is None
+ else table_value_type
+ )
+
+ @HasMemoized.memoized_attribute
+ def column(self):
+ """Return a column expression representing this
+ :class:`_sql.TableValuedAlias`.
+
+ This accessor is used to implement the
+ :meth:`_functions.FunctionElement.column_valued` method. See that
+ method for further details.
+
+ E.g.::
+
+ >>> print(select(func.some_func().table_valued("value").column))
+ SELECT anon_1 FROM some_func() AS anon_1
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ """
+
+ return TableValuedColumn(self, self._tableval_type)
+
+ def alias(self, name=None):
+ """Return a new alias of this :class:`_sql.TableValuedAlias`.
+
+ This creates a distinct FROM object that will be distinguished
+ from the original one when used in a SQL statement.
+
+ """
+
+ tva = TableValuedAlias._construct(
+ self,
+ name=name,
+ table_value_type=self._tableval_type,
+ joins_implicitly=self.joins_implicitly,
+ )
+
+ if self._render_derived:
+ tva._render_derived = True
+ tva._render_derived_w_types = self._render_derived_w_types
+
+ return tva
+
+ def lateral(self, name=None):
+ """Return a new :class:`_sql.TableValuedAlias` with the lateral flag
+ set, so that it renders as LATERAL.
+
+ .. seealso::
+
+ :func:`_expression.lateral`
+
+ """
+ tva = self.alias(name=name)
+ tva._is_lateral = True
+ return tva
+
+ def render_derived(self, name=None, with_types=False):
+ """Apply "render derived" to this :class:`_sql.TableValuedAlias`.
+
+ This has the effect of the individual column names listed out
+ after the alias name in the "AS" sequence, e.g.::
+
+ >>> print(
+ ... select(
+ ... func.unnest(array(["one", "two", "three"])).
+ table_valued("x", with_ordinality="o").render_derived()
+ ... )
+ ... )
+ SELECT anon_1.x, anon_1.o
+ FROM unnest(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) WITH ORDINALITY AS anon_1(x, o)
+
+ The ``with_types`` keyword will render column types inline within
+ the alias expression (this syntax currently applies to the
+ PostgreSQL database)::
+
+ >>> print(
+ ... select(
+ ... func.json_to_recordset(
+ ... '[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]'
+ ... )
+ ... .table_valued(column("a", Integer), column("b", String))
+ ... .render_derived(with_types=True)
+ ... )
+ ... )
+ SELECT anon_1.a, anon_1.b FROM json_to_recordset(:json_to_recordset_1)
+ AS anon_1(a INTEGER, b VARCHAR)
+
+ :param name: optional string name that will be applied to the alias
+ generated. If left as None, a unique anonymizing name will be used.
+
+ :param with_types: if True, the derived columns will include the
+ datatype specification with each column. This is a special syntax
+ currently known to be required by PostgreSQL for some SQL functions.
+
+ """ # noqa: E501
+
+ # note: don't use the @_generative system here, keep a reference
+ # to the original object. otherwise you can have re-use of the
+ # python id() of the original which can cause name conflicts if
+ # a new anon-name grabs the same identifier as the local anon-name
+ # (just saw it happen on CI)
+
+ # construct against original to prevent memory growth
+ # for repeated generations
+ new_alias = TableValuedAlias._construct(
+ self.element,
+ name=name,
+ table_value_type=self._tableval_type,
+ joins_implicitly=self.joins_implicitly,
+ )
+ new_alias._render_derived = True
+ new_alias._render_derived_w_types = with_types
+ return new_alias
+
+
+class Lateral(AliasedReturnsRows):
+ """Represent a LATERAL subquery.
+
+ This object is constructed from the :func:`_expression.lateral` module
+ level function as well as the :meth:`_expression.FromClause.lateral`
+ method available
+ on all :class:`_expression.FromClause` subclasses.
+
+ While LATERAL is part of the SQL standard, currently only more recent
+ PostgreSQL versions provide support for this keyword.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+
+ __visit_name__ = "lateral"
+ _is_lateral = True
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None):
+ """Return a :class:`_expression.Lateral` object.
+
+ :class:`_expression.Lateral` is an :class:`_expression.Alias`
+ subclass that represents
+ a subquery with the LATERAL keyword applied to it.
+
+ The special behavior of a LATERAL subquery is that it appears in the
+ FROM clause of an enclosing SELECT, but may correlate to other
+ FROM clauses of that SELECT. It is a special case of subquery
+ only supported by a small number of backends, currently more recent
+ PostgreSQL versions.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+
+ """
+ return coercions.expect(
+ roles.FromClauseRole, selectable, explicit_subquery=True
+ ).lateral(name=name)
+
+
+class TableSample(AliasedReturnsRows):
+ """Represent a TABLESAMPLE clause.
+
+ This object is constructed from the :func:`_expression.tablesample` module
+ level function as well as the :meth:`_expression.FromClause.tablesample`
+ method
+ available on all :class:`_expression.FromClause` subclasses.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.tablesample`
+
+ """
+
+ __visit_name__ = "tablesample"
+
+ _traverse_internals = AliasedReturnsRows._traverse_internals + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+
+ @classmethod
+ def _factory(cls, selectable, sampling, name=None, seed=None):
+ """Return a :class:`_expression.TableSample` object.
+
+ :class:`_expression.TableSample` is an :class:`_expression.Alias`
+ subclass that represents
+ a table with the TABLESAMPLE clause applied to it.
+ :func:`_expression.tablesample`
+ is also available from the :class:`_expression.FromClause`
+ class via the
+ :meth:`_expression.FromClause.tablesample` method.
+
+ The TABLESAMPLE clause allows selecting a randomly selected approximate
+ percentage of rows from a table. It supports multiple sampling methods,
+ most commonly BERNOULLI and SYSTEM.
+
+ e.g.::
+
+ from sqlalchemy import func
+
+ selectable = people.tablesample(
+ func.bernoulli(1),
+ name='alias',
+ seed=func.random())
+ stmt = select(selectable.c.people_id)
+
+ Assuming ``people`` with a column ``people_id``, the above
+ statement would render as::
+
+ SELECT alias.people_id FROM
+ people AS alias TABLESAMPLE bernoulli(:bernoulli_1)
+ REPEATABLE (random())
+
+ .. versionadded:: 1.1
+
+ :param sampling: a ``float`` percentage between 0 and 100 or
+ :class:`_functions.Function`.
+
+ :param name: optional alias name
+
+ :param seed: any real-valued SQL expression. When specified, the
+ REPEATABLE sub-clause is also rendered.
+
+ """
+ return coercions.expect(roles.FromClauseRole, selectable).tablesample(
+ sampling, name=name, seed=seed
+ )
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def _init(self, selectable, sampling, name=None, seed=None):
+ functions = util.preloaded.sql_functions
+ if not isinstance(sampling, functions.Function):
+ sampling = functions.func.system(sampling)
+
+ self.sampling = sampling
+ self.seed = seed
+ super(TableSample, self)._init(selectable, name=name)
+
+ def _get_method(self):
+ return self.sampling
+
+
+class CTE(
+ roles.DMLTableRole,
+ roles.IsCTERole,
+ Generative,
+ HasPrefixes,
+ HasSuffixes,
+ AliasedReturnsRows,
+):
+ """Represent a Common Table Expression.
+
+ The :class:`_expression.CTE` object is obtained using the
+ :meth:`_sql.SelectBase.cte` method from any SELECT statement. A less often
+ available syntax also allows use of the :meth:`_sql.HasCTE.cte` method
+ present on :term:`DML` constructs such as :class:`_sql.Insert`,
+ :class:`_sql.Update` and
+ :class:`_sql.Delete`. See the :meth:`_sql.HasCTE.cte` method for
+ usage details on CTEs.
+
+ .. seealso::
+
+ :ref:`tutorial_subqueries_ctes` - in the 2.0 tutorial
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+
+ __visit_name__ = "cte"
+
+ _traverse_internals = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("_cte_alias", InternalTraversal.dp_clauseelement),
+ ("_restates", InternalTraversal.dp_clauseelement),
+ ("recursive", InternalTraversal.dp_boolean),
+ ("nesting", InternalTraversal.dp_boolean),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + HasSuffixes._has_suffixes_traverse_internals
+ )
+
+ @classmethod
+ def _factory(cls, selectable, name=None, recursive=False):
+ r"""Return a new :class:`_expression.CTE`,
+ or Common Table Expression instance.
+
+ Please see :meth:`_expression.HasCTE.cte` for detail on CTE usage.
+
+ """
+ return coercions.expect(roles.HasCTERole, selectable).cte(
+ name=name, recursive=recursive
+ )
+
+ def _init(
+ self,
+ selectable,
+ name=None,
+ recursive=False,
+ nesting=False,
+ _cte_alias=None,
+ _restates=None,
+ _prefixes=None,
+ _suffixes=None,
+ ):
+ self.recursive = recursive
+ self.nesting = nesting
+ self._cte_alias = _cte_alias
+ # Keep recursivity reference with union/union_all
+ self._restates = _restates
+ if _prefixes:
+ self._prefixes = _prefixes
+ if _suffixes:
+ self._suffixes = _suffixes
+ super(CTE, self)._init(selectable, name=name)
+
+ def _populate_column_collection(self):
+ if self._cte_alias is not None:
+ self._cte_alias._generate_fromclause_column_proxies(self)
+ else:
+ self.element._generate_fromclause_column_proxies(self)
+
+ def alias(self, name=None, flat=False):
+ """Return an :class:`_expression.Alias` of this
+ :class:`_expression.CTE`.
+
+ This method is a CTE-specific specialization of the
+ :meth:`_expression.FromClause.alias` method.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ return CTE._construct(
+ self.element,
+ name=name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _cte_alias=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def union(self, *other):
+ r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
+ of the original CTE against the given selectables provided
+ as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+ .. seealso::
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+ return CTE._construct(
+ self.element.union(*other),
+ name=self.name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _restates=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def union_all(self, *other):
+ r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
+ of the original CTE against the given selectables provided
+ as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+ .. seealso::
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+ return CTE._construct(
+ self.element.union_all(*other),
+ name=self.name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _restates=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def _get_reference_cte(self):
+ """
+ A recursive CTE is updated to attach the recursive part.
+ Updated CTEs should still refer to the original CTE.
+ This function returns this reference identifier.
+ """
+ return self._restates if self._restates is not None else self
+
+
+class HasCTE(roles.HasCTERole):
+ """Mixin that declares a class to include CTE support.
+
+ .. versionadded:: 1.1
+
+ """
+
+ _has_ctes_traverse_internals = [
+ ("_independent_ctes", InternalTraversal.dp_clauseelement_list),
+ ]
+
+ _independent_ctes = ()
+
+ @_generative
+ def add_cte(self, cte):
+ """Add a :class:`_sql.CTE` to this statement object that will be
+ independently rendered even if not referenced in the statement
+ otherwise.
+
+ This feature is useful for the use case of embedding a DML statement
+ such as an INSERT or UPDATE as a CTE inline with a primary statement
+ that may draw from its results indirectly; while PostgreSQL is known
+ to support this usage, it may not be supported by other backends.
+
+ E.g.::
+
+ from sqlalchemy import table, column, select
+ t = table('t', column('c1'), column('c2'))
+
+ ins = t.insert().values({"c1": "x", "c2": "y"}).cte()
+
+ stmt = select(t).add_cte(ins)
+
+ Would render::
+
+ WITH anon_1 AS
+ (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2))
+ SELECT t.c1, t.c2
+ FROM t
+
+ Above, the "anon_1" CTE is not referred towards in the SELECT
+ statement, however still accomplishes the task of running an INSERT
+ statement.
+
+ Similarly in a DML-related context, using the PostgreSQL
+ :class:`_postgresql.Insert` construct to generate an "upsert"::
+
+ from sqlalchemy import table, column
+ from sqlalchemy.dialects.postgresql import insert
+
+ t = table("t", column("c1"), column("c2"))
+
+ delete_statement_cte = (
+ t.delete().where(t.c.c1 < 1).cte("deletions")
+ )
+
+ insert_stmt = insert(t).values({"c1": 1, "c2": 2})
+ update_statement = insert_stmt.on_conflict_do_update(
+ index_elements=[t.c.c1],
+ set_={
+ "c1": insert_stmt.excluded.c1,
+ "c2": insert_stmt.excluded.c2,
+ },
+ ).add_cte(delete_statement_cte)
+
+ print(update_statement)
+
+ The above statement renders as::
+
+ WITH deletions AS
+ (DELETE FROM t WHERE t.c1 < %(c1_1)s)
+ INSERT INTO t (c1, c2) VALUES (%(c1)s, %(c2)s)
+ ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, c2 = excluded.c2
+
+ .. versionadded:: 1.4.21
+
+ """
+ cte = coercions.expect(roles.IsCTERole, cte)
+ self._independent_ctes += (cte,)
+
+ def cte(self, name=None, recursive=False, nesting=False):
+ r"""Return a new :class:`_expression.CTE`,
+ or Common Table Expression instance.
+
+ Common table expressions are a SQL standard whereby SELECT
+ statements can draw upon secondary statements specified along
+ with the primary statement, using a clause called "WITH".
+ Special semantics regarding UNION can also be employed to
+ allow "recursive" queries, where a SELECT statement can draw
+ upon the set of rows that have previously been selected.
+
+ CTEs can also be applied to DML constructs UPDATE, INSERT
+ and DELETE on some databases, both as a source of CTE rows
+ when combined with RETURNING, as well as a consumer of
+ CTE rows.
+
+ .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as
+ CTE, CTEs added to UPDATE/INSERT/DELETE.
+
+ SQLAlchemy detects :class:`_expression.CTE` objects, which are treated
+ similarly to :class:`_expression.Alias` objects, as special elements
+ to be delivered to the FROM clause of the statement as well
+ as to a WITH clause at the top of the statement.
+
+ For special prefixes such as PostgreSQL "MATERIALIZED" and
+ "NOT MATERIALIZED", the :meth:`_expression.CTE.prefix_with`
+ method may be
+ used to establish these.
+
+ .. versionchanged:: 1.3.13 Added support for prefixes.
+ In particular - MATERIALIZED and NOT MATERIALIZED.
+
+ :param name: name given to the common table expression. Like
+ :meth:`_expression.FromClause.alias`, the name can be left as
+ ``None`` in which case an anonymous symbol will be used at query
+ compile time.
+ :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+ A recursive common table expression is intended to be used in
+ conjunction with UNION ALL in order to derive rows
+ from those already selected.
+ :param nesting: if ``True``, will render the CTE locally to the
+ actual statement.
+
+ .. versionadded:: 1.4.24
+
+ The following examples include two from PostgreSQL's documentation at
+ https://www.postgresql.org/docs/current/static/queries-with.html,
+ as well as additional examples.
+
+ Example 1, non recursive::
+
+ from sqlalchemy import (Table, Column, String, Integer,
+ MetaData, select, func)
+
+ metadata = MetaData()
+
+ orders = Table('orders', metadata,
+ Column('region', String),
+ Column('amount', Integer),
+ Column('product', String),
+ Column('quantity', Integer)
+ )
+
+ regional_sales = select(
+ orders.c.region,
+ func.sum(orders.c.amount).label('total_sales')
+ ).group_by(orders.c.region).cte("regional_sales")
+
+
+ top_regions = select(regional_sales.c.region).\
+ where(
+ regional_sales.c.total_sales >
+ select(
+ func.sum(regional_sales.c.total_sales) / 10
+ )
+ ).cte("top_regions")
+
+ statement = select(
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
+ func.sum(orders.c.amount).label("product_sales")
+ ).where(orders.c.region.in_(
+ select(top_regions.c.region)
+ )).group_by(orders.c.region, orders.c.product)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 2, WITH RECURSIVE::
+
+ from sqlalchemy import (Table, Column, String, Integer,
+ MetaData, select, func)
+
+ metadata = MetaData()
+
+ parts = Table('parts', metadata,
+ Column('part', String),
+ Column('sub_part', String),
+ Column('quantity', Integer),
+ )
+
+ included_parts = select(\
+ parts.c.sub_part, parts.c.part, parts.c.quantity\
+ ).\
+ where(parts.c.part=='our part').\
+ cte(recursive=True)
+
+
+ incl_alias = included_parts.alias()
+ parts_alias = parts.alias()
+ included_parts = included_parts.union_all(
+ select(
+ parts_alias.c.sub_part,
+ parts_alias.c.part,
+ parts_alias.c.quantity
+ ).\
+ where(parts_alias.c.part==incl_alias.c.sub_part)
+ )
+
+ statement = select(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).
+ label('total_quantity')
+ ).\
+ group_by(included_parts.c.sub_part)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 3, an upsert using UPDATE and INSERT with CTEs::
+
+ from datetime import date
+ from sqlalchemy import (MetaData, Table, Column, Integer,
+ Date, select, literal, and_, exists)
+
+ metadata = MetaData()
+
+ visitors = Table('visitors', metadata,
+ Column('product_id', Integer, primary_key=True),
+ Column('date', Date, primary_key=True),
+ Column('count', Integer),
+ )
+
+ # add 5 visitors for the product_id == 1
+ product_id = 1
+ day = date.today()
+ count = 5
+
+ update_cte = (
+ visitors.update()
+ .where(and_(visitors.c.product_id == product_id,
+ visitors.c.date == day))
+ .values(count=visitors.c.count + count)
+ .returning(literal(1))
+ .cte('update_cte')
+ )
+
+ upsert = visitors.insert().from_select(
+ [visitors.c.product_id, visitors.c.date, visitors.c.count],
+ select(literal(product_id), literal(day), literal(count))
+ .where(~exists(update_cte.select()))
+ )
+
+ connection.execute(upsert)
+
+ Example 4, Nesting CTE (SQLAlchemy 1.4.24 and above)::
+
+ value_a = select(
+ literal("root").label("n")
+ ).cte("value_a")
+
+ # A nested CTE with the same name as the root one
+ value_a_nested = select(
+ literal("nesting").label("n")
+ ).cte("value_a", nesting=True)
+
+ # Nesting CTEs takes ascendency locally
+ # over the CTEs at a higher level
+ value_b = select(value_a_nested.c.n).cte("value_b")
+
+ value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b"))
+
+ The above query will render the second CTE nested inside the first,
+ shown with inline parameters below as::
+
+ WITH
+ value_a AS
+ (SELECT 'root' AS n),
+ value_b AS
+ (WITH value_a AS
+ (SELECT 'nesting' AS n)
+ SELECT value_a.n AS n FROM value_a)
+ SELECT value_a.n AS a, value_b.n AS b
+ FROM value_a, value_b
+
+ Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
+
+ edge = Table(
+ "edge",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("left", Integer),
+ Column("right", Integer),
+ )
+
+ root_node = select(literal(1).label("node")).cte(
+ "nodes", recursive=True
+ )
+
+ left_edge = select(edge.c.left).join(
+ root_node, edge.c.right == root_node.c.node
+ )
+ right_edge = select(edge.c.right).join(
+ root_node, edge.c.left == root_node.c.node
+ )
+
+ subgraph_cte = root_node.union(left_edge, right_edge)
+
+ subgraph = select(subgraph_cte)
+
+ The above query will render 2 UNIONs inside the recursive CTE::
+
+ WITH RECURSIVE nodes(node) AS (
+ SELECT 1 AS node
+ UNION
+ SELECT edge."left" AS "left"
+ FROM edge JOIN nodes ON edge."right" = nodes.node
+ UNION
+ SELECT edge."right" AS "right"
+ FROM edge JOIN nodes ON edge."left" = nodes.node
+ )
+ SELECT nodes.node FROM nodes
+
+ .. seealso::
+
+ :meth:`_orm.Query.cte` - ORM version of
+ :meth:`_expression.HasCTE.cte`.
+
+ """
+ return CTE._construct(
+ self, name=name, recursive=recursive, nesting=nesting
+ )
+
+
+class Subquery(AliasedReturnsRows):
+ """Represent a subquery of a SELECT.
+
+ A :class:`.Subquery` is created by invoking the
+ :meth:`_expression.SelectBase.subquery` method, or for convenience the
+ :meth:`_expression.SelectBase.alias` method, on any
+ :class:`_expression.SelectBase` subclass
+ which includes :class:`_expression.Select`,
+ :class:`_expression.CompoundSelect`, and
+ :class:`_expression.TextualSelect`. As rendered in a FROM clause,
+ it represents the
+ body of the SELECT statement inside of parenthesis, followed by the usual
+ "AS <somename>" that defines all "alias" objects.
+
+ The :class:`.Subquery` object is very similar to the
+ :class:`_expression.Alias`
+ object and can be used in an equivalent way. The difference between
+ :class:`_expression.Alias` and :class:`.Subquery` is that
+ :class:`_expression.Alias` always
+ contains a :class:`_expression.FromClause` object whereas
+ :class:`.Subquery`
+ always contains a :class:`_expression.SelectBase` object.
+
+ .. versionadded:: 1.4 The :class:`.Subquery` class was added which now
+ serves the purpose of providing an aliased version of a SELECT
+ statement.
+
+ """
+
+ __visit_name__ = "subquery"
+
+ _is_subquery = True
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None):
+ """Return a :class:`.Subquery` object."""
+ return coercions.expect(
+ roles.SelectStatementRole, selectable
+ ).subquery(name=name)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.Subquery.as_scalar` method, which was previously "
+ "``Alias.as_scalar()`` prior to version 1.4, is deprecated and "
+ "will be removed in a future release; Please use the "
+ ":meth:`_expression.Select.scalar_subquery` method of the "
+ ":func:`_expression.select` "
+ "construct before constructing a subquery object, or with the ORM "
+ "use the :meth:`_query.Query.scalar_subquery` method.",
+ )
+ def as_scalar(self):
+ return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery()
+
+ def _execute_on_connection(
+ self,
+ connection,
+ multiparams,
+ params,
+ execution_options,
+ ):
+ util.warn_deprecated(
+ "Executing a subquery object is deprecated and will raise "
+ "ObjectNotExecutableError in an upcoming release. Please "
+ "execute the underlying select() statement directly.",
+ "1.4",
+ )
+ return self.element._execute_on_connection(
+ connection, multiparams, params, execution_options, _force=True
+ )
+
+
+class FromGrouping(GroupedElement, FromClause):
+ """Represent a grouping of a FROM clause"""
+
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ def __init__(self, element):
+ self.element = coercions.expect(roles.FromClauseRole, element)
+
+ def _init_collections(self):
+ pass
+
+ @property
+ def columns(self):
+ return self.element.columns
+
+ @property
+ def primary_key(self):
+ return self.element.primary_key
+
+ @property
+ def foreign_keys(self):
+ return self.element.foreign_keys
+
+ def is_derived_from(self, element):
+ return self.element.is_derived_from(element)
+
+ def alias(self, **kw):
+ return FromGrouping(self.element.alias(**kw))
+
+ def _anonymous_fromclause(self, **kw):
+ return FromGrouping(self.element._anonymous_fromclause(**kw))
+
+ @property
+ def _hide_froms(self):
+ return self.element._hide_froms
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def __getstate__(self):
+ return {"element": self.element}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+
+
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
+ """Represents a minimal "table" construct.
+
+ This is a lightweight table object that has only a name, a
+ collection of columns, which are typically produced
+ by the :func:`_expression.column` function, and a schema::
+
+ from sqlalchemy import table, column
+
+ user = table("user",
+ column("id"),
+ column("name"),
+ column("description"),
+ )
+
+ The :class:`_expression.TableClause` construct serves as the base for
+ the more commonly used :class:`_schema.Table` object, providing
+ the usual set of :class:`_expression.FromClause` services including
+ the ``.c.`` collection and statement generation methods.
+
+ It does **not** provide all the additional schema-level services
+ of :class:`_schema.Table`, including constraints, references to other
+ tables, or support for :class:`_schema.MetaData`-level services.
+ It's useful
+ on its own as an ad-hoc construct used to generate quick SQL
+ statements when a more fully fledged :class:`_schema.Table`
+ is not on hand.
+
+ """
+
+ __visit_name__ = "table"
+
+ _traverse_internals = [
+ (
+ "columns",
+ InternalTraversal.dp_fromclause_canonical_column_collection,
+ ),
+ ("name", InternalTraversal.dp_string),
+ ]
+
+ named_with_column = True
+
+ implicit_returning = False
+ """:class:`_expression.TableClause`
+ doesn't support having a primary key or column
+ -level defaults, so implicit returning doesn't apply."""
+
+ _autoincrement_column = None
+ """No PK or default support so no autoincrement column."""
+
+ def __init__(self, name, *columns, **kw):
+ """Produce a new :class:`_expression.TableClause`.
+
+ The object returned is an instance of
+ :class:`_expression.TableClause`, which
+ represents the "syntactical" portion of the schema-level
+ :class:`_schema.Table` object.
+ It may be used to construct lightweight table constructs.
+
+ .. versionchanged:: 1.0.0 :func:`_expression.table` can now
+ be imported from the plain ``sqlalchemy`` namespace like any
+ other SQL element.
+
+
+ :param name: Name of the table.
+
+ :param columns: A collection of :func:`_expression.column` constructs.
+
+ :param schema: The schema name for this table.
+
+ .. versionadded:: 1.3.18 :func:`_expression.table` can now
+ accept a ``schema`` argument.
+ """
+
+ super(TableClause, self).__init__()
+ self.name = name
+ self._columns = DedupeColumnCollection()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
+ for c in columns:
+ self.append_column(c)
+
+ schema = kw.pop("schema", None)
+ if schema is not None:
+ self.schema = schema
+ if self.schema is not None:
+ self.fullname = "%s.%s" % (self.schema, self.name)
+ else:
+ self.fullname = self.name
+ if kw:
+ raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw))
+
+ def __str__(self):
+ if self.schema is not None:
+ return self.schema + "." + self.name
+ else:
+ return self.name
+
+ def _refresh_for_new_column(self, column):
+ pass
+
+ def _init_collections(self):
+ pass
+
+ @util.memoized_property
+ def description(self):
+ if util.py3k:
+ return self.name
+ else:
+ return self.name.encode("ascii", "backslashreplace")
+
+ def append_column(self, c, **kw):
+ existing = c.table
+ if existing is not None and existing is not self:
+ raise exc.ArgumentError(
+ "column object '%s' already assigned to table '%s'"
+ % (c.key, existing)
+ )
+
+ self._columns.add(c)
+ c.table = self
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def insert(self, values=None, inline=False, **kwargs):
+ """Generate an :func:`_expression.insert` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.insert().values(name='foo')
+
+ See :func:`_expression.insert` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Insert(
+ self, values=values, inline=inline, **kwargs
+ )
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def update(self, whereclause=None, values=None, inline=False, **kwargs):
+ """Generate an :func:`_expression.update` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.update().where(table.c.id==7).values(name='foo')
+
+ See :func:`_expression.update` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Update(
+ self,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs
+ )
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def delete(self, whereclause=None, **kwargs):
+ """Generate a :func:`_expression.delete` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.delete().where(table.c.id==7)
+
+ See :func:`_expression.delete` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Delete(self, whereclause, **kwargs)
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+
+class ForUpdateArg(ClauseElement):
+ _traverse_internals = [
+ ("of", InternalTraversal.dp_clauseelement_list),
+ ("nowait", InternalTraversal.dp_boolean),
+ ("read", InternalTraversal.dp_boolean),
+ ("skip_locked", InternalTraversal.dp_boolean),
+ ]
+
+ @classmethod
+ def _from_argument(cls, with_for_update):
+ if isinstance(with_for_update, ForUpdateArg):
+ return with_for_update
+ elif with_for_update in (None, False):
+ return None
+ elif with_for_update is True:
+ return ForUpdateArg()
+ else:
+ return ForUpdateArg(**with_for_update)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, ForUpdateArg)
+ and other.nowait == self.nowait
+ and other.read == self.read
+ and other.skip_locked == self.skip_locked
+ and other.key_share == self.key_share
+ and other.of is self.of
+ )
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return id(self)
+
+ def __init__(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """Represents arguments specified to
+ :meth:`_expression.Select.for_update`.
+
+ """
+
+ self.nowait = nowait
+ self.read = read
+ self.skip_locked = skip_locked
+ self.key_share = key_share
+ if of is not None:
+ self.of = [
+ coercions.expect(roles.ColumnsClauseRole, elem)
+ for elem in util.to_list(of)
+ ]
+ else:
+ self.of = None
+
+
+class Values(Generative, FromClause):
+ """Represent a ``VALUES`` construct that can be used as a FROM element
+ in a statement.
+
+ The :class:`_expression.Values` object is created from the
+ :func:`_expression.values` function.
+
+ .. versionadded:: 1.4
+
+ """
+
+ named_with_column = True
+ __visit_name__ = "values"
+
+ _data = ()
+
+ _traverse_internals = [
+ ("_column_args", InternalTraversal.dp_clauseelement_list),
+ ("_data", InternalTraversal.dp_dml_multi_values),
+ ("name", InternalTraversal.dp_string),
+ ("literal_binds", InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(self, *columns, **kw):
+ r"""Construct a :class:`_expression.Values` construct.
+
+ The column expressions and the actual data for
+ :class:`_expression.Values` are given in two separate steps. The
+ constructor receives the column expressions typically as
+ :func:`_expression.column` constructs,
+ and the data is then passed via the
+ :meth:`_expression.Values.data` method as a list,
+ which can be called multiple
+ times to add more data, e.g.::
+
+ from sqlalchemy import column
+ from sqlalchemy import values
+
+ value_expr = values(
+ column('id', Integer),
+ column('name', String),
+ name="my_values"
+ ).data(
+ [(1, 'name1'), (2, 'name2'), (3, 'name3')]
+ )
+
+ :param \*columns: column expressions, typically composed using
+ :func:`_expression.column` objects.
+
+ :param name: the name for this VALUES construct. If omitted, the
+ VALUES construct will be unnamed in a SQL expression. Different
+ backends may have different requirements here.
+
+ :param literal_binds: Defaults to False. Whether or not to render
+ the data values inline in the SQL output, rather than using bound
+ parameters.
+
+ """
+
+ super(Values, self).__init__()
+ self._column_args = columns
+ self.name = kw.pop("name", None)
+ self.literal_binds = kw.pop("literal_binds", False)
+ self.named_with_column = self.name is not None
+
+ @property
+ def _column_types(self):
+ return [col.type for col in self._column_args]
+
+ @_generative
+ def alias(self, name, **kw):
+ """Return a new :class:`_expression.Values`
+ construct that is a copy of this
+ one with the given name.
+
+ This method is a VALUES-specific specialization of the
+ :meth:`_expression.FromClause.alias` method.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ self.name = name
+ self.named_with_column = self.name is not None
+
+ @_generative
+ def lateral(self, name=None):
+ """Return a new :class:`_expression.Values` with the lateral flag set,
+ so that
+ it renders as LATERAL.
+
+ .. seealso::
+
+ :func:`_expression.lateral`
+
+ """
+ self._is_lateral = True
+ if name is not None:
+ self.name = name
+
+ @_generative
+ def data(self, values):
+ """Return a new :class:`_expression.Values` construct,
+ adding the given data
+ to the data list.
+
+ E.g.::
+
+ my_values = my_values.data([(1, 'value 1'), (2, 'value2')])
+
+ :param values: a sequence (i.e. list) of tuples that map to the
+ column expressions given in the :class:`_expression.Values`
+ constructor.
+
+ """
+
+ self._data += (values,)
+
+ def _populate_column_collection(self):
+ for c in self._column_args:
+ self._columns.add(c)
+ c.table = self
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+
+class SelectBase(
+ roles.SelectStatementRole,
+ roles.DMLSelectRole,
+ roles.CompoundElementRole,
+ roles.InElementRole,
+ HasCTE,
+ Executable,
+ SupportsCloneAnnotations,
+ Selectable,
+):
+ """Base class for SELECT statements.
+
+
+ This includes :class:`_expression.Select`,
+ :class:`_expression.CompoundSelect` and
+ :class:`_expression.TextualSelect`.
+
+
+ """
+
+ _is_select_statement = True
+ is_select = True
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ raise NotImplementedError()
+
+ def _refresh_for_new_column(self, column):
+ self._reset_memoizations()
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ .. note::
+
+ The :attr:`_sql.SelectBase.selected_columns` collection does not
+ include expressions established in the columns clause using the
+ :func:`_sql.text` construct; these are silently omitted from the
+ collection. To use plain textual column expressions inside of a
+ :class:`_sql.Select` construct, use the :func:`_sql.literal_column`
+ construct.
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ .. versionadded:: 1.4
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def _all_selected_columns(self):
+ """A sequence of expressions that correspond to what is rendered
+ in the columns clause, including :class:`_sql.TextClause`
+ constructs.
+
+ .. versionadded:: 1.4.12
+
+ .. seealso::
+
+ :attr:`_sql.SelectBase.exported_columns`
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.Selectable`, not including
+ :class:`_sql.TextClause` constructs.
+
+ The "exported" columns for a :class:`_expression.SelectBase`
+ object are synonymous
+ with the :attr:`_expression.SelectBase.selected_columns` collection.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.Select.exported_columns`
+
+ :attr:`_expression.Selectable.exported_columns`
+
+ :attr:`_expression.FromClause.exported_columns`
+
+
+ """
+ return self.selected_columns
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.SelectBase.c` and "
+ ":attr:`_expression.SelectBase.columns` attributes "
+ "are deprecated and will be removed in a future release; these "
+ "attributes implicitly create a subquery that should be explicit. "
+ "Please call :meth:`_expression.SelectBase.subquery` "
+ "first in order to create "
+ "a subquery, which then contains this attribute. To access the "
+ "columns that this SELECT object SELECTs "
+ "from, use the :attr:`_expression.SelectBase.selected_columns` "
+ "attribute.",
+ )
+ def c(self):
+ return self._implicit_subquery.columns
+
+ @property
+ def columns(self):
+ return self.c
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.SelectBase.select` method is deprecated "
+ "and will be removed in a future release; this method implicitly "
+ "creates a subquery that should be explicit. "
+ "Please call :meth:`_expression.SelectBase.subquery` "
+ "first in order to create "
+ "a subquery, which then can be selected.",
+ )
+ def select(self, *arg, **kw):
+ return self._implicit_subquery.select(*arg, **kw)
+
+ @HasMemoized.memoized_attribute
+ def _implicit_subquery(self):
+ return self.subquery()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.SelectBase.as_scalar` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_expression.SelectBase.scalar_subquery`.",
+ )
+ def as_scalar(self):
+ return self.scalar_subquery()
+
+ def exists(self):
+ """Return an :class:`_sql.Exists` representation of this selectable,
+ which can be used as a column expression.
+
+ The returned object is an instance of :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :func:`_sql.exists`
+
+ :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial.
+
+ .. versionadded:: 1.4
+
+ """
+ return Exists(self)
+
+ def scalar_subquery(self):
+ """Return a 'scalar' representation of this selectable, which can be
+ used as a column expression.
+
+ The returned object is an instance of :class:`_sql.ScalarSelect`.
+
+ Typically, a select statement which has only one column in its columns
+ clause is eligible to be used as a scalar expression. The scalar
+ subquery can then be used in the WHERE clause or columns clause of
+ an enclosing SELECT.
+
+ Note that the scalar subquery differentiates from the FROM-level
+ subquery that can be produced using the
+ :meth:`_expression.SelectBase.subquery`
+ method.
+
+ .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to
+ :meth:`_expression.SelectBase.scalar_subquery`.
+
+ .. seealso::
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+ """
+ if self._label_style is not LABEL_STYLE_NONE:
+ self = self.set_label_style(LABEL_STYLE_NONE)
+
+ return ScalarSelect(self)
+
+ def label(self, name):
+ """Return a 'scalar' representation of this selectable, embedded as a
+ subquery with a label.
+
+ .. seealso::
+
+ :meth:`_expression.SelectBase.as_scalar`.
+
+ """
+ return self.scalar_subquery().label(name)
+
+ def lateral(self, name=None):
+ """Return a LATERAL alias of this :class:`_expression.Selectable`.
+
+ The return value is the :class:`_expression.Lateral` construct also
+ provided by the top-level :func:`_expression.lateral` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+ return Lateral._factory(self, name)
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+ def subquery(self, name=None):
+ """Return a subquery of this :class:`_expression.SelectBase`.
+
+ A subquery is from a SQL perspective a parenthesized, named
+ construct that can be placed in the FROM clause of another
+ SELECT statement.
+
+ Given a SELECT statement such as::
+
+ stmt = select(table.c.id, table.c.name)
+
+ The above statement might look like::
+
+ SELECT table.id, table.name FROM table
+
+ The subquery form by itself renders the same way, however when
+ embedded into the FROM clause of another SELECT statement, it becomes
+ a named sub-element::
+
+ subq = stmt.subquery()
+ new_stmt = select(subq)
+
+ The above renders as::
+
+ SELECT anon_1.id, anon_1.name
+ FROM (SELECT table.id, table.name FROM table) AS anon_1
+
+ Historically, :meth:`_expression.SelectBase.subquery`
+ is equivalent to calling
+ the :meth:`_expression.FromClause.alias`
+ method on a FROM object; however,
+ as a :class:`_expression.SelectBase`
+ object is not directly FROM object,
+ the :meth:`_expression.SelectBase.subquery`
+ method provides clearer semantics.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return Subquery._construct(self._ensure_disambiguated_names(), name)
+
+ def _ensure_disambiguated_names(self):
+ """Ensure that the names generated by this selectbase will be
+ disambiguated in some way, if possible.
+
+ """
+
+ raise NotImplementedError()
+
+ def alias(self, name=None, flat=False):
+ """Return a named subquery against this
+ :class:`_expression.SelectBase`.
+
+ For a :class:`_expression.SelectBase` (as opposed to a
+ :class:`_expression.FromClause`),
+ this returns a :class:`.Subquery` object which behaves mostly the
+ same as the :class:`_expression.Alias` object that is used with a
+ :class:`_expression.FromClause`.
+
+ .. versionchanged:: 1.4 The :meth:`_expression.SelectBase.alias`
+ method is now
+ a synonym for the :meth:`_expression.SelectBase.subquery` method.
+
+ """
+ return self.subquery(name=name)
+
+
+class SelectStatementGrouping(GroupedElement, SelectBase):
+ """Represent a grouping of a :class:`_expression.SelectBase`.
+
+ This differs from :class:`.Subquery` in that we are still
+ an "inner" SELECT statement, this is strictly for grouping inside of
+ compound selects.
+
+ """
+
+ __visit_name__ = "select_statement_grouping"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ _is_select_container = True
+
+ def __init__(self, element):
+ self.element = coercions.expect(roles.SelectStatementRole, element)
+
+ def _ensure_disambiguated_names(self):
+ new_element = self.element._ensure_disambiguated_names()
+ if new_element is not self.element:
+ return SelectStatementGrouping(new_element)
+ else:
+ return self
+
+ def get_label_style(self):
+ return self._label_style
+
+ def set_label_style(self, label_style):
+ return SelectStatementGrouping(
+ self.element.set_label_style(label_style)
+ )
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
+ @property
+ def select_statement(self):
+ return self.element
+
+ def self_group(self, against=None):
+ return self
+
+ def _generate_columns_plus_names(self, anon_for_dupe_key):
+ return self.element._generate_columns_plus_names(anon_for_dupe_key)
+
+ def _generate_fromclause_column_proxies(self, subquery):
+ self.element._generate_fromclause_column_proxies(subquery)
+
+ def _generate_proxy_for_new_column(self, column, subquery):
+ return self.element._generate_proxy_for_new_column(subquery)
+
+ @property
+ def _all_selected_columns(self):
+ return self.element._all_selected_columns
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ the embedded SELECT statement returns in its result set, not including
+ :class:`_sql.TextClause` constructs.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ """
+ return self.element.selected_columns
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+
+class DeprecatedSelectBaseGenerations(object):
+ """A collection of methods available on :class:`_sql.Select` and
+ :class:`_sql.CompoundSelect`, these are all **deprecated** methods as they
+ modify the object in-place.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.GenerativeSelect.append_order_by` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative method "
+ ":meth:`_expression.GenerativeSelect.order_by`.",
+ )
+ def append_order_by(self, *clauses):
+ """Append the given ORDER BY criterion applied to this selectable.
+
+ The criterion will be appended to any pre-existing ORDER BY criterion.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.GenerativeSelect.order_by` method is preferred,
+ as it
+ provides standard :term:`method chaining`.
+
+ .. seealso::
+
+ :meth:`_expression.GenerativeSelect.order_by`
+
+ """
+ self.order_by.non_generative(self, *clauses)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.GenerativeSelect.append_group_by` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative method "
+ ":meth:`_expression.GenerativeSelect.group_by`.",
+ )
+ def append_group_by(self, *clauses):
+ """Append the given GROUP BY criterion applied to this selectable.
+
+ The criterion will be appended to any pre-existing GROUP BY criterion.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.GenerativeSelect.group_by` method is preferred,
+ as it
+ provides standard :term:`method chaining`.
+
+
+ """
+ self.group_by.non_generative(self, *clauses)
+
+
+class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
+ """Base class for SELECT statements where additional elements can be
+ added.
+
+ This serves as the base for :class:`_expression.Select` and
+ :class:`_expression.CompoundSelect`
+ where elements such as ORDER BY, GROUP BY can be added and column
+ rendering can be controlled. Compare to
+ :class:`_expression.TextualSelect`, which,
+ while it subclasses :class:`_expression.SelectBase`
+ and is also a SELECT construct,
+ represents a fixed textual string which cannot be altered at this level,
+ only wrapped as a subquery.
+
+ """
+
+ _order_by_clauses = ()
+ _group_by_clauses = ()
+ _limit_clause = None
+ _offset_clause = None
+ _fetch_clause = None
+ _fetch_clause_options = None
+ _for_update_arg = None
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.select.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ _label_style=LABEL_STYLE_DEFAULT,
+ use_labels=False,
+ limit=None,
+ offset=None,
+ order_by=None,
+ group_by=None,
+ bind=None,
+ ):
+ if use_labels:
+ if util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "The use_labels=True keyword argument to GenerativeSelect "
+ "is deprecated and will be removed in version 2.0. Please "
+ "use "
+ "select.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "if you need to replicate this legacy behavior.",
+ stacklevel=4,
+ )
+ _label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+
+ self._label_style = _label_style
+
+ if limit is not None:
+ self.limit.non_generative(self, limit)
+ if offset is not None:
+ self.offset.non_generative(self, offset)
+
+ if order_by is not None:
+ self.order_by.non_generative(self, *util.to_list(order_by))
+ if group_by is not None:
+ self.group_by.non_generative(self, *util.to_list(group_by))
+
+ self._bind = bind
+
+ @_generative
+ def with_for_update(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """Specify a ``FOR UPDATE`` clause for this
+ :class:`_expression.GenerativeSelect`.
+
+ E.g.::
+
+ stmt = select(table).with_for_update(nowait=True)
+
+ On a database like PostgreSQL or Oracle, the above would render a
+ statement like::
+
+ SELECT table.a, table.b FROM table FOR UPDATE NOWAIT
+
+ on other backends, the ``nowait`` option is ignored and instead
+ would produce::
+
+ SELECT table.a, table.b FROM table FOR UPDATE
+
+ When called with no arguments, the statement will render with
+ the suffix ``FOR UPDATE``. Additional arguments can then be
+ provided which allow for common database-specific
+ variants.
+
+ :param nowait: boolean; will render ``FOR UPDATE NOWAIT`` on Oracle
+ and PostgreSQL dialects.
+
+ :param read: boolean; will render ``LOCK IN SHARE MODE`` on MySQL,
+ ``FOR SHARE`` on PostgreSQL. On PostgreSQL, when combined with
+ ``nowait``, will render ``FOR SHARE NOWAIT``.
+
+ :param of: SQL expression or list of SQL expression elements
+ (typically :class:`_schema.Column`
+ objects or a compatible expression) which
+ will render into a ``FOR UPDATE OF`` clause; supported by PostgreSQL
+ and Oracle. May render as a table or as a column depending on
+ backend.
+
+ :param skip_locked: boolean, will render ``FOR UPDATE SKIP LOCKED``
+ on Oracle and PostgreSQL dialects or ``FOR SHARE SKIP LOCKED`` if
+ ``read=True`` is also specified.
+
+ :param key_share: boolean, will render ``FOR NO KEY UPDATE``,
+ or if combined with ``read=True`` will render ``FOR KEY SHARE``,
+ on the PostgreSQL dialect.
+
+ """
+ self._for_update_arg = ForUpdateArg(
+ nowait=nowait,
+ read=read,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
+
+ def get_label_style(self):
+ """
+ Retrieve the current label style.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._label_style
+
+ def set_label_style(self, style):
+ """Return a new selectable with the specified label style.
+
+ There are three "label styles" available,
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`,
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`, and
+ :data:`_sql.LABEL_STYLE_NONE`. The default style is
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`.
+
+ In modern SQLAlchemy, there is not generally a need to change the
+ labeling style, as per-expression labels are more effectively used by
+ making use of the :meth:`_sql.ColumnElement.label` method. In past
+ versions, :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL` was used to
+ disambiguate same-named columns from different tables, aliases, or
+ subqueries; the newer :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` now
+ applies labels only to names that conflict with an existing name so
+ that the impact of this labeling is minimal.
+
+ The rationale for disambiguation is mostly so that all column
+ expressions are available from a given :attr:`_sql.FromClause.c`
+ collection when a subquery is created.
+
+ .. versionadded:: 1.4 - the
+ :meth:`_sql.GenerativeSelect.set_label_style` method replaces the
+ previous combination of ``.apply_labels()``, ``.with_labels()`` and
+ ``use_labels=True`` methods and/or parameters.
+
+ .. seealso::
+
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`
+
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`
+
+ :data:`_sql.LABEL_STYLE_NONE`
+
+ :data:`_sql.LABEL_STYLE_DEFAULT`
+
+ """
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
+
+ @util.deprecated_20(
+ ":meth:`_sql.GenerativeSelect.apply_labels`",
+ alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "instead.",
+ )
+ def apply_labels(self):
+ return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ @property
+ def _group_by_clause(self):
+ """ClauseList access to group_by_clauses for legacy dialects"""
+ return ClauseList._construct_raw(
+ operators.comma_op, self._group_by_clauses
+ )
+
+ @property
+ def _order_by_clause(self):
+ """ClauseList access to order_by_clauses for legacy dialects"""
+ return ClauseList._construct_raw(
+ operators.comma_op, self._order_by_clauses
+ )
+
+ def _offset_or_limit_clause(self, element, name=None, type_=None):
+ """Convert the given value to an "offset or limit" clause.
+
+ This handles incoming integers and converts to an expression; if
+ an expression is already given, it is passed through.
+
+ """
+ return coercions.expect(
+ roles.LimitOffsetRole, element, name=name, type_=type_
+ )
+
+ def _offset_or_limit_clause_asint(self, clause, attrname):
+ """Convert the "offset or limit" clause of a select construct to an
+ integer.
+
+ This is only possible if the value is stored as a simple bound
+ parameter. Otherwise, a compilation error is raised.
+
+ """
+ if clause is None:
+ return None
+ try:
+ value = clause._limit_offset_value
+ except AttributeError as err:
+ util.raise_(
+ exc.CompileError(
+ "This SELECT structure does not use a simple "
+ "integer value for %s" % attrname
+ ),
+ replace_context=err,
+ )
+ else:
+ return util.asint(value)
+
+ @property
+ def _limit(self):
+ """Get an integer value for the limit. This should only be used
+ by code that cannot support a limit as a BindParameter or
+ other custom clause as it will throw an exception if the limit
+ isn't currently set to an integer.
+
+ """
+ return self._offset_or_limit_clause_asint(self._limit_clause, "limit")
+
+ def _simple_int_clause(self, clause):
+ """True if the clause is a simple integer, False
+ if it is not present or is a SQL expression.
+ """
+ return isinstance(clause, _OffsetLimitParam)
+
+ @property
+ def _offset(self):
+ """Get an integer value for the offset. This should only be used
+ by code that cannot support an offset as a BindParameter or
+ other custom clause as it will throw an exception if the
+ offset isn't currently set to an integer.
+
+ """
+ return self._offset_or_limit_clause_asint(
+ self._offset_clause, "offset"
+ )
+
+ @property
+ def _has_row_limiting_clause(self):
+ return (
+ self._limit_clause is not None
+ or self._offset_clause is not None
+ or self._fetch_clause is not None
+ )
+
+ @_generative
+ def limit(self, limit):
+ """Return a new selectable with the given LIMIT criterion
+ applied.
+
+ This is a numerical value which usually renders as a ``LIMIT``
+ expression in the resulting select. Backends that don't
+ support ``LIMIT`` will attempt to provide similar
+ functionality.
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.limit` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.fetch`.
+
+ .. versionchanged:: 1.0.0 - :meth:`_expression.Select.limit` can now
+ accept arbitrary SQL expressions as well as integer values.
+
+ :param limit: an integer LIMIT parameter, or a SQL expression
+ that provides an integer result. Pass ``None`` to reset it.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ """
+
+ self._fetch_clause = self._fetch_clause_options = None
+ self._limit_clause = self._offset_or_limit_clause(limit)
+
+ @_generative
+ def fetch(self, count, with_ties=False, percent=False):
+ """Return a new selectable with the given FETCH FIRST criterion
+ applied.
+
+ This is a numeric value which usually renders as
+ ``FETCH {FIRST | NEXT} [ count ] {ROW | ROWS} {ONLY | WITH TIES}``
+ expression in the resulting select. This functionality is
+ is currently implemented for Oracle, PostgreSQL, MSSQL.
+
+ Use :meth:`_sql.GenerativeSelect.offset` to specify the offset.
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.fetch` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.limit`.
+
+ .. versionadded:: 1.4
+
+ :param count: an integer COUNT parameter, or a SQL expression
+ that provides an integer result. When ``percent=True`` this will
+ represent the percentage of rows to return, not the absolute value.
+ Pass ``None`` to reset it.
+
+ :param with_ties: When ``True``, the WITH TIES option is used
+ to return any additional rows that tie for the last place in the
+ result set according to the ``ORDER BY`` clause. The
+ ``ORDER BY`` may be mandatory in this case. Defaults to ``False``
+
+ :param percent: When ``True``, ``count`` represents the percentage
+ of the total number of selected rows to return. Defaults to ``False``
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ """
+
+ self._limit_clause = None
+ if count is None:
+ self._fetch_clause = self._fetch_clause_options = None
+ else:
+ self._fetch_clause = self._offset_or_limit_clause(count)
+ self._fetch_clause_options = {
+ "with_ties": with_ties,
+ "percent": percent,
+ }
+
+ @_generative
+ def offset(self, offset):
+ """Return a new selectable with the given OFFSET criterion
+ applied.
+
+
+ This is a numeric value which usually renders as an ``OFFSET``
+ expression in the resulting select. Backends that don't
+ support ``OFFSET`` will attempt to provide similar
+ functionality.
+
+
+ .. versionchanged:: 1.0.0 - :meth:`_expression.Select.offset` can now
+ accept arbitrary SQL expressions as well as integer values.
+
+ :param offset: an integer OFFSET parameter, or a SQL expression
+ that provides an integer result. Pass ``None`` to reset it.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ """
+
+ self._offset_clause = self._offset_or_limit_clause(offset)
+
+ @_generative
+ @util.preload_module("sqlalchemy.sql.util")
+ def slice(self, start, stop):
+ """Apply LIMIT / OFFSET to this statement based on a slice.
+
+ The start and stop indices behave like the argument to Python's
+ built-in :func:`range` function. This method provides an
+ alternative to using ``LIMIT``/``OFFSET`` to get a slice of the
+ query.
+
+ For example, ::
+
+ stmt = select(User).order_by(User).id.slice(1, 3)
+
+ renders as
+
+ .. sourcecode:: sql
+
+ SELECT users.id AS users_id,
+ users.name AS users_name
+ FROM users ORDER BY users.id
+ LIMIT ? OFFSET ?
+ (2, 1)
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.slice` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.fetch`.
+
+ .. versionadded:: 1.4 Added the :meth:`_sql.GenerativeSelect.slice`
+ method generalized from the ORM.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ """
+ sql_util = util.preloaded.sql_util
+ self._fetch_clause = self._fetch_clause_options = None
+ self._limit_clause, self._offset_clause = sql_util._make_slice(
+ self._limit_clause, self._offset_clause, start, stop
+ )
+
+ @_generative
+ def order_by(self, *clauses):
+ r"""Return a new selectable with the given list of ORDER BY
+ criteria applied.
+
+ e.g.::
+
+ stmt = select(table).order_by(table.c.id, table.c.name)
+
+ All existing ORDER BY criteria may be cancelled by passing
+ ``None`` by itself. New ORDER BY criteria may then be added by
+ invoking :meth:`_sql.Select.order_by` again, e.g.::
+
+ # will erase all ORDER BY and ORDER BY new_col alone
+ stmt = stmt.order_by(None).order_by(new_col)
+
+ :param \*clauses: a series of :class:`_expression.ColumnElement`
+ constructs
+ which will be used to generate an ORDER BY clause.
+
+ .. seealso::
+
+ :ref:`tutorial_order_by` - in the :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and clauses[0] is None:
+ self._order_by_clauses = ()
+ else:
+ self._order_by_clauses += tuple(
+ coercions.expect(roles.OrderByRole, clause)
+ for clause in clauses
+ )
+
+ @_generative
+ def group_by(self, *clauses):
+ r"""Return a new selectable with the given list of GROUP BY
+ criterion applied.
+
+ All existing GROUP BY settings can be suppressed by passing ``None``.
+
+ e.g.::
+
+ stmt = select(table.c.name, func.max(table.c.stat)).\
+ group_by(table.c.name)
+
+ :param \*clauses: a series of :class:`_expression.ColumnElement`
+ constructs
+ which will be used to generate an GROUP BY clause.
+
+ .. seealso::
+
+ :ref:`tutorial_group_by_w_aggregates` - in the
+ :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and clauses[0] is None:
+ self._group_by_clauses = ()
+ else:
+ self._group_by_clauses += tuple(
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
+ )
+
+
+@CompileState.plugin_for("default", "compound_select")
+class CompoundSelectState(CompileState):
+ @util.memoized_property
+ def _label_resolve_dict(self):
+ # TODO: this is hacky and slow
+ hacky_subquery = self.statement.subquery()
+ hacky_subquery.named_with_column = False
+ d = dict((c.key, c) for c in hacky_subquery.c)
+ return d, d, d
+
+
+class CompoundSelect(HasCompileState, GenerativeSelect):
+ """Forms the basis of ``UNION``, ``UNION ALL``, and other
+ SELECT-based set operations.
+
+
+ .. seealso::
+
+ :func:`_expression.union`
+
+ :func:`_expression.union_all`
+
+ :func:`_expression.intersect`
+
+ :func:`_expression.intersect_all`
+
+ :func:`_expression.except`
+
+ :func:`_expression.except_all`
+
+ """
+
+ __visit_name__ = "compound_select"
+
+ _traverse_internals = [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+
+ UNION = util.symbol("UNION")
+ UNION_ALL = util.symbol("UNION ALL")
+ EXCEPT = util.symbol("EXCEPT")
+ EXCEPT_ALL = util.symbol("EXCEPT ALL")
+ INTERSECT = util.symbol("INTERSECT")
+ INTERSECT_ALL = util.symbol("INTERSECT ALL")
+
+ _is_from_container = True
+
+ def __init__(self, keyword, *selects, **kwargs):
+ self._auto_correlate = kwargs.pop("correlate", False)
+ self.keyword = keyword
+ self.selects = [
+ coercions.expect(roles.CompoundElementRole, s).self_group(
+ against=self
+ )
+ for s in selects
+ ]
+
+ if kwargs and util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "Set functions such as union(), union_all(), extract(), etc. "
+ "in SQLAlchemy 2.0 will accept a "
+ "series of SELECT statements only. "
+ "Please use generative methods such as order_by() for "
+ "additional modifications to this CompoundSelect.",
+ stacklevel=4,
+ )
+
+ GenerativeSelect.__init__(self, **kwargs)
+
+ @classmethod
+ def _create_union(cls, *selects, **kwargs):
+ r"""Return a ``UNION`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ A similar :func:`union()` method is available on all
+ :class:`_expression.FromClause` subclasses.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs)
+
+ @classmethod
+ def _create_union_all(cls, *selects, **kwargs):
+ r"""Return a ``UNION ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ A similar :func:`union_all()` method is available on all
+ :class:`_expression.FromClause` subclasses.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs)
+
+ @classmethod
+ def _create_except(cls, *selects, **kwargs):
+ r"""Return an ``EXCEPT`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs)
+
+ @classmethod
+ def _create_except_all(cls, *selects, **kwargs):
+ r"""Return an ``EXCEPT ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs)
+
+ @classmethod
+ def _create_intersect(cls, *selects, **kwargs):
+ r"""Return an ``INTERSECT`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs)
+
+ @classmethod
+ def _create_intersect_all(cls, *selects, **kwargs):
+ r"""Return an ``INTERSECT ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
+
+ def _scalar_type(self):
+ return self.selects[0]._scalar_type()
+
+ def self_group(self, against=None):
+ return SelectStatementGrouping(self)
+
+ def is_derived_from(self, fromclause):
+ for s in self.selects:
+ if s.is_derived_from(fromclause):
+ return True
+ return False
+
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ select_0 = self.selects[0]._set_label_style(style)
+ self.selects = [select_0] + self.selects[1:]
+
+ return self
+
+ def _ensure_disambiguated_names(self):
+ new_select = self.selects[0]._ensure_disambiguated_names()
+ if new_select is not self.selects[0]:
+ self = self._generate()
+ self.selects = [new_select] + self.selects[1:]
+
+ return self
+
+ def _generate_fromclause_column_proxies(self, subquery):
+
+ # this is a slightly hacky thing - the union exports a
+ # column that resembles just that of the *first* selectable.
+ # to get at a "composite" column, particularly foreign keys,
+ # you have to dig through the proxies collection which we
+ # generate below. We may want to improve upon this, such as
+ # perhaps _make_proxy can accept a list of other columns
+ # that are "shared" - schema.column can then copy all the
+ # ForeignKeys in. this would allow the union() to have all
+ # those fks too.
+ select_0 = self.selects[0]
+
+ if self._label_style is not LABEL_STYLE_DEFAULT:
+ select_0 = select_0.set_label_style(self._label_style)
+ select_0._generate_fromclause_column_proxies(subquery)
+
+ # hand-construct the "_proxies" collection to include all
+ # derived columns place a 'weight' annotation corresponding
+ # to how low in the list of select()s the column occurs, so
+ # that the corresponding_column() operation can resolve
+ # conflicts
+
+ for subq_col, select_cols in zip(
+ subquery.c._all_columns,
+ zip(*[s.selected_columns for s in self.selects]),
+ ):
+ subq_col._proxies = [
+ c._annotate({"weight": i + 1})
+ for (i, c) in enumerate(select_cols)
+ ]
+
+ def _refresh_for_new_column(self, column):
+ super(CompoundSelect, self)._refresh_for_new_column(column)
+ for select in self.selects:
+ select._refresh_for_new_column(column)
+
+ @property
+ def _all_selected_columns(self):
+ return self.selects[0]._all_selected_columns
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ For a :class:`_expression.CompoundSelect`, the
+ :attr:`_expression.CompoundSelect.selected_columns`
+ attribute returns the selected
+ columns of the first SELECT statement contained within the series of
+ statements within the set operation.
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ .. versionadded:: 1.4
+
+ """
+ return self.selects[0].selected_columns
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this :class:`.Executable` is bound, or None if none found.
+
+ """
+ if self._bind:
+ return self._bind
+ for s in self.selects:
+ e = s.bind
+ if e:
+ return e
+ else:
+ return None
+
+ @bind.setter
+ def bind(self, bind):
+ self._bind = bind
+
+
+class DeprecatedSelectGenerations(object):
+ """A collection of methods available on :class:`_sql.Select`, these
+ are all **deprecated** methods as they modify the :class:`_sql.Select`
+ object in -place.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_correlation` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.correlate`.",
+ )
+ def append_correlation(self, fromclause):
+ """Append the given correlation expression to this select()
+ construct.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.correlate` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+
+ self.correlate.non_generative(self, fromclause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_column` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.add_columns`.",
+ )
+ def append_column(self, column):
+ """Append the given column expression to the columns clause of this
+ select() construct.
+
+ E.g.::
+
+ my_select.append_column(some_table.c.new_column)
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.add_columns` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+ self.add_columns.non_generative(self, column)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_prefix` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.prefix_with`.",
+ )
+ def append_prefix(self, clause):
+ """Append the given columns clause prefix expression to this select()
+ construct.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.prefix_with` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+ self.prefix_with.non_generative(self, clause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_whereclause` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.where`.",
+ )
+ def append_whereclause(self, whereclause):
+ """Append the given expression to this select() construct's WHERE
+ criterion.
+
+ The expression will be joined to existing WHERE criterion via AND.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.where` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+ self.where.non_generative(self, whereclause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_having` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.having`.",
+ )
+ def append_having(self, having):
+ """Append the given expression to this select() construct's HAVING
+ criterion.
+
+ The expression will be joined to existing HAVING criterion via AND.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.having` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+
+ self.having.non_generative(self, having)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_from` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.select_from`.",
+ )
+ def append_from(self, fromclause):
+ """Append the given :class:`_expression.FromClause` expression
+ to this select() construct's FROM clause.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.select_from` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+ self.select_from.non_generative(self, fromclause)
+
+
+@CompileState.plugin_for("default", "select")
+class SelectState(util.MemoizedSlots, CompileState):
+ __slots__ = (
+ "from_clauses",
+ "froms",
+ "columns_plus_names",
+ "_label_resolve_dict",
+ )
+
+ class default_select_compile_options(CacheableOptions):
+ _cache_key_traversal = []
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+ self.from_clauses = statement._from_obj
+
+ for memoized_entities in statement._memoized_select_entities:
+ self._setup_joins(
+ memoized_entities._setup_joins, memoized_entities._raw_columns
+ )
+
+ if statement._setup_joins:
+ self._setup_joins(statement._setup_joins, statement._raw_columns)
+
+ self.froms = self._get_froms(statement)
+
+ self.columns_plus_names = statement._generate_columns_plus_names(True)
+
+ @classmethod
+ def _plugin_not_implemented(cls):
+ raise NotImplementedError(
+ "The default SELECT construct without plugins does not "
+ "implement this method."
+ )
+
+ @classmethod
+ def get_column_descriptions(cls, statement):
+ return [
+ {
+ "name": name,
+ "type": element.type,
+ "expr": element,
+ }
+ for _, name, _, element, _ in (
+ statement._generate_columns_plus_names(False)
+ )
+ ]
+
+ @classmethod
+ def from_statement(cls, statement, from_statement):
+ cls._plugin_not_implemented()
+
+ @classmethod
+ def get_columns_clause_froms(cls, statement):
+ return cls._normalize_froms(
+ itertools.chain.from_iterable(
+ element._from_objects for element in statement._raw_columns
+ )
+ )
+
+ @classmethod
+ def _column_naming_convention(cls, label_style):
+
+ table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ dedupe = label_style is not LABEL_STYLE_NONE
+
+ pa = prefix_anon_map()
+ names = set()
+
+ def go(c, col_name=None):
+ if c._is_text_clause:
+ return None
+
+ elif not dedupe:
+ name = c._proxy_key
+ if name is None:
+ name = "_no_label"
+ return name
+
+ name = c._tq_key_label if table_qualified else c._proxy_key
+
+ if name is None:
+ name = "_no_label"
+ if name in names:
+ return c._anon_label(name) % pa
+ else:
+ names.add(name)
+ return name
+
+ elif name in names:
+ return (
+ c._anon_tq_key_label % pa
+ if table_qualified
+ else c._anon_key_label % pa
+ )
+ else:
+ names.add(name)
+ return name
+
+ return go
+
+ def _get_froms(self, statement):
+ return self._normalize_froms(
+ itertools.chain(
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._raw_columns
+ ]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ self.from_clauses,
+ ),
+ check_statement=statement,
+ )
+
+ @classmethod
+ def _normalize_froms(cls, iterable_of_froms, check_statement=None):
+ """given an iterable of things to select FROM, reduce them to what
+ would actually render in the FROM clause of a SELECT.
+
+ This does the job of checking for JOINs, tables, etc. that are in fact
+ overlapping due to cloning, adaption, present in overlapping joins,
+ etc.
+
+ """
+ seen = set()
+ froms = []
+
+ for item in iterable_of_froms:
+ if item._is_subquery and item.element is check_statement:
+ raise exc.InvalidRequestError(
+ "select() construct refers to itself as a FROM"
+ )
+
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ if froms:
+ toremove = set(
+ itertools.chain.from_iterable(
+ [_expand_cloned(f._hide_froms) for f in froms]
+ )
+ )
+ if toremove:
+ # filter out to FROM clauses not in the list,
+ # using a list to maintain ordering
+ froms = [f for f in froms if f not in toremove]
+
+ return froms
+
+ def _get_display_froms(
+ self, explicit_correlate_froms=None, implicit_correlate_froms=None
+ ):
+ """Return the full list of 'from' clauses to be displayed.
+
+ Takes into account a set of existing froms which may be
+ rendered in the FROM clause of enclosing selects; this Select
+ may want to leave those absent if it is automatically
+ correlating.
+
+ """
+
+ froms = self.froms
+
+ if self.statement._correlate:
+ to_correlate = self.statement._correlate
+ if to_correlate:
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(
+ _cloned_intersection(
+ froms, explicit_correlate_froms or ()
+ ),
+ to_correlate,
+ )
+ ]
+
+ if self.statement._correlate_except is not None:
+
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_difference(
+ _cloned_intersection(
+ froms, explicit_correlate_froms or ()
+ ),
+ self.statement._correlate_except,
+ )
+ ]
+
+ if (
+ self.statement._auto_correlate
+ and implicit_correlate_froms
+ and len(froms) > 1
+ ):
+
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(froms, implicit_correlate_froms)
+ ]
+
+ if not len(froms):
+ raise exc.InvalidRequestError(
+ "Select statement '%r"
+ "' returned no FROM clauses "
+ "due to auto-correlation; "
+ "specify correlate(<tables>) "
+ "to control correlation "
+ "manually." % self.statement
+ )
+
+ return froms
+
+ def _memoized_attr__label_resolve_dict(self):
+ with_cols = dict(
+ (c._tq_label or c.key, c)
+ for c in self.statement._all_selected_columns
+ if c._allow_label_resolve
+ )
+ only_froms = dict(
+ (c.key, c)
+ for c in _select_iterables(self.froms)
+ if c._allow_label_resolve
+ )
+ only_cols = with_cols.copy()
+ for key, value in only_froms.items():
+ with_cols.setdefault(key, value)
+
+ return with_cols, only_froms, only_cols
+
+ @classmethod
+ def determine_last_joined_entity(cls, stmt):
+ if stmt._setup_joins:
+ return stmt._setup_joins[-1][0]
+ else:
+ return None
+
+ @classmethod
+ def all_selected_columns(cls, statement):
+ return [c for c in _select_iterables(statement._raw_columns)]
+
+ def _setup_joins(self, args, raw_columns):
+ for (right, onclause, left, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+
+ if left is None:
+ (
+ left,
+ replace_from_obj_index,
+ ) = self._join_determine_implicit_left_side(
+ raw_columns, left, right, onclause
+ )
+ else:
+ (replace_from_obj_index) = self._join_place_explicit_left_side(
+ left
+ )
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + (
+ Join(
+ left_clause,
+ right,
+ onclause,
+ isouter=isouter,
+ full=full,
+ ),
+ )
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+
+ self.from_clauses = self.from_clauses + (
+ Join(left, right, onclause, isouter=isouter, full=full),
+ )
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_determine_implicit_left_side(
+ self, raw_columns, left, right, onclause
+ ):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ sql_util = util.preloaded.sql_util
+
+ replace_from_obj_index = None
+
+ from_clauses = self.from_clauses
+
+ if from_clauses:
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ from_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = from_clauses[replace_from_obj_index]
+ else:
+ potential = {}
+ statement = self.statement
+
+ for from_clause in itertools.chain(
+ itertools.chain.from_iterable(
+ [element._from_objects for element in raw_columns]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ ):
+
+ potential[from_clause] = ()
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ left = all_clauses[indexes[0]]
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already to "
+ "help resolve the ambiguity."
+ )
+ elif not indexes:
+ raise exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already to "
+ "help resolve the ambiguity." % (right,)
+ )
+ return left, replace_from_obj_index
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_place_explicit_left_side(self, left):
+ replace_from_obj_index = None
+
+ sql_util = util.preloaded.sql_util
+
+ from_clauses = list(self.statement._iterate_from_elements())
+
+ if from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, left
+ )
+ else:
+ indexes = []
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ return replace_from_obj_index
+
+
+class _SelectFromElements(object):
+ def _iterate_from_elements(self):
+ # note this does not include elements
+ # in _setup_joins or _legacy_setup_joins
+
+ seen = set()
+ for element in self._raw_columns:
+ for fr in element._from_objects:
+ if fr in seen:
+ continue
+ seen.add(fr)
+ yield fr
+ for element in self._where_criteria:
+ for fr in element._from_objects:
+ if fr in seen:
+ continue
+ seen.add(fr)
+ yield fr
+ for element in self._from_obj:
+ if element in seen:
+ continue
+ seen.add(element)
+ yield element
+
+
+class _MemoizedSelectEntities(
+ traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+ __visit_name__ = "memoized_select_entities"
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_with_options", InternalTraversal.dp_executable_options),
+ ]
+
+ _annotations = util.EMPTY_DICT
+
+ def _clone(self, **kw):
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = {k: v for k, v in self.__dict__.items()}
+
+ c._is_clone_of = self.__dict__.get("_is_clone_of", self)
+ return c
+
+ @classmethod
+ def _generate_for_statement(cls, select_stmt):
+ if (
+ select_stmt._setup_joins
+ or select_stmt._legacy_setup_joins
+ or select_stmt._with_options
+ ):
+ self = _MemoizedSelectEntities()
+ self._raw_columns = select_stmt._raw_columns
+ self._setup_joins = select_stmt._setup_joins
+ self._legacy_setup_joins = select_stmt._legacy_setup_joins
+ self._with_options = select_stmt._with_options
+
+ select_stmt._memoized_select_entities += (self,)
+ select_stmt._raw_columns = (
+ select_stmt._setup_joins
+ ) = (
+ select_stmt._legacy_setup_joins
+ ) = select_stmt._with_options = ()
+
+
+class Select(
+ HasPrefixes,
+ HasSuffixes,
+ HasHints,
+ HasCompileState,
+ DeprecatedSelectGenerations,
+ _SelectFromElements,
+ GenerativeSelect,
+):
+ """Represents a ``SELECT`` statement.
+
+ The :class:`_sql.Select` object is normally constructed using the
+ :func:`_sql.select` function. See that function for details.
+
+ .. seealso::
+
+ :func:`_sql.select`
+
+ :ref:`tutorial_selecting_data` - in the 2.0 tutorial
+
+ """
+
+ __visit_name__ = "select"
+
+ _setup_joins = ()
+ _legacy_setup_joins = ()
+ _memoized_select_entities = ()
+
+ _distinct = False
+ _distinct_on = ()
+ _correlate = ()
+ _correlate_except = None
+ _where_criteria = ()
+ _having_criteria = ()
+ _from_obj = ()
+ _auto_correlate = True
+
+ _compile_options = SelectState.default_select_compile_options
+
+ _traverse_internals = (
+ [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ (
+ "_memoized_select_entities",
+ InternalTraversal.dp_memoized_select_entities,
+ ),
+ ("_from_obj", InternalTraversal.dp_clauseelement_list),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_correlate", InternalTraversal.dp_clauseelement_tuple),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_tuple),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("_distinct", InternalTraversal.dp_boolean),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
+ ("_label_style", InternalTraversal.dp_plain_obj),
+ ]
+ + HasCTE._has_ctes_traverse_internals
+ + HasPrefixes._has_prefixes_traverse_internals
+ + HasSuffixes._has_suffixes_traverse_internals
+ + HasHints._has_hints_traverse_internals
+ + SupportsCloneAnnotations._clone_annotations_traverse_internals
+ + Executable._executable_traverse_internals
+ )
+
+ _cache_key_traversal = _traverse_internals + [
+ ("_compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
+ @classmethod
+ def _create_select_from_fromclause(cls, target, entities, *arg, **kw):
+ if arg or kw:
+ return Select.create_legacy_select(entities, *arg, **kw)
+ else:
+ return Select._create_select(*entities)
+
+ @classmethod
+ @util.deprecated(
+ "2.0",
+ "The legacy calling style of :func:`_sql.select` is deprecated and "
+ "will be removed in SQLAlchemy 2.0. Please use the new calling "
+ "style described at :func:`_sql.select`.",
+ )
+ def create_legacy_select(
+ cls,
+ columns=None,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ suffixes=None,
+ **kwargs
+ ):
+ """Construct a new :class:`_expression.Select` using the 1.x style API.
+
+ This method is called implicitly when the :func:`_expression.select`
+ construct is used and the first argument is a Python list or other
+ plain sequence object, which is taken to refer to the columns
+ collection.
+
+ .. versionchanged:: 1.4 Added the :meth:`.Select.create_legacy_select`
+ constructor which documents the calling style in use when the
+ :func:`.select` construct is invoked using 1.x-style arguments.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ All arguments which accept :class:`_expression.ClauseElement` arguments
+ also accept string arguments, which will be converted as appropriate
+ into either :func:`_expression.text()` or
+ :func:`_expression.literal_column()` constructs.
+
+ .. seealso::
+
+ :ref:`tutorial_selecting_data` - in the :ref:`unified_tutorial`
+
+ :param columns:
+ A list of :class:`_expression.ColumnElement` or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as
+ given, as well as ORM-mapped classes.
+
+ .. note::
+
+ The :paramref:`_expression.select.columns`
+ parameter is not available
+ in the method form of :func:`_expression.select`, e.g.
+ :meth:`_expression.FromClause.select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.column`
+
+ :meth:`_expression.Select.with_only_columns`
+
+ :param whereclause:
+ A :class:`_expression.ClauseElement`
+ expression which will be used to form the
+ ``WHERE`` clause. It is typically preferable to add WHERE
+ criterion to an existing :class:`_expression.Select`
+ using method chaining
+ with :meth:`_expression.Select.where`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.where`
+
+ :param from_obj:
+ A list of :class:`_expression.ClauseElement`
+ objects which will be added to the
+ ``FROM`` clause of the resulting statement. This is equivalent
+ to calling :meth:`_expression.Select.select_from`
+ using method chaining on
+ an existing :class:`_expression.Select` object.
+
+ .. seealso::
+
+ :meth:`_expression.Select.select_from`
+ - full description of explicit
+ FROM clause specification.
+
+ :param bind=None:
+ an :class:`_engine.Engine` or :class:`_engine.Connection` instance
+ to which the
+ resulting :class:`_expression.Select` object will be bound. The
+ :class:`_expression.Select`
+ object will otherwise automatically bind to
+ whatever :class:`~.base.Connectable` instances can be located within
+ its contained :class:`_expression.ClauseElement` members.
+
+ :param correlate=True:
+ indicates that this :class:`_expression.Select`
+ object should have its
+ contained :class:`_expression.FromClause`
+ elements "correlated" to an enclosing
+ :class:`_expression.Select` object.
+ It is typically preferable to specify
+ correlations on an existing :class:`_expression.Select`
+ construct using
+ :meth:`_expression.Select.correlate`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate`
+ - full description of correlation.
+
+ :param distinct=False:
+ when ``True``, applies a ``DISTINCT`` qualifier to the columns
+ clause of the resulting statement.
+
+ The boolean argument may also be a column expression or list
+ of column expressions - this is a special calling form which
+ is understood by the PostgreSQL dialect to render the
+ ``DISTINCT ON (<columns>)`` syntax.
+
+ ``distinct`` is also available on an existing
+ :class:`_expression.Select`
+ object via the :meth:`_expression.Select.distinct` method.
+
+ .. seealso::
+
+ :meth:`_expression.Select.distinct`
+
+ :param group_by:
+ a list of :class:`_expression.ClauseElement`
+ objects which will comprise the
+ ``GROUP BY`` clause of the resulting select. This parameter
+ is typically specified more naturally using the
+ :meth:`_expression.Select.group_by` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.group_by`
+
+ :param having:
+ a :class:`_expression.ClauseElement`
+ that will comprise the ``HAVING`` clause
+ of the resulting select when ``GROUP BY`` is used. This parameter
+ is typically specified more naturally using the
+ :meth:`_expression.Select.having` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.having`
+
+ :param limit=None:
+ a numerical value which usually renders as a ``LIMIT``
+ expression in the resulting select. Backends that don't
+ support ``LIMIT`` will attempt to provide similar
+ functionality. This parameter is typically specified more
+ naturally using the :meth:`_expression.Select.limit`
+ method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.limit`
+
+ :param offset=None:
+ a numeric value which usually renders as an ``OFFSET``
+ expression in the resulting select. Backends that don't
+ support ``OFFSET`` will attempt to provide similar
+ functionality. This parameter is typically specified more naturally
+ using the :meth:`_expression.Select.offset` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.offset`
+
+ :param order_by:
+ a scalar or list of :class:`_expression.ClauseElement`
+ objects which will
+ comprise the ``ORDER BY`` clause of the resulting select.
+ This parameter is typically specified more naturally using the
+ :meth:`_expression.Select.order_by` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.order_by`
+
+ :param use_labels=False:
+ when ``True``, the statement will be generated using labels
+ for each column in the columns clause, which qualify each
+ column with its parent table's (or aliases) name so that name
+ conflicts between columns in different tables don't occur.
+ The format of the label is ``<tablename>_<column>``. The "c"
+ collection of a :class:`_expression.Subquery` created
+ against this :class:`_expression.Select`
+ object, as well as the :attr:`_expression.Select.selected_columns`
+ collection of the :class:`_expression.Select` itself, will use these
+ names for targeting column members.
+
+ This parameter can also be specified on an existing
+ :class:`_expression.Select` object using the
+ :meth:`_expression.Select.set_label_style`
+ method.
+
+ .. seealso::
+
+ :meth:`_expression.Select.set_label_style`
+
+ """
+ self = cls.__new__(cls)
+
+ self._auto_correlate = correlate
+
+ if distinct is not False:
+ if distinct is True:
+ self.distinct.non_generative(self)
+ else:
+ self.distinct.non_generative(self, *util.to_list(distinct))
+
+ if from_obj is not None:
+ self.select_from.non_generative(self, *util.to_list(from_obj))
+
+ try:
+ cols_present = bool(columns)
+ except TypeError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "select() construct created in legacy mode, i.e. with "
+ "keyword arguments, must provide the columns argument as "
+ "a Python list or other iterable.",
+ code="c9ae",
+ ),
+ from_=err,
+ )
+
+ if cols_present:
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, c, apply_propagate_attrs=self
+ )
+ for c in columns
+ ]
+ else:
+ self._raw_columns = []
+
+ if whereclause is not None:
+ self.where.non_generative(self, whereclause)
+
+ if having is not None:
+ self.having.non_generative(self, having)
+
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ if suffixes:
+ self._setup_suffixes(suffixes)
+
+ GenerativeSelect.__init__(self, **kwargs)
+ return self
+
+ @classmethod
+ def _create_future_select(cls, *entities):
+ r"""Construct a new :class:`_expression.Select` using the 2.
+ x style API.
+
+ .. versionadded:: 1.4 - The :func:`_sql.select` function now accepts
+ column arguments positionally. The top-level :func:`_sql.select`
+ function will automatically use the 1.x or 2.x style API based on
+ the incoming arguments; using :func:`_future.select` from the
+ ``sqlalchemy.future`` module will enforce that only the 2.x style
+ constructor is used.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - Core Tutorial description of
+ :func:`_expression.select`.
+
+ :param \*entities:
+ Entities to SELECT from. For Core usage, this is typically a series
+ of :class:`_expression.ColumnElement` and / or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as
+ given, as well as ORM-mapped classes.
+
+ """
+
+ self = cls.__new__(cls)
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in entities
+ ]
+
+ GenerativeSelect.__init__(self)
+
+ return self
+
+ _create_select = _create_future_select
+
+ @classmethod
+ def _create_raw_select(cls, **kw):
+ """Create a :class:`.Select` using raw ``__new__`` with no coercions.
+
+ Used internally to build up :class:`.Select` constructs with
+ pre-established state.
+
+ """
+
+ stmt = Select.__new__(Select)
+ stmt.__dict__.update(kw)
+ return stmt
+
+ @classmethod
+ def _create(cls, *args, **kw):
+ r"""Create a :class:`.Select` using either the 1.x or 2.0 constructor
+ style.
+
+ For the legacy calling style, see :meth:`.Select.create_legacy_select`.
+ If the first argument passed is a Python sequence or if keyword
+ arguments are present, this style is used.
+
+ .. versionadded:: 2.0 - the :func:`_future.select` construct is
+ the same construct as the one returned by
+ :func:`_expression.select`, except that the function only
+ accepts the "columns clause" entities up front; the rest of the
+ state of the SELECT should be built up using generative methods.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - Core Tutorial description of
+ :func:`_expression.select`.
+
+ :param \*entities:
+ Entities to SELECT from. For Core usage, this is typically a series
+ of :class:`_expression.ColumnElement` and / or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as given, as well as ORM-mapped classes.
+
+ """
+ if (
+ args
+ and (
+ isinstance(args[0], list)
+ or (
+ hasattr(args[0], "__iter__")
+ and not isinstance(
+ args[0], util.string_types + (ClauseElement,)
+ )
+ and inspect(args[0], raiseerr=False) is None
+ and not hasattr(args[0], "__clause_element__")
+ )
+ )
+ ) or kw:
+ return cls.create_legacy_select(*args, **kw)
+ else:
+ return cls._create_future_select(*args)
+
+ def __init__(self):
+ raise NotImplementedError()
+
+ def _scalar_type(self):
+ elem = self._raw_columns[0]
+ cols = list(elem._select_iterable)
+ return cols[0].type
+
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_future.Select.where` method."""
+
+ return self.where(*criteria)
+
+ def _filter_by_zero(self):
+ if self._setup_joins:
+ meth = SelectState.get_plugin_class(
+ self
+ ).determine_last_joined_entity
+ _last_joined_entity = meth(self)
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ if self._from_obj:
+ return self._from_obj[0]
+
+ return self._raw_columns[0]
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @property
+ def column_descriptions(self):
+ """Return a :term:`plugin-enabled` 'column descriptions' structure
+ referring to the columns which are SELECTed by this statement.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core-only statement, the structure returned by this accessor
+ is derived from the same objects that are returned by the
+ :attr:`.Select.selected_columns` accessor, formatted as a list of
+ dictionaries which contain the keys ``name``, ``type`` and ``expr``,
+ which indicate the column expressions to be selected::
+
+ >>> stmt = select(user_table)
+ >>> stmt.column_descriptions
+ [
+ {
+ 'name': 'id',
+ 'type': Integer(),
+ 'expr': Column('id', Integer(), ...)},
+ {
+ 'name': 'name',
+ 'type': String(length=30),
+ 'expr': Column('name', String(length=30), ...)}
+ ]
+
+ .. versionchanged:: 1.4.33 The :attr:`.Select.column_descriptions`
+ attribute returns a structure for a Core-only set of entities,
+ not just ORM-only entities.
+
+ .. seealso::
+
+ :attr:`.UpdateBase.entity_description` - entity information for
+ an :func:`.insert`, :func:`.update`, or :func:`.delete`
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """
+ meth = SelectState.get_plugin_class(self).get_column_descriptions
+ return meth(self)
+
+ def from_statement(self, statement):
+ """Apply the columns which this :class:`.Select` would select
+ onto another statement.
+
+ This operation is :term:`plugin-specific` and will raise a not
+ supported exception if this :class:`_sql.Select` does not select from
+ plugin-enabled entities.
+
+
+ The statement is typically either a :func:`_expression.text` or
+ :func:`_expression.select` construct, and should return the set of
+ columns appropriate to the entities represented by this
+ :class:`.Select`.
+
+ .. seealso::
+
+ :ref:`orm_queryguide_selecting_text` - usage examples in the
+ ORM Querying Guide
+
+ """
+ meth = SelectState.get_plugin_class(self).from_statement
+ return meth(self, statement)
+
+ @_generative
+ def join(self, target, onclause=None, isouter=False, full=False):
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_expression.Select`.
+
+ E.g.::
+
+ stmt = select(user_table).join(address_table, user_table.c.id == address_table.c.user_id)
+
+ The above statement generates SQL similar to::
+
+ SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id
+
+ .. versionchanged:: 1.4 :meth:`_expression.Select.join` now creates
+ a :class:`_sql.Join` object between a :class:`_sql.FromClause`
+ source that is within the FROM clause of the existing SELECT,
+ and a given target :class:`_sql.FromClause`, and then adds
+ this :class:`_sql.Join` to the FROM clause of the newly generated
+ SELECT statement. This is completely reworked from the behavior
+ in 1.3, which would instead create a subquery of the entire
+ :class:`_expression.Select` and then join that subquery to the
+ target.
+
+ This is a **backwards incompatible change** as the previous behavior
+ was mostly useless, producing an unnamed subquery rejected by
+ most databases in any case. The new behavior is modeled after
+ that of the very successful :meth:`_orm.Query.join` method in the
+ ORM, in order to support the functionality of :class:`_orm.Query`
+ being available by using a :class:`_sql.Select` object with an
+ :class:`_orm.Session`.
+
+ See the notes for this change at :ref:`change_select_join`.
+
+
+ :param target: target table to join towards
+
+ :param onclause: ON clause of the join. If omitted, an ON clause
+ is generated automatically based on the :class:`_schema.ForeignKey`
+ linkages between the two tables, if one can be unambiguously
+ determined, otherwise an error is raised.
+
+ :param isouter: if True, generate LEFT OUTER join. Same as
+ :meth:`_expression.Select.outerjoin`.
+
+ :param full: if True, generate FULL OUTER join.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join_from`
+
+ :meth:`_expression.Select.outerjoin`
+
+ """ # noqa: E501
+ target = coercions.expect(
+ roles.JoinTargetRole, target, apply_propagate_attrs=self
+ )
+ if onclause is not None:
+ onclause = coercions.expect(roles.OnClauseRole, onclause)
+ self._setup_joins += (
+ (target, onclause, None, {"isouter": isouter, "full": full}),
+ )
+
+ def outerjoin_from(self, from_, target, onclause=None, full=False):
+ r"""Create a SQL LEFT OUTER JOIN against this
+ :class:`_expression.Select` object's criterion and apply generatively,
+ returning the newly resulting :class:`_expression.Select`.
+
+ Usage is the same as that of :meth:`_selectable.Select.join_from`.
+
+ """
+ return self.join_from(
+ from_, target, onclause=onclause, isouter=True, full=full
+ )
+
+ @_generative
+ def join_from(
+ self, from_, target, onclause=None, isouter=False, full=False
+ ):
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_expression.Select`.
+
+ E.g.::
+
+ stmt = select(user_table, address_table).join_from(
+ user_table, address_table, user_table.c.id == address_table.c.user_id
+ )
+
+ The above statement generates SQL similar to::
+
+ SELECT user.id, user.name, address.id, address.email, address.user_id
+ FROM user JOIN address ON user.id = address.user_id
+
+ .. versionadded:: 1.4
+
+ :param from\_: the left side of the join, will be rendered in the
+ FROM clause and is roughly equivalent to using the
+ :meth:`.Select.select_from` method.
+
+ :param target: target table to join towards
+
+ :param onclause: ON clause of the join.
+
+ :param isouter: if True, generate LEFT OUTER join. Same as
+ :meth:`_expression.Select.outerjoin`.
+
+ :param full: if True, generate FULL OUTER join.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join`
+
+ """ # noqa: E501
+
+ # note the order of parsing from vs. target is important here, as we
+ # are also deriving the source of the plugin (i.e. the subject mapper
+ # in an ORM query) which should favor the "from_" over the "target"
+
+ from_ = coercions.expect(
+ roles.FromClauseRole, from_, apply_propagate_attrs=self
+ )
+ target = coercions.expect(
+ roles.JoinTargetRole, target, apply_propagate_attrs=self
+ )
+ if onclause is not None:
+ onclause = coercions.expect(roles.OnClauseRole, onclause)
+
+ self._setup_joins += (
+ (target, onclause, from_, {"isouter": isouter, "full": full}),
+ )
+
+ def outerjoin(self, target, onclause=None, full=False):
+ """Create a left outer join.
+
+ Parameters are the same as that of :meth:`_expression.Select.join`.
+
+ .. versionchanged:: 1.4 :meth:`_expression.Select.outerjoin` now
+ creates a :class:`_sql.Join` object between a
+ :class:`_sql.FromClause` source that is within the FROM clause of
+ the existing SELECT, and a given target :class:`_sql.FromClause`,
+ and then adds this :class:`_sql.Join` to the FROM clause of the
+ newly generated SELECT statement. This is completely reworked
+ from the behavior in 1.3, which would instead create a subquery of
+ the entire
+ :class:`_expression.Select` and then join that subquery to the
+ target.
+
+ This is a **backwards incompatible change** as the previous behavior
+ was mostly useless, producing an unnamed subquery rejected by
+ most databases in any case. The new behavior is modeled after
+ that of the very successful :meth:`_orm.Query.join` method in the
+ ORM, in order to support the functionality of :class:`_orm.Query`
+ being available by using a :class:`_sql.Select` object with an
+ :class:`_orm.Session`.
+
+ See the notes for this change at :ref:`change_select_join`.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join`
+
+ """
+ return self.join(target, onclause=onclause, isouter=True, full=full)
+
+ def get_final_froms(self):
+ """Compute the final displayed list of :class:`_expression.FromClause`
+ elements.
+
+ This method will run through the full computation required to
+ determine what FROM elements will be displayed in the resulting
+ SELECT statement, including shadowing individual tables with
+ JOIN objects, as well as full computation for ORM use cases including
+ eager loading clauses.
+
+ For ORM use, this accessor returns the **post compilation**
+ list of FROM objects; this collection will include elements such as
+ eagerly loaded tables and joins. The objects will **not** be
+ ORM enabled and not work as a replacement for the
+ :meth:`_sql.Select.select_froms` collection; additionally, the
+ method is not well performing for an ORM enabled statement as it
+ will incur the full ORM construction process.
+
+ To retrieve the FROM list that's implied by the "columns" collection
+ passed to the :class:`_sql.Select` originally, use the
+ :attr:`_sql.Select.columns_clause_froms` accessor.
+
+ To select from an alternative set of columns while maintaining the
+ FROM list, use the :meth:`_sql.Select.with_only_columns` method and
+ pass the
+ :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter.
+
+ .. versionadded:: 1.4.23 - the :meth:`_sql.Select.get_final_froms`
+ method replaces the previous :attr:`_sql.Select.froms` accessor,
+ which is deprecated.
+
+ .. seealso::
+
+ :attr:`_sql.Select.columns_clause_froms`
+
+ """
+ return self._compile_state_factory(self, None)._get_display_froms()
+
+ @property
+ @util.deprecated(
+ "1.4.23",
+ "The :attr:`_expression.Select.froms` attribute is moved to "
+ "the :meth:`_expression.Select.get_final_froms` method.",
+ )
+ def froms(self):
+ """Return the displayed list of :class:`_expression.FromClause`
+ elements.
+
+
+ """
+ return self.get_final_froms()
+
+ @property
+ def columns_clause_froms(self):
+ """Return the set of :class:`_expression.FromClause` objects implied
+ by the columns clause of this SELECT statement.
+
+ .. versionadded:: 1.4.23
+
+ .. seealso::
+
+ :attr:`_sql.Select.froms` - "final" FROM list taking the full
+ statement into account
+
+ :meth:`_sql.Select.with_only_columns` - makes use of this
+ collection to set up a new FROM list
+
+ """
+
+ return SelectState.get_plugin_class(self).get_columns_clause_froms(
+ self
+ )
+
+ @property
+ def inner_columns(self):
+ """An iterator of all :class:`_expression.ColumnElement`
+ expressions which would
+ be rendered into the columns clause of the resulting SELECT statement.
+
+ This method is legacy as of 1.4 and is superseded by the
+ :attr:`_expression.Select.exported_columns` collection.
+
+ """
+
+ return iter(self._all_selected_columns)
+
+ def is_derived_from(self, fromclause):
+ if self in fromclause._cloned_set:
+ return True
+
+ for f in self._iterate_from_elements():
+ if f.is_derived_from(fromclause):
+ return True
+ return False
+
+ def _copy_internals(self, clone=_clone, **kw):
+ # Select() object has been cloned and probably adapted by the
+ # given clone function. Apply the cloning function to internal
+ # objects
+
+ # 1. keep a dictionary of the froms we've cloned, and what
+ # they've become. This allows us to ensure the same cloned from
+ # is used when other items such as columns are "cloned"
+
+ all_the_froms = set(
+ itertools.chain(
+ _from_objects(*self._raw_columns),
+ _from_objects(*self._where_criteria),
+ _from_objects(*[elem[0] for elem in self._setup_joins]),
+ )
+ )
+
+ # do a clone for the froms we've gathered. what is important here
+ # is if any of the things we are selecting from, like tables,
+ # were converted into Join objects. if so, these need to be
+ # added to _from_obj explicitly, because otherwise they won't be
+ # part of the new state, as they don't associate themselves with
+ # their columns.
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+
+ # 2. copy FROM collections, adding in joins that we've created.
+ existing_from_obj = [clone(f, **kw) for f in self._from_obj]
+ add_froms = (
+ set(f for f in new_froms.values() if isinstance(f, Join))
+ .difference(all_the_froms)
+ .difference(existing_from_obj)
+ )
+
+ self._from_obj = tuple(existing_from_obj) + tuple(add_froms)
+
+ # 3. clone everything else, making sure we use columns
+ # corresponding to the froms we just made.
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # copy everything else. for table-ish things like correlate,
+ # correlate_except, setup_joins, these clone normally. For
+ # column-expression oriented things like raw_columns, where_criteria,
+ # order by, we get this from the new froms.
+ super(Select, self)._copy_internals(
+ clone=clone, omit_attrs=("_from_obj",), **kw
+ )
+
+ self._reset_memoizations()
+
+ def get_children(self, **kwargs):
+ return itertools.chain(
+ super(Select, self).get_children(
+ omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+ ),
+ self._iterate_from_elements(),
+ )
+
+ @_generative
+ def add_columns(self, *columns):
+ """Return a new :func:`_expression.select` construct with
+ the given column expressions added to its columns clause.
+
+ E.g.::
+
+ my_select = my_select.add_columns(table.c.new_column)
+
+ See the documentation for
+ :meth:`_expression.Select.with_only_columns`
+ for guidelines on adding /replacing the columns of a
+ :class:`_expression.Select` object.
+
+ """
+ self._reset_memoizations()
+
+ self._raw_columns = self._raw_columns + [
+ coercions.expect(
+ roles.ColumnsClauseRole, column, apply_propagate_attrs=self
+ )
+ for column in columns
+ ]
+
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in util.to_list(entities)
+ ]
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.column` method is deprecated and will "
+ "be removed in a future release. Please use "
+ ":meth:`_expression.Select.add_columns`",
+ )
+ def column(self, column):
+ """Return a new :func:`_expression.select` construct with
+ the given column expression added to its columns clause.
+
+ E.g.::
+
+ my_select = my_select.column(table.c.new_column)
+
+ See the documentation for
+ :meth:`_expression.Select.with_only_columns`
+ for guidelines on adding /replacing the columns of a
+ :class:`_expression.Select` object.
+
+ """
+ return self.add_columns(column)
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def reduce_columns(self, only_synonyms=True):
+ """Return a new :func:`_expression.select` construct with redundantly
+ named, equivalently-valued columns removed from the columns clause.
+
+ "Redundant" here means two columns where one refers to the
+ other either based on foreign key, or via a simple equality
+ comparison in the WHERE clause of the statement. The primary purpose
+ of this method is to automatically construct a select statement
+ with all uniquely-named columns, without the need to use
+ table-qualified labels as
+ :meth:`_expression.Select.set_label_style`
+ does.
+
+ When columns are omitted based on foreign key, the referred-to
+ column is the one that's kept. When columns are omitted based on
+ WHERE equivalence, the first column in the columns clause is the
+ one that's kept.
+
+ :param only_synonyms: when True, limit the removal of columns
+ to those which have the same name as the equivalent. Otherwise,
+ all columns that are equivalent to another are removed.
+
+ """
+ return self.with_only_columns(
+ *util.preloaded.sql_util.reduce_columns(
+ self._all_selected_columns,
+ only_synonyms=only_synonyms,
+ *(self._where_criteria + self._from_obj)
+ )
+ )
+
+ @_generative
+ def with_only_columns(self, *columns, **kw):
+ r"""Return a new :func:`_expression.select` construct with its columns
+ clause replaced with the given columns.
+
+ By default, this method is exactly equivalent to as if the original
+ :func:`_expression.select` had been called with the given columns
+ clause. E.g. a statement::
+
+ s = select(table1.c.a, table1.c.b)
+ s = s.with_only_columns(table1.c.b)
+
+ should be exactly equivalent to::
+
+ s = select(table1.c.b)
+
+ In this mode of operation, :meth:`_sql.Select.with_only_columns`
+ will also dynamically alter the FROM clause of the
+ statement if it is not explicitly stated.
+ To maintain the existing set of FROMs including those implied by the
+ current columns clause, add the
+ :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.with_only_columns(table1.c.a, maintain_column_froms=True)
+
+ The above parameter performs a transfer of the effective FROMs
+ in the columns collection to the :meth:`_sql.Select.select_from`
+ method, as though the following were invoked::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.select_from(table1, table2).with_only_columns(table1.c.a)
+
+ The :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter makes use of the :attr:`_sql.Select.columns_clause_froms`
+ collection and performs an operation equivalent to the following::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.select_from(*s.columns_clause_froms).with_only_columns(table1.c.a)
+
+ :param \*columns: column expressions to be used.
+
+ .. versionchanged:: 1.4 the :meth:`_sql.Select.with_only_columns`
+ method accepts the list of column expressions positionally;
+ passing the expressions as a list is deprecated.
+
+ :param maintain_column_froms: boolean parameter that will ensure the
+ FROM list implied from the current columns clause will be transferred
+ to the :meth:`_sql.Select.select_from` method first.
+
+ .. versionadded:: 1.4.23
+
+ """ # noqa: E501
+
+ # memoizations should be cleared here as of
+ # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
+ # is the case for now.
+ self._assert_no_memoizations()
+
+ maintain_column_froms = kw.pop("maintain_column_froms", False)
+ if kw:
+ raise TypeError("unknown parameters: %s" % (", ".join(kw),))
+
+ if maintain_column_froms:
+ self.select_from.non_generative(self, *self.columns_clause_froms)
+
+ # then memoize the FROMs etc.
+ _MemoizedSelectEntities._generate_for_statement(self)
+
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in coercions._expression_collection_was_a_list(
+ "columns", "Select.with_only_columns", columns
+ )
+ ]
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this
+ :class:`_expression.Select` statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+ _whereclause = whereclause
+
+ @_generative
+ def where(self, *whereclause):
+ """Return a new :func:`_expression.select` construct with
+ the given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ """
+
+ assert isinstance(self._where_criteria, tuple)
+
+ for criterion in whereclause:
+ where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ self._where_criteria += (where_criteria,)
+
+ @_generative
+ def having(self, having):
+ """Return a new :func:`_expression.select` construct with
+ the given expression added to
+ its HAVING clause, joined to the existing clause via AND, if any.
+
+ """
+ self._having_criteria += (
+ coercions.expect(roles.WhereHavingRole, having),
+ )
+
+ @_generative
+ def distinct(self, *expr):
+ r"""Return a new :func:`_expression.select` construct which
+ will apply DISTINCT to its columns clause.
+
+ :param \*expr: optional column expressions. When present,
+ the PostgreSQL dialect will render a ``DISTINCT ON (<expressions>>)``
+ construct.
+
+ .. deprecated:: 1.4 Using \*expr in other dialects is deprecated
+ and will raise :class:`_exc.CompileError` in a future version.
+
+ """
+ if expr:
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(
+ coercions.expect(roles.ByOfRole, e) for e in expr
+ )
+ else:
+ self._distinct = True
+
+ @_generative
+ def select_from(self, *froms):
+ r"""Return a new :func:`_expression.select` construct with the
+ given FROM expression(s)
+ merged into its list of FROM objects.
+
+ E.g.::
+
+ table1 = table('t1', column('a'))
+ table2 = table('t2', column('b'))
+ s = select(table1.c.a).\
+ select_from(
+ table1.join(table2, table1.c.a==table2.c.b)
+ )
+
+ The "from" list is a unique set on the identity of each element,
+ so adding an already present :class:`_schema.Table`
+ or other selectable
+ will have no effect. Passing a :class:`_expression.Join` that refers
+ to an already present :class:`_schema.Table`
+ or other selectable will have
+ the effect of concealing the presence of that selectable as
+ an individual element in the rendered FROM list, instead
+ rendering it into a JOIN clause.
+
+ While the typical purpose of :meth:`_expression.Select.select_from`
+ is to
+ replace the default, derived FROM clause with a join, it can
+ also be called with individual table elements, multiple times
+ if desired, in the case that the FROM clause cannot be fully
+ derived from the columns clause::
+
+ select(func.count('*')).select_from(table1)
+
+ """
+
+ self._from_obj += tuple(
+ coercions.expect(
+ roles.FromClauseRole, fromclause, apply_propagate_attrs=self
+ )
+ for fromclause in froms
+ )
+
+ @_generative
+ def correlate(self, *fromclauses):
+ r"""Return a new :class:`_expression.Select`
+ which will correlate the given FROM
+ clauses to that of an enclosing :class:`_expression.Select`.
+
+ Calling this method turns off the :class:`_expression.Select` object's
+ default behavior of "auto-correlation". Normally, FROM elements
+ which appear in a :class:`_expression.Select`
+ that encloses this one via
+ its :term:`WHERE clause`, ORDER BY, HAVING or
+ :term:`columns clause` will be omitted from this
+ :class:`_expression.Select`
+ object's :term:`FROM clause`.
+ Setting an explicit correlation collection using the
+ :meth:`_expression.Select.correlate`
+ method provides a fixed list of FROM objects
+ that can potentially take place in this process.
+
+ When :meth:`_expression.Select.correlate`
+ is used to apply specific FROM clauses
+ for correlation, the FROM elements become candidates for
+ correlation regardless of how deeply nested this
+ :class:`_expression.Select`
+ object is, relative to an enclosing :class:`_expression.Select`
+ which refers to
+ the same FROM object. This is in contrast to the behavior of
+ "auto-correlation" which only correlates to an immediate enclosing
+ :class:`_expression.Select`.
+ Multi-level correlation ensures that the link
+ between enclosed and enclosing :class:`_expression.Select`
+ is always via
+ at least one WHERE/ORDER BY/HAVING/columns clause in order for
+ correlation to take place.
+
+ If ``None`` is passed, the :class:`_expression.Select`
+ object will correlate
+ none of its FROM entries, and all will render unconditionally
+ in the local FROM clause.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate collection.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate_except`
+
+ :ref:`tutorial_scalar_subquery`
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate = ()
+ else:
+ self._correlate = self._correlate + tuple(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @_generative
+ def correlate_except(self, *fromclauses):
+ r"""Return a new :class:`_expression.Select`
+ which will omit the given FROM
+ clauses from the auto-correlation process.
+
+ Calling :meth:`_expression.Select.correlate_except` turns off the
+ :class:`_expression.Select` object's default behavior of
+ "auto-correlation" for the given FROM elements. An element
+ specified here will unconditionally appear in the FROM list, while
+ all other FROM elements remain subject to normal auto-correlation
+ behaviors.
+
+ If ``None`` is passed, the :class:`_expression.Select`
+ object will correlate
+ all of its FROM entries.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate-exception collection.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate`
+
+ :ref:`tutorial_scalar_subquery`
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate_except = ()
+ else:
+ self._correlate_except = (self._correlate_except or ()) + tuple(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @HasMemoized.memoized_attribute
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ For a :func:`_expression.select` construct, the collection here is
+ exactly what would be rendered inside the "SELECT" statement, and the
+ :class:`_expression.ColumnElement` objects are directly present as they
+ were given, e.g.::
+
+ col1 = column('q', Integer)
+ col2 = column('p', Integer)
+ stmt = select(col1, col2)
+
+ Above, ``stmt.selected_columns`` would be a collection that contains
+ the ``col1`` and ``col2`` objects directly. For a statement that is
+ against a :class:`_schema.Table` or other
+ :class:`_expression.FromClause`, the collection will use the
+ :class:`_expression.ColumnElement` objects that are in the
+ :attr:`_expression.FromClause.c` collection of the from element.
+
+ .. note::
+
+ The :attr:`_sql.Select.selected_columns` collection does not
+ include expressions established in the columns clause using the
+ :func:`_sql.text` construct; these are silently omitted from the
+ collection. To use plain textual column expressions inside of a
+ :class:`_sql.Select` construct, use the :func:`_sql.literal_column`
+ construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ # compare to SelectState._generate_columns_plus_names, which
+ # generates the actual names used in the SELECT string. that
+ # method is more complex because it also renders columns that are
+ # fully ambiguous, e.g. same column more than once.
+ conv = SelectState._column_naming_convention(self._label_style)
+
+ return ColumnCollection(
+ [
+ (conv(c), c)
+ for c in self._all_selected_columns
+ if not c._is_text_clause
+ ]
+ ).as_immutable()
+
+ @HasMemoized.memoized_attribute
+ def _all_selected_columns(self):
+ meth = SelectState.get_plugin_class(self).all_selected_columns
+ return list(meth(self))
+
+ def _ensure_disambiguated_names(self):
+ if self._label_style is LABEL_STYLE_NONE:
+ self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
+ return self
+
+ def _generate_columns_plus_names(self, anon_for_dupe_key):
+ """Generate column names as rendered in a SELECT statement by
+ the compiler.
+
+ This is distinct from the _column_naming_convention generator that's
+ intended for population of .c collections and similar, which has
+ different rules. the collection returned here calls upon the
+ _column_naming_convention as well.
+
+ """
+ cols = self._all_selected_columns
+
+ key_naming_convention = SelectState._column_naming_convention(
+ self._label_style
+ )
+
+ names = {}
+
+ result = []
+ result_append = result.append
+
+ table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ label_style_none = self._label_style is LABEL_STYLE_NONE
+
+ # a counter used for "dedupe" labels, which have double underscores
+ # in them and are never referred by name; they only act
+ # as positional placeholders. they need only be unique within
+ # the single columns clause they're rendered within (required by
+ # some dbs such as mysql). So their anon identity is tracked against
+ # a fixed counter rather than hash() identity.
+ dedupe_hash = 1
+
+ for c in cols:
+ repeated = False
+
+ if not c._render_label_in_columns_clause:
+ effective_name = (
+ required_label_name
+ ) = fallback_label_name = None
+ elif label_style_none:
+ effective_name = required_label_name = None
+ fallback_label_name = c._non_anon_label or c._anon_name_label
+ else:
+ if table_qualified:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = c._tq_label
+ else:
+ effective_name = fallback_label_name = c._non_anon_label
+ required_label_name = None
+
+ if effective_name is None:
+ # it seems like this could be _proxy_key and we would
+ # not need _expression_label but it isn't
+ # giving us a clue when to use anon_label instead
+ expr_label = c._expression_label
+ if expr_label is None:
+ repeated = c._anon_name_label in names
+ names[c._anon_name_label] = c
+ effective_name = required_label_name = None
+
+ if repeated:
+ # here, "required_label_name" is sent as
+ # "None" and "fallback_label_name" is sent.
+ if table_qualified:
+ fallback_label_name = (
+ c._dedupe_anon_tq_label_idx(dedupe_hash)
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._dedupe_anon_label_idx(
+ dedupe_hash
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._anon_name_label
+ else:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = expr_label
+
+ if effective_name is not None:
+ if effective_name in names:
+ # when looking to see if names[name] is the same column as
+ # c, use hash(), so that an annotated version of the column
+ # is seen as the same as the non-annotated
+ if hash(names[effective_name]) != hash(c):
+
+ # different column under the same name. apply
+ # disambiguating label
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_tq_label
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_name_label
+
+ if anon_for_dupe_key and required_label_name in names:
+ # here, c._anon_tq_label is definitely unique to
+ # that column identity (or annotated version), so
+ # this should always be true.
+ # this is also an infrequent codepath because
+ # you need two levels of duplication to be here
+ assert hash(names[required_label_name]) == hash(c)
+
+ # the column under the disambiguating label is
+ # already present. apply the "dedupe" label to
+ # subsequent occurrences of the column so that the
+ # original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[required_label_name] = c
+ elif anon_for_dupe_key:
+ # same column under the same name. apply the "dedupe"
+ # label so that the original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[effective_name] = c
+
+ result_append(
+ (
+ # string label name, if non-None, must be rendered as a
+ # label, i.e. "AS <name>"
+ required_label_name,
+ # proxy_key that is to be part of the result map for this
+ # col. this is also the key in a fromclause.c or
+ # select.selected_columns collection
+ key_naming_convention(c),
+ # name that can be used to render an "AS <name>" when
+ # we have to render a label even though
+ # required_label_name was not given
+ fallback_label_name,
+ # the ColumnElement itself
+ c,
+ # True if this is a duplicate of a previous column
+ # in the list of columns
+ repeated,
+ )
+ )
+
+ return result
+
+ def _generate_fromclause_column_proxies(self, subquery):
+ """Generate column proxies to place in the exported ``.c``
+ collection of a subquery."""
+
+ prox = [
+ c._make_proxy(
+ subquery,
+ key=proxy_key,
+ name=required_label_name,
+ name_is_truncatable=True,
+ )
+ for (
+ required_label_name,
+ proxy_key,
+ fallback_label_name,
+ c,
+ repeated,
+ ) in (self._generate_columns_plus_names(False))
+ if not c._is_text_clause
+ ]
+
+ subquery._columns._populate_separate_keys(prox)
+
+ def _needs_parens_for_grouping(self):
+ return self._has_row_limiting_clause or bool(
+ self._order_by_clause.clauses
+ )
+
+ def self_group(self, against=None):
+ """Return a 'grouping' construct as per the
+ :class:`_expression.ClauseElement` specification.
+
+ This produces an element that can be embedded in an expression. Note
+ that this method is called automatically as needed when constructing
+ expressions and should not require explicit use.
+
+ """
+ if (
+ isinstance(against, CompoundSelect)
+ and not self._needs_parens_for_grouping()
+ ):
+ return self
+ else:
+ return SelectStatementGrouping(self)
+
+ def union(self, *other, **kwargs):
+ r"""Return a SQL ``UNION`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_union(self, *other, **kwargs)
+
+ def union_all(self, *other, **kwargs):
+ r"""Return a SQL ``UNION ALL`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_union_all(self, *other, **kwargs)
+
+ def except_(self, *other, **kwargs):
+ r"""Return a SQL ``EXCEPT`` of this select() construct against
+ the given selectable provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_except(self, *other, **kwargs)
+
+ def except_all(self, *other, **kwargs):
+ r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_except_all(self, *other, **kwargs)
+
+ def intersect(self, *other, **kwargs):
+ r"""Return a SQL ``INTERSECT`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_intersect(self, *other, **kwargs)
+
+ def intersect_all(self, *other, **kwargs):
+ r"""Return a SQL ``INTERSECT ALL`` of this select() construct
+ against the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_intersect_all(self, *other, **kwargs)
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this :class:`.Executable` is bound, or None if none found.
+
+ """
+ if self._bind:
+ return self._bind
+
+ for item in self._iterate_from_elements():
+ if item._is_subquery and item.element is self:
+ raise exc.InvalidRequestError(
+ "select() construct refers to itself as a FROM"
+ )
+
+ e = item.bind
+ if e:
+ self._bind = e
+ return e
+ else:
+ break
+
+ for c in self._raw_columns:
+ e = c.bind
+ if e:
+ self._bind = e
+ return e
+
+ @bind.setter
+ def bind(self, bind):
+ self._bind = bind
+
+
+class ScalarSelect(roles.InElementRole, Generative, Grouping):
+ """Represent a scalar subquery.
+
+
+ A :class:`_sql.ScalarSelect` is created by invoking the
+ :meth:`_sql.SelectBase.scalar_subquery` method. The object
+ then participates in other SQL expressions as a SQL column expression
+ within the :class:`_sql.ColumnElement` hierarchy.
+
+ .. seealso::
+
+ :meth:`_sql.SelectBase.scalar_subquery`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+ """
+
+ _from_objects = []
+ _is_from_container = True
+ _is_implicitly_boolean = False
+ inherit_cache = True
+
+ def __init__(self, element):
+ self.element = element
+ self.type = element._scalar_type()
+
+ @property
+ def columns(self):
+ raise exc.InvalidRequestError(
+ "Scalar Select expression has no "
+ "columns; use this object directly "
+ "within a column-level expression."
+ )
+
+ c = columns
+
+ @_generative
+ def where(self, crit):
+ """Apply a WHERE clause to the SELECT statement referred to
+ by this :class:`_expression.ScalarSelect`.
+
+ """
+ self.element = self.element.where(crit)
+
+ def self_group(self, **kwargs):
+ return self
+
+ @_generative
+ def correlate(self, *fromclauses):
+ r"""Return a new :class:`_expression.ScalarSelect`
+ which will correlate the given FROM
+ clauses to that of an enclosing :class:`_expression.Select`.
+
+ This method is mirrored from the :meth:`_sql.Select.correlate` method
+ of the underlying :class:`_sql.Select`. The method applies the
+ :meth:_sql.Select.correlate` method, then returns a new
+ :class:`_sql.ScalarSelect` against that statement.
+
+ .. versionadded:: 1.4 Previously, the
+ :meth:`_sql.ScalarSelect.correlate`
+ method was only available from :class:`_sql.Select`.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate collection.
+
+ .. seealso::
+
+ :meth:`_expression.ScalarSelect.correlate_except`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+
+ """
+ self.element = self.element.correlate(*fromclauses)
+
+ @_generative
+ def correlate_except(self, *fromclauses):
+ r"""Return a new :class:`_expression.ScalarSelect`
+ which will omit the given FROM
+ clauses from the auto-correlation process.
+
+ This method is mirrored from the
+ :meth:`_sql.Select.correlate_except` method of the underlying
+ :class:`_sql.Select`. The method applies the
+ :meth:_sql.Select.correlate_except` method, then returns a new
+ :class:`_sql.ScalarSelect` against that statement.
+
+ .. versionadded:: 1.4 Previously, the
+ :meth:`_sql.ScalarSelect.correlate_except`
+ method was only available from :class:`_sql.Select`.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate-exception collection.
+
+ .. seealso::
+
+ :meth:`_expression.ScalarSelect.correlate`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+
+ """
+
+ self.element = self.element.correlate_except(*fromclauses)
+
+
+class Exists(UnaryExpression):
+ """Represent an ``EXISTS`` clause.
+
+ See :func:`_sql.exists` for a description of usage.
+
+ An ``EXISTS`` clause can also be constructed from a :func:`_sql.select`
+ instance by calling :meth:`_sql.SelectBase.exists`.
+
+ """
+
+ _from_objects = []
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ """Construct a new :class:`_expression.Exists` construct.
+
+ The :func:`_sql.exists` can be invoked by itself to produce an
+ :class:`_sql.Exists` construct, which will accept simple WHERE
+ criteria::
+
+ exists_criteria = exists().where(table1.c.col1 == table2.c.col2)
+
+ However, for greater flexibility in constructing the SELECT, an
+ existing :class:`_sql.Select` construct may be converted to an
+ :class:`_sql.Exists`, most conveniently by making use of the
+ :meth:`_sql.SelectBase.exists` method::
+
+ exists_criteria = (
+ select(table2.c.col2).
+ where(table1.c.col1 == table2.c.col2).
+ exists()
+ )
+
+ The EXISTS criteria is then used inside of an enclosing SELECT::
+
+ stmt = select(table1.c.col1).where(exists_criteria)
+
+ The above statement will then be of the form::
+
+ SELECT col1 FROM table1 WHERE EXISTS
+ (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1)
+
+ .. seealso::
+
+ :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial.
+
+ :meth:`_sql.SelectBase.exists` - method to transform a ``SELECT`` to an
+ ``EXISTS`` clause.
+
+ """ # noqa: E501
+ if args and isinstance(args[0], (SelectBase, ScalarSelect)):
+ s = args[0]
+ else:
+ if not args:
+ args = (literal_column("*"),)
+ s = Select._create(*args, **kwargs).scalar_subquery()
+
+ UnaryExpression.__init__(
+ self,
+ s,
+ operator=operators.exists,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=True,
+ )
+
+ def _regroup(self, fn):
+ element = self.element._ungroup()
+ element = fn(element)
+ return element.self_group(against=operators.exists)
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.Exists.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use "
+ "of the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.Exists.select` method will no longer accept "
+ "keyword arguments in version 2.0. "
+ "Please use generative methods from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Return a SELECT of this :class:`_expression.Exists`.
+
+ e.g.::
+
+ stmt = exists(some_table.c.id).where(some_table.c.id == 5).select()
+
+ This will produce a statement resembling::
+
+ SELECT EXISTS (SELECT id FROM some_table WHERE some_table = :param) AS anon_1
+
+ :param whereclause: a WHERE clause, equivalent to calling the
+ :meth:`_sql.Select.where` method.
+
+ :param **kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ .. seealso::
+
+ :func:`_expression.select` - general purpose
+ method which allows for arbitrary column lists.
+
+ """ # noqa
+
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(self, [self], **kwargs)
+
+ def correlate(self, *fromclause):
+ """Apply correlation to the subquery noted by this
+ :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :meth:`_sql.ScalarSelect.correlate`
+
+ """
+ e = self._clone()
+ e.element = self._regroup(
+ lambda element: element.correlate(*fromclause)
+ )
+ return e
+
+ def correlate_except(self, *fromclause):
+ """Apply correlation to the subquery noted by this
+ :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :meth:`_sql.ScalarSelect.correlate_except`
+
+ """
+
+ e = self._clone()
+ e.element = self._regroup(
+ lambda element: element.correlate_except(*fromclause)
+ )
+ return e
+
+ def select_from(self, *froms):
+ """Return a new :class:`_expression.Exists` construct,
+ applying the given
+ expression to the :meth:`_expression.Select.select_from`
+ method of the select
+ statement contained.
+
+ .. note:: it is typically preferable to build a :class:`_sql.Select`
+ statement first, including the desired WHERE clause, then use the
+ :meth:`_sql.SelectBase.exists` method to produce an
+ :class:`_sql.Exists` object at once.
+
+ """
+ e = self._clone()
+ e.element = self._regroup(lambda element: element.select_from(*froms))
+ return e
+
+ def where(self, *clause):
+ """Return a new :func:`_expression.exists` construct with the
+ given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+
+ .. note:: it is typically preferable to build a :class:`_sql.Select`
+ statement first, including the desired WHERE clause, then use the
+ :meth:`_sql.SelectBase.exists` method to produce an
+ :class:`_sql.Exists` object at once.
+
+ """
+ e = self._clone()
+ e.element = self._regroup(lambda element: element.where(*clause))
+ return e
+
+
+class TextualSelect(SelectBase):
+ """Wrap a :class:`_expression.TextClause` construct within a
+ :class:`_expression.SelectBase`
+ interface.
+
+ This allows the :class:`_expression.TextClause` object to gain a
+ ``.c`` collection
+ and other FROM-like capabilities such as
+ :meth:`_expression.FromClause.alias`,
+ :meth:`_expression.SelectBase.cte`, etc.
+
+ The :class:`_expression.TextualSelect` construct is produced via the
+ :meth:`_expression.TextClause.columns`
+ method - see that method for details.
+
+ .. versionchanged:: 1.4 the :class:`_expression.TextualSelect`
+ class was renamed
+ from ``TextAsFrom``, to more correctly suit its role as a
+ SELECT-oriented object and not a FROM clause.
+
+ .. seealso::
+
+ :func:`_expression.text`
+
+ :meth:`_expression.TextClause.columns` - primary creation interface.
+
+ """
+
+ __visit_name__ = "textual_select"
+
+ _label_style = LABEL_STYLE_NONE
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("column_args", InternalTraversal.dp_clauseelement_list),
+ ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+
+ _is_textual = True
+
+ is_text = True
+ is_select = True
+
+ def __init__(self, text, columns, positional=False):
+ self.element = text
+ # convert for ORM attributes->columns, etc
+ self.column_args = [
+ coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+ ]
+ self.positional = positional
+
+ @HasMemoized.memoized_attribute
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ For a :class:`_expression.TextualSelect` construct, the collection
+ contains the :class:`_expression.ColumnElement` objects that were
+ passed to the constructor, typically via the
+ :meth:`_expression.TextClause.columns` method.
+
+
+ .. versionadded:: 1.4
+
+ """
+ return ColumnCollection(
+ (c.key, c) for c in self.column_args
+ ).as_immutable()
+
+ @property
+ def _all_selected_columns(self):
+ return self.column_args
+
+ def _set_label_style(self, style):
+ return self
+
+ def _ensure_disambiguated_names(self):
+ return self
+
+ @property
+ def _bind(self):
+ return self.element._bind
+
+ @_generative
+ def bindparams(self, *binds, **bind_as_values):
+ self.element = self.element.bindparams(*binds, **bind_as_values)
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ c._make_proxy(fromclause) for c in self.column_args
+ )
+
+ def _scalar_type(self):
+ return self.column_args[0].type
+
+
+TextAsFrom = TextualSelect
+"""Backwards compatibility with the previous name"""
+
+
+class AnnotatedFromClause(Annotated):
+ def __init__(self, element, values):
+ # force FromClause to generate their internal
+ # collections into __dict__
+ element.c
+ Annotated.__init__(self, element, values)
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
new file mode 100644
index 0000000..322bfec
--- /dev/null
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -0,0 +1,3351 @@
+# sql/sqltypes.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL specific types.
+
+"""
+
+import codecs
+import datetime as dt
+import decimal
+import json
+
+from . import coercions
+from . import elements
+from . import operators
+from . import roles
+from . import type_api
+from .base import _bind_or_error
+from .base import NO_ARG
+from .base import SchemaEventTarget
+from .elements import _NONE_NAME
+from .elements import quoted_name
+from .elements import Slice
+from .elements import TypeCoerce as type_coerce # noqa
+from .traversals import HasCacheKey
+from .traversals import InternalTraversal
+from .type_api import Emulated
+from .type_api import NativeForEmulated # noqa
+from .type_api import to_instance
+from .type_api import TypeDecorator
+from .type_api import TypeEngine
+from .type_api import Variant
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import processors
+from .. import util
+from ..util import compat
+from ..util import langhelpers
+from ..util import OrderedDict
+from ..util import pickle
+
+
+class _LookupExpressionAdapter(object):
+
+ """Mixin expression adaptations based on lookup tables.
+
+ These rules are currently used by the numeric, integer and date types
+ which have detailed cross-expression coercion rules.
+
+ """
+
+ @property
+ def _expression_adaptations(self):
+ raise NotImplementedError()
+
+ class Comparator(TypeEngine.Comparator):
+ _blank_dict = util.immutabledict()
+
+ def _adapt_expression(self, op, other_comparator):
+ othertype = other_comparator.type._type_affinity
+ lookup = self.type._expression_adaptations.get(
+ op, self._blank_dict
+ ).get(othertype, self.type)
+ if lookup is othertype:
+ return (op, other_comparator.type)
+ elif lookup is self.type._type_affinity:
+ return (op, self.type)
+ else:
+ return (op, to_instance(lookup))
+
+ comparator_factory = Comparator
+
+
+class Concatenable(object):
+
+ """A mixin that marks a type as supporting 'concatenation',
+ typically strings."""
+
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if op is operators.add and isinstance(
+ other_comparator,
+ (Concatenable.Comparator, NullType.Comparator),
+ ):
+ return operators.concat_op, self.expr.type
+ else:
+ return super(Concatenable.Comparator, self)._adapt_expression(
+ op, other_comparator
+ )
+
+ comparator_factory = Comparator
+
+
+class Indexable(object):
+ """A mixin that marks a type as supporting indexing operations,
+ such as array or JSON structures.
+
+
+ .. versionadded:: 1.1.0
+
+
+ """
+
+ class Comparator(TypeEngine.Comparator):
+ def _setup_getitem(self, index):
+ raise NotImplementedError()
+
+ def __getitem__(self, index):
+ (
+ adjusted_op,
+ adjusted_right_expr,
+ result_type,
+ ) = self._setup_getitem(index)
+ return self.operate(
+ adjusted_op, adjusted_right_expr, result_type=result_type
+ )
+
+ comparator_factory = Comparator
+
+
+class String(Concatenable, TypeEngine):
+
+ """The base for all string and character types.
+
+ In SQL, corresponds to VARCHAR. Can also take Python unicode objects
+ and encode to the database's encoding in bind params (and the reverse for
+ result sets.)
+
+ The `length` field is usually required when the `String` type is
+ used within a CREATE TABLE statement, as VARCHAR requires a length
+ on most databases.
+
+ """
+
+ __visit_name__ = "string"
+
+ RETURNS_UNICODE = util.symbol(
+ "RETURNS_UNICODE",
+ """Indicates that the DBAPI returns Python Unicode for VARCHAR,
+ NVARCHAR, and other character-based datatypes in all cases.
+
+ This is the default value for
+ :attr:`.DefaultDialect.returns_unicode_strings` under Python 3.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_BYTES = util.symbol(
+ "RETURNS_BYTES",
+ """Indicates that the DBAPI returns byte objects under Python 3
+ or non-Unicode string objects under Python 2 for VARCHAR, NVARCHAR,
+ and other character-based datatypes in all cases.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_CONDITIONAL = util.symbol(
+ "RETURNS_CONDITIONAL",
+ """Indicates that the DBAPI may return Unicode or bytestrings for
+ VARCHAR, NVARCHAR, and other character-based datatypes, and that
+ SQLAlchemy's default String datatype will need to test on a per-row
+ basis for Unicode or bytes.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_UNKNOWN = util.symbol(
+ "RETURNS_UNKNOWN",
+ """Indicates that the dialect should test on first connect what the
+ string-returning behavior of character-based datatypes is.
+
+ This is the default value for DefaultDialect.unicode_returns under
+ Python 2.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute under
+ Python 2 only. The value is disallowed under Python 3.
+
+ .. versionadded:: 1.4
+
+ .. deprecated:: 1.4 This value will be removed in SQLAlchemy 2.0.
+
+ """,
+ )
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`.String.convert_unicode` parameter is deprecated "
+ "and will be removed in a future release. All modern DBAPIs "
+ "now support Python Unicode directly and this parameter is "
+ "unnecessary.",
+ ),
+ unicode_error=(
+ "1.3",
+ "The :paramref:`.String.unicode_errors` parameter is deprecated "
+ "and will be removed in a future release. This parameter is "
+ "unnecessary for modern Python DBAPIs and degrades performance "
+ "significantly.",
+ ),
+ )
+ def __init__(
+ self,
+ length=None,
+ collation=None,
+ convert_unicode=False,
+ unicode_error=None,
+ _warn_on_bytestring=False,
+ _expect_unicode=False,
+ ):
+ """
+ Create a string-holding type.
+
+ :param length: optional, a length for the column for use in
+ DDL and CAST expressions. May be safely omitted if no ``CREATE
+ TABLE`` will be issued. Certain databases may require a
+ ``length`` for use in DDL, and will raise an exception when
+ the ``CREATE TABLE`` DDL is issued if a ``VARCHAR``
+ with no length is included. Whether the value is
+ interpreted as bytes or characters is database specific.
+
+ :param collation: Optional, a column-level collation for
+ use in DDL and CAST expressions. Renders using the
+ COLLATE keyword supported by SQLite, MySQL, and PostgreSQL.
+ E.g.::
+
+ >>> from sqlalchemy import cast, select, String
+ >>> print(select(cast('some string', String(collation='utf8'))))
+ SELECT CAST(:param_1 AS VARCHAR COLLATE utf8) AS anon_1
+
+ :param convert_unicode: When set to ``True``, the
+ :class:`.String` type will assume that
+ input is to be passed as Python Unicode objects under Python 2,
+ and results returned as Python Unicode objects.
+ In the rare circumstance that the DBAPI does not support
+ Python unicode under Python 2, SQLAlchemy will use its own
+ encoder/decoder functionality on strings, referring to the
+ value of the :paramref:`_sa.create_engine.encoding` parameter
+ parameter passed to :func:`_sa.create_engine` as the encoding.
+
+ For the extremely rare case that Python Unicode
+ is to be encoded/decoded by SQLAlchemy on a backend
+ that *does* natively support Python Unicode,
+ the string value ``"force"`` can be passed here which will
+ cause SQLAlchemy's encode/decode services to be
+ used unconditionally.
+
+ .. note::
+
+ SQLAlchemy's unicode-conversion flags and features only apply
+ to Python 2; in Python 3, all string objects are Unicode objects.
+ For this reason, as well as the fact that virtually all modern
+ DBAPIs now support Unicode natively even under Python 2,
+ the :paramref:`.String.convert_unicode` flag is inherently a
+ legacy feature.
+
+ .. note::
+
+ In the vast majority of cases, the :class:`.Unicode` or
+ :class:`.UnicodeText` datatypes should be used for a
+ :class:`_schema.Column` that expects to store non-ascii data.
+ These
+ datatypes will ensure that the correct types are used on the
+ database side as well as set up the correct Unicode behaviors
+ under Python 2.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.convert_unicode` -
+ :class:`_engine.Engine`-wide parameter
+
+ :param unicode_error: Optional, a method to use to handle Unicode
+ conversion errors. Behaves like the ``errors`` keyword argument to
+ the standard library's ``string.decode()`` functions, requires
+ that :paramref:`.String.convert_unicode` is set to
+ ``"force"``
+
+ """
+ if unicode_error is not None and convert_unicode != "force":
+ raise exc.ArgumentError(
+ "convert_unicode must be 'force' " "when unicode_error is set."
+ )
+
+ self.length = length
+ self.collation = collation
+ self._expect_unicode = convert_unicode or _expect_unicode
+ self._expect_unicode_error = unicode_error
+
+ self._warn_on_bytestring = _warn_on_bytestring
+
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.replace("'", "''")
+
+ if dialect.identifier_preparer._double_percents:
+ value = value.replace("%", "%%")
+
+ return "'%s'" % value
+
+ return process
+
+ def bind_processor(self, dialect):
+ if self._expect_unicode or dialect.convert_unicode:
+ if (
+ dialect.supports_unicode_binds
+ and self._expect_unicode != "force"
+ ):
+ if self._warn_on_bytestring:
+
+ def process(value):
+ if isinstance(value, util.binary_type):
+ util.warn_limited(
+ "Unicode type received non-unicode "
+ "bind param value %r.",
+ (util.ellipses_string(value),),
+ )
+ return value
+
+ return process
+ else:
+ return None
+ else:
+ encoder = codecs.getencoder(dialect.encoding)
+ warn_on_bytestring = self._warn_on_bytestring
+
+ def process(value):
+ if isinstance(value, util.text_type):
+ return encoder(value, self._expect_unicode_error)[0]
+ elif warn_on_bytestring and value is not None:
+ util.warn_limited(
+ "Unicode type received non-unicode bind "
+ "param value %r.",
+ (util.ellipses_string(value),),
+ )
+ return value
+
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect, coltype):
+ wants_unicode = self._expect_unicode or dialect.convert_unicode
+ needs_convert = wants_unicode and (
+ dialect.returns_unicode_strings is not String.RETURNS_UNICODE
+ or self._expect_unicode in ("force", "force_nocheck")
+ )
+ needs_isinstance = (
+ needs_convert
+ and dialect.returns_unicode_strings
+ in (
+ String.RETURNS_CONDITIONAL,
+ String.RETURNS_UNICODE,
+ )
+ and self._expect_unicode != "force_nocheck"
+ )
+ if needs_convert:
+ if needs_isinstance:
+ return processors.to_conditional_unicode_processor_factory(
+ dialect.encoding, self._expect_unicode_error
+ )
+ else:
+ return processors.to_unicode_processor_factory(
+ dialect.encoding, self._expect_unicode_error
+ )
+ else:
+ return None
+
+ @property
+ def python_type(self):
+ if self._expect_unicode:
+ return util.text_type
+ else:
+ return str
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+ @classmethod
+ def _warn_deprecated_unicode(cls):
+ util.warn_deprecated(
+ "The convert_unicode on Engine and String as well as the "
+ "unicode_error flag on String are deprecated. All modern "
+ "DBAPIs now support Python Unicode natively under Python 2, and "
+ "under Python 3 all strings are inherently Unicode. These flags "
+ "will be removed in a future release.",
+ version="1.3",
+ )
+
+
+class Text(String):
+
+ """A variably sized string type.
+
+ In SQL, usually corresponds to CLOB or TEXT. Can also take Python
+ unicode objects and encode to the database's encoding in bind
+ params (and the reverse for result sets.) In general, TEXT objects
+ do not have a length; while some databases will accept a length
+ argument here, it will be rejected by others.
+
+ """
+
+ __visit_name__ = "text"
+
+
+class Unicode(String):
+
+ """A variable length Unicode string type.
+
+ The :class:`.Unicode` type is a :class:`.String` subclass that assumes
+ input and output strings that may contain non-ASCII characters, and for
+ some backends implies an underlying column type that is explicitly
+ supporting of non-ASCII data, such as ``NVARCHAR`` on Oracle and SQL
+ Server. This will impact the output of ``CREATE TABLE`` statements and
+ ``CAST`` functions at the dialect level, and also in some cases will
+ indicate different behavior in the DBAPI itself in how it handles bound
+ parameters.
+
+ The character encoding used by the :class:`.Unicode` type that is used to
+ transmit and receive data to the database is usually determined by the
+ DBAPI itself. All modern DBAPIs accommodate non-ASCII strings but may have
+ different methods of managing database encodings; if necessary, this
+ encoding should be configured as detailed in the notes for the target DBAPI
+ in the :ref:`dialect_toplevel` section.
+
+ In modern SQLAlchemy, use of the :class:`.Unicode` datatype does not
+ typically imply any encoding/decoding behavior within SQLAlchemy itself.
+ Historically, when DBAPIs did not support Python ``unicode`` objects under
+ Python 2, SQLAlchemy handled unicode encoding/decoding services itself
+ which would be controlled by the flag :paramref:`.String.convert_unicode`;
+ this flag is deprecated as it is no longer needed for Python 3.
+
+ When using Python 2, data that is passed to columns that use the
+ :class:`.Unicode` datatype must be of type ``unicode``, and not ``str``
+ which in Python 2 is equivalent to ``bytes``. In Python 3, all data
+ passed to columns that use the :class:`.Unicode` datatype should be
+ of type ``str``. See the flag :paramref:`.String.convert_unicode` for
+ more discussion of unicode encode/decode behavior under Python 2.
+
+ .. warning:: Some database backends, particularly SQL Server with pyodbc,
+ are known to have undesirable behaviors regarding data that is noted
+ as being of ``NVARCHAR`` type as opposed to ``VARCHAR``, including
+ datatype mismatch errors and non-use of indexes. See the section
+ on :meth:`.DialectEvents.do_setinputsizes` for background on working
+ around unicode character issues for backends like SQL Server with
+ pyodbc as well as cx_Oracle.
+
+ .. seealso::
+
+ :class:`.UnicodeText` - unlengthed textual counterpart
+ to :class:`.Unicode`.
+
+ :paramref:`.String.convert_unicode`
+
+ :meth:`.DialectEvents.do_setinputsizes`
+
+
+ """
+
+ __visit_name__ = "unicode"
+
+ def __init__(self, length=None, **kwargs):
+ """
+ Create a :class:`.Unicode` object.
+
+ Parameters are the same as that of :class:`.String`,
+ with the exception that ``convert_unicode``
+ defaults to ``True``.
+
+ """
+ kwargs.setdefault("_expect_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
+ super(Unicode, self).__init__(length=length, **kwargs)
+
+
+class UnicodeText(Text):
+
+ """An unbounded-length Unicode string type.
+
+ See :class:`.Unicode` for details on the unicode
+ behavior of this object.
+
+ Like :class:`.Unicode`, usage the :class:`.UnicodeText` type implies a
+ unicode-capable type being used on the backend, such as
+ ``NCLOB``, ``NTEXT``.
+
+ """
+
+ __visit_name__ = "unicode_text"
+
+ def __init__(self, length=None, **kwargs):
+ """
+ Create a Unicode-converting Text type.
+
+ Parameters are the same as that of :class:`_expression.TextClause`,
+ with the exception that ``convert_unicode``
+ defaults to ``True``.
+
+ """
+ kwargs.setdefault("_expect_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
+ super(UnicodeText, self).__init__(length=length, **kwargs)
+
+ def _warn_deprecated_unicode(self):
+ pass
+
+
+class Integer(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``int`` integers."""
+
+ __visit_name__ = "integer"
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ @property
+ def python_type(self):
+ return int
+
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(int(value))
+
+ return process
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # TODO: need a dictionary object that will
+ # handle operators generically here, this is incomplete
+ return {
+ operators.add: {
+ Date: Date,
+ Integer: self.__class__,
+ Numeric: Numeric,
+ },
+ operators.mul: {
+ Interval: Interval,
+ Integer: self.__class__,
+ Numeric: Numeric,
+ },
+ operators.div: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.sub: {Integer: self.__class__, Numeric: Numeric},
+ }
+
+
+class SmallInteger(Integer):
+
+ """A type for smaller ``int`` integers.
+
+ Typically generates a ``SMALLINT`` in DDL, and otherwise acts like
+ a normal :class:`.Integer` on the Python side.
+
+ """
+
+ __visit_name__ = "small_integer"
+
+
+class BigInteger(Integer):
+
+ """A type for bigger ``int`` integers.
+
+ Typically generates a ``BIGINT`` in DDL, and otherwise acts like
+ a normal :class:`.Integer` on the Python side.
+
+ """
+
+ __visit_name__ = "big_integer"
+
+
+class Numeric(_LookupExpressionAdapter, TypeEngine):
+
+ """Base for non-integer numeric types, such as
+ ``NUMERIC``, ``FLOAT``, ``DECIMAL``, and other variants.
+
+ The :class:`.Numeric` datatype when used directly will render DDL
+ corresponding to precision numerics if available, such as
+ ``NUMERIC(precision, scale)``. The :class:`.Float` subclass will
+ attempt to render a floating-point datatype such as ``FLOAT(precision)``.
+
+ :class:`.Numeric` returns Python ``decimal.Decimal`` objects by default,
+ based on the default value of ``True`` for the
+ :paramref:`.Numeric.asdecimal` parameter. If this parameter is set to
+ False, returned values are coerced to Python ``float`` objects.
+
+ The :class:`.Float` subtype, being more specific to floating point,
+ defaults the :paramref:`.Float.asdecimal` flag to False so that the
+ default Python datatype is ``float``.
+
+ .. note::
+
+ When using a :class:`.Numeric` datatype against a database type that
+ returns Python floating point values to the driver, the accuracy of the
+ decimal conversion indicated by :paramref:`.Numeric.asdecimal` may be
+ limited. The behavior of specific numeric/floating point datatypes
+ is a product of the SQL datatype in use, the Python :term:`DBAPI`
+ in use, as well as strategies that may be present within
+ the SQLAlchemy dialect in use. Users requiring specific precision/
+ scale are encouraged to experiment with the available datatypes
+ in order to determine the best results.
+
+ """
+
+ __visit_name__ = "numeric"
+
+ _default_decimal_return_scale = 10
+
+ def __init__(
+ self,
+ precision=None,
+ scale=None,
+ decimal_return_scale=None,
+ asdecimal=True,
+ ):
+ """
+ Construct a Numeric.
+
+ :param precision: the numeric precision for use in DDL ``CREATE
+ TABLE``.
+
+ :param scale: the numeric scale for use in DDL ``CREATE TABLE``.
+
+ :param asdecimal: default True. Return whether or not
+ values should be sent as Python Decimal objects, or
+ as floats. Different DBAPIs send one or the other based on
+ datatypes - the Numeric type will ensure that return values
+ are one or the other across DBAPIs consistently.
+
+ :param decimal_return_scale: Default scale to use when converting
+ from floats to Python decimals. Floating point values will typically
+ be much longer due to decimal inaccuracy, and most floating point
+ database types don't have a notion of "scale", so by default the
+ float type looks for the first ten decimal places when converting.
+ Specifying this value will override that length. Types which
+ do include an explicit ".scale" value, such as the base
+ :class:`.Numeric` as well as the MySQL float types, will use the
+ value of ".scale" as the default for decimal_return_scale, if not
+ otherwise specified.
+
+ When using the ``Numeric`` type, care should be taken to ensure
+ that the asdecimal setting is appropriate for the DBAPI in use -
+ when Numeric applies a conversion from Decimal->float or float->
+ Decimal, this conversion incurs an additional performance overhead
+ for all result columns received.
+
+ DBAPIs that return Decimal natively (e.g. psycopg2) will have
+ better accuracy and higher performance with a setting of ``True``,
+ as the native translation to Decimal reduces the amount of floating-
+ point issues at play, and the Numeric type itself doesn't need
+ to apply any further conversions. However, another DBAPI which
+ returns floats natively *will* incur an additional conversion
+ overhead, and is still subject to floating point data loss - in
+ which case ``asdecimal=False`` will at least remove the extra
+ conversion overhead.
+
+ """
+ self.precision = precision
+ self.scale = scale
+ self.decimal_return_scale = decimal_return_scale
+ self.asdecimal = asdecimal
+
+ @property
+ def _effective_decimal_return_scale(self):
+ if self.decimal_return_scale is not None:
+ return self.decimal_return_scale
+ elif getattr(self, "scale", None) is not None:
+ return self.scale
+ else:
+ return self._default_decimal_return_scale
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(value)
+
+ return process
+
+ @property
+ def python_type(self):
+ if self.asdecimal:
+ return decimal.Decimal
+ else:
+ return float
+
+ def bind_processor(self, dialect):
+ if dialect.supports_native_decimal:
+ return None
+ else:
+ return processors.to_float
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if dialect.supports_native_decimal:
+ # we're a "numeric", DBAPI will give us Decimal directly
+ return None
+ else:
+ util.warn(
+ "Dialect %s+%s does *not* support Decimal "
+ "objects natively, and SQLAlchemy must "
+ "convert from floating point - rounding "
+ "errors and other issues may occur. Please "
+ "consider storing Decimal numbers as strings "
+ "or integers on this platform for lossless "
+ "storage." % (dialect.name, dialect.driver)
+ )
+
+ # we're a "numeric", DBAPI returns floats, convert.
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal,
+ self.scale
+ if self.scale is not None
+ else self._default_decimal_return_scale,
+ )
+ else:
+ if dialect.supports_native_decimal:
+ return processors.to_float
+ else:
+ return None
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ return {
+ operators.mul: {
+ Interval: Interval,
+ Numeric: self.__class__,
+ Integer: self.__class__,
+ },
+ operators.div: {Numeric: self.__class__, Integer: self.__class__},
+ operators.truediv: {
+ Numeric: self.__class__,
+ Integer: self.__class__,
+ },
+ operators.add: {Numeric: self.__class__, Integer: self.__class__},
+ operators.sub: {Numeric: self.__class__, Integer: self.__class__},
+ }
+
+
+class Float(Numeric):
+
+ """Type representing floating point types, such as ``FLOAT`` or ``REAL``.
+
+ This type returns Python ``float`` objects by default, unless the
+ :paramref:`.Float.asdecimal` flag is set to True, in which case they
+ are coerced to ``decimal.Decimal`` objects.
+
+
+ """
+
+ __visit_name__ = "float"
+
+ scale = None
+
+ def __init__(
+ self, precision=None, asdecimal=False, decimal_return_scale=None
+ ):
+ r"""
+ Construct a Float.
+
+ :param precision: the numeric precision for use in DDL ``CREATE
+ TABLE``.
+
+ :param asdecimal: the same flag as that of :class:`.Numeric`, but
+ defaults to ``False``. Note that setting this flag to ``True``
+ results in floating point conversion.
+
+ :param decimal_return_scale: Default scale to use when converting
+ from floats to Python decimals. Floating point values will typically
+ be much longer due to decimal inaccuracy, and most floating point
+ database types don't have a notion of "scale", so by default the
+ float type looks for the first ten decimal places when converting.
+ Specifying this value will override that length. Note that the
+ MySQL float types, which do include "scale", will use "scale"
+ as the default for decimal_return_scale, if not otherwise specified.
+
+ .. versionadded:: 0.9.0
+
+ """
+ self.precision = precision
+ self.asdecimal = asdecimal
+ self.decimal_return_scale = decimal_return_scale
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif dialect.supports_native_decimal:
+ return processors.to_float
+ else:
+ return None
+
+
+class DateTime(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.datetime()`` objects.
+
+ Date and time types return objects from the Python ``datetime``
+ module. Most DBAPIs have built in support for the datetime
+ module, with the noted exception of SQLite. In the case of
+ SQLite, date and time types are stored as strings which are then
+ converted back to datetime objects when rows are returned.
+
+ For the time representation within the datetime type, some
+ backends include additional options, such as timezone support and
+ fractional seconds support. For fractional seconds, use the
+ dialect-specific datatype, such as :class:`.mysql.TIME`. For
+ timezone support, use at least the :class:`_types.TIMESTAMP` datatype,
+ if not the dialect-specific datatype object.
+
+ """
+
+ __visit_name__ = "datetime"
+
+ def __init__(self, timezone=False):
+ """Construct a new :class:`.DateTime`.
+
+ :param timezone: boolean. Indicates that the datetime type should
+ enable timezone support, if available on the
+ **base date/time-holding type only**. It is recommended
+ to make use of the :class:`_types.TIMESTAMP` datatype directly when
+ using this flag, as some databases include separate generic
+ date/time-holding types distinct from the timezone-capable
+ TIMESTAMP datatype, such as Oracle.
+
+
+ """
+ self.timezone = timezone
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return DATETIME_TIMEZONE
+ else:
+ return self
+
+ @property
+ def python_type(self):
+ return dt.datetime
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {Interval: self.__class__},
+ operators.sub: {Interval: self.__class__, DateTime: Interval},
+ }
+
+
+class Date(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.date()`` objects."""
+
+ __visit_name__ = "date"
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ @property
+ def python_type(self):
+ return dt.date
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {
+ Integer: self.__class__,
+ Interval: DateTime,
+ Time: DateTime,
+ },
+ operators.sub: {
+ # date - integer = date
+ Integer: self.__class__,
+ # date - date = integer.
+ Date: Integer,
+ Interval: DateTime,
+ # date - datetime = interval,
+ # this one is not in the PG docs
+ # but works
+ DateTime: Interval,
+ },
+ }
+
+
+class Time(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.time()`` objects."""
+
+ __visit_name__ = "time"
+
+ def __init__(self, timezone=False):
+ self.timezone = timezone
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ @property
+ def python_type(self):
+ return dt.time
+
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return TIME_TIMEZONE
+ else:
+ return self
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {Date: DateTime, Interval: self.__class__},
+ operators.sub: {Time: Interval, Interval: self.__class__},
+ }
+
+
+class _Binary(TypeEngine):
+
+ """Define base behavior for binary types."""
+
+ def __init__(self, length=None):
+ self.length = length
+
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.decode(dialect.encoding).replace("'", "''")
+ return "'%s'" % value
+
+ return process
+
+ @property
+ def python_type(self):
+ return util.binary_type
+
+ # Python 3 - sqlite3 doesn't need the `Binary` conversion
+ # here, though pg8000 does to indicate "bytea"
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ return None
+
+ return process
+
+ # Python 3 has native bytes() type
+ # both sqlite3 and pg8000 seem to return it,
+ # psycopg2 as of 2.5 returns 'memoryview'
+ if util.py2k:
+
+ def result_processor(self, dialect, coltype):
+ return processors.to_str
+
+ else:
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ value = bytes(value)
+ return value
+
+ return process
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+ if isinstance(value, util.string_types):
+ return self
+ else:
+ return super(_Binary, self).coerce_compared_value(op, value)
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BINARY
+
+
+class LargeBinary(_Binary):
+
+ """A type for large binary byte data.
+
+ The :class:`.LargeBinary` type corresponds to a large and/or unlengthed
+ binary type for the target platform, such as BLOB on MySQL and BYTEA for
+ PostgreSQL. It also handles the necessary conversions for the DBAPI.
+
+ """
+
+ __visit_name__ = "large_binary"
+
+ def __init__(self, length=None):
+ """
+ Construct a LargeBinary type.
+
+ :param length: optional, a length for the column for use in
+ DDL statements, for those binary types that accept a length,
+ such as the MySQL BLOB type.
+
+ """
+ _Binary.__init__(self, length=length)
+
+
+class SchemaType(SchemaEventTarget):
+
+ """Mark a type as possibly requiring schema-level DDL for usage.
+
+ Supports types that must be explicitly created/dropped (i.e. PG ENUM type)
+ as well as types that are complimented by table or schema level
+ constraints, triggers, and other rules.
+
+ :class:`.SchemaType` classes can also be targets for the
+ :meth:`.DDLEvents.before_parent_attach` and
+ :meth:`.DDLEvents.after_parent_attach` events, where the events fire off
+ surrounding the association of the type object with a parent
+ :class:`_schema.Column`.
+
+ .. seealso::
+
+ :class:`.Enum`
+
+ :class:`.Boolean`
+
+
+ """
+
+ _use_schema_map = True
+
+ def __init__(
+ self,
+ name=None,
+ schema=None,
+ metadata=None,
+ inherit_schema=False,
+ quote=None,
+ _create_events=True,
+ ):
+ if name is not None:
+ self.name = quoted_name(name, quote)
+ else:
+ self.name = None
+ self.schema = schema
+ self.metadata = metadata
+ self.inherit_schema = inherit_schema
+ self._create_events = _create_events
+
+ if _create_events and self.metadata:
+ event.listen(
+ self.metadata,
+ "before_create",
+ util.portable_instancemethod(self._on_metadata_create),
+ )
+ event.listen(
+ self.metadata,
+ "after_drop",
+ util.portable_instancemethod(self._on_metadata_drop),
+ )
+
+ def _set_parent(self, column, **kw):
+ column._on_table_attach(util.portable_instancemethod(self._set_table))
+
+ def _variant_mapping_for_set_table(self, column):
+ if isinstance(column.type, Variant):
+ variant_mapping = column.type.mapping.copy()
+ variant_mapping["_default"] = column.type.impl
+ else:
+ variant_mapping = None
+ return variant_mapping
+
+ def _set_table(self, column, table):
+ if self.inherit_schema:
+ self.schema = table.schema
+ elif self.metadata and self.schema is None and self.metadata.schema:
+ self.schema = self.metadata.schema
+
+ if not self._create_events:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ event.listen(
+ table,
+ "before_create",
+ util.portable_instancemethod(
+ self._on_table_create, {"variant_mapping": variant_mapping}
+ ),
+ )
+ event.listen(
+ table,
+ "after_drop",
+ util.portable_instancemethod(
+ self._on_table_drop, {"variant_mapping": variant_mapping}
+ ),
+ )
+ if self.metadata is None:
+ # TODO: what's the difference between self.metadata
+ # and table.metadata here ?
+ event.listen(
+ table.metadata,
+ "before_create",
+ util.portable_instancemethod(
+ self._on_metadata_create,
+ {"variant_mapping": variant_mapping},
+ ),
+ )
+ event.listen(
+ table.metadata,
+ "after_drop",
+ util.portable_instancemethod(
+ self._on_metadata_drop,
+ {"variant_mapping": variant_mapping},
+ ),
+ )
+
+ def copy(self, **kw):
+ return self.adapt(self.__class__, _create_events=True)
+
+ def adapt(self, impltype, **kw):
+ schema = kw.pop("schema", self.schema)
+ metadata = kw.pop("metadata", self.metadata)
+ _create_events = kw.pop("_create_events", False)
+ return impltype(
+ name=self.name,
+ schema=schema,
+ inherit_schema=self.inherit_schema,
+ metadata=metadata,
+ _create_events=_create_events,
+ **kw
+ )
+
+ @property
+ def bind(self):
+ return self.metadata and self.metadata.bind or None
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue CREATE DDL for this type, if applicable."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t.create(bind=bind, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue DROP DDL for this type, if applicable."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_table_create(target, bind, **kw)
+
+ def _on_table_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_table_drop(target, bind, **kw)
+
+ def _on_metadata_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_metadata_create(target, bind, **kw)
+
+ def _on_metadata_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_metadata_drop(target, bind, **kw)
+
+ def _is_impl_for_variant(self, dialect, kw):
+ variant_mapping = kw.pop("variant_mapping", None)
+ if variant_mapping is None:
+ return True
+
+ # since PostgreSQL is the only DB that has ARRAY this can only
+ # be integration tested by PG-specific tests
+ def _we_are_the_impl(typ):
+ return (
+ typ is self or isinstance(typ, ARRAY) and typ.item_type is self
+ )
+
+ if dialect.name in variant_mapping and _we_are_the_impl(
+ variant_mapping[dialect.name]
+ ):
+ return True
+ elif dialect.name not in variant_mapping:
+ return _we_are_the_impl(variant_mapping["_default"])
+
+
+class Enum(Emulated, String, SchemaType):
+ """Generic Enum Type.
+
+ The :class:`.Enum` type provides a set of possible string values
+ which the column is constrained towards.
+
+ The :class:`.Enum` type will make use of the backend's native "ENUM"
+ type if one is available; otherwise, it uses a VARCHAR datatype.
+ An option also exists to automatically produce a CHECK constraint
+ when the VARCHAR (so called "non-native") variant is produced;
+ see the :paramref:`.Enum.create_constraint` flag.
+
+ The :class:`.Enum` type also provides in-Python validation of string
+ values during both read and write operations. When reading a value
+ from the database in a result set, the string value is always checked
+ against the list of possible values and a ``LookupError`` is raised
+ if no match is found. When passing a value to the database as a
+ plain string within a SQL statement, if the
+ :paramref:`.Enum.validate_strings` parameter is
+ set to True, a ``LookupError`` is raised for any string value that's
+ not located in the given list of possible values; note that this
+ impacts usage of LIKE expressions with enumerated values (an unusual
+ use case).
+
+ .. versionchanged:: 1.1 the :class:`.Enum` type now provides in-Python
+ validation of input values as well as on data being returned by
+ the database.
+
+ The source of enumerated values may be a list of string values, or
+ alternatively a PEP-435-compliant enumerated class. For the purposes
+ of the :class:`.Enum` datatype, this class need only provide a
+ ``__members__`` method.
+
+ When using an enumerated class, the enumerated objects are used
+ both for input and output, rather than strings as is the case with
+ a plain-string enumerated type::
+
+ import enum
+ from sqlalchemy import Enum
+
+ class MyEnum(enum.Enum):
+ one = 1
+ two = 2
+ three = 3
+
+ t = Table(
+ 'data', MetaData(),
+ Column('value', Enum(MyEnum))
+ )
+
+ connection.execute(t.insert(), {"value": MyEnum.two})
+ assert connection.scalar(t.select()) is MyEnum.two
+
+ Above, the string names of each element, e.g. "one", "two", "three",
+ are persisted to the database; the values of the Python Enum, here
+ indicated as integers, are **not** used; the value of each enum can
+ therefore be any kind of Python object whether or not it is persistable.
+
+ In order to persist the values and not the names, the
+ :paramref:`.Enum.values_callable` parameter may be used. The value of
+ this parameter is a user-supplied callable, which is intended to be used
+ with a PEP-435-compliant enumerated class and returns a list of string
+ values to be persisted. For a simple enumeration that uses string values,
+ a callable such as ``lambda x: [e.value for e in x]`` is sufficient.
+
+ .. versionadded:: 1.1 - support for PEP-435-style enumerated
+ classes.
+
+
+ .. seealso::
+
+ :class:`_postgresql.ENUM` - PostgreSQL-specific type,
+ which has additional functionality.
+
+ :class:`.mysql.ENUM` - MySQL-specific type
+
+ """
+
+ __visit_name__ = "enum"
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`.Enum.convert_unicode` parameter is deprecated "
+ "and will be removed in a future release. All modern DBAPIs "
+ "now support Python Unicode directly and this parameter is "
+ "unnecessary.",
+ )
+ )
+ def __init__(self, *enums, **kw):
+ r"""Construct an enum.
+
+ Keyword arguments which don't apply to a specific backend are ignored
+ by that backend.
+
+ :param \*enums: either exactly one PEP-435 compliant enumerated type
+ or one or more string labels.
+
+ .. versionadded:: 1.1 a PEP-435 style enumerated class may be
+ passed.
+
+ :param convert_unicode: Enable unicode-aware bind parameter and
+ result-set processing for this Enum's data under Python 2 only.
+ Under Python 2, this is set automatically based on the presence of
+ unicode label strings. This flag will be removed in SQLAlchemy 2.0.
+
+ :param create_constraint: defaults to False. When creating a
+ non-native enumerated type, also build a CHECK constraint on the
+ database against the valid values.
+
+ .. note:: it is strongly recommended that the CHECK constraint
+ have an explicit name in order to support schema-management
+ concerns. This can be established either by setting the
+ :paramref:`.Enum.name` parameter or by setting up an
+ appropriate naming convention; see
+ :ref:`constraint_naming_conventions` for background.
+
+ .. versionchanged:: 1.4 - this flag now defaults to False, meaning
+ no CHECK constraint is generated for a non-native enumerated
+ type.
+
+ :param metadata: Associate this type directly with a ``MetaData``
+ object. For types that exist on the target database as an
+ independent schema construct (PostgreSQL), this type will be
+ created and dropped within ``create_all()`` and ``drop_all()``
+ operations. If the type is not associated with any ``MetaData``
+ object, it will associate itself with each ``Table`` in which it is
+ used, and will be created when any of those individual tables are
+ created, after a check is performed for its existence. The type is
+ only dropped when ``drop_all()`` is called for that ``Table``
+ object's metadata, however.
+
+ The value of the :paramref:`_schema.MetaData.schema` parameter of
+ the :class:`_schema.MetaData` object, if set, will be used as the
+ default value of the :paramref:`_types.Enum.schema` on this object
+ if an explicit value is not otherwise supplied.
+
+ .. versionchanged:: 1.4.12 :class:`_types.Enum` inherits the
+ :paramref:`_schema.MetaData.schema` parameter of the
+ :class:`_schema.MetaData` object if present, when passed using
+ the :paramref:`_types.Enum.metadata` parameter.
+
+ :param name: The name of this type. This is required for PostgreSQL
+ and any future supported database which requires an explicitly
+ named type, or an explicitly named constraint in order to generate
+ the type and/or a table that uses it. If a PEP-435 enumerated
+ class was used, its name (converted to lower case) is used by
+ default.
+
+ :param native_enum: Use the database's native ENUM type when
+ available. Defaults to True. When False, uses VARCHAR + check
+ constraint for all backends. When False, the VARCHAR length can be
+ controlled with :paramref:`.Enum.length`; currently "length" is
+ ignored if native_enum=True.
+
+ :param length: Allows specifying a custom length for the VARCHAR
+ when :paramref:`.Enum.native_enum` is False. By default it uses the
+ length of the longest value.
+
+ .. versionadded:: 1.3.16
+
+ :param schema: Schema name of this type. For types that exist on the
+ target database as an independent schema construct (PostgreSQL),
+ this parameter specifies the named schema in which the type is
+ present.
+
+ If not present, the schema name will be taken from the
+ :class:`_schema.MetaData` collection if passed as
+ :paramref:`_types.Enum.metadata`, for a :class:`_schema.MetaData`
+ that includes the :paramref:`_schema.MetaData.schema` parameter.
+
+ .. versionchanged:: 1.4.12 :class:`_types.Enum` inherits the
+ :paramref:`_schema.MetaData.schema` parameter of the
+ :class:`_schema.MetaData` object if present, when passed using
+ the :paramref:`_types.Enum.metadata` parameter.
+
+ Otherwise, if the :paramref:`_types.Enum.inherit_schema` flag is set
+ to ``True``, the schema will be inherited from the associated
+ :class:`_schema.Table` object if any; when
+ :paramref:`_types.Enum.inherit_schema` is at its default of
+ ``False``, the owning table's schema is **not** used.
+
+
+ :param quote: Set explicit quoting preferences for the type's name.
+
+ :param inherit_schema: When ``True``, the "schema" from the owning
+ :class:`_schema.Table`
+ will be copied to the "schema" attribute of this
+ :class:`.Enum`, replacing whatever value was passed for the
+ ``schema`` attribute. This also takes effect when using the
+ :meth:`_schema.Table.to_metadata` operation.
+
+ :param validate_strings: when True, string values that are being
+ passed to the database in a SQL statement will be checked
+ for validity against the list of enumerated values. Unrecognized
+ values will result in a ``LookupError`` being raised.
+
+ .. versionadded:: 1.1.0b2
+
+ :param values_callable: A callable which will be passed the PEP-435
+ compliant enumerated type, which should then return a list of string
+ values to be persisted. This allows for alternate usages such as
+ using the string value of an enum to be persisted to the database
+ instead of its name.
+
+ .. versionadded:: 1.2.3
+
+ :param sort_key_function: a Python callable which may be used as the
+ "key" argument in the Python ``sorted()`` built-in. The SQLAlchemy
+ ORM requires that primary key columns which are mapped must
+ be sortable in some way. When using an unsortable enumeration
+ object such as a Python 3 ``Enum`` object, this parameter may be
+ used to set a default sort key function for the objects. By
+ default, the database value of the enumeration is used as the
+ sorting function.
+
+ .. versionadded:: 1.3.8
+
+ :param omit_aliases: A boolean that when true will remove aliases from
+ pep 435 enums. For backward compatibility it defaults to ``False``.
+ A deprecation warning is raised if the enum has aliases and this
+ flag was not set.
+
+ .. versionadded:: 1.4.5
+
+ .. deprecated:: 1.4 The default will be changed to ``True`` in
+ SQLAlchemy 2.0.
+
+ """
+ self._enum_init(enums, kw)
+
+ @property
+ def _enums_argument(self):
+ if self.enum_class is not None:
+ return [self.enum_class]
+ else:
+ return self.enums
+
+ def _enum_init(self, enums, kw):
+ """internal init for :class:`.Enum` and subclasses.
+
+ friendly init helper used by subclasses to remove
+ all the Enum-specific keyword arguments from kw. Allows all
+ other arguments in kw to pass through.
+
+ """
+ self.native_enum = kw.pop("native_enum", True)
+ self.create_constraint = kw.pop("create_constraint", False)
+ self.values_callable = kw.pop("values_callable", None)
+ self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
+ length_arg = kw.pop("length", NO_ARG)
+ self._omit_aliases = kw.pop("omit_aliases", NO_ARG)
+ _disable_warnings = kw.pop("_disable_warnings", False)
+ values, objects = self._parse_into_values(enums, kw)
+ self._setup_for_values(values, objects, kw)
+
+ convert_unicode = kw.pop("convert_unicode", None)
+ self.validate_strings = kw.pop("validate_strings", False)
+
+ if convert_unicode is None:
+ for e in self.enums:
+ # this is all py2k logic that can go away for py3k only,
+ # "expect unicode" will always be implicitly true
+ if isinstance(e, util.text_type):
+ _expect_unicode = True
+ break
+ else:
+ _expect_unicode = False
+ else:
+ _expect_unicode = convert_unicode
+
+ if self.enums:
+ self._default_length = length = max(len(x) for x in self.enums)
+ else:
+ self._default_length = length = 0
+
+ if length_arg is not NO_ARG:
+ if self.native_enum:
+ if not _disable_warnings:
+ util.warn(
+ "Enum 'length' argument is currently ignored unless "
+ "native_enum is specified as False, including for DDL "
+ "that renders VARCHAR in any case. This may change "
+ "in a future release."
+ )
+ else:
+ if not _disable_warnings and length_arg < length:
+ raise ValueError(
+ "When provided, length must be larger or equal"
+ " than the length of the longest enum value. %s < %s"
+ % (length_arg, length)
+ )
+ length = length_arg
+
+ self._valid_lookup[None] = self._object_lookup[None] = None
+
+ super(Enum, self).__init__(
+ length=length, _expect_unicode=_expect_unicode
+ )
+
+ if self.enum_class:
+ kw.setdefault("name", self.enum_class.__name__.lower())
+ SchemaType.__init__(
+ self,
+ name=kw.pop("name", None),
+ schema=kw.pop("schema", None),
+ metadata=kw.pop("metadata", None),
+ inherit_schema=kw.pop("inherit_schema", False),
+ quote=kw.pop("quote", None),
+ _create_events=kw.pop("_create_events", True),
+ )
+
+ def _parse_into_values(self, enums, kw):
+ if not enums and "_enums" in kw:
+ enums = kw.pop("_enums")
+
+ if len(enums) == 1 and hasattr(enums[0], "__members__"):
+ self.enum_class = enums[0]
+
+ _members = self.enum_class.__members__
+
+ aliases = [n for n, v in _members.items() if v.name != n]
+ if self._omit_aliases is NO_ARG and aliases:
+ util.warn_deprecated_20(
+ "The provided enum %s contains the aliases %s. The "
+ "``omit_aliases`` will default to ``True`` in SQLAlchemy "
+ "2.0. Specify a value to silence this warning."
+ % (self.enum_class.__name__, aliases)
+ )
+ if self._omit_aliases is True:
+ # remove aliases
+ members = OrderedDict(
+ (n, v) for n, v in _members.items() if v.name == n
+ )
+ else:
+ members = _members
+ if self.values_callable:
+ values = self.values_callable(self.enum_class)
+ else:
+ values = list(members)
+ objects = [members[k] for k in members]
+ return values, objects
+ else:
+ self.enum_class = None
+ return enums, enums
+
+ def _setup_for_values(self, values, objects, kw):
+ self.enums = list(values)
+
+ self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
+
+ self._object_lookup = dict(zip(values, objects))
+
+ self._valid_lookup.update(
+ [
+ (value, self._valid_lookup[self._object_lookup[value]])
+ for value in values
+ ]
+ )
+
+ @property
+ def sort_key_function(self):
+ if self._sort_key_function is NO_ARG:
+ return self._db_value_for_elem
+ else:
+ return self._sort_key_function
+
+ @property
+ def native(self):
+ return self.native_enum
+
+ def _db_value_for_elem(self, elem):
+ try:
+ return self._valid_lookup[elem]
+ except KeyError as err:
+ # for unknown string values, we return as is. While we can
+ # validate these if we wanted, that does not allow for lesser-used
+ # end-user use cases, such as using a LIKE comparison with an enum,
+ # or for an application that wishes to apply string tests to an
+ # ENUM (see [ticket:3725]). While we can decide to differentiate
+ # here between an INSERT statement and a criteria used in a SELECT,
+ # for now we're staying conservative w/ behavioral changes (perhaps
+ # someone has a trigger that handles strings on INSERT)
+ if not self.validate_strings and isinstance(
+ elem, compat.string_types
+ ):
+ return elem
+ else:
+ util.raise_(
+ LookupError(
+ "'%s' is not among the defined enum values. "
+ "Enum name: %s. Possible values: %s"
+ % (
+ elem,
+ self.name,
+ langhelpers.repr_tuple_names(self.enums),
+ )
+ ),
+ replace_context=err,
+ )
+
+ class Comparator(String.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ op, typ = super(Enum.Comparator, self)._adapt_expression(
+ op, other_comparator
+ )
+ if op is operators.concat_op:
+ typ = String(
+ self.type.length, _expect_unicode=self.type._expect_unicode
+ )
+ return op, typ
+
+ comparator_factory = Comparator
+
+ def _object_value_for_elem(self, elem):
+ try:
+ return self._object_lookup[elem]
+ except KeyError as err:
+ util.raise_(
+ LookupError(
+ "'%s' is not among the defined enum values. "
+ "Enum name: %s. Possible values: %s"
+ % (
+ elem,
+ self.name,
+ langhelpers.repr_tuple_names(self.enums),
+ )
+ ),
+ replace_context=err,
+ )
+
+ def __repr__(self):
+ return util.generic_repr(
+ self,
+ additional_kw=[
+ ("native_enum", True),
+ ("create_constraint", False),
+ ("length", self._default_length),
+ ],
+ to_inspect=[Enum, SchemaType],
+ )
+
+ def as_generic(self, allow_nulltype=False):
+ if hasattr(self, "enums"):
+ args = self.enums
+ else:
+ raise NotImplementedError(
+ "TypeEngine.as_generic() heuristic "
+ "is undefined for types that inherit Enum but do not have "
+ "an `enums` attribute."
+ )
+
+ return util.constructor_copy(
+ self, self._generic_type_affinity, *args, _disable_warnings=True
+ )
+
+ def adapt_to_emulated(self, impltype, **kw):
+ kw.setdefault("_expect_unicode", self._expect_unicode)
+ kw.setdefault("validate_strings", self.validate_strings)
+ kw.setdefault("name", self.name)
+ kw["_disable_warnings"] = True
+ kw.setdefault("schema", self.schema)
+ kw.setdefault("inherit_schema", self.inherit_schema)
+ kw.setdefault("metadata", self.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("native_enum", self.native_enum)
+ kw.setdefault("values_callable", self.values_callable)
+ kw.setdefault("create_constraint", self.create_constraint)
+ kw.setdefault("length", self.length)
+ kw.setdefault("omit_aliases", self._omit_aliases)
+ assert "_enums" in kw
+ return impltype(**kw)
+
+ def adapt(self, impltype, **kw):
+ kw["_enums"] = self._enums_argument
+ kw["_disable_warnings"] = True
+ return super(Enum, self).adapt(impltype, **kw)
+
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
+ return (
+ not self.native_enum or not compiler.dialect.supports_native_enum
+ )
+
+ @util.preload_module("sqlalchemy.sql.schema")
+ def _set_table(self, column, table):
+ schema = util.preloaded.sql_schema
+ SchemaType._set_table(self, column, table)
+
+ if not self.create_constraint:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ e = schema.CheckConstraint(
+ type_coerce(column, String()).in_(self.enums),
+ name=_NONE_NAME if self.name is None else self.name,
+ _create_rule=util.portable_instancemethod(
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
+ )
+ assert e.table is table
+
+ def literal_processor(self, dialect):
+ parent_processor = super(Enum, self).literal_processor(dialect)
+
+ def process(value):
+ value = self._db_value_for_elem(value)
+ if parent_processor:
+ value = parent_processor(value)
+ return value
+
+ return process
+
+ def bind_processor(self, dialect):
+ parent_processor = super(Enum, self).bind_processor(dialect)
+
+ def process(value):
+ value = self._db_value_for_elem(value)
+ if parent_processor:
+ value = parent_processor(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ parent_processor = super(Enum, self).result_processor(dialect, coltype)
+
+ def process(value):
+ if parent_processor:
+ value = parent_processor(value)
+
+ value = self._object_value_for_elem(value)
+ return value
+
+ return process
+
+ def copy(self, **kw):
+ return SchemaType.copy(self, **kw)
+
+ @property
+ def python_type(self):
+ if self.enum_class:
+ return self.enum_class
+ else:
+ return super(Enum, self).python_type
+
+
+class PickleType(TypeDecorator):
+ """Holds Python objects, which are serialized using pickle.
+
+ PickleType builds upon the Binary type to apply Python's
+ ``pickle.dumps()`` to incoming objects, and ``pickle.loads()`` on
+ the way out, allowing any pickleable Python object to be stored as
+ a serialized binary field.
+
+ To allow ORM change events to propagate for elements associated
+ with :class:`.PickleType`, see :ref:`mutable_toplevel`.
+
+ """
+
+ impl = LargeBinary
+ cache_ok = True
+
+ def __init__(
+ self,
+ protocol=pickle.HIGHEST_PROTOCOL,
+ pickler=None,
+ comparator=None,
+ impl=None,
+ ):
+ """
+ Construct a PickleType.
+
+ :param protocol: defaults to ``pickle.HIGHEST_PROTOCOL``.
+
+ :param pickler: defaults to cPickle.pickle or pickle.pickle if
+ cPickle is not available. May be any object with
+ pickle-compatible ``dumps`` and ``loads`` methods.
+
+ :param comparator: a 2-arg callable predicate used
+ to compare values of this type. If left as ``None``,
+ the Python "equals" operator is used to compare values.
+
+ :param impl: A binary-storing :class:`_types.TypeEngine` class or
+ instance to use in place of the default :class:`_types.LargeBinary`.
+ For example the :class: `_mysql.LONGBLOB` class may be more effective
+ when using MySQL.
+
+ .. versionadded:: 1.4.20
+
+ """
+ self.protocol = protocol
+ self.pickler = pickler or pickle
+ self.comparator = comparator
+ super(PickleType, self).__init__()
+
+ if impl:
+ self.impl = to_instance(impl)
+
+ def __reduce__(self):
+ return PickleType, (self.protocol, None, self.comparator)
+
+ def bind_processor(self, dialect):
+ impl_processor = self.impl.bind_processor(dialect)
+ dumps = self.pickler.dumps
+ protocol = self.protocol
+ if impl_processor:
+
+ def process(value):
+ if value is not None:
+ value = dumps(value, protocol)
+ return impl_processor(value)
+
+ else:
+
+ def process(value):
+ if value is not None:
+ value = dumps(value, protocol)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ loads = self.pickler.loads
+ if impl_processor:
+
+ def process(value):
+ value = impl_processor(value)
+ if value is None:
+ return None
+ return loads(value)
+
+ else:
+
+ def process(value):
+ if value is None:
+ return None
+ return loads(value)
+
+ return process
+
+ def compare_values(self, x, y):
+ if self.comparator:
+ return self.comparator(x, y)
+ else:
+ return x == y
+
+
+class Boolean(Emulated, TypeEngine, SchemaType):
+
+ """A bool datatype.
+
+ :class:`.Boolean` typically uses BOOLEAN or SMALLINT on the DDL side,
+ and on the Python side deals in ``True`` or ``False``.
+
+ The :class:`.Boolean` datatype currently has two levels of assertion
+ that the values persisted are simple true/false values. For all
+ backends, only the Python values ``None``, ``True``, ``False``, ``1``
+ or ``0`` are accepted as parameter values. For those backends that
+ don't support a "native boolean" datatype, an option exists to
+ also create a CHECK constraint on the target column
+
+ .. versionchanged:: 1.2 the :class:`.Boolean` datatype now asserts that
+ incoming Python values are already in pure boolean form.
+
+
+ """
+
+ __visit_name__ = "boolean"
+ native = True
+
+ def __init__(
+ self, create_constraint=False, name=None, _create_events=True
+ ):
+ """Construct a Boolean.
+
+ :param create_constraint: defaults to False. If the boolean
+ is generated as an int/smallint, also create a CHECK constraint
+ on the table that ensures 1 or 0 as a value.
+
+ .. note:: it is strongly recommended that the CHECK constraint
+ have an explicit name in order to support schema-management
+ concerns. This can be established either by setting the
+ :paramref:`.Boolean.name` parameter or by setting up an
+ appropriate naming convention; see
+ :ref:`constraint_naming_conventions` for background.
+
+ .. versionchanged:: 1.4 - this flag now defaults to False, meaning
+ no CHECK constraint is generated for a non-native enumerated
+ type.
+
+ :param name: if a CHECK constraint is generated, specify
+ the name of the constraint.
+
+ """
+ self.create_constraint = create_constraint
+ self.name = name
+ self._create_events = _create_events
+
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
+ return (
+ not compiler.dialect.supports_native_boolean
+ and compiler.dialect.non_native_boolean_check_constraint
+ )
+
+ @util.preload_module("sqlalchemy.sql.schema")
+ def _set_table(self, column, table):
+ schema = util.preloaded.sql_schema
+ if not self.create_constraint:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ e = schema.CheckConstraint(
+ type_coerce(column, self).in_([0, 1]),
+ name=_NONE_NAME if self.name is None else self.name,
+ _create_rule=util.portable_instancemethod(
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
+ )
+ assert e.table is table
+
+ @property
+ def python_type(self):
+ return bool
+
+ _strict_bools = frozenset([None, True, False])
+
+ def _strict_as_bool(self, value):
+ if value not in self._strict_bools:
+ if not isinstance(value, int):
+ raise TypeError("Not a boolean value: %r" % (value,))
+ else:
+ raise ValueError(
+ "Value %r is not None, True, or False" % (value,)
+ )
+ return value
+
+ def literal_processor(self, dialect):
+ compiler = dialect.statement_compiler(dialect, None)
+ true = compiler.visit_true(None)
+ false = compiler.visit_false(None)
+
+ def process(value):
+ return true if self._strict_as_bool(value) else false
+
+ return process
+
+ def bind_processor(self, dialect):
+ _strict_as_bool = self._strict_as_bool
+ if dialect.supports_native_boolean:
+ _coerce = bool
+ else:
+ _coerce = int
+
+ def process(value):
+ value = _strict_as_bool(value)
+ if value is not None:
+ value = _coerce(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if dialect.supports_native_boolean:
+ return None
+ else:
+ return processors.int_to_boolean
+
+
+class _AbstractInterval(_LookupExpressionAdapter, TypeEngine):
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {
+ Date: DateTime,
+ Interval: self.__class__,
+ DateTime: DateTime,
+ Time: Time,
+ },
+ operators.sub: {Interval: self.__class__},
+ operators.mul: {Numeric: self.__class__},
+ operators.truediv: {Numeric: self.__class__},
+ operators.div: {Numeric: self.__class__},
+ }
+
+ @property
+ def _type_affinity(self):
+ return Interval
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+ return self.impl.coerce_compared_value(op, value)
+
+
+class Interval(Emulated, _AbstractInterval, TypeDecorator):
+
+ """A type for ``datetime.timedelta()`` objects.
+
+ The Interval type deals with ``datetime.timedelta`` objects. In
+ PostgreSQL, the native ``INTERVAL`` type is used; for others, the
+ value is stored as a date which is relative to the "epoch"
+ (Jan. 1, 1970).
+
+ Note that the ``Interval`` type does not currently provide date arithmetic
+ operations on platforms which do not support interval types natively. Such
+ operations usually require transformation of both sides of the expression
+ (such as, conversion of both sides into integer epoch values first) which
+ currently is a manual procedure (such as via
+ :attr:`~sqlalchemy.sql.expression.func`).
+
+ """
+
+ impl = DateTime
+ epoch = dt.datetime.utcfromtimestamp(0)
+ cache_ok = True
+
+ def __init__(self, native=True, second_precision=None, day_precision=None):
+ """Construct an Interval object.
+
+ :param native: when True, use the actual
+ INTERVAL type provided by the database, if
+ supported (currently PostgreSQL, Oracle).
+ Otherwise, represent the interval data as
+ an epoch value regardless.
+
+ :param second_precision: For native interval types
+ which support a "fractional seconds precision" parameter,
+ i.e. Oracle and PostgreSQL
+
+ :param day_precision: for native interval types which
+ support a "day precision" parameter, i.e. Oracle.
+
+ """
+ super(Interval, self).__init__()
+ self.native = native
+ self.second_precision = second_precision
+ self.day_precision = day_precision
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+ def adapt_to_emulated(self, impltype, **kw):
+ return _AbstractInterval.adapt(self, impltype, **kw)
+
+ def bind_processor(self, dialect):
+ impl_processor = self.impl.bind_processor(dialect)
+ epoch = self.epoch
+ if impl_processor:
+
+ def process(value):
+ if value is not None:
+ value = epoch + value
+ return impl_processor(value)
+
+ else:
+
+ def process(value):
+ if value is not None:
+ value = epoch + value
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ epoch = self.epoch
+ if impl_processor:
+
+ def process(value):
+ value = impl_processor(value)
+ if value is None:
+ return None
+ return value - epoch
+
+ else:
+
+ def process(value):
+ if value is None:
+ return None
+ return value - epoch
+
+ return process
+
+
+class JSON(Indexable, TypeEngine):
+ """Represent a SQL JSON type.
+
+ .. note:: :class:`_types.JSON`
+ is provided as a facade for vendor-specific
+ JSON types. Since it supports JSON SQL operations, it only
+ works on backends that have an actual JSON type, currently:
+
+ * PostgreSQL - see :class:`sqlalchemy.dialects.postgresql.JSON` and
+ :class:`sqlalchemy.dialects.postgresql.JSONB` for backend-specific
+ notes
+
+ * MySQL - see
+ :class:`sqlalchemy.dialects.mysql.JSON` for backend-specific notes
+
+ * SQLite as of version 3.9 - see
+ :class:`sqlalchemy.dialects.sqlite.JSON` for backend-specific notes
+
+ * Microsoft SQL Server 2016 and later - see
+ :class:`sqlalchemy.dialects.mssql.JSON` for backend-specific notes
+
+ :class:`_types.JSON` is part of the Core in support of the growing
+ popularity of native JSON datatypes.
+
+ The :class:`_types.JSON` type stores arbitrary JSON format data, e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', JSON)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ {"data": {"key1": "value1", "key2": "value2"}}
+ )
+
+ **JSON-Specific Expression Operators**
+
+ The :class:`_types.JSON`
+ datatype provides these additional SQL operations:
+
+ * Keyed index operations::
+
+ data_table.c.data['some key']
+
+ * Integer index operations::
+
+ data_table.c.data[3]
+
+ * Path index operations::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
+
+ * Data casters for specific JSON element types, subsequent to an index
+ or path operation being invoked::
+
+ data_table.c.data["some key"].as_integer()
+
+ .. versionadded:: 1.3.11
+
+ Additional operations may be available from the dialect-specific versions
+ of :class:`_types.JSON`, such as
+ :class:`sqlalchemy.dialects.postgresql.JSON` and
+ :class:`sqlalchemy.dialects.postgresql.JSONB` which both offer additional
+ PostgreSQL-specific operations.
+
+ **Casting JSON Elements to Other Types**
+
+ Index operations, i.e. those invoked by calling upon the expression using
+ the Python bracket operator as in ``some_column['some key']``, return an
+ expression object whose type defaults to :class:`_types.JSON` by default,
+ so that
+ further JSON-oriented instructions may be called upon the result type.
+ However, it is likely more common that an index operation is expected
+ to return a specific scalar element, such as a string or integer. In
+ order to provide access to these elements in a backend-agnostic way,
+ a series of data casters are provided:
+
+ * :meth:`.JSON.Comparator.as_string` - return the element as a string
+
+ * :meth:`.JSON.Comparator.as_boolean` - return the element as a boolean
+
+ * :meth:`.JSON.Comparator.as_float` - return the element as a float
+
+ * :meth:`.JSON.Comparator.as_integer` - return the element as an integer
+
+ These data casters are implemented by supporting dialects in order to
+ assure that comparisons to the above types will work as expected, such as::
+
+ # integer comparison
+ data_table.c.data["some_integer_key"].as_integer() == 5
+
+ # boolean comparison
+ data_table.c.data["some_boolean"].as_boolean() == True
+
+ .. versionadded:: 1.3.11 Added type-specific casters for the basic JSON
+ data element types.
+
+ .. note::
+
+ The data caster functions are new in version 1.3.11, and supersede
+ the previous documented approaches of using CAST; for reference,
+ this looked like::
+
+ from sqlalchemy import cast, type_coerce
+ from sqlalchemy import String, JSON
+ cast(
+ data_table.c.data['some_key'], String
+ ) == type_coerce(55, JSON)
+
+ The above case now works directly as::
+
+ data_table.c.data['some_key'].as_integer() == 5
+
+ For details on the previous comparison approach within the 1.3.x
+ series, see the documentation for SQLAlchemy 1.2 or the included HTML
+ files in the doc/ directory of the version's distribution.
+
+ **Detecting Changes in JSON columns when using the ORM**
+
+ The :class:`_types.JSON` type, when used with the SQLAlchemy ORM, does not
+ detect in-place mutations to the structure. In order to detect these, the
+ :mod:`sqlalchemy.ext.mutable` extension must be used. This extension will
+ allow "in-place" changes to the datastructure to produce events which
+ will be detected by the unit of work. See the example at :class:`.HSTORE`
+ for a simple example involving a dictionary.
+
+ **Support for JSON null vs. SQL NULL**
+
+ When working with NULL values, the :class:`_types.JSON` type recommends the
+ use of two specific constants in order to differentiate between a column
+ that evaluates to SQL NULL, e.g. no value, vs. the JSON-encoded string of
+ ``"null"``. To insert or select against a value that is SQL NULL, use the
+ constant :func:`.null`. This symbol may be passed as a parameter value
+ specifically when using the :class:`_types.JSON` datatype, which contains
+ special logic that interprets this symbol to mean that the column value
+ should be SQL NULL as opposed to JSON ``"null"``::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), {"json_value": null()})
+
+ To insert or select against a value that is JSON ``"null"``, use the
+ constant :attr:`_types.JSON.NULL`::
+
+ conn.execute(table.insert(), {"json_value": JSON.NULL})
+
+ The :class:`_types.JSON` type supports a flag
+ :paramref:`_types.JSON.none_as_null` which when set to True will result
+ in the Python constant ``None`` evaluating to the value of SQL
+ NULL, and when set to False results in the Python constant
+ ``None`` evaluating to the value of JSON ``"null"``. The Python
+ value ``None`` may be used in conjunction with either
+ :attr:`_types.JSON.NULL` and :func:`.null` in order to indicate NULL
+ values, but care must be taken as to the value of the
+ :paramref:`_types.JSON.none_as_null` in these cases.
+
+ **Customizing the JSON Serializer**
+
+ The JSON serializer and deserializer used by :class:`_types.JSON`
+ defaults to
+ Python's ``json.dumps`` and ``json.loads`` functions; in the case of the
+ psycopg2 dialect, psycopg2 may be using its own custom loader function.
+
+ In order to affect the serializer / deserializer, they are currently
+ configurable at the :func:`_sa.create_engine` level via the
+ :paramref:`_sa.create_engine.json_serializer` and
+ :paramref:`_sa.create_engine.json_deserializer` parameters. For example,
+ to turn off ``ensure_ascii``::
+
+ engine = create_engine(
+ "sqlite://",
+ json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False))
+
+ .. versionchanged:: 1.3.7
+
+ SQLite dialect's ``json_serializer`` and ``json_deserializer``
+ parameters renamed from ``_json_serializer`` and
+ ``_json_deserializer``.
+
+ .. seealso::
+
+ :class:`sqlalchemy.dialects.postgresql.JSON`
+
+ :class:`sqlalchemy.dialects.postgresql.JSONB`
+
+ :class:`sqlalchemy.dialects.mysql.JSON`
+
+ :class:`sqlalchemy.dialects.sqlite.JSON`
+
+ .. versionadded:: 1.1
+
+
+ """
+
+ __visit_name__ = "JSON"
+
+ hashable = False
+ NULL = util.symbol("JSON_NULL")
+ """Describe the json value of NULL.
+
+ This value is used to force the JSON value of ``"null"`` to be
+ used as the value. A value of Python ``None`` will be recognized
+ either as SQL NULL or JSON ``"null"``, based on the setting
+ of the :paramref:`_types.JSON.none_as_null` flag; the
+ :attr:`_types.JSON.NULL`
+ constant can be used to always resolve to JSON ``"null"`` regardless
+ of this setting. This is in contrast to the :func:`_expression.null`
+ construct,
+ which always resolves to SQL NULL. E.g.::
+
+ from sqlalchemy import null
+ from sqlalchemy.dialects.postgresql import JSON
+
+ # will *always* insert SQL NULL
+ obj1 = MyObject(json_value=null())
+
+ # will *always* insert JSON string "null"
+ obj2 = MyObject(json_value=JSON.NULL)
+
+ session.add_all([obj1, obj2])
+ session.commit()
+
+ In order to set JSON NULL as a default value for a column, the most
+ transparent method is to use :func:`_expression.text`::
+
+ Table(
+ 'my_table', metadata,
+ Column('json_data', JSON, default=text("'null'"))
+ )
+
+ While it is possible to use :attr:`_types.JSON.NULL` in this context, the
+ :attr:`_types.JSON.NULL` value will be returned as the value of the
+ column,
+ which in the context of the ORM or other repurposing of the default
+ value, may not be desirable. Using a SQL expression means the value
+ will be re-fetched from the database within the context of retrieving
+ generated defaults.
+
+
+ """
+
+ def __init__(self, none_as_null=False):
+ """Construct a :class:`_types.JSON` type.
+
+ :param none_as_null=False: if True, persist the value ``None`` as a
+ SQL NULL value, not the JSON encoding of ``null``. Note that when this
+ flag is False, the :func:`.null` construct can still be used to
+ persist a NULL value, which may be passed directly as a parameter
+ value that is specially interpreted by the :class:`_types.JSON` type
+ as SQL NULL::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), {"data": null()})
+
+ .. note::
+
+ :paramref:`_types.JSON.none_as_null` does **not** apply to the
+ values passed to :paramref:`_schema.Column.default` and
+ :paramref:`_schema.Column.server_default`; a value of ``None``
+ passed for these parameters means "no default present".
+
+ Additionally, when used in SQL comparison expressions, the
+ Python value ``None`` continues to refer to SQL null, and not
+ JSON NULL. The :paramref:`_types.JSON.none_as_null` flag refers
+ explicitly to the **persistence** of the value within an
+ INSERT or UPDATE statement. The :attr:`_types.JSON.NULL`
+ value should be used for SQL expressions that wish to compare to
+ JSON null.
+
+ .. seealso::
+
+ :attr:`.types.JSON.NULL`
+
+ """
+ self.none_as_null = none_as_null
+
+ class JSONElementType(TypeEngine):
+ """Common function for index / path elements in a JSON expression."""
+
+ _integer = Integer()
+ _string = String()
+
+ def string_bind_processor(self, dialect):
+ return self._string._cached_bind_processor(dialect)
+
+ def string_literal_processor(self, dialect):
+ return self._string._cached_literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ int_processor = self._integer._cached_bind_processor(dialect)
+ string_processor = self.string_bind_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ int_processor = self._integer._cached_literal_processor(dialect)
+ string_processor = self.string_literal_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ class JSONIndexType(JSONElementType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONIntIndexType(JSONIndexType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONStrIndexType(JSONIndexType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONPathType(JSONElementType):
+ """Placeholder type for JSON path operations.
+
+ This allows execution-time processing of a path-based
+ index value into a specific SQL syntax.
+
+ """
+
+ class Comparator(Indexable.Comparator, Concatenable.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ def _setup_getitem(self, index):
+ if not isinstance(index, util.string_types) and isinstance(
+ index, compat.collections_abc.Sequence
+ ):
+ index = coercions.expect(
+ roles.BinaryElementRole,
+ index,
+ expr=self.expr,
+ operator=operators.json_path_getitem_op,
+ bindparam_type=JSON.JSONPathType,
+ )
+
+ operator = operators.json_path_getitem_op
+ else:
+ index = coercions.expect(
+ roles.BinaryElementRole,
+ index,
+ expr=self.expr,
+ operator=operators.json_getitem_op,
+ bindparam_type=JSON.JSONIntIndexType
+ if isinstance(index, int)
+ else JSON.JSONStrIndexType,
+ )
+ operator = operators.json_getitem_op
+
+ return operator, index, self.type
+
+ def as_boolean(self):
+ """Cast an indexed value as boolean.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_boolean()
+ ).where(
+ mytable.c.json_column['some_data'].as_boolean() == True
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Boolean(), "as_boolean")
+
+ def as_string(self):
+ """Cast an indexed value as string.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_string()
+ ).where(
+ mytable.c.json_column['some_data'].as_string() ==
+ 'some string'
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(String(), "as_string")
+
+ def as_integer(self):
+ """Cast an indexed value as integer.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_integer()
+ ).where(
+ mytable.c.json_column['some_data'].as_integer() == 5
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Integer(), "as_integer")
+
+ def as_float(self):
+ """Cast an indexed value as float.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_float()
+ ).where(
+ mytable.c.json_column['some_data'].as_float() == 29.75
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Float(), "as_float")
+
+ def as_numeric(self, precision, scale, asdecimal=True):
+ """Cast an indexed value as numeric/decimal.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_numeric(10, 6)
+ ).where(
+ mytable.c.
+ json_column['some_data'].as_numeric(10, 6) == 29.75
+ )
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ return self._binary_w_type(
+ Numeric(precision, scale, asdecimal=asdecimal), "as_numeric"
+ )
+
+ def as_json(self):
+ """Cast an indexed value as JSON.
+
+ e.g.::
+
+ stmt = select(mytable.c.json_column['some_data'].as_json())
+
+ This is typically the default behavior of indexed elements in any
+ case.
+
+ Note that comparison of full JSON structures may not be
+ supported by all backends.
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self.expr
+
+ def _binary_w_type(self, typ, method_name):
+ if not isinstance(
+ self.expr, elements.BinaryExpression
+ ) or self.expr.operator not in (
+ operators.json_getitem_op,
+ operators.json_path_getitem_op,
+ ):
+ raise exc.InvalidRequestError(
+ "The JSON cast operator JSON.%s() only works with a JSON "
+ "index expression e.g. col['q'].%s()"
+ % (method_name, method_name)
+ )
+ expr = self.expr._clone()
+ expr.type = typ
+ return expr
+
+ comparator_factory = Comparator
+
+ @property
+ def python_type(self):
+ return dict
+
+ @property
+ def should_evaluate_none(self):
+ """Alias of :attr:`_types.JSON.none_as_null`"""
+ return not self.none_as_null
+
+ @should_evaluate_none.setter
+ def should_evaluate_none(self, value):
+ self.none_as_null = not value
+
+ @util.memoized_property
+ def _str_impl(self):
+ return String(_expect_unicode=True)
+
+ def bind_processor(self, dialect):
+ string_process = self._str_impl.bind_processor(dialect)
+
+ json_serializer = dialect._json_serializer or json.dumps
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, elements.Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+
+ serialized = json_serializer(value)
+ if string_process:
+ serialized = string_process(serialized)
+ return serialized
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ string_process = self._str_impl.result_processor(dialect, coltype)
+ json_deserializer = dialect._json_deserializer or json.loads
+
+ def process(value):
+ if value is None:
+ return None
+ if string_process:
+ value = string_process(value)
+ return json_deserializer(value)
+
+ return process
+
+
+class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
+ """Represent a SQL Array type.
+
+ .. note:: This type serves as the basis for all ARRAY operations.
+ However, currently **only the PostgreSQL backend has support for SQL
+ arrays in SQLAlchemy**. It is recommended to use the PostgreSQL-specific
+ :class:`sqlalchemy.dialects.postgresql.ARRAY` type directly when using
+ ARRAY types with PostgreSQL, as it provides additional operators
+ specific to that backend.
+
+ :class:`_types.ARRAY` is part of the Core in support of various SQL
+ standard functions such as :class:`_functions.array_agg`
+ which explicitly involve
+ arrays; however, with the exception of the PostgreSQL backend and possibly
+ some third-party dialects, no other SQLAlchemy built-in dialect has support
+ for this type.
+
+ An :class:`_types.ARRAY` type is constructed given the "type"
+ of element::
+
+ mytable = Table("mytable", metadata,
+ Column("data", ARRAY(Integer))
+ )
+
+ The above type represents an N-dimensional array,
+ meaning a supporting backend such as PostgreSQL will interpret values
+ with any number of dimensions automatically. To produce an INSERT
+ construct that passes in a 1-dimensional array of integers::
+
+ connection.execute(
+ mytable.insert(),
+ {"data": [1,2,3]}
+ )
+
+ The :class:`_types.ARRAY` type can be constructed given a fixed number
+ of dimensions::
+
+ mytable = Table("mytable", metadata,
+ Column("data", ARRAY(Integer, dimensions=2))
+ )
+
+ Sending a number of dimensions is optional, but recommended if the
+ datatype is to represent arrays of more than one dimension. This number
+ is used:
+
+ * When emitting the type declaration itself to the database, e.g.
+ ``INTEGER[][]``
+
+ * When translating Python values to database values, and vice versa, e.g.
+ an ARRAY of :class:`.Unicode` objects uses this number to efficiently
+ access the string values inside of array structures without resorting
+ to per-row type inspection
+
+ * When used with the Python ``getitem`` accessor, the number of dimensions
+ serves to define the kind of type that the ``[]`` operator should
+ return, e.g. for an ARRAY of INTEGER with two dimensions::
+
+ >>> expr = table.c.column[5] # returns ARRAY(Integer, dimensions=1)
+ >>> expr = expr[6] # returns Integer
+
+ For 1-dimensional arrays, an :class:`_types.ARRAY` instance with no
+ dimension parameter will generally assume single-dimensional behaviors.
+
+ SQL expressions of type :class:`_types.ARRAY` have support for "index" and
+ "slice" behavior. The Python ``[]`` operator works normally here, given
+ integer indexes or slices. Arrays default to 1-based indexing.
+ The operator produces binary expression
+ constructs which will produce the appropriate SQL, both for
+ SELECT statements::
+
+ select(mytable.c.data[5], mytable.c.data[2:7])
+
+ as well as UPDATE statements when the :meth:`_expression.Update.values`
+ method
+ is used::
+
+ mytable.update().values({
+ mytable.c.data[5]: 7,
+ mytable.c.data[2:7]: [1, 2, 3]
+ })
+
+ The :class:`_types.ARRAY` type also provides for the operators
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`. The PostgreSQL-specific version of
+ :class:`_types.ARRAY` also provides additional operators.
+
+ .. versionadded:: 1.1.0
+
+ .. seealso::
+
+ :class:`sqlalchemy.dialects.postgresql.ARRAY`
+
+ """
+
+ __visit_name__ = "ARRAY"
+
+ _is_array = True
+
+ zero_indexes = False
+ """If True, Python zero-based indexes should be interpreted as one-based
+ on the SQL expression side."""
+
+ class Comparator(Indexable.Comparator, Concatenable.Comparator):
+
+ """Define comparison operations for :class:`_types.ARRAY`.
+
+ More operators are available on the dialect-specific form
+ of this type. See :class:`.postgresql.ARRAY.Comparator`.
+
+ """
+
+ def _setup_getitem(self, index):
+ if isinstance(index, slice):
+ return_type = self.type
+ if self.type.zero_indexes:
+ index = slice(index.start + 1, index.stop + 1, index.step)
+ slice_ = Slice(
+ index.start, index.stop, index.step, _name=self.expr.key
+ )
+ return operators.getitem, slice_, return_type
+ else:
+ if self.type.zero_indexes:
+ index += 1
+ if self.type.dimensions is None or self.type.dimensions == 1:
+ return_type = self.type.item_type
+ else:
+ adapt_kw = {"dimensions": self.type.dimensions - 1}
+ return_type = self.type.adapt(
+ self.type.__class__, **adapt_kw
+ )
+
+ return operators.getitem, index, return_type
+
+ def contains(self, *arg, **kw):
+ raise NotImplementedError(
+ "ARRAY.contains() not implemented for the base "
+ "ARRAY type; please use the dialect-specific ARRAY type"
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def any(self, other, operator=None):
+ """Return ``other operator ANY (array)`` clause.
+
+ .. note:: This method is an :class:`_types.ARRAY` - specific
+ construct that is now superseded by the :func:`_sql.any_`
+ function, which features a different calling style. The
+ :func:`_sql.any_` function is also mirrored at the method level
+ via the :meth:`_sql.ColumnOperators.any_` method.
+
+ Usage of array-specific :meth:`_types.ARRAY.Comparator.any`
+ is as follows::
+
+ from sqlalchemy.sql import operators
+
+ conn.execute(
+ select(table.c.data).where(
+ table.c.data.any(7, operator=operators.lt)
+ )
+ )
+
+ :param other: expression to be compared
+ :param operator: an operator object from the
+ :mod:`sqlalchemy.sql.operators`
+ package, defaults to :func:`.operators.eq`.
+
+ .. seealso::
+
+ :func:`_expression.any_`
+
+ :meth:`.types.ARRAY.Comparator.all`
+
+ """
+ elements = util.preloaded.sql_elements
+ operator = operator if operator else operators.eq
+
+ arr_type = self.type
+
+ # send plain BinaryExpression so that negate remains at None,
+ # leading to NOT expr for negation.
+ return elements.BinaryExpression(
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
+ elements.CollectionAggregate._create_any(self.expr),
+ operator,
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def all(self, other, operator=None):
+ """Return ``other operator ALL (array)`` clause.
+
+ .. note:: This method is an :class:`_types.ARRAY` - specific
+ construct that is now superseded by the :func:`_sql.any_`
+ function, which features a different calling style. The
+ :func:`_sql.any_` function is also mirrored at the method level
+ via the :meth:`_sql.ColumnOperators.any_` method.
+
+ Usage of array-specific :meth:`_types.ARRAY.Comparator.all`
+ is as follows::
+
+ from sqlalchemy.sql import operators
+
+ conn.execute(
+ select(table.c.data).where(
+ table.c.data.all(7, operator=operators.lt)
+ )
+ )
+
+ :param other: expression to be compared
+ :param operator: an operator object from the
+ :mod:`sqlalchemy.sql.operators`
+ package, defaults to :func:`.operators.eq`.
+
+ .. seealso::
+
+ :func:`_expression.all_`
+
+ :meth:`.types.ARRAY.Comparator.any`
+
+ """
+ elements = util.preloaded.sql_elements
+ operator = operator if operator else operators.eq
+
+ arr_type = self.type
+
+ # send plain BinaryExpression so that negate remains at None,
+ # leading to NOT expr for negation.
+ return elements.BinaryExpression(
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
+ elements.CollectionAggregate._create_all(self.expr),
+ operator,
+ )
+
+ comparator_factory = Comparator
+
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
+ """Construct an :class:`_types.ARRAY`.
+
+ E.g.::
+
+ Column('myarray', ARRAY(Integer))
+
+ Arguments are:
+
+ :param item_type: The data type of items of this array. Note that
+ dimensionality is irrelevant here, so multi-dimensional arrays like
+ ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
+ ``ARRAY(ARRAY(Integer))`` or such.
+
+ :param as_tuple=False: Specify whether return results
+ should be converted to tuples from lists. This parameter is
+ not generally needed as a Python list corresponds well
+ to a SQL array.
+
+ :param dimensions: if non-None, the ARRAY will assume a fixed
+ number of dimensions. This impacts how the array is declared
+ on the database, how it goes about interpreting Python and
+ result values, as well as how expression behavior in conjunction
+ with the "getitem" operator works. See the description at
+ :class:`_types.ARRAY` for additional detail.
+
+ :param zero_indexes=False: when True, index values will be converted
+ between Python zero-based and SQL one-based indexes, e.g.
+ a value of one will be added to all index values before passing
+ to the database.
+
+ """
+ if isinstance(item_type, ARRAY):
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ self.as_tuple = as_tuple
+ self.dimensions = dimensions
+ self.zero_indexes = zero_indexes
+
+ @property
+ def hashable(self):
+ return self.as_tuple
+
+ @property
+ def python_type(self):
+ return list
+
+ def compare_values(self, x, y):
+ return x == y
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ if not outer and isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent(column, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True)
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent_with_dispatch(parent)
+
+
+class TupleType(TypeEngine):
+ """represent the composite type of a Tuple."""
+
+ _is_tuple_type = True
+
+ def __init__(self, *types):
+ self._fully_typed = NULLTYPE not in types
+ self.types = [
+ item_type() if isinstance(item_type, type) else item_type
+ for item_type in types
+ ]
+
+ def _resolve_values_to_types(self, value):
+ if self._fully_typed:
+ return self
+ else:
+ return TupleType(
+ *[
+ _resolve_value_to_type(elem) if typ is NULLTYPE else typ
+ for typ, elem in zip(self.types, value)
+ ]
+ )
+
+ def result_processor(self, dialect, coltype):
+ raise NotImplementedError(
+ "The tuple type does not support being fetched "
+ "as a column in a result row."
+ )
+
+
+class REAL(Float):
+
+ """The SQL REAL type."""
+
+ __visit_name__ = "REAL"
+
+
+class FLOAT(Float):
+
+ """The SQL FLOAT type."""
+
+ __visit_name__ = "FLOAT"
+
+
+class NUMERIC(Numeric):
+
+ """The SQL NUMERIC type."""
+
+ __visit_name__ = "NUMERIC"
+
+
+class DECIMAL(Numeric):
+
+ """The SQL DECIMAL type."""
+
+ __visit_name__ = "DECIMAL"
+
+
+class INTEGER(Integer):
+
+ """The SQL INT or INTEGER type."""
+
+ __visit_name__ = "INTEGER"
+
+
+INT = INTEGER
+
+
+class SMALLINT(SmallInteger):
+
+ """The SQL SMALLINT type."""
+
+ __visit_name__ = "SMALLINT"
+
+
+class BIGINT(BigInteger):
+
+ """The SQL BIGINT type."""
+
+ __visit_name__ = "BIGINT"
+
+
+class TIMESTAMP(DateTime):
+
+ """The SQL TIMESTAMP type.
+
+ :class:`_types.TIMESTAMP` datatypes have support for timezone
+ storage on some backends, such as PostgreSQL and Oracle. Use the
+ :paramref:`~types.TIMESTAMP.timezone` argument in order to enable
+ "TIMESTAMP WITH TIMEZONE" for these backends.
+
+ """
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False):
+ """Construct a new :class:`_types.TIMESTAMP`.
+
+ :param timezone: boolean. Indicates that the TIMESTAMP type should
+ enable timezone support, if available on the target database.
+ On a per-dialect basis is similar to "TIMESTAMP WITH TIMEZONE".
+ If the target database does not support timezones, this flag is
+ ignored.
+
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.TIMESTAMP
+
+
+class DATETIME(DateTime):
+
+ """The SQL DATETIME type."""
+
+ __visit_name__ = "DATETIME"
+
+
+class DATE(Date):
+
+ """The SQL DATE type."""
+
+ __visit_name__ = "DATE"
+
+
+class TIME(Time):
+
+ """The SQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+
+class TEXT(Text):
+
+ """The SQL TEXT type."""
+
+ __visit_name__ = "TEXT"
+
+
+class CLOB(Text):
+
+ """The CLOB type.
+
+ This type is found in Oracle and Informix.
+ """
+
+ __visit_name__ = "CLOB"
+
+
+class VARCHAR(String):
+
+ """The SQL VARCHAR type."""
+
+ __visit_name__ = "VARCHAR"
+
+
+class NVARCHAR(Unicode):
+
+ """The SQL NVARCHAR type."""
+
+ __visit_name__ = "NVARCHAR"
+
+
+class CHAR(String):
+
+ """The SQL CHAR type."""
+
+ __visit_name__ = "CHAR"
+
+
+class NCHAR(Unicode):
+
+ """The SQL NCHAR type."""
+
+ __visit_name__ = "NCHAR"
+
+
+class BLOB(LargeBinary):
+
+ """The SQL BLOB type."""
+
+ __visit_name__ = "BLOB"
+
+
+class BINARY(_Binary):
+
+ """The SQL BINARY type."""
+
+ __visit_name__ = "BINARY"
+
+
+class VARBINARY(_Binary):
+
+ """The SQL VARBINARY type."""
+
+ __visit_name__ = "VARBINARY"
+
+
+class BOOLEAN(Boolean):
+
+ """The SQL BOOLEAN type."""
+
+ __visit_name__ = "BOOLEAN"
+
+
+class NullType(TypeEngine):
+
+ """An unknown type.
+
+ :class:`.NullType` is used as a default type for those cases where
+ a type cannot be determined, including:
+
+ * During table reflection, when the type of a column is not recognized
+ by the :class:`.Dialect`
+ * When constructing SQL expressions using plain Python objects of
+ unknown types (e.g. ``somecolumn == my_special_object``)
+ * When a new :class:`_schema.Column` is created,
+ and the given type is passed
+ as ``None`` or is not passed at all.
+
+ The :class:`.NullType` can be used within SQL expression invocation
+ without issue, it just has no behavior either at the expression
+ construction level or at the bind-parameter/result processing level.
+ :class:`.NullType` will result in a :exc:`.CompileError` if the compiler
+ is asked to render the type itself, such as if it is used in a
+ :func:`.cast` operation or within a schema creation operation such as that
+ invoked by :meth:`_schema.MetaData.create_all` or the
+ :class:`.CreateTable`
+ construct.
+
+ """
+
+ __visit_name__ = "null"
+
+ _isnull = True
+
+ def literal_processor(self, dialect):
+ def process(value):
+ raise exc.CompileError(
+ "Don't know how to render literal SQL value: %r" % (value,)
+ )
+
+ return process
+
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if isinstance(
+ other_comparator, NullType.Comparator
+ ) or not operators.is_commutative(op):
+ return op, self.expr.type
+ else:
+ return other_comparator._adapt_expression(op, self)
+
+ comparator_factory = Comparator
+
+
+class TableValueType(HasCacheKey, TypeEngine):
+ """Refers to a table value type."""
+
+ _is_table_value = True
+
+ _traverse_internals = [
+ ("_elements", InternalTraversal.dp_clauseelement_list),
+ ]
+
+ def __init__(self, *elements):
+ self._elements = [
+ coercions.expect(roles.StrAsPlainColumnRole, elem)
+ for elem in elements
+ ]
+
+
+class MatchType(Boolean):
+ """Refers to the return type of the MATCH operator.
+
+ As the :meth:`.ColumnOperators.match` is probably the most open-ended
+ operator in generic SQLAlchemy Core, we can't assume the return type
+ at SQL evaluation time, as MySQL returns a floating point, not a boolean,
+ and other backends might do something different. So this type
+ acts as a placeholder, currently subclassing :class:`.Boolean`.
+ The type allows dialects to inject result-processing functionality
+ if needed, and on MySQL will return floating-point values.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+
+NULLTYPE = NullType()
+BOOLEANTYPE = Boolean()
+STRINGTYPE = String()
+INTEGERTYPE = Integer()
+NUMERICTYPE = Numeric()
+MATCHTYPE = MatchType()
+TABLEVALUE = TableValueType()
+DATETIME_TIMEZONE = DateTime(timezone=True)
+TIME_TIMEZONE = Time(timezone=True)
+
+_type_map = {
+ int: Integer(),
+ float: Float(),
+ bool: BOOLEANTYPE,
+ decimal.Decimal: Numeric(),
+ dt.date: Date(),
+ dt.datetime: DateTime(),
+ dt.time: Time(),
+ dt.timedelta: Interval(),
+ util.NoneType: NULLTYPE,
+}
+
+if util.py3k:
+ _type_map[bytes] = LargeBinary() # noqa
+ _type_map[str] = Unicode()
+else:
+ _type_map[unicode] = Unicode() # noqa
+ _type_map[str] = String()
+
+
+_type_map_get = _type_map.get
+
+
+def _resolve_value_to_type(value):
+ _result_type = _type_map_get(type(value), False)
+ if _result_type is False:
+ # use inspect() to detect SQLAlchemy built-in
+ # objects.
+ insp = inspection.inspect(value, False)
+ if (
+ insp is not None
+ and
+ # foil mock.Mock() and other impostors by ensuring
+ # the inspection target itself self-inspects
+ insp.__class__ in inspection._registrars
+ ):
+ raise exc.ArgumentError(
+ "Object %r is not legal as a SQL literal value" % (value,)
+ )
+ return NULLTYPE
+ else:
+ return _result_type._resolve_for_literal(value)
+
+
+# back-assign to type_api
+type_api.BOOLEANTYPE = BOOLEANTYPE
+type_api.STRINGTYPE = STRINGTYPE
+type_api.INTEGERTYPE = INTEGERTYPE
+type_api.NULLTYPE = NULLTYPE
+type_api.NUMERICTYPE = NUMERICTYPE
+type_api.MATCHTYPE = MATCHTYPE
+type_api.INDEXABLE = Indexable
+type_api.TABLEVALUE = TABLEVALUE
+type_api._resolve_value_to_type = _resolve_value_to_type
+TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
new file mode 100644
index 0000000..9da61ab
--- /dev/null
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -0,0 +1,1559 @@
+from collections import deque
+from collections import namedtuple
+import itertools
+import operator
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import util
+from ..inspection import inspect
+from ..util import collections_abc
+from ..util import HasMemoized
+from ..util import py37
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+CACHE_IN_PLACE = util.symbol("cache_in_place")
+CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
+STATIC_CACHE_KEY = util.symbol("static_cache_key")
+PROPAGATE_ATTRS = util.symbol("propagate_attrs")
+ANON_NAME = util.symbol("anon_name")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = TraversalComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+def _preconfigure_traversals(target_hierarchy):
+ for cls in util.walk_subclasses(target_hierarchy):
+ if hasattr(cls, "_traverse_internals"):
+ cls._generate_cache_attrs()
+ _copy_internals.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_copy_internals_traversal",
+ )
+ _get_children.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_get_children_traversal",
+ )
+
+
+class HasCacheKey(object):
+ """Mixin for objects which can produce a cache key.
+
+ .. seealso::
+
+ :class:`.CacheKey`
+
+ :ref:`sql_caching`
+
+ """
+
+ _cache_key_traversal = NO_CACHE
+
+ _is_has_cache_key = True
+
+ _hierarchy_supports_caching = True
+ """private attribute which may be set to False to prevent the
+ inherit_cache warning from being emitted for a hierarchy of subclasses.
+
+ Currently applies to the DDLElement hierarchy which does not implement
+ caching.
+
+ """
+
+ inherit_cache = None
+ """Indicate if this :class:`.HasCacheKey` instance should make use of the
+ cache key generation scheme used by its immediate superclass.
+
+ The attribute defaults to ``None``, which indicates that a construct has
+ not yet taken into account whether or not its appropriate for it to
+ participate in caching; this is functionally equivalent to setting the
+ value to ``False``, except that a warning is also emitted.
+
+ This flag can be set to ``True`` on a particular class, if the SQL that
+ corresponds to the object does not change based on attributes which
+ are local to this class, and not its superclass.
+
+ .. seealso::
+
+ :ref:`compilerext_caching` - General guideslines for setting the
+ :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
+ defined SQL constructs.
+
+ """
+
+ __slots__ = ()
+
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit_cache = cls.__dict__.get("inherit_cache", None)
+ inherit = bool(inherit_cache)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ if (
+ inherit_cache is None
+ and cls._hierarchy_supports_caching
+ ):
+ util.warn(
+ "Class %s will not make use of SQL compilation "
+ "caching as it does not set the 'inherit_cache' "
+ "attribute to ``True``. This can have "
+ "significant performance implications including "
+ "some performance degradations in comparison to "
+ "prior SQLAlchemy versions. Set this attribute "
+ "to True if this object can make use of the cache "
+ "key generated by the superclass. Alternatively, "
+ "this attribute may be set to False which will "
+ "disable this warning." % (cls.__name__),
+ code="cprf",
+ )
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, the NO_CACHE
+ symbol should be added to the anon_map and the method should
+ return None.
+
+ """
+
+ idself = id(self)
+ cls = self.__class__
+
+ if idself in anon_map:
+ return (anon_map[idself], cls)
+ else:
+ # inline of
+ # id_ = anon_map[idself]
+ anon_map[idself] = id_ = str(anon_map.index)
+ anon_map.index += 1
+
+ try:
+ dispatcher = cls.__dict__["_generated_cache_key_traversal"]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = cls._generate_cache_attrs()
+
+ if dispatcher is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, cls)
+
+ # inline of _cache_key_traversal_visitor.run_generated_dispatch()
+
+ for attrname, obj, meth in dispatcher(
+ self, _cache_key_traversal_visitor
+ ):
+ if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
+
+ if meth is STATIC_CACHE_KEY:
+ sck = obj._static_cache_key
+ if sck is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+ result += (attrname, sck)
+ elif meth is ANON_NAME:
+ elements = util.preloaded.sql_elements
+ if isinstance(obj, elements._anonymous_label):
+ obj = obj.apply_map(anon_map)
+ result += (attrname, obj)
+ elif meth is CALL_GEN_CACHE_KEY:
+ result += (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams),
+ )
+
+ # remaining cache functions are against
+ # Python tuples, dicts, lists, etc. so we can skip
+ # if they are empty
+ elif obj:
+ if meth is CACHE_IN_PLACE:
+ result += (attrname, obj)
+ elif meth is PROPAGATE_ATTRS:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ )
+ if obj["plugin_subject"]
+ else None,
+ )
+ elif meth is InternalTraversal.dp_annotations_key:
+ # obj is here is the _annotations dict. however, we
+ # want to use the memoized cache key version of it. for
+ # Columns, this should be long lived. For select()
+ # statements, not so much, but they usually won't have
+ # annotations.
+ result += self._annotations_cache_key
+ elif (
+ meth is InternalTraversal.dp_clauseelement_list
+ or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
+ ):
+ result += (
+ attrname,
+ tuple(
+ [
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ ]
+ ),
+ )
+ else:
+ result += meth(
+ attrname, obj, self, anon_map, bindparams
+ )
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`_expression.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ """The key used to identify a SQL statement construct in the
+ SQL compilation cache.
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ """
+
+ def __hash__(self):
+ """CacheKey itself is not hashable - hash the .key portion"""
+
+ return None
+
+ def to_offline_string(self, statement_cache, statement, parameters):
+ """Generate an "offline string" form of this :class:`.CacheKey`
+
+ The "offline string" is basically the string SQL for the
+ statement plus a repr of the bound parameter values in series.
+ Whereas the :class:`.CacheKey` object is dependent on in-memory
+ identities in order to work as a cache key, the "offline" version
+ is suitable for a cache that will work for other processes as well.
+
+ The given ``statement_cache`` is a dictionary-like object where the
+ string form of the statement itself will be cached. This dictionary
+ should be in a longer lived scope in order to reduce the time spent
+ stringifying statements.
+
+
+ """
+ if self.key not in statement_cache:
+ statement_cache[self.key] = sql_str = str(statement)
+ else:
+ sql_str = statement_cache[self.key]
+
+ if not self.bindparams:
+ param_tuple = tuple(parameters[key] for key in sorted(parameters))
+ else:
+ param_tuple = tuple(
+ parameters.get(bindparam.key, bindparam.value)
+ for bindparam in self.bindparams
+ )
+
+ return repr((sql_str, param_tuple))
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+ @classmethod
+ def _diff_tuples(cls, left, right):
+ ck1 = CacheKey(left, [])
+ ck2 = CacheKey(right, [])
+ return ck1._diff(ck2)
+
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
+ def __str__(self):
+ stack = [self.key]
+
+ output = []
+ sentinel = object()
+ indent = -1
+ while stack:
+ elem = stack.pop(0)
+ if elem is sentinel:
+ output.append((" " * (indent * 2)) + "),")
+ indent -= 1
+ elif isinstance(elem, tuple):
+ if not elem:
+ output.append((" " * ((indent + 1) * 2)) + "()")
+ else:
+ indent += 1
+ stack = list(elem) + [sentinel] + stack
+ output.append((" " * (indent * 2)) + "(")
+ else:
+ if isinstance(elem, HasCacheKey):
+ repr_ = "<%s object at %s>" % (
+ type(elem).__name__,
+ hex(id(elem)),
+ )
+ else:
+ repr_ = repr(elem)
+ output.append((" " * (indent * 2)) + " " + repr_ + ", ")
+
+ return "CacheKey(key=%s)" % ("\n".join(output),)
+
+ def _generate_param_dict(self):
+ """used for testing"""
+
+ from .compiler import prefix_anon_map
+
+ _anon_map = prefix_anon_map()
+ return {b.key % _anon_map: b.effective_value for b in self.bindparams}
+
+ def _apply_params_to_element(self, original_cache_key, target_element):
+ translate = {
+ k.key: v.value
+ for k, v in zip(original_cache_key.bindparams, self.bindparams)
+ }
+
+ return target_element.params(translate)
+
+
+def _clone(element, **kw):
+ return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+ # very common elements are inlined into the main _get_cache_key() method
+ # to produce a dramatic savings in Python function call overhead
+
+ visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
+ visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
+ visit_annotations_key = InternalTraversal.dp_annotations_key
+ visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
+
+ visit_string = (
+ visit_boolean
+ ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
+ visit_statement_hint_list = CACHE_IN_PLACE
+ visit_type = STATIC_CACHE_KEY
+ visit_anon_name = ANON_NAME
+
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
+ def visit_with_context_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return tuple((fn.__code__, c_key) for fn, c_key in obj)
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
+
+ def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ return tuple(obj)
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_executable_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ if elem._is_has_cache_key
+ ),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ]
+ ),
+ )
+
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ is_legacy = "legacy" in attrname
+
+ return tuple(
+ (
+ target
+ if is_legacy and isinstance(target, str)
+ else target._gen_cache_key(anon_map, bindparams),
+ onclause
+ if is_legacy and isinstance(onclause, str)
+ else onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ]
+ ),
+ )
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+
+ def visit_dialect_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ dialect_name,
+ tuple(
+ [
+ (key, obj[dialect_name][key])
+ for key in sorted(obj[dialect_name])
+ ]
+ ),
+ )
+ for dialect_name in sorted(obj)
+ ),
+ )
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # inlining into the internals of ColumnCollection
+ return (
+ attrname,
+ tuple(
+ col._gen_cache_key(anon_map, bindparams)
+ for k, col in obj._collection
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+ def visit_dml_ordered_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key._gen_cache_key(anon_map, bindparams)
+ if hasattr(key, "__clause_element__")
+ else key,
+ value._gen_cache_key(anon_map, bindparams),
+ )
+ for key, value in obj
+ ),
+ )
+
+ def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ if py37:
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+ else:
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
+
+ str_values = expr_values.symmetric_difference(obj)
+
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
+
+ def visit_dml_multi_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # multivalues are simply not cacheable right now
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal_visitor = _CacheKey()
+
+
+class HasCopyInternals(object):
+ def _clone(self, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, omit_attrs=(), **kw):
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if attrname in omit_attrs:
+ continue
+
+ if obj is not None:
+ result = meth(attrname, self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+
+class _CopyInternals(InternalTraversal):
+ """Generate a _copy_internals internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_clauseelement(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return clone(element, **kw)
+
+ def visit_clauseelement_list(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return [clone(clause, **kw) for clause in element]
+
+ def visit_clauseelement_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_executable_options(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return {clone(clause, **kw) for clause in element}
+
+ def visit_clauseelement_tuples(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return [
+ tuple(clone(tup_elem, **kw) for tup_elem in elem)
+ for elem in element
+ ]
+
+ def visit_string_clauseelement_dict(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return dict(
+ (key, clone(value, **kw)) for key, value in element.items()
+ )
+
+ def visit_setup_join_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple(
+ (
+ clone(target, **kw) if target is not None else None,
+ clone(onclause, **kw) if onclause is not None else None,
+ clone(from_, **kw) if from_ is not None else None,
+ flags,
+ )
+ for (target, onclause, from_, flags) in element
+ )
+
+ def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+ return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
+ def visit_dml_ordered_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # sequence of 2-tuples
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw),
+ )
+ for key, value in element
+ ]
+
+ def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
+ return {
+ (
+ clone(key, **kw) if hasattr(key, "__clause_element__") else key
+ ): clone(value, **kw)
+ for key, value in element.items()
+ }
+
+ def visit_dml_multi_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # sequence of sequences, each sequence contains a list/dict/tuple
+
+ def copy(elem):
+ if isinstance(elem, (list, tuple)):
+ return [
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ for value in elem
+ ]
+ elif isinstance(elem, dict):
+ return {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): (
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ for key, value in elem.items()
+ }
+ else:
+ # TODO: use abc classes
+ assert False
+
+ return [
+ [copy(sub_element) for sub_element in sequence]
+ for sequence in element
+ ]
+
+ def visit_propagate_attrs(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return element
+
+
+_copy_internals = _CopyInternals()
+
+
+def _flatten_clauseelement(element):
+ while hasattr(element, "__clause_element__") and not getattr(
+ element, "is_clause_element", False
+ ):
+ element = element.__clause_element__()
+
+ return element
+
+
+class _GetChildren(InternalTraversal):
+ """Generate a _children_traversal internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_has_cache_key(self, element, **kw):
+ # the GetChildren traversal refers explicitly to ClauseElement
+ # structures. Within these, a plain HasCacheKey is not a
+ # ClauseElement, so don't include these.
+ return ()
+
+ def visit_clauseelement(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement_list(self, element, **kw):
+ return element
+
+ def visit_clauseelement_tuple(self, element, **kw):
+ return element
+
+ def visit_clauseelement_tuples(self, element, **kw):
+ return itertools.chain.from_iterable(element)
+
+ def visit_fromclause_canonical_column_collection(self, element, **kw):
+ return ()
+
+ def visit_string_clauseelement_dict(self, element, **kw):
+ return element.values()
+
+ def visit_fromclause_ordered_set(self, element, **kw):
+ return element
+
+ def visit_clauseelement_unordered_set(self, element, **kw):
+ return element
+
+ def visit_setup_join_tuple(self, element, **kw):
+ for (target, onclause, from_, flags) in element:
+ if from_ is not None:
+ yield from_
+
+ if not isinstance(target, str):
+ yield _flatten_clauseelement(target)
+
+ if onclause is not None and not isinstance(onclause, str):
+ yield _flatten_clauseelement(onclause)
+
+ def visit_memoized_select_entities(self, element, **kw):
+ return self.visit_clauseelement_tuple(element, **kw)
+
+ def visit_dml_ordered_values(self, element, **kw):
+ for k, v in element:
+ if hasattr(k, "__clause_element__"):
+ yield k
+ yield v
+
+ def visit_dml_values(self, element, **kw):
+ expr_values = {k for k in element if hasattr(k, "__clause_element__")}
+ str_values = expr_values.symmetric_difference(element)
+
+ for k in sorted(str_values):
+ yield element[k]
+ for k in expr_values:
+ yield k
+ yield element[k]
+
+ def visit_dml_multi_values(self, element, **kw):
+ return ()
+
+ def visit_propagate_attrs(self, element, **kw):
+ return ()
+
+
+_get_children = _GetChildren()
+
+
+@util.preload_module("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(element, name, anon_map, **kw):
+ if isinstance(name, util.preloaded.sql_elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return name
+
+
+class anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __init__(self):
+ self.index = 0
+
+ def __missing__(self, key):
+ self[key] = val = str(self.index)
+ self.index += 1
+ return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+ __slots__ = "stack", "cache", "anon_map"
+
+ def __init__(self):
+ self.stack = deque()
+ self.cache = set()
+
+ def _memoized_attr_anon_map(self):
+ return (anon_map(), anon_map())
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.stack
+ cache = self.cache
+
+ compare_annotations = kw.get("compare_annotations", False)
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ attributes_compared = meth(left, right, **kw)
+ if attributes_compared is COMPARE_FAILED:
+ return False
+ elif attributes_compared is SKIP_TRAVERSE:
+ continue
+
+ # attributes_compared is returned as a list of attribute
+ # names that were "handled" by the comparison method above.
+ # remaining attribute names in the _traverse_internals
+ # will be compared.
+ else:
+ attributes_compared = ()
+
+ for (
+ (left_attrname, left_visit_sym),
+ (right_attrname, right_visit_sym),
+ ) in util.zip_longest(
+ left._traverse_internals,
+ right._traverse_internals,
+ fillvalue=(None, None),
+ ):
+ if not compare_annotations and (
+ (left_attrname == "_annotations")
+ or (right_attrname == "_annotations")
+ ):
+ continue
+
+ if (
+ left_attrname != right_attrname
+ or left_visit_sym is not right_visit_sym
+ ):
+ return False
+ elif left_attrname in attributes_compared:
+ continue
+
+ dispatch = self.dispatch(left_visit_sym)
+ left_child = operator.attrgetter(left_attrname)(left)
+ right_child = operator.attrgetter(right_attrname)(right)
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(
+ left_attrname, left, left_child, right, right_child, **kw
+ )
+ if comparison is COMPARE_FAILED:
+ return False
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ comparator = self.__class__()
+ return comparator.compare(obj1, obj2, **kw)
+
+ def visit_has_cache_key(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_propagate_attrs(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.compare_inner(
+ left.get("plugin_subject", None), right.get("plugin_subject", None)
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_executable_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if (
+ l._gen_cache_key(self.anon_map[0], [])
+ if l._is_has_cache_key
+ else l
+ ) != (
+ r._gen_cache_key(self.anon_map[1], [])
+ if r._is_has_cache_key
+ else r
+ ):
+ return COMPARE_FAILED
+
+ def visit_clauseelement(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ self.stack.append((left, right))
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((lcol, rcol))
+
+ def visit_fromclause_derived_column_collection(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ pass
+
+ def visit_string_clauseelement_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lstr, rstr in util.zip_longest(
+ sorted(left), sorted(right), fillvalue=None
+ ):
+ if lstr != rstr:
+ return COMPARE_FAILED
+ self.stack.append((left[lstr], right[rstr]))
+
+ def visit_clauseelement_tuples(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+ if ltup is None or rtup is None:
+ return COMPARE_FAILED
+
+ for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self._compare_unordered_sequences(left, right, **kw)
+
+ def visit_fromclause_ordered_set(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_string(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_string_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_anon_name(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return _resolve_name_for_compare(
+ left_parent, left, self.anon_map[0], **kw
+ ) == _resolve_name_for_compare(
+ right_parent, right, self.anon_map[1], **kw
+ )
+
+ def visit_boolean(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_operator(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left is right
+
+ def visit_type(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left._compare_type_affinity(right)
+
+ def visit_plain_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_dialect_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_annotations_key(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left and right:
+ return (
+ left_parent._annotations_cache_key
+ == right_parent._annotations_cache_key
+ )
+ else:
+ return left == right
+
+ def visit_with_context_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
+ (fn.__code__, c_key) for fn, c_key in right
+ )
+
+ def visit_plain_obj(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_named_ddl_element(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None:
+ if right is not None:
+ return COMPARE_FAILED
+
+ return left.name == right.name
+
+ def visit_prefix_sequence(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if l_str != r_str:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l_clause, r_clause))
+
+ def visit_setup_join_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ for (
+ (l_target, l_onclause, l_from, l_flags),
+ (r_target, r_onclause, r_from, r_flags),
+ ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
+ if l_flags != r_flags:
+ return COMPARE_FAILED
+ self.stack.append((l_target, r_target))
+ self.stack.append((l_onclause, r_onclause))
+ self.stack.append((l_from, r_from))
+
+ def visit_memoized_select_entities(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.visit_clauseelement_tuple(
+ attrname, left_parent, left, right_parent, right, **kw
+ )
+
+ def visit_table_hint_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+ right_keys = sorted(
+ right, key=lambda elem: (elem[0].fullname, elem[1])
+ )
+ for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+ left_keys, right_keys, fillvalue=(None, None)
+ ):
+ if ldialect != rdialect:
+ return COMPARE_FAILED
+ elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((ltable, rtable))
+
+ def visit_statement_hint_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_unknown_structure(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ raise NotImplementedError()
+
+ def visit_dml_ordered_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # sequence of tuple pairs
+
+ for (lk, lv), (rk, rv) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+
+ def _compare_dml_values_or_ce(self, lv, rv, **kw):
+ lvce = hasattr(lv, "__clause_element__")
+ rvce = hasattr(rv, "__clause_element__")
+ if lvce != rvce:
+ return False
+ elif lvce and not self.compare_inner(lv, rv, **kw):
+ return False
+ elif not lvce and lv != rv:
+ return False
+ elif not self.compare_inner(lv, rv, **kw):
+ return False
+
+ return True
+
+ def visit_dml_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None or right is None or len(left) != len(right):
+ return COMPARE_FAILED
+
+ if isinstance(left, collections_abc.Sequence):
+ for lv, rv in zip(left, right):
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ elif isinstance(right, collections_abc.Sequence):
+ return COMPARE_FAILED
+ elif py37:
+ # dictionaries guaranteed to support insert ordering in
+ # py37 so that we can compare the keys in order. without
+ # this, we can't compare SQL expression keys because we don't
+ # know which key is which
+ for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ else:
+ for lk in left:
+ lv = left[lk]
+
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
+
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_multi_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
+ if lseq is None or rseq is None:
+ return COMPARE_FAILED
+
+ for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
+ if (
+ self.visit_dml_values(
+ attrname, left_parent, ld, right_parent, rd, **kw
+ )
+ is COMPARE_FAILED
+ ):
+ return COMPARE_FAILED
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses, **kw
+ ):
+ return ["operator", "clauses"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_binary(self, left, right, **kw):
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ self.compare_inner(left.left, right.left, **kw)
+ and self.compare_inner(left.right, right.right, **kw)
+ ) or (
+ self.compare_inner(left.left, right.right, **kw)
+ and self.compare_inner(left.right, right.left, **kw)
+ ):
+ return ["operator", "negate", "left", "right"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator", "negate"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_bindparam(self, left, right, **kw):
+ compare_keys = kw.pop("compare_keys", True)
+ compare_values = kw.pop("compare_values", True)
+
+ if compare_values:
+ omit = []
+ else:
+ # this means, "skip these, we already compared"
+ omit = ["callable", "value"]
+
+ if not compare_keys:
+ omit.append("key")
+
+ return omit
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return SKIP_TRAVERSE
+ elif hash(left) == hash(right):
+ return SKIP_TRAVERSE
+ else:
+ return COMPARE_FAILED
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return SKIP_TRAVERSE if left is right else COMPARE_FAILED
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
new file mode 100644
index 0000000..29dc749
--- /dev/null
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -0,0 +1,1974 @@
+# sql/types_api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base types API.
+
+"""
+
+
+from . import operators
+from .base import SchemaEventTarget
+from .traversals import NO_CACHE
+from .visitors import Traversible
+from .visitors import TraversibleType
+from .. import exc
+from .. import util
+
+# these are back-assigned by sqltypes.
+BOOLEANTYPE = None
+INTEGERTYPE = None
+NULLTYPE = None
+NUMERICTYPE = None
+STRINGTYPE = None
+MATCHTYPE = None
+INDEXABLE = None
+TABLEVALUE = None
+_resolve_value_to_type = None
+
+
+class TypeEngine(Traversible):
+ """The ultimate base class for all SQL datatypes.
+
+ Common subclasses of :class:`.TypeEngine` include
+ :class:`.String`, :class:`.Integer`, and :class:`.Boolean`.
+
+ For an overview of the SQLAlchemy typing system, see
+ :ref:`types_toplevel`.
+
+ .. seealso::
+
+ :ref:`types_toplevel`
+
+ """
+
+ _sqla_type = True
+ _isnull = False
+ _is_tuple_type = False
+ _is_table_value = False
+ _is_array = False
+ _is_type_decorator = False
+
+ class Comparator(operators.ColumnOperators):
+ """Base class for custom comparison operations defined at the
+ type level. See :attr:`.TypeEngine.comparator_factory`.
+
+
+ """
+
+ __slots__ = "expr", "type"
+
+ default_comparator = None
+
+ def __clause_element__(self):
+ return self.expr
+
+ def __init__(self, expr):
+ self.expr = expr
+ self.type = expr.type
+
+ @util.preload_module("sqlalchemy.sql.default_comparator")
+ def operate(self, op, *other, **kwargs):
+ default_comparator = util.preloaded.sql_default_comparator
+ o = default_comparator.operator_lookup[op.__name__]
+ return o[0](self.expr, op, *(other + o[1:]), **kwargs)
+
+ @util.preload_module("sqlalchemy.sql.default_comparator")
+ def reverse_operate(self, op, other, **kwargs):
+ default_comparator = util.preloaded.sql_default_comparator
+ o = default_comparator.operator_lookup[op.__name__]
+ return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs)
+
+ def _adapt_expression(self, op, other_comparator):
+ """evaluate the return type of <self> <op> <othertype>,
+ and apply any adaptations to the given operator.
+
+ This method determines the type of a resulting binary expression
+ given two source types and an operator. For example, two
+ :class:`_schema.Column` objects, both of the type
+ :class:`.Integer`, will
+ produce a :class:`.BinaryExpression` that also has the type
+ :class:`.Integer` when compared via the addition (``+``) operator.
+ However, using the addition operator with an :class:`.Integer`
+ and a :class:`.Date` object will produce a :class:`.Date`, assuming
+ "days delta" behavior by the database (in reality, most databases
+ other than PostgreSQL don't accept this particular operation).
+
+ The method returns a tuple of the form <operator>, <type>.
+ The resulting operator and type will be those applied to the
+ resulting :class:`.BinaryExpression` as the final operator and the
+ right-hand side of the expression.
+
+ Note that only a subset of operators make usage of
+ :meth:`._adapt_expression`,
+ including math operators and user-defined operators, but not
+ boolean comparison or special SQL keywords like MATCH or BETWEEN.
+
+ """
+
+ return op, self.type
+
+ def __reduce__(self):
+ return _reconstitute_comparator, (self.expr,)
+
+ hashable = True
+ """Flag, if False, means values from this type aren't hashable.
+
+ Used by the ORM when uniquing result lists.
+
+ """
+
+ comparator_factory = Comparator
+ """A :class:`.TypeEngine.Comparator` class which will apply
+ to operations performed by owning :class:`_expression.ColumnElement`
+ objects.
+
+ The :attr:`.comparator_factory` attribute is a hook consulted by
+ the core expression system when column and SQL expression operations
+ are performed. When a :class:`.TypeEngine.Comparator` class is
+ associated with this attribute, it allows custom re-definition of
+ all existing operators, as well as definition of new operators.
+ Existing operators include those provided by Python operator overloading
+ such as :meth:`.operators.ColumnOperators.__add__` and
+ :meth:`.operators.ColumnOperators.__eq__`,
+ those provided as standard
+ attributes of :class:`.operators.ColumnOperators` such as
+ :meth:`.operators.ColumnOperators.like`
+ and :meth:`.operators.ColumnOperators.in_`.
+
+ Rudimentary usage of this hook is allowed through simple subclassing
+ of existing types, or alternatively by using :class:`.TypeDecorator`.
+ See the documentation section :ref:`types_operators` for examples.
+
+ """
+
+ sort_key_function = None
+ """A sorting function that can be passed as the key to sorted.
+
+ The default value of ``None`` indicates that the values stored by
+ this type are self-sorting.
+
+ .. versionadded:: 1.3.8
+
+ """
+
+ should_evaluate_none = False
+ """If True, the Python constant ``None`` is considered to be handled
+ explicitly by this type.
+
+ The ORM uses this flag to indicate that a positive value of ``None``
+ is passed to the column in an INSERT statement, rather than omitting
+ the column from the INSERT statement which has the effect of firing
+ off column-level defaults. It also allows types which have special
+ behavior for Python None, such as a JSON type, to indicate that
+ they'd like to handle the None value explicitly.
+
+ To set this flag on an existing type, use the
+ :meth:`.TypeEngine.evaluates_none` method.
+
+ .. seealso::
+
+ :meth:`.TypeEngine.evaluates_none`
+
+ .. versionadded:: 1.1
+
+
+ """
+
+ def evaluates_none(self):
+ """Return a copy of this type which has the
+ :attr:`.should_evaluate_none` flag set to True.
+
+ E.g.::
+
+ Table(
+ 'some_table', metadata,
+ Column(
+ String(50).evaluates_none(),
+ nullable=True,
+ server_default='no value')
+ )
+
+ The ORM uses this flag to indicate that a positive value of ``None``
+ is passed to the column in an INSERT statement, rather than omitting
+ the column from the INSERT statement which has the effect of firing
+ off column-level defaults. It also allows for types which have
+ special behavior associated with the Python None value to indicate
+ that the value doesn't necessarily translate into SQL NULL; a
+ prime example of this is a JSON type which may wish to persist the
+ JSON value ``'null'``.
+
+ In all cases, the actual NULL SQL value can be always be
+ persisted in any column by using
+ the :obj:`_expression.null` SQL construct in an INSERT statement
+ or associated with an ORM-mapped attribute.
+
+ .. note::
+
+ The "evaluates none" flag does **not** apply to a value
+ of ``None`` passed to :paramref:`_schema.Column.default` or
+ :paramref:`_schema.Column.server_default`; in these cases,
+ ``None``
+ still means "no default".
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_forcing_null` - in the ORM documentation
+
+ :paramref:`.postgresql.JSON.none_as_null` - PostgreSQL JSON
+ interaction with this flag.
+
+ :attr:`.TypeEngine.should_evaluate_none` - class-level flag
+
+ """
+ typ = self.copy()
+ typ.should_evaluate_none = True
+ return typ
+
+ def copy(self, **kw):
+ return self.adapt(self.__class__)
+
+ def compare_against_backend(self, dialect, conn_type):
+ """Compare this type against the given backend type.
+
+ This function is currently not implemented for SQLAlchemy
+ types, and for all built in types will return ``None``. However,
+ it can be implemented by a user-defined type
+ where it can be consumed by schema comparison tools such as
+ Alembic autogenerate.
+
+ A future release of SQLAlchemy will potentially implement this method
+ for builtin types as well.
+
+ The function should return True if this type is equivalent to the
+ given type; the type is typically reflected from the database
+ so should be database specific. The dialect in use is also
+ passed. It can also return False to assert that the type is
+ not equivalent.
+
+ :param dialect: a :class:`.Dialect` that is involved in the comparison.
+
+ :param conn_type: the type object reflected from the backend.
+
+ .. versionadded:: 1.0.3
+
+ """
+ return None
+
+ def copy_value(self, value):
+ return value
+
+ def literal_processor(self, dialect):
+ """Return a conversion function for processing literal values that are
+ to be rendered directly without using binds.
+
+ This function is used when the compiler makes use of the
+ "literal_binds" flag, typically used in DDL generation as well
+ as in certain scenarios where backends don't accept bound parameters.
+
+ Returns a callable which will receive a literal Python value
+ as the sole positional argument and will return a string representation
+ to be rendered in a SQL statement.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.literal_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.literal_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_literal_param`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ """
+ return None
+
+ def bind_processor(self, dialect):
+ """Return a conversion function for processing bind values.
+
+ Returns a callable which will receive a bind parameter value
+ as the sole positional argument and will return a value to
+ send to the DB-API.
+
+ If processing is not necessary, the method should return ``None``.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.bind_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.bind_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_bind_param`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ :param dialect: Dialect instance in use.
+
+ """
+ return None
+
+ def result_processor(self, dialect, coltype):
+ """Return a conversion function for processing result row values.
+
+ Returns a callable which will receive a result row column
+ value as the sole positional argument and will return a value
+ to return to the user.
+
+ If processing is not necessary, the method should return ``None``.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.result_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.result_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_result_value`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :param dialect: Dialect instance in use.
+
+ :param coltype: DBAPI coltype argument received in cursor.description.
+
+ """
+ return None
+
+ def column_expression(self, colexpr):
+ """Given a SELECT column expression, return a wrapping SQL expression.
+
+ This is typically a SQL function that wraps a column expression
+ as rendered in the columns clause of a SELECT statement.
+ It is used for special data types that require
+ columns to be wrapped in some special database function in order
+ to coerce the value before being sent back to the application.
+ It is the SQL analogue of the :meth:`.TypeEngine.result_processor`
+ method.
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.column_expression`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.column_expression`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.column_expression`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ .. seealso::
+
+ :ref:`types_sql_value_processing`
+
+ """
+
+ return None
+
+ @util.memoized_property
+ def _has_column_expression(self):
+ """memoized boolean, check if column_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return (
+ self.__class__.column_expression.__code__
+ is not TypeEngine.column_expression.__code__
+ )
+
+ def bind_expression(self, bindvalue):
+ """Given a bind value (i.e. a :class:`.BindParameter` instance),
+ return a SQL expression in its place.
+
+ This is typically a SQL function that wraps the existing bound
+ parameter within the statement. It is used for special data types
+ that require literals being wrapped in some special database function
+ in order to coerce an application-level value into a database-specific
+ format. It is the SQL analogue of the
+ :meth:`.TypeEngine.bind_processor` method.
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values.
+
+ Note that this method, when implemented, should always return
+ the exact same structure, without any conditional logic, as it
+ may be used in an executemany() call against an arbitrary number
+ of bound parameter sets.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.bind_expression`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.bind_expression`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.bind_expression`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ .. seealso::
+
+ :ref:`types_sql_value_processing`
+
+ """
+ return None
+
+ @util.memoized_property
+ def _has_bind_expression(self):
+ """memoized boolean, check if bind_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return util.method_is_overridden(self, TypeEngine.bind_expression)
+
+ @staticmethod
+ def _to_instance(cls_or_self):
+ return to_instance(cls_or_self)
+
+ def compare_values(self, x, y):
+ """Compare two values for equality."""
+
+ return x == y
+
+ def get_dbapi_type(self, dbapi):
+ """Return the corresponding type object from the underlying DB-API, if
+ any.
+
+ This can be useful for calling ``setinputsizes()``, for example.
+
+ """
+ return None
+
+ @property
+ def python_type(self):
+ """Return the Python type object expected to be returned
+ by instances of this type, if known.
+
+ Basically, for those types which enforce a return type,
+ or are known across the board to do such for all common
+ DBAPIs (like ``int`` for example), will return that type.
+
+ If a return type is not defined, raises
+ ``NotImplementedError``.
+
+ Note that any type also accommodates NULL in SQL which
+ means you can also get back ``None`` from any type
+ in practice.
+
+ """
+ raise NotImplementedError()
+
+ def with_variant(self, type_, dialect_name):
+ r"""Produce a new type object that will utilize the given
+ type when applied to the dialect of the given name.
+
+ e.g.::
+
+ from sqlalchemy.types import String
+ from sqlalchemy.dialects import mysql
+
+ s = String()
+
+ s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql')
+
+ The construction of :meth:`.TypeEngine.with_variant` is always
+ from the "fallback" type to that which is dialect specific.
+ The returned type is an instance of :class:`.Variant`, which
+ itself provides a :meth:`.Variant.with_variant`
+ that can be called repeatedly.
+
+ :param type\_: a :class:`.TypeEngine` that will be selected
+ as a variant from the originating type, when a dialect
+ of the given name is in use.
+ :param dialect_name: base name of the dialect which uses
+ this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
+
+ """
+ return Variant(self, {dialect_name: to_instance(type_)})
+
+ def _resolve_for_literal(self, value):
+ """adjust this type given a literal Python value that will be
+ stored in a bound parameter.
+
+ Used exclusively by _resolve_value_to_type().
+
+ .. versionadded:: 1.4.30 or 2.0
+
+ """
+ return self
+
+ @util.memoized_property
+ def _type_affinity(self):
+ """Return a rudimental 'affinity' value expressing the general class
+ of type."""
+
+ typ = None
+ for t in self.__class__.__mro__:
+ if t in (TypeEngine, UserDefinedType):
+ return typ
+ elif issubclass(t, (TypeEngine, UserDefinedType)):
+ typ = t
+ else:
+ return self.__class__
+
+ @util.memoized_property
+ def _generic_type_affinity(self):
+ best_camelcase = None
+ best_uppercase = None
+
+ if not isinstance(self, (TypeEngine, UserDefinedType)):
+ return self.__class__
+
+ for t in self.__class__.__mro__:
+ if (
+ t.__module__
+ in (
+ "sqlalchemy.sql.sqltypes",
+ "sqlalchemy.sql.type_api",
+ )
+ and issubclass(t, TypeEngine)
+ and t is not TypeEngine
+ and t.__name__[0] != "_"
+ ):
+ if t.__name__.isupper() and not best_uppercase:
+ best_uppercase = t
+ elif not t.__name__.isupper() and not best_camelcase:
+ best_camelcase = t
+
+ return best_camelcase or best_uppercase or NULLTYPE.__class__
+
+ def as_generic(self, allow_nulltype=False):
+ """
+ Return an instance of the generic type corresponding to this type
+ using heuristic rule. The method may be overridden if this
+ heuristic rule is not sufficient.
+
+ >>> from sqlalchemy.dialects.mysql import INTEGER
+ >>> INTEGER(display_width=4).as_generic()
+ Integer()
+
+ >>> from sqlalchemy.dialects.mysql import NVARCHAR
+ >>> NVARCHAR(length=100).as_generic()
+ Unicode(length=100)
+
+ .. versionadded:: 1.4.0b2
+
+
+ .. seealso::
+
+ :ref:`metadata_reflection_dbagnostic_types` - describes the
+ use of :meth:`_types.TypeEngine.as_generic` in conjunction with
+ the :meth:`_sql.DDLEvents.column_reflect` event, which is its
+ intended use.
+
+ """
+ if (
+ not allow_nulltype
+ and self._generic_type_affinity == NULLTYPE.__class__
+ ):
+ raise NotImplementedError(
+ "Default TypeEngine.as_generic() "
+ "heuristic method was unsuccessful for {}. A custom "
+ "as_generic() method must be implemented for this "
+ "type class.".format(
+ self.__class__.__module__ + "." + self.__class__.__name__
+ )
+ )
+
+ return util.constructor_copy(self, self._generic_type_affinity)
+
+ def dialect_impl(self, dialect):
+ """Return a dialect-specific implementation for this
+ :class:`.TypeEngine`.
+
+ """
+ try:
+ return dialect._type_memos[self]["impl"]
+ except KeyError:
+ pass
+ return self._dialect_info(dialect)["impl"]
+
+ def _unwrapped_dialect_impl(self, dialect):
+ """Return the 'unwrapped' dialect impl for this type.
+
+ For a type that applies wrapping logic (e.g. TypeDecorator), give
+ us the real, actual dialect-level type that is used.
+
+ This is used by TypeDecorator itself as well at least one case where
+ dialects need to check that a particular specific dialect-level
+ type is in use, within the :meth:`.DefaultDialect.set_input_sizes`
+ method.
+
+ """
+ return self.dialect_impl(dialect)
+
+ def _cached_literal_processor(self, dialect):
+ """Return a dialect-specific literal processor for this type."""
+ try:
+ return dialect._type_memos[self]["literal"]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into literal_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ d["literal"] = lp = d["impl"].literal_processor(dialect)
+ return lp
+
+ def _cached_bind_processor(self, dialect):
+ """Return a dialect-specific bind processor for this type."""
+
+ try:
+ return dialect._type_memos[self]["bind"]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into bind_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ d["bind"] = bp = d["impl"].bind_processor(dialect)
+ return bp
+
+ def _cached_result_processor(self, dialect, coltype):
+ """Return a dialect-specific result processor for this type."""
+
+ try:
+ return dialect._type_memos[self][coltype]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into result_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ # key assumption: DBAPI type codes are
+ # constants. Else this dictionary would
+ # grow unbounded.
+ d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+ return rp
+
+ def _cached_custom_processor(self, dialect, key, fn):
+ try:
+ return dialect._type_memos[self][key]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into fn() function
+ # raises
+ d = self._dialect_info(dialect)
+ impl = d["impl"]
+ d[key] = result = fn(impl)
+ return result
+
+ def _dialect_info(self, dialect):
+ """Return a dialect-specific registry which
+ caches a dialect-specific implementation, bind processing
+ function, and one or more result processing functions."""
+
+ if self in dialect._type_memos:
+ return dialect._type_memos[self]
+ else:
+ impl = self._gen_dialect_impl(dialect)
+ if impl is self:
+ impl = self.adapt(type(self))
+ # this can't be self, else we create a cycle
+ assert impl is not self
+ dialect._type_memos[self] = d = {"impl": impl}
+ return d
+
+ def _gen_dialect_impl(self, dialect):
+ return dialect.type_descriptor(self)
+
+ @util.memoized_property
+ def _static_cache_key(self):
+ names = util.get_cls_kwargs(self.__class__)
+ return (self.__class__,) + tuple(
+ (
+ k,
+ self.__dict__[k]._static_cache_key
+ if isinstance(self.__dict__[k], TypeEngine)
+ else self.__dict__[k],
+ )
+ for k in names
+ if k in self.__dict__ and not k.startswith("_")
+ )
+
+ def adapt(self, cls, **kw):
+ """Produce an "adapted" form of this type, given an "impl" class
+ to work with.
+
+ This method is used internally to associate generic
+ types with "implementation" types that are specific to a particular
+ dialect.
+ """
+ return util.constructor_copy(self, cls, **kw)
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ Given an operator and value, gives the type a chance
+ to return a type which the value should be coerced into.
+
+ The default behavior here is conservative; if the right-hand
+ side is already coerced into a SQL type based on its
+ Python type, it is usually left alone.
+
+ End-user functionality extension here should generally be via
+ :class:`.TypeDecorator`, which provides more liberal behavior in that
+ it defaults to coercing the other side of the expression into this
+ type, thus applying special Python conversions above and beyond those
+ needed by the DBAPI to both ides. It also provides the public method
+ :meth:`.TypeDecorator.coerce_compared_value` which is intended for
+ end-user customization of this behavior.
+
+ """
+ _coerced_type = _resolve_value_to_type(value)
+ if (
+ _coerced_type is NULLTYPE
+ or _coerced_type._type_affinity is self._type_affinity
+ ):
+ return self
+ else:
+ return _coerced_type
+
+ def _compare_type_affinity(self, other):
+ return self._type_affinity is other._type_affinity
+
+ def compile(self, dialect=None):
+ """Produce a string-compiled form of this :class:`.TypeEngine`.
+
+ When called with no arguments, uses a "default" dialect
+ to produce a string result.
+
+ :param dialect: a :class:`.Dialect` instance.
+
+ """
+ # arg, return value is inconsistent with
+ # ClauseElement.compile()....this is a mistake.
+
+ if not dialect:
+ dialect = self._default_dialect()
+
+ return dialect.type_compiler.process(self)
+
+ @util.preload_module("sqlalchemy.engine.default")
+ def _default_dialect(self):
+ default = util.preloaded.engine_default
+ return default.StrCompileDialect()
+
+ def __str__(self):
+ if util.py2k:
+ return unicode(self.compile()).encode( # noqa
+ "ascii", "backslashreplace"
+ ) # noqa
+ else:
+ return str(self.compile())
+
+ def __repr__(self):
+ return util.generic_repr(self)
+
+
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
+ pass
+
+
+class ExternalType(object):
+ """mixin that defines attributes and behaviors specific to third-party
+ datatypes.
+
+ "Third party" refers to datatypes that are defined outside the scope
+ of SQLAlchemy within either end-user application code or within
+ external extensions to SQLAlchemy.
+
+ Subclasses currently include :class:`.TypeDecorator` and
+ :class:`.UserDefinedType`.
+
+ .. versionadded:: 1.4.28
+
+ """
+
+ cache_ok = None
+ """Indicate if statements using this :class:`.ExternalType` are "safe to
+ cache".
+
+ The default value ``None`` will emit a warning and then not allow caching
+ of a statement which includes this type. Set to ``False`` to disable
+ statements using this type from being cached at all without a warning.
+ When set to ``True``, the object's class and selected elements from its
+ state will be used as part of the cache key. For example, using a
+ :class:`.TypeDecorator`::
+
+ class MyType(TypeDecorator):
+ impl = String
+
+ cache_ok = True
+
+ def __init__(self, choices):
+ self.choices = tuple(choices)
+ self.internal_only = True
+
+ The cache key for the above type would be equivalent to::
+
+ >>> MyType(["a", "b", "c"])._static_cache_key
+ (<class '__main__.MyType'>, ('choices', ('a', 'b', 'c')))
+
+ The caching scheme will extract attributes from the type that correspond
+ to the names of parameters in the ``__init__()`` method. Above, the
+ "choices" attribute becomes part of the cache key but "internal_only"
+ does not, because there is no parameter named "internal_only".
+
+ The requirements for cacheable elements is that they are hashable
+ and also that they indicate the same SQL rendered for expressions using
+ this type every time for a given cache value.
+
+ To accommodate for datatypes that refer to unhashable structures such
+ as dictionaries, sets and lists, these objects can be made "cacheable"
+ by assigning hashable structures to the attributes whose names
+ correspond with the names of the arguments. For example, a datatype
+ which accepts a dictionary of lookup values may publish this as a sorted
+ series of tuples. Given a previously un-cacheable type as::
+
+ class LookupType(UserDefinedType):
+ '''a custom type that accepts a dictionary as a parameter.
+
+ this is the non-cacheable version, as "self.lookup" is not
+ hashable.
+
+ '''
+
+ def __init__(self, lookup):
+ self.lookup = lookup
+
+ def get_col_spec(self, **kw):
+ return "VARCHAR(255)"
+
+ def bind_processor(self, dialect):
+ # ... works with "self.lookup" ...
+
+ Where "lookup" is a dictionary. The type will not be able to generate
+ a cache key::
+
+ >>> type_ = LookupType({"a": 10, "b": 20})
+ >>> type_._static_cache_key
+ <stdin>:1: SAWarning: UserDefinedType LookupType({'a': 10, 'b': 20}) will not
+ produce a cache key because the ``cache_ok`` flag is not set to True.
+ Set this flag to True if this type object's state is safe to use
+ in a cache key, or False to disable this warning.
+ symbol('no_cache')
+
+ If we **did** set up such a cache key, it wouldn't be usable. We would
+ get a tuple structure that contains a dictionary inside of it, which
+ cannot itself be used as a key in a "cache dictionary" such as SQLAlchemy's
+ statement cache, since Python dictionaries aren't hashable::
+
+ >>> # set cache_ok = True
+ >>> type_.cache_ok = True
+
+ >>> # this is the cache key it would generate
+ >>> key = type_._static_cache_key
+ >>> key
+ (<class '__main__.LookupType'>, ('lookup', {'a': 10, 'b': 20}))
+
+ >>> # however this key is not hashable, will fail when used with
+ >>> # SQLAlchemy statement cache
+ >>> some_cache = {key: "some sql value"}
+ Traceback (most recent call last): File "<stdin>", line 1,
+ in <module> TypeError: unhashable type: 'dict'
+
+ The type may be made cacheable by assigning a sorted tuple of tuples
+ to the ".lookup" attribute::
+
+ class LookupType(UserDefinedType):
+ '''a custom type that accepts a dictionary as a parameter.
+
+ The dictionary is stored both as itself in a private variable,
+ and published in a public variable as a sorted tuple of tuples,
+ which is hashable and will also return the same value for any
+ two equivalent dictionaries. Note it assumes the keys and
+ values of the dictionary are themselves hashable.
+
+ '''
+
+ cache_ok = True
+
+ def __init__(self, lookup):
+ self._lookup = lookup
+
+ # assume keys/values of "lookup" are hashable; otherwise
+ # they would also need to be converted in some way here
+ self.lookup = tuple(
+ (key, lookup[key]) for key in sorted(lookup)
+ )
+
+ def get_col_spec(self, **kw):
+ return "VARCHAR(255)"
+
+ def bind_processor(self, dialect):
+ # ... works with "self._lookup" ...
+
+ Where above, the cache key for ``LookupType({"a": 10, "b": 20})`` will be::
+
+ >>> LookupType({"a": 10, "b": 20})._static_cache_key
+ (<class '__main__.LookupType'>, ('lookup', (('a', 10), ('b', 20))))
+
+ .. versionadded:: 1.4.14 - added the ``cache_ok`` flag to allow
+ some configurability of caching for :class:`.TypeDecorator` classes.
+
+ .. versionadded:: 1.4.28 - added the :class:`.ExternalType` mixin which
+ generalizes the ``cache_ok`` flag to both the :class:`.TypeDecorator`
+ and :class:`.UserDefinedType` classes.
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ """ # noqa: E501
+
+ @property
+ def _static_cache_key(self):
+ cache_ok = self.__class__.__dict__.get("cache_ok", None)
+
+ if cache_ok is None:
+ subtype_idx = self.__class__.__mro__.index(ExternalType)
+ subtype = self.__class__.__mro__[max(subtype_idx - 1, 0)]
+
+ util.warn(
+ "%s %r will not produce a cache key because "
+ "the ``cache_ok`` attribute is not set to True. This can "
+ "have significant performance implications including some "
+ "performance degradations in comparison to prior SQLAlchemy "
+ "versions. Set this attribute to True if this type object's "
+ "state is safe to use in a cache key, or False to "
+ "disable this warning." % (subtype.__name__, self),
+ code="cprf",
+ )
+ elif cache_ok is True:
+ return super(ExternalType, self)._static_cache_key
+
+ return NO_CACHE
+
+
+class UserDefinedType(
+ util.with_metaclass(VisitableCheckKWArg, ExternalType, TypeEngine)
+):
+ """Base for user defined types.
+
+ This should be the base of new types. Note that
+ for most cases, :class:`.TypeDecorator` is probably
+ more appropriate::
+
+ import sqlalchemy.types as types
+
+ class MyType(types.UserDefinedType):
+ cache_ok = True
+
+ def __init__(self, precision = 8):
+ self.precision = precision
+
+ def get_col_spec(self, **kw):
+ return "MYTYPE(%s)" % self.precision
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ return value
+ return process
+
+ Once the type is made, it's immediately usable::
+
+ table = Table('foo', metadata_obj,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyType(16))
+ )
+
+ The ``get_col_spec()`` method will in most cases receive a keyword
+ argument ``type_expression`` which refers to the owning expression
+ of the type as being compiled, such as a :class:`_schema.Column` or
+ :func:`.cast` construct. This keyword is only sent if the method
+ accepts keyword arguments (e.g. ``**kw``) in its argument signature;
+ introspection is used to check for this in order to support legacy
+ forms of this function.
+
+ .. versionadded:: 1.0.0 the owning expression is passed to
+ the ``get_col_spec()`` method via the keyword argument
+ ``type_expression``, if it receives ``**kw`` in its signature.
+
+ The :attr:`.UserDefinedType.cache_ok` class-level flag indicates if this
+ custom :class:`.UserDefinedType` is safe to be used as part of a cache key.
+ This flag defaults to ``None`` which will initially generate a warning
+ when the SQL compiler attempts to generate a cache key for a statement
+ that uses this type. If the :class:`.UserDefinedType` is not guaranteed
+ to produce the same bind/result behavior and SQL generation
+ every time, this flag should be set to ``False``; otherwise if the
+ class produces the same behavior each time, it may be set to ``True``.
+ See :attr:`.UserDefinedType.cache_ok` for further notes on how this works.
+
+ .. versionadded:: 1.4.28 Generalized the :attr:`.ExternalType.cache_ok`
+ flag so that it is available for both :class:`.TypeDecorator` as well
+ as :class:`.UserDefinedType`.
+
+ """
+
+ __visit_name__ = "user_defined"
+
+ ensure_kwarg = "get_col_spec"
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ Default behavior for :class:`.UserDefinedType` is the
+ same as that of :class:`.TypeDecorator`; by default it returns
+ ``self``, assuming the compared value should be coerced into
+ the same type as this one. See
+ :meth:`.TypeDecorator.coerce_compared_value` for more detail.
+
+ """
+
+ return self
+
+
+class Emulated(object):
+ """Mixin for base types that emulate the behavior of a DB-native type.
+
+ An :class:`.Emulated` type will use an available database type
+ in conjunction with Python-side routines and/or database constraints
+ in order to approximate the behavior of a database type that is provided
+ natively by some backends. When a native-providing backend is in
+ use, the native version of the type is used. This native version
+ should include the :class:`.NativeForEmulated` mixin to allow it to be
+ distinguished from :class:`.Emulated`.
+
+ Current examples of :class:`.Emulated` are: :class:`.Interval`,
+ :class:`.Enum`, :class:`.Boolean`.
+
+ .. versionadded:: 1.2.0b3
+
+ """
+
+ def adapt_to_emulated(self, impltype, **kw):
+ """Given an impl class, adapt this type to the impl assuming
+ "emulated".
+
+ The impl should also be an "emulated" version of this type,
+ most likely the same class as this type itself.
+
+ e.g.: sqltypes.Enum adapts to the Enum class.
+
+ """
+ return super(Emulated, self).adapt(impltype, **kw)
+
+ def adapt(self, impltype, **kw):
+ if hasattr(impltype, "adapt_emulated_to_native"):
+ if self.native:
+ # native support requested, dialect gave us a native
+ # implementor, pass control over to it
+ return impltype.adapt_emulated_to_native(self, **kw)
+ else:
+ # non-native support, let the native implementor
+ # decide also, at the moment this is just to help debugging
+ # as only the default logic is implemented.
+ return impltype.adapt_native_to_emulated(self, **kw)
+ else:
+ if issubclass(impltype, self.__class__):
+ return self.adapt_to_emulated(impltype, **kw)
+ else:
+ return super(Emulated, self).adapt(impltype, **kw)
+
+
+class NativeForEmulated(object):
+ """Indicates DB-native types supported by an :class:`.Emulated` type.
+
+ .. versionadded:: 1.2.0b3
+
+ """
+
+ @classmethod
+ def adapt_native_to_emulated(cls, impl, **kw):
+ """Given an impl, adapt this type's class to the impl assuming
+ "emulated".
+
+
+ """
+ impltype = impl.__class__
+ return impl.adapt(impltype, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Given an impl, adapt this type's class to the impl assuming
+ "native".
+
+ The impl will be an :class:`.Emulated` class but not a
+ :class:`.NativeForEmulated`.
+
+ e.g.: postgresql.ENUM produces a type given an Enum instance.
+
+ """
+ return cls(**kw)
+
+
+class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
+ """Allows the creation of types which add additional functionality
+ to an existing type.
+
+ This method is preferred to direct subclassing of SQLAlchemy's
+ built-in types as it ensures that all required functionality of
+ the underlying type is kept in place.
+
+ Typical usage::
+
+ import sqlalchemy.types as types
+
+ class MyType(types.TypeDecorator):
+ '''Prefixes Unicode values with "PREFIX:" on the way in and
+ strips it off on the way out.
+ '''
+
+ impl = types.Unicode
+
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect):
+ return "PREFIX:" + value
+
+ def process_result_value(self, value, dialect):
+ return value[7:]
+
+ def copy(self, **kw):
+ return MyType(self.impl.length)
+
+ The class-level ``impl`` attribute is required, and can reference any
+ :class:`.TypeEngine` class. Alternatively, the :meth:`load_dialect_impl`
+ method can be used to provide different type classes based on the dialect
+ given; in this case, the ``impl`` variable can reference
+ ``TypeEngine`` as a placeholder.
+
+ The :attr:`.TypeDecorator.cache_ok` class-level flag indicates if this
+ custom :class:`.TypeDecorator` is safe to be used as part of a cache key.
+ This flag defaults to ``None`` which will initially generate a warning
+ when the SQL compiler attempts to generate a cache key for a statement
+ that uses this type. If the :class:`.TypeDecorator` is not guaranteed
+ to produce the same bind/result behavior and SQL generation
+ every time, this flag should be set to ``False``; otherwise if the
+ class produces the same behavior each time, it may be set to ``True``.
+ See :attr:`.TypeDecorator.cache_ok` for further notes on how this works.
+
+ Types that receive a Python type that isn't similar to the ultimate type
+ used may want to define the :meth:`TypeDecorator.coerce_compared_value`
+ method. This is used to give the expression system a hint when coercing
+ Python objects into bind parameters within expressions. Consider this
+ expression::
+
+ mytable.c.somecol + datetime.date(2009, 5, 15)
+
+ Above, if "somecol" is an ``Integer`` variant, it makes sense that
+ we're doing date arithmetic, where above is usually interpreted
+ by databases as adding a number of days to the given date.
+ The expression system does the right thing by not attempting to
+ coerce the "date()" value into an integer-oriented bind parameter.
+
+ However, in the case of ``TypeDecorator``, we are usually changing an
+ incoming Python type to something new - ``TypeDecorator`` by default will
+ "coerce" the non-typed side to be the same type as itself. Such as below,
+ we define an "epoch" type that stores a date value as an integer::
+
+ class MyEpochType(types.TypeDecorator):
+ impl = types.Integer
+
+ epoch = datetime.date(1970, 1, 1)
+
+ def process_bind_param(self, value, dialect):
+ return (value - self.epoch).days
+
+ def process_result_value(self, value, dialect):
+ return self.epoch + timedelta(days=value)
+
+ Our expression of ``somecol + date`` with the above type will coerce the
+ "date" on the right side to also be treated as ``MyEpochType``.
+
+ This behavior can be overridden via the
+ :meth:`~TypeDecorator.coerce_compared_value` method, which returns a type
+ that should be used for the value of the expression. Below we set it such
+ that an integer value will be treated as an ``Integer``, and any other
+ value is assumed to be a date and will be treated as a ``MyEpochType``::
+
+ def coerce_compared_value(self, op, value):
+ if isinstance(value, int):
+ return Integer()
+ else:
+ return self
+
+ .. warning::
+
+ Note that the **behavior of coerce_compared_value is not inherited
+ by default from that of the base type**.
+ If the :class:`.TypeDecorator` is augmenting a
+ type that requires special logic for certain types of operators,
+ this method **must** be overridden. A key example is when decorating
+ the :class:`_postgresql.JSON` and :class:`_postgresql.JSONB` types;
+ the default rules of :meth:`.TypeEngine.coerce_compared_value` should
+ be used in order to deal with operators like index operations::
+
+ from sqlalchemy import JSON
+ from sqlalchemy import TypeDecorator
+
+ class MyJsonType(TypeDecorator):
+ impl = JSON
+
+ cache_ok = True
+
+ def coerce_compared_value(self, op, value):
+ return self.impl.coerce_compared_value(op, value)
+
+ Without the above step, index operations such as ``mycol['foo']``
+ will cause the index value ``'foo'`` to be JSON encoded.
+
+ Similarly, when working with the :class:`.ARRAY` datatype, the
+ type coercion for index operations (e.g. ``mycol[5]``) is also
+ handled by :meth:`.TypeDecorator.coerce_compared_value`, where
+ again a simple override is sufficient unless special rules are needed
+ for particular operators::
+
+ from sqlalchemy import ARRAY
+ from sqlalchemy import TypeDecorator
+
+ class MyArrayType(TypeDecorator):
+ impl = ARRAY
+
+ cache_ok = True
+
+ def coerce_compared_value(self, op, value):
+ return self.impl.coerce_compared_value(op, value)
+
+
+ """
+
+ __visit_name__ = "type_decorator"
+
+ _is_type_decorator = True
+
+ def __init__(self, *args, **kwargs):
+ """Construct a :class:`.TypeDecorator`.
+
+ Arguments sent here are passed to the constructor
+ of the class assigned to the ``impl`` class level attribute,
+ assuming the ``impl`` is a callable, and the resulting
+ object is assigned to the ``self.impl`` instance attribute
+ (thus overriding the class attribute of the same name).
+
+ If the class level ``impl`` is not a callable (the unusual case),
+ it will be assigned to the same instance attribute 'as-is',
+ ignoring those arguments passed to the constructor.
+
+ Subclasses can override this to customize the generation
+ of ``self.impl`` entirely.
+
+ """
+
+ if not hasattr(self.__class__, "impl"):
+ raise AssertionError(
+ "TypeDecorator implementations "
+ "require a class-level variable "
+ "'impl' which refers to the class of "
+ "type being decorated"
+ )
+ self.impl = to_instance(self.__class__.impl, *args, **kwargs)
+
+ coerce_to_is_types = (util.NoneType,)
+ """Specify those Python types which should be coerced at the expression
+ level to "IS <constant>" when compared using ``==`` (and same for
+ ``IS NOT`` in conjunction with ``!=``).
+
+ For most SQLAlchemy types, this includes ``NoneType``, as well as
+ ``bool``.
+
+ :class:`.TypeDecorator` modifies this list to only include ``NoneType``,
+ as typedecorator implementations that deal with boolean types are common.
+
+ Custom :class:`.TypeDecorator` classes can override this attribute to
+ return an empty tuple, in which case no values will be coerced to
+ constants.
+
+ """
+
+ class Comparator(TypeEngine.Comparator):
+ """A :class:`.TypeEngine.Comparator` that is specific to
+ :class:`.TypeDecorator`.
+
+ User-defined :class:`.TypeDecorator` classes should not typically
+ need to modify this.
+
+
+ """
+
+ __slots__ = ()
+
+ def operate(self, op, *other, **kwargs):
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
+ return super(TypeDecorator.Comparator, self).operate(
+ op, *other, **kwargs
+ )
+
+ def reverse_operate(self, op, other, **kwargs):
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
+ return super(TypeDecorator.Comparator, self).reverse_operate(
+ op, other, **kwargs
+ )
+
+ @property
+ def comparator_factory(self):
+ if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
+ return self.impl.comparator_factory
+ else:
+ return type(
+ "TDComparator",
+ (TypeDecorator.Comparator, self.impl.comparator_factory),
+ {},
+ )
+
+ def _gen_dialect_impl(self, dialect):
+ """
+ #todo
+ """
+ adapted = dialect.type_descriptor(self)
+ if adapted is not self:
+ return adapted
+
+ # otherwise adapt the impl type, link
+ # to a copy of this TypeDecorator and return
+ # that.
+ typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
+ tt = self.copy()
+ if not isinstance(tt, self.__class__):
+ raise AssertionError(
+ "Type object %s does not properly "
+ "implement the copy() method, it must "
+ "return an object of type %s" % (self, self.__class__)
+ )
+ tt.impl = typedesc
+ return tt
+
+ @property
+ def _type_affinity(self):
+ """
+ #todo
+ """
+ return self.impl._type_affinity
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ super(TypeDecorator, self)._set_parent(column)
+
+ if not outer and isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent(column, outer=False, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ super(TypeDecorator, self)._set_parent_with_dispatch(
+ parent, outer=True
+ )
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent_with_dispatch(parent)
+
+ def type_engine(self, dialect):
+ """Return a dialect-specific :class:`.TypeEngine` instance
+ for this :class:`.TypeDecorator`.
+
+ In most cases this returns a dialect-adapted form of
+ the :class:`.TypeEngine` type represented by ``self.impl``.
+ Makes usage of :meth:`dialect_impl`.
+ Behavior can be customized here by overriding
+ :meth:`load_dialect_impl`.
+
+ """
+ adapted = dialect.type_descriptor(self)
+ if not isinstance(adapted, type(self)):
+ return adapted
+ else:
+ return self.load_dialect_impl(dialect)
+
+ def load_dialect_impl(self, dialect):
+ """Return a :class:`.TypeEngine` object corresponding to a dialect.
+
+ This is an end-user override hook that can be used to provide
+ differing types depending on the given dialect. It is used
+ by the :class:`.TypeDecorator` implementation of :meth:`type_engine`
+ to help determine what type should ultimately be returned
+ for a given :class:`.TypeDecorator`.
+
+ By default returns ``self.impl``.
+
+ """
+ return self.impl
+
+ def _unwrapped_dialect_impl(self, dialect):
+ """Return the 'unwrapped' dialect impl for this type.
+
+ This is used by the :meth:`.DefaultDialect.set_input_sizes`
+ method.
+
+ """
+ # some dialects have a lookup for a TypeDecorator subclass directly.
+ # postgresql.INTERVAL being the main example
+ typ = self.dialect_impl(dialect)
+
+ # if we are still a type decorator, load the per-dialect switch
+ # (such as what Variant uses), then get the dialect impl for that.
+ if isinstance(typ, self.__class__):
+ return typ.load_dialect_impl(dialect).dialect_impl(dialect)
+ else:
+ return typ
+
+ def __getattr__(self, key):
+ """Proxy all other undefined accessors to the underlying
+ implementation."""
+ return getattr(self.impl, key)
+
+ def process_literal_param(self, value, dialect):
+ """Receive a literal parameter value to be rendered inline within
+ a statement.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. Unlike other SQL
+ compilation methods, it is passed a specific Python value to be
+ rendered as a string. However it should not be confused with the
+ :meth:`_types.TypeDecorator.process_bind_param` method, which is
+ the more typical method that processes the actual value passed to a
+ particular parameter at statement execution time.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for incoming data values
+ that are in the special case of being rendered as literals.
+
+ The returned string will be rendered into the output string.
+
+ """
+ raise NotImplementedError()
+
+ def process_bind_param(self, value, dialect):
+ """Receive a bound parameter value to be converted.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for incoming data values.
+ This method is called at **statement execution time** and is passed
+ the literal Python data value which is to be associated with a bound
+ parameter in the statement.
+
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or serializing data.
+ This could also be used as a hook for validating logic.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
+ :param dialect: the :class:`.Dialect` in use.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :meth:`_types.TypeDecorator.process_result_value`
+
+ """
+
+ raise NotImplementedError()
+
+ def process_result_value(self, value, dialect):
+ """Receive a result-row column value to be converted.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for data values
+ being received in result rows coming from the database.
+ This method is called at **result fetching time** and is passed
+ the literal Python data value that's extracted from a database result
+ row.
+
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or deserializing data.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
+ :param dialect: the :class:`.Dialect` in use.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :meth:`_types.TypeDecorator.process_bind_param`
+
+
+ """
+
+ raise NotImplementedError()
+
+ @util.memoized_property
+ def _has_bind_processor(self):
+ """memoized boolean, check if process_bind_param is implemented.
+
+ Allows the base process_bind_param to raise
+ NotImplementedError without needing to test an expensive
+ exception throw.
+
+ """
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_bind_param
+ )
+
+ @util.memoized_property
+ def _has_literal_processor(self):
+ """memoized boolean, check if process_literal_param is implemented."""
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_literal_param
+ )
+
+ def literal_processor(self, dialect):
+ """Provide a literal processing function for the given
+ :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for literal value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.literal_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_literal_param` so that the
+ "inner" processing provided by the implementing type is maintained.
+
+ """
+ if self._has_literal_processor:
+ process_param = self.process_literal_param
+ elif self._has_bind_processor:
+ # the bind processor should normally be OK
+ # for TypeDecorator since it isn't doing DB-level
+ # handling, the handling here won't be different for bound vs.
+ # literals.
+ process_param = self.process_bind_param
+ else:
+ process_param = None
+
+ if process_param:
+ impl_processor = self.impl.literal_processor(dialect)
+ if impl_processor:
+
+ def process(value):
+ return impl_processor(process_param(value, dialect))
+
+ else:
+
+ def process(value):
+ return process_param(value, dialect)
+
+ return process
+ else:
+ return self.impl.literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ """Provide a bound value processing function for the
+ given :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for bound value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.bind_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_bind_param` so that the "inner"
+ processing provided by the implementing type is maintained.
+
+ :param dialect: Dialect instance in use.
+
+ """
+ if self._has_bind_processor:
+ process_param = self.process_bind_param
+ impl_processor = self.impl.bind_processor(dialect)
+ if impl_processor:
+
+ def process(value):
+ return impl_processor(process_param(value, dialect))
+
+ else:
+
+ def process(value):
+ return process_param(value, dialect)
+
+ return process
+ else:
+ return self.impl.bind_processor(dialect)
+
+ @util.memoized_property
+ def _has_result_processor(self):
+ """memoized boolean, check if process_result_value is implemented.
+
+ Allows the base process_result_value to raise
+ NotImplementedError without needing to test an expensive
+ exception throw.
+
+ """
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_result_value
+ )
+
+ def result_processor(self, dialect, coltype):
+ """Provide a result value processing function for the given
+ :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for bound value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.result_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_result_value` so that the
+ "inner" processing provided by the implementing type is maintained.
+
+ :param dialect: Dialect instance in use.
+ :param coltype: A SQLAlchemy data type
+
+ """
+ if self._has_result_processor:
+ process_value = self.process_result_value
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ if impl_processor:
+
+ def process(value):
+ return process_value(impl_processor(value), dialect)
+
+ else:
+
+ def process(value):
+ return process_value(value, dialect)
+
+ return process
+ else:
+ return self.impl.result_processor(dialect, coltype)
+
+ @util.memoized_property
+ def _has_bind_expression(self):
+
+ return (
+ util.method_is_overridden(self, TypeDecorator.bind_expression)
+ or self.impl._has_bind_expression
+ )
+
+ def bind_expression(self, bindparam):
+ """Given a bind value (i.e. a :class:`.BindParameter` instance),
+ return a SQL expression which will typically wrap the given parameter.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** necessarily
+ called against specific values, and should not be confused with the
+ :meth:`_types.TypeDecorator.process_bind_param` method, which is
+ the more typical method that processes the actual value passed to a
+ particular parameter at statement execution time.
+
+ Subclasses of :class:`_types.TypeDecorator` can override this method
+ to provide custom bind expression behavior for the type. This
+ implementation will **replace** that of the underlying implementation
+ type.
+
+ """
+ return self.impl.bind_expression(bindparam)
+
+ @util.memoized_property
+ def _has_column_expression(self):
+ """memoized boolean, check if column_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return (
+ util.method_is_overridden(self, TypeDecorator.column_expression)
+ or self.impl._has_column_expression
+ )
+
+ def column_expression(self, column):
+ """Given a SELECT column expression, return a wrapping SQL expression.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values, and should not be confused with the
+ :meth:`_types.TypeDecorator.process_result_value` method, which is
+ the more typical method that processes the actual value returned
+ in a result row subsequent to statement execution time.
+
+ Subclasses of :class:`_types.TypeDecorator` can override this method
+ to provide custom column expresion behavior for the type. This
+ implementation will **replace** that of the underlying implementation
+ type.
+
+ See the description of :meth:`_types.TypeEngine.column_expression`
+ for a complete description of the method's use.
+
+ """
+
+ return self.impl.column_expression(column)
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ By default, returns self. This method is called by
+ the expression system when an object using this type is
+ on the left or right side of an expression against a plain Python
+ object which does not yet have a SQLAlchemy type assigned::
+
+ expr = table.c.somecolumn + 35
+
+ Where above, if ``somecolumn`` uses this type, this method will
+ be called with the value ``operator.add``
+ and ``35``. The return value is whatever SQLAlchemy type should
+ be used for ``35`` for this particular operation.
+
+ """
+ return self
+
+ def copy(self, **kw):
+ """Produce a copy of this :class:`.TypeDecorator` instance.
+
+ This is a shallow copy and is provided to fulfill part of
+ the :class:`.TypeEngine` contract. It usually does not
+ need to be overridden unless the user-defined :class:`.TypeDecorator`
+ has local state that should be deep-copied.
+
+ """
+
+ instance = self.__class__.__new__(self.__class__)
+ instance.__dict__.update(self.__dict__)
+ return instance
+
+ def get_dbapi_type(self, dbapi):
+ """Return the DBAPI type object represented by this
+ :class:`.TypeDecorator`.
+
+ By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the
+ underlying "impl".
+ """
+ return self.impl.get_dbapi_type(dbapi)
+
+ def compare_values(self, x, y):
+ """Given two values, compare them for equality.
+
+ By default this calls upon :meth:`.TypeEngine.compare_values`
+ of the underlying "impl", which in turn usually
+ uses the Python equals operator ``==``.
+
+ This function is used by the ORM to compare
+ an original-loaded value with an intercepted
+ "changed" value, to determine if a net change
+ has occurred.
+
+ """
+ return self.impl.compare_values(x, y)
+
+ @property
+ def sort_key_function(self):
+ return self.impl.sort_key_function
+
+ def __repr__(self):
+ return util.generic_repr(self, to_inspect=self.impl)
+
+
+class Variant(TypeDecorator):
+ """A wrapping type that selects among a variety of
+ implementations based on dialect in use.
+
+ The :class:`.Variant` type is typically constructed
+ using the :meth:`.TypeEngine.with_variant` method.
+
+ .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use.
+
+ """
+
+ cache_ok = True
+
+ def __init__(self, base, mapping):
+ """Construct a new :class:`.Variant`.
+
+ :param base: the base 'fallback' type
+ :param mapping: dictionary of string dialect names to
+ :class:`.TypeEngine` instances.
+
+ """
+ self.impl = base
+ self.mapping = mapping
+
+ @util.memoized_property
+ def _static_cache_key(self):
+ # TODO: needs tests in test/sql/test_compare.py
+ return (self.__class__,) + (
+ self.impl._static_cache_key,
+ tuple(
+ (key, self.mapping[key]._static_cache_key)
+ for key in sorted(self.mapping)
+ ),
+ )
+
+ def coerce_compared_value(self, operator, value):
+ result = self.impl.coerce_compared_value(operator, value)
+ if result is self.impl:
+ return self
+ else:
+ return result
+
+ def load_dialect_impl(self, dialect):
+ if dialect.name in self.mapping:
+ return self.mapping[dialect.name]
+ else:
+ return self.impl
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent(column, **kw)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent(column, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent_with_dispatch(parent)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(parent)
+
+ def with_variant(self, type_, dialect_name):
+ r"""Return a new :class:`.Variant` which adds the given
+ type + dialect name to the mapping, in addition to the
+ mapping present in this :class:`.Variant`.
+
+ :param type\_: a :class:`.TypeEngine` that will be selected
+ as a variant from the originating type, when a dialect
+ of the given name is in use.
+ :param dialect_name: base name of the dialect which uses
+ this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
+
+ """
+
+ if dialect_name in self.mapping:
+ raise exc.ArgumentError(
+ "Dialect '%s' is already present in "
+ "the mapping for this Variant" % dialect_name
+ )
+ mapping = self.mapping.copy()
+ mapping[dialect_name] = type_
+ return Variant(self.impl, mapping)
+
+ @property
+ def comparator_factory(self):
+ """express comparison behavior in terms of the base type"""
+ return self.impl.comparator_factory
+
+
+def _reconstitute_comparator(expression):
+ return expression.comparator
+
+
+def to_instance(typeobj, *arg, **kw):
+ if typeobj is None:
+ return NULLTYPE
+
+ if callable(typeobj):
+ return typeobj(*arg, **kw)
+ else:
+ return typeobj
+
+
+def adapt_type(typeobj, colspecs):
+ if isinstance(typeobj, type):
+ typeobj = typeobj()
+ for t in typeobj.__class__.__mro__[0:-1]:
+ try:
+ impltype = colspecs[t]
+ break
+ except KeyError:
+ pass
+ else:
+ # couldn't adapt - so just return the type itself
+ # (it may be a user-defined type)
+ return typeobj
+ # if we adapted the given generic type to a database-specific type,
+ # but it turns out the originally given "generic" type
+ # is actually a subclass of our resulting type, then we were already
+ # given a more specific type than that required; so use that.
+ if issubclass(typeobj.__class__, impltype):
+ return typeobj
+ return typeobj.adapt(impltype)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
new file mode 100644
index 0000000..019b29e
--- /dev/null
+++ b/lib/sqlalchemy/sql/util.py
@@ -0,0 +1,1120 @@
+# sql/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""High level utilities which build upon other modules here.
+
+"""
+
+from collections import deque
+from itertools import chain
+
+from . import coercions
+from . import operators
+from . import roles
+from . import visitors
+from .annotation import _deep_annotate # noqa
+from .annotation import _deep_deannotate # noqa
+from .annotation import _shallow_annotate # noqa
+from .base import _expand_cloned
+from .base import _from_objects
+from .base import ColumnSet
+from .ddl import sort_tables # noqa
+from .elements import _find_columns # noqa
+from .elements import _label_reference
+from .elements import _textual_label_reference
+from .elements import BindParameter
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import Grouping
+from .elements import Label
+from .elements import Null
+from .elements import UnaryExpression
+from .schema import Column
+from .selectable import Alias
+from .selectable import FromClause
+from .selectable import FromGrouping
+from .selectable import Join
+from .selectable import ScalarSelect
+from .selectable import SelectBase
+from .selectable import TableClause
+from .traversals import HasCacheKey # noqa
+from .. import exc
+from .. import util
+
+
+join_condition = util.langhelpers.public_factory(
+ Join._join_condition, ".sql.util.join_condition"
+)
+
+
+def find_join_source(clauses, join_to):
+ """Given a list of FROM clauses and a selectable,
+ return the first index and element from the list of
+ clauses which can be joined against the selectable. returns
+ None, None if no match is found.
+
+ e.g.::
+
+ clause1 = table1.join(table2)
+ clause2 = table4.join(table5)
+
+ join_to = table2.join(table3)
+
+ find_join_source([clause1, clause2], join_to) == clause1
+
+ """
+
+ selectables = list(_from_objects(join_to))
+ idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ if f.is_derived_from(s):
+ idx.append(i)
+ return idx
+
+
+def find_left_clause_that_matches_given(clauses, join_from):
+ """Given a list of FROM clauses and a selectable,
+ return the indexes from the list of
+ clauses which is derived from the selectable.
+
+ """
+
+ selectables = list(_from_objects(join_from))
+ liberal_idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ # basic check, if f is derived from s.
+ # this can be joins containing a table, or an aliased table
+ # or select statement matching to a table. This check
+ # will match a table to a selectable that is adapted from
+ # that table. With Query, this suits the case where a join
+ # is being made to an adapted entity
+ if f.is_derived_from(s):
+ liberal_idx.append(i)
+ break
+
+ # in an extremely small set of use cases, a join is being made where
+ # there are multiple FROM clauses where our target table is represented
+ # in more than one, such as embedded or similar. in this case, do
+ # another pass where we try to get a more exact match where we aren't
+ # looking at adaption relationships.
+ if len(liberal_idx) > 1:
+ conservative_idx = []
+ for idx in liberal_idx:
+ f = clauses[idx]
+ for s in selectables:
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
+ conservative_idx.append(idx)
+ break
+ if conservative_idx:
+ return conservative_idx
+
+ return liberal_idx
+
+
+def find_left_clause_to_join_from(clauses, join_to, onclause):
+ """Given a list of FROM clauses, a selectable,
+ and optional ON clause, return a list of integer indexes from the
+ clauses list indicating the clauses that can be joined from.
+
+ The presence of an "onclause" indicates that at least one clause can
+ definitely be joined from; if the list of clauses is of length one
+ and the onclause is given, returns that index. If the list of clauses
+ is more than length one, and the onclause is given, attempts to locate
+ which clauses contain the same columns.
+
+ """
+ idx = []
+ selectables = set(_from_objects(join_to))
+
+ # if we are given more than one target clause to join
+ # from, use the onclause to provide a more specific answer.
+ # otherwise, don't try to limit, after all, "ON TRUE" is a valid
+ # on clause
+ if len(clauses) > 1 and onclause is not None:
+ resolve_ambiguity = True
+ cols_in_onclause = _find_columns(onclause)
+ else:
+ resolve_ambiguity = False
+ cols_in_onclause = None
+
+ for i, f in enumerate(clauses):
+ for s in selectables.difference([f]):
+ if resolve_ambiguity:
+ if set(f.c).union(s.c).issuperset(cols_in_onclause):
+ idx.append(i)
+ break
+ elif onclause is not None or Join._can_join(f, s):
+ idx.append(i)
+ break
+
+ if len(idx) > 1:
+ # this is the same "hide froms" logic from
+ # Selectable._get_display_froms
+ toremove = set(
+ chain(*[_expand_cloned(f._hide_froms) for f in clauses])
+ )
+ idx = [i for i in idx if clauses[i] not in toremove]
+
+ # onclause was given and none of them resolved, so assume
+ # all indexes can match
+ if not idx and onclause is not None:
+ return range(len(clauses))
+ else:
+ return idx
+
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+
+ def visit(element):
+ if isinstance(element, ScalarSelect):
+ # we don't want to dig into correlated subqueries,
+ # those are just column elements by themselves
+ yield element
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+
+ list(visit(expr))
+ visit = None # remove gc cycles
+
+
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
+ """locate Table objects within the given expression."""
+
+ tables = []
+ _visitors = {}
+
+ if include_selects:
+ _visitors["select"] = _visitors["compound_select"] = tables.append
+
+ if include_joins:
+ _visitors["join"] = tables.append
+
+ if include_aliases:
+ _visitors["alias"] = _visitors["subquery"] = _visitors[
+ "tablesample"
+ ] = _visitors["lateral"] = tables.append
+
+ if include_crud:
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
+
+ if check_columns:
+
+ def visit_column(column):
+ tables.append(column.table)
+
+ _visitors["column"] = visit_column
+
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {}, _visitors)
+ return tables
+
+
+def unwrap_order_by(clause):
+ """Break up an 'order by' expression into individual column-expressions,
+ without DESC/ASC/NULLS FIRST/NULLS LAST"""
+
+ cols = util.column_set()
+ result = []
+ stack = deque([clause])
+
+ # examples
+ # column -> ASC/DESC == column
+ # column -> ASC/DESC -> label == column
+ # column -> label -> ASC/DESC -> label == column
+ # scalar_select -> label -> ASC/DESC == scalar_select -> label
+
+ while stack:
+ t = stack.popleft()
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
+ ):
+ if isinstance(t, Label) and not isinstance(
+ t.element, ScalarSelect
+ ):
+ t = t.element
+
+ if isinstance(t, Grouping):
+ t = t.element
+
+ stack.append(t)
+ continue
+ elif isinstance(t, _label_reference):
+ t = t.element
+
+ stack.append(t)
+ continue
+ if isinstance(t, (_textual_label_reference)):
+ continue
+ if t not in cols:
+ cols.add(t)
+ result.append(t)
+
+ else:
+ for c in t.get_children():
+ stack.append(c)
+ return result
+
+
+def unwrap_label_reference(element):
+ def replace(elem):
+ if isinstance(elem, (_label_reference, _textual_label_reference)):
+ return elem.element
+
+ return visitors.replacement_traverse(element, {}, replace)
+
+
+def expand_column_list_from_order_by(collist, order_by):
+ """Given the columns clause and ORDER BY of a selectable,
+ return a list of column expressions that can be added to the collist
+ corresponding to the ORDER BY, without repeating those already
+ in the collist.
+
+ """
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
+
+ to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
+
+ return [col for col in to_look_for if col not in cols_already_present]
+
+
+def clause_is_present(clause, search):
+ """Given a target clause and a second to search within, return True
+ if the target is plainly present in the search without any
+ subqueries or aliases involved.
+
+ Basically descends through Joins.
+
+ """
+
+ for elem in surface_selectables(search):
+ if clause == elem: # use == here so that Annotated's compare
+ return True
+ else:
+ return False
+
+
+def tables_from_leftmost(clause):
+ if isinstance(clause, Join):
+ for t in tables_from_leftmost(clause.left):
+ yield t
+ for t in tables_from_leftmost(clause.right):
+ yield t
+ elif isinstance(clause, FromGrouping):
+ for t in tables_from_leftmost(clause.element):
+ yield t
+ else:
+ yield clause
+
+
+def surface_selectables(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ yield elem
+ if isinstance(elem, Join):
+ stack.extend((elem.left, elem.right))
+ elif isinstance(elem, FromGrouping):
+ stack.append(elem.element)
+
+
+def surface_selectables_only(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ if isinstance(elem, (TableClause, Alias)):
+ yield elem
+ if isinstance(elem, Join):
+ stack.extend((elem.left, elem.right))
+ elif isinstance(elem, FromGrouping):
+ stack.append(elem.element)
+ elif isinstance(elem, ColumnClause):
+ if elem.table is not None:
+ stack.append(elem.table)
+ else:
+ yield elem
+ elif elem is not None:
+ yield elem
+
+
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
+
+ stack = deque([column])
+ while stack:
+ elem = stack.popleft()
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
+ for sub in elem.get_children():
+ if isinstance(sub, filter_):
+ continue
+ stack.append(sub)
+ return None
+
+
+def selectables_overlap(left, right):
+ """Return True if left/right have some overlapping selectable"""
+
+ return bool(
+ set(surface_selectables(left)).intersection(surface_selectables(right))
+ )
+
+
+def bind_values(clause):
+ """Return an ordered list of "bound" values in the given clause.
+
+ E.g.::
+
+ >>> expr = and_(
+ ... table.c.foo==5, table.c.foo==7
+ ... )
+ >>> bind_values(expr)
+ [5, 7]
+ """
+
+ v = []
+
+ def visit_bindparam(bind):
+ v.append(bind.effective_value)
+
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
+ return v
+
+
+def _quote_ddl_expr(element):
+ if isinstance(element, util.string_types):
+ element = element.replace("'", "''")
+ return "'%s'" % element
+ else:
+ return repr(element)
+
+
+class _repr_base(object):
+ _LIST = 0
+ _TUPLE = 1
+ _DICT = 2
+
+ __slots__ = ("max_chars",)
+
+ def trunc(self, value):
+ rep = repr(value)
+ lenrep = len(rep)
+ if lenrep > self.max_chars:
+ segment_length = self.max_chars // 2
+ rep = (
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
+ )
+ return rep
+
+
+class _repr_row(_repr_base):
+ """Provide a string view of a row."""
+
+ __slots__ = ("row",)
+
+ def __init__(self, row, max_chars=300):
+ self.row = row
+ self.max_chars = max_chars
+
+ def __repr__(self):
+ trunc = self.trunc
+ return "(%s%s)" % (
+ ", ".join(trunc(value) for value in self.row),
+ "," if len(self.row) == 1 else "",
+ )
+
+
+class _repr_params(_repr_base):
+ """Provide a string view of bound parameters.
+
+ Truncates display to a given number of 'multi' parameter sets,
+ as well as long values to a given number of characters.
+
+ """
+
+ __slots__ = "params", "batches", "ismulti"
+
+ def __init__(self, params, batches, max_chars=300, ismulti=None):
+ self.params = params
+ self.ismulti = ismulti
+ self.batches = batches
+ self.max_chars = max_chars
+
+ def __repr__(self):
+ if self.ismulti is None:
+ return self.trunc(self.params)
+
+ if isinstance(self.params, list):
+ typ = self._LIST
+
+ elif isinstance(self.params, tuple):
+ typ = self._TUPLE
+ elif isinstance(self.params, dict):
+ typ = self._DICT
+ else:
+ return self.trunc(self.params)
+
+ if self.ismulti and len(self.params) > self.batches:
+ msg = " ... displaying %i of %i total bound parameter sets ... "
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
+ elif self.ismulti:
+ return self._repr_multi(self.params, typ)
+ else:
+ return self._repr_params(self.params, typ)
+
+ def _repr_multi(self, multi_params, typ):
+ if multi_params:
+ if isinstance(multi_params[0], list):
+ elem_type = self._LIST
+ elif isinstance(multi_params[0], tuple):
+ elem_type = self._TUPLE
+ elif isinstance(multi_params[0], dict):
+ elem_type = self._DICT
+ else:
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
+
+ elements = ", ".join(
+ self._repr_params(params, elem_type) for params in multi_params
+ )
+ else:
+ elements = ""
+
+ if typ == self._LIST:
+ return "[%s]" % elements
+ else:
+ return "(%s)" % elements
+
+ def _repr_params(self, params, typ):
+ trunc = self.trunc
+ if typ is self._DICT:
+ return "{%s}" % (
+ ", ".join(
+ "%r: %s" % (key, trunc(value))
+ for key, value in params.items()
+ )
+ )
+ elif typ is self._TUPLE:
+ return "(%s%s)" % (
+ ", ".join(trunc(value) for value in params),
+ "," if len(params) == 1 else "",
+ )
+ else:
+ return "[%s]" % (", ".join(trunc(value) for value in params))
+
+
+def adapt_criterion_to_null(crit, nulls):
+ """given criterion containing bind params, convert selected elements
+ to IS NULL.
+
+ """
+
+ def visit_binary(binary):
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
+ # reverse order if the NULL is on the left side
+ binary.left = binary.right
+ binary.right = Null()
+ binary.operator = operators.is_
+ binary.negate = operators.is_not
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
+ binary.right = Null()
+ binary.operator = operators.is_
+ binary.negate = operators.is_not
+
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
+
+
+def splice_joins(left, right, stop_on=None):
+ if left is None:
+ return right
+
+ stack = [(right, None)]
+
+ adapter = ClauseAdapter(left)
+ ret = None
+ while stack:
+ (right, prevright) = stack.pop()
+ if isinstance(right, Join) and right is not stop_on:
+ right = right._clone()
+ right.onclause = adapter.traverse(right.onclause)
+ stack.append((right.left, right))
+ else:
+ right = adapter.traverse(right)
+ if prevright is not None:
+ prevright.left = right
+ if ret is None:
+ ret = right
+
+ return ret
+
+
+def reduce_columns(columns, *clauses, **kw):
+ r"""given a list of columns, return a 'reduced' set based on natural
+ equivalents.
+
+ the set is reduced to the smallest list of columns which have no natural
+ equivalent present in the list. A "natural equivalent" means that two
+ columns will ultimately represent the same value because they are related
+ by a foreign key.
+
+ \*clauses is an optional list of join clauses which will be traversed
+ to further identify columns that are "equivalent".
+
+ \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
+ whose tables are not yet configured, or columns that aren't yet present.
+
+ This function is primarily used to determine the most minimal "primary
+ key" from a selectable, by reducing the set of primary key columns present
+ in the selectable to just those that are not repeated.
+
+ """
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
+
+ columns = util.ordered_column_set(columns)
+
+ omit = util.column_set()
+ for col in columns:
+ for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
+ for c in columns:
+ if c is col:
+ continue
+ try:
+ fk_col = fk.column
+ except exc.NoReferencedColumnError:
+ # TODO: add specific coverage here
+ # to test/sql/test_selectable ReduceTest
+ if ignore_nonexistent_tables:
+ continue
+ else:
+ raise
+ except exc.NoReferencedTableError:
+ # TODO: add specific coverage here
+ # to test/sql/test_selectable ReduceTest
+ if ignore_nonexistent_tables:
+ continue
+ else:
+ raise
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
+ omit.add(col)
+ break
+
+ if clauses:
+
+ def visit_binary(binary):
+ if binary.operator == operators.eq:
+ cols = util.column_set(
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
+ if binary.left in cols and binary.right in cols:
+ for c in reversed(columns):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
+ omit.add(c)
+ break
+
+ for clause in clauses:
+ if clause is not None:
+ visitors.traverse(clause, {}, {"binary": visit_binary})
+
+ return ColumnSet(columns.difference(omit))
+
+
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
+ """traverse an expression and locate binary criterion pairs."""
+
+ if consider_as_foreign_keys and consider_as_referenced_keys:
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
+
+ def col_is(a, b):
+ # return a is b
+ return a.compare(b)
+
+ def visit_binary(binary):
+ if not any_operator and binary.operator is not operators.eq:
+ return
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
+ return
+
+ if consider_as_foreign_keys:
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
+ pairs.append((binary.right, binary.left))
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
+ pairs.append((binary.left, binary.right))
+ elif consider_as_referenced_keys:
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
+ pairs.append((binary.left, binary.right))
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
+ pairs.append((binary.right, binary.left))
+ else:
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
+ if binary.left.references(binary.right):
+ pairs.append((binary.right, binary.left))
+ elif binary.right.references(binary.left):
+ pairs.append((binary.left, binary.right))
+
+ pairs = []
+ visitors.traverse(expression, {}, {"binary": visit_binary})
+ return pairs
+
+
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
+ """Clones and modifies clauses based on column correspondence.
+
+ E.g.::
+
+ table1 = Table('sometable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+ table2 = Table('someothertable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+
+ condition = table1.c.col1 == table2.c.col1
+
+ make an alias of table1::
+
+ s = table1.alias('foo')
+
+ calling ``ClauseAdapter(s).traverse(condition)`` converts
+ condition to read::
+
+ s.c.col1 == table2.c.col1
+
+ """
+
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ adapt_from_selectables=None,
+ ):
+ self.__traverse_options__ = {
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
+ self.selectable = selectable
+ self.include_fn = include_fn
+ self.exclude_fn = exclude_fn
+ self.equivalents = util.column_dict(equivalents or {})
+ self.adapt_on_names = adapt_on_names
+ self.adapt_from_selectables = adapt_from_selectables
+
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
+
+ newcol = self.selectable.corresponding_column(
+ col, require_embedded=require_embedded
+ )
+ if newcol is None and col in self.equivalents and col not in _seen:
+ for equiv in self.equivalents[col]:
+ newcol = self._corresponding_column(
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
+ if newcol is not None:
+ return newcol
+ if self.adapt_on_names and newcol is None:
+ newcol = self.selectable.exported_columns.get(col.name)
+ return newcol
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def replace(self, col, _include_singleton_constants=False):
+ functions = util.preloaded.sql_functions
+
+ if isinstance(col, FromClause) and not isinstance(
+ col, functions.FunctionElement
+ ):
+
+ if self.selectable.is_derived_from(col):
+ if self.adapt_from_selectables:
+ for adp in self.adapt_from_selectables:
+ if adp.is_derived_from(col):
+ break
+ else:
+ return None
+ return self.selectable
+ elif isinstance(col, Alias) and isinstance(
+ col.element, TableClause
+ ):
+ # we are a SELECT statement and not derived from an alias of a
+ # table (which nonetheless may be a table our SELECT derives
+ # from), so return the alias to prevent further traversal
+ # or
+ # we are an alias of a table and we are not derived from an
+ # alias of a table (which nonetheless may be the same table
+ # as ours) so, same thing
+ return col
+ else:
+ # other cases where we are a selectable and the element
+ # is another join or selectable that contains a table which our
+ # selectable derives from, that we want to process
+ return None
+
+ elif not isinstance(col, ColumnElement):
+ return None
+ elif not _include_singleton_constants and col._is_singleton_constant:
+ # dont swap out NULL, TRUE, FALSE for a label name
+ # in a SQL statement that's being rewritten,
+ # leave them as the constant. This is first noted in #6259,
+ # however the logic to check this moved here as of #7154 so that
+ # it is made specific to SQL rewriting and not all column
+ # correspondence
+ return None
+
+ if "adapt_column" in col._annotations:
+ col = col._annotations["adapt_column"]
+
+ if self.adapt_from_selectables and col not in self.equivalents:
+ for adp in self.adapt_from_selectables:
+ if adp.c.corresponding_column(col, False) is not None:
+ break
+ else:
+ return None
+
+ if self.include_fn and not self.include_fn(col):
+ return None
+ elif self.exclude_fn and self.exclude_fn(col):
+ return None
+ else:
+ return self._corresponding_column(col, True)
+
+
+class ColumnAdapter(ClauseAdapter):
+ """Extends ClauseAdapter with extra utility functions.
+
+ Key aspects of ColumnAdapter include:
+
+ * Expressions that are adapted are stored in a persistent
+ .columns collection; so that an expression E adapted into
+ an expression E1, will return the same object E1 when adapted
+ a second time. This is important in particular for things like
+ Label objects that are anonymized, so that the ColumnAdapter can
+ be used to present a consistent "adapted" view of things.
+
+ * Exclusion of items from the persistent collection based on
+ include/exclude rules, but also independent of hash identity.
+ This because "annotated" items all have the same hash identity as their
+ parent.
+
+ * "wrapping" capability is added, so that the replacement of an expression
+ E can proceed through a series of adapters. This differs from the
+ visitor's "chaining" feature in that the resulting object is passed
+ through all replacing functions unconditionally, rather than stopping
+ at the first one that returns non-None.
+
+ * An adapt_required option, used by eager loading to indicate that
+ We don't trust a result row column that is not translated.
+ This is to prevent a column from being interpreted as that
+ of the child row in a self-referential scenario, see
+ inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
+
+ """
+
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ adapt_from_selectables=None,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ adapt_from_selectables=adapt_from_selectables,
+ )
+
+ self.columns = util.WeakPopulateDict(self._locate_col)
+ if self.include_fn or self.exclude_fn:
+ self.columns = self._IncludeExcludeMapping(self, self.columns)
+ self.adapt_required = adapt_required
+ self.allow_label_resolve = allow_label_resolve
+ self._wrap = None
+
+ class _IncludeExcludeMapping(object):
+ def __init__(self, parent, columns):
+ self.parent = parent
+ self.columns = columns
+
+ def __getitem__(self, key):
+ if (
+ self.parent.include_fn and not self.parent.include_fn(key)
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
+ if self.parent._wrap:
+ return self.parent._wrap.columns[key]
+ else:
+ return key
+ return self.columns[key]
+
+ def wrap(self, adapter):
+ ac = self.__class__.__new__(self.__class__)
+ ac.__dict__.update(self.__dict__)
+ ac._wrap = adapter
+ ac.columns = util.WeakPopulateDict(ac._locate_col)
+ if ac.include_fn or ac.exclude_fn:
+ ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
+
+ return ac
+
+ def traverse(self, obj):
+ return self.columns[obj]
+
+ adapt_clause = traverse
+ adapt_list = ClauseAdapter.copy_and_process
+
+ def adapt_check_present(self, col):
+ newcol = self.columns[col]
+
+ if newcol is col and self._corresponding_column(col, True) is None:
+ return None
+
+ return newcol
+
+ def _locate_col(self, col):
+ # both replace and traverse() are overly complicated for what
+ # we are doing here and we would do better to have an inlined
+ # version that doesn't build up as much overhead. the issue is that
+ # sometimes the lookup does in fact have to adapt the insides of
+ # say a labeled scalar subquery. However, if the object is an
+ # Immutable, i.e. Column objects, we can skip the "clone" /
+ # "copy internals" part since those will be no-ops in any case.
+ # additionally we want to catch singleton objects null/true/false
+ # and make sure they are adapted as well here.
+
+ if col._is_immutable:
+ for vis in self.visitor_iterator:
+ c = vis.replace(col, _include_singleton_constants=True)
+ if c is not None:
+ break
+ else:
+ c = col
+ else:
+ c = ClauseAdapter.traverse(self, col)
+
+ if self._wrap:
+ c2 = self._wrap._locate_col(c)
+ if c2 is not None:
+ c = c2
+
+ if self.adapt_required and c is col:
+ return None
+
+ c._allow_label_resolve = self.allow_label_resolve
+
+ return c
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ del d["columns"]
+ return d
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self.columns = util.WeakPopulateDict(self._locate_col)
+
+
+def _offset_or_limit_clause(element, name=None, type_=None):
+ """Convert the given value to an "offset or limit" clause.
+
+ This handles incoming integers and converts to an expression; if
+ an expression is already given, it is passed through.
+
+ """
+ return coercions.expect(
+ roles.LimitOffsetRole, element, name=name, type_=type_
+ )
+
+
+def _offset_or_limit_clause_asint_if_possible(clause):
+ """Return the offset or limit clause as a simple integer if possible,
+ else return the clause.
+
+ """
+ if clause is None:
+ return None
+ if hasattr(clause, "_limit_offset_value"):
+ value = clause._limit_offset_value
+ return util.asint(value)
+ else:
+ return clause
+
+
+def _make_slice(limit_clause, offset_clause, start, stop):
+ """Compute LIMIT/OFFSET in terms of slice start/end"""
+
+ # for calculated limit/offset, try to do the addition of
+ # values to offset in Python, however if a SQL clause is present
+ # then the addition has to be on the SQL side.
+ if start is not None and stop is not None:
+ offset_clause = _offset_or_limit_clause_asint_if_possible(
+ offset_clause
+ )
+ if offset_clause is None:
+ offset_clause = 0
+
+ if start != 0:
+ offset_clause = offset_clause + start
+
+ if offset_clause == 0:
+ offset_clause = None
+ else:
+ offset_clause = _offset_or_limit_clause(offset_clause)
+
+ limit_clause = _offset_or_limit_clause(stop - start)
+
+ elif start is None and stop is not None:
+ limit_clause = _offset_or_limit_clause(stop)
+ elif start is not None and stop is None:
+ offset_clause = _offset_or_limit_clause_asint_if_possible(
+ offset_clause
+ )
+ if offset_clause is None:
+ offset_clause = 0
+
+ if start != 0:
+ offset_clause = offset_clause + start
+
+ if offset_clause == 0:
+ offset_clause = None
+ else:
+ offset_clause = _offset_or_limit_clause(offset_clause)
+
+ return limit_clause, offset_clause
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
new file mode 100644
index 0000000..f72d83a
--- /dev/null
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -0,0 +1,852 @@
+# sql/visitors.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Visitor/traversal interface and library functions.
+
+SQLAlchemy schema and expression constructs rely on a Python-centric
+version of the classic "visitor" pattern as the primary way in which
+they apply functionality. The most common use of this pattern
+is statement compilation, where individual expression classes match
+up to rendering methods that produce a string result. Beyond this,
+the visitor system is also used to inspect expressions for various
+information and patterns, as well as for the purposes of applying
+transformations to expressions.
+
+Examples of how the visit system is used can be seen in the source code
+of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
+modules. Some background on clause adaption is also at
+https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
+
+"""
+
+from collections import deque
+import itertools
+import operator
+
+from .. import exc
+from .. import util
+from ..util import langhelpers
+from ..util import symbol
+
+__all__ = [
+ "iterate",
+ "traverse_using",
+ "traverse",
+ "cloned_traverse",
+ "replacement_traverse",
+ "Traversible",
+ "TraversibleType",
+ "ExternalTraversal",
+ "InternalTraversal",
+]
+
+
+def _generate_compiler_dispatch(cls):
+ """Generate a _compiler_dispatch() external traversal on classes with a
+ __visit_name__ attribute.
+
+ """
+ visit_name = cls.__visit_name__
+
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
+ if not isinstance(visit_name, util.compat.string_types):
+ raise exc.InvalidRequestError(
+ "__visit_name__ on class %s must be a string at the class level"
+ % cls.__name__
+ )
+
+ name = "visit_%s" % visit_name
+ getter = operator.attrgetter(name)
+
+ def _compiler_dispatch(self, visitor, **kw):
+ """Look for an attribute named "visit_<visit_name>" on the
+ visitor, and call it with the same kw params.
+
+ """
+ try:
+ meth = getter(visitor)
+ except AttributeError as err:
+ return visitor.visit_unsupported_compilation(self, err, **kw)
+
+ else:
+ return meth(self, **kw)
+
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
+
+
+class TraversibleType(type):
+ """Metaclass which assigns dispatch attributes to various kinds of
+ "visitable" classes.
+
+ Attributes include:
+
+ * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+ This is called "external traversal" because the caller of each visit()
+ method is responsible for sub-traversing the inner elements of each
+ object. This is appropriate for string compilers and other traversals
+ that need to call upon the inner elements in a specific pattern.
+
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+ an optional ``_traverse_internals`` collection of symbols which comes
+ from the :class:`.InternalTraversal` list of symbols. This is called
+ "internal traversal" MARKMARK
+
+ """
+
+ def __init__(cls, clsname, bases, clsdict):
+ if clsname != "Traversible":
+ if "__visit_name__" in clsdict:
+ _generate_compiler_dispatch(cls)
+
+ super(TraversibleType, cls).__init__(clsname, bases, clsdict)
+
+
+class Traversible(util.with_metaclass(TraversibleType)):
+ """Base class for visitable objects, applies the
+ :class:`.visitors.TraversibleType` metaclass.
+
+ """
+
+ def __class_getitem__(cls, key):
+ # allow generic classes in py3.9+
+ return cls
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Traversible`
+ elements of this :class:`.visitors.Traversible`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class _InternalTraversalType(type):
+ def __init__(cls, clsname, bases, clsdict):
+ if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+ lookup = {}
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = visitor.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ # print(meth_text)
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+ r"""Defines visitor symbols used for internal traversal.
+
+ The :class:`.InternalTraversal` class is used in two ways. One is that
+ it can serve as the superclass for an object that implements the
+ various visit methods of the class. The other is that the symbols
+ themselves of :class:`.InternalTraversal` are used within
+ the ``_traverse_internals`` collection. Such as, the :class:`.Case`
+ object defines ``_traverse_internals`` as ::
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ Above, the :class:`.Case` class indicates its internal state as the
+ attributes named ``value``, ``whens``, and ``else_``. They each
+ link to an :class:`.InternalTraversal` method which indicates the type
+ of datastructure referred towards.
+
+ Using the ``_traverse_internals`` structure, objects of type
+ :class:`.InternalTraversible` will have the following methods automatically
+ implemented:
+
+ * :meth:`.Traversible.get_children`
+
+ * :meth:`.Traversible._copy_internals`
+
+ * :meth:`.Traversible._gen_cache_key`
+
+ Subclasses can also implement these methods directly, particularly for the
+ :meth:`.Traversible._copy_internals` method, when special steps
+ are needed.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`.InternalTraversal`, return the
+ corresponding method on a subclass.
+
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self, target_cls, internal_dispatch, generate_dispatcher_name
+ ):
+ dispatcher = _generate_dispatcher(
+ self, internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ dp_has_cache_key = symbol("HC")
+ """Visit a :class:`.HasCacheKey` object."""
+
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
+ dp_clauseelement = symbol("CE")
+ """Visit a :class:`_expression.ClauseElement` object."""
+
+ dp_fromclause_canonical_column_collection = symbol("FC")
+ """Visit a :class:`_expression.FromClause` object in the context of the
+ ``columns`` attribute.
+
+ The column collection is "canonical", meaning it is the originally
+ defined location of the :class:`.ColumnClause` objects. Right now
+ this means that the object being visited is a
+ :class:`_expression.TableClause`
+ or :class:`_schema.Table` object only.
+
+ """
+
+ dp_clauseelement_tuples = symbol("CTS")
+ """Visit a list of tuples which contain :class:`_expression.ClauseElement`
+ objects.
+
+ """
+
+ dp_clauseelement_list = symbol("CL")
+ """Visit a list of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_clauseelement_tuple = symbol("CT")
+ """Visit a tuple of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_executable_options = symbol("EO")
+
+ dp_with_context_options = symbol("WC")
+
+ dp_fromclause_ordered_set = symbol("CO")
+ """Visit an ordered set of :class:`_expression.FromClause` objects. """
+
+ dp_string = symbol("S")
+ """Visit a plain string value.
+
+ Examples include table and column names, bound parameter keys, special
+ keywords such as "UNION", "UNION ALL".
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_string_list = symbol("SL")
+ """Visit a list of strings."""
+
+ dp_anon_name = symbol("AN")
+ """Visit a potentially "anonymized" string value.
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_boolean = symbol("B")
+ """Visit a boolean value.
+
+ The boolean value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_operator = symbol("O")
+ """Visit an operator.
+
+ The operator is a function from the :mod:`sqlalchemy.sql.operators`
+ module.
+
+ The operator value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_type = symbol("T")
+ """Visit a :class:`.TypeEngine` object
+
+ The type object is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_plain_dict = symbol("PD")
+ """Visit a dictionary with string keys.
+
+ The keys of the dictionary should be strings, the values should
+ be immutable and hashable. The dictionary is considered to be
+ significant for cache key generation.
+
+ """
+
+ dp_dialect_options = symbol("DO")
+ """Visit a dialect options structure."""
+
+ dp_string_clauseelement_dict = symbol("CD")
+ """Visit a dictionary of string keys to :class:`_expression.ClauseElement`
+ objects.
+
+ """
+
+ dp_string_multi_dict = symbol("MD")
+ """Visit a dictionary of string keys to values which may either be
+ plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+ """
+
+ dp_annotations_key = symbol("AK")
+ """Visit the _annotations_cache_key element.
+
+ This is a dictionary of additional information about a ClauseElement
+ that modifies its role. It should be included when comparing or caching
+ objects, however generating this key is relatively expensive. Visitors
+ should check the "_annotations" dict for non-None first before creating
+ this key.
+
+ """
+
+ dp_plain_obj = symbol("PO")
+ """Visit a plain python object.
+
+ The value should be immutable and hashable, such as an integer.
+ The value is considered to be significant for cache key generation.
+
+ """
+
+ dp_named_ddl_element = symbol("DD")
+ """Visit a simple named DDL element.
+
+ The current object used by this method is the :class:`.Sequence`.
+
+ The object is only considered to be important for cache key generation
+ as far as its name, but not any other aspects of it.
+
+ """
+
+ dp_prefix_sequence = symbol("PS")
+ """Visit the sequence represented by :class:`_expression.HasPrefixes`
+ or :class:`_expression.HasSuffixes`.
+
+ """
+
+ dp_table_hint_list = symbol("TH")
+ """Visit the ``_hints`` collection of a :class:`_expression.Select`
+ object.
+
+ """
+
+ dp_setup_join_tuple = symbol("SJ")
+
+ dp_memoized_select_entities = symbol("ME")
+
+ dp_statement_hint_list = symbol("SH")
+ """Visit the ``_statement_hints`` collection of a
+ :class:`_expression.Select`
+ object.
+
+ """
+
+ dp_unknown_structure = symbol("UK")
+ """Visit an unknown structure.
+
+ """
+
+ dp_dml_ordered_values = symbol("DML_OV")
+ """Visit the values() ordered tuple list of an
+ :class:`_expression.Update` object."""
+
+ dp_dml_values = symbol("DML_V")
+ """Visit the values() dictionary of a :class:`.ValuesBase`
+ (e.g. Insert or Update) object.
+
+ """
+
+ dp_dml_multi_values = symbol("DML_MV")
+ """Visit the values() multi-valued list of dictionaries of an
+ :class:`_expression.Insert` object.
+
+ """
+
+ dp_propagate_attrs = symbol("PA")
+ """Visit the propagate attrs dict. This hardcodes to the particular
+ elements we care about right now."""
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+ """Defines additional symbols that are useful in caching applications.
+
+ Traversals for :class:`_expression.ClauseElement` objects only need to use
+ those symbols present in :class:`.InternalTraversal`. However, for
+ additional caching use cases within the ORM, symbols dealing with the
+ :class:`.HasCacheKey` class are added here.
+
+ """
+
+ dp_ignore = symbol("IG")
+ """Specify an object that should be ignored entirely.
+
+ This currently applies function call argument caching where some
+ arguments should not be considered to be part of a cache key.
+
+ """
+
+ dp_inspectable = symbol("IS")
+ """Visit an inspectable object where the return value is a
+ :class:`.HasCacheKey` object."""
+
+ dp_multi = symbol("M")
+ """Visit an object that may be a :class:`.HasCacheKey` or may be a
+ plain hashable object."""
+
+ dp_multi_list = symbol("MT")
+ """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+ may be a plain hashable object."""
+
+ dp_has_cache_key_tuples = symbol("HT")
+ """Visit a list of tuples which contain :class:`.HasCacheKey`
+ objects.
+
+ """
+
+ dp_inspectable_list = symbol("IL")
+ """Visit a list of inspectable objects which upon inspection are
+ HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+ """Base class for visitor objects which can traverse externally using
+ the :func:`.visitors.traverse` function.
+
+ Direct usage of the :func:`.visitors.traverse` function is usually
+ preferred.
+
+ """
+
+ __traverse_options__ = {}
+
+ def traverse_single(self, obj, **kw):
+ for v in self.visitor_iterator:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kw)
+
+ def iterate(self, obj):
+ """Traverse the given expression structure, returning an iterator
+ of all elements.
+
+ """
+ return iterate(obj, self.__traverse_options__)
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ return traverse(obj, self.__traverse_options__, self._visitor_dict)
+
+ @util.memoized_property
+ def _visitor_dict(self):
+ visitors = {}
+
+ for name in dir(self):
+ if name.startswith("visit_"):
+ visitors[name[6:]] = getattr(self, name)
+ return visitors
+
+ @property
+ def visitor_iterator(self):
+ """Iterate through this visitor and each 'chained' visitor."""
+
+ v = self
+ while v:
+ yield v
+ v = getattr(v, "_next", None)
+
+ def chain(self, visitor):
+ """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+
+ The chained visitor will receive all visit events after this one.
+
+ """
+ tail = list(self.visitor_iterator)[-1]
+ tail._next = visitor
+ return self
+
+
+class CloningExternalTraversal(ExternalTraversal):
+ """Base class for visitor objects which can traverse using
+ the :func:`.visitors.cloned_traverse` function.
+
+ Direct usage of the :func:`.visitors.cloned_traverse` function is usually
+ preferred.
+
+
+ """
+
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return
+ the new list.
+
+ """
+ return [self.traverse(x) for x in list_]
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ return cloned_traverse(
+ obj, self.__traverse_options__, self._visitor_dict
+ )
+
+
+class ReplacingExternalTraversal(CloningExternalTraversal):
+ """Base class for visitor objects which can traverse using
+ the :func:`.visitors.replacement_traverse` function.
+
+ Direct usage of the :func:`.visitors.replacement_traverse` function is
+ usually preferred.
+
+ """
+
+ def replace(self, elem):
+ """Receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ def replace(elem):
+ for v in self.visitor_iterator:
+ e = v.replace(elem)
+ if e is not None:
+ return e
+
+ return replacement_traverse(obj, self.__traverse_options__, replace)
+
+
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
+def iterate(obj, opts=util.immutabledict()):
+ r"""Traverse the given expression structure, returning an iterator.
+
+ Traversal is configured to be breadth-first.
+
+ The central API feature used by the :func:`.visitors.iterate`
+ function is the
+ :meth:`_expression.ClauseElement.get_children` method of
+ :class:`_expression.ClauseElement` objects. This method should return all
+ the :class:`_expression.ClauseElement` objects which are associated with a
+ particular :class:`_expression.ClauseElement` object. For example, a
+ :class:`.Case` structure will refer to a series of
+ :class:`_expression.ColumnElement` objects within its "whens" and "else\_"
+ member variables.
+
+ :param obj: :class:`_expression.ClauseElement` structure to be traversed
+
+ :param opts: dictionary of iteration options. This dictionary is usually
+ empty in modern usage.
+
+ """
+ yield obj
+ children = obj.get_children(**opts)
+
+ if not children:
+ return
+
+ stack = deque([children])
+ while stack:
+ t_iterator = stack.popleft()
+ for t in t_iterator:
+ yield t
+ stack.append(t.get_children(**opts))
+
+
+def traverse_using(iterator, obj, visitors):
+ """Visit the given expression structure using the given iterator of
+ objects.
+
+ :func:`.visitors.traverse_using` is usually called internally as the result
+ of the :func:`.visitors.traverse` function.
+
+ :param iterator: an iterable or sequence which will yield
+ :class:`_expression.ClauseElement`
+ structures; the iterator is assumed to be the
+ product of the :func:`.visitors.iterate` function.
+
+ :param obj: the :class:`_expression.ClauseElement`
+ that was used as the target of the
+ :func:`.iterate` function.
+
+ :param visitors: dictionary of visit functions. See :func:`.traverse`
+ for details on this dictionary.
+
+ .. seealso::
+
+ :func:`.traverse`
+
+
+ """
+ for target in iterator:
+ meth = visitors.get(target.__visit_name__, None)
+ if meth:
+ meth(target)
+ return obj
+
+
+def traverse(obj, opts, visitors):
+ """Traverse and visit the given expression structure using the default
+ iterator.
+
+ e.g.::
+
+ from sqlalchemy.sql import visitors
+
+ stmt = select(some_table).where(some_table.c.foo == 'bar')
+
+ def visit_bindparam(bind_param):
+ print("found bound value: %s" % bind_param.value)
+
+ visitors.traverse(stmt, {}, {"bindparam": visit_bindparam})
+
+ The iteration of objects uses the :func:`.visitors.iterate` function,
+ which does a breadth-first traversal using a stack.
+
+ :param obj: :class:`_expression.ClauseElement` structure to be traversed
+
+ :param opts: dictionary of iteration options. This dictionary is usually
+ empty in modern usage.
+
+ :param visitors: dictionary of visit functions. The dictionary should
+ have strings as keys, each of which would correspond to the
+ ``__visit_name__`` of a particular kind of SQL expression object, and
+ callable functions as values, each of which represents a visitor function
+ for that kind of object.
+
+ """
+ return traverse_using(iterate(obj, opts), obj, visitors)
+
+
+def cloned_traverse(obj, opts, visitors):
+ """Clone the given expression structure, allowing modifications by
+ visitors.
+
+ Traversal usage is the same as that of :func:`.visitors.traverse`.
+ The visitor functions present in the ``visitors`` dictionary may also
+ modify the internals of the given structure as the traversal proceeds.
+
+ The central API feature used by the :func:`.visitors.cloned_traverse`
+ and :func:`.visitors.replacement_traverse` functions, in addition to the
+ :meth:`_expression.ClauseElement.get_children`
+ function that is used to achieve
+ the iteration, is the :meth:`_expression.ClauseElement._copy_internals`
+ method.
+ For a :class:`_expression.ClauseElement`
+ structure to support cloning and replacement
+ traversals correctly, it needs to be able to pass a cloning function into
+ its internal members in order to make copies of them.
+
+ .. seealso::
+
+ :func:`.visitors.traverse`
+
+ :func:`.visitors.replacement_traverse`
+
+ """
+
+ cloned = {}
+ stop_on = set(opts.get("stop_on", []))
+
+ def deferred_copy_internals(obj):
+ return cloned_traverse(obj, opts, visitors)
+
+ def clone(elem, **kw):
+ if elem in stop_on:
+ return elem
+ else:
+ if id(elem) not in cloned:
+
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id(elem)] = newelem
+ return newelem
+
+ cloned[id(elem)] = newelem = elem._clone(clone=clone, **kw)
+ newelem._copy_internals(clone=clone, **kw)
+ meth = visitors.get(newelem.__visit_name__, None)
+ if meth:
+ meth(newelem)
+ return cloned[id(elem)]
+
+ if obj is not None:
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
+ clone = None # remove gc cycles
+ return obj
+
+
+def replacement_traverse(obj, opts, replace):
+ """Clone the given expression structure, allowing element
+ replacement by a given replacement function.
+
+ This function is very similar to the :func:`.visitors.cloned_traverse`
+ function, except instead of being passed a dictionary of visitors, all
+ elements are unconditionally passed into the given replace function.
+ The replace function then has the option to return an entirely new object
+ which will replace the one given. If it returns ``None``, then the object
+ is kept in place.
+
+ The difference in usage between :func:`.visitors.cloned_traverse` and
+ :func:`.visitors.replacement_traverse` is that in the former case, an
+ already-cloned object is passed to the visitor function, and the visitor
+ function can then manipulate the internal state of the object.
+ In the case of the latter, the visitor function should only return an
+ entirely different object, or do nothing.
+
+ The use case for :func:`.visitors.replacement_traverse` is that of
+ replacing a FROM clause inside of a SQL structure with a different one,
+ as is a common use case within the ORM.
+
+ """
+
+ cloned = {}
+ stop_on = {id(x) for x in opts.get("stop_on", [])}
+
+ def deferred_copy_internals(obj):
+ return replacement_traverse(obj, opts, replace)
+
+ def clone(elem, **kw):
+ if (
+ id(elem) in stop_on
+ or "no_replacement_traverse" in elem._annotations
+ ):
+ return elem
+ else:
+ newelem = replace(elem)
+ if newelem is not None:
+ stop_on.add(id(newelem))
+ return newelem
+ else:
+ # base "already seen" on id(), not hash, so that we don't
+ # replace an Annotated element with its non-annotated one, and
+ # vice versa
+ id_elem = id(elem)
+ if id_elem not in cloned:
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id_elem] = newelem
+ return newelem
+
+ cloned[id_elem] = newelem = elem._clone(**kw)
+ newelem._copy_internals(clone=clone, **kw)
+ return cloned[id_elem]
+
+ if obj is not None:
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
+ clone = None # remove gc cycles
+ return obj
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
new file mode 100644
index 0000000..80d344f
--- /dev/null
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -0,0 +1,86 @@
+# testing/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from . import config
+from . import mock
+from .assertions import assert_raises
+from .assertions import assert_raises_context_ok
+from .assertions import assert_raises_message
+from .assertions import assert_raises_message_context_ok
+from .assertions import assert_warns
+from .assertions import assert_warns_message
+from .assertions import AssertsCompiledSQL
+from .assertions import AssertsExecutionResults
+from .assertions import ComparesTables
+from .assertions import emits_warning
+from .assertions import emits_warning_on
+from .assertions import eq_
+from .assertions import eq_ignore_whitespace
+from .assertions import eq_regex
+from .assertions import expect_deprecated
+from .assertions import expect_deprecated_20
+from .assertions import expect_raises
+from .assertions import expect_raises_message
+from .assertions import expect_warnings
+from .assertions import in_
+from .assertions import is_
+from .assertions import is_false
+from .assertions import is_instance_of
+from .assertions import is_none
+from .assertions import is_not
+from .assertions import is_not_
+from .assertions import is_not_none
+from .assertions import is_true
+from .assertions import le_
+from .assertions import ne_
+from .assertions import not_in
+from .assertions import not_in_
+from .assertions import startswith_
+from .assertions import uses_deprecated
+from .config import async_test
+from .config import combinations
+from .config import combinations_list
+from .config import db
+from .config import fixture
+from .config import requirements as requires
+from .exclusions import _is_excluded
+from .exclusions import _server_version
+from .exclusions import against as _against
+from .exclusions import db_spec
+from .exclusions import exclude
+from .exclusions import fails
+from .exclusions import fails_if
+from .exclusions import fails_on
+from .exclusions import fails_on_everything_except
+from .exclusions import future
+from .exclusions import only_if
+from .exclusions import only_on
+from .exclusions import skip
+from .exclusions import skip_if
+from .schema import eq_clause_element
+from .schema import eq_type_affinity
+from .util import adict
+from .util import fail
+from .util import flag_combinations
+from .util import force_drop_names
+from .util import lambda_combinations
+from .util import metadata_fixture
+from .util import provide_metadata
+from .util import resolve_lambda
+from .util import rowset
+from .util import run_as_contextmanager
+from .util import teardown_events
+from .warnings import assert_warnings
+from .warnings import warn_test_suite
+
+
+def against(*queries):
+ return _against(config._current, *queries)
+
+
+crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
new file mode 100644
index 0000000..9a3c06b
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -0,0 +1,845 @@
+# testing/assertions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import contextlib
+import re
+import sys
+import warnings
+
+from . import assertsql
+from . import config
+from . import engines
+from . import mock
+from .exclusions import db_spec
+from .util import fail
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import types as sqltypes
+from .. import util
+from ..engine import default
+from ..engine import url
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import compat
+from ..util import decorator
+
+
+def expect_warnings(*messages, **kw):
+ """Context manager which expects one or more warnings.
+
+ With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via
+ sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
+ pass string expressions that will match selected warnings via regex;
+ all non-matching warnings are sent through.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ Note that the test suite sets SAWarning warnings to raise exceptions.
+
+ """ # noqa
+ return _expect_warnings(
+ (sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw
+ )
+
+
+@contextlib.contextmanager
+def expect_warnings_on(db, *messages, **kw):
+ """Context manager which expects one or more warnings on specific
+ dialects.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ """
+ spec = db_spec(db)
+
+ if isinstance(db, util.string_types) and not spec(config._current):
+ yield
+ else:
+ with expect_warnings(*messages, **kw):
+ yield
+
+
+def emits_warning(*messages):
+ """Decorator form of expect_warnings().
+
+ Note that emits_warning does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings(assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def expect_deprecated(*messages, **kw):
+ return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
+
+
+def expect_deprecated_20(*messages, **kw):
+ return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw)
+
+
+def emits_warning_on(db, *messages):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+
+ Note that emits_warning_on does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings_on(db, assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+
+ Note that uses_deprecated does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_deprecated(*messages, assert_=False):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+_FILTERS = None
+_SEEN = None
+_EXC_CLS = None
+
+
+@contextlib.contextmanager
+def _expect_warnings(
+ exc_cls,
+ messages,
+ regex=True,
+ search_msg=False,
+ assert_=True,
+ py2konly=False,
+ raise_on_any_unexpected=False,
+ squelch_other_warnings=False,
+):
+
+ global _FILTERS, _SEEN, _EXC_CLS
+
+ if regex or search_msg:
+ filters = [re.compile(msg, re.I | re.S) for msg in messages]
+ else:
+ filters = list(messages)
+
+ if _FILTERS is not None:
+ # nested call; update _FILTERS and _SEEN, return. outer
+ # block will assert our messages
+ assert _SEEN is not None
+ assert _EXC_CLS is not None
+ _FILTERS.extend(filters)
+ _SEEN.update(filters)
+ _EXC_CLS += (exc_cls,)
+ yield
+ else:
+ seen = _SEEN = set(filters)
+ _FILTERS = filters
+ _EXC_CLS = (exc_cls,)
+
+ if raise_on_any_unexpected:
+
+ def real_warn(msg, *arg, **kw):
+ raise AssertionError("Got unexpected warning: %r" % msg)
+
+ else:
+ real_warn = warnings.warn
+
+ def our_warn(msg, *arg, **kw):
+
+ if isinstance(msg, _EXC_CLS):
+ exception = type(msg)
+ msg = str(msg)
+ elif arg:
+ exception = arg[0]
+ else:
+ exception = None
+
+ if not exception or not issubclass(exception, _EXC_CLS):
+ if not squelch_other_warnings:
+ return real_warn(msg, *arg, **kw)
+ else:
+ return
+
+ if not filters and not raise_on_any_unexpected:
+ return
+
+ for filter_ in filters:
+ if (
+ (search_msg and filter_.search(msg))
+ or (regex and filter_.match(msg))
+ or (not regex and filter_ == msg)
+ ):
+ seen.discard(filter_)
+ break
+ else:
+ if not squelch_other_warnings:
+ real_warn(msg, *arg, **kw)
+
+ with mock.patch("warnings.warn", our_warn), mock.patch(
+ "sqlalchemy.util.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.engine.row.LegacyRow._default_key_style", 2
+ ):
+ try:
+ yield
+ finally:
+ _SEEN = _FILTERS = _EXC_CLS = None
+
+ if assert_ and (not py2konly or not compat.py3k):
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+ _assert_no_stray_pool_connections()
+
+
+def _assert_no_stray_pool_connections():
+ engines.testing_reaper.assert_all_closed()
+
+
+def eq_regex(a, b, msg=None):
+ assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+
+def le_(a, b, msg=None):
+ """Assert a <= b, with repr messaging on failure."""
+ assert a <= b, msg or "%r != %r" % (a, b)
+
+
+def is_instance_of(a, b, msg=None):
+ assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
+
+
+def is_none(a, msg=None):
+ is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
+def is_true(a, msg=None):
+ is_(bool(a), True, msg=msg)
+
+
+def is_false(a, msg=None):
+ is_(bool(a), False, msg=msg)
+
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+
+def is_not(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
+# deprecated. See #5429
+is_not_ = is_not
+
+
+def in_(a, b, msg=None):
+ """Assert a in b, with repr messaging on failure."""
+ assert a in b, msg or "%r not in %r" % (a, b)
+
+
+def not_in(a, b, msg=None):
+ """Assert a in not b, with repr messaging on failure."""
+ assert a not in b, msg or "%r is in %r" % (a, b)
+
+
+# deprecated. See #5429
+not_in_ = not_in
+
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a,
+ fragment,
+ )
+
+
+def eq_ignore_whitespace(a, b, msg=None):
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
+
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def _assert_proper_exception_context(exception):
+ """assert that any exception we're catching does not have a __context__
+ without a __cause__, and that __suppress_context__ is never set.
+
+ Python 3 will report nested as exceptions as "during the handling of
+ error X, error Y occurred". That's not what we want to do. we want
+ these exceptions in a cause chain.
+
+ """
+
+ if not util.py3k:
+ return
+
+ if (
+ exception.__context__ is not exception.__cause__
+ and not exception.__suppress_context__
+ ):
+ assert False, (
+ "Exception %r was correctly raised but did not set a cause, "
+ "within context %r as its cause."
+ % (exception, exception.__context__)
+ )
+
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ return _assert_raises(
+ except_cls, callable_, args, kwargs, msg=msg, check_context=True
+ )
+
+
+def assert_warns(except_cls, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+
+ """
+ with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True):
+ return callable_(*args, **kwargs)
+
+
+def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+ Also uses regex.search() to match the given message to the error string
+ rather than regex.match().
+
+ """
+ with _expect_warnings(
+ except_cls,
+ [msg],
+ search_msg=True,
+ regex=False,
+ squelch_other_warnings=True,
+ ):
+ return callable_(*args, **kwargs)
+
+
+def assert_raises_message_context_ok(
+ except_cls, msg, callable_, *args, **kwargs
+):
+ return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+ except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+
+ with _expect_raises(except_cls, msg, check_context) as ec:
+ callable_(*args, **kwargs)
+ return ec.error
+
+
+class _ErrorContainer(object):
+ error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+ if (
+ isinstance(except_cls, type)
+ and issubclass(except_cls, Warning)
+ or isinstance(except_cls, Warning)
+ ):
+ raise TypeError(
+ "Use expect_warnings for warnings, not "
+ "expect_raises / assert_raises"
+ )
+ ec = _ErrorContainer()
+ if check_context:
+ are_we_already_in_a_traceback = sys.exc_info()[0]
+ try:
+ yield ec
+ success = False
+ except except_cls as err:
+ ec.error = err
+ success = True
+ if msg is not None:
+ assert re.search(
+ msg, util.text_type(err), re.UNICODE
+ ), "%r !~ %s" % (msg, err)
+ if check_context and not are_we_already_in_a_traceback:
+ _assert_proper_exception_context(err)
+ print(util.text_type(err).encode("utf-8"))
+
+ # it's generally a good idea to not carry traceback objects outside
+ # of the except: block, but in this case especially we seem to have
+ # hit some bug in either python 3.10.0b2 or greenlet or both which
+ # this seems to fix:
+ # https://github.com/python-greenlet/greenlet/issues/242
+ del ec
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+
+def expect_raises(except_cls, check_context=True):
+ return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+ return _expect_raises(except_cls, msg=msg, check_context=check_context)
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ for_executemany=False,
+ check_literal_execute=None,
+ check_post_param=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ supports_default_values=True,
+ supports_default_metavalue=True,
+ literal_binds=False,
+ render_postcompile=False,
+ schema_translate_map=None,
+ render_schema_translate=False,
+ default_schema_name=None,
+ from_linting=False,
+ ):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif allow_dialect_select:
+ dialect = None
+ else:
+ if dialect is None:
+ dialect = getattr(self, "__dialect__", None)
+
+ if dialect is None:
+ dialect = config.db.dialect
+ elif dialect == "default":
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif dialect == "default_enhanced":
+ dialect = default.StrCompileDialect()
+ elif isinstance(dialect, util.string_types):
+ dialect = url.URL.create(dialect).get_dialect()()
+
+ if default_schema_name:
+ dialect.default_schema_name = default_schema_name
+
+ kw = {}
+ compile_kwargs = {}
+
+ if schema_translate_map:
+ kw["schema_translate_map"] = schema_translate_map
+
+ if params is not None:
+ kw["column_keys"] = list(params)
+
+ if literal_binds:
+ compile_kwargs["literal_binds"] = True
+
+ if render_postcompile:
+ compile_kwargs["render_postcompile"] = True
+
+ if for_executemany:
+ kw["for_executemany"] = True
+
+ if render_schema_translate:
+ kw["render_schema_translate"] = True
+
+ if from_linting or getattr(self, "assert_from_linting", False):
+ kw["linting"] = sql.FROM_LINTING
+
+ from sqlalchemy import orm
+
+ if isinstance(clause, orm.Query):
+ stmt = clause._statement_20()
+ stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+ clause = stmt
+
+ if compile_kwargs:
+ kw["compile_kwargs"] = compile_kwargs
+
+ class DontAccess(object):
+ def __getattribute__(self, key):
+ raise NotImplementedError(
+ "compiler accessed .statement; use "
+ "compiler.current_executable"
+ )
+
+ class CheckCompilerAccess(object):
+ def __init__(self, test_statement):
+ self.test_statement = test_statement
+ self._annotations = {}
+ self.supports_execution = getattr(
+ test_statement, "supports_execution", False
+ )
+
+ if self.supports_execution:
+ self._execution_options = test_statement._execution_options
+
+ if hasattr(test_statement, "_returning"):
+ self._returning = test_statement._returning
+ if hasattr(test_statement, "_inline"):
+ self._inline = test_statement._inline
+ if hasattr(test_statement, "_return_defaults"):
+ self._return_defaults = test_statement._return_defaults
+
+ def _default_dialect(self):
+ return self.test_statement._default_dialect()
+
+ def compile(self, dialect, **kw):
+ return self.test_statement.compile.__func__(
+ self, dialect=dialect, **kw
+ )
+
+ def _compiler(self, dialect, **kw):
+ return self.test_statement._compiler.__func__(
+ self, dialect, **kw
+ )
+
+ def _compiler_dispatch(self, compiler, **kwargs):
+ if hasattr(compiler, "statement"):
+ with mock.patch.object(
+ compiler, "statement", DontAccess()
+ ):
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+ else:
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+
+ # no construct can assume it's the "top level" construct in all cases
+ # as anything can be nested. ensure constructs don't assume they
+ # are the "self.statement" element
+ c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
+
+ if isinstance(clause, sqltypes.TypeEngine):
+ cache_key_no_warnings = clause._static_cache_key
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings)
+ else:
+ cache_key_no_warnings = clause._generate_cache_key()
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings[0])
+
+ param_str = repr(getattr(c, "params", {}))
+ if util.py3k:
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
+ print(
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
+ else:
+ print(
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
+
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+ if check_prefetch is not None:
+ eq_(c.prefetch, check_prefetch)
+ if check_literal_execute is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.literal_execute_params
+ },
+ check_literal_execute,
+ )
+ if check_post_param is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.post_compile_params
+ },
+ check_post_param,
+ )
+
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ msg = "Type '%s' doesn't correspond to type '%s'"
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(
+ {f.column.name for f in c.foreign_keys},
+ {f.column.name for f in reflected_c.foreign_keys},
+ )
+ if c.server_default:
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(
+ c2.type
+ ), "On column %r, type '%s' doesn't correspond to type '%s'" % (
+ c1.name,
+ c1.type,
+ c2.type,
+ )
+
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print(repr(result))
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list_):
+ self.assert_(
+ len(result) == len(list_),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
+ for i in range(0, len(list_)):
+ self.assert_row(class_, result[i], list_[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
+ for key, value in desc.items():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = {immutabledict(e) for e in expected}
+
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
+
+ if len(found) != len(expected):
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
+
+ NOVALUE = object()
+
+ def _compare_item(obj, spec):
+ for key, value in spec.items():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1]
+ )
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
+ return True
+
+ def sql_execution_asserter(self, db=None):
+ if db is None:
+ from . import db as db
+
+ return assertsql.assert_engine(db)
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ result = callable_()
+ asserter.assert_(*rules)
+ return result
+
+ def assert_sql(self, db, callable_, rules):
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
+ else:
+ newrule = assertsql.CompiledSQL(*rule)
+ newrules.append(newrule)
+
+ return self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(
+ db, callable_, assertsql.CountStatements(count)
+ )
+
+ def assert_multiple_sql_count(self, dbs, callable_, counts):
+ recs = [
+ (self.sql_execution_asserter(db), db, count)
+ for (db, count) in zip(dbs, counts)
+ ]
+ asserters = []
+ for ctx, db, count in recs:
+ asserters.append(ctx.__enter__())
+ try:
+ return callable_()
+ finally:
+ for asserter, (ctx, db, count) in zip(asserters, recs):
+ ctx.__exit__(None, None, None)
+ asserter.assert_(assertsql.CountStatements(count))
+
+ @contextlib.contextmanager
+ def assert_execution(self, db, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ yield
+ asserter.assert_(*rules)
+
+ def assert_statement_count(self, db, count):
+ return self.assert_execution(db, assertsql.CountStatements(count))
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
new file mode 100644
index 0000000..565b3ed
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,457 @@
+# testing/assertsql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+import contextlib
+import re
+
+from .. import event
+from .. import util
+from ..engine import url
+from ..engine.default import DefaultDialect
+from ..engine.util import _distill_cursor_params
+from ..schema import _DDLCompiles
+
+
+class AssertRule(object):
+
+ is_consumed = False
+ errormessage = None
+ consume_statement = True
+
+ def process_statement(self, execute_observed):
+ pass
+
+ def no_more_statements(self):
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
+
+
+class SQLMatchRule(AssertRule):
+ pass
+
+
+class CursorSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, consume_statement=True):
+ self.statement = statement
+ self.params = params
+ self.consume_statement = consume_statement
+
+ def process_statement(self, execute_observed):
+ stmt = execute_observed.statements[0]
+ if self.statement != stmt.statement or (
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
+ )
+ )
+ else:
+ execute_observed.statements.pop(0)
+ self.is_consumed = True
+ if not execute_observed.statements:
+ self.consume_statement = True
+
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, dialect="default"):
+ self.statement = statement
+ self.params = params
+ self.dialect = dialect
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ return received_statement == stmt
+
+ def _compile_dialect(self, execute_observed):
+ if self.dialect == "default":
+ dialect = DefaultDialect()
+ # this is currently what tests are expecting
+ # dialect.supports_default_values = True
+ dialect.supports_default_metavalue = True
+ return dialect
+ else:
+ # ugh
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
+ else:
+ params = {}
+ return url.URL.create(self.dialect).get_dialect()(**params)
+
+ def _received_statement(self, execute_observed):
+ """reconstruct the statement and params in terms
+ of a target dialect, which for CompiledSQL is just DefaultDialect."""
+
+ context = execute_observed.context
+ compare_dialect = self._compile_dialect(execute_observed)
+
+ # received_statement runs a full compile(). we should not need to
+ # consider extracted_parameters; if we do this indicates some state
+ # is being sent from a previous cached query, which some misbehaviors
+ # in the ORM can cause, see #6881
+ cache_key = None # execute_observed.context.compiled.cache_key
+ extracted_parameters = (
+ None # execute_observed.context.extracted_parameters
+ )
+
+ if "schema_translate_map" in context.execution_options:
+ map_ = context.execution_options["schema_translate_map"]
+ else:
+ map_ = None
+
+ if isinstance(execute_observed.clauseelement, _DDLCompiles):
+
+ compiled = execute_observed.clauseelement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=map_,
+ )
+ else:
+ compiled = execute_observed.clauseelement.compile(
+ cache_key=cache_key,
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ for_executemany=context.compiled.for_executemany,
+ schema_translate_map=map_,
+ )
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
+ parameters = execute_observed.parameters
+
+ if not parameters:
+ _received_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters
+ )
+ ]
+ else:
+ _received_parameters = [
+ compiled.construct_params(
+ m, extracted_parameters=extracted_parameters
+ )
+ for m in parameters
+ ]
+
+ return _received_statement, _received_parameters
+
+ def process_statement(self, execute_observed):
+ context = execute_observed.context
+
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
+ params = self._all_params(context)
+
+ equivalent = self._compare_sql(execute_observed, _received_statement)
+
+ if equivalent:
+ if params is not None:
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while all_params and all_received:
+ param = dict(all_params.pop(0))
+
+ for idx, received in enumerate(list(all_received)):
+ # do a positive compare only
+ for param_key in param:
+ # a key in param did not match current
+ # 'received'
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
+ break
+ else:
+ # all keys in param matched 'received';
+ # onto next param
+ del all_received[idx]
+ break
+ else:
+ # param did not match any entry
+ # in all_received
+ equivalent = False
+ break
+ if all_params or all_received:
+ equivalent = False
+
+ if equivalent:
+ self.is_consumed = True
+ self.errormessage = None
+ else:
+ self.errormessage = self._failure_message(params) % {
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
+ }
+
+ def _all_params(self, context):
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ return params
+ else:
+ return None
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement\n%r partial params %s, "
+ "received\n%%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.statement.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+
+class RegexSQL(CompiledSQL):
+ def __init__(self, regex, params=None, dialect="default"):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+ self.dialect = dialect
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement ~%r partial params %s, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.orig_regex.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+ def _compare_sql(self, execute_observed, received_statement):
+ return bool(self.regex.match(received_statement))
+
+
+class DialectSQL(CompiledSQL):
+ def _compile_dialect(self, execute_observed):
+ return execute_observed.context.dialect
+
+ def _compare_no_space(self, real_stmt, received_stmt):
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
+ return received_stmt == stmt
+
+ def _received_statement(self, execute_observed):
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
+
+ # TODO: why do we need this part?
+ for real_stmt in execute_observed.statements:
+ if self._compare_no_space(real_stmt.statement, received_stmt):
+ break
+ else:
+ raise AssertionError(
+ "Can't locate compiled statement %r in list of "
+ "statements actually invoked" % received_stmt
+ )
+
+ return received_stmt, execute_observed.context.compiled_parameters
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ # convert our comparison statement to have the
+ # paramstyle of the received
+ paramstyle = execute_observed.context.dialect.paramstyle
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
+ else:
+ # positional params
+ repl = None
+ if paramstyle == "qmark":
+ repl = "?"
+ elif paramstyle == "format":
+ repl = r"%s"
+ elif paramstyle == "numeric":
+ repl = None
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+ return received_statement == stmt
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_statement(self, execute_observed):
+ self._statement_count += 1
+
+ def no_more_statements(self):
+ if self.count != self._statement_count:
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
+
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_statement(self, execute_observed):
+ for rule in list(self.rules):
+ rule.errormessage = None
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.discard(rule)
+ if not self.rules:
+ self.is_consumed = True
+ break
+ elif not rule.errormessage:
+ # rule is not done yet
+ self.errormessage = None
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class EachOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = list(rules)
+
+ def process_statement(self, execute_observed):
+ while self.rules:
+ rule = self.rules[0]
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.pop(0)
+ elif rule.errormessage:
+ self.errormessage = rule.errormessage
+ if rule.consume_statement:
+ break
+
+ if not self.rules:
+ self.is_consumed = True
+
+ def no_more_statements(self):
+ if self.rules and not self.rules[0].is_consumed:
+ self.rules[0].no_more_statements()
+ elif self.rules:
+ super(EachOf, self).no_more_statements()
+
+
+class Conditional(EachOf):
+ def __init__(self, condition, rules, else_rules):
+ if condition:
+ super(Conditional, self).__init__(*rules)
+ else:
+ super(Conditional, self).__init__(*else_rules)
+
+
+class Or(AllOf):
+ def process_statement(self, execute_observed):
+ for rule in self.rules:
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.is_consumed = True
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class SQLExecuteObserved(object):
+ def __init__(self, context, clauseelement, multiparams, params):
+ self.context = context
+ self.clauseelement = clauseelement
+ self.parameters = _distill_cursor_params(
+ context.connection, tuple(multiparams), params
+ )
+ self.statements = []
+
+ def __repr__(self):
+ return str(self.statements)
+
+
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"],
+ )
+):
+ pass
+
+
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
+
+ def _close(self):
+ self._final = self.accumulated
+ del self.accumulated
+
+ def assert_(self, *rules):
+ rule = EachOf(*rules)
+
+ observed = list(self._final)
+ while observed:
+ statement = observed.pop(0)
+ rule.process_statement(statement)
+ if rule.is_consumed:
+ break
+ elif rule.errormessage:
+ assert False, rule.errormessage
+ if observed:
+ assert False, "Additional SQL statements remain:\n%s" % observed
+ elif not rule.is_consumed:
+ rule.no_more_statements()
+
+
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
+
+ orig = []
+
+ @event.listens_for(engine, "before_execute")
+ def connection_execute(
+ conn, clauseelement, multiparams, params, execution_options
+ ):
+ # grab the original statement + params before any cursor
+ # execution
+ orig[:] = clauseelement, multiparams, params
+
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ if not context:
+ return
+ # then grab real cursor statements and associate them all
+ # around a single context
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
+ obs = asserter.accumulated[-1]
+ else:
+ obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+ asserter.accumulated.append(obs)
+ obs.statements.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany
+ )
+ )
+
+ try:
+ yield asserter
+ finally:
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "before_execute", connection_execute)
+ asserter._close()
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644
index 0000000..2189060
--- /dev/null
+++ b/lib/sqlalchemy/testing/asyncio.py
@@ -0,0 +1,128 @@
+# testing/asyncio.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+# functions and wrappers to run tests, fixtures, provisioning and
+# setup/teardown in an asyncio event loop, conditionally based on the
+# current DB driver being used for a test.
+
+# note that SQLAlchemy's asyncio integration also supports a method
+# of running individual asyncio functions inside of separate event loops
+# using "async_fallback" mode; however running whole functions in the event
+# loop is a more accurate test for how SQLAlchemy's asyncio features
+# would run in the real world.
+
+
+from functools import wraps
+import inspect
+
+from . import config
+from ..util.concurrency import _util_async_run
+from ..util.concurrency import _util_async_run_coroutine_function
+
+# may be set to False if the
+# --disable-asyncio flag is passed to the test runner.
+ENABLE_ASYNCIO = True
+
+
+def _run_coroutine_function(fn, *args, **kwargs):
+ return _util_async_run_coroutine_function(fn, *args, **kwargs)
+
+
+def _assume_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop unconditionally.
+
+ This function is used for provisioning features like
+ testing a database connection for server info.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ return _util_async_run(fn, *args, **kwargs)
+
+
+def _maybe_async_provisioning(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if any current drivers might need it.
+
+ This function is used for provisioning features that take
+ place outside of a specific database driver being selected, so if the
+ current driver that happens to be used for the provisioning operation
+ is an async driver, it will run in asyncio and not fail.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ if config.any_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if the current selected driver is
+ async.
+
+ This function is used for test setup/teardown and tests themselves
+ where the current DB driver is known.
+
+
+ """
+ if not ENABLE_ASYNCIO:
+
+ return fn(*args, **kwargs)
+
+ is_async = config._current.is_async
+
+ if is_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async_wrapper(fn):
+ """Apply the _maybe_async function to an existing function and return
+ as a wrapped callable, supporting generator functions as well.
+
+ This is currently used for pytest fixtures that support generator use.
+
+ """
+
+ if inspect.isgeneratorfunction(fn):
+ _stop = object()
+
+ def call_next(gen):
+ try:
+ return next(gen)
+ # can't raise StopIteration in an awaitable.
+ except StopIteration:
+ return _stop
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ gen = fn(*args, **kwargs)
+ while True:
+ value = _maybe_async(call_next, gen)
+ if value is _stop:
+ break
+ yield value
+
+ else:
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ return _maybe_async(fn, *args, **kwargs)
+
+ return wrap_fixture
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
new file mode 100644
index 0000000..fc13a16
--- /dev/null
+++ b/lib/sqlalchemy/testing/config.py
@@ -0,0 +1,209 @@
+# testing/config.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+
+from .. import util
+
+requirements = None
+db = None
+db_url = None
+db_opts = None
+file_config = None
+test_schema = None
+test_schema_2 = None
+any_async = False
+_current = None
+ident = "main"
+
+_fixture_functions = None # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+ r"""Deliver multiple versions of a test based on positional combinations.
+
+ This is a facade over pytest.mark.parametrize.
+
+
+ :param \*comb: argument combinations. These are tuples that will be passed
+ positionally to the decorated function.
+
+ :param argnames: optional list of argument names. These are the names
+ of the arguments in the test function that correspond to the entries
+ in each argument tuple. pytest.mark.parametrize requires this, however
+ the combinations function will derive it automatically if not present
+ by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
+ first argument is "self" which is discarded.
+
+ :param id\_: optional id template. This is a string template that
+ describes how the "id" for each parameter set should be defined, if any.
+ The number of characters in the template should match the number of
+ entries in each argument tuple. Each character describes how the
+ corresponding entry in the argument tuple should be handled, as far as
+ whether or not it is included in the arguments passed to the function, as
+ well as if it is included in the tokens used to create the id of the
+ parameter set.
+
+ If omitted, the argument combinations are passed to parametrize as is. If
+ passed, each argument combination is turned into a pytest.param() object,
+ mapping the elements of the argument tuple to produce an id based on a
+ character value in the same position within the string template using the
+ following scheme::
+
+ i - the given argument is a string that is part of the id only, don't
+ pass it as an argument
+
+ n - the given argument should be passed and it should be added to the
+ id by calling the .__name__ attribute
+
+ r - the given argument should be passed and it should be added to the
+ id by calling repr()
+
+ s - the given argument should be passed and it should be added to the
+ id by calling str()
+
+ a - (argument) the given argument should be passed and it should not
+ be used to generated the id
+
+ e.g.::
+
+ @testing.combinations(
+ (operator.eq, "eq"),
+ (operator.ne, "ne"),
+ (operator.gt, "gt"),
+ (operator.lt, "lt"),
+ id_="na"
+ )
+ def test_operator(self, opfunc, name):
+ pass
+
+ The above combination will call ``.__name__`` on the first member of
+ each tuple and use that as the "id" to pytest.param().
+
+
+ """
+ return _fixture_functions.combinations(*comb, **kw)
+
+
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
+def fixture(*arg, **kw):
+ return _fixture_functions.fixture(*arg, **kw)
+
+
+def get_current_test_name():
+ return _fixture_functions.get_current_test_name()
+
+
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
+class Config(object):
+ def __init__(self, db, db_opts, options, file_config):
+ self._set_name(db)
+ self.db = db
+ self.db_opts = db_opts
+ self.options = options
+ self.file_config = file_config
+ self.test_schema = "test_schema"
+ self.test_schema_2 = "test_schema_2"
+
+ self.is_async = db.dialect.is_async and not util.asbool(
+ db.url.query.get("async_fallback", False)
+ )
+
+ _stack = collections.deque()
+ _configs = set()
+
+ def _set_name(self, db):
+ if db.dialect.server_version_info:
+ svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
+ self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
+ else:
+ self.name = "%s+%s" % (db.name, db.driver)
+
+ @classmethod
+ def register(cls, db, db_opts, options, file_config):
+ """add a config as one of the global configs.
+
+ If there are no configs set up yet, this config also
+ gets set as the "_current".
+ """
+ global any_async
+
+ cfg = Config(db, db_opts, options, file_config)
+
+ # if any backends include an async driver, then ensure
+ # all setup/teardown and tests are wrapped in the maybe_async()
+ # decorator that will set up a greenlet context for async drivers.
+ any_async = any_async or cfg.is_async
+
+ cls._configs.add(cfg)
+ return cfg
+
+ @classmethod
+ def set_as_current(cls, config, namespace):
+ global db, _current, db_url, test_schema, test_schema_2, db_opts
+ _current = config
+ db_url = config.db.url
+ db_opts = config.db_opts
+ test_schema = config.test_schema
+ test_schema_2 = config.test_schema_2
+ namespace.db = db = config.db
+
+ @classmethod
+ def push_engine(cls, db, namespace):
+ assert _current, "Can't push without a default Config set up"
+ cls.push(
+ Config(
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
+ )
+
+ @classmethod
+ def push(cls, config, namespace):
+ cls._stack.append(_current)
+ cls.set_as_current(config, namespace)
+
+ @classmethod
+ def pop(cls, namespace):
+ if cls._stack:
+ # a failed test w/ -x option can call reset() ahead of time
+ _current = cls._stack[-1]
+ del cls._stack[-1]
+ cls.set_as_current(_current, namespace)
+
+ @classmethod
+ def reset(cls, namespace):
+ if cls._stack:
+ cls.set_as_current(cls._stack[0], namespace)
+ cls._stack.clear()
+
+ @classmethod
+ def all_configs(cls):
+ return cls._configs
+
+ @classmethod
+ def all_dbs(cls):
+ for cfg in cls.all_configs():
+ yield cfg.db
+
+ def skip_test(self, msg):
+ skip_test(msg)
+
+
+def skip_test(msg):
+ raise _fixture_functions.skip_test_exception(msg)
+
+
+def async_test(fn):
+ return _fixture_functions.async_test(fn)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
new file mode 100644
index 0000000..b8be6b9
--- /dev/null
+++ b/lib/sqlalchemy/testing/engines.py
@@ -0,0 +1,465 @@
+# testing/engines.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import collections
+import re
+import warnings
+import weakref
+
+from . import config
+from .util import decorator
+from .util import gc_collect
+from .. import event
+from .. import pool
+from ..util import await_only
+
+
+class ConnectionKiller(object):
+ def __init__(self):
+ self.proxy_refs = weakref.WeakKeyDictionary()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
+
+ def add_pool(self, pool):
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
+
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
+
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
+
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn(
+ "testing_reaper couldn't rollback/close connection: %s" % e
+ )
+
+ def rollback_all(self):
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self._safe(rec.rollback)
+
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self.dbapi_connections.discard(rec.dbapi_connection)
+ self._safe(rec._checkin)
+
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
+
+ def close_all(self):
+ self.checkin_all()
+
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ if hasattr(rec, "sync_engine"):
+ await_only(rec.dispose())
+ else:
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
+
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def stop_test_class_outside_fixtures(self):
+ # ensure no refs to checked out connections at all.
+
+ if pool.base._strong_ref_connection_records:
+ gc_collect()
+
+ if pool.base._strong_ref_connection_records:
+ ln = len(pool.base._strong_ref_connection_records)
+ pool.base._strong_ref_connection_records.clear()
+ assert (
+ False
+ ), "%d connection recs not cleared after test suite" % (ln)
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
+
+ def assert_all_closed(self):
+ for rec in self.proxy_refs:
+ if rec.is_valid:
+ assert False
+
+
+testing_reaper = ConnectionKiller()
+
+
+@decorator
+def assert_conns_closed(fn, *args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.assert_all_closed()
+
+
+@decorator
+def rollback_open_connections(fn, *args, **kw):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+
+
+@decorator
+def close_first(fn, *args, **kw):
+ """Decorator that closes all connections before fn execution."""
+
+ testing_reaper.checkin_all()
+ fn(*args, **kw)
+
+
+@decorator
+def close_open_connections(fn, *args, **kw):
+ """Decorator that closes all connections after fn execution."""
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.checkin_all()
+
+
+def all_dialects(exclude=None):
+ import sqlalchemy.dialects as d
+
+ for name in d.__all__:
+ # TEMPORARY
+ if exclude and name in exclude:
+ continue
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(
+ __import__("sqlalchemy.dialects.%s" % name).dialects, name
+ )
+ yield mod.dialect()
+
+
+class ReconnectFixture(object):
+ def __init__(self, dbapi):
+ self.dbapi = dbapi
+ self.connections = []
+ self.is_stopped = False
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi, key)
+
+ def connect(self, *args, **kwargs):
+
+ conn = self.dbapi.connect(*args, **kwargs)
+ if self.is_stopped:
+ self._safe(conn.close)
+ curs = conn.cursor() # should fail on Oracle etc.
+ # should fail for everything that didn't fail
+ # above, connection is closed
+ curs.execute("select 1")
+ assert False, "simulated connect failure didn't work"
+ else:
+ self.connections.append(conn)
+ return conn
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
+
+ def shutdown(self, stop=False):
+ # TODO: this doesn't cover all cases
+ # as nicely as we'd like, namely MySQLdb.
+ # would need to implement R. Brewer's
+ # proxy server idea to get better
+ # coverage.
+ self.is_stopped = stop
+ for c in list(self.connections):
+ self._safe(c.close)
+ self.connections = []
+
+ def restart(self):
+ self.is_stopped = False
+
+
+def reconnecting_engine(url=None, options=None):
+ url = url or config.db.url
+ dbapi = config.db.dialect.dbapi
+ if not options:
+ options = {}
+ options["module"] = ReconnectFixture(dbapi)
+ engine = testing_engine(url, options)
+ _dispose = engine.dispose
+
+ def dispose():
+ engine.dialect.dbapi.shutdown()
+ engine.dialect.dbapi.is_stopped = False
+ _dispose()
+
+ engine.test_shutdown = engine.dialect.dbapi.shutdown
+ engine.test_restart = engine.dialect.dbapi.restart
+ engine.dispose = dispose
+ return engine
+
+
+def testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ _sqlite_savepoint=False,
+):
+ """Produce an engine configured by --options with optional overrides."""
+
+ if asyncio:
+ assert not _sqlite_savepoint
+ from sqlalchemy.ext.asyncio import (
+ create_async_engine as create_engine,
+ )
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
+ from sqlalchemy.future import create_engine
+ else:
+ from sqlalchemy import create_engine
+ from sqlalchemy.engine.url import make_url
+
+ if not options:
+ use_reaper = True
+ scope = "function"
+ sqlite_savepoint = False
+ else:
+ use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
+ sqlite_savepoint = options.pop("sqlite_savepoint", False)
+
+ url = url or config.db.url
+
+ url = make_url(url)
+ if options is None:
+ if config.db is None or url.drivername == config.db.url.drivername:
+ options = config.db_opts
+ else:
+ options = {}
+ elif config.db is not None and url.drivername == config.db.url.drivername:
+ default_opt = config.db_opts.copy()
+ default_opt.update(options)
+
+ engine = create_engine(url, **options)
+
+ if sqlite_savepoint and engine.name == "sqlite":
+ # apply SQLite savepoint workaround
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN")
+
+ if transfer_staticpool:
+ from sqlalchemy.pool import StaticPool
+
+ if config.db is not None and isinstance(config.db.pool, StaticPool):
+ use_reaper = False
+ engine.pool._transfer_from(config.db.pool)
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
+
+ if isinstance(engine.pool, pool.QueuePool):
+ engine.pool._timeout = 0
+ engine.pool._max_overflow = 0
+ if use_reaper:
+ testing_reaper.add_engine(engine, scope)
+
+ return engine
+
+
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
+
+ from sqlalchemy import create_mock_engine
+
+ if not dialect_name:
+ dialect_name = config.db.name
+
+ buffer = []
+
+ def executor(sql, *a, **kw):
+ buffer.append(sql)
+
+ def assert_sql(stmts):
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ def print_sql():
+ d = engine.dialect
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_mock_engine(dialect_name + "://", executor)
+ assert not hasattr(engine, "mock")
+ engine.mock = buffer
+ engine.assert_sql = assert_sql
+ engine.print_sql = print_sql
+ return engine
+
+
+class DBAPIProxyCursor(object):
+ """Proxy a DBAPI cursor.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level cursor operations.
+
+ """
+
+ def __init__(self, engine, conn, *args, **kwargs):
+ self.engine = engine
+ self.connection = conn
+ self.cursor = conn.cursor(*args, **kwargs)
+
+ def execute(self, stmt, parameters=None, **kw):
+ if parameters:
+ return self.cursor.execute(stmt, parameters, **kw)
+ else:
+ return self.cursor.execute(stmt, **kw)
+
+ def executemany(self, stmt, params, **kw):
+ return self.cursor.executemany(stmt, params, **kw)
+
+ def __iter__(self):
+ return iter(self.cursor)
+
+ def __getattr__(self, key):
+ return getattr(self.cursor, key)
+
+
+class DBAPIProxyConnection(object):
+ """Proxy a DBAPI connection.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level connection operations.
+
+ """
+
+ def __init__(self, engine, cursor_cls):
+ self.conn = engine.pool._creator()
+ self.engine = engine
+ self.cursor_cls = cursor_cls
+
+ def cursor(self, *args, **kwargs):
+ return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
+
+ def close(self):
+ self.conn.close()
+
+ def __getattr__(self, key):
+ return getattr(self.conn, key)
+
+
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
+ """Produce an engine that provides proxy hooks for
+ common methods.
+
+ """
+
+ def mock_conn():
+ return conn_cls(config.db, cursor_cls)
+
+ def _wrap_do_on_connect(do_on_connect):
+ def go(dbapi_conn):
+ return do_on_connect(dbapi_conn.conn)
+
+ return go
+
+ return testing_engine(
+ options={
+ "creator": mock_conn,
+ "_wrap_do_on_connect": _wrap_do_on_connect,
+ }
+ )
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
new file mode 100644
index 0000000..8ea65d6
--- /dev/null
+++ b/lib/sqlalchemy/testing/entities.py
@@ -0,0 +1,111 @@
+# testing/entities.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sqlalchemy as sa
+from .. import exc as sa_exc
+from ..util import compat
+
+_repr_stack = set()
+
+
+class BasicEntity(object):
+ def __init__(self, **kw):
+ for key, value in kw.items():
+ setattr(self, key, value)
+
+ def __repr__(self):
+ if id(self) in _repr_stack:
+ return object.__repr__(self)
+ _repr_stack.add(id(self))
+ try:
+ return "%s(%s)" % (
+ (self.__class__.__name__),
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
+ finally:
+ _repr_stack.remove(id(self))
+
+
+_recursion_stack = set()
+
+
+class ComparableMixin(object):
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __eq__(self, other):
+ """'Deep, sparse compare.
+
+ Deeply compare two entities, following the non-None attributes of the
+ non-persisted object, if possible.
+
+ """
+ if other is self:
+ return True
+ elif not self.__class__ == other.__class__:
+ return False
+
+ if id(self) in _recursion_stack:
+ return True
+ _recursion_stack.add(id(self))
+
+ try:
+ # pick the entity that's not SA persisted as the source
+ try:
+ self_key = sa.orm.attributes.instance_state(self).key
+ except sa.orm.exc.NO_STATE:
+ self_key = None
+
+ if other is None:
+ a = self
+ b = other
+ elif self_key is not None:
+ a = other
+ b = self
+ else:
+ a = self
+ b = other
+
+ for attr in list(a.__dict__):
+ if attr.startswith("_"):
+ continue
+ value = getattr(a, attr)
+
+ try:
+ # handle lazy loader errors
+ battr = getattr(b, attr)
+ except (AttributeError, sa_exc.UnboundExecutionError):
+ return False
+
+ if hasattr(value, "__iter__") and not isinstance(
+ value, compat.string_types
+ ):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
+ if list(value) != list(battr):
+ return False
+ else:
+ if set(value) != set(battr):
+ return False
+ else:
+ if value is not None and value != battr:
+ return False
+ return True
+ finally:
+ _recursion_stack.remove(id(self))
+
+
+class ComparableEntity(ComparableMixin, BasicEntity):
+ def __hash__(self):
+ return hash(self.__class__)
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
new file mode 100644
index 0000000..521a4aa
--- /dev/null
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -0,0 +1,465 @@
+# testing/exclusions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+import contextlib
+import operator
+import re
+import sys
+
+from . import config
+from .. import util
+from ..util import decorator
+from ..util.compat import inspect_getfullargspec
+
+
+def skip_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.skips.add(pred)
+ return rule
+
+
+def fails_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.fails.add(pred)
+ return rule
+
+
+class compound(object):
+ def __init__(self):
+ self.fails = set()
+ self.skips = set()
+ self.tags = set()
+
+ def __add__(self, other):
+ return self.add(other)
+
+ def as_skips(self):
+ rule = compound()
+ rule.skips.update(self.skips)
+ rule.skips.update(self.fails)
+ rule.tags.update(self.tags)
+ return rule
+
+ def add(self, *others):
+ copy = compound()
+ copy.fails.update(self.fails)
+ copy.skips.update(self.skips)
+ copy.tags.update(self.tags)
+ for other in others:
+ copy.fails.update(other.fails)
+ copy.skips.update(other.skips)
+ copy.tags.update(other.tags)
+ return copy
+
+ def not_(self):
+ copy = compound()
+ copy.fails.update(NotPredicate(fail) for fail in self.fails)
+ copy.skips.update(NotPredicate(skip) for skip in self.skips)
+ copy.tags.update(self.tags)
+ return copy
+
+ @property
+ def enabled(self):
+ return self.enabled_for_config(config._current)
+
+ def enabled_for_config(self, config):
+ for predicate in self.skips.union(self.fails):
+ if predicate(config):
+ return False
+ else:
+ return True
+
+ def matching_config_reasons(self, config):
+ return [
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
+ if predicate(config)
+ ]
+
+ def include_test(self, include_tags, exclude_tags):
+ return bool(
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
+ )
+
+ def _extend(self, other):
+ self.skips.update(other.skips)
+ self.fails.update(other.fails)
+ self.tags.update(other.tags)
+
+ def __call__(self, fn):
+ if hasattr(fn, "_sa_exclusion_extend"):
+ fn._sa_exclusion_extend._extend(self)
+ return fn
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ return self._do(config._current, fn, *args, **kw)
+
+ decorated = decorate(fn)
+ decorated._sa_exclusion_extend = self
+ return decorated
+
+ @contextlib.contextmanager
+ def fail_if(self):
+ all_fails = compound()
+ all_fails.fails.update(self.skips.union(self.fails))
+
+ try:
+ yield
+ except Exception as ex:
+ all_fails._expect_failure(config._current, ex)
+ else:
+ all_fails._expect_success(config._current)
+
+ def _do(self, cfg, fn, *args, **kw):
+ for skip in self.skips:
+ if skip(cfg):
+ msg = "'%s' : %s" % (
+ config.get_current_test_name(),
+ skip._as_string(cfg),
+ )
+ config.skip_test(msg)
+
+ try:
+ return_value = fn(*args, **kw)
+ except Exception as ex:
+ self._expect_failure(cfg, ex, name=fn.__name__)
+ else:
+ self._expect_success(cfg, name=fn.__name__)
+ return return_value
+
+ def _expect_failure(self, config, ex, name="block"):
+ for fail in self.fails:
+ if fail(config):
+ if util.py2k:
+ str_ex = unicode(ex).encode( # noqa: F821
+ "utf-8", errors="ignore"
+ )
+ else:
+ str_ex = str(ex)
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str_ex)
+ )
+ )
+ break
+ else:
+ util.raise_(ex, with_traceback=sys.exc_info()[2])
+
+ def _expect_success(self, config, name="block"):
+ if not self.fails:
+ return
+
+ for fail in self.fails:
+ if fail(config):
+ raise AssertionError(
+ "Unexpected success for '%s' (%s)"
+ % (
+ name,
+ " and ".join(
+ fail._as_string(config) for fail in self.fails
+ ),
+ )
+ )
+
+
+def requires_tag(tagname):
+ return tags([tagname])
+
+
+def tags(tagnames):
+ comp = compound()
+ comp.tags.update(tagnames)
+ return comp
+
+
+def only_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return skip_if(NotPredicate(predicate), reason)
+
+
+def succeeds_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return fails_if(NotPredicate(predicate), reason)
+
+
+class Predicate(object):
+ @classmethod
+ def as_predicate(cls, predicate, description=None):
+ if isinstance(predicate, compound):
+ return cls.as_predicate(predicate.enabled_for_config, description)
+ elif isinstance(predicate, Predicate):
+ if description and predicate.description is None:
+ predicate.description = description
+ return predicate
+ elif isinstance(predicate, (list, set)):
+ return OrPredicate(
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
+ elif isinstance(predicate, tuple):
+ return SpecPredicate(*predicate)
+ elif isinstance(predicate, util.string_types):
+ tokens = re.match(
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
+ if not tokens:
+ raise ValueError(
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
+ db = tokens.group(1)
+ op = tokens.group(2)
+ spec = (
+ tuple(int(d) for d in tokens.group(3).split("."))
+ if tokens.group(3)
+ else None
+ )
+
+ return SpecPredicate(db, op, spec, description=description)
+ elif callable(predicate):
+ return LambdaPredicate(predicate, description)
+ else:
+ assert False, "unknown predicate type: %s" % predicate
+
+ def _format_description(self, config, negate=False):
+ bool_ = self(config)
+ if negate:
+ bool_ = not negate
+ return self.description % {
+ "driver": config.db.url.get_driver_name()
+ if config
+ else "<no driver>",
+ "database": config.db.url.get_backend_name()
+ if config
+ else "<no database>",
+ "doesnt_support": "doesn't support" if bool_ else "does support",
+ "does_support": "does support" if bool_ else "doesn't support",
+ }
+
+ def _as_string(self, config=None, negate=False):
+ raise NotImplementedError()
+
+
+class BooleanPredicate(Predicate):
+ def __init__(self, value, description=None):
+ self.value = value
+ self.description = description or "boolean %s" % value
+
+ def __call__(self, config):
+ return self.value
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config, negate=negate)
+
+
+class SpecPredicate(Predicate):
+ def __init__(self, db, op=None, spec=None, description=None):
+ self.db = db
+ self.op = op
+ self.spec = spec
+ self.description = description
+
+ _ops = {
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+ def __call__(self, config):
+ if config is None:
+ return False
+
+ engine = config.db
+
+ if "+" in self.db:
+ dialect, driver = self.db.split("+")
+ else:
+ dialect, driver = self.db, None
+
+ if dialect and engine.name != dialect:
+ return False
+ if driver is not None and engine.driver != driver:
+ return False
+
+ if self.op is not None:
+ assert driver is None, "DBAPI version specs not supported yet"
+
+ version = _server_version(engine)
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
+ return oper(version, self.spec)
+ else:
+ return True
+
+ def _as_string(self, config, negate=False):
+ if self.description is not None:
+ return self._format_description(config)
+ elif self.op is None:
+ if negate:
+ return "not %s" % self.db
+ else:
+ return "%s" % self.db
+ else:
+ if negate:
+ return "not %s %s %s" % (self.db, self.op, self.spec)
+ else:
+ return "%s %s %s" % (self.db, self.op, self.spec)
+
+
+class LambdaPredicate(Predicate):
+ def __init__(self, lambda_, description=None, args=None, kw=None):
+ spec = inspect_getfullargspec(lambda_)
+ if not spec[0]:
+ self.lambda_ = lambda db: lambda_()
+ else:
+ self.lambda_ = lambda_
+ self.args = args or ()
+ self.kw = kw or {}
+ if description:
+ self.description = description
+ elif lambda_.__doc__:
+ self.description = lambda_.__doc__
+ else:
+ self.description = "custom function"
+
+ def __call__(self, config):
+ return self.lambda_(config)
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config)
+
+
+class NotPredicate(Predicate):
+ def __init__(self, predicate, description=None):
+ self.predicate = predicate
+ self.description = description
+
+ def __call__(self, config):
+ return not self.predicate(config)
+
+ def _as_string(self, config, negate=False):
+ if self.description:
+ return self._format_description(config, not negate)
+ else:
+ return self.predicate._as_string(config, not negate)
+
+
+class OrPredicate(Predicate):
+ def __init__(self, predicates, description=None):
+ self.predicates = predicates
+ self.description = description
+
+ def __call__(self, config):
+ for pred in self.predicates:
+ if pred(config):
+ return True
+ return False
+
+ def _eval_str(self, config, negate=False):
+ if negate:
+ conjunction = " and "
+ else:
+ conjunction = " or "
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
+
+ def _negation_str(self, config):
+ if self.description is not None:
+ return "Not " + self._format_description(config)
+ else:
+ return self._eval_str(config, negate=True)
+
+ def _as_string(self, config, negate=False):
+ if negate:
+ return self._negation_str(config)
+ else:
+ if self.description is not None:
+ return self._format_description(config)
+ else:
+ return self._eval_str(config)
+
+
+_as_predicate = Predicate.as_predicate
+
+
+def _is_excluded(db, op, spec):
+ return SpecPredicate(db, op, spec)(config._current)
+
+
+def _server_version(engine):
+ """Return a server_version_info tuple."""
+
+ # force metadata to be retrieved
+ conn = engine.connect()
+ version = getattr(engine.dialect, "server_version_info", None)
+ if version is None:
+ version = ()
+ conn.close()
+ return version
+
+
+def db_spec(*dbs):
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
+
+
+def open(): # noqa
+ return skip_if(BooleanPredicate(False, "mark as execute"))
+
+
+def closed():
+ return skip_if(BooleanPredicate(True, "marked as skip"))
+
+
+def fails(reason=None):
+ return fails_if(BooleanPredicate(True, reason or "expected to fail"))
+
+
+@decorator
+def future(fn, *arg):
+ return fails_if(LambdaPredicate(fn), "Future feature")
+
+
+def fails_on(db, reason=None):
+ return fails_if(db, reason)
+
+
+def fails_on_everything_except(*dbs):
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
+
+
+def skip(db, reason=None):
+ return skip_if(db, reason)
+
+
+def only_on(dbs, reason=None):
+ return only_if(
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
+ )
+
+
+def exclude(db, op, spec, reason=None):
+ return skip_if(SpecPredicate(db, op, spec), reason)
+
+
+def against(config, *queries):
+ assert queries, "no queries sent!"
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
new file mode 100644
index 0000000..0a2d63b
--- /dev/null
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -0,0 +1,870 @@
+# testing/fixtures.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import contextlib
+import re
+import sys
+
+import sqlalchemy as sa
+from . import assertions
+from . import config
+from . import schema
+from .entities import BasicEntity
+from .entities import ComparableEntity
+from .entities import ComparableMixin # noqa
+from .util import adict
+from .util import drop_all_tables_from_metadata
+from .. import event
+from .. import util
+from ..orm import declarative_base
+from ..orm import registry
+from ..orm.decl_api import DeclarativeMeta
+from ..schema import sort_tables_and_constraints
+
+
+@config.mark_base_test_class()
+class TestBase(object):
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ @config.fixture()
+ def nocache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+ @config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
+ def connection(self):
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
+
+ conn = eng.connect()
+ trans = conn.begin()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
+
+ @config.fixture()
+ def close_result_when_finished(self):
+ to_close = []
+ to_consume = []
+
+ def go(result, consume=False):
+ to_close.append(result)
+ if consume:
+ to_consume.append(result)
+
+ yield go
+ for r in to_consume:
+ try:
+ r.all()
+ except:
+ pass
+ for r in to_close:
+ try:
+ r.close()
+ except:
+ pass
+
+ @config.fixture()
+ def registry(self, metadata):
+ reg = registry(metadata=metadata)
+ yield reg
+ reg.dispose()
+
+ @config.fixture
+ def decl_base(self, registry):
+ return registry.generate_base()
+
+ @config.fixture()
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
+
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url,
+ options=options,
+ future=future,
+ asyncio=asyncio,
+ transfer_staticpool=transfer_staticpool,
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
+
+ @config.fixture()
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture
+ def fixture_session(self):
+ return fixture_session()
+
+ @config.fixture()
+ def metadata(self, request):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ..sql import schema
+
+ metadata = schema.MetaData()
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+
+ @config.fixture(
+ params=[
+ (rollback, second_operation, begin_nested)
+ for rollback in (True, False)
+ for second_operation in ("none", "execute", "begin")
+ for begin_nested in (
+ True,
+ False,
+ )
+ ]
+ )
+ def trans_ctx_manager_fixture(self, request, metadata):
+ rollback, second_operation, begin_nested = request.param
+
+ from sqlalchemy import Table, Column, Integer, func, select
+ from . import eq_
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ def run_test(subject, trans_on_subject, execute_on_subject):
+ with subject.begin() as trans:
+
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ with nested_trans:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ nested_trans.rollback()
+ else:
+ nested_trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context "
+ "manager. Please complete the context "
+ "manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 14})
+ else:
+ trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ subject.rollback()
+ else:
+ subject.commit()
+ else:
+ if rollback:
+ trans.rollback()
+ else:
+ trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 12})
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if hasattr(trans, "begin"):
+ trans.begin()
+ else:
+ subject.begin()
+ elif second_operation == "begin_nested":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+_connection_fixture_connection = None
+
+
+@contextlib.contextmanager
+def _push_future_engine(engine):
+
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
+
+ config._current.pop(testing)
+
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+
+class TablesTest(TestBase):
+
+ # 'once', None
+ run_setup_bind = "once"
+
+ # 'once', 'each', None
+ run_define_tables = "once"
+
+ # 'once', 'each', None
+ run_create_tables = "once"
+
+ # 'once', 'each', None
+ run_inserts = "each"
+
+ # 'each', None
+ run_deletes = "each"
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ _tables_metadata = None
+ tables = None
+ other = None
+ sequences = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
+
+ cls.other = adict()
+ cls.tables = adict()
+ cls.sequences = adict()
+
+ cls.bind = cls.setup_bind()
+ cls._tables_metadata = sa.MetaData()
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == "once":
+ cls._load_fixtures()
+ with cls.bind.begin() as conn:
+ cls.insert_data(conn)
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == "once":
+ cls.define_tables(cls._tables_metadata)
+ if cls.run_create_tables == "once":
+ cls._tables_metadata.create_all(cls.bind)
+ cls.tables.update(cls._tables_metadata.tables)
+ cls.sequences.update(cls._tables_metadata._sequences)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == "each":
+ self.define_tables(self._tables_metadata)
+ if self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+ self.tables.update(self._tables_metadata.tables)
+ self.sequences.update(self._tables_metadata._sequences)
+ elif self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == "each":
+ self._load_fixtures()
+ with self.bind.begin() as conn:
+ self.insert_data(conn)
+
+ def _teardown_each_tables(self):
+ if self.run_define_tables == "each":
+ self.tables.clear()
+ if self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+ self._tables_metadata.clear()
+ elif self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+ savepoints = getattr(config.requirements, "savepoints", False)
+ if savepoints:
+ savepoints = savepoints.enabled
+
+ # no need to run deletes if tables are recreated on setup
+ if (
+ self.run_define_tables != "each"
+ and self.run_create_tables != "each"
+ and self.run_deletes == "each"
+ ):
+ with self.bind.begin() as conn:
+ for table in reversed(
+ [
+ t
+ for (t, fks) in sort_tables_and_constraints(
+ self._tables_metadata.tables.values()
+ )
+ if t is not None
+ ]
+ ):
+ try:
+ if savepoints:
+ with conn.begin_nested():
+ conn.execute(table.delete())
+ else:
+ conn.execute(table.delete())
+ except sa.exc.DBAPIError as ex:
+ util.print_(
+ ("Error emptying table %s: %r" % (table, ex)),
+ file=sys.stderr,
+ )
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+ if cls.run_dispose_bind == "once":
+ cls.dispose_bind(cls.bind)
+
+ cls._tables_metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, "dispose"):
+ bind.dispose()
+ elif hasattr(bind, "close"):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls, connection):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().items():
+ if len(data) < 2:
+ continue
+ if isinstance(table, util.string_types):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table, fks in sort_tables_and_constraints(
+ cls._tables_metadata.tables.values()
+ ):
+ if table is None:
+ continue
+ if table not in headers:
+ continue
+ with cls.bind.begin() as conn:
+ conn.execute(
+ table.insert(),
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
+
+class NoCache(object):
+ @config.fixture(autouse=True, scope="function")
+ def _disable_cache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+
+class RemovesEvents(object):
+ @util.memoized_property
+ def _event_fns(self):
+ return set()
+
+ def event_listen(self, target, name, fn, **kw):
+ self._event_fns.add((target, name, fn))
+ event.listen(target, name, fn, **kw)
+
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
+ for key in self._event_fns:
+ event.remove(*key)
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+ kw.setdefault("autoflush", True)
+ kw.setdefault("expire_on_commit", True)
+
+ bind = kw.pop("bind", config.db)
+
+ sess = sa.orm.Session(bind, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
+ pass
+
+
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = "once"
+
+ # 'once', 'each', None
+ run_setup_mappers = "each"
+
+ classes = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_classes()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ yield
+
+ sa.orm.session.close_all_sessions()
+ self._teardown_each_mappers()
+ self._teardown_each_classes()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == "once":
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == "once":
+ cls.mapper_registry, cls.mapper = cls._generate_registry()
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers != "once":
+ (
+ self.__class__.mapper_registry,
+ self.__class__.mapper,
+ ) = self._generate_registry()
+
+ if self.run_setup_mappers == "each":
+ self._with_register_classes(self.setup_mappers)
+
+ def _setup_each_classes(self):
+ if self.run_setup_classes == "each":
+ self._with_register_classes(self.setup_classes)
+
+ @classmethod
+ def _generate_registry(cls):
+ decl = registry(metadata=cls._tables_metadata)
+ return decl, decl.map_imperatively
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+
+ assert cls_registry is not None
+
+ class FindFixture(type):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ type.__init__(cls, classname, bases, dict_)
+
+ class _Base(util.with_metaclass(FindFixture, object)):
+ pass
+
+ class Basic(BasicEntity, _Base):
+ pass
+
+ class Comparable(ComparableEntity, _Base):
+ pass
+
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != "once":
+ sa.orm.clear_mappers()
+
+ def _teardown_each_classes(self):
+ if self.run_setup_classes != "once":
+ self.classes.clear()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+
+ class FindFixtureDeclarative(DeclarativeMeta):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ DeclarativeMeta.__init__(cls, classname, bases, dict_)
+
+ class DeclarativeBasic(object):
+ __table_cls__ = schema.Table
+
+ _DeclBase = declarative_base(
+ metadata=cls._tables_metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
+
+ cls.DeclarativeBasic = _DeclBase
+
+ # sets up cls.Basic which is helpful for things like composite
+ # classes
+ super(DeclarativeMappedTest, cls)._with_register_classes(fn)
+
+ if cls._tables_metadata.tables and cls.run_create_tables:
+ cls._tables_metadata.create_all(config.db)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("computed_columns", "table_reflection")
+
+ regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+ def normalize(self, text):
+ return self.regexp.sub("", text).lower()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from .. import Integer
+ from .. import testing
+ from ..schema import Column
+ from ..schema import Computed
+ from ..schema import Table
+
+ Table(
+ "computed_default_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_col", Integer, Computed("normal + 42")),
+ Column("with_default", Integer, server_default="42"),
+ )
+
+ t = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal + 42")),
+ )
+
+ if testing.requires.schemas.enabled:
+ t2 = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal / 42")),
+ schema=config.test_schema,
+ )
+
+ if testing.requires.computed_columns_virtual.enabled:
+ t.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal + 2", persisted=False),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal / 2", persisted=False),
+ )
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ t.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal - 42", persisted=True),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal * 42", persisted=True),
+ )
+ )
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
new file mode 100644
index 0000000..e333c70
--- /dev/null
+++ b/lib/sqlalchemy/testing/mock.py
@@ -0,0 +1,32 @@
+# testing/mock.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+
+from ..util import py3k
+
+
+if py3k:
+ from unittest.mock import MagicMock
+ from unittest.mock import Mock
+ from unittest.mock import call
+ from unittest.mock import patch
+ from unittest.mock import ANY
+else:
+ try:
+ from mock import MagicMock # noqa
+ from mock import Mock # noqa
+ from mock import call # noqa
+ from mock import patch # noqa
+ from mock import ANY # noqa
+ except ImportError:
+ raise ImportError(
+ "SQLAlchemy's test suite requires the "
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
new file mode 100644
index 0000000..f05960c
--- /dev/null
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -0,0 +1,151 @@
+# testing/pickleable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Classes used in pickling tests, need to be at the module level for
+unpickling.
+"""
+
+from . import fixtures
+from ..schema import Column
+from ..types import String
+
+
+class User(fixtures.ComparableEntity):
+ pass
+
+
+class Order(fixtures.ComparableEntity):
+ pass
+
+
+class Dingaling(fixtures.ComparableEntity):
+ pass
+
+
+class EmailUser(User):
+ pass
+
+
+class Address(fixtures.ComparableEntity):
+ pass
+
+
+# TODO: these are kind of arbitrary....
+class Child1(fixtures.ComparableEntity):
+ pass
+
+
+class Child2(fixtures.ComparableEntity):
+ pass
+
+
+class Parent(fixtures.ComparableEntity):
+ pass
+
+
+class Screen(object):
+ def __init__(self, obj, parent=None):
+ self.obj = obj
+ self.parent = parent
+
+
+class Mixin(object):
+ email_address = Column(String)
+
+
+class AddressWMixin(Mixin, fixtures.ComparableEntity):
+ pass
+
+
+class Foo(object):
+ def __init__(self, moredata, stuff="im stuff"):
+ self.data = "im data"
+ self.stuff = stuff
+ self.moredata = moredata
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
+
+
+class Bar(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class OldSchool:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+
+class OldSchoolWithoutCompare:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+
+class BarWithoutCompare(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class NotComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return NotImplemented
+
+ def __ne__(self, other):
+ return NotImplemented
+
+
+class BrokenComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ raise NotImplementedError
diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/__init__.py
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
new file mode 100644
index 0000000..6721f48
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -0,0 +1,54 @@
+"""
+Bootstrapper for test framework plugins.
+
+The entire rationale for this system is to get the modules in plugin/
+imported without importing all of the supporting library, so that we can
+set up things for testing before coverage starts.
+
+The rationale for all of plugin/ being *in* the supporting library in the
+first place is so that the testing and plugin suite is available to other
+libraries, mainly external SQLAlchemy and Alembic dialects, to make use
+of the same test environment and standard suites available to
+SQLAlchemy/Alembic themselves without the need to ship/install a separate
+package outside of SQLAlchemy.
+
+
+"""
+
+import os
+import sys
+
+
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
+
+
+def load_file_as_module(name):
+ path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
+
+ if sys.version_info >= (3, 5):
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location(name, path)
+ assert spec is not None
+ assert spec.loader is not None
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ else:
+ import imp
+
+ mod = imp.load_source(name, path)
+
+ return mod
+
+
+if to_bootstrap == "pytest":
+ sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
+ sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
+else:
+ raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
new file mode 100644
index 0000000..d59564e
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -0,0 +1,789 @@
+# plugin/plugin_base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Testing extensions.
+
+this module is designed to work as a testing-framework-agnostic library,
+created so that multiple test frameworks can be supported at once
+(mostly so that we can migrate to new ones). The current target
+is pytest.
+
+"""
+
+from __future__ import absolute_import
+
+import abc
+import logging
+import re
+import sys
+
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
+
+log = logging.getLogger("sqlalchemy.testing.plugin_base")
+
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+ import configparser
+
+ ABC = abc.ABC
+else:
+ import ConfigParser as configparser
+ import collections as collections_abc # noqa
+
+ class ABC(object):
+ __metaclass__ = abc.ABCMeta
+
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+profiling = None
+provision = None
+assertions = None
+requirements = None
+config = None
+testing = None
+util = None
+file_config = None
+
+logging = None
+include_tags = set()
+exclude_tags = set()
+options = None
+
+
+def setup_options(make_option):
+ make_option(
+ "--log-info",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type=str,
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type=str,
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dbdriver",
+ action="append",
+ type=str,
+ dest="dbdriver",
+ help="Additional database drivers to include in tests. "
+ "These are linked to the existing database URLs by the "
+ "provisioning system.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--disable-asyncio",
+ action="store_true",
+ help="disable test / fixtures / provisoning running in asyncio",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__ or __sparse_backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--notimingintensive",
+ action="store_true",
+ dest="notimingintensive",
+ help="Don't run timing intensive tests",
+ )
+ make_option(
+ "--profile-sort",
+ type=str,
+ default="cumulative",
+ dest="profilesort",
+ help="Type of sort for profiling standard output",
+ )
+ make_option(
+ "--profile-dump",
+ type=str,
+ dest="profiledump",
+ help="Filename where a single profile run will be dumped",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type=str,
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type=str,
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type=str,
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type=str,
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type=str,
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
+ make_option(
+ "--dump-pyannotate",
+ type=str,
+ dest="dump_pyannotate",
+ help="Run pyannotate and dump json info to given file",
+ )
+ make_option(
+ "--mypy-extra-test-path",
+ type=str,
+ action="append",
+ default=[],
+ dest="mypy_extra_test_paths",
+ help="Additional test directories to add to the mypy tests. "
+ "This is used only when running mypy tests. Multiple OK",
+ )
+
+
+def configure_follower(follower_ident):
+ """Configure required state for a follower.
+
+ This invokes in the parent process and typically includes
+ database creation.
+
+ """
+ from sqlalchemy.testing import provision
+
+ provision.FOLLOWER_IDENT = follower_ident
+
+
+def memoize_important_follower_config(dict_):
+ """Store important configuration we will need to send to a follower.
+
+ This invokes in the parent process after normal config is set up.
+
+ This is necessary as pytest seems to not be using forking, so we
+ start with nothing in memory, *but* it isn't running our argparse
+ callables, so we have to just copy all of that over.
+
+ """
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
+ }
+
+
+def restore_important_follower_config(dict_):
+ """Restore important configuration needed by a follower.
+
+ This invokes in the follower process.
+
+ """
+ global include_tags, exclude_tags
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
+
+
+def read_config():
+ global file_config
+ file_config = configparser.ConfigParser()
+ file_config.read(["setup.cfg", "test.cfg"])
+
+
+def pre_begin(opt):
+ """things to set up early, before coverage might be setup."""
+ global options
+ options = opt
+ for fn in pre_configure:
+ fn(options, file_config)
+
+
+def set_coverage_flag(value):
+ options.has_coverage = value
+
+
+def post_begin():
+ """things to set up later, once we know coverage is running."""
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ # late imports, has to happen after config.
+ global util, fixtures, engines, exclusions, assertions, provision
+ global warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
+ from sqlalchemy.testing import fixtures, engines, exclusions # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import config, provision # noqa
+ from sqlalchemy import util # noqa
+
+ warnings.setup_filters()
+
+
+def _log(opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+
+ logging.basicConfig()
+
+ if opt_str.endswith("-info"):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith("-debug"):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print("Available --db options (use --dburi to override)")
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
+ sys.exit(0)
+
+
+def _requirements_opt(opt_str, value, parser):
+ _setup_requirements(value)
+
+
+def _exclude_tag(opt_str, value, parser):
+ exclude_tags.add(value.replace("-", "_"))
+
+
+def _include_tag(opt_str, value, parser):
+ include_tags.add(value.replace("-", "_"))
+
+
+pre_configure = []
+post_configure = []
+
+
+def pre(fn):
+ pre_configure.append(fn)
+ return fn
+
+
+def post(fn):
+ post_configure.append(fn)
+ return fn
+
+
+@pre
+def _setup_options(opt, file_config):
+ global options
+ options = opt
+
+
+@pre
+def _set_nomemory(opt, file_config):
+ if opt.nomemory:
+ exclude_tags.add("memory_intensive")
+
+
+@pre
+def _set_notimingintensive(opt, file_config):
+ if opt.notimingintensive:
+ exclude_tags.add("timing_intensive")
+
+
+@pre
+def _monkeypatch_cdecimal(options, file_config):
+ if options.cdecimal:
+ import cdecimal
+
+ sys.modules["decimal"] = cdecimal
+
+
+@post
+def _init_symbols(options, file_config):
+ from sqlalchemy.testing import config
+
+ config._fixture_functions = _fixture_fn_class()
+
+
+@post
+def _set_disable_asyncio(opt, file_config):
+ if opt.disable_asyncio or not py3k:
+ from sqlalchemy.testing import asyncio
+
+ asyncio.ENABLE_ASYNCIO = False
+
+
+@post
+def _engine_uri(options, file_config):
+
+ from sqlalchemy import testing
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import provision
+
+ if options.dburi:
+ db_urls = list(options.dburi)
+ else:
+ db_urls = []
+
+ extra_drivers = options.dbdriver or []
+
+ if options.db:
+ for db_token in options.db:
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
+ raise RuntimeError(
+ "Unknown URI specifier '%s'. "
+ "Specify --dbs for known uris." % db
+ )
+ else:
+ db_urls.append(file_config.get("db", db))
+
+ if not db_urls:
+ db_urls.append(file_config.get("db", "default"))
+
+ config._current = None
+
+ expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
+
+ for db_url in expanded_urls:
+ log.info("Adding database URL: %s", db_url)
+
+ if options.write_idents and provision.FOLLOWER_IDENT:
+ with open(options.write_idents, "a") as file_:
+ file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
+
+ cfg = provision.setup_config(
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
+ if not config._current:
+ cfg.set_as_current(cfg, testing)
+
+
+@post
+def _requirements(options, file_config):
+
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
+ _setup_requirements(requirement_cls)
+
+
+def _setup_requirements(argument):
+ from sqlalchemy.testing import config
+ from sqlalchemy import testing
+
+ if config.requirements is not None:
+ return
+
+ modname, clsname = argument.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+ req_cls = getattr(mod, clsname)
+
+ config.requirements = testing.requires = req_cls()
+
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
+
+@post
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.testing import config
+
+ if options.dropfirst:
+ from sqlalchemy.testing import provision
+
+ for cfg in config.Config.all_configs():
+ provision.drop_all_schema_objects(cfg, cfg.db)
+
+
+@post
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm.util import randomize_unitofwork
+
+ randomize_unitofwork()
+
+
+@post
+def _post_setup_options(opt, file_config):
+ from sqlalchemy.testing import config
+
+ config.options = options
+ config.file_config = file_config
+
+
+@post
+def _setup_profiling(options, file_config):
+ from sqlalchemy.testing import profiling
+
+ profiling._profile_stats = profiling.ProfileStatsFile(
+ file_config.get("sqla_testing", "profile_file"),
+ sort=options.profilesort,
+ dump=options.profiledump,
+ )
+
+
+def want_class(name, cls):
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif name.startswith("_"):
+ return False
+ elif (
+ config.options.backend_only
+ and not getattr(cls, "__backend__", False)
+ and not getattr(cls, "__sparse_backend__", False)
+ and not getattr(cls, "__only_on__", False)
+ ):
+ return False
+ else:
+ return True
+
+
+def want_method(cls, fn):
+ if not fn.__name__.startswith("test_"):
+ return False
+ elif fn.__module__ is None:
+ return False
+ elif include_tags:
+ return (
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ ) or (
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
+ )
+ elif exclude_tags and hasattr(cls, "__tags__"):
+ return exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
+ return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
+ else:
+ return True
+
+
+def generate_sub_tests(cls, module):
+ if getattr(cls, "__backend__", False) or getattr(
+ cls, "__sparse_backend__", False
+ ):
+ sparse = getattr(cls, "__sparse_backend__", False)
+ for cfg in _possible_configs_for_cls(cls, sparse=sparse):
+ orig_name = cls.__name__
+
+ # we can have special chars in these names except for the
+ # pytest junit plugin, which is tripped up by the brackets
+ # and periods, so sanitize
+
+ alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub(r"_+$", "", alpha_name)
+ name = "%s_%s" % (cls.__name__, alpha_name)
+ subcls = type(
+ name,
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
+ )
+ setattr(module, name, subcls)
+ yield subcls
+ else:
+ yield cls
+
+
+def start_test_class_outside_fixtures(cls):
+ _do_skips(cls)
+ _setup_engine(cls)
+
+
+def stop_test_class(cls):
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
+
+
+def stop_test_class_outside_fixtures(cls):
+ engines.testing_reaper.stop_test_class_outside_fixtures()
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
+ try:
+ if not options.low_connections:
+ assertions.global_cleanup_assertions()
+ finally:
+ _restore_engine()
+
+
+def _restore_engine():
+ if config._current:
+ config._current.reset(testing)
+
+
+def final_process_cleanup():
+ engines.testing_reaper.final_cleanup()
+ assertions.global_cleanup_assertions()
+ _restore_engine()
+
+
+def _setup_engine(cls):
+ if getattr(cls, "__engine_options__", None):
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
+ config._current.push_engine(eng, testing)
+
+
+def before_test(test, test_module_name, test_class, test_name):
+
+ # format looks like:
+ # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
+
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
+
+ id_ = "%s.%s.%s" % (test_module_name, name, test_name)
+
+ profiling._start_current_test(id_)
+
+
+def after_test(test):
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
+
+
+def _possible_configs_for_cls(cls, reasons=None, sparse=False):
+ all_configs = set(config.Config.all_configs())
+
+ if cls.__unsupported_on__:
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ for config_obj in list(all_configs):
+ if spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on__", None):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ for config_obj in list(all_configs):
+ if not spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on_config__", None):
+ all_configs.intersection_update([cls.__only_on_config__])
+
+ if hasattr(cls, "__requires__"):
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__requires__:
+ check = getattr(requirements, requirement)
+
+ skip_reasons = check.matching_config_reasons(config_obj)
+ if skip_reasons:
+ all_configs.remove(config_obj)
+ if reasons is not None:
+ reasons.extend(skip_reasons)
+ break
+
+ if hasattr(cls, "__prefer_requires__"):
+ non_preferred = set()
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__prefer_requires__:
+ check = getattr(requirements, requirement)
+
+ if not check.enabled_for_config(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if sparse:
+ # pick only one config from each base dialect
+ # sorted so we get the same backend each time selecting the highest
+ # server version info.
+ per_dialect = {}
+ for cfg in reversed(
+ sorted(
+ all_configs,
+ key=lambda cfg: (
+ cfg.db.name,
+ cfg.db.driver,
+ cfg.db.dialect.server_version_info,
+ ),
+ )
+ ):
+ db = cfg.db.name
+ if db not in per_dialect:
+ per_dialect[db] = cfg
+ return per_dialect.values()
+
+ return all_configs
+
+
+def _do_skips(cls):
+ reasons = []
+ all_configs = _possible_configs_for_cls(cls, reasons)
+
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
+ if c():
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
+ )
+
+ if not all_configs:
+ msg = "'%s' unsupported on any DB implementation %s%s" % (
+ cls.__name__,
+ ", ".join(
+ "'%s(%s)+%s'"
+ % (
+ config_obj.db.name,
+ ".".join(
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
+ )
+ for config_obj in config.Config.all_configs()
+ ),
+ ", ".join(reasons),
+ )
+ config.skip_test(msg)
+ elif hasattr(cls, "__prefer_backends__"):
+ non_preferred = set()
+ spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
+ for config_obj in all_configs:
+ if not spec(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if config._current not in all_configs:
+ _setup_config(all_configs.pop(), cls)
+
+
+def _setup_config(config_obj, ctx):
+ config._current.push(config_obj, testing)
+
+
+class FixtureFunctions(ABC):
+ @abc.abstractmethod
+ def skip_test_exception(self, *arg, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def combinations(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def param_ident(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def fixture(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def get_current_test_name(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+ global _fixture_fn_class
+ _fixture_fn_class = fixture_fn_class
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
new file mode 100644
index 0000000..5a51582
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -0,0 +1,820 @@
+try:
+ # installed by bootstrap.py
+ import sqla_plugin_base as plugin_base
+except ImportError:
+ # assume we're a package, use traditional import
+ from . import plugin_base
+
+import argparse
+import collections
+from functools import update_wrapper
+import inspect
+import itertools
+import operator
+import os
+import re
+import sys
+import uuid
+
+import pytest
+
+
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
+def pytest_addoption(parser):
+ group = parser.getgroup("sqlalchemy")
+
+ def make_option(name, **kw):
+ callback_ = kw.pop("callback", None)
+ if callback_:
+
+ class CallableAction(argparse.Action):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ callback_(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ zeroarg_callback = kw.pop("zeroarg_callback", None)
+ if zeroarg_callback:
+
+ class CallableAction(argparse.Action):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None, # noqa
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ zeroarg_callback(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ group.addoption(name, **kw)
+
+ plugin_base.setup_options(make_option)
+ plugin_base.read_config()
+
+
+def pytest_configure(config):
+ if config.pluginmanager.hasplugin("xdist"):
+ config.pluginmanager.register(XDistHooks())
+
+ if hasattr(config, "workerinput"):
+ plugin_base.restore_important_follower_config(config.workerinput)
+ plugin_base.configure_follower(config.workerinput["follower_ident"])
+ else:
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
+ os.remove(config.option.write_idents)
+
+ plugin_base.pre_begin(config.option)
+
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
+
+ plugin_base.set_fixture_functions(PytestFixtureFunctions)
+
+ if config.option.dump_pyannotate:
+ global DUMP_PYANNOTATE
+ DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+ if DUMP_PYANNOTATE:
+ from pyannotate_runtime import collect_types
+
+ collect_types.start()
+ yield
+ if DUMP_PYANNOTATE:
+ collect_types.stop()
+
+
+def pytest_sessionstart(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._assume_async(plugin_base.post_begin)
+
+
+def pytest_sessionfinish(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
+
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+ def _filter(filename):
+ filename = os.path.normpath(os.path.abspath(filename))
+ if "lib/sqlalchemy" not in os.path.commonpath(
+ [filename, lib_sqlalchemy]
+ ):
+ return None
+ if "testing" in filename:
+ return None
+
+ return filename
+
+ collect_types.init_types_collection(filter_filename=_filter)
+
+
+class XDistHooks(object):
+ def pytest_configure_node(self, node):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ # the master for each node fills workerinput dictionary
+ # which pytest-xdist will transfer to the subprocess
+
+ plugin_base.memoize_important_follower_config(node.workerinput)
+
+ node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
+
+ asyncio._maybe_async_provisioning(
+ provision.create_follower_db, node.workerinput["follower_ident"]
+ )
+
+ def pytest_testnodedown(self, node, error):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(
+ provision.drop_follower_db, node.workerinput["follower_ident"]
+ )
+
+
+def pytest_collection_modifyitems(session, config, items):
+
+ # look for all those classes that specify __backend__ and
+ # expand them out into per-database test cases.
+
+ # this is much easier to do within pytest_pycollect_makeitem, however
+ # pytest is iterating through cls.__dict__ as makeitem is
+ # called which causes a "dictionary changed size" error on py3k.
+ # I'd submit a pullreq for them to turn it into a list first, but
+ # it's to suit the rather odd use case here which is that we are adding
+ # new classes to a module on the fly.
+
+ from sqlalchemy.testing import asyncio
+
+ rebuilt_items = collections.defaultdict(
+ lambda: collections.defaultdict(list)
+ )
+
+ items[:] = [
+ item
+ for item in items
+ if item.getparent(pytest.Class) is not None
+ and not item.getparent(pytest.Class).name.startswith("_")
+ ]
+
+ test_classes = set(item.getparent(pytest.Class) for item in items)
+
+ def collect(element):
+ for inst_or_fn in element.collect():
+ if isinstance(inst_or_fn, pytest.Collector):
+ # no yield from in 2.7
+ for el in collect(inst_or_fn):
+ yield el
+ else:
+ yield inst_or_fn
+
+ def setup_test_classes():
+ for test_class in test_classes:
+ for sub_cls in plugin_base.generate_sub_tests(
+ test_class.cls, test_class.module
+ ):
+ if sub_cls is not test_class.cls:
+ per_cls_dict = rebuilt_items[test_class.cls]
+
+ # support pytest 5.4.0 and above pytest.Class.from_parent
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ module = test_class.getparent(pytest.Module)
+ for fn in collect(
+ ctor(name=sub_cls.__name__, parent=module)
+ ):
+ per_cls_dict[fn.name].append(fn)
+
+ # class requirements will sometimes need to access the DB to check
+ # capabilities, so need to do this for async
+ asyncio._maybe_async_provisioning(setup_test_classes)
+
+ newitems = []
+ for item in items:
+ cls_ = item.cls
+ if cls_ in rebuilt_items:
+ newitems.extend(rebuilt_items[cls_][item.name])
+ else:
+ newitems.append(item)
+
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
+ # seems like the functions attached to a test class aren't sorted already?
+ # is that true and why's that? (when using unittest, they're sorted)
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.getparent(pytest.Module).name,
+ item.getparent(pytest.Class).name,
+ item.name,
+ ),
+ )
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+ if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+ from sqlalchemy.testing import config
+
+ if config.any_async:
+ obj = _apply_maybe_async(obj)
+
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ return [
+ ctor(name=parametrize_cls.__name__, parent=collector)
+ for parametrize_cls in _parametrize_cls(collector.module, obj)
+ ]
+ elif (
+ inspect.isfunction(obj)
+ and collector.cls is not None
+ and plugin_base.want_method(collector.cls, obj)
+ ):
+ # None means, fall back to default logic, which includes
+ # method-level parametrize
+ return None
+ else:
+ # empty list means skip this item
+ return []
+
+
+def _is_wrapped_coroutine_function(fn):
+ while hasattr(fn, "__wrapped__"):
+ fn = fn.__wrapped__
+
+ return inspect.iscoroutinefunction(fn)
+
+
+def _apply_maybe_async(obj, recurse=True):
+ from sqlalchemy.testing import asyncio
+
+ for name, value in vars(obj).items():
+ if (
+ (callable(value) or isinstance(value, classmethod))
+ and not getattr(value, "_maybe_async_applied", False)
+ and (name.startswith("test_"))
+ and not _is_wrapped_coroutine_function(value)
+ ):
+ is_classmethod = False
+ if isinstance(value, classmethod):
+ value = value.__func__
+ is_classmethod = True
+
+ @_pytest_fn_decorator
+ def make_async(fn, *args, **kwargs):
+ return asyncio._maybe_async(fn, *args, **kwargs)
+
+ do_async = make_async(value)
+ if is_classmethod:
+ do_async = classmethod(do_async)
+ do_async._maybe_async_applied = True
+
+ setattr(obj, name, do_async)
+ if recurse:
+ for cls in obj.mro()[1:]:
+ if cls != object:
+ _apply_maybe_async(cls, False)
+ return obj
+
+
+def _parametrize_cls(module, cls):
+ """implement a class-based version of pytest parametrize."""
+
+ if "_sa_parametrize" not in cls.__dict__:
+ return [cls]
+
+ _sa_parametrize = cls._sa_parametrize
+ classes = []
+ for full_param_set in itertools.product(
+ *[params for argname, params in _sa_parametrize]
+ ):
+ cls_variables = {}
+
+ for argname, param in zip(
+ [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+ ):
+ if not argname:
+ raise TypeError("need argnames for class-based combinations")
+ argname_split = re.split(r",\s*", argname)
+ for arg, val in zip(argname_split, param.values):
+ cls_variables[arg] = val
+ parametrized_name = "_".join(
+ # token is a string, but in py2k pytest is giving us a unicode,
+ # so call str() on it.
+ str(re.sub(r"\W", "", token))
+ for param in full_param_set
+ for token in param.id.split("-")
+ )
+ name = "%s_%s" % (cls.__name__, parametrized_name)
+ newcls = type.__new__(type, name, (cls,), cls_variables)
+ setattr(module, name, newcls)
+ classes.append(newcls)
+ return classes
+
+
+_current_class = None
+
+
+def pytest_runtest_setup(item):
+ from sqlalchemy.testing import asyncio
+
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
+ if isinstance(item, pytest.Function) and _current_class is None:
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.cls,
+ )
+ _current_class = item.getparent(pytest.Class)
+
+
+@pytest.hookimpl(hookwrapper=True)
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
+ from sqlalchemy.util import string_types
+
+ asyncio._maybe_async(plugin_base.after_test, item)
+
+ yield
+ # this is now after all the fixture teardown have run, the class can be
+ # finalized. Since pytest v7 this finalizer can no longer be added in
+ # pytest_runtest_setup since the class has not yet been setup at that
+ # time.
+ # See https://github.com/pytest-dev/pytest/issues/9343
+ global _current_class, _current_report
+
+ if _current_class is not None and (
+ # last test or a new class
+ nextitem is None
+ or nextitem.getparent(pytest.Class) is not _current_class
+ ):
+ _current_class = None
+
+ try:
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures, item.cls
+ )
+ except Exception as e:
+ # in case of an exception during teardown attach the original
+ # error to the exception message, otherwise it will get lost
+ if _current_report.failed:
+ if not e.args:
+ e.args = (
+ "__Original test failure__:\n"
+ + _current_report.longreprtext,
+ )
+ elif e.args[-1] and isinstance(e.args[-1], string_types):
+ args = list(e.args)
+ args[-1] += (
+ "\n__Original test failure__:\n"
+ + _current_report.longreprtext
+ )
+ e.args = tuple(args)
+ else:
+ e.args += (
+ "__Original test failure__",
+ _current_report.longreprtext,
+ )
+ raise
+ finally:
+ _current_report = None
+
+
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
+
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.module.__name__,
+ item.cls,
+ item.name,
+ )
+
+
+_current_report = None
+
+
+def pytest_runtest_logreport(report):
+ global _current_report
+ if report.when == "call":
+ _current_report = report
+
+
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
+
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # before this fixture runs:
+
+ # 1. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # 3. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # inside the yield:
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 8. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 10. run xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
+
+ # 11. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ # 12. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+
+def getargspec(fn):
+ if sys.version_info.major == 3:
+ return inspect.getfullargspec(fn)
+ else:
+ return inspect.getargspec(fn)
+
+
+def _pytest_fn_decorator(target):
+ """Port of langhelpers.decorator with pytest-specific tricks."""
+
+ from sqlalchemy.util.langhelpers import format_argspec_plus
+ from sqlalchemy.util.compat import inspect_getfullargspec
+
+ def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+ def decorate(fn, add_positional_parameters=()):
+
+ spec = inspect_getfullargspec(fn)
+ if add_positional_parameters:
+ spec.args.extend(add_positional_parameters)
+
+ metadata = dict(
+ __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
+ )
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
+ )
+ if not add_positional_parameters:
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+ else:
+ # this is the pytest hacky part. don't do a full update wrapper
+ # because pytest is really being sneaky about finding the args
+ # for the wrapped function
+ decorated.__module__ = fn.__module__
+ decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
+ return decorated
+
+ return decorate
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+ def skip_test_exception(self, *arg, **kw):
+ return pytest.skip.Exception(*arg, **kw)
+
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
+ _combination_id_fns = {
+ "i": lambda obj: obj,
+ "r": repr,
+ "s": str,
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
+ }
+
+ def combinations(self, *arg_sets, **kw):
+ """Facade for pytest.mark.parametrize.
+
+ Automatically derives argument names from the callable which in our
+ case is always a method on a class with positional arguments.
+
+ ids for parameter sets are derived using an optional template.
+
+ """
+ from sqlalchemy.testing import exclusions
+
+ if sys.version_info.major == 3:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+ arg_sets = list(arg_sets[0])
+ else:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+ arg_sets = list(arg_sets[0])
+
+ argnames = kw.pop("argnames", None)
+
+ def _filter_exclusions(args):
+ result = []
+ gathered_exclusions = []
+ for a in args:
+ if isinstance(a, exclusions.compound):
+ gathered_exclusions.append(a)
+ else:
+ result.append(a)
+
+ return result, gathered_exclusions
+
+ id_ = kw.pop("id_", None)
+
+ tobuild_pytest_params = []
+ has_exclusions = False
+ if id_:
+ _combination_id_fns = self._combination_id_fns
+
+ # because itemgetter is not consistent for one argument vs.
+ # multiple, make it multiple in all cases and use a slice
+ # to omit the first argument
+ _arg_getter = operator.itemgetter(
+ 0,
+ *[
+ idx
+ for idx, char in enumerate(id_)
+ if char in ("n", "r", "s", "a")
+ ]
+ )
+ fns = [
+ (operator.itemgetter(idx), _combination_id_fns[char])
+ for idx, char in enumerate(id_)
+ if char in _combination_id_fns
+ ]
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ parameters = _arg_getter(fn_params)[1:]
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (
+ parameters,
+ param_exclusions,
+ "-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ ),
+ )
+ )
+
+ else:
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (fn_params, param_exclusions, None)
+ )
+
+ pytest_params = []
+ for parameters, param_exclusions, id_ in tobuild_pytest_params:
+ if has_exclusions:
+ parameters += (param_exclusions,)
+
+ param = pytest.param(*parameters, id=id_)
+ pytest_params.append(param)
+
+ def decorate(fn):
+ if inspect.isclass(fn):
+ if has_exclusions:
+ raise NotImplementedError(
+ "exclusions not supported for class level combinations"
+ )
+ if "_sa_parametrize" not in fn.__dict__:
+ fn._sa_parametrize = []
+ fn._sa_parametrize.append((argnames, pytest_params))
+ return fn
+ else:
+ if argnames is None:
+ _argnames = getargspec(fn).args[1:]
+ else:
+ _argnames = re.split(r", *", argnames)
+
+ if has_exclusions:
+ _argnames += ["_exclusions"]
+
+ @_pytest_fn_decorator
+ def check_exclusions(fn, *args, **kw):
+ _exclusions = args[-1]
+ if _exclusions:
+ exlu = exclusions.compound().add(*_exclusions)
+ fn = exlu(fn)
+ return fn(*args[0:-1], **kw)
+
+ def process_metadata(spec):
+ spec.args.append("_exclusions")
+
+ fn = check_exclusions(
+ fn, add_positional_parameters=("_exclusions",)
+ )
+
+ return pytest.mark.parametrize(_argnames, pytest_params)(fn)
+
+ return decorate
+
+ def param_ident(self, *parameters):
+ ident = parameters[0]
+ return pytest.param(*parameters[1:], id=ident)
+
+ def fixture(self, *arg, **kw):
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import asyncio
+
+ # wrapping pytest.fixture function. determine if
+ # decorator was called as @fixture or @fixture().
+ if len(arg) > 0 and callable(arg[0]):
+ # was called as @fixture(), we have the function to wrap.
+ fn = arg[0]
+ arg = arg[1:]
+ else:
+ # was called as @fixture, don't have the function yet.
+ fn = None
+
+ # create a pytest.fixture marker. because the fn is not being
+ # passed, this is always a pytest.FixtureFunctionMarker()
+ # object (or whatever pytest is calling it when you read this)
+ # that is waiting for a function.
+ fixture = pytest.fixture(*arg, **kw)
+
+ # now apply wrappers to the function, including fixture itself
+
+ def wrap(fn):
+ if config.any_async:
+ fn = asyncio._maybe_async_wrapper(fn)
+ # other wrappers may be added here
+
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
+ return fn
+
+ if fn:
+ return wrap(fn)
+ else:
+ return wrap
+
+ def get_current_test_name(self):
+ return os.environ.get("PYTEST_CURRENT_TEST")
+
+ def async_test(self, fn):
+ from sqlalchemy.testing import asyncio
+
+ @_pytest_fn_decorator
+ def decorate(fn, *args, **kwargs):
+ asyncio._run_coroutine_function(fn, *args, **kwargs)
+
+ return decorate(fn)
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644
index 0000000..36b6841
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
new file mode 100644
index 0000000..4132630
--- /dev/null
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -0,0 +1,335 @@
+# testing/profiling.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
+
+import collections
+import contextlib
+import os
+import platform
+import pstats
+import re
+import sys
+
+from . import config
+from .util import gc_collect
+from ..util import has_compiled_ext
+
+
+try:
+ import cProfile
+except ImportError:
+ cProfile = None
+
+_profile_stats = None
+"""global ProfileStatsFileInstance.
+
+plugin_base assigns this at the start of all tests.
+
+"""
+
+
+_current_test = None
+"""String id of current test.
+
+plugin_base assigns this at the start of each test using
+_start_current_test.
+
+"""
+
+
+def _start_current_test(id_):
+ global _current_test
+ _current_test = id_
+
+ if _profile_stats.force_write:
+ _profile_stats.reset_count()
+
+
+class ProfileStatsFile(object):
+ """Store per-platform/fn profiling results in a file.
+
+ There was no json module available when this was written, but now
+ the file format which is very deterministically line oriented is kind of
+ handy in any case for diffs and merges.
+
+ """
+
+ def __init__(self, filename, sort="cumulative", dump=None):
+ self.force_write = (
+ config.options is not None and config.options.force_write_profiles
+ )
+ self.write = self.force_write or (
+ config.options is not None and config.options.write_profiles
+ )
+ self.fname = os.path.abspath(filename)
+ self.short_fname = os.path.split(self.fname)[-1]
+ self.data = collections.defaultdict(
+ lambda: collections.defaultdict(dict)
+ )
+ self.dump = dump
+ self.sort = sort
+ self._read()
+ if self.write:
+ # rewrite for the case where features changed,
+ # etc.
+ self._write()
+
+ @property
+ def platform_key(self):
+
+ dbapi_key = config.db.name + "_" + config.db.driver
+
+ if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
+ config.db.url
+ ):
+ dbapi_key += "_file"
+
+ # keep it at 2.7, 3.1, 3.2, etc. for now.
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
+
+ platform_tokens = [
+ platform.machine(),
+ platform.system().lower(),
+ platform.python_implementation().lower(),
+ py_version,
+ dbapi_key,
+ ]
+
+ platform_tokens.append(
+ "nativeunicode"
+ if config.db.dialect.convert_unicode
+ else "dbapiunicode"
+ )
+ _has_cext = has_compiled_ext()
+ platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
+ return "_".join(platform_tokens)
+
+ def has_stats(self):
+ test_key = _current_test
+ return (
+ test_key in self.data and self.platform_key in self.data[test_key]
+ )
+
+ def result(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
+ else:
+ counts = per_platform["counts"]
+
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
+ else:
+ current_count = per_platform["current_count"]
+
+ has_count = len(counts) > current_count
+
+ if not has_count:
+ counts.append(callcount)
+ if self.write:
+ self._write()
+ result = None
+ else:
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
+ return result
+
+ def reset_count(self):
+ test_key = _current_test
+ # since self.data is a defaultdict, don't access a key
+ # if we don't know it's there first.
+ if test_key not in self.data:
+ return
+ per_fn = self.data[test_key]
+ if self.platform_key not in per_fn:
+ return
+ per_platform = per_fn[self.platform_key]
+ if "counts" in per_platform:
+ per_platform["counts"][:] = []
+
+ def replace(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
+ if current_count < len(counts):
+ counts[current_count - 1] = callcount
+ else:
+ counts[-1] = callcount
+ if self.write:
+ self._write()
+
+ def _header(self):
+ return (
+ "# %s\n"
+ "# This file is written out on a per-environment basis.\n"
+ "# For each test in aaa_profiling, the corresponding "
+ "function and \n"
+ "# environment is located within this file. "
+ "If it doesn't exist,\n"
+ "# the test is skipped.\n"
+ "# If a callcount does exist, it is compared "
+ "to what we received. \n"
+ "# assertions are raised if the counts do not match.\n"
+ "# \n"
+ "# To add a new callcount test, apply the function_call_count \n"
+ "# decorator and re-run the tests using the --write-profiles \n"
+ "# option - this file will be rewritten including the new count.\n"
+ "# \n"
+ ) % (self.fname)
+
+ def _read(self):
+ try:
+ profile_f = open(self.fname)
+ except IOError:
+ return
+ for lineno, line in enumerate(profile_f):
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+
+ test_key, platform_key, counts = line.split()
+ per_fn = self.data[test_key]
+ per_platform = per_fn[platform_key]
+ c = [int(count) for count in counts.split(",")]
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
+ profile_f.close()
+
+ def _write(self):
+ print(("Writing profile file %s" % self.fname))
+ profile_f = open(self.fname, "w")
+ profile_f.write(self._header())
+ for test_key in sorted(self.data):
+
+ per_fn = self.data[test_key]
+ profile_f.write("\n# TEST: %s\n\n" % test_key)
+ for platform_key in sorted(per_fn):
+ per_platform = per_fn[platform_key]
+ c = ",".join(str(count) for count in per_platform["counts"])
+ profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
+ profile_f.close()
+
+
+def function_call_count(variance=0.05, times=1, warmup=0):
+ """Assert a target for a test case's function call count.
+
+ The main purpose of this assertion is to detect changes in
+ callcounts for various functions - the actual number is not as important.
+ Callcounts are stored in a file keyed to Python version and OS platform
+ information. This file is generated automatically for new tests,
+ and versioned so that unexpected changes in callcounts will be detected.
+
+ """
+
+ # use signature-rewriting decorator function so that pytest fixtures
+ # still work on py27. In Py3, update_wrapper() alone is good enough,
+ # likely due to the introduction of __signature__.
+
+ from sqlalchemy.util import decorator
+ from sqlalchemy.util import deprecations
+ from sqlalchemy.engine import row
+ from sqlalchemy.testing import mock
+
+ @decorator
+ def wrap(fn, *args, **kw):
+
+ with mock.patch.object(
+ deprecations, "SQLALCHEMY_WARN_20", False
+ ), mock.patch.object(
+ row.LegacyRow, "_default_key_style", row.KEY_OBJECTS_NO_WARN
+ ):
+ for warm in range(warmup):
+ fn(*args, **kw)
+
+ timerange = range(times)
+ with count_functions(variance=variance):
+ for time in timerange:
+ rv = fn(*args, **kw)
+ return rv
+
+ return wrap
+
+
+@contextlib.contextmanager
+def count_functions(variance=0.05):
+ if cProfile is None:
+ raise config._skip_test_exception("cProfile is not installed")
+
+ if not _profile_stats.has_stats() and not _profile_stats.write:
+ config.skip_test(
+ "No profiling stats available on this "
+ "platform for this function. Run tests with "
+ "--write-profiles to add statistics to %s for "
+ "this platform." % _profile_stats.short_fname
+ )
+
+ gc_collect()
+
+ pr = cProfile.Profile()
+ pr.enable()
+ # began = time.time()
+ yield
+ # ended = time.time()
+ pr.disable()
+
+ # s = compat.StringIO()
+ stats = pstats.Stats(pr, stream=sys.stdout)
+
+ # timespent = ended - began
+ callcount = stats.total_calls
+
+ expected = _profile_stats.result(callcount)
+
+ if expected is None:
+ expected_count = None
+ else:
+ line_no, expected_count = expected
+
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
+ stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
+ stats.print_stats()
+ if _profile_stats.dump:
+ base, ext = os.path.splitext(_profile_stats.dump)
+ test_name = _current_test.split(".")[-1]
+ dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
+ stats.dump_stats(dumpfile)
+ print("Dumped stats to file %s" % dumpfile)
+ # stats.print_callers()
+ if _profile_stats.force_write:
+ _profile_stats.replace(callcount)
+ elif expected_count:
+ deviance = int(callcount * variance)
+ failed = abs(callcount - expected_count) > deviance
+
+ if failed:
+ if _profile_stats.write:
+ _profile_stats.replace(callcount)
+ else:
+ raise AssertionError(
+ "Adjusted function call count %s not within %s%% "
+ "of expected %s, platform %s. Rerun with "
+ "--write-profiles to "
+ "regenerate this callcount."
+ % (
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
new file mode 100644
index 0000000..90c4d93
--- /dev/null
+++ b/lib/sqlalchemy/testing/provision.py
@@ -0,0 +1,416 @@
+import collections
+import logging
+
+from . import config
+from . import engines
+from . import util
+from .. import exc
+from .. import inspect
+from ..engine import url as sa_url
+from ..sql import ddl
+from ..sql import schema
+from ..util import compat
+
+
+log = logging.getLogger(__name__)
+
+FOLLOWER_IDENT = None
+
+
+class register(object):
+ def __init__(self):
+ self.fns = {}
+
+ @classmethod
+ def init(cls, fn):
+ return register().for_db("*")(fn)
+
+ def for_db(self, *dbnames):
+ def decorate(fn):
+ for dbname in dbnames:
+ self.fns[dbname] = fn
+ return self
+
+ return decorate
+
+ def __call__(self, cfg, *arg):
+ if isinstance(cfg, compat.string_types):
+ url = sa_url.make_url(cfg)
+ elif isinstance(cfg, sa_url.URL):
+ url = cfg
+ else:
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ if backend in self.fns:
+ return self.fns[backend](cfg, *arg)
+ else:
+ return self.fns["*"](cfg, *arg)
+
+
+def create_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
+ create_db(cfg, cfg.db, follower_ident)
+
+
+def setup_config(db_url, options, file_config, follower_ident):
+ # load the dialect, which should also have it set up its provision
+ # hooks
+
+ dialect = sa_url.make_url(db_url).get_dialect()
+ dialect.load_provisioning()
+
+ if follower_ident:
+ db_url = follower_url_from_main(db_url, follower_ident)
+ db_opts = {}
+ update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
+ eng = engines.testing_engine(db_url, db_opts)
+ post_configure_engine(db_url, eng, follower_ident)
+ eng.connect().close()
+
+ cfg = config.Config.register(eng, db_opts, options, file_config)
+
+ # a symbolic name that tests can use if they need to disambiguate
+ # names across databases
+ if follower_ident:
+ config.ident = follower_ident
+
+ if follower_ident:
+ configure_follower(cfg, follower_ident)
+ return cfg
+
+
+def drop_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
+ drop_db(cfg, cfg.db, follower_ident)
+
+
+def generate_db_urls(db_urls, extra_drivers):
+ """Generate a set of URLs to test given configured URLs plus additional
+ driver names.
+
+ Given::
+
+ --dburi postgresql://db1 \
+ --dburi postgresql://db2 \
+ --dburi postgresql://db2 \
+ --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
+
+ Noting that the default postgresql driver is psycopg2, the output
+ would be::
+
+ postgresql+psycopg2://db1
+ postgresql+asyncpg://db1
+ postgresql+psycopg2://db2
+ postgresql+psycopg2://db3
+
+ That is, for the driver in a --dburi, we want to keep that and use that
+ driver for each URL it's part of . For a driver that is only
+ in --dbdrivers, we want to use it just once for one of the URLs.
+ for a driver that is both coming from --dburi as well as --dbdrivers,
+ we want to keep it in that dburi.
+
+ Driver specific query options can be specified by added them to the
+ driver name. For example, to enable the async fallback option for
+ asyncpg::
+
+ --dburi postgresql://db1 \
+ --dbdriver=asyncpg?async_fallback=true
+
+ """
+ urls = set()
+
+ backend_to_driver_we_already_have = collections.defaultdict(set)
+
+ urls_plus_dialects = [
+ (url_obj, url_obj.get_dialect())
+ for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
+ ]
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
+
+ backend_to_driver_we_need = {}
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend = dialect.name
+ dialect.load_provisioning()
+
+ if backend not in backend_to_driver_we_need:
+ backend_to_driver_we_need[backend] = extra_per_backend = set(
+ extra_drivers
+ ).difference(backend_to_driver_we_already_have[backend])
+ else:
+ extra_per_backend = backend_to_driver_we_need[backend]
+
+ for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
+ if driver_url in urls:
+ continue
+ urls.add(driver_url)
+ yield driver_url
+
+
+def _generate_driver_urls(url, extra_drivers):
+ main_driver = url.get_driver_name()
+ extra_drivers.discard(main_driver)
+
+ url = generate_driver_url(url, main_driver, "")
+ yield str(url)
+
+ for drv in list(extra_drivers):
+
+ if "?" in drv:
+
+ driver_only, query_str = drv.split("?", 1)
+
+ else:
+ driver_only = drv
+ query_str = None
+
+ new_url = generate_driver_url(url, driver_only, query_str)
+ if new_url:
+ extra_drivers.remove(drv)
+
+ yield str(new_url)
+
+
+@register.init
+def generate_driver_url(url, driver, query_str):
+ backend = url.get_backend_name()
+
+ new_url = url.set(
+ drivername="%s+%s" % (backend, driver),
+ )
+ if query_str:
+ new_url = new_url.update_query_string(query_str)
+
+ try:
+ new_url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return new_url
+
+
+def _configs_for_db_operation():
+ hosts = set()
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+ for cfg in config.Config.all_configs():
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ host_conf = (backend, url.username, url.host, url.database)
+
+ if host_conf not in hosts:
+ yield cfg
+ hosts.add(host_conf)
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+
+@register.init
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ pass
+
+
+@register.init
+def drop_all_schema_objects_post_tables(cfg, eng):
+ pass
+
+
+def drop_all_schema_objects(cfg, eng):
+
+ drop_all_schema_objects_pre_tables(cfg, eng)
+
+ inspector = inspect(eng)
+ try:
+ view_names = inspector.get_view_names()
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(schema.Table(vname, schema.MetaData()))
+ )
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ try:
+ view_names = inspector.get_view_names(schema="test_schema")
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
+
+ util.drop_all_tables(eng, inspector)
+ if config.requirements.schemas.enabled_for_config(cfg):
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
+
+ drop_all_schema_objects_post_tables(cfg, eng)
+
+ if config.requirements.sequences.enabled_for_config(cfg):
+ with eng.begin() as conn:
+ for seq in inspector.get_sequence_names():
+ conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
+
+
+@register.init
+def create_db(cfg, eng, ident):
+ """Dynamically create a database for testing.
+
+ Used when a test run will employ multiple processes, e.g., when run
+ via `tox` or `pytest -n4`.
+ """
+ raise NotImplementedError(
+ "no DB creation routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def drop_db(cfg, eng, ident):
+ """Drop a database that we dynamically created for testing."""
+ raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
+
+
+@register.init
+def update_db_opts(db_url, db_opts):
+ """Set database options (db_opts) for a test database that we created."""
+ pass
+
+
+@register.init
+def post_configure_engine(url, engine, follower_ident):
+ """Perform extra steps after configuring an engine for testing.
+
+ (For the internal dialects, currently only used by sqlite, oracle)
+ """
+ pass
+
+
+@register.init
+def follower_url_from_main(url, ident):
+ """Create a connection URL for a dynamically-created test database.
+
+ :param url: the connection URL specified when the test run was invoked
+ :param ident: the pytest-xdist "worker identifier" to be used as the
+ database name
+ """
+ url = sa_url.make_url(url)
+ return url.set(database=ident)
+
+
+@register.init
+def configure_follower(cfg, ident):
+ """Create dialect-specific config settings for a follower database."""
+ pass
+
+
+@register.init
+def run_reap_dbs(url, ident):
+ """Remove databases that were created during the test process, after the
+ process has ended.
+
+ This is an optional step that is invoked for certain backends that do not
+ reliably release locks on the database as long as a process is still in
+ use. For the internal dialects, this is currently only necessary for
+ mssql and oracle.
+ """
+ pass
+
+
+def reap_dbs(idents_file):
+ log.info("Reaping databases...")
+
+ urls = collections.defaultdict(set)
+ idents = collections.defaultdict(set)
+ dialects = {}
+
+ with open(idents_file) as file_:
+ for line in file_:
+ line = line.strip()
+ db_name, db_url = line.split(" ")
+ url_obj = sa_url.make_url(db_url)
+ if db_name not in dialects:
+ dialects[db_name] = url_obj.get_dialect()
+ dialects[db_name].load_provisioning()
+ url_key = (url_obj.get_backend_name(), url_obj.host)
+ urls[url_key].add(db_url)
+ idents[url_key].add(db_name)
+
+ for url_key in urls:
+ url = list(urls[url_key])[0]
+ ident = idents[url_key]
+ run_reap_dbs(url, ident)
+
+
+@register.init
+def temp_table_keyword_args(cfg, eng):
+ """Specify keyword arguments for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ kwargs that are passed to the Table method when creating a temporary
+ table for testing, e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+ """
+ raise NotImplementedError(
+ "no temp table keyword args routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
+ pass
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+ """Specify table name for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ name to use when creating a temporary table for testing,
+ e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+
+ Default to just the base name since that's what most dialects will
+ use. The mssql dialect's implementation will need a "#" prepended.
+ """
+ return base_name
+
+
+@register.init
+def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
+ raise NotImplementedError(
+ "backend does not implement a schema name set function: %s"
+ % (cfg.db.url,)
+ )
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
new file mode 100644
index 0000000..857d1fd
--- /dev/null
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -0,0 +1,1518 @@
+# testing/requirements.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+External dialect test suites should subclass SuiteRequirements
+to provide specific inclusion/exclusions.
+
+"""
+
+import platform
+import sys
+
+from . import exclusions
+from . import only_on
+from .. import util
+from ..pool import QueuePool
+
+
+class Requirements(object):
+ pass
+
+
+class SuiteRequirements(Requirements):
+ @property
+ def create_table(self):
+ """target platform can emit basic CreateTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def drop_table(self):
+ """target platform can emit basic DropTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def table_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for tables."""
+
+ return exclusions.closed()
+
+ @property
+ def index_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for indexes."""
+
+ return exclusions.closed()
+
+ @property
+ def foreign_keys(self):
+ """Target database must support foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def table_value_constructor(self):
+ """Database / dialect supports a query like::
+
+ SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...)
+ AS some_table(col1, col2)
+
+ SQLAlchemy generates this with the :func:`_sql.values` function.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def standard_cursor_sql(self):
+ """Target database passes SQL-92 style statements to cursor.execute()
+ when a statement like select() or insert() is run.
+
+ A very small portion of dialect-level tests will ensure that certain
+ conditions are present in SQL strings, and these tests use very basic
+ SQL that will work on any SQL-like platform in order to assert results.
+
+ It's normally a given for any pep-249 DBAPI that a statement like
+ "SELECT id, name FROM table WHERE some_table.id=5" will work.
+ However, there are dialects that don't actually produce SQL Strings
+ and instead may work with symbolic objects instead, or dialects that
+ aren't working with SQL, so for those this requirement can be marked
+ as excluded.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def on_update_cascade(self):
+ """target database must support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def non_updating_cascade(self):
+ """target database must *not* support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+ return exclusions.closed()
+
+ @property
+ def deferrable_fks(self):
+ return exclusions.closed()
+
+ @property
+ def on_update_or_deferrable_fks(self):
+ # TODO: exclusions should be composable,
+ # somehow only_if([x, y]) isn't working here, negation/conjunctions
+ # getting confused.
+ return exclusions.only_if(
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
+ )
+
+ @property
+ def queue_pool(self):
+ """target database is using QueuePool"""
+
+ def go(config):
+ return isinstance(config.db.pool, QueuePool)
+
+ return exclusions.only_if(go)
+
+ @property
+ def self_referential_foreign_keys(self):
+ """Target database must support self-referential foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def foreign_key_ddl(self):
+ """Target database must support the DDL phrases for FOREIGN KEY."""
+
+ return exclusions.open()
+
+ @property
+ def named_constraints(self):
+ """target database must support names for constraints."""
+
+ return exclusions.open()
+
+ @property
+ def implicitly_named_constraints(self):
+ """target database must apply names to unnamed constraints."""
+
+ return exclusions.open()
+
+ @property
+ def subqueries(self):
+ """Target database must support subqueries."""
+
+ return exclusions.open()
+
+ @property
+ def offset(self):
+ """target database can render OFFSET, or an equivalent, in a
+ SELECT.
+ """
+
+ return exclusions.open()
+
+ @property
+ def bound_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET using a bound
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def sql_expression_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET with a complete
+ SQL expression, such as one that uses the addition operator.
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_w_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when LIMIT/OFFSET is specifically present.
+
+ E.g. (SELECT ...) UNION (SELECT ..)
+
+ This is known to fail on SQLite.
+
+ """
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_wo_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when OFFSET/LIMIT is specifically not present.
+
+ E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..)
+
+ This is known to fail on SQLite. It also fails on Oracle
+ because without LIMIT/OFFSET, there is currently no step that
+ creates an additional subquery.
+
+ """
+ return exclusions.open()
+
+ @property
+ def boolean_col_expressions(self):
+ """Target database must support boolean expressions as columns"""
+
+ return exclusions.closed()
+
+ @property
+ def nullable_booleans(self):
+ """Target database allows boolean columns to store NULL."""
+
+ return exclusions.open()
+
+ @property
+ def nullsordering(self):
+ """Target backends that support nulls ordering."""
+
+ return exclusions.closed()
+
+ @property
+ def standalone_binds(self):
+ """target database/driver supports bound parameters as column
+ expressions without being in the context of a typed column.
+ """
+ return exclusions.closed()
+
+ @property
+ def standalone_null_binds_whereclause(self):
+ """target database/driver supports bound parameters with NULL in the
+ WHERE clause, in situations where it has to be typed.
+
+ """
+ return exclusions.open()
+
+ @property
+ def intersect(self):
+ """Target database must support INTERSECT or equivalent."""
+ return exclusions.closed()
+
+ @property
+ def except_(self):
+ """Target database must support EXCEPT or equivalent (i.e. MINUS)."""
+ return exclusions.closed()
+
+ @property
+ def window_functions(self):
+ """Target database must support window functions."""
+ return exclusions.closed()
+
+ @property
+ def ctes(self):
+ """Target database supports CTEs"""
+
+ return exclusions.closed()
+
+ @property
+ def ctes_with_update_delete(self):
+ """target database supports CTES that ride on top of a normal UPDATE
+ or DELETE statement which refers to the CTE in a correlated subquery.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def ctes_on_dml(self):
+ """target database supports CTES which consist of INSERT, UPDATE
+ or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
+
+ return exclusions.closed()
+
+ @property
+ def autoincrement_insert(self):
+ """target platform generates new surrogate integer primary key values
+ when insert() is executed, excluding the pk column."""
+
+ return exclusions.open()
+
+ @property
+ def fetch_rows_post_commit(self):
+ """target platform will allow cursor.fetchone() to proceed after a
+ COMMIT.
+
+ Typically this refers to an INSERT statement with RETURNING which
+ is invoked within "autocommit". If the row can be returned
+ after the autocommit, then this rule can be open.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def group_by_complex_expression(self):
+ """target platform supports SQL expressions in GROUP BY
+
+ e.g.
+
+ SELECT x + y AS somelabel FROM table GROUP BY x + y
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def sane_rowcount(self):
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_sane_rowcount,
+ "driver doesn't support 'sane' rowcount",
+ )
+
+ @property
+ def sane_multi_rowcount(self):
+ return exclusions.fails_if(
+ lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
+ )
+
+ @property
+ def sane_rowcount_w_returning(self):
+ return exclusions.fails_if(
+ lambda config: not (
+ config.db.dialect.supports_sane_rowcount_returning
+ ),
+ "driver doesn't support 'sane' rowcount when returning is on",
+ )
+
+ @property
+ def empty_inserts(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent."""
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values
+ or config.db.dialect.supports_default_metavalue,
+ "empty inserts not supported",
+ )
+
+ @property
+ def empty_inserts_executemany(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent, within executemany()"""
+
+ return self.empty_inserts
+
+ @property
+ def insert_from_select(self):
+ """target platform supports INSERT from a SELECT."""
+
+ return exclusions.open()
+
+ @property
+ def full_returning(self):
+ """target platform supports RETURNING completely, including
+ multiple rows returned.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.full_returning,
+ "%(database)s %(does_support)s 'RETURNING of multiple rows'",
+ )
+
+ @property
+ def insert_executemany_returning(self):
+ """target platform supports RETURNING when INSERT is used with
+ executemany(), e.g. multiple parameter sets, indicating
+ as many rows come back as do parameter sets were passed.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.insert_executemany_returning,
+ "%(database)s %(does_support)s 'RETURNING of "
+ "multiple rows with INSERT executemany'",
+ )
+
+ @property
+ def returning(self):
+ """target platform supports RETURNING for at least one row.
+
+ .. seealso::
+
+ :attr:`.Requirements.full_returning`
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.implicit_returning,
+ "%(database)s %(does_support)s 'RETURNING of a single row'",
+ )
+
+ @property
+ def tuple_in(self):
+ """Target platform supports the syntax
+ "(x, y) IN ((x1, y1), (x2, y2), ...)"
+ """
+
+ return exclusions.closed()
+
+ @property
+ def tuple_in_w_empty(self):
+ """Target platform tuple IN w/ empty set"""
+ return self.tuple_in
+
+ @property
+ def duplicate_names_in_cursor_description(self):
+ """target platform supports a SELECT statement that has
+ the same name repeated more than once in the columns list."""
+
+ return exclusions.open()
+
+ @property
+ def denormalized_names(self):
+ """Target database must have 'denormalized', i.e.
+ UPPERCASE as case insensitive names."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.requires_name_normalize,
+ "Backend does not require denormalized names.",
+ )
+
+ @property
+ def multivalues_inserts(self):
+ """target database must support multiple VALUES clauses in an
+ INSERT statement."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_multivalues_insert,
+ "Backend does not support multirow inserts.",
+ )
+
+ @property
+ def implements_get_lastrowid(self):
+ """target dialect implements the executioncontext.get_lastrowid()
+ method without reliance on RETURNING.
+
+ """
+ return exclusions.open()
+
+ @property
+ def emulated_lastrowid(self):
+ """target dialect retrieves cursor.lastrowid, or fetches
+ from a database-side function after an insert() construct executes,
+ within the get_lastrowid() method.
+
+ Only dialects that "pre-execute", or need RETURNING to get last
+ inserted id, would return closed/fail/skip for this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def emulated_lastrowid_even_with_sequences(self):
+ """target dialect retrieves cursor.lastrowid or an equivalent
+ after an insert() construct executes, even if the table has a
+ Sequence on it.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def dbapi_lastrowid(self):
+ """target platform includes a 'lastrowid' accessor on the DBAPI
+ cursor object.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def views(self):
+ """Target database must support VIEWs."""
+
+ return exclusions.closed()
+
+ @property
+ def schemas(self):
+ """Target database must support external schemas, and have one
+ named 'test_schema'."""
+
+ return only_on(lambda config: config.db.dialect.supports_schemas)
+
+ @property
+ def cross_schema_fk_reflection(self):
+ """target system must support reflection of inter-schema
+ foreign keys"""
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_name_reflection(self):
+ """Target supports refleciton of FOREIGN KEY constraints and
+ will return the name of the constraint that was used in the
+ "CONSTRAINT <name> FOREIGN KEY" DDL.
+
+ MySQL prior to version 8 and MariaDB prior to version 10.5
+ don't support this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def implicit_default_schema(self):
+ """target system has a strong concept of 'default' schema that can
+ be referred to implicitly.
+
+ basically, PostgreSQL.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def default_schema_name_switch(self):
+ """target dialect implements provisioning module including
+ set_default_schema_on_connection"""
+
+ return exclusions.closed()
+
+ @property
+ def server_side_cursors(self):
+ """Target dialect must support server side cursors."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
+
+ @property
+ def sequences(self):
+ """Target database must support SEQUENCEs."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
+
+ @property
+ def no_sequences(self):
+ """the opposite of "sequences", DB does not support sequences at
+ all."""
+
+ return exclusions.NotPredicate(self.sequences)
+
+ @property
+ def sequences_optional(self):
+ """Target database supports sequences, but also optionally
+ as a means of generating new PK values."""
+
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
+
+ @property
+ def supports_lastrowid(self):
+ """target database / driver supports cursor.lastrowid as a means
+ of retrieving the last inserted primary key value.
+
+ note that if the target DB supports sequences also, this is still
+ assumed to work. This is a new use case brought on by MariaDB 10.3.
+
+ """
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def no_lastrowid_support(self):
+ """the opposite of supports_lastrowid"""
+ return exclusions.only_if(
+ [lambda config: not config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def reflects_pk_names(self):
+ return exclusions.closed()
+
+ @property
+ def table_reflection(self):
+ """target database has general support for table reflection"""
+ return exclusions.open()
+
+ @property
+ def reflect_tables_no_columns(self):
+ """target database supports creation and reflection of tables with no
+ columns, or at least tables that seem to have no columns."""
+
+ return exclusions.closed()
+
+ @property
+ def comment_reflection(self):
+ return exclusions.closed()
+
+ @property
+ def view_column_reflection(self):
+ """target database must support retrieval of the columns in a view,
+ similarly to how a table is inspected.
+
+ This does not include the full CREATE VIEW definition.
+
+ """
+ return self.views
+
+ @property
+ def view_reflection(self):
+ """target database must support inspection of the full CREATE VIEW
+ definition."""
+ return self.views
+
+ @property
+ def schema_reflection(self):
+ return self.schemas
+
+ @property
+ def primary_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_option_reflection_ondelete(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_noaction(self):
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_option_reflection_onupdate(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_onupdate_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def temp_table_reflection(self):
+ return exclusions.open()
+
+ @property
+ def temp_table_reflect_indexes(self):
+ return self.temp_table_reflection
+
+ @property
+ def temp_table_names(self):
+ """target dialect supports listing of temporary table names"""
+ return exclusions.closed()
+
+ @property
+ def temporary_tables(self):
+ """target database supports temporary tables"""
+ return exclusions.open()
+
+ @property
+ def temporary_views(self):
+ """target database supports temporary views"""
+ return exclusions.closed()
+
+ @property
+ def index_reflection(self):
+ return exclusions.open()
+
+ @property
+ def index_reflects_included_columns(self):
+ return exclusions.closed()
+
+ @property
+ def indexes_with_ascdesc(self):
+ """target database supports CREATE INDEX with per-column ASC/DESC."""
+ return exclusions.open()
+
+ @property
+ def indexes_with_expressions(self):
+ """target database supports CREATE INDEX against SQL expressions."""
+ return exclusions.closed()
+
+ @property
+ def unique_constraint_reflection(self):
+ """target dialect supports reflection of unique constraints"""
+ return exclusions.open()
+
+ @property
+ def check_constraint_reflection(self):
+ """target dialect supports reflection of check constraints"""
+ return exclusions.closed()
+
+ @property
+ def duplicate_key_raises_integrity_error(self):
+ """target dialect raises IntegrityError when reporting an INSERT
+ with a primary key violation. (hint: it should)
+
+ """
+ return exclusions.open()
+
+ @property
+ def unbounded_varchar(self):
+ """Target database must support VARCHAR with no length"""
+
+ return exclusions.open()
+
+ @property
+ def unicode_data(self):
+ """Target database/dialect must support Python unicode objects with
+ non-ASCII characters represented, delivered as bound parameters
+ as well as in result rows.
+
+ """
+ return exclusions.open()
+
+ @property
+ def unicode_ddl(self):
+ """Target driver must support some degree of non-ascii symbol
+ names.
+ """
+ return exclusions.closed()
+
+ @property
+ def symbol_names_w_double_quote(self):
+ """Target driver can create tables with a name like 'some " table'"""
+ return exclusions.open()
+
+ @property
+ def datetime_literals(self):
+ """target dialect supports rendering of a date, time, or datetime as a
+ literal string, e.g. via the TypeEngine.literal_processor() method.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def datetime(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects."""
+
+ return exclusions.open()
+
+ @property
+ def datetime_timezone(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with tzinfo with DateTime(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def time_timezone(self):
+ """target dialect supports representation of Python
+ datetime.time() with tzinfo with Time(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def datetime_implicit_bound(self):
+ """target dialect when given a datetime object will bind it such
+ that the database server knows the object is a datetime, and not
+ a plain string.
+
+ """
+ return exclusions.open()
+
+ @property
+ def datetime_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def timestamp_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects but only
+ if TIMESTAMP is used."""
+ return exclusions.closed()
+
+ @property
+ def timestamp_microseconds_implicit_bound(self):
+ """target dialect when given a datetime object which also includes
+ a microseconds portion when using the TIMESTAMP data type
+ will bind it such that the database server knows
+ the object is a datetime with microseconds, and not a plain string.
+
+ """
+ return self.timestamp_microseconds
+
+ @property
+ def datetime_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def date(self):
+ """target dialect supports representation of Python
+ datetime.date() objects."""
+
+ return exclusions.open()
+
+ @property
+ def date_coerces_from_datetime(self):
+ """target dialect accepts a datetime object as the target
+ of a date column."""
+
+ return exclusions.open()
+
+ @property
+ def date_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def time(self):
+ """target dialect supports representation of Python
+ datetime.time() objects."""
+
+ return exclusions.open()
+
+ @property
+ def time_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.time() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def binary_comparisons(self):
+ """target database/driver can allow BLOB/BINARY fields to be compared
+ against a bound parameter value.
+ """
+
+ return exclusions.open()
+
+ @property
+ def binary_literals(self):
+ """target backend supports simple binary literals, e.g. an
+ expression like::
+
+ SELECT CAST('foo' AS BINARY)
+
+ Where ``BINARY`` is the type emitted from :class:`.LargeBinary`,
+ e.g. it could be ``BLOB`` or similar.
+
+ Basically fails on Oracle.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def autocommit(self):
+ """target dialect supports 'AUTOCOMMIT' as an isolation_level"""
+ return exclusions.closed()
+
+ @property
+ def isolation_level(self):
+ """target dialect supports general isolation level settings.
+
+ Note that this requirement, when enabled, also requires that
+ the get_isolation_levels() method be implemented.
+
+ """
+ return exclusions.closed()
+
+ def get_isolation_levels(self, config):
+ """Return a structure of supported isolation levels for the current
+ testing dialect.
+
+ The structure indicates to the testing suite what the expected
+ "default" isolation should be, as well as the other values that
+ are accepted. The dictionary has two keys, "default" and "supported".
+ The "supported" key refers to a list of all supported levels and
+ it should include AUTOCOMMIT if the dialect supports it.
+
+ If the :meth:`.DefaultRequirements.isolation_level` requirement is
+ not open, then this method has no return value.
+
+ E.g.::
+
+ >>> testing.requirements.get_isolation_levels()
+ {
+ "default": "READ_COMMITTED",
+ "supported": [
+ "SERIALIZABLE", "READ UNCOMMITTED",
+ "READ COMMITTED", "REPEATABLE READ",
+ "AUTOCOMMIT"
+ ]
+ }
+ """
+
+ @property
+ def json_type(self):
+ """target platform implements a native JSON type."""
+
+ return exclusions.closed()
+
+ @property
+ def json_array_indexes(self):
+ """target platform supports numeric array indexes
+ within a JSON structure"""
+
+ return self.json_type
+
+ @property
+ def json_index_supplementary_unicode_element(self):
+ return exclusions.open()
+
+ @property
+ def legacy_unconditional_json_extract(self):
+ """Backend has a JSON_EXTRACT or similar function that returns a
+ valid JSON string in all cases.
+
+ Used to test a legacy feature and is not needed.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_general(self):
+ """target backend has general support for moderately high-precision
+ numerics."""
+ return exclusions.open()
+
+ @property
+ def precision_numerics_enotation_small(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very small values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_enotation_large(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very large values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_many_significant_digits(self):
+ """target backend supports values with many digits on both sides,
+ such as 319438950232418390.273596, 87673.594069654243
+
+ """
+ return exclusions.closed()
+
+ @property
+ def cast_precision_numerics_many_significant_digits(self):
+ """same as precision_numerics_many_significant_digits but within the
+ context of a CAST statement (hello MySQL)
+
+ """
+ return self.precision_numerics_many_significant_digits
+
+ @property
+ def implicit_decimal_binds(self):
+ """target backend will return a selected Decimal as a Decimal, not
+ a string.
+
+ e.g.::
+
+ expr = decimal.Decimal("15.7563")
+
+ value = e.scalar(
+ select(literal(expr))
+ )
+
+ assert value == expr
+
+ See :ticket:`4036`
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def nested_aggregates(self):
+ """target database can select an aggregate from a subquery that's
+ also using an aggregate
+
+ """
+ return exclusions.open()
+
+ @property
+ def recursive_fk_cascade(self):
+ """target database must support ON DELETE CASCADE on a self-referential
+ foreign key
+
+ """
+ return exclusions.open()
+
+ @property
+ def precision_numerics_retains_significant_digits(self):
+ """A precision numeric type will return empty significant digits,
+ i.e. a value such as 10.000 will come back in Decimal form with
+ the .000 maintained."""
+
+ return exclusions.closed()
+
+ @property
+ def infinity_floats(self):
+ """The Float type can persist and load float('inf'), float('-inf')."""
+
+ return exclusions.closed()
+
+ @property
+ def precision_generic_float_type(self):
+ """target backend will return native floating point numbers with at
+ least seven decimal places when using the generic Float type.
+
+ """
+ return exclusions.open()
+
+ @property
+ def floats_to_four_decimals(self):
+ """target backend can return a floating-point number with four
+ significant digits (such as 15.7563) accurately
+ (i.e. without FP inaccuracies, such as 15.75629997253418).
+
+ """
+ return exclusions.open()
+
+ @property
+ def fetch_null_from_numeric(self):
+ """target backend doesn't crash when you try to select a NUMERIC
+ value that has a value of NULL.
+
+ Added to support Pyodbc bug #351.
+ """
+
+ return exclusions.open()
+
+ @property
+ def text_type(self):
+ """Target database must support an unbounded Text() "
+ "type such as TEXT or CLOB"""
+
+ return exclusions.open()
+
+ @property
+ def empty_strings_varchar(self):
+ """target database can persist/return an empty string with a
+ varchar.
+
+ """
+ return exclusions.open()
+
+ @property
+ def empty_strings_text(self):
+ """target database can persist/return an empty string with an
+ unbounded text."""
+
+ return exclusions.open()
+
+ @property
+ def expressions_against_unbounded_text(self):
+ """target database supports use of an unbounded textual field in a
+ WHERE clause."""
+
+ return exclusions.open()
+
+ @property
+ def selectone(self):
+ """target driver must support the literal statement 'select 1'"""
+ return exclusions.open()
+
+ @property
+ def savepoints(self):
+ """Target database must support savepoints."""
+
+ return exclusions.closed()
+
+ @property
+ def two_phase_transactions(self):
+ """Target database must support two-phase transactions."""
+
+ return exclusions.closed()
+
+ @property
+ def update_from(self):
+ """Target must support UPDATE..FROM syntax"""
+ return exclusions.closed()
+
+ @property
+ def delete_from(self):
+ """Target must support DELETE FROM..FROM or DELETE..USING syntax"""
+ return exclusions.closed()
+
+ @property
+ def update_where_target_in_subquery(self):
+ """Target must support UPDATE (or DELETE) where the same table is
+ present in a subquery in the WHERE clause.
+
+ This is an ANSI-standard syntax that apparently MySQL can't handle,
+ such as::
+
+ UPDATE documents SET flag=1 WHERE documents.title IN
+ (SELECT max(documents.title) AS title
+ FROM documents GROUP BY documents.user_id
+ )
+
+ """
+ return exclusions.open()
+
+ @property
+ def mod_operator_as_percent_sign(self):
+ """target database must use a plain percent '%' as the 'modulus'
+ operator."""
+ return exclusions.closed()
+
+ @property
+ def percent_schema_names(self):
+ """target backend supports weird identifiers with percent signs
+ in them, e.g. 'some % column'.
+
+ this is a very weird use case but often has problems because of
+ DBAPIs that use python formatting. It's not a critical use
+ case either.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_col_from_union(self):
+ """target database supports ordering by a column from a SELECT
+ inside of a UNION
+
+ E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id
+
+ """
+ return exclusions.open()
+
+ @property
+ def order_by_label_with_expression(self):
+ """target backend supports ORDER BY a column label within an
+ expression.
+
+ Basically this::
+
+ select data as foo from test order by foo || 'bar'
+
+ Lots of databases including PostgreSQL don't support this,
+ so this is off by default.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_collation(self):
+ def check(config):
+ try:
+ self.get_order_by_collation(config)
+ return False
+ except NotImplementedError:
+ return True
+
+ return exclusions.skip_if(check)
+
+ def get_order_by_collation(self, config):
+ raise NotImplementedError()
+
+ @property
+ def unicode_connections(self):
+ """Target driver must support non-ASCII characters being passed at
+ all.
+ """
+ return exclusions.open()
+
+ @property
+ def graceful_disconnects(self):
+ """Target driver must raise a DBAPI-level exception, such as
+ InterfaceError, when the underlying connection has been closed
+ and the execute() method is called.
+ """
+ return exclusions.open()
+
+ @property
+ def independent_connections(self):
+ """
+ Target must support simultaneous, independent database connections.
+ """
+ return exclusions.open()
+
+ @property
+ def skip_mysql_on_windows(self):
+ """Catchall for a large variety of MySQL on Windows failures"""
+ return exclusions.open()
+
+ @property
+ def ad_hoc_engines(self):
+ """Test environment must allow ad-hoc engine/connection creation.
+
+ DBs that scale poorly for many connections, even when closed, i.e.
+ Oracle, may use the "--low-connections" option which flags this
+ requirement as not present.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.low_connections
+ )
+
+ @property
+ def no_windows(self):
+ return exclusions.skip_if(self._running_on_windows())
+
+ def _running_on_windows(self):
+ return exclusions.LambdaPredicate(
+ lambda: platform.system() == "Windows",
+ description="running on Windows",
+ )
+
+ @property
+ def timing_intensive(self):
+ return exclusions.requires_tag("timing_intensive")
+
+ @property
+ def memory_intensive(self):
+ return exclusions.requires_tag("memory_intensive")
+
+ @property
+ def threading_with_mock(self):
+ """Mark tests that use threading and mock at the same time - stability
+ issues have been observed with coverage + python 3.3
+
+ """
+ return exclusions.skip_if(
+ lambda config: util.py3k and config.options.has_coverage,
+ "Stability issues with coverage + py3k",
+ )
+
+ @property
+ def sqlalchemy2_stubs(self):
+ def check(config):
+ try:
+ __import__("sqlalchemy-stubs.ext.mypy")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check)
+
+ @property
+ def python2(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info >= (3,),
+ "Python version 2.xx is required.",
+ )
+
+ @property
+ def python3(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
+ )
+
+ @property
+ def pep520(self):
+ return self.python36
+
+ @property
+ def insert_order_dicts(self):
+ return self.python37
+
+ @property
+ def python36(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 6),
+ "Python version 3.6 or greater is required.",
+ )
+
+ @property
+ def python37(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 7),
+ "Python version 3.7 or greater is required.",
+ )
+
+ @property
+ def dataclasses(self):
+ return self.python37
+
+ @property
+ def python38(self):
+ return exclusions.only_if(
+ lambda: util.py38, "Python 3.8 or above required"
+ )
+
+ @property
+ def cpython(self):
+ return exclusions.only_if(
+ lambda: util.cpython, "cPython interpreter needed"
+ )
+
+ @property
+ def patch_library(self):
+ def check_lib():
+ try:
+ __import__("patch")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check_lib, "patch library needed")
+
+ @property
+ def non_broken_pickle(self):
+ from sqlalchemy.util import pickle
+
+ return exclusions.only_if(
+ lambda: util.cpython
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
+ )
+
+ @property
+ def predictable_gc(self):
+ """target platform must remove all cycles unconditionally when
+ gc.collect() is called, as well as clean out unreferenced subclasses.
+
+ """
+ return self.cpython
+
+ @property
+ def no_coverage(self):
+ """Test should be skipped if coverage is enabled.
+
+ This is to block tests that exercise libraries that seem to be
+ sensitive to coverage, such as PostgreSQL notice logging.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.has_coverage,
+ "Issues observed when coverage is enabled",
+ )
+
+ def _has_mysql_on_windows(self, config):
+ return False
+
+ def _has_mysql_fully_case_sensitive(self, config):
+ return False
+
+ @property
+ def sqlite(self):
+ return exclusions.skip_if(lambda: not self._has_sqlite())
+
+ @property
+ def cextensions(self):
+ return exclusions.skip_if(
+ lambda: not util.has_compiled_ext(), "C extensions not installed"
+ )
+
+ def _has_sqlite(self):
+ from sqlalchemy import create_engine
+
+ try:
+ create_engine("sqlite://")
+ return True
+ except ImportError:
+ return False
+
+ @property
+ def async_dialect(self):
+ """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+ return exclusions.closed()
+
+ @property
+ def asyncio(self):
+ return self.greenlet
+
+ @property
+ def greenlet(self):
+ def go(config):
+ try:
+ import greenlet # noqa: F401
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(go)
+
+ @property
+ def computed_columns(self):
+ "Supports computed columns"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_stored(self):
+ "Supports computed columns with `persisted=True`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_virtual(self):
+ "Supports computed columns with `persisted=False`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_default_persisted(self):
+ """If the default persistence is virtual or stored when `persisted`
+ is omitted"""
+ return exclusions.closed()
+
+ @property
+ def computed_columns_reflect_persisted(self):
+ """If persistence information is returned by the reflection of
+ computed columns"""
+ return exclusions.closed()
+
+ @property
+ def supports_distinct_on(self):
+ """If a backend supports the DISTINCT ON in a select"""
+ return exclusions.closed()
+
+ @property
+ def supports_is_distinct_from(self):
+ """Supports some form of "x IS [NOT] DISTINCT FROM y" construct.
+ Different dialects will implement their own flavour, e.g.,
+ sqlite will emit "x IS NOT y" instead of "x IS DISTINCT FROM y".
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.is_distinct_from`
+
+ """
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_is_distinct_from,
+ "driver doesn't support an IS DISTINCT FROM construct",
+ )
+
+ @property
+ def identity_columns(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY"""
+ return exclusions.closed()
+
+ @property
+ def identity_columns_standard(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY with a standard syntax.
+ This is mainly to exclude MSSql.
+ """
+ return exclusions.closed()
+
+ @property
+ def regexp_match(self):
+ """backend supports the regexp_match operator."""
+ return exclusions.closed()
+
+ @property
+ def regexp_replace(self):
+ """backend supports the regexp_replace operator."""
+ return exclusions.closed()
+
+ @property
+ def fetch_first(self):
+ """backend supports the fetch first clause."""
+ return exclusions.closed()
+
+ @property
+ def fetch_percent(self):
+ """backend supports the fetch first clause with percent."""
+ return exclusions.closed()
+
+ @property
+ def fetch_ties(self):
+ """backend supports the fetch first clause with ties."""
+ return exclusions.closed()
+
+ @property
+ def fetch_no_order_by(self):
+ """backend supports the fetch first without order by"""
+ return exclusions.closed()
+
+ @property
+ def fetch_offset_with_options(self):
+ """backend supports the offset when using fetch first with percent
+ or ties. basically this is "not mssql"
+ """
+ return exclusions.closed()
+
+ @property
+ def fetch_expression(self):
+ """backend supports fetch / offset with expression in them, like
+
+ SELECT * FROM some_table
+ OFFSET 1 + 1 ROWS FETCH FIRST 1 + 1 ROWS ONLY
+ """
+ return exclusions.closed()
+
+ @property
+ def autoincrement_without_sequence(self):
+ """If autoincrement=True on a column does not require an explicit
+ sequence. This should be false only for oracle.
+ """
+ return exclusions.open()
+
+ @property
+ def generic_classes(self):
+ "If X[Y] can be implemented with ``__class_getitem__``. py3.7+"
+ return exclusions.only_if(lambda: util.py37)
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
new file mode 100644
index 0000000..bff07a5
--- /dev/null
+++ b/lib/sqlalchemy/testing/schema.py
@@ -0,0 +1,218 @@
+# testing/schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sys
+
+from . import config
+from . import exclusions
+from .. import event
+from .. import schema
+from .. import types as sqltypes
+from ..util import OrderedDict
+
+
+__all__ = ["Table", "Column"]
+
+table_options = {}
+
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ kw.update(table_options)
+
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
+ else:
+ kw["mysql_engine"] = "MyISAM"
+ elif exclusions.against(config._current, "mariadb"):
+ if (
+ "mariadb_engine" not in kw
+ and "mariadb_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mariadb_engine"] = "InnoDB"
+ else:
+ kw["mariadb_engine"] = "MyISAM"
+
+ # Apply some default cascading rules for self-referential foreign keys.
+ # MySQL InnoDB has some issues around selecting self-refs too.
+ if exclusions.against(config._current, "firebird"):
+ table_name = args[0]
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
+
+ # Only going after ForeignKeys in Columns. May need to
+ # expand to ForeignKeyConstraint too.
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
+
+ for fk in fks:
+ # root around in raw spec
+ ref = fk._colspec
+ if isinstance(ref, schema.Column):
+ name = ref.table.name
+ else:
+ # take just the table name: on FB there cannot be
+ # a schema, so the first element is always the
+ # table name, possibly followed by the field name
+ name = unpack(ref)[0]
+ if name == table_name:
+ if fk.ondelete is None:
+ fk.ondelete = "CASCADE"
+ if fk.onupdate is None:
+ fk.onupdate = "CASCADE"
+
+ return schema.Table(*args, **kw)
+
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ if not config.requirements.foreign_key_ddl.enabled_for_config(config):
+ args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
+
+ col = schema.Column(*args, **kw)
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
+
+ if col.default is None and col.server_default is None:
+ col.autoincrement = True
+
+ # allow any test suite to pick up on this
+ col.info["test_needs_autoincrement"] = True
+
+ # hardcoded rule for firebird, oracle; this should
+ # be moved out
+ if exclusions.against(config._current, "firebird", "oracle"):
+
+ def add_seq(c, tbl):
+ c._init_items(
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
+ )
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
+ return col
+
+
+class eq_type_affinity(object):
+ """Helper to compare types inside of datastructures based on affinity.
+
+ E.g.::
+
+ eq_(
+ inspect(connection).get_columns("foo"),
+ [
+ {
+ "name": "id",
+ "type": testing.eq_type_affinity(sqltypes.INTEGER),
+ "nullable": False,
+ "default": None,
+ "autoincrement": False,
+ },
+ {
+ "name": "data",
+ "type": testing.eq_type_affinity(sqltypes.NullType),
+ "nullable": True,
+ "default": None,
+ "autoincrement": False,
+ },
+ ],
+ )
+
+ """
+
+ def __init__(self, target):
+ self.target = sqltypes.to_instance(target)
+
+ def __eq__(self, other):
+ return self.target._type_affinity is other._type_affinity
+
+ def __ne__(self, other):
+ return self.target._type_affinity is not other._type_affinity
+
+
+class eq_clause_element(object):
+ """Helper to compare SQL structures based on compare()"""
+
+ def __init__(self, target):
+ self.target = target
+
+ def __eq__(self, other):
+ return self.target.compare(other)
+
+ def __ne__(self, other):
+ return not self.target.compare(other)
+
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
+ else:
+ return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
new file mode 100644
index 0000000..30817e1
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -0,0 +1,13 @@
+from .test_cte import * # noqa
+from .test_ddl import * # noqa
+from .test_deprecations import * # noqa
+from .test_dialect import * # noqa
+from .test_insert import * # noqa
+from .test_reflection import * # noqa
+from .test_results import * # noqa
+from .test_rowcount import * # noqa
+from .test_select import * # noqa
+from .test_sequence import * # noqa
+from .test_types import * # noqa
+from .test_unicode_ddl import * # noqa
+from .test_update_delete import * # noqa
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 0000000..a94ee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,204 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import ForeignKey
+from ... import Integer
+from ... import select
+from ... import String
+from ... import testing
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("ctes",)
+
+ run_inserts = "each"
+ run_deletes = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
+
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
+ )
+
+ def test_select_nonrecursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ result = connection.execute(
+ select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4",)])
+
+ def test_select_recursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select(st1).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = connection.execute(
+ select(cte.c.data)
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
+ )
+
+ def test_insert_from_select_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(cte)
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self, connection):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data
+ == select(cte.c.data)
+ .where(cte.c.id == some_other_table.c.id)
+ .scalar_subquery()
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
new file mode 100644
index 0000000..b3fee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -0,0 +1,381 @@
+import random
+
+from . import testing
+from .. import config
+from .. import fixtures
+from .. import util
+from ..assertions import eq_
+from ..assertions import is_false
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Table
+from ... import CheckConstraint
+from ... import Column
+from ... import ForeignKeyConstraint
+from ... import Index
+from ... import inspect
+from ... import Integer
+from ... import schema
+from ... import String
+from ... import UniqueConstraint
+
+
+class TableDDLTest(fixtures.TestBase):
+ __backend__ = True
+
+ def _simple_fixture(self, schema=None):
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ schema=schema,
+ )
+
+ def _underscore_fixture(self):
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
+
+ def _table_index_fixture(self, schema=None):
+ table = self._simple_fixture(schema=schema)
+ idx = Index("test_index", table.c.data)
+ return table, idx
+
+ def _simple_roundtrip(self, table):
+ with config.db.begin() as conn:
+ conn.execute(table.insert().values((1, "some data")))
+ result = conn.execute(table.select())
+ eq_(result.first(), (1, "some data"))
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_create_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.create_table
+ @requirements.schemas
+ @util.provide_metadata
+ def test_create_table_schema(self):
+ table = self._simple_fixture(schema=config.test_schema)
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.drop_table
+ @util.provide_metadata
+ def test_drop_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_underscore_names(self):
+ table = self._underscore_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_add_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"),
+ {"text": "a comment"},
+ )
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_drop_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ connection.execute(schema.DropTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"), {"text": None}
+ )
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_create_table_if_not_exists(self, connection):
+ table = self._simple_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ is_true(inspect(connection).has_table("test_table"))
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_create_index_if_not_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+ is_true(inspect(connection).has_table("test_table"))
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_table_if_exists(self, connection):
+ table = self._simple_fixture()
+
+ table.create(connection)
+
+ is_true(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ is_false(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_index_if_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ table.create(connection)
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+
+class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
+ pass
+
+
+class LongNameBlowoutTest(fixtures.TestBase):
+ """test the creation of a variety of DDL structures and ensure
+ label length limits pass on backends
+
+ """
+
+ __backend__ = True
+
+ def fk(self, metadata, connection):
+ convention = {
+ "fk": "foreign_key_%(table_name)s_"
+ "%(column_0_N_name)s_"
+ "%(referred_table_name)s_"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(20))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ cons = ForeignKeyConstraint(
+ ["aid"], ["a_things_with_stuff.id_long_column_name"]
+ )
+ Table(
+ "b_related_things_of_value",
+ metadata,
+ Column(
+ "aid",
+ ),
+ cons,
+ test_needs_fk=True,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+
+ if testing.requires.foreign_key_constraint_name_reflection.enabled:
+ insp = inspect(connection)
+ fks = insp.get_foreign_keys("b_related_things_of_value")
+ reflected_name = fks[0]["name"]
+
+ return actual_name, reflected_name
+ else:
+ return actual_name, None
+
+ def pk(self, metadata, connection):
+ convention = {
+ "pk": "primary_key_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer, primary_key=True),
+ )
+ cons = a.primary_key
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ pk = insp.get_pk_constraint("a_things_with_stuff")
+ reflected_name = pk["name"]
+ return actual_name, reflected_name
+
+ def ix(self, metadata, connection):
+ convention = {
+ "ix": "index_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ )
+ cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ix = insp.get_indexes("a_things_with_stuff")
+ reflected_name = ix[0]["name"]
+ return actual_name, reflected_name
+
+ def uq(self, metadata, connection):
+ convention = {
+ "uq": "unique_constraint_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ uq = insp.get_unique_constraints("a_things_with_stuff")
+ reflected_name = uq[0]["name"]
+ return actual_name, reflected_name
+
+ def ck(self, metadata, connection):
+ convention = {
+ "ck": "check_constraint_%(table_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = CheckConstraint("some_long_column_name > 5")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("some_long_column_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ck = insp.get_check_constraints("a_things_with_stuff")
+ reflected_name = ck[0]["name"]
+ return actual_name, reflected_name
+
+ @testing.combinations(
+ ("fk",),
+ ("pk",),
+ ("ix",),
+ ("ck", testing.requires.check_constraint_reflection.as_skips()),
+ ("uq", testing.requires.unique_constraint_reflection.as_skips()),
+ argnames="type_",
+ )
+ def test_long_convention_name(self, type_, metadata, connection):
+ actual_name, reflected_name = getattr(self, type_)(
+ metadata, connection
+ )
+
+ assert len(actual_name) > 255
+
+ if reflected_name is not None:
+ overlap = actual_name[0 : len(reflected_name)]
+ if len(overlap) < len(actual_name):
+ eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
+ else:
+ eq_(overlap, reflected_name)
+
+
+__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py
new file mode 100644
index 0000000..b36162f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_deprecations.py
@@ -0,0 +1,145 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import select
+from ... import testing
+from ... import union
+
+
+class DeprecatedCompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, conn, select, result, params=()):
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ # note we've had to remove one use case entirely, which is this
+ # one. the Select gets its FROMS from the WHERE clause and the
+ # columns clause, but not the ORDER BY, which means the old ".c" system
+ # allowed you to "order_by(s.c.foo)" to get an unnamed column in the
+ # ORDER BY without adding the SELECT into the FROM and breaking the
+ # query. Users will have to adjust for this use case if they were doing
+ # it before.
+ def _dont_test_select_from_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
new file mode 100644
index 0000000..c2c17d0
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -0,0 +1,361 @@
+#! coding: utf-8
+
+from . import testing
+from .. import assert_raises
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import fixtures
+from .. import ne_
+from .. import provide_metadata
+from ..config import requirements
+from ..provision import set_default_schema_on_connection
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import event
+from ... import exc
+from ... import Integer
+from ... import literal_column
+from ... import select
+from ... import String
+from ...util import compat
+
+
+class ExceptionTest(fixtures.TablesTest):
+ """Test basic exception wrapping.
+
+ DBAPIs vary a lot in exception behavior so to actually anticipate
+ specific exceptions from real round trips, we need to be conservative.
+
+ """
+
+ run_deletes = "each"
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ @requirements.duplicate_key_raises_integrity_error
+ def test_integrity_error(self):
+
+ with config.db.connect() as conn:
+
+ trans = conn.begin()
+ conn.execute(
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
+ )
+
+ assert_raises(
+ exc.IntegrityError,
+ conn.execute,
+ self.tables.manual_pk.insert(),
+ {"id": 1, "data": "d1"},
+ )
+
+ trans.rollback()
+
+ def test_exception_with_non_ascii(self):
+ with config.db.connect() as conn:
+ try:
+ # try to create an error message that likely has non-ascii
+ # characters in the DBAPI's message string. unfortunately
+ # there's no way to make this happen with some drivers like
+ # mysqlclient, pymysql. this at least does produce a non-
+ # ascii error message for cx_oracle, psycopg2
+ conn.execute(select(literal_column(u"méil")))
+ assert False
+ except exc.DBAPIError as err:
+ err_str = str(err)
+
+ assert str(err.orig) in str(err)
+
+ # test that we are actually getting string on Py2k, unicode
+ # on Py3k.
+ if compat.py2k:
+ assert isinstance(err_str, str)
+ else:
+ assert isinstance(err_str, str)
+
+
+class IsolationLevelTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("isolation_level",)
+
+ def _get_non_default_isolation_level(self):
+ levels = requirements.get_isolation_levels(config)
+
+ default = levels["default"]
+ supported = levels["supported"]
+
+ s = set(supported).difference(["AUTOCOMMIT", default])
+ if s:
+ return s.pop()
+ else:
+ config.skip_test("no non-default isolation level available")
+
+ def test_default_isolation_level(self):
+ eq_(
+ config.db.dialect.default_isolation_level,
+ requirements.get_isolation_levels(config)["default"],
+ )
+
+ def test_non_default_isolation_level(self):
+ non_default = self._get_non_default_isolation_level()
+
+ with config.db.connect() as conn:
+ existing = conn.get_isolation_level()
+
+ ne_(existing, non_default)
+
+ conn.execution_options(isolation_level=non_default)
+
+ eq_(conn.get_isolation_level(), non_default)
+
+ conn.dialect.reset_isolation_level(conn.connection)
+
+ eq_(conn.get_isolation_level(), existing)
+
+ def test_all_levels(self):
+ levels = requirements.get_isolation_levels(config)
+
+ all_levels = levels["supported"]
+
+ for level in set(all_levels).difference(["AUTOCOMMIT"]):
+ with config.db.connect() as conn:
+ conn.execution_options(isolation_level=level)
+
+ eq_(conn.get_isolation_level(), level)
+
+ trans = conn.begin()
+ trans.rollback()
+
+ eq_(conn.get_isolation_level(), level)
+
+ with config.db.connect() as conn:
+ eq_(
+ conn.get_isolation_level(),
+ levels["default"],
+ )
+
+
+class AutocommitIsolationTest(fixtures.TablesTest):
+
+ run_deletes = "each"
+
+ __requires__ = ("autocommit",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
+
+ def _test_conn_autocommits(self, conn, autocommit):
+ trans = conn.begin()
+ conn.execute(
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
+ )
+ trans.rollback()
+
+ eq_(
+ conn.scalar(select(self.tables.some_table.c.id)),
+ 1 if autocommit else None,
+ )
+
+ with conn.begin():
+ conn.execute(self.tables.some_table.delete())
+
+ def test_autocommit_on(self, connection_no_trans):
+ conn = connection_no_trans
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(c2, True)
+
+ c2.dialect.reset_isolation_level(c2.connection)
+
+ self._test_conn_autocommits(conn, False)
+
+ def test_autocommit_off(self, connection_no_trans):
+ conn = connection_no_trans
+ self._test_conn_autocommits(conn, False)
+
+ def test_turn_autocommit_off_via_default_iso_level(
+ self, connection_no_trans
+ ):
+ conn = connection_no_trans
+ conn = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(conn, True)
+
+ conn.execution_options(
+ isolation_level=requirements.get_isolation_levels(config)[
+ "default"
+ ]
+ )
+ self._test_conn_autocommits(conn, False)
+
+
+class EscapingTest(fixtures.TestBase):
+ @provide_metadata
+ def test_percent_sign_round_trip(self):
+ """test that the DBAPI accommodates for escaped / nonescaped
+ percent signs in a way that matches the compiler
+
+ """
+ m = self.metadata
+ t = Table("t", m, Column("data", String(50)))
+ t.create(config.db)
+ with config.db.begin() as conn:
+ conn.execute(t.insert(), dict(data="some % value"))
+ conn.execute(t.insert(), dict(data="some %% other value"))
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some % value'")
+ )
+ ),
+ "some % value",
+ )
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
+ )
+
+
+class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("default_schema_name_switch",)
+
+ def test_control_case(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+ with eng.connect():
+ pass
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_wont_work_wo_insert(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect")
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_schema_change_on_connect(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+ def test_schema_change_works_w_transactions(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, *arg):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ trans = conn.begin()
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+ trans.rollback()
+
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+
+class FutureWeCanSetDefaultSchemaWEventsTest(
+ fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
+):
+ pass
+
+
+class DifficultParametersTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.combinations(
+ ("boring",),
+ ("per cent",),
+ ("per % cent",),
+ ("%percent",),
+ ("par(ens)",),
+ ("percent%(ens)yah",),
+ ("col:ons",),
+ ("more :: %colons%",),
+ ("/slashes/",),
+ ("more/slashes",),
+ ("q?marks",),
+ ("1param",),
+ ("1col:on",),
+ argnames="name",
+ )
+ def test_round_trip(self, name, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(name, String(50), nullable=False),
+ )
+
+ # table is created
+ t.create(connection)
+
+ # automatic param generated by insert
+ connection.execute(t.insert().values({"id": 1, name: "some name"}))
+
+ # automatic param generated by criteria, plus selecting the column
+ stmt = select(t.c[name]).where(t.c[name] == "some name")
+
+ eq_(connection.scalar(stmt), "some name")
+
+ # use the name in a param explicitly
+ stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
+
+ row = connection.execute(stmt, {name: "some name"}).first()
+
+ # name works as the key from cursor.description
+ eq_(row._mapping[name], "some name")
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
new file mode 100644
index 0000000..3c22f50
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -0,0 +1,367 @@
+from .. import config
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import select
+from ... import String
+
+
+class LastrowidTest(fixtures.TablesTest):
+ run_deletes = "each"
+
+ __backend__ = True
+
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
+
+ __engine_options__ = {"implicit_returning": False}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ def test_autoincrement_on_insert(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+ @requirements.dbapi_lastrowid
+ def test_native_lastrowid_autoinc(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ lastrowid = r.lastrowid
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(lastrowid, pk)
+
+
+class InsertBehaviorTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
+
+ @requirements.autoincrement_insert
+ def test_autoclose_on_insert(self):
+ if requirements.returning.enabled:
+ engine = engines.testing_engine(
+ options={"implicit_returning": False}
+ )
+ else:
+ engine = config.db
+
+ with engine.begin() as conn:
+ r = conn.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
+ # an insert where the PK was taken from a row that the dialect
+ # selected, as is the case for mssql/pyodbc, will still report
+ # returns_rows as true because there's a cursor description. in that
+ # case, the row had to have been consumed at least.
+ assert not r.returns_rows or r.fetchone() is None
+
+ @requirements.returning
+ def test_autoclose_on_insert_implicit_returning(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # note we are experimenting with having this be True
+ # as of I8091919d45421e3f53029b8660427f844fee0228 .
+ # implicit returning has fetched the row, but it still is a
+ # "returns rows"
+ assert r.returns_rows
+
+ # and we should be able to fetchone() on it, we just get no row
+ eq_(r.fetchone(), None)
+
+ # and the keys, etc.
+ eq_(r.keys(), ["id"])
+
+ # but the dialect took in the row already. not really sure
+ # what the best behavior is.
+
+ @requirements.empty_inserts
+ def test_empty_insert(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert())
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+ eq_(len(r.all()), 1)
+
+ @requirements.empty_inserts_executemany
+ def test_empty_insert_multiple(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+
+ eq_(len(r.all()), 3)
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+ connection.execute(
+ src_table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+ eq_(result.fetchall(), [("data2",), ("data3",)])
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc_no_rows(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+
+ eq_(result.fetchall(), [])
+
+ @requirements.insert_from_select
+ def test_insert_from_select(self, connection):
+ table = self.tables.manual_pk
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table.c.data).order_by(table.c.data)
+ ).fetchall(),
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
+ )
+
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self, connection):
+ table = self.tables.includes_defaults
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table).order_by(table.c.data, table.c.id)
+ ).fetchall(),
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
+ )
+
+
+class ReturningTest(fixtures.TablesTest):
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
+ __backend__ = True
+
+ __engine_options__ = {"implicit_returning": True}
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ @requirements.fetch_rows_post_commit
+ def test_explicit_returning_pk_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_explicit_returning_pk_no_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_autoincrement_on_insert_implicit_returning(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id_implicit_returning(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
new file mode 100644
index 0000000..459a4d8
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -0,0 +1,1738 @@
+import operator
+import re
+
+import sqlalchemy as sa
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import expect_warnings
+from .. import fixtures
+from .. import is_
+from ..provision import get_temp_table_name
+from ..provision import temp_table_keyword_args
+from ..schema import Column
+from ..schema import Table
+from ... import event
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import String
+from ... import testing
+from ... import types as sql_types
+from ...schema import DDL
+from ...schema import Index
+from ...sql.elements import quoted_name
+from ...sql.schema import BLANK_SCHEMA
+from ...testing import is_false
+from ...testing import is_true
+
+
+metadata, users = None, None
+
+
+class HasTableTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "test_table_s",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+
+ def test_has_table(self):
+ with config.db.begin() as conn:
+ is_true(config.db.dialect.has_table(conn, "test_table"))
+ is_false(config.db.dialect.has_table(conn, "test_table_s"))
+ is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+
+ @testing.requires.schemas
+ def test_has_table_schema(self):
+ with config.db.begin() as conn:
+ is_false(
+ config.db.dialect.has_table(
+ conn, "test_table", schema=config.test_schema
+ )
+ )
+ is_true(
+ config.db.dialect.has_table(
+ conn, "test_table_s", schema=config.test_schema
+ )
+ )
+ is_false(
+ config.db.dialect.has_table(
+ conn, "nonexistent_table", schema=config.test_schema
+ )
+ )
+
+
+class HasIndexTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Index("my_idx", tt.c.data)
+
+ if testing.requires.schemas.enabled:
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+ Index("my_idx_s", tt.c.data)
+
+ def test_has_index(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(conn, "test_table", "my_idx")
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "nonexistent_table", "my_idx"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "nonexistent_idx"
+ )
+
+ @testing.requires.schemas
+ def test_has_index_schema(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "nonexistent_table",
+ "my_idx_s",
+ schema=config.test_schema,
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "test_table",
+ "nonexistent_idx_s",
+ schema=config.test_schema,
+ )
+
+
+class QuotedNameArgumentTest(fixtures.TablesTest):
+ run_create_tables = "once"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "quote ' one",
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name="pk quote ' one"),
+ sa.Index("ix quote ' one", "name"),
+ sa.UniqueConstraint(
+ "data",
+ name="uq quote' one",
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name="fk quote ' one"
+ ),
+ sa.CheckConstraint("name != 'foo'", name="ck quote ' one"),
+ comment=r"""quote ' one comment""",
+ test_needs_fk=True,
+ )
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ Table(
+ 'quote " two',
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name='pk quote " two'),
+ sa.Index('ix quote " two', "name"),
+ sa.UniqueConstraint(
+ "data",
+ name='uq quote" two',
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name='fk quote " two'
+ ),
+ sa.CheckConstraint("name != 'foo'", name='ck quote " two '),
+ comment=r"""quote " two comment""",
+ test_needs_fk=True,
+ )
+
+ Table(
+ "related",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("related", Integer),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.view_column_reflection.enabled:
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ names = [
+ "quote ' one",
+ 'quote " two',
+ ]
+ else:
+ names = [
+ "quote ' one",
+ ]
+ for name in names:
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ ),
+ config.db.dialect.identifier_preparer.quote(name),
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata,
+ "before_drop",
+ DDL(
+ "DROP VIEW %s"
+ % config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ )
+ ),
+ )
+
+ def quote_fixtures(fn):
+ return testing.combinations(
+ ("quote ' one",),
+ ('quote " two', testing.requires.symbol_names_w_double_quote),
+ )(fn)
+
+ @quote_fixtures
+ def test_get_table_options(self, name):
+ insp = inspect(config.db)
+
+ insp.get_table_options(name)
+
+ @quote_fixtures
+ @testing.requires.view_column_reflection
+ def test_get_view_definition(self, name):
+ insp = inspect(config.db)
+ assert insp.get_view_definition("view %s" % name)
+
+ @quote_fixtures
+ def test_get_columns(self, name):
+ insp = inspect(config.db)
+ assert insp.get_columns(name)
+
+ @quote_fixtures
+ def test_get_pk_constraint(self, name):
+ insp = inspect(config.db)
+ assert insp.get_pk_constraint(name)
+
+ @quote_fixtures
+ def test_get_foreign_keys(self, name):
+ insp = inspect(config.db)
+ assert insp.get_foreign_keys(name)
+
+ @quote_fixtures
+ def test_get_indexes(self, name):
+ insp = inspect(config.db)
+ assert insp.get_indexes(name)
+
+ @quote_fixtures
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_unique_constraints(name)
+
+ @quote_fixtures
+ @testing.requires.comment_reflection
+ def test_get_table_comment(self, name):
+ insp = inspect(config.db)
+ assert insp.get_table_comment(name)
+
+ @quote_fixtures
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_check_constraints(name)
+
+
+class ComponentReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+
+ @classmethod
+ def setup_bind(cls):
+ if config.requirements.independent_connections.enabled:
+ from sqlalchemy import pool
+
+ return engines.testing_engine(
+ options=dict(poolclass=pool.StaticPool, scope="class"),
+ )
+ else:
+ return config.db
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.define_reflected_tables(metadata, None)
+ if testing.requires.schemas.enabled:
+ cls.define_reflected_tables(metadata, testing.config.test_schema)
+
+ @classmethod
+ def define_reflected_tables(cls, metadata, schema):
+ if schema:
+ schema_prefix = schema + "."
+ else:
+ schema_prefix = ""
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ else:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
+
+ if testing.requires.cross_schema_fk_reflection.enabled:
+ if schema is None:
+ Table(
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ Column(
+ "remote_id",
+ ForeignKey(
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
+ ),
+ test_needs_fk=True,
+ schema=config.db.dialect.default_schema_name,
+ )
+ else:
+ Table(
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column(
+ "local_id",
+ ForeignKey(
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
+ ),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ if testing.requires.index_reflection.enabled:
+ cls.define_index(metadata, users)
+
+ if not schema:
+ # test_needs_fk is at the moment to force MySQL InnoDB
+ noncol_idx_test_nopk = Table(
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ noncol_idx_test_pk = Table(
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.indexes_with_ascdesc.enabled:
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
+
+ if testing.requires.view_column_reflection.enabled:
+ cls.define_views(metadata, schema)
+ if not schema and testing.requires.temp_table_reflection.enabled:
+ cls.define_temp_tables(metadata)
+
+ @classmethod
+ def define_temp_tables(cls, metadata):
+ kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ user_tmp = Table(
+ table_name,
+ metadata,
+ Column("id", sa.INT, primary_key=True),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ # disambiguate temp table unique constraint names. this is
+ # pretty arbitrary for a generic dialect however we are doing
+ # it to suit SQL Server which will produce name conflicts for
+ # unique constraints created against temp tables in different
+ # databases.
+ # https://www.arbinada.com/en/node/1645
+ sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.Index("user_tmp_ix", "foo"),
+ **kw
+ )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
+ event.listen(
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp_%s" % config.ident
+ ),
+ )
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
+
+ @classmethod
+ def define_index(cls, metadata, users):
+ Index("users_t_idx", users.c.test1, users.c.test2)
+ Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
+
+ @classmethod
+ def define_views(cls, metadata, schema):
+ for table_name in ("users", "email_addresses"):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + "_v"
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ view_name,
+ fullname,
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ )
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names(self):
+ insp = inspect(self.bind)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names_w_translate_map(self, connection):
+ """test #7300"""
+
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_dialect_initialize(self):
+ engine = engines.testing_engine()
+ inspect(engine)
+ assert hasattr(engine.dialect, "default_schema_name")
+
+ @testing.requires.schema_reflection
+ def test_get_default_schema_name(self):
+ insp = inspect(self.bind)
+ eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @testing.combinations(
+ (None, True, False, False),
+ (None, True, False, True, testing.requires.schemas),
+ ("foreign_key", True, False, False),
+ (None, False, True, False),
+ (None, False, True, True, testing.requires.schemas),
+ (None, True, True, False),
+ (None, True, True, True, testing.requires.schemas),
+ argnames="order_by,include_plain,include_views,use_schema",
+ )
+ def test_get_table_names(
+ self, connection, order_by, include_plain, include_views, use_schema
+ ):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ _ignore_tables = [
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
+ ]
+
+ insp = inspect(connection)
+
+ if include_views:
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ["email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
+
+ if include_plain:
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
+
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
+
+ @testing.requires.temp_table_names
+ def test_get_temp_table_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_table_names()
+ eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+
+ @testing.requires.view_reflection
+ @testing.requires.temp_table_names
+ @testing.requires.temporary_views
+ def test_get_temp_view_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_view_names()
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
+
+ @testing.requires.comment_reflection
+ def test_get_comments(self):
+ self._test_get_comments()
+
+ @testing.requires.comment_reflection
+ @testing.requires.schemas
+ def test_get_comments_with_schema(self):
+ self._test_get_comments(testing.config.test_schema)
+
+ def _test_get_comments(self, schema=None):
+ insp = inspect(self.bind)
+
+ eq_(
+ insp.get_table_comment("comment_test", schema=schema),
+ {"text": r"""the test % ' " \ table comment"""},
+ )
+
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+
+ eq_(
+ [
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
+ ],
+ [
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": (
+ r"""Comment types type speedily ' " \ '' Fun!"""
+ ),
+ "name": "d2",
+ },
+ ],
+ )
+
+ @testing.combinations(
+ (False, False),
+ (False, True, testing.requires.schemas),
+ (True, False, testing.requires.view_reflection),
+ (
+ True,
+ True,
+ testing.requires.schemas + testing.requires.view_reflection,
+ ),
+ argnames="use_views,use_schema",
+ )
+ def test_get_columns(self, connection, use_views, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ if use_views:
+ table_names = ["users_v", "email_addresses_v"]
+ else:
+ table_names = ["users", "email_addresses"]
+
+ insp = inspect(connection)
+ for table_name, table in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ # should be in order
+
+ for i, col in enumerate(table.columns):
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type share
+ # a base within one of the generic types.
+
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
+
+ if not col.primary_key:
+ assert cols[i]["default"] is None
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(
+ config, self.bind, "user_tmp_%s" % config.ident
+ )
+ user_tmp = self.tables[table_name]
+ insp = inspect(self.bind)
+ cols = insp.get_columns(table_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ for i, col in enumerate(user_tmp.columns):
+ eq_(col.name, cols[i]["name"])
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.view_column_reflection
+ @testing.requires.temporary_views
+ def test_get_temp_view_columns(self):
+ insp = inspect(self.bind)
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.primary_key_constraint_reflection
+ def test_get_pk_constraint(self, connection, use_schema):
+ if use_schema:
+ schema = testing.config.test_schema
+ else:
+ schema = None
+
+ users, addresses = self.tables.users, self.tables.email_addresses
+ insp = inspect(connection)
+
+ users_cons = insp.get_pk_constraint(users.name, schema=schema)
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
+
+ addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
+
+ with testing.requires.reflects_pk_names.fail_if():
+ eq_(addr_cons["name"], "email_ad_pk")
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.foreign_key_constraint_reflection
+ def test_get_foreign_keys(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ insp = inspect(connection)
+ expected_schema = schema
+ # users
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
+ fkey1 = users_fkeys[0]
+
+ with testing.requires.named_constraints.fail_if():
+ eq_(fkey1["name"], "user_id_fk")
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ if testing.requires.self_referential_foreign_keys.enabled:
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
+
+ # addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
+ fkey1 = addr_fkeys[0]
+
+ with testing.requires.implicitly_named_constraints.fail_if():
+ self.assert_(fkey1["name"] is not None)
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
+
+ @testing.requires.cross_schema_fk_reflection
+ @testing.requires.schemas
+ def test_get_inter_schema_foreign_keys(self):
+ local_table, remote_table, remote_table_2 = self.tables(
+ "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
+ )
+
+ insp = inspect(self.bind)
+
+ local_fkeys = insp.get_foreign_keys(local_table.name)
+ eq_(len(local_fkeys), 1)
+
+ fkey1 = local_fkeys[0]
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
+
+ remote_fkeys = insp.get_foreign_keys(
+ remote_table.name, schema=testing.config.test_schema
+ )
+ eq_(len(remote_fkeys), 1)
+
+ fkey2 = remote_fkeys[0]
+
+ assert fkey2["referred_schema"] in (
+ None,
+ self.bind.dialect.default_schema_name,
+ )
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
+
+ def _assert_insp_indexes(self, indexes, expected_indexes):
+ index_names = [d["name"] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_indexes(self, connection, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = inspect(self.bind)
+ indexes = insp.get_indexes("users", schema=schema)
+ expected_indexes = [
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
+ ]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ @testing.combinations(
+ ("noncol_idx_test_nopk", "noncol_idx_nopk"),
+ ("noncol_idx_test_pk", "noncol_idx_pk"),
+ argnames="tname,ixname",
+ )
+ @testing.requires.index_reflection
+ @testing.requires.indexes_with_ascdesc
+ def test_get_noncol_index(self, connection, tname, ixname):
+ insp = inspect(connection)
+ indexes = insp.get_indexes(tname)
+
+ # reflecting an index that has "x DESC" in it as the column.
+ # the DB may or may not give us "x", but make sure we get the index
+ # back, it has a name, it's connected to the table.
+ expected_indexes = [{"unique": False, "name": ixname}]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ t = Table(tname, MetaData(), autoload_with=connection)
+ eq_(len(t.indexes), 1)
+ is_(list(t.indexes)[0].table, t)
+ eq_(list(t.indexes)[0].name, ixname)
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.unique_constraint_reflection
+ def test_get_temp_table_unique_constraints(self):
+ insp = inspect(self.bind)
+ reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
+ for refl in reflected:
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop("duplicates_index", None)
+ eq_(
+ reflected,
+ [
+ {
+ "column_names": ["name"],
+ "name": "user_tmp_uq_%s" % config.ident,
+ }
+ ],
+ )
+
+ @testing.requires.temp_table_reflect_indexes
+ def test_get_temp_table_indexes(self):
+ insp = inspect(self.bind)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ indexes = insp.get_indexes(table_name)
+ for ind in indexes:
+ ind.pop("dialect_options", None)
+ expected = [
+ {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ expected,
+ )
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, metadata, connection, use_schema):
+ # SQLite dialect needs to parse the names of the constraints
+ # separately from what it gets from PRAGMA index_list(), and
+ # then matches them up. so same set of column_names in two
+ # constraints will confuse it. Perhaps we should no longer
+ # bother with index_list() here since we have the whole
+ # CREATE TABLE?
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ uniques = sorted(
+ [
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
+ ],
+ key=operator.itemgetter("name"),
+ )
+ table = Table(
+ "testtbl",
+ metadata,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
+ # reserved identifiers
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
+ )
+ for uc in uniques:
+ table.append_constraint(
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
+ )
+ table.create(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ names_that_duplicate_index = set()
+
+ for orig, refl in zip(uniques, reflected):
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ dupe = refl.pop("duplicates_index", None)
+ if dupe:
+ names_that_duplicate_index.add(dupe)
+ eq_(orig, refl)
+
+ reflected_metadata = MetaData()
+ reflected = Table(
+ "testtbl",
+ reflected_metadata,
+ autoload_with=connection,
+ schema=schema,
+ )
+
+ # test "deduplicates for index" logic. MySQL and Oracle
+ # "unique constraints" are actually unique indexes (with possible
+ # exception of a unique that is a dupe of another one in the case
+ # of Oracle). make sure # they aren't duplicated.
+ idx_names = set([idx.name for idx in reflected.indexes])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
+
+ assert not idx_names.intersection(uq_names)
+ if names_that_duplicate_index:
+ eq_(names_that_duplicate_index, idx_names)
+ eq_(uq_names, set())
+
+ @testing.requires.view_reflection
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_view_definition(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
+ insp = inspect(connection)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+
+ # why is this here if it's PG specific ?
+ @testing.combinations(
+ ("users", False),
+ ("users", True, testing.requires.schemas),
+ argnames="table_name,use_schema",
+ )
+ @testing.only_on("postgresql", "PG specific feature")
+ def test_get_table_oid(self, connection, table_name, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ insp = inspect(connection)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, int))
+
+ @testing.requires.table_reflection
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ A backend is better off not returning "autoincrement" at all,
+ instead of potentially returning "False" for an auto-incrementing
+ primary key column.
+
+ """
+
+ insp = inspect(self.bind)
+
+ for tname, cname in [
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
+ ]:
+ cols = insp.get_columns(tname)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
+
+
+class TableNoColumnsTest(fixtures.TestBase):
+ __requires__ = ("reflect_tables_no_columns",)
+ __backend__ = True
+
+ @testing.fixture
+ def table_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ @testing.fixture
+ def view_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ Table("empty", metadata)
+ event.listen(
+ metadata,
+ "after_create",
+ DDL("CREATE VIEW empty_v AS SELECT * FROM empty"),
+ )
+
+ # for transactional DDL the transaction is rolled back before this
+ # drop statement is invoked
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v")
+ )
+ metadata.create_all(connection)
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_table_no_columns(self, connection, table_no_columns):
+ t2 = Table("empty", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_table_no_columns(self, connection, table_no_columns):
+ eq_(inspect(connection).get_columns("empty"), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
+ m = MetaData()
+ m.reflect(connection)
+ assert set(m.tables).intersection(["empty"])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_view_no_columns(self, connection, view_no_columns):
+ t2 = Table("empty_v", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_view_no_columns(self, connection, view_no_columns):
+ eq_(inspect(connection).get_columns("empty_v"), [])
+
+
+class ComponentReflectionTestExtra(fixtures.TestBase):
+
+ __backend__ = True
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, metadata, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ Table(
+ "sa_cc",
+ metadata,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint(
+ "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing"
+ ),
+ schema=schema,
+ )
+
+ metadata.create_all(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ # trying to minimize effect of quoting, parenthesis, etc.
+ # may need to add more to this as new dialects get CHECK
+ # constraint reflection support
+ def normalize(sqltext):
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
+
+ reflected = [
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
+ for item in reflected
+ ]
+ eq_(
+ reflected,
+ [
+ {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"},
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ ],
+ )
+
+ @testing.requires.indexes_with_expressions
+ def test_reflect_expression_based_indexes(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+
+ Index("t_idx", func.lower(t.c.x), func.lower(t.c.y))
+
+ Index("t_idx_2", t.c.x)
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ expected = [
+ {"name": "t_idx_2", "column_names": ["x"], "unique": False}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ expected[0]["dialect_options"] = {
+ "%s_include" % connection.engine.name: []
+ }
+
+ with expect_warnings(
+ "Skipped unsupported reflection of expression-based index t_idx"
+ ):
+ eq_(
+ insp.get_indexes("t"),
+ expected,
+ )
+
+ @testing.requires.index_reflects_included_columns
+ def test_reflect_covering_index(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+ idx = Index("t_idx", t.c.x)
+ idx.dialect_options[connection.engine.name]["include"] = ["y"]
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ eq_(
+ insp.get_indexes("t"),
+ [
+ {
+ "name": "t_idx",
+ "column_names": ["x"],
+ "include_columns": ["y"],
+ "unique": False,
+ "dialect_options": {
+ "%s_include" % connection.engine.name: ["y"]
+ },
+ }
+ ],
+ )
+
+ t2 = Table("t", MetaData(), autoload_with=connection)
+ eq_(
+ list(t2.indexes)[0].dialect_options[connection.engine.name][
+ "include"
+ ],
+ ["y"],
+ )
+
+ def _type_round_trip(self, connection, metadata, *types):
+ t = Table(
+ "t",
+ metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
+ t.create(connection)
+
+ return [c["type"] for c in inspect(connection).get_columns("t")]
+
+ @testing.requires.table_reflection
+ def test_numeric_reflection(self, connection, metadata):
+ for typ in self._type_round_trip(
+ connection, metadata, sql_types.Numeric(18, 5)
+ ):
+ assert isinstance(typ, sql_types.Numeric)
+ eq_(typ.precision, 18)
+ eq_(typ.scale, 5)
+
+ @testing.requires.table_reflection
+ def test_varchar_reflection(self, connection, metadata):
+ typ = self._type_round_trip(
+ connection, metadata, sql_types.String(52)
+ )[0]
+ assert isinstance(typ, sql_types.String)
+ eq_(typ.length, 52)
+
+ @testing.requires.table_reflection
+ def test_nullable_reflection(self, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
+ t.create(connection)
+ eq_(
+ dict(
+ (col["name"], col["nullable"])
+ for col in inspect(connection).get_columns("t")
+ ),
+ {"a": True, "b": False},
+ )
+
+ @testing.combinations(
+ (
+ None,
+ "CASCADE",
+ None,
+ testing.requires.foreign_key_constraint_option_reflection_ondelete,
+ ),
+ (
+ None,
+ None,
+ "SET NULL",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ None,
+ "NO ACTION",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ "NO ACTION",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_noaction,
+ ),
+ (
+ None,
+ None,
+ "RESTRICT",
+ testing.requires.fk_constraint_option_reflection_onupdate_restrict,
+ ),
+ (
+ None,
+ "RESTRICT",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_restrict,
+ ),
+ argnames="expected,ondelete,onupdate",
+ )
+ def test_get_foreign_key_options(
+ self, connection, metadata, expected, ondelete, onupdate
+ ):
+ options = {}
+ if ondelete:
+ options["ondelete"] = ondelete
+ if onupdate:
+ options["onupdate"] = onupdate
+
+ if expected is None:
+ expected = options
+
+ Table(
+ "x",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ # test 'options' is always present for a backend
+ # that can reflect these, since alembic looks for this
+ opts = insp.get_foreign_keys("table")[0]["options"]
+
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
+
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(opts, expected)
+ # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
+
+
+class NormalizedNameTest(fixtures.TablesTest):
+ __requires__ = ("denormalized_names",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+ Table(
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
+ )
+
+ def test_reflect_lowercase_forced_tables(self):
+
+ m2 = MetaData()
+ t2_ref = Table(
+ quoted_name("t2", quote=True), m2, autoload_with=config.db
+ )
+ t1_ref = m2.tables["t1"]
+ assert t2_ref.c.t1id.references(t1_ref.c.id)
+
+ m3 = MetaData()
+ m3.reflect(
+ config.db, only=lambda name, m: name.lower() in ("t1", "t2")
+ )
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
+
+ def test_get_table_names(self):
+ tablenames = [
+ t
+ for t in inspect(config.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
+
+ eq_(tablenames[0].upper(), tablenames[0].lower())
+ eq_(tablenames[1].upper(), tablenames[1].lower())
+
+
+class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest):
+ def test_computed_col_default_not_set(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ col_data = {c["name"]: c for c in cols}
+ is_true("42" in col_data["with_default"]["default"])
+ is_(col_data["normal"]["default"], None)
+ is_(col_data["computed_col"]["default"], None)
+
+ def test_get_column_returns_computed(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ data = {c["name"]: c for c in cols}
+ for key in ("id", "normal", "with_default"):
+ is_true("computed" not in data[key])
+ compData = data["computed_col"]
+ is_true("computed" in compData)
+ is_true("sqltext" in compData["computed"])
+ eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")
+ eq_(
+ "persisted" in compData["computed"],
+ testing.requires.computed_columns_reflect_persisted.enabled,
+ )
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ eq_(
+ compData["computed"]["persisted"],
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+
+ def check_column(self, data, column, sqltext, persisted):
+ is_true("computed" in data[column])
+ compData = data[column]["computed"]
+ eq_(self.normalize(compData["sqltext"]), sqltext)
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ is_true("persisted" in compData)
+ is_(compData["persisted"], persisted)
+
+ def test_get_column_returns_persisted(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_column_table")
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal+42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal+2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal-42",
+ True,
+ )
+
+ @testing.requires.schemas
+ def test_get_column_returns_persisted_with_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns(
+ "computed_column_table", schema=config.test_schema
+ )
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal/42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal/2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal*42",
+ True,
+ )
+
+
+class IdentityReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("identity_columns", "table_reflection")
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity()),
+ )
+ Table(
+ "t2",
+ metadata,
+ Column(
+ "id2",
+ Integer,
+ Identity(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ ),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity(always=True, start=20)),
+ schema=config.test_schema,
+ )
+
+ def check(self, value, exp, approx):
+ if testing.requires.identity_columns_standard.enabled:
+ common_keys = (
+ "always",
+ "start",
+ "increment",
+ "minvalue",
+ "maxvalue",
+ "cycle",
+ "cache",
+ )
+ for k in list(value):
+ if k not in common_keys:
+ value.pop(k)
+ if approx:
+ eq_(len(value), len(exp))
+ for k in value:
+ if k == "minvalue":
+ is_true(value[k] <= exp[k])
+ elif k in {"maxvalue", "cache"}:
+ is_true(value[k] >= exp[k])
+ else:
+ eq_(value[k], exp[k], k)
+ else:
+ eq_(value, exp)
+ else:
+ eq_(value["start"], exp["start"])
+ eq_(value["increment"], exp["increment"])
+
+ def test_reflect_identity(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1") + insp.get_columns("t2")
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=False,
+ start=1,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+ elif col["name"] == "id2":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ approx=False,
+ )
+
+ @testing.requires.schemas
+ def test_reflect_identity_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1", schema=config.test_schema)
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=20,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+
+
+class CompositeKeyReflectionTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tb1 = Table(
+ "tb1",
+ metadata,
+ Column("id", Integer),
+ Column("attr", Integer),
+ Column("name", sql_types.VARCHAR(20)),
+ sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"),
+ schema=None,
+ test_needs_fk=True,
+ )
+ Table(
+ "tb2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("pid", Integer),
+ Column("pattr", Integer),
+ Column("pname", sql_types.VARCHAR(20)),
+ sa.ForeignKeyConstraint(
+ ["pname", "pid", "pattr"],
+ [tb1.c.name, tb1.c.id, tb1.c.attr],
+ name="fk_tb1_name_id_attr",
+ ),
+ schema=None,
+ test_needs_fk=True,
+ )
+
+ @testing.requires.primary_key_constraint_reflection
+ def test_pk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ primary_key = insp.get_pk_constraint(self.tables.tb1.name)
+ eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
+
+ @testing.requires.foreign_key_constraint_reflection
+ def test_fk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
+ eq_(len(foreign_keys), 1)
+ fkey1 = foreign_keys[0]
+ eq_(fkey1.get("referred_columns"), ["name", "id", "attr"])
+ eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"])
+
+
+__all__ = (
+ "ComponentReflectionTest",
+ "ComponentReflectionTestExtra",
+ "TableNoColumnsTest",
+ "QuotedNameArgumentTest",
+ "HasTableTest",
+ "HasIndexTest",
+ "NormalizedNameTest",
+ "ComputedReflectionTest",
+ "IdentityReflectionTest",
+ "CompositeKeyReflectionTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
new file mode 100644
index 0000000..c41a550
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -0,0 +1,426 @@
+import datetime
+
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import DateTime
+from ... import func
+from ... import Integer
+from ... import select
+from ... import sql
+from ... import String
+from ... import testing
+from ... import text
+from ... import util
+
+
+class RowFetchTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ connection.execute(
+ cls.tables.has_dates.insert(),
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
+ )
+
+ def test_via_attr(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row.id, 1)
+ eq_(row.data, "d1")
+
+ def test_via_string(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping["id"], 1)
+ eq_(row._mapping["data"], "d1")
+
+ def test_via_int(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
+
+ def test_via_col_object(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping[self.tables.plain_pk.c.id], 1)
+ eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
+
+ @requirements.duplicate_names_in_cursor_description
+ def test_row_with_dupe_names(self, connection):
+ result = connection.execute(
+ select(
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ).order_by(self.tables.plain_pk.c.id)
+ )
+ row = result.first()
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
+
+ def test_row_w_scalar_select(self, connection):
+ """test that a scalar select as a column is returned as such
+ and that type conversion works OK.
+
+ (this is half a SQLAlchemy Core test and half to catch database
+ backends that may have unusual behavior with scalar selects.)
+
+ """
+ datetable = self.tables.has_dates
+ s = select(datetable.alias("x").c.today).scalar_subquery()
+ s2 = select(datetable.c.id, s.label("somelabel"))
+ row = connection.execute(s2).first()
+
+ eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
+
+
+class PercentSchemaNamesTest(fixtures.TablesTest):
+ """tests using percent signs, spaces in table and column names.
+
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
+
+ """
+
+ __requires__ = ("percent_schema_names",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
+ cls.tables.lightweight_percent_table = sql.table(
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
+ )
+
+ def test_single_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ for params in [
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ]:
+ connection.execute(percent_table.insert(), params)
+ self._assert_table(connection)
+
+ def test_executemany_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ connection.execute(
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+ )
+ connection.execute(
+ percent_table.insert(),
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
+ )
+ self._assert_table(connection)
+
+ def _assert_table(self, conn):
+ percent_table = self.tables.percent_table
+ lightweight_percent_table = self.tables.lightweight_percent_table
+
+ for table in (
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
+ eq_(
+ list(
+ conn.execute(table.select().order_by(table.c["percent%"]))
+ ),
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
+ )
+ ),
+ [(9, 10), (11, 9)],
+ )
+
+ row = conn.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row._mapping["percent%"], 5)
+ eq_(row._mapping["spaces % more spaces"], 12)
+
+ eq_(row._mapping[table.c["percent%"]], 5)
+ eq_(row._mapping[table.c["spaces % more spaces"]], 12)
+
+ conn.execute(
+ percent_table.update().values(
+ {percent_table.c["spaces % more spaces"]: 15}
+ )
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
+ )
+ ),
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
+ )
+
+
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
+
+ __requires__ = ("server_side_cursors",)
+
+ __backend__ = True
+
+ def _is_server_side(self, cursor):
+ # TODO: this is a huge issue as it prevents these tests from being
+ # usable by third party dialects.
+ if self.engine.dialect.driver == "psycopg2":
+ return bool(cursor.name)
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "mysqldb":
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "mariadbconnector":
+ return not cursor.buffered
+ elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "pg8000":
+ return getattr(cursor, "server_side", False)
+ else:
+ return False
+
+ def _fixture(self, server_side_cursors):
+ if server_side_cursors:
+ with testing.expect_deprecated(
+ "The create_engine.server_side_cursors parameter is "
+ "deprecated and will be removed in a future release. "
+ "Please use the Connection.execution_options.stream_results "
+ "parameter."
+ ):
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ else:
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ return self.engine
+
+ @testing.combinations(
+ ("global_string", True, "select 1", True),
+ ("global_text", True, text("select 1"), True),
+ ("global_expr", True, select(1), True),
+ ("global_off_explicit", False, text("select 1"), False),
+ (
+ "stmt_option",
+ False,
+ select(1).execution_options(stream_results=True),
+ True,
+ ),
+ (
+ "stmt_option_disabled",
+ True,
+ select(1).execution_options(stream_results=False),
+ False,
+ ),
+ ("for_update_expr", True, select(1).with_for_update(), True),
+ # TODO: need a real requirement for this, or dont use this test
+ (
+ "for_update_string",
+ True,
+ "SELECT 1 FOR UPDATE",
+ True,
+ testing.skip_if("sqlite"),
+ ),
+ ("text_no_ss", False, text("select 42"), False),
+ (
+ "text_ss_option",
+ False,
+ text("select 42").execution_options(stream_results=True),
+ True,
+ ),
+ id_="iaaa",
+ argnames="engine_ss_arg, statement, cursor_ss_status",
+ )
+ def test_ss_cursor_status(
+ self, engine_ss_arg, statement, cursor_ss_status
+ ):
+ engine = self._fixture(engine_ss_arg)
+ with engine.begin() as conn:
+ if isinstance(statement, util.string_types):
+ result = conn.exec_driver_sql(statement)
+ else:
+ result = conn.execute(statement)
+ eq_(self._is_server_side(result.cursor), cursor_ss_status)
+ result.close()
+
+ def test_conn_option(self):
+ engine = self._fixture(False)
+
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
+
+ def test_stmt_enabled_conn_option_disabled(self):
+ engine = self._fixture(False)
+
+ s = select(1).execution_options(stream_results=True)
+
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
+
+ def test_aliases_and_ss(self):
+ engine = self._fixture(False)
+ s1 = (
+ select(sql.literal_column("1").label("x"))
+ .execution_options(stream_results=True)
+ .subquery()
+ )
+
+ # options don't propagate out when subquery is used as a FROM clause
+ with engine.begin() as conn:
+ result = conn.execute(s1.select())
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ s2 = select(1).select_from(s1)
+ with engine.begin() as conn:
+ result = conn.execute(s2)
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ def test_roundtrip_fetchall(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(test_table.insert(), dict(data="data1"))
+ connection.execute(test_table.insert(), dict(data="data2"))
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ connection.execute(
+ test_table.update()
+ .where(test_table.c.id == 2)
+ .values(data=test_table.c.data + " updated")
+ )
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
+ connection.execute(test_table.delete())
+ eq_(
+ connection.scalar(
+ select(func.count("*")).select_from(test_table)
+ ),
+ 0,
+ )
+
+ def test_roundtrip_fetchmany(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(
+ test_table.insert(),
+ [dict(data="data%d" % i) for i in range(1, 20)],
+ )
+
+ result = connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ )
+
+ eq_(
+ result.fetchmany(5),
+ [(i, "data%d" % i) for i in range(1, 6)],
+ )
+ eq_(
+ result.fetchmany(10),
+ [(i, "data%d" % i) for i in range(6, 16)],
+ )
+ eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py
new file mode 100644
index 0000000..82e831f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_rowcount.py
@@ -0,0 +1,165 @@
+from sqlalchemy import bindparam
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+
+
+class RowCountTest(fixtures.TablesTest):
+ """test rowcount functionality"""
+
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "employees",
+ metadata,
+ Column(
+ "employee_id",
+ Integer,
+ autoincrement=False,
+ primary_key=True,
+ ),
+ Column("name", String(50)),
+ Column("department", String(1)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ cls.data = data = [
+ ("Angela", "A"),
+ ("Andrew", "A"),
+ ("Anand", "A"),
+ ("Bob", "B"),
+ ("Bobette", "B"),
+ ("Buffy", "B"),
+ ("Charlie", "C"),
+ ("Cynthia", "C"),
+ ("Chris", "C"),
+ ]
+
+ employees_table = cls.tables.employees
+ connection.execute(
+ employees_table.insert(),
+ [
+ {"employee_id": i, "name": n, "department": d}
+ for i, (n, d) in enumerate(data)
+ ],
+ )
+
+ def test_basic(self, connection):
+ employees_table = self.tables.employees
+ s = select(
+ employees_table.c.name, employees_table.c.department
+ ).order_by(employees_table.c.employee_id)
+ rows = connection.execute(s).fetchall()
+
+ eq_(rows, self.data)
+
+ def test_update_rowcount1(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows changed
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "Z"},
+ )
+ assert r.rowcount == 3
+
+ def test_update_rowcount2(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 0 rows changed
+ department = employees_table.c.department
+
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "C"},
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_rowcount_w_returning
+ def test_update_rowcount_return_defaults(self, connection):
+ employees_table = self.tables.employees
+
+ department = employees_table.c.department
+ stmt = (
+ employees_table.update()
+ .where(department == "C")
+ .values(name=employees_table.c.department + "Z")
+ .return_defaults()
+ )
+
+ r = connection.execute(stmt)
+ eq_(r.rowcount, 3)
+
+ def test_raw_sql_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.exec_driver_sql(
+ "update employees set department='Z' where department='C'"
+ )
+ eq_(result.rowcount, 3)
+
+ def test_text_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.execute(
+ text("update employees set department='Z' " "where department='C'")
+ )
+ eq_(result.rowcount, 3)
+
+ def test_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows deleted
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.delete().where(department == "C")
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_update_rowcount(self, connection):
+ employees_table = self.tables.employees
+ stmt = (
+ employees_table.update()
+ .where(employees_table.c.name == bindparam("emp_name"))
+ .values(department="C")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ stmt = employees_table.delete().where(
+ employees_table.c.name == bindparam("emp_name")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
new file mode 100644
index 0000000..cb78fff
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -0,0 +1,1783 @@
+import itertools
+
+from .. import AssertsCompiledSQL
+from .. import AssertsExecutionResults
+from .. import config
+from .. import fixtures
+from ..assertions import assert_raises
+from ..assertions import eq_
+from ..assertions import in_
+from ..assertsql import CursorSQL
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import case
+from ... import column
+from ... import Computed
+from ... import exists
+from ... import false
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import null
+from ... import select
+from ... import String
+from ... import table
+from ... import testing
+from ... import text
+from ... import true
+from ... import tuple_
+from ... import TupleType
+from ... import union
+from ... import util
+from ... import values
+from ...exc import DatabaseError
+from ...exc import ProgrammingError
+from ...util import collections_abc
+
+
+class CollateTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "collate data1"},
+ {"id": 2, "data": "collate data2"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ @testing.requires.order_by_collation
+ def test_collate_order_by(self):
+ collation = testing.requires.get_order_by_collation(testing.config)
+
+ self._assert_result(
+ select(self.tables.some_table).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
+ )
+
+
+class OrderByLabelTest(fixtures.TablesTest):
+ """Test the dialect sends appropriate ORDER BY expressions when
+ labels are used.
+
+ This essentially exercises the "supports_simple_order_by_label"
+ setting.
+
+ """
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
+ {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
+ {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ def test_plain(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)])
+
+ def test_composed_int(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)])
+
+ def test_composed_multiple(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
+ self._assert_result(
+ select(lx, ly).order_by(lx, ly.desc()),
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
+ )
+
+ def test_plain_desc(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)])
+
+ def test_composed_int_desc(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)])
+
+ @testing.requires.group_by_complex_expression
+ def test_group_by_composed(self):
+ table = self.tables.some_table
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select(func.count(table.c.id), expr).group_by(expr).order_by(expr)
+ )
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
+
+
+class ValuesExpressionTest(fixtures.TestBase):
+ __requires__ = ("table_value_constructor",)
+
+ __backend__ = True
+
+ def test_tuples(self, connection):
+ value_expr = values(
+ column("id", Integer), column("name", String), name="my_values"
+ ).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+ eq_(
+ connection.execute(select(value_expr)).all(),
+ [(1, "name1"), (2, "name2"), (3, "name3")],
+ )
+
+
+class FetchLimitOffsetTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ {"id": 5, "x": 4, "y": 6},
+ ],
+ )
+
+ def _assert_result(
+ self, connection, select, result, params=(), set_=False
+ ):
+ if set_:
+ query_res = connection.execute(select, params).fetchall()
+ eq_(len(query_res), len(result))
+ eq_(set(query_res), set(result))
+
+ else:
+ eq_(connection.execute(select, params).fetchall(), result)
+
+ def _assert_result_str(self, select, result, params=()):
+ conn = config.db.connect(close_with_result=True)
+ eq_(conn.exec_driver_sql(select, params).fetchall(), result)
+
+ def test_simple_limit(self, connection):
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id)
+ self._assert_result(
+ connection,
+ stmt.limit(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ stmt.limit(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ def test_limit_render_multiple_times(self, connection):
+ table = self.tables.some_table
+ stmt = select(table.c.id).limit(1).scalar_subquery()
+
+ u = union(select(stmt), select(stmt)).subquery().select()
+
+ self._assert_result(
+ connection,
+ u,
+ [
+ (1,),
+ ],
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.offset
+ def test_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.combinations(
+ ([(2, 0), (2, 1), (3, 2)]),
+ ([(2, 1), (2, 0), (3, 2)]),
+ ([(3, 1), (2, 1), (3, 1)]),
+ argnames="cases",
+ )
+ @testing.requires.offset
+ def test_simple_limit_offset(self, connection, cases):
+ table = self.tables.some_table
+ connection = connection.execution_options(compiled_cache={})
+
+ assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)]
+
+ for limit, offset in cases:
+ expected = assert_data[offset : offset + limit]
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(limit).offset(offset),
+ expected,
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2).offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_no_order_by
+ def test_fetch_offset_no_order(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).fetch(10),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.offset
+ def test_simple_offset_zero(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(0),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(1),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.offset
+ def test_limit_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).limit(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.fetch_first
+ def test_fetch_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).fetch(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3)],
+ params={"l": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ params={"l": 3},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 1},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"l": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"l": 3, "o": 2},
+ )
+
+ @testing.requires.fetch_first
+ def test_bound_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"f": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"f": 3, "o": 2},
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .offset(literal_column("1") + literal_column("2")),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("2")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.fetch_first
+ @testing.requires.fetch_expression
+ def test_expr_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_simple_limit_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(2)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(3)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(2),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ def test_simple_fetch_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties_exact_number(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(3, with_ties=True)
+ .offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(20, percent=True),
+ [(1, 1, 2)],
+ )
+
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(40, percent=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x.desc())
+ .fetch(20, percent=True, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(40, percent=True, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+
+class JoinTest(fixtures.TablesTest):
+ __backend__ = True
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("a", metadata, Column("id", Integer, primary_key=True))
+ Table(
+ "b",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a_id", ForeignKey("a.id"), nullable=False),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.a.insert(),
+ [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}],
+ )
+
+ connection.execute(
+ cls.tables.b.insert(),
+ [
+ {"id": 1, "a_id": 1},
+ {"id": 2, "a_id": 1},
+ {"id": 4, "a_id": 2},
+ {"id": 5, "a_id": 3},
+ ],
+ )
+
+ def test_inner_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+ def test_inner_join_true(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, true()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (a, b, c)
+ for (a,), (b, c) in itertools.product(
+ [(1,), (2,), (3,), (4,), (5,)],
+ [(1, 1), (2, 1), (4, 2), (5, 3)],
+ )
+ ],
+ )
+
+ def test_inner_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(stmt, [])
+
+ def test_outer_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.outerjoin(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (1, None, None),
+ (2, None, None),
+ (3, None, None),
+ (4, None, None),
+ (5, None, None),
+ ],
+ )
+
+ def test_outer_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+
+class CompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_select_from_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_in_unions_from_alias(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ # this necessarily has double parens
+ u1 = union(s1, s2).alias()
+ self._assert_result(
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+
+class PostCompileParamsTest(
+ AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest
+):
+ __backend__ = True
+
+ __requires__ = ("standard_cursor_sql",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def test_compile(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table "
+ "WHERE some_table.x = __[POSTCOMPILE_q]",
+ {},
+ )
+
+ def test_compile_literal_binds(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", 10, literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table WHERE some_table.x = 10",
+ {},
+ literal_binds=True,
+ )
+
+ def test_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=10))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x = 10",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ def test_execute_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x.in_(bindparam("q", expanding=True, literal_execute=True))
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[5, 6, 7]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x IN (5, 6, 7)",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, 10), (12, 18)]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.y) "
+ "IN (%s(5, 10), (12, 18))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.z) "
+ "IN (%s(5, 'z1'), (12, 'z3'))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_multiple_empty_sets_bindparam(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .where(table.c.y.in_(bindparam("p")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
+ def test_multiple_empty_sets_direct(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.y.in_([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+ go([], [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+ go([], [])
+
+ def test_bound_in_scalar_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
+ def test_bound_in_scalar_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3, 4]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ def test_nonempty_in_plus_empty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3]))
+ .where(table.c.id.not_in([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,)])
+
+ def test_empty_in_plus_notempty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.id.not_in([2, 3]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ def test_typed_str_in(self):
+ """test related to #7292.
+
+ as a type is given to the bound param, there is no ambiguity
+ to the type of element.
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", type_=String, expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ def test_untyped_str_in(self):
+ """test related to #7292.
+
+ for untyped expression, we look at the types of elements.
+ Test for Sequence to detect tuple in. but not strings or bytes!
+ as always....
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ [(2, "z2"), (3, "z3"), (4, "z4")]
+ )
+ )
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(
+ bindparam(
+ "q", type_=TupleType(Integer(), String()), expanding=True
+ )
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ def test_empty_set_against_integer_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_integer_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_integer_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_integer_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_empty_set_against_string_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_string_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_string_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_string_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_null_in_empty_set_is_false_bindparam(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_(bindparam("foo", value=())),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+ def test_null_in_empty_set_is_false_direct(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_([]),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+
+class LikeFunctionsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ run_inserts = "once"
+ run_deletes = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "abcdefg"},
+ {"id": 2, "data": "ab/cdefg"},
+ {"id": 3, "data": "ab%cdefg"},
+ {"id": 4, "data": "ab_cdefg"},
+ {"id": 5, "data": "abcde/fg"},
+ {"id": 6, "data": "abcde%fg"},
+ {"id": 7, "data": "ab#cdefg"},
+ {"id": 8, "data": "ab9cdefg"},
+ {"id": 9, "data": "abcde#fg"},
+ {"id": 10, "data": "abcd9fg"},
+ {"id": 11, "data": None},
+ ],
+ )
+
+ def _test(self, expr, expected):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ rows = {
+ value
+ for value, in conn.execute(select(some_table.c.id).where(expr))
+ }
+
+ eq_(rows, expected)
+
+ def test_startswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+
+ def test_startswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True), {3})
+
+ def test_startswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.startswith(literal_column("'ab%c'")),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
+
+ def test_startswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab##c", escape="#"), {7})
+
+ def test_startswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3})
+ self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7})
+
+ def test_endswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_endswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
+
+ def test_endswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True), {6})
+
+ def test_endswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e##fg", escape="#"), {9})
+
+ def test_endswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6})
+ self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9})
+
+ def test_contains_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_contains_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde", autoescape=True), {3})
+
+ def test_contains_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b##cde", escape="#"), {7})
+
+ def test_contains_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
+ self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+ @testing.requires.regexp_match
+ def test_not_regexp_match(self):
+ col = self.tables.some_table.c.data
+ self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10})
+
+ @testing.requires.regexp_replace
+ def test_regexp_replace(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9}
+ )
+
+ @testing.requires.regexp_match
+ @testing.combinations(
+ ("a.cde", {1, 5, 6, 9}),
+ ("abc", {1, 5, 6, 9, 10}),
+ ("^abc", {1, 5, 6, 9, 10}),
+ ("9cde", {8}),
+ ("^a", set(range(1, 11))),
+ ("(b|c)", set(range(1, 11))),
+ ("^(b|c)", set()),
+ )
+ def test_regexp_match(self, text, expected):
+ col = self.tables.some_table.c.data
+ self._test(col.regexp_match(text), expected)
+
+
+class ComputedColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("computed_columns",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "square",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("side", Integer),
+ Column("area", Integer, Computed("side * side")),
+ Column("perimeter", Integer, Computed("4 * side")),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.square.insert(),
+ [{"id": 1, "side": 10}, {"id": 10, "side": 42}],
+ )
+
+ def test_select_all(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(text("*"))
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)])
+
+ def test_select_columns(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(
+ self.tables.square.c.area, self.tables.square.c.perimeter
+ )
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(100, 40), (1764, 168)])
+
+
+class IdentityColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("identity_columns",)
+ run_inserts = "once"
+ run_deletes = "once"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl_a",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(
+ always=True, start=42, nominvalue=True, nomaxvalue=True
+ ),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+ Table(
+ "tbl_b",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.tbl_a.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"id": 42, "desc": "c"}],
+ )
+
+ def test_select_all(self, connection):
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_a)
+ .order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42, "a"), (43, "b")])
+
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_b)
+ .order_by(self.tables.tbl_b.c.id)
+ ).fetchall()
+ eq_(res, [(-5, "b"), (0, "a"), (42, "c")])
+
+ def test_select_columns(self, connection):
+
+ res = connection.execute(
+ select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42,), (43,)])
+
+ @testing.requires.identity_columns_standard
+ def test_insert_always_error(self, connection):
+ def fn():
+ connection.execute(
+ self.tables.tbl_a.insert(),
+ [{"id": 200, "desc": "a"}],
+ )
+
+ assert_raises((DatabaseError, ProgrammingError), fn)
+
+
+class IdentityAutoincrementTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("autoincrement_without_sequence",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(),
+ primary_key=True,
+ autoincrement=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ def test_autoincrement_with_identity(self, connection):
+ res = connection.execute(self.tables.tbl.insert(), {"desc": "row"})
+ res = connection.execute(self.tables.tbl.select()).first()
+ eq_(res, (1, "row"))
+
+
+class ExistsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "stuff",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.stuff.insert(),
+ [
+ {"id": 1, "data": "some data"},
+ {"id": 2, "data": "some data"},
+ {"id": 3, "data": "some data"},
+ {"id": 4, "data": "some other data"},
+ ],
+ )
+
+ def test_select_exists(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "some data")
+ )
+ ).fetchall(),
+ [(1,)],
+ )
+
+ def test_select_exists_false(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "no data")
+ )
+ ).fetchall(),
+ [],
+ )
+
+
+class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest):
+ __backend__ = True
+
+ @testing.fails_if(testing.requires.supports_distinct_on)
+ def test_distinct_on(self):
+ stm = select("*").distinct(column("q")).select_from(table("foo"))
+ with testing.expect_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ ):
+ self.assert_compile(stm, "SELECT DISTINCT * FROM foo")
+
+
+class IsOrIsNotDistinctFromTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("supports_is_distinct_from",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "is_distinct_test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("col_a", Integer, nullable=True),
+ Column("col_b", Integer, nullable=True),
+ )
+
+ @testing.combinations(
+ ("both_int_different", 0, 1, 1),
+ ("both_int_same", 1, 1, 0),
+ ("one_null_first", None, 1, 1),
+ ("one_null_second", 0, None, 1),
+ ("both_null", None, None, 0),
+ id_="iaaa",
+ argnames="col_a_value, col_b_value, expected_row_count_for_is",
+ )
+ def test_is_or_is_not_distinct_from(
+ self, col_a_value, col_b_value, expected_row_count_for_is, connection
+ ):
+ tbl = self.tables.is_distinct_test
+
+ connection.execute(
+ tbl.insert(),
+ [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}],
+ )
+
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is,
+ )
+
+ expected_row_count_for_is_not = (
+ 1 if expected_row_count_for_is == 0 else 0
+ )
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is_not,
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
new file mode 100644
index 0000000..d6747d2
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -0,0 +1,282 @@
+from .. import config
+from .. import fixtures
+from ..assertions import eq_
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import Sequence
+from ... import String
+from ... import testing
+
+
+class SequenceTest(fixtures.TablesTest):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ run_create_tables = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "seq_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", data_type=Integer, optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
+ def test_insert_roundtrip(self, connection):
+ connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
+ self._assert_round_trip(self.tables.seq_pk, connection)
+
+ def test_insert_lastrowid(self, connection):
+ r = connection.execute(
+ self.tables.seq_pk.insert(), dict(data="some data")
+ )
+ eq_(
+ r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
+ )
+
+ def test_nextval_direct(self, connection):
+ r = connection.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+ @requirements.sequences_optional
+ def test_optional_seq(self, connection):
+ r = connection.execute(
+ self.tables.seq_opt_pk.insert(), dict(data="some data")
+ )
+ eq_(r.inserted_primary_key, (1,))
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
+
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_literal_binds_inline_compile(self, connection):
+ table = Table(
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
+
+ stmt = table.insert().values(q=5)
+
+ seq_nextval = connection.dialect.statement_compiler(
+ statement=None, dialect=connection.dialect
+ ).visit_sequence(Sequence("y_seq"))
+ self.assert_compile(
+ stmt,
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
+ literal_binds=True,
+ dialect=connection.dialect,
+ )
+
+
+class HasSequenceTest(fixtures.TablesTest):
+ run_deletes = None
+
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Sequence("user_id_seq", metadata=metadata)
+ Sequence(
+ "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
+ )
+ if testing.requires.schemas.enabled:
+ Sequence(
+ "user_id_seq", schema=config.test_schema, metadata=metadata
+ )
+ Sequence(
+ "schema_seq", schema=config.test_schema, metadata=metadata
+ )
+ Table(
+ "user_id_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+
+ def test_has_sequence(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_seq"),
+ True,
+ )
+
+ def test_has_sequence_other_object(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_table"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
+
+ def test_has_sequence_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence("some_sequence"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schemas_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "some_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_default_not_in_remote(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "other_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_remote_not_in_default(self, connection):
+ eq_(
+ inspect(connection).has_sequence("schema_seq"),
+ False,
+ )
+
+ def test_get_sequence_names(self, connection):
+ exp = {"other_seq", "user_id_seq"}
+
+ res = set(inspect(connection).get_sequence_names())
+ is_true(res.intersection(exp) == exp)
+ is_true("schema_seq" not in res)
+
+ @testing.requires.schemas
+ def test_get_sequence_names_no_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema_2
+ ),
+ [],
+ )
+
+ @testing.requires.schemas
+ def test_get_sequence_names_sequences_schema(self, connection):
+ eq_(
+ sorted(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema
+ )
+ ),
+ ["schema_seq", "user_id_seq"],
+ )
+
+
+class HasSequenceTestEmpty(fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_get_sequence_names_no_sequence(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(),
+ [],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
new file mode 100644
index 0000000..b96350e
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -0,0 +1,1508 @@
+# coding: utf-8
+
+import datetime
+import decimal
+import json
+import re
+
+from .. import config
+from .. import engines
+from .. import fixtures
+from .. import mock
+from ..assertions import eq_
+from ..assertions import is_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import and_
+from ... import BigInteger
+from ... import bindparam
+from ... import Boolean
+from ... import case
+from ... import cast
+from ... import Date
+from ... import DateTime
+from ... import Float
+from ... import Integer
+from ... import JSON
+from ... import literal
+from ... import MetaData
+from ... import null
+from ... import Numeric
+from ... import select
+from ... import String
+from ... import testing
+from ... import Text
+from ... import Time
+from ... import TIMESTAMP
+from ... import TypeDecorator
+from ... import Unicode
+from ... import UnicodeText
+from ... import util
+from ...orm import declarative_base
+from ...orm import Session
+from ...sql.sqltypes import LargeBinary
+from ...sql.sqltypes import PickleType
+from ...util import compat
+from ...util import u
+
+
+class _LiteralRoundTripFixture(object):
+ supports_whereclause = True
+
+ @testing.fixture
+ def literal_round_trip(self, metadata, connection):
+ """test literal rendering"""
+
+ # for literal, we test the literal render in an INSERT
+ # into a typed column. we can then SELECT it back as its
+ # official type; ideally we'd be able to use CAST here
+ # but MySQL in particular can't CAST fully
+
+ def run(type_, input_, output, filter_=None):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ for value in input_:
+ ins = (
+ t.insert()
+ .values(x=literal(value, type_))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ )
+ connection.execute(ins)
+
+ if self.supports_whereclause:
+ stmt = t.select().where(t.c.x == literal(value))
+ else:
+ stmt = t.select()
+
+ stmt = stmt.compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ for row in connection.execute(stmt):
+ value = row[0]
+ if filter_ is not None:
+ value = filter_(value)
+ assert value in output
+
+ return run
+
+
+class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ __requires__ = ("unicode_data",)
+
+ data = u(
+ "Alors vous imaginez ma 🐍 surprise, au lever du jour, "
+ "quand une drôle de petite 🐍 voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »"
+ )
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
+
+ def test_round_trip(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": self.data}
+ )
+
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+
+ eq_(row, (self.data,))
+ assert isinstance(row[0], util.text_type)
+
+ def test_round_trip_executemany(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(),
+ [{"id": i, "unicode_data": self.data} for i in range(1, 4)],
+ )
+
+ rows = connection.execute(
+ select(unicode_table.c.unicode_data)
+ ).fetchall()
+ eq_(rows, [(self.data,) for i in range(1, 4)])
+ for row in rows:
+ assert isinstance(row[0], util.text_type)
+
+ def _test_null_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": None}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (None,))
+
+ def _test_empty_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": u("")}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (u(""),))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(self.datatype, [self.data], [self.data])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+
+class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = ("unicode_data",)
+ __backend__ = True
+
+ datatype = Unicode(255)
+
+ @requirements.empty_strings_varchar
+ def test_empty_strings_varchar(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_varchar(self, connection):
+ self._test_null_strings(connection)
+
+
+class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = "unicode_data", "text_type"
+ __backend__ = True
+
+ datatype = UnicodeText()
+
+ @requirements.empty_strings_text
+ def test_empty_strings_text(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_text(self, connection):
+ self._test_null_strings(connection)
+
+
+class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("binary_literals",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "binary_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("binary_data", LargeBinary),
+ Column("pickle_data", PickleType),
+ )
+
+ def test_binary_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(), {"id": 1, "binary_data": b"this is binary"}
+ )
+ row = connection.execute(select(binary_table.c.binary_data)).first()
+ eq_(row, (b"this is binary",))
+
+ def test_pickle_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(),
+ {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}},
+ )
+ row = connection.execute(select(binary_table.c.pickle_data)).first()
+ eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},))
+
+
+class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("text_type",)
+ __backend__ = True
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
+
+ def test_text_roundtrip(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(
+ text_table.insert(), {"id": 1, "text_data": "some text"}
+ )
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("some text",))
+
+ @testing.requires.empty_strings_text
+ def test_text_empty_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": ""})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("",))
+
+ def test_text_null_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": None})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, (None,))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Text, ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_percentsigns(self, literal_round_trip):
+ data = r"percent % signs %% percent"
+ literal_round_trip(Text, [data], [data])
+
+
+class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @requirements.unbounded_varchar
+ def test_nolength_string(self):
+ metadata = MetaData()
+ foo = Table("foo", metadata, Column("one", String))
+
+ foo.create(config.db)
+ foo.drop(config.db)
+
+ def test_literal(self, literal_round_trip):
+ # note that in Python 3, this invokes the Unicode
+ # datatype for the literal part because all strings are unicode
+ literal_round_trip(String(40), ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(String(40), [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(String(40), [data], [data])
+
+
+class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ compare = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class Decorated(TypeDecorator):
+ impl = cls.datatype
+ cache_ok = True
+
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ Column("decorated_date_data", Decorated),
+ )
+
+ @testing.requires.datetime_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+ def test_round_trip(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_round_trip_decorated(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "decorated_date_data": self.data}
+ )
+
+ row = connection.execute(
+ select(date_table.c.decorated_date_data)
+ ).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_null(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(date_table.insert(), {"id": 1, "date_data": None})
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+ eq_(row, (None,))
+
+ @testing.requires.datetime_literals
+ def test_literal(self, literal_round_trip):
+ compare = self.compare or self.data
+ literal_round_trip(self.datatype, [self.data], [compare])
+
+ @testing.requires.standalone_null_binds_whereclause
+ def test_null_bound_comparison(self):
+ # this test is based on an Oracle issue observed in #4886.
+ # passing NULL for an expression that needs to be interpreted as
+ # a certain type, does the DBAPI have the info it needs to do this.
+ date_table = self.tables.date_table
+ with config.db.begin() as conn:
+ result = conn.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+ id_ = result.inserted_primary_key[0]
+ stmt = select(date_table.c.id).where(
+ case(
+ (
+ bindparam("foo", type_=self.datatype) != None,
+ bindparam("foo", type_=self.datatype),
+ ),
+ else_=date_table.c.date_data,
+ )
+ == date_table.c.date_data
+ )
+
+ row = conn.execute(stmt, {"foo": None}).first()
+ eq_(row[0], id_)
+
+
+class DateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+
+
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_timezone",)
+ __backend__ = True
+ datatype = DateTime(timezone=True)
+ data = datetime.datetime(
+ 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+ )
+
+
+class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_microseconds",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+
+class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("timestamp_microseconds",)
+ __backend__ = True
+ datatype = TIMESTAMP
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+ @testing.requires.timestamp_microseconds_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+
+class TimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18)
+
+
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_timezone",)
+ __backend__ = True
+ datatype = Time(timezone=True)
+ data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
+class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_microseconds",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18, 396)
+
+
+class DateTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(2012, 10, 15)
+
+
+class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = "date", "date_coerces_from_datetime"
+ __backend__ = True
+ datatype = Date
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+ compare = datetime.date(2012, 10, 15)
+
+
+class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_historic",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(1850, 11, 10, 11, 52, 35)
+
+
+class DateHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date_historic",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(1727, 4, 1)
+
+
+class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Integer, [5], [5])
+
+ def test_huge_int(self, integer_round_trip):
+ integer_round_trip(BigInteger, 1376537018368127)
+
+ @testing.fixture
+ def integer_round_trip(self, metadata, connection):
+ def run(datatype, data):
+ int_table = Table(
+ "integer_table",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ primary_key=True,
+ test_needs_autoincrement=True,
+ ),
+ Column("integer_data", datatype),
+ )
+
+ metadata.create_all(config.db)
+
+ connection.execute(
+ int_table.insert(), {"id": 1, "integer_data": data}
+ )
+
+ row = connection.execute(select(int_table.c.integer_data)).first()
+
+ eq_(row, (data,))
+
+ if util.py3k:
+ assert isinstance(row[0], int)
+ else:
+ assert isinstance(row[0], (long, int)) # noqa
+
+ return run
+
+
+class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def string_as_int(self):
+ class StringAsInt(TypeDecorator):
+ impl = String(50)
+ cache_ok = True
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def column_expression(self, col):
+ return cast(col, Integer)
+
+ def bind_expression(self, col):
+ return cast(col, String(50))
+
+ return StringAsInt()
+
+ def test_special_type(self, metadata, connection, string_as_int):
+
+ type_ = string_as_int
+
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ eq_(result, {1, 2, 3})
+
+ result = {
+ row[0] for row in connection.execute(t.select().where(t.c.x == 2))
+ }
+ eq_(result, {2})
+
+
+class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def do_numeric_test(self, metadata, connection):
+ @testing.emits_warning(
+ r".*does \*not\* support Decimal objects natively"
+ )
+ def run(type_, input_, output, filter_=None, check_scale=False):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
+
+ return run
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric_asfloat(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ def test_render_literal_float(self, literal_round_trip):
+ literal_round_trip(
+ Float(4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ @testing.requires.precision_generic_float_type
+ def test_float_custom_scale(self, do_numeric_test):
+ do_numeric_test(
+ Float(None, decimal_return_scale=7, asdecimal=True),
+ [15.7563827, decimal.Decimal("15.7563827")],
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
+ )
+
+ def test_numeric_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ def test_numeric_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ @testing.requires.infinity_floats
+ def test_infinity_floats(self, do_numeric_test):
+ """test for #977, #7283"""
+
+ do_numeric_test(
+ Float(None),
+ [float("inf")],
+ [float("inf")],
+ )
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_decimal(self, do_numeric_test):
+ do_numeric_test(Numeric(precision=8, scale=4), [None], [None])
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
+ )
+
+ @testing.requires.floats_to_four_decimals
+ def test_float_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8, asdecimal=True),
+ [15.7563, decimal.Decimal("15.7563"), None],
+ [decimal.Decimal("15.7563"), None],
+ filter_=lambda n: n is not None and round(n, 4) or None,
+ )
+
+ def test_float_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ def test_float_coerce_round_trip(self, connection):
+ expr = 15.7563
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ # this does not work in MySQL, see #4036, however we choose not
+ # to render CAST unconditionally since this is kind of an edge case.
+
+ @testing.requires.implicit_decimal_binds
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip_w_cast(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(cast(expr, Numeric(10, 4))))
+ eq_(val, expr)
+
+ @testing.requires.precision_numerics_general
+ def test_precision_decimal(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
+ )
+
+ do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal(self, do_numeric_test):
+ """test exceedingly small decimals.
+
+ Decimal reports values with E notation when the exponent
+ is greater than 6.
+
+ """
+
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal_large(self, do_numeric_test):
+ """test exceedingly large decimals."""
+
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers)
+
+ @testing.requires.precision_numerics_many_significant_digits
+ def test_many_significant_digits(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_retains_significant_digits
+ def test_numeric_no_decimal(self, do_numeric_test):
+ numbers = set([decimal.Decimal("1.000")])
+ do_numeric_test(
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
+ )
+
+
+class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
+
+ def test_render_literal_bool(self, literal_round_trip):
+ literal_round_trip(Boolean(), [True, False], [True, False])
+
+ def test_round_trip(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": True, "unconstrained_value": False},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (True, False))
+ assert isinstance(row[0], bool)
+
+ @testing.requires.nullable_booleans
+ def test_null(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": None, "unconstrained_value": None},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (None, None))
+
+ def test_whereclause(self):
+ # testing "WHERE <column>" renders a compatible expression
+ boolean_table = self.tables.boolean_table
+
+ with config.db.begin() as conn:
+ conn.execute(
+ boolean_table.insert(),
+ [
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
+ )
+
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(boolean_table.c.value)
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ boolean_table.c.unconstrained_value
+ )
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(~boolean_table.c.value)
+ ),
+ 2,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ ~boolean_table.c.unconstrained_value
+ )
+ ),
+ 2,
+ )
+
+
+class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("json_type",)
+ __backend__ = True
+
+ datatype = JSON
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype, nullable=False),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def test_round_trip_data1(self, connection):
+ self._test_round_trip({"key1": "value1", "key2": "value2"}, connection)
+
+ def _test_round_trip(self, data_element, connection):
+ data_table = self.tables.data_table
+
+ connection.execute(
+ data_table.insert(),
+ {"id": 1, "name": "row1", "data": data_element},
+ )
+
+ row = connection.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+
+ def _index_fixtures(include_comparison):
+
+ if include_comparison:
+ # basically SQL Server and MariaDB can kind of do json
+ # comparison, MySQL, PG and SQLite can't. not worth it.
+ json_elements = []
+ else:
+ json_elements = [
+ ("json", {"foo": "bar"}),
+ ("json", ["one", "two", "three"]),
+ (None, {"foo": "bar"}),
+ (None, ["one", "two", "three"]),
+ ]
+
+ elements = [
+ ("boolean", True),
+ ("boolean", False),
+ ("boolean", None),
+ ("string", "some string"),
+ ("string", None),
+ ("string", util.u("réve illé")),
+ (
+ "string",
+ util.u("réve🐍 illé"),
+ testing.requires.json_index_supplementary_unicode_element,
+ ),
+ ("integer", 15),
+ ("integer", 1),
+ ("integer", 0),
+ ("integer", None),
+ ("float", 28.5),
+ ("float", None),
+ (
+ "float",
+ 1234567.89,
+ ),
+ ("numeric", 1234567.89),
+ # this one "works" because the float value you see here is
+ # lost immediately to floating point stuff
+ ("numeric", 99998969694839.983485848, requirements.python3),
+ ("numeric", 99939.983485848, requirements.python3),
+ ("_decimal", decimal.Decimal("1234567.89")),
+ (
+ "_decimal",
+ decimal.Decimal("99998969694839.983485848"),
+ # fails on SQLite and MySQL (non-mariadb)
+ requirements.cast_precision_numerics_many_significant_digits,
+ ),
+ (
+ "_decimal",
+ decimal.Decimal("99939.983485848"),
+ ),
+ ] + json_elements
+
+ def decorate(fn):
+ fn = testing.combinations(id_="sa", *elements)(fn)
+
+ return fn
+
+ return decorate
+
+ def _json_value_insert(self, connection, datatype, value, data_element):
+ data_table = self.tables.data_table
+ if datatype == "_decimal":
+
+ # Python's builtin json serializer basically doesn't support
+ # Decimal objects without implicit float conversion period.
+ # users can otherwise use simplejson which supports
+ # precision decimals
+
+ # https://bugs.python.org/issue16535
+
+ # inserting as strings to avoid a new fixture around the
+ # dialect which would have idiosyncrasies for different
+ # backends.
+
+ class DecimalEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, decimal.Decimal):
+ return str(o)
+ return super(DecimalEncoder, self).default(o)
+
+ json_data = json.dumps(data_element, cls=DecimalEncoder)
+
+ # take the quotes out. yup, there is *literally* no other
+ # way to get Python's json.dumps() to put all the digits in
+ # the string
+ json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data)
+
+ datatype = "numeric"
+
+ connection.execute(
+ data_table.insert().values(
+ name="row1",
+ # to pass the string directly to every backend, including
+ # PostgreSQL which needs the value to be CAST as JSON
+ # both in the SQL as well as at the prepared statement
+ # level for asyncpg, while at the same time MySQL
+ # doesn't even support CAST for JSON, here we are
+ # sending the string embedded in the SQL without using
+ # a parameter.
+ data=bindparam(None, json_data, literal_execute=True),
+ nulldata=bindparam(None, json_data, literal_execute=True),
+ ),
+ )
+ else:
+ connection.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ p_s = None
+
+ if datatype:
+ if datatype == "numeric":
+ a, b = str(value).split(".")
+ s = len(b)
+ p = len(a) + s
+
+ if isinstance(value, decimal.Decimal):
+ compare_value = value
+ else:
+ compare_value = decimal.Decimal(str(value))
+
+ p_s = (p, s)
+ else:
+ compare_value = value
+ else:
+ compare_value = value
+
+ return datatype, compare_value, p_s
+
+ @_index_fixtures(False)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_access(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ roundtrip = conn.scalar(select(expr))
+ eq_(roundtrip, compare_value)
+ if util.py3k: # skip py2k to avoid comparing unicode to str etc.
+ is_(type(roundtrip), type(compare_value))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_path_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": {"subkey1": value}}
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data[("key1", "subkey1")]
+
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @testing.combinations(
+ (True,),
+ (False,),
+ (None,),
+ (15,),
+ (0,),
+ (-1,),
+ (-1.0,),
+ (15.052,),
+ ("a string",),
+ (util.u("réve illé"),),
+ (util.u("réve🐍 illé"),),
+ )
+ def test_single_element_round_trip(self, element):
+ data_table = self.tables.data_table
+ data_element = element
+ with config.db.begin() as conn:
+ conn.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ row = conn.execute(
+ select(data_table.c.data, data_table.c.nulldata)
+ ).first()
+
+ eq_(row, (data_element, data_element))
+
+ def test_round_trip_custom_json(self):
+ data_table = self.tables.data_table
+ data_element = {"key1": "data1"}
+
+ js = mock.Mock(side_effect=json.dumps)
+ jd = mock.Mock(side_effect=json.loads)
+ engine = engines.testing_engine(
+ options=dict(json_serializer=js, json_deserializer=jd)
+ )
+
+ # support sqlite :memory: database...
+ data_table.create(engine, checkfirst=True)
+ with engine.begin() as conn:
+ conn.execute(
+ data_table.insert(), {"name": "row1", "data": data_element}
+ )
+ row = conn.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+ eq_(js.mock_calls, [mock.call(data_element)])
+ eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ ("omit",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_sql_null(self, connection, insert_type):
+ col = self.tables.data_table.c["nulldata"]
+
+ conn = connection
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "nulldata": None,
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "nulldata": None, "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(
+ name="r1",
+ nulldata=None,
+ data=None,
+ ),
+ {},
+ )
+ elif insert_type == "omit":
+ stmt, params = (
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": None},
+ )
+
+ else:
+ assert False
+
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(col.is_(null()))
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_round_trip_json_null_as_json_null(self, connection):
+ col = self.tables.data_table.c["data"]
+
+ conn = connection
+ conn.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": JSON.NULL},
+ )
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_json_null(self, connection, insert_type):
+ col = self.tables.data_table.c["data"]
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(name="r1", data=None),
+ {},
+ )
+ else:
+ assert False
+
+ conn = connection
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_unicode_round_trip(self):
+ # note we include Unicode supplementary characters as well
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ {
+ "name": "r1",
+ "data": {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ },
+ )
+
+ eq_(
+ conn.scalar(select(self.tables.data_table.c.data)),
+ {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ )
+
+ def test_eval_none_flag_orm(self, connection):
+
+ Base = declarative_base()
+
+ class Data(Base):
+ __table__ = self.tables.data_table
+
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
+
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+class JSONLegacyStringCastIndexTest(
+ _LiteralRoundTripFixture, fixtures.TablesTest
+):
+ """test JSON index access with "cast to string", which we have documented
+ for a long time as how to compare JSON values, but is ultimately not
+ reliable in all cases. The "as_XYZ()" comparators should be used
+ instead.
+
+ """
+
+ __requires__ = ("json_type", "legacy_unconditional_json_extract")
+ __backend__ = True
+
+ datatype = JSON
+
+ data1 = {"key1": "value1", "key2": "value2"}
+
+ data2 = {
+ "Key 'One'": "value1",
+ "key two": "value2",
+ "key three": "value ' three '",
+ }
+
+ data3 = {
+ "key1": [1, 2, 3],
+ "key2": ["one", "two", "three"],
+ "key3": [{"four": "five"}, {"six": "seven"}],
+ }
+
+ data4 = ["one", "two", "three"]
+
+ data5 = {
+ "nested": {
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
+ }
+ }
+
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def _criteria_fixture(self):
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
+ )
+
+ def _test_index_criteria(self, crit, expected, test_literal=True):
+ self._criteria_fixture()
+ with config.db.connect() as conn:
+ stmt = select(self.tables.data_table.c.name).where(crit)
+
+ eq_(conn.scalar(stmt), expected)
+
+ if test_literal:
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
+
+ eq_(conn.exec_driver_sql(literal_sql).scalar(), expected)
+
+ def test_string_cast_crit_spaces_in_key(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract field from a non-object", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name.in_(["r1", "r2", "r3"]),
+ cast(col["key two"], String) == '"value2"',
+ ),
+ "r2",
+ )
+
+ @config.requirements.json_array_indexes
+ def test_string_cast_crit_simple_int(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract array element from a non-array", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name == "r4",
+ cast(col[1], String) == '"two"',
+ ),
+ "r4",
+ )
+
+ def test_string_cast_crit_mixed_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("key3", 1, "six")], String) == '"seven"',
+ "r3",
+ )
+
+ def test_string_cast_crit_string_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("nested", "elem2", "elem3", "elem4")], String)
+ == '"elem5"',
+ "r5",
+ )
+
+ def test_string_cast_crit_against_string_basic(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ self._test_index_criteria(
+ and_(
+ name == "r6",
+ cast(col["b"], String) == '"some value"',
+ ),
+ "r6",
+ )
+
+
+__all__ = (
+ "BinaryTest",
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "JSONLegacyStringCastIndexTest",
+ "DateTest",
+ "DateTimeTest",
+ "DateTimeTZTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "CastTypeDecoratorTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "TimeTZTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
new file mode 100644
index 0000000..a4ae334
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
@@ -0,0 +1,206 @@
+# coding: utf-8
+"""verrrrry basic unicode column name testing"""
+
+from sqlalchemy import desc
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import testing
+from sqlalchemy import util
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.util import u
+from sqlalchemy.util import ue
+
+
+class UnicodeSchemaTest(fixtures.TablesTest):
+ __requires__ = ("unicode_ddl",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2, t3
+
+ t1 = Table(
+ u("unitable1"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True),
+ Column(ue("\u6e2c\u8a66"), Integer),
+ test_needs_fk=True,
+ )
+ t2 = Table(
+ u("Unitéble2"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True, key="a"),
+ Column(
+ ue("\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(u("unitable1.méil")),
+ key="b",
+ ),
+ test_needs_fk=True,
+ )
+
+ # Few DBs support Unicode foreign keys
+ if testing.against("sqlite"):
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(
+ ue("unitable1_\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(ue("unitable1.\u6e2c\u8a66")),
+ ),
+ Column(
+ u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b"))
+ ),
+ Column(
+ ue("\u6e2c\u8a66_self"),
+ Integer,
+ ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")),
+ ),
+ test_needs_fk=True,
+ )
+ else:
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(ue("unitable1_\u6e2c\u8a66"), Integer),
+ Column(u("Unitéble2_b"), Integer),
+ Column(ue("\u6e2c\u8a66_self"), Integer),
+ test_needs_fk=True,
+ )
+
+ def test_insert(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
+ eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
+ eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
+
+ def test_col_targeting(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ row = connection.execute(t1.select()).first()
+ eq_(row._mapping[t1.c[u("méil")]], 1)
+ eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5)
+
+ row = connection.execute(t2.select()).first()
+ eq_(row._mapping[t2.c[u("a")]], 1)
+ eq_(row._mapping[t2.c[u("b")]], 1)
+
+ row = connection.execute(t3.select()).first()
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1)
+ eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5)
+ eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1)
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1)
+
+ def test_reflect(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7})
+ connection.execute(t2.insert(), {u("a"): 2, u("b"): 2})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 2,
+ ue("unitable1_\u6e2c\u8a66"): 7,
+ u("Unitéble2_b"): 2,
+ ue("\u6e2c\u8a66_self"): 2,
+ },
+ )
+
+ meta = MetaData()
+ tt1 = Table(t1.name, meta, autoload_with=connection)
+ tt2 = Table(t2.name, meta, autoload_with=connection)
+ tt3 = Table(t3.name, meta, autoload_with=connection)
+
+ connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1})
+ connection.execute(
+ tt3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(
+ connection.execute(
+ tt1.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 7), (1, 5)],
+ )
+ eq_(
+ connection.execute(
+ tt2.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 2), (1, 1)],
+ )
+ eq_(
+ connection.execute(
+ tt3.select().order_by(desc(ue("\u6e2c\u8a66_id")))
+ ).fetchall(),
+ [(2, 7, 2, 2), (1, 5, 1, 1)],
+ )
+
+ def test_repr(self):
+ meta = MetaData()
+ t = Table(
+ ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer)
+ )
+
+ if util.py2k:
+ eq_(
+ repr(t),
+ (
+ "Table('\\u6e2c\\u8a66', MetaData(), "
+ "Column('\\u6e2c\\u8a66_id', Integer(), "
+ "table=<\u6e2c\u8a66>), "
+ "schema=None)"
+ ),
+ )
+ else:
+ eq_(
+ repr(t),
+ (
+ "Table('測試', MetaData(), "
+ "Column('測試_id', Integer(), "
+ "table=<測試>), "
+ "schema=None)"
+ ),
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
new file mode 100644
index 0000000..f04a9d5
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -0,0 +1,60 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import String
+
+
+class SimpleUpdateDeleteTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ def test_update(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(
+ t.update().where(t.c.id == 2), dict(data="d2_new")
+ )
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
+ )
+
+ def test_delete(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(t.delete().where(t.c.id == 2))
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (3, "d3")],
+ )
+
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
new file mode 100644
index 0000000..be89bc6
--- /dev/null
+++ b/lib/sqlalchemy/testing/util.py
@@ -0,0 +1,458 @@
+# testing/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import decimal
+import gc
+import random
+import sys
+import types
+
+from . import config
+from . import mock
+from .. import inspect
+from ..engine import Connection
+from ..schema import Column
+from ..schema import DropConstraint
+from ..schema import DropTable
+from ..schema import ForeignKeyConstraint
+from ..schema import MetaData
+from ..schema import Table
+from ..sql import schema
+from ..sql.sqltypes import Integer
+from ..util import decorator
+from ..util import defaultdict
+from ..util import has_refcount_gc
+from ..util import inspect_getfullargspec
+from ..util import py2k
+
+
+if not has_refcount_gc:
+
+ def non_refcount_gc_collect(*args):
+ gc.collect()
+ gc.collect()
+
+ gc_collect = lazy_gc = non_refcount_gc_collect
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+
+ def lazy_gc():
+ pass
+
+
+def picklers():
+ picklers = set()
+ if py2k:
+ try:
+ import cPickle
+
+ picklers.add(cPickle)
+ except ImportError:
+ pass
+
+ import pickle
+
+ picklers.add(pickle)
+
+ # yes, this thing needs this much testing
+ for pickle_ in picklers:
+ for protocol in range(-2, pickle.HIGHEST_PROTOCOL):
+ yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
+
+
+if py2k:
+
+ def random_choices(population, k=1):
+ pop = list(population)
+ # lame but works :)
+ random.shuffle(pop)
+ return pop[0:k]
+
+
+else:
+
+ def random_choices(population, k=1):
+ return random.choices(population, k=k)
+
+
+def round_decimal(value, prec):
+ if isinstance(value, float):
+ return round(value, prec)
+
+ # can also use shift() here but that is 2.6 only
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
+
+
+class RandomSet(set):
+ def __iter__(self):
+ l = list(set.__iter__(self))
+ random.shuffle(l)
+ return iter(l)
+
+ def pop(self):
+ index = random.randint(0, len(self) - 1)
+ item = list(set.__iter__(self))[index]
+ self.remove(item)
+ return item
+
+ def union(self, other):
+ return RandomSet(set.union(self, other))
+
+ def difference(self, other):
+ return RandomSet(set.difference(self, other))
+
+ def intersection(self, other):
+ return RandomSet(set.intersection(self, other))
+
+ def copy(self):
+ return RandomSet(self)
+
+
+def conforms_partial_ordering(tuples, sorted_elements):
+ """True if the given sorting conforms to the given partial ordering."""
+
+ deps = defaultdict(set)
+ for parent, child in tuples:
+ deps[parent].add(child)
+ for i, node in enumerate(sorted_elements):
+ for n in sorted_elements[i:]:
+ if node in deps[n]:
+ return False
+ else:
+ return True
+
+
+def all_partial_orderings(tuples, elements):
+ edges = defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ def _all_orderings(elements):
+
+ if len(elements) == 1:
+ yield list(elements)
+ else:
+ for elem in elements:
+ subset = set(elements).difference([elem])
+ if not subset.intersection(edges[elem]):
+ for sub_ordering in _all_orderings(subset):
+ yield [elem] + sub_ordering
+
+ return iter(_all_orderings(elements))
+
+
+def function_named(fn, name):
+ """Return a function with a given __name__.
+
+ Will assign to __name__ and return the original function if possible on
+ the Python implementation, otherwise a new function will be constructed.
+
+ This function should be phased out as much as possible
+ in favor of @decorator. Tests that "generate" many named tests
+ should be modernized.
+
+ """
+ try:
+ fn.__name__ = name
+ except TypeError:
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
+ return fn
+
+
+def run_as_contextmanager(ctx, fn, *arg, **kw):
+ """Run the given function under the given contextmanager,
+ simulating the behavior of 'with' to support older
+ Python versions.
+
+ This is not necessary anymore as we have placed 2.6
+ as minimum Python version, however some tests are still using
+ this structure.
+
+ """
+
+ obj = ctx.__enter__()
+ try:
+ result = fn(obj, *arg, **kw)
+ ctx.__exit__(None, None, None)
+ return result
+ except:
+ exc_info = sys.exc_info()
+ raise_ = ctx.__exit__(*exc_info)
+ if not raise_:
+ raise
+ else:
+ return raise_
+
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return {tuple(row) for row in results}
+
+
+def fail(msg):
+ assert False, msg
+
+
+@decorator
+def provide_metadata(fn, *args, **kw):
+ """Provide bound MetaData for a single test, dropping afterwards.
+
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
+
+ metadata = schema.MetaData()
+ self = args[0]
+ prev_meta = getattr(self, "metadata", None)
+ self.metadata = metadata
+ try:
+ return fn(*args, **kw)
+ finally:
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+ self.metadata = prev_meta
+
+
+def flag_combinations(*combinations):
+ """A facade around @testing.combinations() oriented towards boolean
+ keyword-based arguments.
+
+ Basically generates a nice looking identifier based on the keywords
+ and also sets up the argument names.
+
+ E.g.::
+
+ @testing.flag_combinations(
+ dict(lazy=False, passive=False),
+ dict(lazy=True, passive=False),
+ dict(lazy=False, passive=True),
+ dict(lazy=False, passive=True, raiseload=True),
+ )
+
+
+ would result in::
+
+ @testing.combinations(
+ ('', False, False, False),
+ ('lazy', True, False, False),
+ ('lazy_passive', True, True, False),
+ ('lazy_passive', True, True, True),
+ id_='iaaa',
+ argnames='lazy,passive,raiseload'
+ )
+
+ """
+
+ keys = set()
+
+ for d in combinations:
+ keys.update(d)
+
+ keys = sorted(keys)
+
+ return config.combinations(
+ *[
+ ("_".join(k for k in keys if d.get(k, False)),)
+ + tuple(d.get(k, False) for k in keys)
+ for d in combinations
+ ],
+ id_="i" + ("a" * len(keys)),
+ argnames=",".join(keys)
+ )
+
+
+def lambda_combinations(lambda_arg_sets, **kw):
+ args = inspect_getfullargspec(lambda_arg_sets)
+
+ arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
+
+ def create_fixture(pos):
+ def fixture(**kw):
+ return lambda_arg_sets(**kw)[pos]
+
+ fixture.__name__ = "fixture_%3.3d" % pos
+ return fixture
+
+ return config.combinations(
+ *[(create_fixture(i),) for i in range(len(arg_sets))], **kw
+ )
+
+
+def resolve_lambda(__fn, **kw):
+ """Given a no-arg lambda and a namespace, return a new lambda that
+ has all the values filled in.
+
+ This is used so that we can have module-level fixtures that
+ refer to instance-level variables using lambdas.
+
+ """
+
+ pos_args = inspect_getfullargspec(__fn)[0]
+ pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
+ glb = dict(__fn.__globals__)
+ glb.update(kw)
+ new_fn = types.FunctionType(__fn.__code__, glb)
+ return new_fn(**pass_pos_args)
+
+
+def metadata_fixture(ddl="function"):
+ """Provide MetaData for a pytest fixture."""
+
+ def decorate(fn):
+ def run_ddl(self):
+
+ metadata = self.metadata = schema.MetaData()
+ try:
+ result = fn(self, metadata)
+ metadata.create_all(config.db)
+ # TODO:
+ # somehow get a per-function dml erase fixture here
+ yield result
+ finally:
+ metadata.drop_all(config.db)
+
+ return config.fixture(scope=ddl)(run_ddl)
+
+ return decorate
+
+
+def force_drop_names(*names):
+ """Force the given table names to be dropped after test complete,
+ isolating for foreign key cycles
+
+ """
+
+ @decorator
+ def go(fn, *args, **kw):
+
+ try:
+ return fn(*args, **kw)
+ finally:
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
+ return go
+
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def __call__(self, *keys):
+ return tuple([self[key] for key in keys])
+
+ get_all = __call__
+
+
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
+def drop_all_tables(engine, inspector, schema=None, include_names=None):
+
+ if include_names is not None:
+ include_names = set(include_names)
+
+ with engine.begin() as conn:
+ for tname, fkcs in reversed(
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
+ if tname:
+ if include_names is not None and tname not in include_names:
+ continue
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
+ elif fkcs:
+ if not engine.dialect.supports_alter:
+ continue
+ for tname, fkc in fkcs:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
+ continue
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
+ )
+
+
+def teardown_events(event_cls):
+ @decorator
+ def decorate(fn, *arg, **kw):
+ try:
+ return fn(*arg, **kw)
+ finally:
+ event_cls._clear()
+
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
new file mode 100644
index 0000000..3e78387
--- /dev/null
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -0,0 +1,82 @@
+# testing/warnings.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import warnings
+
+from . import assertions
+from .. import exc as sa_exc
+from ..util.langhelpers import _warnings_warn
+
+
+class SATestSuiteWarning(Warning):
+ """warning for a condition detected during tests that is non-fatal
+
+ Currently outside of SAWarning so that we can work around tools like
+ Alembic doing the wrong thing with warnings.
+
+ """
+
+
+def warn_test_suite(message):
+ _warnings_warn(message, category=SATestSuiteWarning)
+
+
+def setup_filters():
+ """Set global warning behavior for the test suite."""
+
+ # TODO: at this point we can use the normal pytest warnings plugin,
+ # if we decide the test suite can be linked to pytest only
+
+ origin = r"^(?:test|sqlalchemy)\..*"
+
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
+
+ warnings.filterwarnings("always", category=SATestSuiteWarning)
+
+ warnings.filterwarnings(
+ "error", category=DeprecationWarning, module=origin
+ )
+
+ # ignore things that are deprecated *as of* 2.0 :)
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r".*\(deprecated since: 2.0\)$",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r"^The (Sybase|firebird) dialect is deprecated and will be",
+ )
+
+ try:
+ import pytest
+ except ImportError:
+ pass
+ else:
+ warnings.filterwarnings(
+ "once", category=pytest.PytestDeprecationWarning, module=origin
+ )
+
+
+def assert_warnings(fn, warning_msgs, regex=False):
+ """Assert that each of the given warnings are emitted by fn.
+
+ Deprecated. Please use assertions.expect_warnings().
+
+ """
+
+ with assertions._expect_warnings(
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
+ return fn()
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
new file mode 100644
index 0000000..07263c5
--- /dev/null
+++ b/lib/sqlalchemy/types.py
@@ -0,0 +1,119 @@
+# types.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Compatibility namespace for sqlalchemy.sql.types.
+
+"""
+
+__all__ = [
+ "TypeEngine",
+ "TypeDecorator",
+ "UserDefinedType",
+ "ExternalType",
+ "INT",
+ "CHAR",
+ "VARCHAR",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "Text",
+ "FLOAT",
+ "NUMERIC",
+ "REAL",
+ "DECIMAL",
+ "TIMESTAMP",
+ "DATETIME",
+ "CLOB",
+ "BLOB",
+ "BINARY",
+ "VARBINARY",
+ "BOOLEAN",
+ "BIGINT",
+ "SMALLINT",
+ "INTEGER",
+ "DATE",
+ "TIME",
+ "TupleType",
+ "String",
+ "Integer",
+ "SmallInteger",
+ "BigInteger",
+ "Numeric",
+ "Float",
+ "DateTime",
+ "Date",
+ "Time",
+ "LargeBinary",
+ "Boolean",
+ "Unicode",
+ "Concatenable",
+ "UnicodeText",
+ "PickleType",
+ "Interval",
+ "Enum",
+ "Indexable",
+ "ARRAY",
+ "JSON",
+]
+
+from .sql.sqltypes import _Binary
+from .sql.sqltypes import ARRAY
+from .sql.sqltypes import BIGINT
+from .sql.sqltypes import BigInteger
+from .sql.sqltypes import BINARY
+from .sql.sqltypes import BLOB
+from .sql.sqltypes import BOOLEAN
+from .sql.sqltypes import Boolean
+from .sql.sqltypes import CHAR
+from .sql.sqltypes import CLOB
+from .sql.sqltypes import Concatenable
+from .sql.sqltypes import DATE
+from .sql.sqltypes import Date
+from .sql.sqltypes import DATETIME
+from .sql.sqltypes import DateTime
+from .sql.sqltypes import DECIMAL
+from .sql.sqltypes import Enum
+from .sql.sqltypes import FLOAT
+from .sql.sqltypes import Float
+from .sql.sqltypes import Indexable
+from .sql.sqltypes import INT
+from .sql.sqltypes import INTEGER
+from .sql.sqltypes import Integer
+from .sql.sqltypes import Interval
+from .sql.sqltypes import JSON
+from .sql.sqltypes import LargeBinary
+from .sql.sqltypes import MatchType
+from .sql.sqltypes import NCHAR
+from .sql.sqltypes import NULLTYPE
+from .sql.sqltypes import NullType
+from .sql.sqltypes import NUMERIC
+from .sql.sqltypes import Numeric
+from .sql.sqltypes import NVARCHAR
+from .sql.sqltypes import PickleType
+from .sql.sqltypes import REAL
+from .sql.sqltypes import SchemaType
+from .sql.sqltypes import SMALLINT
+from .sql.sqltypes import SmallInteger
+from .sql.sqltypes import String
+from .sql.sqltypes import STRINGTYPE
+from .sql.sqltypes import TEXT
+from .sql.sqltypes import Text
+from .sql.sqltypes import TIME
+from .sql.sqltypes import Time
+from .sql.sqltypes import TIMESTAMP
+from .sql.sqltypes import TupleType
+from .sql.sqltypes import Unicode
+from .sql.sqltypes import UnicodeText
+from .sql.sqltypes import VARBINARY
+from .sql.sqltypes import VARCHAR
+from .sql.type_api import adapt_type
+from .sql.type_api import ExternalType
+from .sql.type_api import to_instance
+from .sql.type_api import TypeDecorator
+from .sql.type_api import TypeEngine
+from .sql.type_api import UserDefinedType
+from .sql.type_api import Variant
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
new file mode 100644
index 0000000..33427e3
--- /dev/null
+++ b/lib/sqlalchemy/util/__init__.py
@@ -0,0 +1,175 @@
+# util/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from collections import defaultdict
+from contextlib import contextmanager
+from functools import partial
+from functools import update_wrapper
+
+from ._collections import coerce_generator_arg
+from ._collections import coerce_to_immutabledict
+from ._collections import collections_abc
+from ._collections import column_dict
+from ._collections import column_set
+from ._collections import EMPTY_DICT
+from ._collections import EMPTY_SET
+from ._collections import FacadeDict
+from ._collections import flatten_iterator
+from ._collections import has_dupes
+from ._collections import has_intersection
+from ._collections import IdentitySet
+from ._collections import ImmutableContainer
+from ._collections import immutabledict
+from ._collections import ImmutableProperties
+from ._collections import LRUCache
+from ._collections import ordered_column_set
+from ._collections import OrderedDict
+from ._collections import OrderedIdentitySet
+from ._collections import OrderedProperties
+from ._collections import OrderedSet
+from ._collections import PopulateDict
+from ._collections import Properties
+from ._collections import ScopedRegistry
+from ._collections import sort_dictionary
+from ._collections import ThreadLocalRegistry
+from ._collections import to_column_set
+from ._collections import to_list
+from ._collections import to_set
+from ._collections import unique_list
+from ._collections import UniqueAppender
+from ._collections import update_copy
+from ._collections import WeakPopulateDict
+from ._collections import WeakSequence
+from ._preloaded import preload_module
+from ._preloaded import preloaded
+from .compat import ABC
+from .compat import arm
+from .compat import b
+from .compat import b64decode
+from .compat import b64encode
+from .compat import binary_type
+from .compat import binary_types
+from .compat import byte_buffer
+from .compat import callable
+from .compat import cmp
+from .compat import cpython
+from .compat import dataclass_fields
+from .compat import decode_backslashreplace
+from .compat import dottedgetter
+from .compat import has_refcount_gc
+from .compat import inspect_getfullargspec
+from .compat import int_types
+from .compat import iterbytes
+from .compat import itertools_filter
+from .compat import itertools_filterfalse
+from .compat import local_dataclass_fields
+from .compat import namedtuple
+from .compat import next
+from .compat import nullcontext
+from .compat import osx
+from .compat import parse_qsl
+from .compat import perf_counter
+from .compat import pickle
+from .compat import print_
+from .compat import py2k
+from .compat import py311
+from .compat import py37
+from .compat import py38
+from .compat import py39
+from .compat import py3k
+from .compat import pypy
+from .compat import quote_plus
+from .compat import raise_
+from .compat import raise_from_cause
+from .compat import reduce
+from .compat import reraise
+from .compat import string_types
+from .compat import StringIO
+from .compat import text_type
+from .compat import threading
+from .compat import timezone
+from .compat import TYPE_CHECKING
+from .compat import u
+from .compat import ue
+from .compat import unquote
+from .compat import unquote_plus
+from .compat import win32
+from .compat import with_metaclass
+from .compat import zip_longest
+from .concurrency import asyncio
+from .concurrency import await_fallback
+from .concurrency import await_only
+from .concurrency import greenlet_spawn
+from .concurrency import is_exit_exception
+from .deprecations import deprecated
+from .deprecations import deprecated_20
+from .deprecations import deprecated_20_cls
+from .deprecations import deprecated_cls
+from .deprecations import deprecated_params
+from .deprecations import inject_docstring_text
+from .deprecations import moved_20
+from .deprecations import SQLALCHEMY_WARN_20
+from .deprecations import warn_deprecated
+from .deprecations import warn_deprecated_20
+from .langhelpers import add_parameter_text
+from .langhelpers import as_interface
+from .langhelpers import asbool
+from .langhelpers import asint
+from .langhelpers import assert_arg_type
+from .langhelpers import attrsetter
+from .langhelpers import bool_or_str
+from .langhelpers import chop_traceback
+from .langhelpers import class_hierarchy
+from .langhelpers import classproperty
+from .langhelpers import clsname_as_plain_name
+from .langhelpers import coerce_kw_type
+from .langhelpers import constructor_copy
+from .langhelpers import constructor_key
+from .langhelpers import counter
+from .langhelpers import create_proxy_methods
+from .langhelpers import decode_slice
+from .langhelpers import decorator
+from .langhelpers import dictlike_iteritems
+from .langhelpers import duck_type_collection
+from .langhelpers import ellipses_string
+from .langhelpers import EnsureKWArgType
+from .langhelpers import format_argspec_init
+from .langhelpers import format_argspec_plus
+from .langhelpers import generic_repr
+from .langhelpers import get_callable_argspec
+from .langhelpers import get_cls_kwargs
+from .langhelpers import get_func_kwargs
+from .langhelpers import getargspec_init
+from .langhelpers import has_compiled_ext
+from .langhelpers import HasMemoized
+from .langhelpers import hybridmethod
+from .langhelpers import hybridproperty
+from .langhelpers import iterate_attributes
+from .langhelpers import map_bits
+from .langhelpers import md5_hex
+from .langhelpers import memoized_instancemethod
+from .langhelpers import memoized_property
+from .langhelpers import MemoizedSlots
+from .langhelpers import method_is_overridden
+from .langhelpers import methods_equivalent
+from .langhelpers import monkeypatch_proxied_specials
+from .langhelpers import NoneType
+from .langhelpers import only_once
+from .langhelpers import PluginLoader
+from .langhelpers import portable_instancemethod
+from .langhelpers import quoted_token_parser
+from .langhelpers import safe_reraise
+from .langhelpers import set_creation_order
+from .langhelpers import string_or_unprintable
+from .langhelpers import symbol
+from .langhelpers import unbound_method_to_callable
+from .langhelpers import walk_subclasses
+from .langhelpers import warn
+from .langhelpers import warn_exception
+from .langhelpers import warn_limited
+from .langhelpers import wrap_callable
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
new file mode 100644
index 0000000..8e21830
--- /dev/null
+++ b/lib/sqlalchemy/util/_collections.py
@@ -0,0 +1,1089 @@
+# util/_collections.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Collection classes and helpers."""
+
+from __future__ import absolute_import
+
+import operator
+import types
+import weakref
+
+from .compat import binary_types
+from .compat import collections_abc
+from .compat import itertools_filterfalse
+from .compat import py2k
+from .compat import py37
+from .compat import string_types
+from .compat import threading
+
+
+EMPTY_SET = frozenset()
+
+
+class ImmutableContainer(object):
+ def _immutable(self, *arg, **kw):
+ raise TypeError("%s object is immutable" % self.__class__.__name__)
+
+ __delitem__ = __setitem__ = __setattr__ = _immutable
+
+
+def _immutabledict_py_fallback():
+ class immutabledict(ImmutableContainer, dict):
+
+ clear = (
+ pop
+ ) = popitem = setdefault = update = ImmutableContainer._immutable
+
+ def __new__(cls, *args):
+ new = dict.__new__(cls)
+ dict.__init__(new, *args)
+ return new
+
+ def __init__(self, *args):
+ pass
+
+ def __reduce__(self):
+ return _immutabledict_reconstructor, (dict(self),)
+
+ def union(self, __d=None):
+ if not __d:
+ return self
+
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ dict.update(new, __d)
+ return new
+
+ def _union_w_kw(self, __d=None, **kw):
+ # not sure if C version works correctly w/ this yet
+ if not __d and not kw:
+ return self
+
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ if __d:
+ dict.update(new, __d)
+ dict.update(new, kw)
+ return new
+
+ def merge_with(self, *dicts):
+ new = None
+ for d in dicts:
+ if d:
+ if new is None:
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ dict.update(new, d)
+ if new is None:
+ return self
+
+ return new
+
+ def __repr__(self):
+ return "immutabledict(%s)" % dict.__repr__(self)
+
+ return immutabledict
+
+
+try:
+ from sqlalchemy.cimmutabledict import immutabledict
+
+ collections_abc.Mapping.register(immutabledict)
+
+except ImportError:
+ immutabledict = _immutabledict_py_fallback()
+
+ def _immutabledict_reconstructor(*arg):
+ """do the pickle dance"""
+ return immutabledict(*arg)
+
+
+def coerce_to_immutabledict(d):
+ if not d:
+ return EMPTY_DICT
+ elif isinstance(d, immutabledict):
+ return d
+ else:
+ return immutabledict(d)
+
+
+EMPTY_DICT = immutabledict()
+
+
+class FacadeDict(ImmutableContainer, dict):
+ """A dictionary that is not publicly mutable."""
+
+ clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
+
+ def __new__(cls, *args):
+ new = dict.__new__(cls)
+ return new
+
+ def copy(self):
+ raise NotImplementedError(
+ "an immutabledict shouldn't need to be copied. use dict(d) "
+ "if you need a mutable dictionary."
+ )
+
+ def __reduce__(self):
+ return FacadeDict, (dict(self),)
+
+ def _insert_item(self, key, value):
+ """insert an item into the dictionary directly."""
+ dict.__setitem__(self, key, value)
+
+ def __repr__(self):
+ return "FacadeDict(%s)" % dict.__repr__(self)
+
+
+class Properties(object):
+ """Provide a __getattr__/__setattr__ interface over a dict."""
+
+ __slots__ = ("_data",)
+
+ def __init__(self, data):
+ object.__setattr__(self, "_data", data)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __iter__(self):
+ return iter(list(self._data.values()))
+
+ def __dir__(self):
+ return dir(super(Properties, self)) + [
+ str(k) for k in self._data.keys()
+ ]
+
+ def __add__(self, other):
+ return list(self) + list(other)
+
+ def __setitem__(self, key, obj):
+ self._data[key] = obj
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+ def __delitem__(self, key):
+ del self._data[key]
+
+ def __setattr__(self, key, obj):
+ self._data[key] = obj
+
+ def __getstate__(self):
+ return {"_data": self._data}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_data", state["_data"])
+
+ def __getattr__(self, key):
+ try:
+ return self._data[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def as_immutable(self):
+ """Return an immutable proxy for this :class:`.Properties`."""
+
+ return ImmutableProperties(self._data)
+
+ def update(self, value):
+ self._data.update(value)
+
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def keys(self):
+ return list(self._data)
+
+ def values(self):
+ return list(self._data.values())
+
+ def items(self):
+ return list(self._data.items())
+
+ def has_key(self, key):
+ return key in self._data
+
+ def clear(self):
+ self._data.clear()
+
+
+class OrderedProperties(Properties):
+ """Provide a __getattr__/__setattr__ interface with an OrderedDict
+ as backing store."""
+
+ __slots__ = ()
+
+ def __init__(self):
+ Properties.__init__(self, OrderedDict())
+
+
+class ImmutableProperties(ImmutableContainer, Properties):
+ """Provide immutable dict/object attribute to an underlying dictionary."""
+
+ __slots__ = ()
+
+
+def _ordered_dictionary_sort(d, key=None):
+ """Sort an OrderedDict in-place."""
+
+ items = [(k, d[k]) for k in sorted(d, key=key)]
+
+ d.clear()
+
+ d.update(items)
+
+
+if py37:
+ OrderedDict = dict
+ sort_dictionary = _ordered_dictionary_sort
+
+else:
+ # prevent sort_dictionary from being used against a plain dictionary
+ # for Python < 3.7
+
+ def sort_dictionary(d, key=None):
+ """Sort an OrderedDict in place."""
+
+ d._ordered_dictionary_sort(key=key)
+
+ class OrderedDict(dict):
+ """Dictionary that maintains insertion order.
+
+ Superseded by Python dict as of Python 3.7
+
+ """
+
+ __slots__ = ("_list",)
+
+ def _ordered_dictionary_sort(self, key=None):
+ _ordered_dictionary_sort(self, key=key)
+
+ def __reduce__(self):
+ return OrderedDict, (self.items(),)
+
+ def __init__(self, ____sequence=None, **kwargs):
+ self._list = []
+ if ____sequence is None:
+ if kwargs:
+ self.update(**kwargs)
+ else:
+ self.update(____sequence, **kwargs)
+
+ def clear(self):
+ self._list = []
+ dict.clear(self)
+
+ def copy(self):
+ return self.__copy__()
+
+ def __copy__(self):
+ return OrderedDict(self)
+
+ def update(self, ____sequence=None, **kwargs):
+ if ____sequence is not None:
+ if hasattr(____sequence, "keys"):
+ for key in ____sequence.keys():
+ self.__setitem__(key, ____sequence[key])
+ else:
+ for key, value in ____sequence:
+ self[key] = value
+ if kwargs:
+ self.update(kwargs)
+
+ def setdefault(self, key, value):
+ if key not in self:
+ self.__setitem__(key, value)
+ return value
+ else:
+ return self.__getitem__(key)
+
+ def __iter__(self):
+ return iter(self._list)
+
+ def keys(self):
+ return list(self)
+
+ def values(self):
+ return [self[key] for key in self._list]
+
+ def items(self):
+ return [(key, self[key]) for key in self._list]
+
+ if py2k:
+
+ def itervalues(self):
+ return iter(self.values())
+
+ def iterkeys(self):
+ return iter(self)
+
+ def iteritems(self):
+ return iter(self.items())
+
+ def __setitem__(self, key, obj):
+ if key not in self:
+ try:
+ self._list.append(key)
+ except AttributeError:
+ # work around Python pickle loads() with
+ # dict subclass (seems to ignore __setstate__?)
+ self._list = [key]
+ dict.__setitem__(self, key, obj)
+
+ def __delitem__(self, key):
+ dict.__delitem__(self, key)
+ self._list.remove(key)
+
+ def pop(self, key, *default):
+ present = key in self
+ value = dict.pop(self, key, *default)
+ if present:
+ self._list.remove(key)
+ return value
+
+ def popitem(self):
+ item = dict.popitem(self)
+ self._list.remove(item[0])
+ return item
+
+
+class OrderedSet(set):
+ def __init__(self, d=None):
+ set.__init__(self)
+ if d is not None:
+ self._list = unique_list(d)
+ set.update(self, self._list)
+ else:
+ self._list = []
+
+ def add(self, element):
+ if element not in self:
+ self._list.append(element)
+ set.add(self, element)
+
+ def remove(self, element):
+ set.remove(self, element)
+ self._list.remove(element)
+
+ def insert(self, pos, element):
+ if element not in self:
+ self._list.insert(pos, element)
+ set.add(self, element)
+
+ def discard(self, element):
+ if element in self:
+ self._list.remove(element)
+ set.remove(self, element)
+
+ def clear(self):
+ set.clear(self)
+ self._list = []
+
+ def __getitem__(self, key):
+ return self._list[key]
+
+ def __iter__(self):
+ return iter(self._list)
+
+ def __add__(self, other):
+ return self.union(other)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self._list)
+
+ __str__ = __repr__
+
+ def update(self, iterable):
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ set.add(self, e)
+ return self
+
+ __ior__ = update
+
+ def union(self, other):
+ result = self.__class__(self)
+ result.update(other)
+ return result
+
+ __or__ = union
+
+ def intersection(self, other):
+ other = set(other)
+ return self.__class__(a for a in self if a in other)
+
+ __and__ = intersection
+
+ def symmetric_difference(self, other):
+ other = set(other)
+ result = self.__class__(a for a in self if a not in other)
+ result.update(a for a in other if a not in self)
+ return result
+
+ __xor__ = symmetric_difference
+
+ def difference(self, other):
+ other = set(other)
+ return self.__class__(a for a in self if a not in other)
+
+ __sub__ = difference
+
+ def intersection_update(self, other):
+ other = set(other)
+ set.intersection_update(self, other)
+ self._list = [a for a in self._list if a in other]
+ return self
+
+ __iand__ = intersection_update
+
+ def symmetric_difference_update(self, other):
+ set.symmetric_difference_update(self, other)
+ self._list = [a for a in self._list if a in self]
+ self._list += [a for a in other._list if a in self]
+ return self
+
+ __ixor__ = symmetric_difference_update
+
+ def difference_update(self, other):
+ set.difference_update(self, other)
+ self._list = [a for a in self._list if a in self]
+ return self
+
+ __isub__ = difference_update
+
+
+class IdentitySet(object):
+ """A set that considers only object id() for uniqueness.
+
+ This strategy has edge cases for builtin types- it's possible to have
+ two 'foo' strings in one of these sets, for example. Use sparingly.
+
+ """
+
+ def __init__(self, iterable=None):
+ self._members = dict()
+ if iterable:
+ self.update(iterable)
+
+ def add(self, value):
+ self._members[id(value)] = value
+
+ def __contains__(self, value):
+ return id(value) in self._members
+
+ def remove(self, value):
+ del self._members[id(value)]
+
+ def discard(self, value):
+ try:
+ self.remove(value)
+ except KeyError:
+ pass
+
+ def pop(self):
+ try:
+ pair = self._members.popitem()
+ return pair[1]
+ except KeyError:
+ raise KeyError("pop from an empty set")
+
+ def clear(self):
+ self._members.clear()
+
+ def __cmp__(self, other):
+ raise TypeError("cannot compare sets using cmp()")
+
+ def __eq__(self, other):
+ if isinstance(other, IdentitySet):
+ return self._members == other._members
+ else:
+ return False
+
+ def __ne__(self, other):
+ if isinstance(other, IdentitySet):
+ return self._members != other._members
+ else:
+ return True
+
+ def issubset(self, iterable):
+ if isinstance(iterable, self.__class__):
+ other = iterable
+ else:
+ other = self.__class__(iterable)
+
+ if len(self) > len(other):
+ return False
+ for m in itertools_filterfalse(
+ other._members.__contains__, iter(self._members.keys())
+ ):
+ return False
+ return True
+
+ def __le__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.issubset(other)
+
+ def __lt__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return len(self) < len(other) and self.issubset(other)
+
+ def issuperset(self, iterable):
+ if isinstance(iterable, self.__class__):
+ other = iterable
+ else:
+ other = self.__class__(iterable)
+
+ if len(self) < len(other):
+ return False
+
+ for m in itertools_filterfalse(
+ self._members.__contains__, iter(other._members.keys())
+ ):
+ return False
+ return True
+
+ def __ge__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.issuperset(other)
+
+ def __gt__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return len(self) > len(other) and self.issuperset(other)
+
+ def union(self, iterable):
+ result = self.__class__()
+ members = self._members
+ result._members.update(members)
+ result._members.update((id(obj), obj) for obj in iterable)
+ return result
+
+ def __or__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.union(other)
+
+ def update(self, iterable):
+ self._members.update((id(obj), obj) for obj in iterable)
+
+ def __ior__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.update(other)
+ return self
+
+ def difference(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = set(iterable._members.keys())
+ else:
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ return result
+
+ def __sub__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.difference(other)
+
+ def difference_update(self, iterable):
+ self._members = self.difference(iterable)._members
+
+ def __isub__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.difference_update(other)
+ return self
+
+ def intersection(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = set(iterable._members.keys())
+ else:
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ (k, v) for k, v in members.items() if k in other
+ )
+ return result
+
+ def __and__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.intersection(other)
+
+ def intersection_update(self, iterable):
+ self._members = self.intersection(iterable)._members
+
+ def __iand__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.intersection_update(other)
+ return self
+
+ def symmetric_difference(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = iterable._members
+ else:
+ other = {id(obj): obj for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ result._members.update(
+ ((k, v) for k, v in other.items() if k not in members)
+ )
+ return result
+
+ def __xor__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.symmetric_difference(other)
+
+ def symmetric_difference_update(self, iterable):
+ self._members = self.symmetric_difference(iterable)._members
+
+ def __ixor__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.symmetric_difference(other)
+ return self
+
+ def copy(self):
+ return type(self)(iter(self._members.values()))
+
+ __copy__ = copy
+
+ def __len__(self):
+ return len(self._members)
+
+ def __iter__(self):
+ return iter(self._members.values())
+
+ def __hash__(self):
+ raise TypeError("set objects are unhashable")
+
+ def __repr__(self):
+ return "%s(%r)" % (type(self).__name__, list(self._members.values()))
+
+
+class WeakSequence(object):
+ def __init__(self, __elements=()):
+ # adapted from weakref.WeakKeyDictionary, prevent reference
+ # cycles in the collection itself
+ def _remove(item, selfref=weakref.ref(self)):
+ self = selfref()
+ if self is not None:
+ self._storage.remove(item)
+
+ self._remove = _remove
+ self._storage = [
+ weakref.ref(element, _remove) for element in __elements
+ ]
+
+ def append(self, item):
+ self._storage.append(weakref.ref(item, self._remove))
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ return (
+ obj for obj in (ref() for ref in self._storage) if obj is not None
+ )
+
+ def __getitem__(self, index):
+ try:
+ obj = self._storage[index]
+ except KeyError:
+ raise IndexError("Index %s out of range" % index)
+ else:
+ return obj()
+
+
+class OrderedIdentitySet(IdentitySet):
+ def __init__(self, iterable=None):
+ IdentitySet.__init__(self)
+ self._members = OrderedDict()
+ if iterable:
+ for o in iterable:
+ self.add(o)
+
+
+class PopulateDict(dict):
+ """A dict which populates missing values via a creation function.
+
+ Note the creation function takes a key, unlike
+ collections.defaultdict.
+
+ """
+
+ def __init__(self, creator):
+ self.creator = creator
+
+ def __missing__(self, key):
+ self[key] = val = self.creator(key)
+ return val
+
+
+class WeakPopulateDict(dict):
+ """Like PopulateDict, but assumes a self + a method and does not create
+ a reference cycle.
+
+ """
+
+ def __init__(self, creator_method):
+ self.creator = creator_method.__func__
+ weakself = creator_method.__self__
+ self.weakself = weakref.ref(weakself)
+
+ def __missing__(self, key):
+ self[key] = val = self.creator(self.weakself(), key)
+ return val
+
+
+# Define collections that are capable of storing
+# ColumnElement objects as hashable keys/elements.
+# At this point, these are mostly historical, things
+# used to be more complicated.
+column_set = set
+column_dict = dict
+ordered_column_set = OrderedSet
+
+
+_getters = PopulateDict(operator.itemgetter)
+
+_property_getters = PopulateDict(
+ lambda idx: property(operator.itemgetter(idx))
+)
+
+
+def unique_list(seq, hashfunc=None):
+ seen = set()
+ seen_add = seen.add
+ if not hashfunc:
+ return [x for x in seq if x not in seen and not seen_add(x)]
+ else:
+ return [
+ x
+ for x in seq
+ if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+ ]
+
+
+class UniqueAppender(object):
+ """Appends items to a collection ensuring uniqueness.
+
+ Additional appends() of the same object are ignored. Membership is
+ determined by identity (``is a``) not equality (``==``).
+ """
+
+ def __init__(self, data, via=None):
+ self.data = data
+ self._unique = {}
+ if via:
+ self._data_appender = getattr(data, via)
+ elif hasattr(data, "append"):
+ self._data_appender = data.append
+ elif hasattr(data, "add"):
+ self._data_appender = data.add
+
+ def append(self, item):
+ id_ = id(item)
+ if id_ not in self._unique:
+ self._data_appender(item)
+ self._unique[id_] = True
+
+ def __iter__(self):
+ return iter(self.data)
+
+
+def coerce_generator_arg(arg):
+ if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
+ return list(arg[0])
+ else:
+ return arg
+
+
+def to_list(x, default=None):
+ if x is None:
+ return default
+ if not isinstance(x, collections_abc.Iterable) or isinstance(
+ x, string_types + binary_types
+ ):
+ return [x]
+ elif isinstance(x, list):
+ return x
+ else:
+ return list(x)
+
+
+def has_intersection(set_, iterable):
+ r"""return True if any items of set\_ are present in iterable.
+
+ Goes through special effort to ensure __hash__ is not called
+ on items in iterable that don't support it.
+
+ """
+ # TODO: optimize, write in C, etc.
+ return bool(set_.intersection([i for i in iterable if i.__hash__]))
+
+
+def to_set(x):
+ if x is None:
+ return set()
+ if not isinstance(x, set):
+ return set(to_list(x))
+ else:
+ return x
+
+
+def to_column_set(x):
+ if x is None:
+ return column_set()
+ if not isinstance(x, column_set):
+ return column_set(to_list(x))
+ else:
+ return x
+
+
+def update_copy(d, _new=None, **kw):
+ """Copy the given dict and update with the given values."""
+
+ d = d.copy()
+ if _new:
+ d.update(_new)
+ d.update(**kw)
+ return d
+
+
+def flatten_iterator(x):
+ """Given an iterator of which further sub-elements may also be
+ iterators, flatten the sub-elements into a single iterator.
+
+ """
+ for elem in x:
+ if not isinstance(elem, str) and hasattr(elem, "__iter__"):
+ for y in flatten_iterator(elem):
+ yield y
+ else:
+ yield elem
+
+
+class LRUCache(dict):
+ """Dictionary with 'squishy' removal of least
+ recently used items.
+
+ Note that either get() or [] should be used here, but
+ generally its not safe to do an "in" check first as the dictionary
+ can change subsequent to that call.
+
+ """
+
+ __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
+
+ def __init__(self, capacity=100, threshold=0.5, size_alert=None):
+ self.capacity = capacity
+ self.threshold = threshold
+ self.size_alert = size_alert
+ self._counter = 0
+ self._mutex = threading.Lock()
+
+ def _inc_counter(self):
+ self._counter += 1
+ return self._counter
+
+ def get(self, key, default=None):
+ item = dict.get(self, key, default)
+ if item is not default:
+ item[2] = self._inc_counter()
+ return item[1]
+ else:
+ return default
+
+ def __getitem__(self, key):
+ item = dict.__getitem__(self, key)
+ item[2] = self._inc_counter()
+ return item[1]
+
+ def values(self):
+ return [i[1] for i in dict.values(self)]
+
+ def setdefault(self, key, value):
+ if key in self:
+ return self[key]
+ else:
+ self[key] = value
+ return value
+
+ def __setitem__(self, key, value):
+ item = dict.get(self, key)
+ if item is None:
+ item = [key, value, self._inc_counter()]
+ dict.__setitem__(self, key, item)
+ else:
+ item[1] = value
+ self._manage_size()
+
+ @property
+ def size_threshold(self):
+ return self.capacity + self.capacity * self.threshold
+
+ def _manage_size(self):
+ if not self._mutex.acquire(False):
+ return
+ try:
+ size_alert = bool(self.size_alert)
+ while len(self) > self.capacity + self.capacity * self.threshold:
+ if size_alert:
+ size_alert = False
+ self.size_alert(self)
+ by_counter = sorted(
+ dict.values(self), key=operator.itemgetter(2), reverse=True
+ )
+ for item in by_counter[self.capacity :]:
+ try:
+ del self[item[0]]
+ except KeyError:
+ # deleted elsewhere; skip
+ continue
+ finally:
+ self._mutex.release()
+
+
+class ScopedRegistry(object):
+ """A Registry that can store one or multiple instances of a single
+ class on the basis of a "scope" function.
+
+ The object implements ``__call__`` as the "getter", so by
+ calling ``myregistry()`` the contained object is returned
+ for the current scope.
+
+ :param createfunc:
+ a callable that returns a new object to be placed in the registry
+
+ :param scopefunc:
+ a callable that will return a key to store/retrieve an object.
+ """
+
+ def __init__(self, createfunc, scopefunc):
+ """Construct a new :class:`.ScopedRegistry`.
+
+ :param createfunc: A creation function that will generate
+ a new value for the current scope, if none is present.
+
+ :param scopefunc: A function that returns a hashable
+ token representing the current scope (such as, current
+ thread identifier).
+
+ """
+ self.createfunc = createfunc
+ self.scopefunc = scopefunc
+ self.registry = {}
+
+ def __call__(self):
+ key = self.scopefunc()
+ try:
+ return self.registry[key]
+ except KeyError:
+ return self.registry.setdefault(key, self.createfunc())
+
+ def has(self):
+ """Return True if an object is present in the current scope."""
+
+ return self.scopefunc() in self.registry
+
+ def set(self, obj):
+ """Set the value for the current scope."""
+
+ self.registry[self.scopefunc()] = obj
+
+ def clear(self):
+ """Clear the current scope, if any."""
+
+ try:
+ del self.registry[self.scopefunc()]
+ except KeyError:
+ pass
+
+
+class ThreadLocalRegistry(ScopedRegistry):
+ """A :class:`.ScopedRegistry` that uses a ``threading.local()``
+ variable for storage.
+
+ """
+
+ def __init__(self, createfunc):
+ self.createfunc = createfunc
+ self.registry = threading.local()
+
+ def __call__(self):
+ try:
+ return self.registry.value
+ except AttributeError:
+ val = self.registry.value = self.createfunc()
+ return val
+
+ def has(self):
+ return hasattr(self.registry, "value")
+
+ def set(self, obj):
+ self.registry.value = obj
+
+ def clear(self):
+ try:
+ del self.registry.value
+ except AttributeError:
+ pass
+
+
+def has_dupes(sequence, target):
+ """Given a sequence and search object, return True if there's more
+ than one, False if zero or one of them.
+
+
+ """
+ # compare to .index version below, this version introduces less function
+ # overhead and is usually the same speed. At 15000 items (way bigger than
+ # a relationship-bound collection in memory usually is) it begins to
+ # fall behind the other version only by microseconds.
+ c = 0
+ for item in sequence:
+ if item is target:
+ c += 1
+ if c > 1:
+ return True
+ return False
+
+
+# .index version. the two __contains__ calls as well
+# as .index() and isinstance() slow this down.
+# def has_dupes(sequence, target):
+# if target not in sequence:
+# return False
+# elif not isinstance(sequence, collections_abc.Sequence):
+# return False
+#
+# idx = sequence.index(target)
+# return target in sequence[idx + 1:]
diff --git a/lib/sqlalchemy/util/_compat_py3k.py b/lib/sqlalchemy/util/_compat_py3k.py
new file mode 100644
index 0000000..ce659a4
--- /dev/null
+++ b/lib/sqlalchemy/util/_compat_py3k.py
@@ -0,0 +1,67 @@
+# util/_compat_py3k.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from functools import wraps
+
+# vendored from py3.7
+
+
+class _AsyncGeneratorContextManager:
+ """Helper for @asynccontextmanager."""
+
+ def __init__(self, func, args, kwds):
+ self.gen = func(*args, **kwds)
+ self.func, self.args, self.kwds = func, args, kwds
+ doc = getattr(func, "__doc__", None)
+ if doc is None:
+ doc = type(self).__doc__
+ self.__doc__ = doc
+
+ async def __aenter__(self):
+ try:
+ return await self.gen.__anext__()
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ async def __aexit__(self, typ, value, traceback):
+ if typ is None:
+ try:
+ await self.gen.__anext__()
+ except StopAsyncIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ value = typ()
+ # See _GeneratorContextManager.__exit__ for comments on subtleties
+ # in this implementation
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ raise RuntimeError("generator didn't stop after athrow()")
+ except StopAsyncIteration as exc:
+ return exc is not value
+ except RuntimeError as exc:
+ if exc is value:
+ return False
+ if isinstance(value, (StopIteration, StopAsyncIteration)):
+ if exc.__cause__ is value:
+ return False
+ raise
+ except BaseException as exc:
+ if exc is not value:
+ raise
+
+
+# using the vendored version in all cases at the moment to establish
+# full test coverage
+def asynccontextmanager(func):
+ @wraps(func)
+ def helper(*args, **kwds):
+ return _AsyncGeneratorContextManager(func, args, kwds)
+
+ return helper
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py
new file mode 100644
index 0000000..0b12834
--- /dev/null
+++ b/lib/sqlalchemy/util/_concurrency_py3k.py
@@ -0,0 +1,194 @@
+# util/_concurrency_py3k.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import asyncio
+import sys
+from typing import Any
+from typing import Callable
+from typing import Coroutine
+
+import greenlet
+
+from . import compat
+from .langhelpers import memoized_property
+from .. import exc
+
+# If greenlet.gr_context is present in current version of greenlet,
+# it will be set with the current context on creation.
+# Refs: https://github.com/python-greenlet/greenlet/pull/198
+_has_gr_context = hasattr(greenlet.getcurrent(), "gr_context")
+
+
+def is_exit_exception(e):
+ # note asyncio.CancelledError is already BaseException
+ # so was an exit exception in any case
+ return not isinstance(e, Exception) or isinstance(
+ e, (asyncio.TimeoutError, asyncio.CancelledError)
+ )
+
+
+# implementation based on snaury gist at
+# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
+# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
+
+
+class _AsyncIoGreenlet(greenlet.greenlet):
+ def __init__(self, fn, driver):
+ greenlet.greenlet.__init__(self, fn, driver)
+ self.driver = driver
+ if _has_gr_context:
+ self.gr_context = driver.gr_context
+
+
+def await_only(awaitable: Coroutine) -> Any:
+ """Awaits an async function in a sync method.
+
+ The sync method must be inside a :func:`greenlet_spawn` context.
+ :func:`await_only` calls cannot be nested.
+
+ :param awaitable: The coroutine to call.
+
+ """
+ # this is called in the context greenlet while running fn
+ current = greenlet.getcurrent()
+ if not isinstance(current, _AsyncIoGreenlet):
+ raise exc.MissingGreenlet(
+ "greenlet_spawn has not been called; can't call await_only() "
+ "here. Was IO attempted in an unexpected place?"
+ )
+
+ # returns the control to the driver greenlet passing it
+ # a coroutine to run. Once the awaitable is done, the driver greenlet
+ # switches back to this greenlet with the result of awaitable that is
+ # then returned to the caller (or raised as error)
+ return current.driver.switch(awaitable)
+
+
+def await_fallback(awaitable: Coroutine) -> Any:
+ """Awaits an async function in a sync method.
+
+ The sync method must be inside a :func:`greenlet_spawn` context.
+ :func:`await_fallback` calls cannot be nested.
+
+ :param awaitable: The coroutine to call.
+
+ """
+ # this is called in the context greenlet while running fn
+ current = greenlet.getcurrent()
+ if not isinstance(current, _AsyncIoGreenlet):
+ loop = get_event_loop()
+ if loop.is_running():
+ raise exc.MissingGreenlet(
+ "greenlet_spawn has not been called and asyncio event "
+ "loop is already running; can't call await_fallback() here. "
+ "Was IO attempted in an unexpected place?"
+ )
+ return loop.run_until_complete(awaitable)
+
+ return current.driver.switch(awaitable)
+
+
+async def greenlet_spawn(
+ fn: Callable, *args, _require_await=False, **kwargs
+) -> Any:
+ """Runs a sync function ``fn`` in a new greenlet.
+
+ The sync function can then use :func:`await_only` to wait for async
+ functions.
+
+ :param fn: The sync callable to call.
+ :param \\*args: Positional arguments to pass to the ``fn`` callable.
+ :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
+ """
+
+ context = _AsyncIoGreenlet(fn, greenlet.getcurrent())
+ # runs the function synchronously in gl greenlet. If the execution
+ # is interrupted by await_only, context is not dead and result is a
+ # coroutine to wait. If the context is dead the function has
+ # returned, and its result can be returned.
+ switch_occurred = False
+ try:
+ result = context.switch(*args, **kwargs)
+ while not context.dead:
+ switch_occurred = True
+ try:
+ # wait for a coroutine from await_only and then return its
+ # result back to it.
+ value = await result
+ except BaseException:
+ # this allows an exception to be raised within
+ # the moderated greenlet so that it can continue
+ # its expected flow.
+ result = context.throw(*sys.exc_info())
+ else:
+ result = context.switch(value)
+ finally:
+ # clean up to avoid cycle resolution by gc
+ del context.driver
+ if _require_await and not switch_occurred:
+ raise exc.AwaitRequired(
+ "The current operation required an async execution but none was "
+ "detected. This will usually happen when using a non compatible "
+ "DBAPI driver. Please ensure that an async DBAPI is used."
+ )
+ return result
+
+
+class AsyncAdaptedLock:
+ @memoized_property
+ def mutex(self):
+ # there should not be a race here for coroutines creating the
+ # new lock as we are not using await, so therefore no concurrency
+ return asyncio.Lock()
+
+ def __enter__(self):
+ # await is used to acquire the lock only after the first calling
+ # coroutine has created the mutex.
+ await_fallback(self.mutex.acquire())
+ return self
+
+ def __exit__(self, *arg, **kw):
+ self.mutex.release()
+
+
+def _util_async_run_coroutine_function(fn, *args, **kwargs):
+ """for test suite/ util only"""
+
+ loop = get_event_loop()
+ if loop.is_running():
+ raise Exception(
+ "for async run coroutine we expect that no greenlet or event "
+ "loop is running when we start out"
+ )
+ return loop.run_until_complete(fn(*args, **kwargs))
+
+
+def _util_async_run(fn, *args, **kwargs):
+ """for test suite/ util only"""
+
+ loop = get_event_loop()
+ if not loop.is_running():
+ return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
+ else:
+ # allow for a wrapped test function to call another
+ assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet)
+ return fn(*args, **kwargs)
+
+
+def get_event_loop():
+ """vendor asyncio.get_event_loop() for python 3.7 and above.
+
+ Python 3.10 deprecates get_event_loop() as a standalone.
+
+ """
+ if compat.py37:
+ try:
+ return asyncio.get_running_loop()
+ except RuntimeError:
+ return asyncio.get_event_loop_policy().get_event_loop()
+ else:
+ return asyncio.get_event_loop()
diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py
new file mode 100644
index 0000000..1803de4
--- /dev/null
+++ b/lib/sqlalchemy/util/_preloaded.py
@@ -0,0 +1,68 @@
+# util/_preloaded.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""supplies the "preloaded" registry to resolve circular module imports at
+runtime.
+
+"""
+
+import sys
+
+from . import compat
+
+
+class _ModuleRegistry:
+ """Registry of modules to load in a package init file.
+
+ To avoid potential thread safety issues for imports that are deferred
+ in a function, like https://bugs.python.org/issue38884, these modules
+ are added to the system module cache by importing them after the packages
+ has finished initialization.
+
+ A global instance is provided under the name :attr:`.preloaded`. Use
+ the function :func:`.preload_module` to register modules to load and
+ :meth:`.import_prefix` to load all the modules that start with the
+ given path.
+
+ While the modules are loaded in the global module cache, it's advisable
+ to access them using :attr:`.preloaded` to ensure that it was actually
+ registered. Each registered module is added to the instance ``__dict__``
+ in the form `<package>_<module>`, omitting ``sqlalchemy`` from the package
+ name. Example: ``sqlalchemy.sql.util`` becomes ``preloaded.sql_util``.
+ """
+
+ def __init__(self, prefix="sqlalchemy."):
+ self.module_registry = set()
+ self.prefix = prefix
+
+ def preload_module(self, *deps):
+ """Adds the specified modules to the list to load.
+
+ This method can be used both as a normal function and as a decorator.
+ No change is performed to the decorated object.
+ """
+ self.module_registry.update(deps)
+ return lambda fn: fn
+
+ def import_prefix(self, path):
+ """Resolve all the modules in the registry that start with the
+ specified path.
+ """
+ for module in self.module_registry:
+ if self.prefix:
+ key = module.split(self.prefix)[-1].replace(".", "_")
+ else:
+ key = module
+ if (
+ not path or module.startswith(path)
+ ) and key not in self.__dict__:
+ compat.import_(module, globals(), locals())
+ self.__dict__[key] = sys.modules[module]
+
+
+preloaded = _ModuleRegistry()
+preload_module = preloaded.preload_module
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
new file mode 100644
index 0000000..21a9491
--- /dev/null
+++ b/lib/sqlalchemy/util/compat.py
@@ -0,0 +1,632 @@
+# util/compat.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Handle Python version/platform incompatibilities."""
+
+import collections
+import contextlib
+import inspect
+import operator
+import platform
+import sys
+
+py311 = sys.version_info >= (3, 11)
+py39 = sys.version_info >= (3, 9)
+py38 = sys.version_info >= (3, 8)
+py37 = sys.version_info >= (3, 7)
+py3k = sys.version_info >= (3, 0)
+py2k = sys.version_info < (3, 0)
+pypy = platform.python_implementation() == "PyPy"
+
+
+cpython = platform.python_implementation() == "CPython"
+win32 = sys.platform.startswith("win")
+osx = sys.platform.startswith("darwin")
+arm = "aarch" in platform.machine().lower()
+
+has_refcount_gc = bool(cpython)
+
+contextmanager = contextlib.contextmanager
+dottedgetter = operator.attrgetter
+namedtuple = collections.namedtuple
+next = next # noqa
+
+FullArgSpec = collections.namedtuple(
+ "FullArgSpec",
+ [
+ "args",
+ "varargs",
+ "varkw",
+ "defaults",
+ "kwonlyargs",
+ "kwonlydefaults",
+ "annotations",
+ ],
+)
+
+
+class nullcontext(object):
+ """Context manager that does no additional processing.
+
+ Vendored from Python 3.7.
+
+ """
+
+ def __init__(self, enter_result=None):
+ self.enter_result = enter_result
+
+ def __enter__(self):
+ return self.enter_result
+
+ def __exit__(self, *excinfo):
+ pass
+
+
+try:
+ import threading
+except ImportError:
+ import dummy_threading as threading # noqa
+
+
+def inspect_getfullargspec(func):
+ """Fully vendored version of getfullargspec from Python 3.3."""
+
+ if inspect.ismethod(func):
+ func = func.__func__
+ if not inspect.isfunction(func):
+ raise TypeError("{!r} is not a Python function".format(func))
+
+ co = func.__code__
+ if not inspect.iscode(co):
+ raise TypeError("{!r} is not a code object".format(co))
+
+ nargs = co.co_argcount
+ names = co.co_varnames
+ nkwargs = co.co_kwonlyargcount if py3k else 0
+ args = list(names[:nargs])
+ kwonlyargs = list(names[nargs : nargs + nkwargs])
+
+ nargs += nkwargs
+ varargs = None
+ if co.co_flags & inspect.CO_VARARGS:
+ varargs = co.co_varnames[nargs]
+ nargs = nargs + 1
+ varkw = None
+ if co.co_flags & inspect.CO_VARKEYWORDS:
+ varkw = co.co_varnames[nargs]
+
+ return FullArgSpec(
+ args,
+ varargs,
+ varkw,
+ func.__defaults__,
+ kwonlyargs,
+ func.__kwdefaults__ if py3k else None,
+ func.__annotations__ if py3k else {},
+ )
+
+
+if py38:
+ from importlib import metadata as importlib_metadata
+else:
+ import importlib_metadata # noqa
+
+
+def importlib_metadata_get(group):
+ ep = importlib_metadata.entry_points()
+ if hasattr(ep, "select"):
+ return ep.select(group=group)
+ else:
+ return ep.get(group, ())
+
+
+if py3k:
+ import base64
+ import builtins
+ import configparser
+ import itertools
+ import pickle
+
+ from functools import reduce
+ from io import BytesIO as byte_buffer
+ from io import StringIO
+ from itertools import zip_longest
+ from time import perf_counter
+ from urllib.parse import (
+ quote_plus,
+ unquote_plus,
+ parse_qsl,
+ quote,
+ unquote,
+ )
+
+ string_types = (str,)
+ binary_types = (bytes,)
+ binary_type = bytes
+ text_type = str
+ int_types = (int,)
+ iterbytes = iter
+ long_type = int
+
+ itertools_filterfalse = itertools.filterfalse
+ itertools_filter = filter
+ itertools_imap = map
+
+ exec_ = getattr(builtins, "exec")
+ import_ = getattr(builtins, "__import__")
+ print_ = getattr(builtins, "print")
+
+ def b(s):
+ return s.encode("latin-1")
+
+ def b64decode(x):
+ return base64.b64decode(x.encode("ascii"))
+
+ def b64encode(x):
+ return base64.b64encode(x).decode("ascii")
+
+ def decode_backslashreplace(text, encoding):
+ return text.decode(encoding, errors="backslashreplace")
+
+ def cmp(a, b):
+ return (a > b) - (a < b)
+
+ def raise_(
+ exception, with_traceback=None, replace_context=None, from_=False
+ ):
+ r"""implement "raise" with cause support.
+
+ :param exception: exception to raise
+ :param with_traceback: will call exception.with_traceback()
+ :param replace_context: an as-yet-unsupported feature. This is
+ an exception object which we are "replacing", e.g., it's our
+ "cause" but we don't want it printed. Basically just what
+ ``__suppress_context__`` does but we don't want to suppress
+ the enclosing context, if any. So for now we make it the
+ cause.
+ :param from\_: the cause. this actually sets the cause and doesn't
+ hope to hide it someday.
+
+ """
+ if with_traceback is not None:
+ exception = exception.with_traceback(with_traceback)
+
+ if from_ is not False:
+ exception.__cause__ = from_
+ elif replace_context is not None:
+ # no good solution here, we would like to have the exception
+ # have only the context of replace_context.__context__ so that the
+ # intermediary exception does not change, but we can't figure
+ # that out.
+ exception.__cause__ = replace_context
+
+ try:
+ raise exception
+ finally:
+ # credit to
+ # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+ # as the __traceback__ object creates a cycle
+ del exception, replace_context, from_, with_traceback
+
+ def u(s):
+ return s
+
+ def ue(s):
+ return s
+
+ from typing import TYPE_CHECKING
+
+ # Unused. Kept for backwards compatibility.
+ callable = callable # noqa
+
+ from abc import ABC
+
+ def _qualname(fn):
+ return fn.__qualname__
+
+
+else:
+ import base64
+ import ConfigParser as configparser # noqa
+ import itertools
+
+ from StringIO import StringIO # noqa
+ from cStringIO import StringIO as byte_buffer # noqa
+ from itertools import izip_longest as zip_longest # noqa
+ from time import clock as perf_counter # noqa
+ from urllib import quote # noqa
+ from urllib import quote_plus # noqa
+ from urllib import unquote # noqa
+ from urllib import unquote_plus # noqa
+ from urlparse import parse_qsl # noqa
+
+ from abc import ABCMeta
+
+ class ABC(object):
+ __metaclass__ = ABCMeta
+
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle # noqa
+
+ string_types = (basestring,) # noqa
+ binary_types = (bytes,)
+ binary_type = str
+ text_type = unicode # noqa
+ int_types = int, long # noqa
+ long_type = long # noqa
+
+ callable = callable # noqa
+ cmp = cmp # noqa
+ reduce = reduce # noqa
+
+ b64encode = base64.b64encode
+ b64decode = base64.b64decode
+
+ itertools_filterfalse = itertools.ifilterfalse
+ itertools_filter = itertools.ifilter
+ itertools_imap = itertools.imap
+
+ def b(s):
+ return s
+
+ def exec_(func_text, globals_, lcl=None):
+ if lcl is None:
+ exec("exec func_text in globals_")
+ else:
+ exec("exec func_text in globals_, lcl")
+
+ def iterbytes(buf):
+ return (ord(byte) for byte in buf)
+
+ def import_(*args):
+ if len(args) == 4:
+ args = args[0:3] + ([str(arg) for arg in args[3]],)
+ return __import__(*args)
+
+ def print_(*args, **kwargs):
+ fp = kwargs.pop("file", sys.stdout)
+ if fp is None:
+ return
+ for arg in enumerate(args):
+ if not isinstance(arg, basestring): # noqa
+ arg = str(arg)
+ fp.write(arg)
+
+ def u(s):
+ # this differs from what six does, which doesn't support non-ASCII
+ # strings - we only use u() with
+ # literal source strings, and all our source files with non-ascii
+ # in them (all are tests) are utf-8 encoded.
+ return unicode(s, "utf-8") # noqa
+
+ def ue(s):
+ return unicode(s, "unicode_escape") # noqa
+
+ def decode_backslashreplace(text, encoding):
+ try:
+ return text.decode(encoding)
+ except UnicodeDecodeError:
+ # regular "backslashreplace" for an incompatible encoding raises:
+ # "TypeError: don't know how to handle UnicodeDecodeError in
+ # error callback"
+ return repr(text)[1:-1].decode()
+
+ def safe_bytestring(text):
+ # py2k only
+ if not isinstance(text, string_types):
+ return unicode(text).encode( # noqa: F821
+ "ascii", errors="backslashreplace"
+ )
+ elif isinstance(text, unicode): # noqa: F821
+ return text.encode("ascii", errors="backslashreplace")
+ else:
+ return text
+
+ exec(
+ "def raise_(exception, with_traceback=None, replace_context=None, "
+ "from_=False):\n"
+ " if with_traceback:\n"
+ " raise type(exception), exception, with_traceback\n"
+ " else:\n"
+ " raise exception\n"
+ )
+
+ TYPE_CHECKING = False
+
+ def _qualname(meth):
+ """return __qualname__ equivalent for a method on a class"""
+
+ for cls in meth.im_class.__mro__:
+ if meth.__name__ in cls.__dict__:
+ break
+ else:
+ return meth.__name__
+
+ return "%s.%s" % (cls.__name__, meth.__name__)
+
+
+if py3k:
+
+ def _formatannotation(annotation, base_module=None):
+ """vendored from python 3.7"""
+
+ if getattr(annotation, "__module__", None) == "typing":
+ return repr(annotation).replace("typing.", "")
+ if isinstance(annotation, type):
+ if annotation.__module__ in ("builtins", base_module):
+ return annotation.__qualname__
+ return annotation.__module__ + "." + annotation.__qualname__
+ return repr(annotation)
+
+ def inspect_formatargspec(
+ args,
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=(),
+ kwonlydefaults={},
+ annotations={},
+ formatarg=str,
+ formatvarargs=lambda name: "*" + name,
+ formatvarkw=lambda name: "**" + name,
+ formatvalue=lambda value: "=" + repr(value),
+ formatreturns=lambda text: " -> " + text,
+ formatannotation=_formatannotation,
+ ):
+ """Copy formatargspec from python 3.7 standard library.
+
+ Python 3 has deprecated formatargspec and requested that Signature
+ be used instead, however this requires a full reimplementation
+ of formatargspec() in terms of creating Parameter objects and such.
+ Instead of introducing all the object-creation overhead and having
+ to reinvent from scratch, just copy their compatibility routine.
+
+ Ultimately we would need to rewrite our "decorator" routine completely
+ which is not really worth it right now, until all Python 2.x support
+ is dropped.
+
+ """
+
+ kwonlydefaults = kwonlydefaults or {}
+ annotations = annotations or {}
+
+ def formatargandannotation(arg):
+ result = formatarg(arg)
+ if arg in annotations:
+ result += ": " + formatannotation(annotations[arg])
+ return result
+
+ specs = []
+ if defaults:
+ firstdefault = len(args) - len(defaults)
+ for i, arg in enumerate(args):
+ spec = formatargandannotation(arg)
+ if defaults and i >= firstdefault:
+ spec = spec + formatvalue(defaults[i - firstdefault])
+ specs.append(spec)
+
+ if varargs is not None:
+ specs.append(formatvarargs(formatargandannotation(varargs)))
+ else:
+ if kwonlyargs:
+ specs.append("*")
+
+ if kwonlyargs:
+ for kwonlyarg in kwonlyargs:
+ spec = formatargandannotation(kwonlyarg)
+ if kwonlydefaults and kwonlyarg in kwonlydefaults:
+ spec += formatvalue(kwonlydefaults[kwonlyarg])
+ specs.append(spec)
+
+ if varkw is not None:
+ specs.append(formatvarkw(formatargandannotation(varkw)))
+
+ result = "(" + ", ".join(specs) + ")"
+ if "return" in annotations:
+ result += formatreturns(formatannotation(annotations["return"]))
+ return result
+
+
+else:
+ from inspect import formatargspec as _inspect_formatargspec
+
+ def inspect_formatargspec(*spec, **kw):
+ # convert for a potential FullArgSpec from compat.getfullargspec()
+ return _inspect_formatargspec(*spec[0:4], **kw) # noqa
+
+
+# Fix deprecation of accessing ABCs straight from collections module
+# (which will stop working in 3.8).
+if py3k:
+ import collections.abc as collections_abc
+else:
+ import collections as collections_abc # noqa
+
+
+if py37:
+ import dataclasses
+
+ def dataclass_fields(cls):
+ """Return a sequence of all dataclasses.Field objects associated
+ with a class."""
+
+ if dataclasses.is_dataclass(cls):
+ return dataclasses.fields(cls)
+ else:
+ return []
+
+ def local_dataclass_fields(cls):
+ """Return a sequence of all dataclasses.Field objects associated with
+ a class, excluding those that originate from a superclass."""
+
+ if dataclasses.is_dataclass(cls):
+ super_fields = set()
+ for sup in cls.__bases__:
+ super_fields.update(dataclass_fields(sup))
+ return [
+ f for f in dataclasses.fields(cls) if f not in super_fields
+ ]
+ else:
+ return []
+
+
+else:
+
+ def dataclass_fields(cls):
+ return []
+
+ def local_dataclass_fields(cls):
+ return []
+
+
+def raise_from_cause(exception, exc_info=None):
+ r"""legacy. use raise\_()"""
+
+ if exc_info is None:
+ exc_info = sys.exc_info()
+ exc_type, exc_value, exc_tb = exc_info
+ cause = exc_value if exc_value is not exception else None
+ reraise(type(exception), exception, tb=exc_tb, cause=cause)
+
+
+def reraise(tp, value, tb=None, cause=None):
+ r"""legacy. use raise\_()"""
+
+ raise_(value, with_traceback=tb, from_=cause)
+
+
+def with_metaclass(meta, *bases, **kw):
+ """Create a base class with a metaclass.
+
+ Drops the middle class upon creation.
+
+ Source: https://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/
+
+ """
+
+ class metaclass(meta):
+ __call__ = type.__call__
+ __init__ = type.__init__
+
+ def __new__(cls, name, this_bases, d):
+ if this_bases is None:
+ cls = type.__new__(cls, name, (), d)
+ else:
+ cls = meta(name, bases, d)
+
+ if hasattr(cls, "__init_subclass__") and hasattr(
+ cls.__init_subclass__, "__func__"
+ ):
+ cls.__init_subclass__.__func__(cls, **kw)
+ return cls
+
+ return metaclass("temporary_class", None, {})
+
+
+if py3k:
+ from datetime import timezone
+else:
+ from datetime import datetime
+ from datetime import timedelta
+ from datetime import tzinfo
+
+ class timezone(tzinfo):
+ """Minimal port of python 3 timezone object"""
+
+ __slots__ = "_offset"
+
+ def __init__(self, offset):
+ if not isinstance(offset, timedelta):
+ raise TypeError("offset must be a timedelta")
+ if not self._minoffset <= offset <= self._maxoffset:
+ raise ValueError(
+ "offset must be a timedelta "
+ "strictly between -timedelta(hours=24) and "
+ "timedelta(hours=24)."
+ )
+ self._offset = offset
+
+ def __eq__(self, other):
+ if type(other) != timezone:
+ return False
+ return self._offset == other._offset
+
+ def __hash__(self):
+ return hash(self._offset)
+
+ def __repr__(self):
+ return "sqlalchemy.util.%s(%r)" % (
+ self.__class__.__name__,
+ self._offset,
+ )
+
+ def __str__(self):
+ return self.tzname(None)
+
+ def utcoffset(self, dt):
+ return self._offset
+
+ def tzname(self, dt):
+ return self._name_from_offset(self._offset)
+
+ def dst(self, dt):
+ return None
+
+ def fromutc(self, dt):
+ if isinstance(dt, datetime):
+ if dt.tzinfo is not self:
+ raise ValueError("fromutc: dt.tzinfo " "is not self")
+ return dt + self._offset
+ raise TypeError(
+ "fromutc() argument must be a datetime instance" " or None"
+ )
+
+ @staticmethod
+ def _timedelta_to_microseconds(timedelta):
+ """backport of timedelta._to_microseconds()"""
+ return (
+ timedelta.days * (24 * 3600) + timedelta.seconds
+ ) * 1000000 + timedelta.microseconds
+
+ @staticmethod
+ def _divmod_timedeltas(a, b):
+ """backport of timedelta.__divmod__"""
+
+ q, r = divmod(
+ timezone._timedelta_to_microseconds(a),
+ timezone._timedelta_to_microseconds(b),
+ )
+ return q, timedelta(0, 0, r)
+
+ @staticmethod
+ def _name_from_offset(delta):
+ if not delta:
+ return "UTC"
+ if delta < timedelta(0):
+ sign = "-"
+ delta = -delta
+ else:
+ sign = "+"
+ hours, rest = timezone._divmod_timedeltas(
+ delta, timedelta(hours=1)
+ )
+ minutes, rest = timezone._divmod_timedeltas(
+ rest, timedelta(minutes=1)
+ )
+ result = "UTC%s%02d:%02d" % (sign, hours, minutes)
+ if rest.seconds:
+ result += ":%02d" % (rest.seconds,)
+ if rest.microseconds:
+ result += ".%06d" % (rest.microseconds,)
+ return result
+
+ _maxoffset = timedelta(hours=23, minutes=59)
+ _minoffset = -_maxoffset
+
+ timezone.utc = timezone(timedelta(0))
diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py
new file mode 100644
index 0000000..e900b43
--- /dev/null
+++ b/lib/sqlalchemy/util/concurrency.py
@@ -0,0 +1,73 @@
+# util/concurrency.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import compat
+
+have_greenlet = False
+greenlet_error = None
+
+if compat.py3k:
+ try:
+ import greenlet # noqa: F401
+ except ImportError as e:
+ greenlet_error = str(e)
+ else:
+ have_greenlet = True
+ from ._concurrency_py3k import await_only
+ from ._concurrency_py3k import await_fallback
+ from ._concurrency_py3k import greenlet_spawn
+ from ._concurrency_py3k import is_exit_exception
+ from ._concurrency_py3k import AsyncAdaptedLock
+ from ._concurrency_py3k import _util_async_run # noqa: F401
+ from ._concurrency_py3k import (
+ _util_async_run_coroutine_function,
+ ) # noqa: F401, E501
+ from ._concurrency_py3k import asyncio # noqa: F401
+
+ # does not need greennlet, just Python 3
+ from ._compat_py3k import asynccontextmanager # noqa: F401
+
+if not have_greenlet:
+
+ asyncio = None # noqa: F811
+
+ def _not_implemented():
+ # this conditional is to prevent pylance from considering
+ # greenlet_spawn() etc as "no return" and dimming out code below it
+ if have_greenlet:
+ return None
+
+ if not compat.py3k:
+ raise ValueError("Cannot use this function in py2.")
+ else:
+ raise ValueError(
+ "the greenlet library is required to use this function."
+ " %s" % greenlet_error
+ if greenlet_error
+ else ""
+ )
+
+ def is_exit_exception(e): # noqa: F811
+ return not isinstance(e, Exception)
+
+ def await_only(thing): # noqa: F811
+ _not_implemented()
+
+ def await_fallback(thing): # noqa: F811
+ return thing
+
+ def greenlet_spawn(fn, *args, **kw): # noqa: F811
+ _not_implemented()
+
+ def AsyncAdaptedLock(*args, **kw): # noqa: F811
+ _not_implemented()
+
+ def _util_async_run(fn, *arg, **kw): # noqa: F811
+ return fn(*arg, **kw)
+
+ def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa: F811
+ _not_implemented()
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
new file mode 100644
index 0000000..b61516d
--- /dev/null
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -0,0 +1,417 @@
+# util/deprecations.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Helpers related to deprecation of functions, methods, classes, other
+functionality."""
+
+import os
+import re
+
+from . import compat
+from .langhelpers import _hash_limit_string
+from .langhelpers import _warnings_warn
+from .langhelpers import decorator
+from .langhelpers import inject_docstring_text
+from .langhelpers import inject_param_text
+from .. import exc
+
+
+SQLALCHEMY_WARN_20 = False
+
+if os.getenv("SQLALCHEMY_WARN_20", "false").lower() in ("true", "yes", "1"):
+ SQLALCHEMY_WARN_20 = True
+
+
+def _warn_with_version(msg, version, type_, stacklevel, code=None):
+ if (
+ issubclass(type_, exc.Base20DeprecationWarning)
+ and not SQLALCHEMY_WARN_20
+ ):
+ return
+
+ warn = type_(msg, code=code)
+ warn.deprecated_since = version
+
+ _warnings_warn(warn, stacklevel=stacklevel + 1)
+
+
+def warn_deprecated(msg, version, stacklevel=3, code=None):
+ _warn_with_version(
+ msg, version, exc.SADeprecationWarning, stacklevel, code=code
+ )
+
+
+def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None):
+ """Issue a deprecation warning with a parameterized string,
+ limiting the number of registrations.
+
+ """
+ if args:
+ msg = _hash_limit_string(msg, 10, args)
+ _warn_with_version(
+ msg, version, exc.SADeprecationWarning, stacklevel, code=code
+ )
+
+
+def warn_deprecated_20(msg, stacklevel=3, code=None):
+
+ _warn_with_version(
+ msg,
+ exc.RemovedIn20Warning.deprecated_since,
+ exc.RemovedIn20Warning,
+ stacklevel,
+ code=code,
+ )
+
+
+def deprecated_cls(version, message, constructor="__init__"):
+ header = ".. deprecated:: %s %s" % (version, (message or ""))
+
+ def decorate(cls):
+ return _decorate_cls_with_warning(
+ cls,
+ constructor,
+ exc.SADeprecationWarning,
+ message % dict(func=constructor),
+ version,
+ header,
+ )
+
+ return decorate
+
+
+def deprecated_20_cls(
+ clsname, alternative=None, constructor="__init__", becomes_legacy=False
+):
+ message = (
+ ".. deprecated:: 1.4 The %s class is considered legacy as of the "
+ "1.x series of SQLAlchemy and %s in 2.0."
+ % (
+ clsname,
+ "will be removed"
+ if not becomes_legacy
+ else "becomes a legacy construct",
+ )
+ )
+
+ if alternative:
+ message += " " + alternative
+
+ if becomes_legacy:
+ warning_cls = exc.LegacyAPIWarning
+ else:
+ warning_cls = exc.RemovedIn20Warning
+
+ def decorate(cls):
+ return _decorate_cls_with_warning(
+ cls,
+ constructor,
+ warning_cls,
+ message,
+ warning_cls.deprecated_since,
+ message,
+ )
+
+ return decorate
+
+
+def deprecated(
+ version,
+ message=None,
+ add_deprecation_to_docstring=True,
+ warning=None,
+ enable_warnings=True,
+):
+ """Decorates a function and issues a deprecation warning on use.
+
+ :param version:
+ Issue version in the warning.
+
+ :param message:
+ If provided, issue message in the warning. A sensible default
+ is used if not provided.
+
+ :param add_deprecation_to_docstring:
+ Default True. If False, the wrapped function's __doc__ is left
+ as-is. If True, the 'message' is prepended to the docs if
+ provided, or sensible default if message is omitted.
+
+ """
+
+ # nothing is deprecated "since" 2.0 at this time. All "removed in 2.0"
+ # should emit the RemovedIn20Warning, but messaging should be expressed
+ # in terms of "deprecated since 1.4".
+
+ if version == "2.0":
+ if warning is None:
+ warning = exc.RemovedIn20Warning
+ version = "1.4"
+ if add_deprecation_to_docstring:
+ header = ".. deprecated:: %s %s" % (
+ version,
+ (message or ""),
+ )
+ else:
+ header = None
+
+ if message is None:
+ message = "Call to deprecated function %(func)s"
+
+ if warning is None:
+ warning = exc.SADeprecationWarning
+
+ if warning is not exc.RemovedIn20Warning:
+ message += " (deprecated since: %s)" % version
+
+ def decorate(fn):
+ return _decorate_with_warning(
+ fn,
+ warning,
+ message % dict(func=fn.__name__),
+ version,
+ header,
+ enable_warnings=enable_warnings,
+ )
+
+ return decorate
+
+
+def moved_20(message, **kw):
+ return deprecated(
+ "2.0", message=message, warning=exc.MovedIn20Warning, **kw
+ )
+
+
+def deprecated_20(api_name, alternative=None, becomes_legacy=False, **kw):
+ type_reg = re.match("^:(attr|func|meth):", api_name)
+ if type_reg:
+ type_ = {"attr": "attribute", "func": "function", "meth": "method"}[
+ type_reg.group(1)
+ ]
+ else:
+ type_ = "construct"
+ message = (
+ "The %s %s is considered legacy as of the "
+ "1.x series of SQLAlchemy and %s in 2.0."
+ % (
+ api_name,
+ type_,
+ "will be removed"
+ if not becomes_legacy
+ else "becomes a legacy construct",
+ )
+ )
+
+ if ":attr:" in api_name:
+ attribute_ok = kw.pop("warn_on_attribute_access", False)
+ if not attribute_ok:
+ assert kw.get("enable_warnings") is False, (
+ "attribute %s will emit a warning on read access. "
+ "If you *really* want this, "
+ "add warn_on_attribute_access=True. Otherwise please add "
+ "enable_warnings=False." % api_name
+ )
+
+ if alternative:
+ message += " " + alternative
+
+ if becomes_legacy:
+ warning_cls = exc.LegacyAPIWarning
+ else:
+ warning_cls = exc.RemovedIn20Warning
+
+ return deprecated("2.0", message=message, warning=warning_cls, **kw)
+
+
+def deprecated_params(**specs):
+ """Decorates a function to warn on use of certain parameters.
+
+ e.g. ::
+
+ @deprecated_params(
+ weak_identity_map=(
+ "0.7",
+ "the :paramref:`.Session.weak_identity_map parameter "
+ "is deprecated."
+ )
+
+ )
+
+ """
+
+ messages = {}
+ versions = {}
+ version_warnings = {}
+
+ for param, (version, message) in specs.items():
+ versions[param] = version
+ messages[param] = _sanitize_restructured_text(message)
+ version_warnings[param] = (
+ exc.RemovedIn20Warning
+ if version == "2.0"
+ else exc.SADeprecationWarning
+ )
+
+ def decorate(fn):
+ spec = compat.inspect_getfullargspec(fn)
+
+ if spec.defaults is not None:
+ defaults = dict(
+ zip(
+ spec.args[(len(spec.args) - len(spec.defaults)) :],
+ spec.defaults,
+ )
+ )
+ check_defaults = set(defaults).intersection(messages)
+ check_kw = set(messages).difference(defaults)
+ else:
+ check_defaults = ()
+ check_kw = set(messages)
+
+ check_any_kw = spec.varkw
+
+ @decorator
+ def warned(fn, *args, **kwargs):
+ for m in check_defaults:
+ if (defaults[m] is None and kwargs[m] is not None) or (
+ defaults[m] is not None and kwargs[m] != defaults[m]
+ ):
+ _warn_with_version(
+ messages[m],
+ versions[m],
+ version_warnings[m],
+ stacklevel=3,
+ )
+
+ if check_any_kw in messages and set(kwargs).difference(
+ check_defaults
+ ):
+
+ _warn_with_version(
+ messages[check_any_kw],
+ versions[check_any_kw],
+ version_warnings[check_any_kw],
+ stacklevel=3,
+ )
+
+ for m in check_kw:
+ if m in kwargs:
+ _warn_with_version(
+ messages[m],
+ versions[m],
+ version_warnings[m],
+ stacklevel=3,
+ )
+ return fn(*args, **kwargs)
+
+ doc = fn.__doc__ is not None and fn.__doc__ or ""
+ if doc:
+ doc = inject_param_text(
+ doc,
+ {
+ param: ".. deprecated:: %s %s"
+ % ("1.4" if version == "2.0" else version, (message or ""))
+ for param, (version, message) in specs.items()
+ },
+ )
+ decorated = warned(fn)
+ decorated.__doc__ = doc
+ return decorated
+
+ return decorate
+
+
+def _sanitize_restructured_text(text):
+ def repl(m):
+ type_, name = m.group(1, 2)
+ if type_ in ("func", "meth"):
+ name += "()"
+ return name
+
+ text = re.sub(r":ref:`(.+) <.*>`", lambda m: '"%s"' % m.group(1), text)
+ return re.sub(r"\:(\w+)\:`~?(?:_\w+)?\.?(.+?)`", repl, text)
+
+
+def _decorate_cls_with_warning(
+ cls, constructor, wtype, message, version, docstring_header=None
+):
+ doc = cls.__doc__ is not None and cls.__doc__ or ""
+ if docstring_header is not None:
+
+ if constructor is not None:
+ docstring_header %= dict(func=constructor)
+
+ if issubclass(wtype, exc.Base20DeprecationWarning):
+ docstring_header += (
+ " (Background on SQLAlchemy 2.0 at: "
+ ":ref:`migration_20_toplevel`)"
+ )
+ doc = inject_docstring_text(doc, docstring_header, 1)
+
+ if type(cls) is type:
+ clsdict = dict(cls.__dict__)
+ clsdict["__doc__"] = doc
+ clsdict.pop("__dict__", None)
+ clsdict.pop("__weakref__", None)
+ cls = type(cls.__name__, cls.__bases__, clsdict)
+ if constructor is not None:
+ constructor_fn = clsdict[constructor]
+
+ else:
+ cls.__doc__ = doc
+ if constructor is not None:
+ constructor_fn = getattr(cls, constructor)
+
+ if constructor is not None:
+ setattr(
+ cls,
+ constructor,
+ _decorate_with_warning(
+ constructor_fn, wtype, message, version, None
+ ),
+ )
+ return cls
+
+
+def _decorate_with_warning(
+ func, wtype, message, version, docstring_header=None, enable_warnings=True
+):
+ """Wrap a function with a warnings.warn and augmented docstring."""
+
+ message = _sanitize_restructured_text(message)
+
+ if issubclass(wtype, exc.Base20DeprecationWarning):
+ doc_only = (
+ " (Background on SQLAlchemy 2.0 at: "
+ ":ref:`migration_20_toplevel`)"
+ )
+ else:
+ doc_only = ""
+
+ @decorator
+ def warned(fn, *args, **kwargs):
+ skip_warning = not enable_warnings or kwargs.pop(
+ "_sa_skip_warning", False
+ )
+ if not skip_warning:
+ _warn_with_version(message, version, wtype, stacklevel=3)
+ return fn(*args, **kwargs)
+
+ doc = func.__doc__ is not None and func.__doc__ or ""
+ if docstring_header is not None:
+ docstring_header %= dict(func=func.__name__)
+
+ docstring_header += doc_only
+
+ doc = inject_docstring_text(doc, docstring_header, 1)
+
+ decorated = warned(func)
+ decorated.__doc__ = doc
+ decorated._sa_warn = lambda: _warn_with_version(
+ message, version, wtype, stacklevel=3
+ )
+ return decorated
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
new file mode 100644
index 0000000..c3636f0
--- /dev/null
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -0,0 +1,1945 @@
+# util/langhelpers.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Routines to help with the creation, loading and introspection of
+modules, classes, hierarchies, attributes, functions, and methods.
+
+"""
+
+import collections
+from functools import update_wrapper
+import hashlib
+import inspect
+import itertools
+import operator
+import re
+import sys
+import textwrap
+import types
+import warnings
+
+from . import _collections
+from . import compat
+from .. import exc
+
+
+def md5_hex(x):
+ if compat.py3k:
+ x = x.encode("utf-8")
+ m = hashlib.md5()
+ m.update(x)
+ return m.hexdigest()
+
+
+class safe_reraise(object):
+ """Reraise an exception after invoking some
+ handler code.
+
+ Stores the existing exception info before
+ invoking so that it is maintained across a potential
+ coroutine context switch.
+
+ e.g.::
+
+ try:
+ sess.commit()
+ except:
+ with safe_reraise():
+ sess.rollback()
+
+ """
+
+ __slots__ = ("warn_only", "_exc_info")
+
+ def __init__(self, warn_only=False):
+ self.warn_only = warn_only
+
+ def __enter__(self):
+ self._exc_info = sys.exc_info()
+
+ def __exit__(self, type_, value, traceback):
+ # see #2703 for notes
+ if type_ is None:
+ exc_type, exc_value, exc_tb = self._exc_info
+ self._exc_info = None # remove potential circular references
+ if not self.warn_only:
+ compat.raise_(
+ exc_value,
+ with_traceback=exc_tb,
+ )
+ else:
+ if not compat.py3k and self._exc_info and self._exc_info[1]:
+ # emulate Py3K's behavior of telling us when an exception
+ # occurs in an exception handler.
+ warn(
+ "An exception has occurred during handling of a "
+ "previous exception. The previous exception "
+ "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
+ )
+ self._exc_info = None # remove potential circular references
+ compat.raise_(value, with_traceback=traceback)
+
+
+def walk_subclasses(cls):
+ seen = set()
+
+ stack = [cls]
+ while stack:
+ cls = stack.pop()
+ if cls in seen:
+ continue
+ else:
+ seen.add(cls)
+ stack.extend(cls.__subclasses__())
+ yield cls
+
+
+def string_or_unprintable(element):
+ if isinstance(element, compat.string_types):
+ return element
+ else:
+ try:
+ return str(element)
+ except Exception:
+ return "unprintable element %r" % element
+
+
+def clsname_as_plain_name(cls):
+ return " ".join(
+ n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
+ )
+
+
+def method_is_overridden(instance_or_cls, against_method):
+ """Return True if the two class methods don't match."""
+
+ if not isinstance(instance_or_cls, type):
+ current_cls = instance_or_cls.__class__
+ else:
+ current_cls = instance_or_cls
+
+ method_name = against_method.__name__
+
+ current_method = getattr(current_cls, method_name)
+
+ return current_method != against_method
+
+
+def decode_slice(slc):
+ """decode a slice object as sent to __getitem__.
+
+ takes into account the 2.5 __index__() method, basically.
+
+ """
+ ret = []
+ for x in slc.start, slc.stop, slc.step:
+ if hasattr(x, "__index__"):
+ x = x.__index__()
+ ret.append(x)
+ return tuple(ret)
+
+
+def _unique_symbols(used, *bases):
+ used = set(used)
+ for base in bases:
+ pool = itertools.chain(
+ (base,),
+ compat.itertools_imap(lambda i: base + str(i), range(1000)),
+ )
+ for sym in pool:
+ if sym not in used:
+ used.add(sym)
+ yield sym
+ break
+ else:
+ raise NameError("exhausted namespace for symbol base %s" % base)
+
+
+def map_bits(fn, n):
+ """Call the given function given each nonzero bit from n."""
+
+ while n:
+ b = n & (~n + 1)
+ yield fn(b)
+ n ^= b
+
+
+def decorator(target):
+ """A signature-matching decorator factory."""
+
+ def decorate(fn):
+ if not inspect.isfunction(fn) and not inspect.ismethod(fn):
+ raise Exception("not a decoratable function")
+
+ spec = compat.inspect_getfullargspec(fn)
+ env = {}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+
+ names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
+ targ_name, fn_name = _unique_symbols(names, "target", "fn")
+
+ metadata = dict(target=targ_name, fn=fn_name)
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ metadata["name"] = fn.__name__
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(target)s(%(fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
+
+ decorated = _exec_code_in_env(code, env, fn.__name__)
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+
+ return update_wrapper(decorate, target)
+
+
+def _update_argspec_defaults_into_env(spec, env):
+ """given a FullArgSpec, convert defaults to be symbol names in an env."""
+
+ if spec.defaults:
+ new_defaults = []
+ i = 0
+ for arg in spec.defaults:
+ if type(arg).__module__ not in ("builtins", "__builtin__"):
+ name = "x%d" % i
+ env[name] = arg
+ new_defaults.append(name)
+ i += 1
+ else:
+ new_defaults.append(arg)
+ elem = list(spec)
+ elem[3] = tuple(new_defaults)
+ return compat.FullArgSpec(*elem)
+ else:
+ return spec
+
+
+def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+
+def public_factory(target, location, class_location=None):
+ """Produce a wrapping function for the given cls or classmethod.
+
+ Rationale here is so that the __init__ method of the
+ class can serve as documentation for the function.
+
+ """
+
+ if isinstance(target, type):
+ fn = target.__init__
+ callable_ = target
+ doc = (
+ "Construct a new :class:`%s` object. \n\n"
+ "This constructor is mirrored as a public API function; "
+ "see :func:`sqlalchemy%s` "
+ "for a full usage and argument description."
+ % (
+ class_location if class_location else ".%s" % target.__name__,
+ location,
+ )
+ )
+ else:
+ fn = callable_ = target
+ doc = (
+ "This function is mirrored; see :func:`sqlalchemy%s` "
+ "for a description of arguments." % location
+ )
+
+ location_name = location.split(".")[-1]
+ spec = compat.inspect_getfullargspec(fn)
+ del spec[0][0]
+ metadata = format_argspec_plus(spec, grouped=False)
+ metadata["name"] = location_name
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return cls(%(apply_kw)s)
+"""
+ % metadata
+ )
+ env = {
+ "cls": callable_,
+ "symbol": symbol,
+ "__name__": callable_.__module__,
+ }
+ exec(code, env)
+ decorated = env[location_name]
+
+ if hasattr(fn, "_linked_to"):
+ linked_to, linked_to_location = fn._linked_to
+ linked_to_doc = linked_to.__doc__
+ if class_location is None:
+ class_location = "%s.%s" % (target.__module__, target.__name__)
+
+ linked_to_doc = inject_docstring_text(
+ linked_to_doc,
+ ".. container:: inherited_member\n\n "
+ "This documentation is inherited from :func:`sqlalchemy%s`; "
+ "this constructor, :func:`sqlalchemy%s`, "
+ "creates a :class:`sqlalchemy%s` object. See that class for "
+ "additional details describing this subclass."
+ % (linked_to_location, location, class_location),
+ 1,
+ )
+ decorated.__doc__ = linked_to_doc
+ else:
+ decorated.__doc__ = fn.__doc__
+
+ decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
+ if decorated.__module__ not in sys.modules:
+ raise ImportError(
+ "public_factory location %s is not in sys.modules"
+ % (decorated.__module__,)
+ )
+
+ if compat.py2k or hasattr(fn, "__func__"):
+ fn.__func__.__doc__ = doc
+ if not hasattr(fn.__func__, "_linked_to"):
+ fn.__func__._linked_to = (decorated, location)
+ else:
+ fn.__doc__ = doc
+ if not hasattr(fn, "_linked_to"):
+ fn._linked_to = (decorated, location)
+
+ return decorated
+
+
+class PluginLoader(object):
+ def __init__(self, group, auto_fn=None):
+ self.group = group
+ self.impls = {}
+ self.auto_fn = auto_fn
+
+ def clear(self):
+ self.impls.clear()
+
+ def load(self, name):
+ if name in self.impls:
+ return self.impls[name]()
+
+ if self.auto_fn:
+ loader = self.auto_fn(name)
+ if loader:
+ self.impls[name] = loader
+ return loader()
+
+ for impl in compat.importlib_metadata_get(self.group):
+ if impl.name == name:
+ self.impls[name] = impl.load
+ return impl.load()
+
+ raise exc.NoSuchModuleError(
+ "Can't load plugin: %s:%s" % (self.group, name)
+ )
+
+ def register(self, name, modulepath, objname):
+ def load():
+ mod = compat.import_(modulepath)
+ for token in modulepath.split(".")[1:]:
+ mod = getattr(mod, token)
+ return getattr(mod, objname)
+
+ self.impls[name] = load
+
+
+def _inspect_func_args(fn):
+ try:
+ co_varkeywords = inspect.CO_VARKEYWORDS
+ except AttributeError:
+ # https://docs.python.org/3/library/inspect.html
+ # The flags are specific to CPython, and may not be defined in other
+ # Python implementations. Furthermore, the flags are an implementation
+ # detail, and can be removed or deprecated in future Python releases.
+ spec = compat.inspect_getfullargspec(fn)
+ return spec[0], bool(spec[2])
+ else:
+ # use fn.__code__ plus flags to reduce method call overhead
+ co = fn.__code__
+ nargs = co.co_argcount
+ return (
+ list(co.co_varnames[:nargs]),
+ bool(co.co_flags & co_varkeywords),
+ )
+
+
+def get_cls_kwargs(cls, _set=None):
+ r"""Return the full set of inherited kwargs for the given `cls`.
+
+ Probes a class's __init__ method, collecting all named arguments. If the
+ __init__ defines a \**kwargs catch-all, then the constructor is presumed
+ to pass along unrecognized keywords to its base classes, and the
+ collection process is repeated recursively on each of the bases.
+
+ Uses a subset of inspect.getfullargspec() to cut down on method overhead,
+ as this is used within the Core typing system to create copies of type
+ objects which is a performance-sensitive operation.
+
+ No anonymous tuple arguments please !
+
+ """
+ toplevel = _set is None
+ if toplevel:
+ _set = set()
+
+ ctr = cls.__dict__.get("__init__", False)
+
+ has_init = (
+ ctr
+ and isinstance(ctr, types.FunctionType)
+ and isinstance(ctr.__code__, types.CodeType)
+ )
+
+ if has_init:
+ names, has_kw = _inspect_func_args(ctr)
+ _set.update(names)
+
+ if not has_kw and not toplevel:
+ return None
+
+ if not has_init or has_kw:
+ for c in cls.__bases__:
+ if get_cls_kwargs(c, _set) is None:
+ break
+
+ _set.discard("self")
+ return _set
+
+
+def get_func_kwargs(func):
+ """Return the set of legal kwargs for the given `func`.
+
+ Uses getargspec so is safe to call for methods, functions,
+ etc.
+
+ """
+
+ return compat.inspect_getfullargspec(func)[0]
+
+
+def get_callable_argspec(fn, no_self=False, _is_init=False):
+ """Return the argument signature for any callable.
+
+ All pure-Python callables are accepted, including
+ functions, methods, classes, objects with __call__;
+ builtins and other edge cases like functools.partial() objects
+ raise a TypeError.
+
+ """
+ if inspect.isbuiltin(fn):
+ raise TypeError("Can't inspect builtin: %s" % fn)
+ elif inspect.isfunction(fn):
+ if _is_init and no_self:
+ spec = compat.inspect_getfullargspec(fn)
+ return compat.FullArgSpec(
+ spec.args[1:],
+ spec.varargs,
+ spec.varkw,
+ spec.defaults,
+ spec.kwonlyargs,
+ spec.kwonlydefaults,
+ spec.annotations,
+ )
+ else:
+ return compat.inspect_getfullargspec(fn)
+ elif inspect.ismethod(fn):
+ if no_self and (_is_init or fn.__self__):
+ spec = compat.inspect_getfullargspec(fn.__func__)
+ return compat.FullArgSpec(
+ spec.args[1:],
+ spec.varargs,
+ spec.varkw,
+ spec.defaults,
+ spec.kwonlyargs,
+ spec.kwonlydefaults,
+ spec.annotations,
+ )
+ else:
+ return compat.inspect_getfullargspec(fn.__func__)
+ elif inspect.isclass(fn):
+ return get_callable_argspec(
+ fn.__init__, no_self=no_self, _is_init=True
+ )
+ elif hasattr(fn, "__func__"):
+ return compat.inspect_getfullargspec(fn.__func__)
+ elif hasattr(fn, "__call__"):
+ if inspect.ismethod(fn.__call__):
+ return get_callable_argspec(fn.__call__, no_self=no_self)
+ else:
+ raise TypeError("Can't inspect callable: %s" % fn)
+ else:
+ raise TypeError("Can't inspect callable: %s" % fn)
+
+
+def format_argspec_plus(fn, grouped=True):
+ """Returns a dictionary of formatted, introspected function arguments.
+
+ A enhanced variant of inspect.formatargspec to support code generation.
+
+ fn
+ An inspectable callable or tuple of inspect getargspec() results.
+ grouped
+ Defaults to True; include (parens, around, argument) lists
+
+ Returns:
+
+ args
+ Full inspect.formatargspec for fn
+ self_arg
+ The name of the first positional argument, varargs[0], or None
+ if the function defines no positional arguments.
+ apply_pos
+ args, re-written in calling rather than receiving syntax. Arguments are
+ passed positionally.
+ apply_kw
+ Like apply_pos, except keyword-ish args are passed as keywords.
+ apply_pos_proxied
+ Like apply_pos but omits the self/cls argument
+
+ Example::
+
+ >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
+ {'args': '(self, a, b, c=3, **d)',
+ 'self_arg': 'self',
+ 'apply_kw': '(self, a, b, c=c, **d)',
+ 'apply_pos': '(self, a, b, c, **d)'}
+
+ """
+ if compat.callable(fn):
+ spec = compat.inspect_getfullargspec(fn)
+ else:
+ spec = fn
+
+ args = compat.inspect_formatargspec(*spec)
+
+ apply_pos = compat.inspect_formatargspec(
+ spec[0], spec[1], spec[2], None, spec[4]
+ )
+
+ if spec[0]:
+ self_arg = spec[0][0]
+
+ apply_pos_proxied = compat.inspect_formatargspec(
+ spec[0][1:], spec[1], spec[2], None, spec[4]
+ )
+
+ elif spec[1]:
+ # I'm not sure what this is
+ self_arg = "%s[0]" % spec[1]
+
+ apply_pos_proxied = apply_pos
+ else:
+ self_arg = None
+ apply_pos_proxied = apply_pos
+
+ num_defaults = 0
+ if spec[3]:
+ num_defaults += len(spec[3])
+ if spec[4]:
+ num_defaults += len(spec[4])
+ name_args = spec[0] + spec[4]
+
+ if num_defaults:
+ defaulted_vals = name_args[0 - num_defaults :]
+ else:
+ defaulted_vals = ()
+
+ apply_kw = compat.inspect_formatargspec(
+ name_args,
+ spec[1],
+ spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: "=" + x,
+ )
+
+ if spec[0]:
+ apply_kw_proxied = compat.inspect_formatargspec(
+ name_args[1:],
+ spec[1],
+ spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: "=" + x,
+ )
+ else:
+ apply_kw_proxied = apply_kw
+
+ if grouped:
+ return dict(
+ args=args,
+ self_arg=self_arg,
+ apply_pos=apply_pos,
+ apply_kw=apply_kw,
+ apply_pos_proxied=apply_pos_proxied,
+ apply_kw_proxied=apply_kw_proxied,
+ )
+ else:
+ return dict(
+ args=args[1:-1],
+ self_arg=self_arg,
+ apply_pos=apply_pos[1:-1],
+ apply_kw=apply_kw[1:-1],
+ apply_pos_proxied=apply_pos_proxied[1:-1],
+ apply_kw_proxied=apply_kw_proxied[1:-1],
+ )
+
+
+def format_argspec_init(method, grouped=True):
+ """format_argspec_plus with considerations for typical __init__ methods
+
+ Wraps format_argspec_plus with error handling strategies for typical
+ __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ if method is object.__init__:
+ args = "(self)" if grouped else "self"
+ proxied = "()" if grouped else ""
+ else:
+ try:
+ return format_argspec_plus(method, grouped=grouped)
+ except TypeError:
+ args = (
+ "(self, *args, **kwargs)"
+ if grouped
+ else "self, *args, **kwargs"
+ )
+ proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
+ return dict(
+ self_arg="self",
+ args=args,
+ apply_pos=args,
+ apply_kw=args,
+ apply_pos_proxied=proxied,
+ apply_kw_proxied=proxied,
+ )
+
+
+def create_proxy_methods(
+ target_cls,
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ classmethods=(),
+ methods=(),
+ attributes=(),
+):
+ """A class decorator that will copy attributes to a proxy class.
+
+ The class to be instrumented must define a single accessor "_proxied".
+
+ """
+
+ def decorate(cls):
+ def instrument(name, clslevel=False):
+ fn = getattr(target_cls, name)
+ spec = compat.inspect_getfullargspec(fn)
+ env = {"__name__": fn.__module__}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+ caller_argspec = format_argspec_plus(spec, grouped=False)
+
+ metadata = {
+ "name": fn.__name__,
+ "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+ "apply_kw_proxied": caller_argspec["apply_kw_proxied"],
+ "args": caller_argspec["args"],
+ "self_arg": caller_argspec["self_arg"],
+ }
+
+ if clslevel:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return target_cls.%(name)s(%(apply_kw_proxied)s)"
+ % metadata
+ )
+ env["target_cls"] = target_cls
+ else:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return %(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)" # noqa: E501
+ % metadata
+ )
+
+ proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+ proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ proxy_fn.__doc__ = inject_docstring_text(
+ fn.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (target_cls_sphinx_name, proxy_cls_sphinx_name),
+ 1,
+ )
+
+ if clslevel:
+ proxy_fn = classmethod(proxy_fn)
+
+ return proxy_fn
+
+ def makeprop(name):
+ attr = target_cls.__dict__.get(name, None)
+
+ if attr is not None:
+ doc = inject_docstring_text(
+ attr.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ ),
+ 1,
+ )
+ else:
+ doc = None
+
+ code = (
+ "def set_(self, attr):\n"
+ " self._proxied.%(name)s = attr\n"
+ "def get(self):\n"
+ " return self._proxied.%(name)s\n"
+ "get.__doc__ = doc\n"
+ "getset = property(get, set_)"
+ ) % {"name": name}
+
+ getset = _exec_code_in_env(code, {"doc": doc}, "getset")
+
+ return getset
+
+ for meth in methods:
+ if hasattr(cls, meth):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, meth)
+ )
+ setattr(cls, meth, instrument(meth))
+
+ for prop in attributes:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, makeprop(prop))
+
+ for prop in classmethods:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, instrument(prop, clslevel=True))
+
+ return cls
+
+ return decorate
+
+
+def getargspec_init(method):
+ """inspect.getargspec with considerations for typical __init__ methods
+
+ Wraps inspect.getargspec with error handling for typical __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ try:
+ return compat.inspect_getfullargspec(method)
+ except TypeError:
+ if method is object.__init__:
+ return (["self"], None, None, None)
+ else:
+ return (["self"], "args", "kwargs", None)
+
+
+def unbound_method_to_callable(func_or_cls):
+ """Adjust the incoming callable such that a 'self' argument is not
+ required.
+
+ """
+
+ if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
+ return func_or_cls.__func__
+ else:
+ return func_or_cls
+
+
+def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
+ """Produce a __repr__() based on direct association of the __init__()
+ specification vs. same-named attributes present.
+
+ """
+ if to_inspect is None:
+ to_inspect = [obj]
+ else:
+ to_inspect = _collections.to_list(to_inspect)
+
+ missing = object()
+
+ pos_args = []
+ kw_args = _collections.OrderedDict()
+ vargs = None
+ for i, insp in enumerate(to_inspect):
+ try:
+ spec = compat.inspect_getfullargspec(insp.__init__)
+ except TypeError:
+ continue
+ else:
+ default_len = spec.defaults and len(spec.defaults) or 0
+ if i == 0:
+ if spec.varargs:
+ vargs = spec.varargs
+ if default_len:
+ pos_args.extend(spec.args[1:-default_len])
+ else:
+ pos_args.extend(spec.args[1:])
+ else:
+ kw_args.update(
+ [(arg, missing) for arg in spec.args[1:-default_len]]
+ )
+
+ if default_len:
+ kw_args.update(
+ [
+ (arg, default)
+ for arg, default in zip(
+ spec.args[-default_len:], spec.defaults
+ )
+ ]
+ )
+ output = []
+
+ output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+
+ if vargs is not None and hasattr(obj, vargs):
+ output.extend([repr(val) for val in getattr(obj, vargs)])
+
+ for arg, defval in kw_args.items():
+ if arg in omit_kwarg:
+ continue
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append("%s=%r" % (arg, val))
+ except Exception:
+ pass
+
+ if additional_kw:
+ for arg, defval in additional_kw:
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append("%s=%r" % (arg, val))
+ except Exception:
+ pass
+
+ return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
+
+
+class portable_instancemethod(object):
+ """Turn an instancemethod into a (parent, name) pair
+ to produce a serializable callable.
+
+ """
+
+ __slots__ = "target", "name", "kwargs", "__weakref__"
+
+ def __getstate__(self):
+ return {
+ "target": self.target,
+ "name": self.name,
+ "kwargs": self.kwargs,
+ }
+
+ def __setstate__(self, state):
+ self.target = state["target"]
+ self.name = state["name"]
+ self.kwargs = state.get("kwargs", ())
+
+ def __init__(self, meth, kwargs=()):
+ self.target = meth.__self__
+ self.name = meth.__name__
+ self.kwargs = kwargs
+
+ def __call__(self, *arg, **kw):
+ kw.update(self.kwargs)
+ return getattr(self.target, self.name)(*arg, **kw)
+
+
+def class_hierarchy(cls):
+ """Return an unordered sequence of all classes related to cls.
+
+ Traverses diamond hierarchies.
+
+ Fibs slightly: subclasses of builtin types are not returned. Thus
+ class_hierarchy(class A(object)) returns (A, object), not A plus every
+ class systemwide that derives from object.
+
+ Old-style classes are discarded and hierarchies rooted on them
+ will not be descended.
+
+ """
+ if compat.py2k:
+ if isinstance(cls, types.ClassType):
+ return list()
+
+ hier = {cls}
+ process = list(cls.__mro__)
+ while process:
+ c = process.pop()
+ if compat.py2k:
+ if isinstance(c, types.ClassType):
+ continue
+ bases = (
+ _
+ for _ in c.__bases__
+ if _ not in hier and not isinstance(_, types.ClassType)
+ )
+ else:
+ bases = (_ for _ in c.__bases__ if _ not in hier)
+
+ for b in bases:
+ process.append(b)
+ hier.add(b)
+
+ if compat.py3k:
+ if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
+ continue
+ else:
+ if c.__module__ == "__builtin__" or not hasattr(
+ c, "__subclasses__"
+ ):
+ continue
+
+ for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+ process.append(s)
+ hier.add(s)
+ return list(hier)
+
+
+def iterate_attributes(cls):
+ """iterate all the keys and attributes associated
+ with a class, without using getattr().
+
+ Does not use getattr() so that class-sensitive
+ descriptors (i.e. property.__get__()) are not called.
+
+ """
+ keys = dir(cls)
+ for key in keys:
+ for c in cls.__mro__:
+ if key in c.__dict__:
+ yield (key, c.__dict__[key])
+ break
+
+
+def monkeypatch_proxied_specials(
+ into_cls,
+ from_cls,
+ skip=None,
+ only=None,
+ name="self.proxy",
+ from_instance=None,
+):
+ """Automates delegation of __specials__ for a proxying type."""
+
+ if only:
+ dunders = only
+ else:
+ if skip is None:
+ skip = (
+ "__slots__",
+ "__del__",
+ "__getattribute__",
+ "__metaclass__",
+ "__getstate__",
+ "__setstate__",
+ )
+ dunders = [
+ m
+ for m in dir(from_cls)
+ if (
+ m.startswith("__")
+ and m.endswith("__")
+ and not hasattr(into_cls, m)
+ and m not in skip
+ )
+ ]
+
+ for method in dunders:
+ try:
+ fn = getattr(from_cls, method)
+ if not hasattr(fn, "__call__"):
+ continue
+ fn = getattr(fn, "__func__", fn)
+ except AttributeError:
+ continue
+ try:
+ spec = compat.inspect_getfullargspec(fn)
+ fn_args = compat.inspect_formatargspec(spec[0])
+ d_args = compat.inspect_formatargspec(spec[0][1:])
+ except TypeError:
+ fn_args = "(self, *args, **kw)"
+ d_args = "(*args, **kw)"
+
+ py = (
+ "def %(method)s%(fn_args)s: "
+ "return %(name)s.%(method)s%(d_args)s" % locals()
+ )
+
+ env = from_instance is not None and {name: from_instance} or {}
+ compat.exec_(py, env)
+ try:
+ env[method].__defaults__ = fn.__defaults__
+ except AttributeError:
+ pass
+ setattr(into_cls, method, env[method])
+
+
+def methods_equivalent(meth1, meth2):
+ """Return True if the two methods are the same implementation."""
+
+ return getattr(meth1, "__func__", meth1) is getattr(
+ meth2, "__func__", meth2
+ )
+
+
+def as_interface(obj, cls=None, methods=None, required=None):
+ """Ensure basic interface compliance for an instance or dict of callables.
+
+ Checks that ``obj`` implements public methods of ``cls`` or has members
+ listed in ``methods``. If ``required`` is not supplied, implementing at
+ least one interface method is sufficient. Methods present on ``obj`` that
+ are not in the interface are ignored.
+
+ If ``obj`` is a dict and ``dict`` does not meet the interface
+ requirements, the keys of the dictionary are inspected. Keys present in
+ ``obj`` that are not in the interface will raise TypeErrors.
+
+ Raises TypeError if ``obj`` does not meet the interface criteria.
+
+ In all passing cases, an object with callable members is returned. In the
+ simple case, ``obj`` is returned as-is; if dict processing kicks in then
+ an anonymous class is returned.
+
+ obj
+ A type, instance, or dictionary of callables.
+ cls
+ Optional, a type. All public methods of cls are considered the
+ interface. An ``obj`` instance of cls will always pass, ignoring
+ ``required``..
+ methods
+ Optional, a sequence of method names to consider as the interface.
+ required
+ Optional, a sequence of mandatory implementations. If omitted, an
+ ``obj`` that provides at least one interface method is considered
+ sufficient. As a convenience, required may be a type, in which case
+ all public methods of the type are required.
+
+ """
+ if not cls and not methods:
+ raise TypeError("a class or collection of method names are required")
+
+ if isinstance(cls, type) and isinstance(obj, cls):
+ return obj
+
+ interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
+ implemented = set(dir(obj))
+
+ complies = operator.ge
+ if isinstance(required, type):
+ required = interface
+ elif not required:
+ required = set()
+ complies = operator.gt
+ else:
+ required = set(required)
+
+ if complies(implemented.intersection(interface), required):
+ return obj
+
+ # No dict duck typing here.
+ if not isinstance(obj, dict):
+ qualifier = complies is operator.gt and "any of" or "all of"
+ raise TypeError(
+ "%r does not implement %s: %s"
+ % (obj, qualifier, ", ".join(interface))
+ )
+
+ class AnonymousInterface(object):
+ """A callable-holding shell."""
+
+ if cls:
+ AnonymousInterface.__name__ = "Anonymous" + cls.__name__
+ found = set()
+
+ for method, impl in dictlike_iteritems(obj):
+ if method not in interface:
+ raise TypeError("%r: unknown in this interface" % method)
+ if not compat.callable(impl):
+ raise TypeError("%r=%r is not callable" % (method, impl))
+ setattr(AnonymousInterface, method, staticmethod(impl))
+ found.add(method)
+
+ if complies(found, required):
+ return AnonymousInterface
+
+ raise TypeError(
+ "dictionary does not contain required keys %s"
+ % ", ".join(required - found)
+ )
+
+
+class memoized_property(object):
+ """A read-only @property that is only evaluated once."""
+
+ def __init__(self, fget, doc=None):
+ self.fget = fget
+ self.__doc__ = doc or fget.__doc__
+ self.__name__ = fget.__name__
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ obj.__dict__[self.__name__] = result = self.fget(obj)
+ return result
+
+ def _reset(self, obj):
+ memoized_property.reset(obj, self.__name__)
+
+ @classmethod
+ def reset(cls, obj, name):
+ obj.__dict__.pop(name, None)
+
+
+def memoized_instancemethod(fn):
+ """Decorate a method memoize its return value.
+
+ Best applied to no-arg methods: memoization is not sensitive to
+ argument values, and will always return the same value even when
+ called with different arguments.
+
+ """
+
+ def oneshot(self, *args, **kw):
+ result = fn(self, *args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ self.__dict__[fn.__name__] = memo
+ return result
+
+ return update_wrapper(oneshot, fn)
+
+
+class HasMemoized(object):
+ """A class that maintains the names of memoized elements in a
+ collection for easy cache clearing, generative, etc.
+
+ """
+
+ __slots__ = ()
+
+ _memoized_keys = frozenset()
+
+ def _reset_memoizations(self):
+ for elem in self._memoized_keys:
+ self.__dict__.pop(elem, None)
+
+ def _assert_no_memoizations(self):
+ for elem in self._memoized_keys:
+ assert elem not in self.__dict__
+
+ def _set_memoized_attribute(self, key, value):
+ self.__dict__[key] = value
+ self._memoized_keys |= {key}
+
+ class memoized_attribute(object):
+ """A read-only @property that is only evaluated once.
+
+ :meta private:
+
+ """
+
+ def __init__(self, fget, doc=None):
+ self.fget = fget
+ self.__doc__ = doc or fget.__doc__
+ self.__name__ = fget.__name__
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ obj.__dict__[self.__name__] = result = self.fget(obj)
+ obj._memoized_keys |= {self.__name__}
+ return result
+
+ @classmethod
+ def memoized_instancemethod(cls, fn):
+ """Decorate a method memoize its return value."""
+
+ def oneshot(self, *args, **kw):
+ result = fn(self, *args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ self.__dict__[fn.__name__] = memo
+ self._memoized_keys |= {fn.__name__}
+ return result
+
+ return update_wrapper(oneshot, fn)
+
+
+class MemoizedSlots(object):
+ """Apply memoized items to an object using a __getattr__ scheme.
+
+ This allows the functionality of memoized_property and
+ memoized_instancemethod to be available to a class using __slots__.
+
+ """
+
+ __slots__ = ()
+
+ def _fallback_getattr(self, key):
+ raise AttributeError(key)
+
+ def __getattr__(self, key):
+ if key.startswith("_memoized"):
+ raise AttributeError(key)
+ elif hasattr(self, "_memoized_attr_%s" % key):
+ value = getattr(self, "_memoized_attr_%s" % key)()
+ setattr(self, key, value)
+ return value
+ elif hasattr(self, "_memoized_method_%s" % key):
+ fn = getattr(self, "_memoized_method_%s" % key)
+
+ def oneshot(*args, **kw):
+ result = fn(*args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ setattr(self, key, memo)
+ return result
+
+ oneshot.__doc__ = fn.__doc__
+ return oneshot
+ else:
+ return self._fallback_getattr(key)
+
+
+# from paste.deploy.converters
+def asbool(obj):
+ if isinstance(obj, compat.string_types):
+ obj = obj.strip().lower()
+ if obj in ["true", "yes", "on", "y", "t", "1"]:
+ return True
+ elif obj in ["false", "no", "off", "n", "f", "0"]:
+ return False
+ else:
+ raise ValueError("String is not true/false: %r" % obj)
+ return bool(obj)
+
+
+def bool_or_str(*text):
+ """Return a callable that will evaluate a string as
+ boolean, or one of a set of "alternate" string values.
+
+ """
+
+ def bool_or_value(obj):
+ if obj in text:
+ return obj
+ else:
+ return asbool(obj)
+
+ return bool_or_value
+
+
+def asint(value):
+ """Coerce to integer."""
+
+ if value is None:
+ return value
+ return int(value)
+
+
+def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None):
+ r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
+ necessary. If 'flexi_bool' is True, the string '0' is considered false
+ when coercing to boolean.
+ """
+
+ if dest is None:
+ dest = kw
+
+ if (
+ key in kw
+ and (not isinstance(type_, type) or not isinstance(kw[key], type_))
+ and kw[key] is not None
+ ):
+ if type_ is bool and flexi_bool:
+ dest[key] = asbool(kw[key])
+ else:
+ dest[key] = type_(kw[key])
+
+
+def constructor_key(obj, cls):
+ """Produce a tuple structure that is cacheable using the __dict__ of
+ obj to retrieve values
+
+ """
+ names = get_cls_kwargs(cls)
+ return (cls,) + tuple(
+ (k, obj.__dict__[k]) for k in names if k in obj.__dict__
+ )
+
+
+def constructor_copy(obj, cls, *args, **kw):
+ """Instantiate cls using the __dict__ of obj as constructor arguments.
+
+ Uses inspect to match the named arguments of ``cls``.
+
+ """
+
+ names = get_cls_kwargs(cls)
+ kw.update(
+ (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
+ )
+ return cls(*args, **kw)
+
+
+def counter():
+ """Return a threadsafe counter function."""
+
+ lock = compat.threading.Lock()
+ counter = itertools.count(1)
+
+ # avoid the 2to3 "next" transformation...
+ def _next():
+ with lock:
+ return next(counter)
+
+ return _next
+
+
+def duck_type_collection(specimen, default=None):
+ """Given an instance or class, guess if it is or is acting as one of
+ the basic collection types: list, set and dict. If the __emulates__
+ property is present, return that preferentially.
+ """
+
+ if hasattr(specimen, "__emulates__"):
+ # canonicalize set vs sets.Set to a standard: the builtin set
+ if specimen.__emulates__ is not None and issubclass(
+ specimen.__emulates__, set
+ ):
+ return set
+ else:
+ return specimen.__emulates__
+
+ isa = isinstance(specimen, type) and issubclass or isinstance
+ if isa(specimen, list):
+ return list
+ elif isa(specimen, set):
+ return set
+ elif isa(specimen, dict):
+ return dict
+
+ if hasattr(specimen, "append"):
+ return list
+ elif hasattr(specimen, "add"):
+ return set
+ elif hasattr(specimen, "set"):
+ return dict
+ else:
+ return default
+
+
+def assert_arg_type(arg, argtype, name):
+ if isinstance(arg, argtype):
+ return arg
+ else:
+ if isinstance(argtype, tuple):
+ raise exc.ArgumentError(
+ "Argument '%s' is expected to be one of type %s, got '%s'"
+ % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
+ )
+ else:
+ raise exc.ArgumentError(
+ "Argument '%s' is expected to be of type '%s', got '%s'"
+ % (name, argtype, type(arg))
+ )
+
+
+def dictlike_iteritems(dictlike):
+ """Return a (key, value) iterator for almost any dict-like object."""
+
+ if compat.py3k:
+ if hasattr(dictlike, "items"):
+ return list(dictlike.items())
+ else:
+ if hasattr(dictlike, "iteritems"):
+ return dictlike.iteritems()
+ elif hasattr(dictlike, "items"):
+ return iter(dictlike.items())
+
+ getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
+ if getter is None:
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
+
+ if hasattr(dictlike, "iterkeys"):
+
+ def iterator():
+ for key in dictlike.iterkeys():
+ yield key, getter(key)
+
+ return iterator()
+ elif hasattr(dictlike, "keys"):
+ return iter((key, getter(key)) for key in dictlike.keys())
+ else:
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
+
+
+class classproperty(property):
+ """A decorator that behaves like @property except that operates
+ on classes rather than instances.
+
+ The decorator is currently special when using the declarative
+ module, but note that the
+ :class:`~.sqlalchemy.ext.declarative.declared_attr`
+ decorator should be used for this purpose with declarative.
+
+ """
+
+ def __init__(self, fget, *arg, **kw):
+ super(classproperty, self).__init__(fget, *arg, **kw)
+ self.__doc__ = fget.__doc__
+
+ def __get__(desc, self, cls):
+ return desc.fget(cls)
+
+
+class hybridproperty(object):
+ def __init__(self, func):
+ self.func = func
+ self.clslevel = func
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ clsval = self.clslevel(owner)
+ return clsval
+ else:
+ return self.func(instance)
+
+ def classlevel(self, func):
+ self.clslevel = func
+ return self
+
+
+class hybridmethod(object):
+ """Decorate a function as cls- or instance- level."""
+
+ def __init__(self, func):
+ self.func = self.__func__ = func
+ self.clslevel = func
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.clslevel.__get__(owner, owner.__class__)
+ else:
+ return self.func.__get__(instance, owner)
+
+ def classlevel(self, func):
+ self.clslevel = func
+ return self
+
+
+class _symbol(int):
+ def __new__(self, name, doc=None, canonical=None):
+ """Construct a new named symbol."""
+ assert isinstance(name, compat.string_types)
+ if canonical is None:
+ canonical = hash(name)
+ v = int.__new__(_symbol, canonical)
+ v.name = name
+ if doc:
+ v.__doc__ = doc
+ return v
+
+ def __reduce__(self):
+ return symbol, (self.name, "x", int(self))
+
+ def __str__(self):
+ return repr(self)
+
+ def __repr__(self):
+ return "symbol(%r)" % self.name
+
+
+_symbol.__name__ = "symbol"
+
+
+class symbol(object):
+ """A constant symbol.
+
+ >>> symbol('foo') is symbol('foo')
+ True
+ >>> symbol('foo')
+ <symbol 'foo>
+
+ A slight refinement of the MAGICCOOKIE=object() pattern. The primary
+ advantage of symbol() is its repr(). They are also singletons.
+
+ Repeated calls of symbol('name') will all return the same instance.
+
+ The optional ``doc`` argument assigns to ``__doc__``. This
+ is strictly so that Sphinx autoattr picks up the docstring we want
+ (it doesn't appear to pick up the in-module docstring if the datamember
+ is in a different module - autoattribute also blows up completely).
+ If Sphinx fixes/improves this then we would no longer need
+ ``doc`` here.
+
+ """
+
+ symbols = {}
+ _lock = compat.threading.Lock()
+
+ def __new__(cls, name, doc=None, canonical=None):
+ with cls._lock:
+ sym = cls.symbols.get(name)
+ if sym is None:
+ cls.symbols[name] = sym = _symbol(name, doc, canonical)
+ return sym
+
+ @classmethod
+ def parse_user_argument(
+ cls, arg, choices, name, resolve_symbol_names=False
+ ):
+ """Given a user parameter, parse the parameter into a chosen symbol.
+
+ The user argument can be a string name that matches the name of a
+ symbol, or the symbol object itself, or any number of alternate choices
+ such as True/False/ None etc.
+
+ :param arg: the user argument.
+ :param choices: dictionary of symbol object to list of possible
+ entries.
+ :param name: name of the argument. Used in an :class:`.ArgumentError`
+ that is raised if the parameter doesn't match any available argument.
+ :param resolve_symbol_names: include the name of each symbol as a valid
+ entry.
+
+ """
+ # note using hash lookup is tricky here because symbol's `__hash__`
+ # is its int value which we don't want included in the lookup
+ # explicitly, so we iterate and compare each.
+ for sym, choice in choices.items():
+ if arg is sym:
+ return sym
+ elif resolve_symbol_names and arg == sym.name:
+ return sym
+ elif arg in choice:
+ return sym
+
+ if arg is None:
+ return None
+
+ raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+
+
+_creation_order = 1
+
+
+def set_creation_order(instance):
+ """Assign a '_creation_order' sequence to the given instance.
+
+ This allows multiple instances to be sorted in order of creation
+ (typically within a single thread; the counter is not particularly
+ threadsafe).
+
+ """
+ global _creation_order
+ instance._creation_order = _creation_order
+ _creation_order += 1
+
+
+def warn_exception(func, *args, **kwargs):
+ """executes the given function, catches all exceptions and converts to
+ a warning.
+
+ """
+ try:
+ return func(*args, **kwargs)
+ except Exception:
+ warn("%s('%s') ignored" % sys.exc_info()[0:2])
+
+
+def ellipses_string(value, len_=25):
+ try:
+ if len(value) > len_:
+ return "%s..." % value[0:len_]
+ else:
+ return value
+ except TypeError:
+ return value
+
+
+class _hash_limit_string(compat.text_type):
+ """A string subclass that can only be hashed on a maximum amount
+ of unique values.
+
+ This is used for warnings so that we can send out parameterized warnings
+ without the __warningregistry__ of the module, or the non-overridable
+ "once" registry within warnings.py, overloading memory,
+
+
+ """
+
+ def __new__(cls, value, num, args):
+ interpolated = (value % args) + (
+ " (this warning may be suppressed after %d occurrences)" % num
+ )
+ self = super(_hash_limit_string, cls).__new__(cls, interpolated)
+ self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
+ return self
+
+ def __hash__(self):
+ return self._hash
+
+ def __eq__(self, other):
+ return hash(self) == hash(other)
+
+
+def warn(msg, code=None):
+ """Issue a warning.
+
+ If msg is a string, :class:`.exc.SAWarning` is used as
+ the category.
+
+ """
+ if code:
+ _warnings_warn(exc.SAWarning(msg, code=code))
+ else:
+ _warnings_warn(msg, exc.SAWarning)
+
+
+def warn_limited(msg, args):
+ """Issue a warning with a parameterized string, limiting the number
+ of registrations.
+
+ """
+ if args:
+ msg = _hash_limit_string(msg, 10, args)
+ _warnings_warn(msg, exc.SAWarning)
+
+
+def _warnings_warn(message, category=None, stacklevel=2):
+
+ # adjust the given stacklevel to be outside of SQLAlchemy
+ try:
+ frame = sys._getframe(stacklevel)
+ except ValueError:
+ # being called from less than 3 (or given) stacklevels, weird,
+ # but don't crash
+ stacklevel = 0
+ except:
+ # _getframe() doesn't work, weird interpreter issue, weird,
+ # ok, but don't crash
+ stacklevel = 0
+ else:
+ # using __name__ here requires that we have __name__ in the
+ # __globals__ of the decorated string functions we make also.
+ # we generate this using {"__name__": fn.__module__}
+ while frame is not None and re.match(
+ r"^(?:sqlalchemy\.|alembic\.)", frame.f_globals.get("__name__", "")
+ ):
+ frame = frame.f_back
+ stacklevel += 1
+
+ if category is not None:
+ warnings.warn(message, category, stacklevel=stacklevel + 1)
+ else:
+ warnings.warn(message, stacklevel=stacklevel + 1)
+
+
+def only_once(fn, retry_on_exception):
+ """Decorate the given function to be a no-op after it is called exactly
+ once."""
+
+ once = [fn]
+
+ def go(*arg, **kw):
+ # strong reference fn so that it isn't garbage collected,
+ # which interferes with the event system's expectations
+ strong_fn = fn # noqa
+ if once:
+ once_fn = once.pop()
+ try:
+ return once_fn(*arg, **kw)
+ except:
+ if retry_on_exception:
+ once.insert(0, once_fn)
+ raise
+
+ return go
+
+
+_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
+_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
+
+
+def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
+ """Chop extraneous lines off beginning and end of a traceback.
+
+ :param tb:
+ a list of traceback lines as returned by ``traceback.format_stack()``
+
+ :param exclude_prefix:
+ a regular expression object matching lines to skip at beginning of
+ ``tb``
+
+ :param exclude_suffix:
+ a regular expression object matching lines to skip at end of ``tb``
+ """
+ start = 0
+ end = len(tb) - 1
+ while start <= end and exclude_prefix.search(tb[start]):
+ start += 1
+ while start <= end and exclude_suffix.search(tb[end]):
+ end -= 1
+ return tb[start : end + 1]
+
+
+NoneType = type(None)
+
+
+def attrsetter(attrname):
+ code = "def set(obj, value):" " obj.%s = value" % attrname
+ env = locals().copy()
+ exec(code, env)
+ return env["set"]
+
+
+class EnsureKWArgType(type):
+ r"""Apply translation of functions to accept \**kw arguments if they
+ don't already.
+
+ """
+
+ def __init__(cls, clsname, bases, clsdict):
+ fn_reg = cls.ensure_kwarg
+ if fn_reg:
+ for key in clsdict:
+ m = re.match(fn_reg, key)
+ if m:
+ fn = clsdict[key]
+ spec = compat.inspect_getfullargspec(fn)
+ if not spec.varkw:
+ clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+ setattr(cls, key, wrapped)
+ super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+
+ def _wrap_w_kw(self, fn):
+ def wrap(*arg, **kw):
+ return fn(*arg)
+
+ return update_wrapper(wrap, fn)
+
+
+def wrap_callable(wrapper, fn):
+ """Augment functools.update_wrapper() to work with objects with
+ a ``__call__()`` method.
+
+ :param fn:
+ object with __call__ method
+
+ """
+ if hasattr(fn, "__name__"):
+ return update_wrapper(wrapper, fn)
+ else:
+ _f = wrapper
+ _f.__name__ = fn.__class__.__name__
+ if hasattr(fn, "__module__"):
+ _f.__module__ = fn.__module__
+
+ if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
+ _f.__doc__ = fn.__call__.__doc__
+ elif fn.__doc__:
+ _f.__doc__ = fn.__doc__
+
+ return _f
+
+
+def quoted_token_parser(value):
+ """Parse a dotted identifier with accommodation for quoted names.
+
+ Includes support for SQL-style double quotes as a literal character.
+
+ E.g.::
+
+ >>> quoted_token_parser("name")
+ ["name"]
+ >>> quoted_token_parser("schema.name")
+ ["schema", "name"]
+ >>> quoted_token_parser('"Schema"."Name"')
+ ['Schema', 'Name']
+ >>> quoted_token_parser('"Schema"."Name""Foo"')
+ ['Schema', 'Name""Foo']
+
+ """
+
+ if '"' not in value:
+ return value.split(".")
+
+ # 0 = outside of quotes
+ # 1 = inside of quotes
+ state = 0
+ result = [[]]
+ idx = 0
+ lv = len(value)
+ while idx < lv:
+ char = value[idx]
+ if char == '"':
+ if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
+ result[-1].append('"')
+ idx += 1
+ else:
+ state ^= 1
+ elif char == "." and state == 0:
+ result.append([])
+ else:
+ result[-1].append(char)
+ idx += 1
+
+ return ["".join(token) for token in result]
+
+
+def add_parameter_text(params, text):
+ params = _collections.to_list(params)
+
+ def decorate(fn):
+ doc = fn.__doc__ is not None and fn.__doc__ or ""
+ if doc:
+ doc = inject_param_text(doc, {param: text for param in params})
+ fn.__doc__ = doc
+ return fn
+
+ return decorate
+
+
+def _dedent_docstring(text):
+ split_text = text.split("\n", 1)
+ if len(split_text) == 1:
+ return text
+ else:
+ firstline, remaining = split_text
+ if not firstline.startswith(" "):
+ return firstline + "\n" + textwrap.dedent(remaining)
+ else:
+ return textwrap.dedent(text)
+
+
+def inject_docstring_text(doctext, injecttext, pos):
+ doctext = _dedent_docstring(doctext or "")
+ lines = doctext.split("\n")
+ if len(lines) == 1:
+ lines.append("")
+ injectlines = textwrap.dedent(injecttext).split("\n")
+ if injectlines[0]:
+ injectlines.insert(0, "")
+
+ blanks = [num for num, line in enumerate(lines) if not line.strip()]
+ blanks.insert(0, 0)
+
+ inject_pos = blanks[min(pos, len(blanks) - 1)]
+
+ lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
+ return "\n".join(lines)
+
+
+_param_reg = re.compile(r"(\s+):param (.+?):")
+
+
+def inject_param_text(doctext, inject_params):
+ doclines = collections.deque(doctext.splitlines())
+ lines = []
+
+ # TODO: this is not working for params like ":param case_sensitive=True:"
+
+ to_inject = None
+ while doclines:
+ line = doclines.popleft()
+
+ m = _param_reg.match(line)
+
+ if to_inject is None:
+ if m:
+ param = m.group(2).lstrip("*")
+ if param in inject_params:
+ # default indent to that of :param: plus one
+ indent = " " * len(m.group(1)) + " "
+
+ # but if the next line has text, use that line's
+ # indentation
+ if doclines:
+ m2 = re.match(r"(\s+)\S", doclines[0])
+ if m2:
+ indent = " " * len(m2.group(1))
+
+ to_inject = indent + inject_params[param]
+ elif m:
+ lines.extend(["\n", to_inject, "\n"])
+ to_inject = None
+ elif not line.rstrip():
+ lines.extend([line, to_inject, "\n"])
+ to_inject = None
+ elif line.endswith("::"):
+ # TODO: this still wont cover if the code example itself has blank
+ # lines in it, need to detect those via indentation.
+ lines.extend([line, doclines.popleft()])
+ continue
+ lines.append(line)
+
+ return "\n".join(lines)
+
+
+def repr_tuple_names(names):
+ """Trims a list of strings from the middle and return a string of up to
+ four elements. Strings greater than 11 characters will be truncated"""
+ if len(names) == 0:
+ return None
+ flag = len(names) <= 4
+ names = names[0:4] if flag else names[0:3] + names[-1:]
+ res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
+ if flag:
+ return ", ".join(res)
+ else:
+ return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
+
+
+def has_compiled_ext():
+ try:
+ from sqlalchemy import cimmutabledict # noqa: F401
+ from sqlalchemy import cprocessors # noqa: F401
+ from sqlalchemy import cresultproxy # noqa: F401
+
+ return True
+ except ImportError:
+ return False
diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py
new file mode 100644
index 0000000..67c5219
--- /dev/null
+++ b/lib/sqlalchemy/util/queue.py
@@ -0,0 +1,291 @@
+# util/queue.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""An adaptation of Py2.3/2.4's Queue module which supports reentrant
+behavior, using RLock instead of Lock for its mutex object. The
+Queue object is used exclusively by the sqlalchemy.pool.QueuePool
+class.
+
+This is to support the connection pool's usage of weakref callbacks to return
+connections to the underlying Queue, which can in extremely
+rare cases be invoked within the ``get()`` method of the Queue itself,
+producing a ``put()`` inside the ``get()`` and therefore a reentrant
+condition.
+
+"""
+
+from collections import deque
+from time import time as _time
+
+from . import compat
+from .compat import threading
+from .concurrency import asyncio
+from .concurrency import await_fallback
+from .concurrency import await_only
+from .langhelpers import memoized_property
+
+
+__all__ = ["Empty", "Full", "Queue"]
+
+
+class Empty(Exception):
+ "Exception raised by Queue.get(block=0)/get_nowait()."
+
+ pass
+
+
+class Full(Exception):
+ "Exception raised by Queue.put(block=0)/put_nowait()."
+
+ pass
+
+
+class Queue:
+ def __init__(self, maxsize=0, use_lifo=False):
+ """Initialize a queue object with a given maximum size.
+
+ If `maxsize` is <= 0, the queue size is infinite.
+
+ If `use_lifo` is True, this Queue acts like a Stack (LIFO).
+ """
+
+ self._init(maxsize)
+ # mutex must be held whenever the queue is mutating. All methods
+ # that acquire mutex must release it before returning. mutex
+ # is shared between the two conditions, so acquiring and
+ # releasing the conditions also acquires and releases mutex.
+ self.mutex = threading.RLock()
+ # Notify not_empty whenever an item is added to the queue; a
+ # thread waiting to get is notified then.
+ self.not_empty = threading.Condition(self.mutex)
+ # Notify not_full whenever an item is removed from the queue;
+ # a thread waiting to put is notified then.
+ self.not_full = threading.Condition(self.mutex)
+ # If this queue uses LIFO or FIFO
+ self.use_lifo = use_lifo
+
+ def qsize(self):
+ """Return the approximate size of the queue (not reliable!)."""
+
+ with self.mutex:
+ return self._qsize()
+
+ def empty(self):
+ """Return True if the queue is empty, False otherwise (not
+ reliable!)."""
+
+ with self.mutex:
+ return self._empty()
+
+ def full(self):
+ """Return True if the queue is full, False otherwise (not
+ reliable!)."""
+
+ with self.mutex:
+ return self._full()
+
+ def put(self, item, block=True, timeout=None):
+ """Put an item into the queue.
+
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until a free slot is
+ available. If `timeout` is a positive number, it blocks at
+ most `timeout` seconds and raises the ``Full`` exception if no
+ free slot was available within that time. Otherwise (`block`
+ is false), put an item on the queue if a free slot is
+ immediately available, else raise the ``Full`` exception
+ (`timeout` is ignored in that case).
+ """
+
+ with self.not_full:
+ if not block:
+ if self._full():
+ raise Full
+ elif timeout is None:
+ while self._full():
+ self.not_full.wait()
+ else:
+ if timeout < 0:
+ raise ValueError("'timeout' must be a positive number")
+ endtime = _time() + timeout
+ while self._full():
+ remaining = endtime - _time()
+ if remaining <= 0.0:
+ raise Full
+ self.not_full.wait(remaining)
+ self._put(item)
+ self.not_empty.notify()
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ Only enqueue the item if a free slot is immediately available.
+ Otherwise raise the ``Full`` exception.
+ """
+ return self.put(item, False)
+
+ def get(self, block=True, timeout=None):
+ """Remove and return an item from the queue.
+
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until an item is available. If
+ `timeout` is a positive number, it blocks at most `timeout`
+ seconds and raises the ``Empty`` exception if no item was
+ available within that time. Otherwise (`block` is false),
+ return an item if one is immediately available, else raise the
+ ``Empty`` exception (`timeout` is ignored in that case).
+
+ """
+ with self.not_empty:
+ if not block:
+ if self._empty():
+ raise Empty
+ elif timeout is None:
+ while self._empty():
+ self.not_empty.wait()
+ else:
+ if timeout < 0:
+ raise ValueError("'timeout' must be a positive number")
+ endtime = _time() + timeout
+ while self._empty():
+ remaining = endtime - _time()
+ if remaining <= 0.0:
+ raise Empty
+ self.not_empty.wait(remaining)
+ item = self._get()
+ self.not_full.notify()
+ return item
+
+ def get_nowait(self):
+ """Remove and return an item from the queue without blocking.
+
+ Only get an item if one is immediately available. Otherwise
+ raise the ``Empty`` exception.
+ """
+
+ return self.get(False)
+
+ # Override these methods to implement other queue organizations
+ # (e.g. stack or priority queue).
+ # These will only be called with appropriate locks held
+
+ # Initialize the queue representation
+ def _init(self, maxsize):
+ self.maxsize = maxsize
+ self.queue = deque()
+
+ def _qsize(self):
+ return len(self.queue)
+
+ # Check whether the queue is empty
+ def _empty(self):
+ return not self.queue
+
+ # Check whether the queue is full
+ def _full(self):
+ return self.maxsize > 0 and len(self.queue) == self.maxsize
+
+ # Put a new item in the queue
+ def _put(self, item):
+ self.queue.append(item)
+
+ # Get an item from the queue
+ def _get(self):
+ if self.use_lifo:
+ # LIFO
+ return self.queue.pop()
+ else:
+ # FIFO
+ return self.queue.popleft()
+
+
+class AsyncAdaptedQueue:
+ await_ = staticmethod(await_only)
+
+ def __init__(self, maxsize=0, use_lifo=False):
+ self.use_lifo = use_lifo
+ self.maxsize = maxsize
+
+ def empty(self):
+ return self._queue.empty()
+
+ def full(self):
+ return self._queue.full()
+
+ def qsize(self):
+ return self._queue.qsize()
+
+ @memoized_property
+ def _queue(self):
+ # Delay creation of the queue until it is first used, to avoid
+ # binding it to a possibly wrong event loop.
+ # By delaying the creation of the pool we accommodate the common
+ # usage pattern of instantiating the engine at module level, where a
+ # different event loop is in present compared to when the application
+ # is actually run.
+
+ if self.use_lifo:
+ queue = asyncio.LifoQueue(maxsize=self.maxsize)
+ else:
+ queue = asyncio.Queue(maxsize=self.maxsize)
+ return queue
+
+ def put_nowait(self, item):
+ try:
+ return self._queue.put_nowait(item)
+ except asyncio.QueueFull as err:
+ compat.raise_(
+ Full(),
+ replace_context=err,
+ )
+
+ def put(self, item, block=True, timeout=None):
+ if not block:
+ return self.put_nowait(item)
+
+ try:
+ if timeout is not None:
+ return self.await_(
+ asyncio.wait_for(self._queue.put(item), timeout)
+ )
+ else:
+ return self.await_(self._queue.put(item))
+ except (asyncio.QueueFull, asyncio.TimeoutError) as err:
+ compat.raise_(
+ Full(),
+ replace_context=err,
+ )
+
+ def get_nowait(self):
+ try:
+ return self._queue.get_nowait()
+ except asyncio.QueueEmpty as err:
+ compat.raise_(
+ Empty(),
+ replace_context=err,
+ )
+
+ def get(self, block=True, timeout=None):
+ if not block:
+ return self.get_nowait()
+
+ try:
+ if timeout is not None:
+ return self.await_(
+ asyncio.wait_for(self._queue.get(), timeout)
+ )
+ else:
+ return self.await_(self._queue.get())
+ except (asyncio.QueueEmpty, asyncio.TimeoutError) as err:
+ compat.raise_(
+ Empty(),
+ replace_context=err,
+ )
+
+
+class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue):
+ await_ = staticmethod(await_fallback)
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
new file mode 100644
index 0000000..bbc819f
--- /dev/null
+++ b/lib/sqlalchemy/util/topological.py
@@ -0,0 +1,100 @@
+# util/topological.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Topological sorting algorithms."""
+
+from .. import util
+from ..exc import CircularDependencyError
+
+__all__ = ["sort", "sort_as_subsets", "find_cycles"]
+
+
+def sort_as_subsets(tuples, allitems):
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ todo = list(allitems)
+ todo_set = set(allitems)
+
+ while todo_set:
+ output = []
+ for node in todo:
+ if todo_set.isdisjoint(edges[node]):
+ output.append(node)
+
+ if not output:
+ raise CircularDependencyError(
+ "Circular dependency detected.",
+ find_cycles(tuples, allitems),
+ _gen_edges(edges),
+ )
+
+ todo_set.difference_update(output)
+ todo = [t for t in todo if t in todo_set]
+ yield output
+
+
+def sort(tuples, allitems, deterministic_order=True):
+ """sort the given list of items by dependency.
+
+ 'tuples' is a list of tuples representing a partial ordering.
+
+ deterministic_order is no longer used, the order is now always
+ deterministic given the order of "allitems". the flag is there
+ for backwards compatibility with Alembic.
+
+ """
+
+ for set_ in sort_as_subsets(tuples, allitems):
+ for s in set_:
+ yield s
+
+
+def find_cycles(tuples, allitems):
+ # adapted from:
+ # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[parent].add(child)
+ nodes_to_test = set(edges)
+
+ output = set()
+
+ # we'd like to find all nodes that are
+ # involved in cycles, so we do the full
+ # pass through the whole thing for each
+ # node in the original list.
+
+ # we can go just through parent edge nodes.
+ # if a node is only a child and never a parent,
+ # by definition it can't be part of a cycle. same
+ # if it's not in the edges at all.
+ for node in nodes_to_test:
+ stack = [node]
+ todo = nodes_to_test.difference(stack)
+ while stack:
+ top = stack[-1]
+ for node in edges[top]:
+ if node in stack:
+ cyc = stack[stack.index(node) :]
+ todo.difference_update(cyc)
+ output.update(cyc)
+
+ if node in todo:
+ stack.append(node)
+ todo.remove(node)
+ break
+ else:
+ node = stack.pop()
+ return output
+
+
+def _gen_edges(edges):
+ return set([(right, left) for left in edges for right in edges[left]])