first add files
This commit is contained in:
11
lib/sqlalchemy/ext/__init__.py
Normal file
11
lib/sqlalchemy/ext/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# ext/__init__.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .. import util as _sa_util
|
||||
|
||||
|
||||
_sa_util.preloaded.import_prefix("sqlalchemy.ext")
|
||||
1627
lib/sqlalchemy/ext/associationproxy.py
Normal file
1627
lib/sqlalchemy/ext/associationproxy.py
Normal file
File diff suppressed because it is too large
Load Diff
22
lib/sqlalchemy/ext/asyncio/__init__.py
Normal file
22
lib/sqlalchemy/ext/asyncio/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# ext/asyncio/__init__.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .engine import async_engine_from_config
|
||||
from .engine import AsyncConnection
|
||||
from .engine import AsyncEngine
|
||||
from .engine import AsyncTransaction
|
||||
from .engine import create_async_engine
|
||||
from .events import AsyncConnectionEvents
|
||||
from .events import AsyncSessionEvents
|
||||
from .result import AsyncMappingResult
|
||||
from .result import AsyncResult
|
||||
from .result import AsyncScalarResult
|
||||
from .scoping import async_scoped_session
|
||||
from .session import async_object_session
|
||||
from .session import async_session
|
||||
from .session import AsyncSession
|
||||
from .session import AsyncSessionTransaction
|
||||
89
lib/sqlalchemy/ext/asyncio/base.py
Normal file
89
lib/sqlalchemy/ext/asyncio/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import abc
|
||||
import functools
|
||||
import weakref
|
||||
|
||||
from . import exc as async_exc
|
||||
|
||||
|
||||
class ReversibleProxy:
|
||||
# weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
|
||||
_proxy_objects = {}
|
||||
__slots__ = ("__weakref__",)
|
||||
|
||||
def _assign_proxied(self, target):
|
||||
if target is not None:
|
||||
target_ref = weakref.ref(target, ReversibleProxy._target_gced)
|
||||
proxy_ref = weakref.ref(
|
||||
self,
|
||||
functools.partial(ReversibleProxy._target_gced, target_ref),
|
||||
)
|
||||
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
|
||||
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def _target_gced(cls, ref, proxy_ref=None):
|
||||
cls._proxy_objects.pop(ref, None)
|
||||
|
||||
@classmethod
|
||||
def _regenerate_proxy_for_target(cls, target):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _retrieve_proxy_for_target(cls, target, regenerate=True):
|
||||
try:
|
||||
proxy_ref = cls._proxy_objects[weakref.ref(target)]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
proxy = proxy_ref()
|
||||
if proxy is not None:
|
||||
return proxy
|
||||
|
||||
if regenerate:
|
||||
return cls._regenerate_proxy_for_target(target)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class StartableContext(abc.ABC):
|
||||
__slots__ = ()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def start(self, is_ctxmanager=False):
|
||||
pass
|
||||
|
||||
def __await__(self):
|
||||
return self.start().__await__()
|
||||
|
||||
async def __aenter__(self):
|
||||
return await self.start(is_ctxmanager=True)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
pass
|
||||
|
||||
def _raise_for_not_started(self):
|
||||
raise async_exc.AsyncContextNotStarted(
|
||||
"%s context has not been started and object has not been awaited."
|
||||
% (self.__class__.__name__)
|
||||
)
|
||||
|
||||
|
||||
class ProxyComparable(ReversibleProxy):
|
||||
__slots__ = ()
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self._proxied == other._proxied
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return (
|
||||
not isinstance(other, self.__class__)
|
||||
or self._proxied != other._proxied
|
||||
)
|
||||
828
lib/sqlalchemy/ext/asyncio/engine.py
Normal file
828
lib/sqlalchemy/ext/asyncio/engine.py
Normal file
@@ -0,0 +1,828 @@
|
||||
# ext/asyncio/engine.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
import asyncio
|
||||
|
||||
from . import exc as async_exc
|
||||
from .base import ProxyComparable
|
||||
from .base import StartableContext
|
||||
from .result import _ensure_sync_result
|
||||
from .result import AsyncResult
|
||||
from ... import exc
|
||||
from ... import inspection
|
||||
from ... import util
|
||||
from ...engine import create_engine as _create_engine
|
||||
from ...engine.base import NestedTransaction
|
||||
from ...future import Connection
|
||||
from ...future import Engine
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
|
||||
|
||||
def create_async_engine(*arg, **kw):
|
||||
"""Create a new async engine instance.
|
||||
|
||||
Arguments passed to :func:`_asyncio.create_async_engine` are mostly
|
||||
identical to those passed to the :func:`_sa.create_engine` function.
|
||||
The specified dialect must be an asyncio-compatible dialect
|
||||
such as :ref:`dialect-postgresql-asyncpg`.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
if kw.get("server_side_cursors", False):
|
||||
raise async_exc.AsyncMethodRequired(
|
||||
"Can't set server_side_cursors for async engine globally; "
|
||||
"use the connection.stream() method for an async "
|
||||
"streaming result set"
|
||||
)
|
||||
kw["future"] = True
|
||||
sync_engine = _create_engine(*arg, **kw)
|
||||
return AsyncEngine(sync_engine)
|
||||
|
||||
|
||||
def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
|
||||
"""Create a new AsyncEngine instance using a configuration dictionary.
|
||||
|
||||
This function is analogous to the :func:`_sa.engine_from_config` function
|
||||
in SQLAlchemy Core, except that the requested dialect must be an
|
||||
asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`.
|
||||
The argument signature of the function is identical to that
|
||||
of :func:`_sa.engine_from_config`.
|
||||
|
||||
.. versionadded:: 1.4.29
|
||||
|
||||
"""
|
||||
options = {
|
||||
key[len(prefix) :]: value
|
||||
for key, value in configuration.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
options["_coerce_config"] = True
|
||||
options.update(kwargs)
|
||||
url = options.pop("url")
|
||||
return create_async_engine(url, **options)
|
||||
|
||||
|
||||
class AsyncConnectable:
|
||||
__slots__ = "_slots_dispatch", "__weakref__"
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
Connection,
|
||||
":class:`_future.Connection`",
|
||||
":class:`_asyncio.AsyncConnection`",
|
||||
classmethods=[],
|
||||
methods=[],
|
||||
attributes=[
|
||||
"closed",
|
||||
"invalidated",
|
||||
"dialect",
|
||||
"default_isolation_level",
|
||||
],
|
||||
)
|
||||
class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
|
||||
"""An asyncio proxy for a :class:`_engine.Connection`.
|
||||
|
||||
:class:`_asyncio.AsyncConnection` is acquired using the
|
||||
:meth:`_asyncio.AsyncEngine.connect`
|
||||
method of :class:`_asyncio.AsyncEngine`::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
|
||||
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(select(table))
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
""" # noqa
|
||||
|
||||
# AsyncConnection is a thin proxy; no state should be added here
|
||||
# that is not retrievable from the "sync" engine / connection, e.g.
|
||||
# current transaction, info, etc. It should be possible to
|
||||
# create a new AsyncConnection that matches this one given only the
|
||||
# "sync" elements.
|
||||
__slots__ = (
|
||||
"engine",
|
||||
"sync_engine",
|
||||
"sync_connection",
|
||||
)
|
||||
|
||||
def __init__(self, async_engine, sync_connection=None):
|
||||
self.engine = async_engine
|
||||
self.sync_engine = async_engine.sync_engine
|
||||
self.sync_connection = self._assign_proxied(sync_connection)
|
||||
|
||||
sync_connection: Connection
|
||||
"""Reference to the sync-style :class:`_engine.Connection` this
|
||||
:class:`_asyncio.AsyncConnection` proxies requests towards.
|
||||
|
||||
This instance can be used as an event target.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`asyncio_events`
|
||||
"""
|
||||
|
||||
sync_engine: Engine
|
||||
"""Reference to the sync-style :class:`_engine.Engine` this
|
||||
:class:`_asyncio.AsyncConnection` is associated with via its underlying
|
||||
:class:`_engine.Connection`.
|
||||
|
||||
This instance can be used as an event target.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`asyncio_events`
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _regenerate_proxy_for_target(cls, target):
|
||||
return AsyncConnection(
|
||||
AsyncEngine._retrieve_proxy_for_target(target.engine), target
|
||||
)
|
||||
|
||||
async def start(self, is_ctxmanager=False):
|
||||
"""Start this :class:`_asyncio.AsyncConnection` object's context
|
||||
outside of using a Python ``with:`` block.
|
||||
|
||||
"""
|
||||
if self.sync_connection:
|
||||
raise exc.InvalidRequestError("connection is already started")
|
||||
self.sync_connection = self._assign_proxied(
|
||||
await (greenlet_spawn(self.sync_engine.connect))
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
"""Not implemented for async; call
|
||||
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
|
||||
"""
|
||||
raise exc.InvalidRequestError(
|
||||
"AsyncConnection.connection accessor is not implemented as the "
|
||||
"attribute may need to reconnect on an invalidated connection. "
|
||||
"Use the get_raw_connection() method."
|
||||
)
|
||||
|
||||
async def get_raw_connection(self):
|
||||
"""Return the pooled DBAPI-level connection in use by this
|
||||
:class:`_asyncio.AsyncConnection`.
|
||||
|
||||
This is a SQLAlchemy connection-pool proxied connection
|
||||
which then has the attribute
|
||||
:attr:`_pool._ConnectionFairy.driver_connection` that refers to the
|
||||
actual driver connection. Its
|
||||
:attr:`_pool._ConnectionFairy.dbapi_connection` refers instead
|
||||
to an :class:`_engine.AdaptedConnection` instance that
|
||||
adapts the driver connection to the DBAPI protocol.
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
return await greenlet_spawn(getattr, conn, "connection")
|
||||
|
||||
@property
|
||||
def _proxied(self):
|
||||
return self.sync_connection
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
"""Return the :attr:`_engine.Connection.info` dictionary of the
|
||||
underlying :class:`_engine.Connection`.
|
||||
|
||||
This dictionary is freely writable for user-defined state to be
|
||||
associated with the database connection.
|
||||
|
||||
This attribute is only available if the :class:`.AsyncConnection` is
|
||||
currently connected. If the :attr:`.AsyncConnection.closed` attribute
|
||||
is ``True``, then accessing this attribute will raise
|
||||
:class:`.ResourceClosedError`.
|
||||
|
||||
.. versionadded:: 1.4.0b2
|
||||
|
||||
"""
|
||||
return self.sync_connection.info
|
||||
|
||||
def _sync_connection(self):
|
||||
if not self.sync_connection:
|
||||
self._raise_for_not_started()
|
||||
return self.sync_connection
|
||||
|
||||
def begin(self):
|
||||
"""Begin a transaction prior to autobegin occurring."""
|
||||
self._sync_connection()
|
||||
return AsyncTransaction(self)
|
||||
|
||||
def begin_nested(self):
|
||||
"""Begin a nested transaction and return a transaction handle."""
|
||||
self._sync_connection()
|
||||
return AsyncTransaction(self, nested=True)
|
||||
|
||||
async def invalidate(self, exception=None):
|
||||
"""Invalidate the underlying DBAPI connection associated with
|
||||
this :class:`_engine.Connection`.
|
||||
|
||||
See the method :meth:`_engine.Connection.invalidate` for full
|
||||
detail on this method.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.invalidate, exception=exception)
|
||||
|
||||
async def get_isolation_level(self):
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.get_isolation_level)
|
||||
|
||||
async def set_isolation_level(self):
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.get_isolation_level)
|
||||
|
||||
def in_transaction(self):
|
||||
"""Return True if a transaction is in progress.
|
||||
|
||||
.. versionadded:: 1.4.0b2
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
|
||||
return conn.in_transaction()
|
||||
|
||||
def in_nested_transaction(self):
|
||||
"""Return True if a transaction is in progress.
|
||||
|
||||
.. versionadded:: 1.4.0b2
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
return conn.in_nested_transaction()
|
||||
|
||||
def get_transaction(self):
|
||||
"""Return an :class:`.AsyncTransaction` representing the current
|
||||
transaction, if any.
|
||||
|
||||
This makes use of the underlying synchronous connection's
|
||||
:meth:`_engine.Connection.get_transaction` method to get the current
|
||||
:class:`_engine.Transaction`, which is then proxied in a new
|
||||
:class:`.AsyncTransaction` object.
|
||||
|
||||
.. versionadded:: 1.4.0b2
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
trans = conn.get_transaction()
|
||||
if trans is not None:
|
||||
return AsyncTransaction._retrieve_proxy_for_target(trans)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_nested_transaction(self):
|
||||
"""Return an :class:`.AsyncTransaction` representing the current
|
||||
nested (savepoint) transaction, if any.
|
||||
|
||||
This makes use of the underlying synchronous connection's
|
||||
:meth:`_engine.Connection.get_nested_transaction` method to get the
|
||||
current :class:`_engine.Transaction`, which is then proxied in a new
|
||||
:class:`.AsyncTransaction` object.
|
||||
|
||||
.. versionadded:: 1.4.0b2
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
trans = conn.get_nested_transaction()
|
||||
if trans is not None:
|
||||
return AsyncTransaction._retrieve_proxy_for_target(trans)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def execution_options(self, **opt):
|
||||
r"""Set non-SQL options for the connection which take effect
|
||||
during execution.
|
||||
|
||||
This returns this :class:`_asyncio.AsyncConnection` object with
|
||||
the new options added.
|
||||
|
||||
See :meth:`_future.Connection.execution_options` for full details
|
||||
on this method.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
c2 = await greenlet_spawn(conn.execution_options, **opt)
|
||||
assert c2 is conn
|
||||
return self
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the transaction that is currently in progress.
|
||||
|
||||
This method commits the current transaction if one has been started.
|
||||
If no transaction was started, the method has no effect, assuming
|
||||
the connection is in a non-invalidated state.
|
||||
|
||||
A transaction is begun on a :class:`_future.Connection` automatically
|
||||
whenever a statement is first executed, or when the
|
||||
:meth:`_future.Connection.begin` method is called.
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
await greenlet_spawn(conn.commit)
|
||||
|
||||
async def rollback(self):
|
||||
"""Roll back the transaction that is currently in progress.
|
||||
|
||||
This method rolls back the current transaction if one has been started.
|
||||
If no transaction was started, the method has no effect. If a
|
||||
transaction was started and the connection is in an invalidated state,
|
||||
the transaction is cleared using this method.
|
||||
|
||||
A transaction is begun on a :class:`_future.Connection` automatically
|
||||
whenever a statement is first executed, or when the
|
||||
:meth:`_future.Connection.begin` method is called.
|
||||
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
await greenlet_spawn(conn.rollback)
|
||||
|
||||
async def close(self):
|
||||
"""Close this :class:`_asyncio.AsyncConnection`.
|
||||
|
||||
This has the effect of also rolling back the transaction if one
|
||||
is in place.
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
await greenlet_spawn(conn.close)
|
||||
|
||||
async def exec_driver_sql(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
r"""Executes a driver-level SQL string and return buffered
|
||||
:class:`_engine.Result`.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
|
||||
result = await greenlet_spawn(
|
||||
conn.exec_driver_sql,
|
||||
statement,
|
||||
parameters,
|
||||
execution_options,
|
||||
_require_await=True,
|
||||
)
|
||||
|
||||
return await _ensure_sync_result(result, self.exec_driver_sql)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
"""Execute a statement and return a streaming
|
||||
:class:`_asyncio.AsyncResult` object."""
|
||||
|
||||
conn = self._sync_connection()
|
||||
|
||||
result = await greenlet_spawn(
|
||||
conn._execute_20,
|
||||
statement,
|
||||
parameters,
|
||||
util.EMPTY_DICT.merge_with(
|
||||
execution_options, {"stream_results": True}
|
||||
),
|
||||
_require_await=True,
|
||||
)
|
||||
if not result.context._is_server_side:
|
||||
# TODO: real exception here
|
||||
assert False, "server side result expected"
|
||||
return AsyncResult(result)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
r"""Executes a SQL statement construct and return a buffered
|
||||
:class:`_engine.Result`.
|
||||
|
||||
:param object: The statement to be executed. This is always
|
||||
an object that is in both the :class:`_expression.ClauseElement` and
|
||||
:class:`_expression.Executable` hierarchies, including:
|
||||
|
||||
* :class:`_expression.Select`
|
||||
* :class:`_expression.Insert`, :class:`_expression.Update`,
|
||||
:class:`_expression.Delete`
|
||||
* :class:`_expression.TextClause` and
|
||||
:class:`_expression.TextualSelect`
|
||||
* :class:`_schema.DDL` and objects which inherit from
|
||||
:class:`_schema.DDLElement`
|
||||
|
||||
:param parameters: parameters which will be bound into the statement.
|
||||
This may be either a dictionary of parameter names to values,
|
||||
or a mutable sequence (e.g. a list) of dictionaries. When a
|
||||
list of dictionaries is passed, the underlying statement execution
|
||||
will make use of the DBAPI ``cursor.executemany()`` method.
|
||||
When a single dictionary is passed, the DBAPI ``cursor.execute()``
|
||||
method will be used.
|
||||
|
||||
:param execution_options: optional dictionary of execution options,
|
||||
which will be associated with the statement execution. This
|
||||
dictionary can provide a subset of the options that are accepted
|
||||
by :meth:`_future.Connection.execution_options`.
|
||||
|
||||
:return: a :class:`_engine.Result` object.
|
||||
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
result = await greenlet_spawn(
|
||||
conn._execute_20,
|
||||
statement,
|
||||
parameters,
|
||||
execution_options,
|
||||
_require_await=True,
|
||||
)
|
||||
return await _ensure_sync_result(result, self.execute)
|
||||
|
||||
async def scalar(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
r"""Executes a SQL statement construct and returns a scalar object.
|
||||
|
||||
This method is shorthand for invoking the
|
||||
:meth:`_engine.Result.scalar` method after invoking the
|
||||
:meth:`_future.Connection.execute` method. Parameters are equivalent.
|
||||
|
||||
:return: a scalar Python value representing the first column of the
|
||||
first row returned.
|
||||
|
||||
"""
|
||||
result = await self.execute(statement, parameters, execution_options)
|
||||
return result.scalar()
|
||||
|
||||
async def scalars(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
r"""Executes a SQL statement construct and returns a scalar objects.
|
||||
|
||||
This method is shorthand for invoking the
|
||||
:meth:`_engine.Result.scalars` method after invoking the
|
||||
:meth:`_future.Connection.execute` method. Parameters are equivalent.
|
||||
|
||||
:return: a :class:`_engine.ScalarResult` object.
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
"""
|
||||
result = await self.execute(statement, parameters, execution_options)
|
||||
return result.scalars()
|
||||
|
||||
async def stream_scalars(
|
||||
self,
|
||||
statement,
|
||||
parameters=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
):
|
||||
r"""Executes a SQL statement and returns a streaming scalar result
|
||||
object.
|
||||
|
||||
This method is shorthand for invoking the
|
||||
:meth:`_engine.AsyncResult.scalars` method after invoking the
|
||||
:meth:`_future.Connection.stream` method. Parameters are equivalent.
|
||||
|
||||
:return: an :class:`_asyncio.AsyncScalarResult` object.
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
"""
|
||||
result = await self.stream(statement, parameters, execution_options)
|
||||
return result.scalars()
|
||||
|
||||
async def run_sync(self, fn, *arg, **kw):
|
||||
"""Invoke the given sync callable passing self as the first argument.
|
||||
|
||||
This method maintains the asyncio event loop all the way through
|
||||
to the database connection by running the given callable in a
|
||||
specially instrumented greenlet.
|
||||
|
||||
E.g.::
|
||||
|
||||
with async_engine.begin() as conn:
|
||||
await conn.run_sync(metadata.create_all)
|
||||
|
||||
.. note::
|
||||
|
||||
The provided callable is invoked inline within the asyncio event
|
||||
loop, and will block on traditional IO calls. IO within this
|
||||
callable should only call into SQLAlchemy's asyncio database
|
||||
APIs which will be properly adapted to the greenlet context.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`session_run_sync`
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
|
||||
return await greenlet_spawn(fn, conn, *arg, **kw)
|
||||
|
||||
def __await__(self):
|
||||
return self.start().__await__()
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
await asyncio.shield(self.close())
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
Engine,
|
||||
":class:`_future.Engine`",
|
||||
":class:`_asyncio.AsyncEngine`",
|
||||
classmethods=[],
|
||||
methods=[
|
||||
"clear_compiled_cache",
|
||||
"update_execution_options",
|
||||
"get_execution_options",
|
||||
],
|
||||
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
|
||||
)
|
||||
class AsyncEngine(ProxyComparable, AsyncConnectable):
|
||||
"""An asyncio proxy for a :class:`_engine.Engine`.
|
||||
|
||||
:class:`_asyncio.AsyncEngine` is acquired using the
|
||||
:func:`_asyncio.create_async_engine` function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
""" # noqa
|
||||
|
||||
# AsyncEngine is a thin proxy; no state should be added here
|
||||
# that is not retrievable from the "sync" engine / connection, e.g.
|
||||
# current transaction, info, etc. It should be possible to
|
||||
# create a new AsyncEngine that matches this one given only the
|
||||
# "sync" elements.
|
||||
__slots__ = ("sync_engine", "_proxied")
|
||||
|
||||
_connection_cls = AsyncConnection
|
||||
|
||||
_option_cls: type
|
||||
|
||||
class _trans_ctx(StartableContext):
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
async def start(self, is_ctxmanager=False):
|
||||
await self.conn.start(is_ctxmanager=is_ctxmanager)
|
||||
self.transaction = self.conn.begin()
|
||||
await self.transaction.__aenter__()
|
||||
|
||||
return self.conn
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
async def go():
|
||||
await self.transaction.__aexit__(type_, value, traceback)
|
||||
await self.conn.close()
|
||||
|
||||
await asyncio.shield(go())
|
||||
|
||||
def __init__(self, sync_engine):
|
||||
if not sync_engine.dialect.is_async:
|
||||
raise exc.InvalidRequestError(
|
||||
"The asyncio extension requires an async driver to be used. "
|
||||
f"The loaded {sync_engine.dialect.driver!r} is not async."
|
||||
)
|
||||
self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
|
||||
|
||||
sync_engine: Engine
|
||||
"""Reference to the sync-style :class:`_engine.Engine` this
|
||||
:class:`_asyncio.AsyncEngine` proxies requests towards.
|
||||
|
||||
This instance can be used as an event target.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`asyncio_events`
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _regenerate_proxy_for_target(cls, target):
|
||||
return AsyncEngine(target)
|
||||
|
||||
def begin(self):
|
||||
"""Return a context manager which when entered will deliver an
|
||||
:class:`_asyncio.AsyncConnection` with an
|
||||
:class:`_asyncio.AsyncTransaction` established.
|
||||
|
||||
E.g.::
|
||||
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text("insert into table (x, y, z) values (1, 2, 3)")
|
||||
)
|
||||
await conn.execute(text("my_special_procedure(5)"))
|
||||
|
||||
|
||||
"""
|
||||
conn = self.connect()
|
||||
return self._trans_ctx(conn)
|
||||
|
||||
def connect(self):
|
||||
"""Return an :class:`_asyncio.AsyncConnection` object.
|
||||
|
||||
The :class:`_asyncio.AsyncConnection` will procure a database
|
||||
connection from the underlying connection pool when it is entered
|
||||
as an async context manager::
|
||||
|
||||
async with async_engine.connect() as conn:
|
||||
result = await conn.execute(select(user_table))
|
||||
|
||||
The :class:`_asyncio.AsyncConnection` may also be started outside of a
|
||||
context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
|
||||
method.
|
||||
|
||||
"""
|
||||
|
||||
return self._connection_cls(self)
|
||||
|
||||
async def raw_connection(self):
|
||||
"""Return a "raw" DBAPI connection from the connection pool.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`dbapi_connections`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self.sync_engine.raw_connection)
|
||||
|
||||
def execution_options(self, **opt):
|
||||
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
|
||||
:class:`_asyncio.AsyncConnection` objects with the given execution
|
||||
options.
|
||||
|
||||
Proxied from :meth:`_future.Engine.execution_options`. See that
|
||||
method for details.
|
||||
|
||||
"""
|
||||
|
||||
return AsyncEngine(self.sync_engine.execution_options(**opt))
|
||||
|
||||
async def dispose(self):
|
||||
"""Dispose of the connection pool used by this
|
||||
:class:`_asyncio.AsyncEngine`.
|
||||
|
||||
This will close all connection pool connections that are
|
||||
**currently checked in**. See the documentation for the underlying
|
||||
:meth:`_future.Engine.dispose` method for further notes.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_future.Engine.dispose`
|
||||
|
||||
"""
|
||||
|
||||
await greenlet_spawn(self.sync_engine.dispose)
|
||||
|
||||
|
||||
class AsyncTransaction(ProxyComparable, StartableContext):
|
||||
"""An asyncio proxy for a :class:`_engine.Transaction`."""
|
||||
|
||||
__slots__ = ("connection", "sync_transaction", "nested")
|
||||
|
||||
def __init__(self, connection, nested=False):
|
||||
self.connection = connection # AsyncConnection
|
||||
self.sync_transaction = None # sqlalchemy.engine.Transaction
|
||||
self.nested = nested
|
||||
|
||||
@classmethod
|
||||
def _regenerate_proxy_for_target(cls, target):
|
||||
sync_connection = target.connection
|
||||
sync_transaction = target
|
||||
nested = isinstance(target, NestedTransaction)
|
||||
|
||||
async_connection = AsyncConnection._retrieve_proxy_for_target(
|
||||
sync_connection
|
||||
)
|
||||
assert async_connection is not None
|
||||
|
||||
obj = cls.__new__(cls)
|
||||
obj.connection = async_connection
|
||||
obj.sync_transaction = obj._assign_proxied(sync_transaction)
|
||||
obj.nested = nested
|
||||
return obj
|
||||
|
||||
def _sync_transaction(self):
|
||||
if not self.sync_transaction:
|
||||
self._raise_for_not_started()
|
||||
return self.sync_transaction
|
||||
|
||||
@property
|
||||
def _proxied(self):
|
||||
return self.sync_transaction
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self._sync_transaction().is_valid
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self._sync_transaction().is_active
|
||||
|
||||
async def close(self):
|
||||
"""Close this :class:`.Transaction`.
|
||||
|
||||
If this transaction is the base transaction in a begin/commit
|
||||
nesting, the transaction will rollback(). Otherwise, the
|
||||
method returns.
|
||||
|
||||
This is used to cancel a Transaction without affecting the scope of
|
||||
an enclosing transaction.
|
||||
|
||||
"""
|
||||
await greenlet_spawn(self._sync_transaction().close)
|
||||
|
||||
async def rollback(self):
|
||||
"""Roll back this :class:`.Transaction`."""
|
||||
await greenlet_spawn(self._sync_transaction().rollback)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit this :class:`.Transaction`."""
|
||||
|
||||
await greenlet_spawn(self._sync_transaction().commit)
|
||||
|
||||
async def start(self, is_ctxmanager=False):
|
||||
"""Start this :class:`_asyncio.AsyncTransaction` object's context
|
||||
outside of using a Python ``with:`` block.
|
||||
|
||||
"""
|
||||
|
||||
self.sync_transaction = self._assign_proxied(
|
||||
await greenlet_spawn(
|
||||
self.connection._sync_connection().begin_nested
|
||||
if self.nested
|
||||
else self.connection._sync_connection().begin
|
||||
)
|
||||
)
|
||||
if is_ctxmanager:
|
||||
self.sync_transaction.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
await greenlet_spawn(
|
||||
self._sync_transaction().__exit__, type_, value, traceback
|
||||
)
|
||||
|
||||
|
||||
def _get_sync_engine_or_connection(async_engine):
|
||||
if isinstance(async_engine, AsyncConnection):
|
||||
return async_engine.sync_connection
|
||||
|
||||
try:
|
||||
return async_engine.sync_engine
|
||||
except AttributeError as e:
|
||||
raise exc.ArgumentError(
|
||||
"AsyncEngine expected, got %r" % async_engine
|
||||
) from e
|
||||
|
||||
|
||||
@inspection._inspects(AsyncConnection)
|
||||
def _no_insp_for_async_conn_yet(subject):
|
||||
raise exc.NoInspectionAvailable(
|
||||
"Inspection on an AsyncConnection is currently not supported. "
|
||||
"Please use ``run_sync`` to pass a callable where it's possible "
|
||||
"to call ``inspect`` on the passed connection.",
|
||||
code="xd3s",
|
||||
)
|
||||
|
||||
|
||||
@inspection._inspects(AsyncEngine)
|
||||
def _no_insp_for_async_engine_xyet(subject):
|
||||
raise exc.NoInspectionAvailable(
|
||||
"Inspection on an AsyncEngine is currently not supported. "
|
||||
"Please obtain a connection then use ``conn.run_sync`` to pass a "
|
||||
"callable where it's possible to call ``inspect`` on the "
|
||||
"passed connection.",
|
||||
code="xd3s",
|
||||
)
|
||||
44
lib/sqlalchemy/ext/asyncio/events.py
Normal file
44
lib/sqlalchemy/ext/asyncio/events.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# ext/asyncio/events.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .engine import AsyncConnectable
|
||||
from .session import AsyncSession
|
||||
from ...engine import events as engine_event
|
||||
from ...orm import events as orm_event
|
||||
|
||||
|
||||
class AsyncConnectionEvents(engine_event.ConnectionEvents):
|
||||
_target_class_doc = "SomeEngine"
|
||||
_dispatch_target = AsyncConnectable
|
||||
|
||||
@classmethod
|
||||
def _no_async_engine_events(cls):
|
||||
raise NotImplementedError(
|
||||
"asynchronous events are not implemented at this time. Apply "
|
||||
"synchronous listeners to the AsyncEngine.sync_engine or "
|
||||
"AsyncConnection.sync_connection attributes."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, retval=False):
|
||||
cls._no_async_engine_events()
|
||||
|
||||
|
||||
class AsyncSessionEvents(orm_event.SessionEvents):
|
||||
_target_class_doc = "SomeSession"
|
||||
_dispatch_target = AsyncSession
|
||||
|
||||
@classmethod
|
||||
def _no_async_engine_events(cls):
|
||||
raise NotImplementedError(
|
||||
"asynchronous events are not implemented at this time. Apply "
|
||||
"synchronous listeners to the AsyncSession.sync_session."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, retval=False):
|
||||
cls._no_async_engine_events()
|
||||
21
lib/sqlalchemy/ext/asyncio/exc.py
Normal file
21
lib/sqlalchemy/ext/asyncio/exc.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# ext/asyncio/exc.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ... import exc
|
||||
|
||||
|
||||
class AsyncMethodRequired(exc.InvalidRequestError):
|
||||
"""an API can't be used because its result would not be
|
||||
compatible with async"""
|
||||
|
||||
|
||||
class AsyncContextNotStarted(exc.InvalidRequestError):
|
||||
"""a startable context manager has not been started."""
|
||||
|
||||
|
||||
class AsyncContextAlreadyStarted(exc.InvalidRequestError):
|
||||
"""a startable context manager is already started."""
|
||||
671
lib/sqlalchemy/ext/asyncio/result.py
Normal file
671
lib/sqlalchemy/ext/asyncio/result.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# ext/asyncio/result.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import operator
|
||||
|
||||
from . import exc as async_exc
|
||||
from ...engine.result import _NO_ROW
|
||||
from ...engine.result import FilterResult
|
||||
from ...engine.result import FrozenResult
|
||||
from ...engine.result import MergedResult
|
||||
from ...sql.base import _generative
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
|
||||
|
||||
class AsyncCommon(FilterResult):
|
||||
async def close(self):
|
||||
"""Close this result."""
|
||||
|
||||
await greenlet_spawn(self._real_result.close)
|
||||
|
||||
|
||||
class AsyncResult(AsyncCommon):
|
||||
"""An asyncio wrapper around a :class:`_result.Result` object.
|
||||
|
||||
The :class:`_asyncio.AsyncResult` only applies to statement executions that
|
||||
use a server-side cursor. It is returned only from the
|
||||
:meth:`_asyncio.AsyncConnection.stream` and
|
||||
:meth:`_asyncio.AsyncSession.stream` methods.
|
||||
|
||||
.. note:: As is the case with :class:`_engine.Result`, this object is
|
||||
used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
|
||||
which can yield instances of ORM mapped objects either individually or
|
||||
within tuple-like rows. Note that these result objects do not
|
||||
deduplicate instances or rows automatically as is the case with the
|
||||
legacy :class:`_orm.Query` object. For in-Python de-duplication of
|
||||
instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
|
||||
method.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, real_result):
|
||||
self._real_result = real_result
|
||||
|
||||
self._metadata = real_result._metadata
|
||||
self._unique_filter_state = real_result._unique_filter_state
|
||||
|
||||
# BaseCursorResult pre-generates the "_row_getter". Use that
|
||||
# if available rather than building a second one
|
||||
if "_row_getter" in real_result.__dict__:
|
||||
self._set_memoized_attribute(
|
||||
"_row_getter", real_result.__dict__["_row_getter"]
|
||||
)
|
||||
|
||||
def keys(self):
|
||||
"""Return the :meth:`_engine.Result.keys` collection from the
|
||||
underlying :class:`_engine.Result`.
|
||||
|
||||
"""
|
||||
return self._metadata.keys
|
||||
|
||||
@_generative
|
||||
def unique(self, strategy=None):
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncResult`.
|
||||
|
||||
Refer to :meth:`_engine.Result.unique` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
|
||||
def columns(self, *col_expressions):
|
||||
r"""Establish the columns that should be returned in each row.
|
||||
|
||||
Refer to :meth:`_engine.Result.columns` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
|
||||
"""
|
||||
return self._column_slices(col_expressions)
|
||||
|
||||
async def partitions(self, size=None):
|
||||
"""Iterate through sub-lists of rows of the size given.
|
||||
|
||||
An async iterator is returned::
|
||||
|
||||
async def scroll_results(connection):
|
||||
result = await connection.stream(select(users_table))
|
||||
|
||||
async for partition in result.partitions(100):
|
||||
print("list of rows: %s" % partition)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_engine.Result.partitions`
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchone(self):
|
||||
"""Fetch one row.
|
||||
|
||||
When all rows are exhausted, returns None.
|
||||
|
||||
This method is provided for backwards compatibility with
|
||||
SQLAlchemy 1.x.x.
|
||||
|
||||
To fetch the first row of a result only, use the
|
||||
:meth:`_engine.Result.first` method. To iterate through all
|
||||
rows, iterate the :class:`_engine.Result` object directly.
|
||||
|
||||
:return: a :class:`.Row` object if no filters are applied, or None
|
||||
if no rows remain.
|
||||
|
||||
"""
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
return None
|
||||
else:
|
||||
return row
|
||||
|
||||
async def fetchmany(self, size=None):
|
||||
"""Fetch many rows.
|
||||
|
||||
When all rows are exhausted, returns an empty list.
|
||||
|
||||
This method is provided for backwards compatibility with
|
||||
SQLAlchemy 1.x.x.
|
||||
|
||||
To fetch rows in groups, use the
|
||||
:meth:`._asyncio.AsyncResult.partitions` method.
|
||||
|
||||
:return: a list of :class:`.Row` objects.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.partitions`
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self):
|
||||
"""Return all rows in a list.
|
||||
|
||||
Closes the result set after invocation. Subsequent invocations
|
||||
will return an empty list.
|
||||
|
||||
:return: a list of :class:`.Row` objects.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self):
|
||||
"""Fetch the first row or None if no row is present.
|
||||
|
||||
Closes the result set and discards remaining rows.
|
||||
|
||||
.. note:: This method returns one **row**, e.g. tuple, by default. To
|
||||
return exactly one single scalar value, that is, the first column of
|
||||
the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
|
||||
or combine :meth:`_asyncio.AsyncResult.scalars` and
|
||||
:meth:`_asyncio.AsyncResult.first`.
|
||||
|
||||
:return: a :class:`.Row` object, or None
|
||||
if no rows remain.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalar`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self):
|
||||
"""Return at most one result or raise an exception.
|
||||
|
||||
Returns ``None`` if the result has no rows.
|
||||
Raises :class:`.MultipleResultsFound`
|
||||
if multiple rows are returned.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
:return: The first :class:`.Row` or None if no row is available.
|
||||
|
||||
:raises: :class:`.MultipleResultsFound`
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.first`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
async def scalar_one(self):
|
||||
"""Return exactly one scalar result or raise an exception.
|
||||
|
||||
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
|
||||
then :meth:`_asyncio.AsyncResult.one`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalars`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, True)
|
||||
|
||||
async def scalar_one_or_none(self):
|
||||
"""Return exactly one or no scalar result.
|
||||
|
||||
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
|
||||
then :meth:`_asyncio.AsyncResult.one_or_none`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one_or_none`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalars`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, True)
|
||||
|
||||
async def one(self):
|
||||
"""Return exactly one row or raise an exception.
|
||||
|
||||
Raises :class:`.NoResultFound` if the result returns no
|
||||
rows, or :class:`.MultipleResultsFound` if multiple rows
|
||||
would be returned.
|
||||
|
||||
.. note:: This method returns one **row**, e.g. tuple, by default.
|
||||
To return exactly one single scalar value, that is, the first
|
||||
column of the first row, use the
|
||||
:meth:`_asyncio.AsyncResult.scalar_one` method, or combine
|
||||
:meth:`_asyncio.AsyncResult.scalars` and
|
||||
:meth:`_asyncio.AsyncResult.one`.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
:return: The first :class:`.Row`.
|
||||
|
||||
:raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.first`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one_or_none`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalar_one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
async def scalar(self):
|
||||
"""Fetch the first column of the first row, and close the result set.
|
||||
|
||||
Returns None if there are no rows to fetch.
|
||||
|
||||
No validation is performed to test if additional rows remain.
|
||||
|
||||
After calling this method, the object is fully closed,
|
||||
e.g. the :meth:`_engine.CursorResult.close`
|
||||
method will have been called.
|
||||
|
||||
:return: a Python scalar value , or None if no rows remain.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, True)
|
||||
|
||||
async def freeze(self):
|
||||
"""Return a callable object that will produce copies of this
|
||||
:class:`_asyncio.AsyncResult` when invoked.
|
||||
|
||||
The callable object returned is an instance of
|
||||
:class:`_engine.FrozenResult`.
|
||||
|
||||
This is used for result set caching. The method must be called
|
||||
on the result when it has been unconsumed, and calling the method
|
||||
will consume the result fully. When the :class:`_engine.FrozenResult`
|
||||
is retrieved from a cache, it can be called any number of times where
|
||||
it will produce a new :class:`_engine.Result` object each time
|
||||
against its stored set of rows.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`do_orm_execute_re_executing` - example usage within the
|
||||
ORM to implement a result-set cache.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(FrozenResult, self)
|
||||
|
||||
def merge(self, *others):
|
||||
"""Merge this :class:`_asyncio.AsyncResult` with other compatible
|
||||
result objects.
|
||||
|
||||
The object returned is an instance of :class:`_engine.MergedResult`,
|
||||
which will be composed of iterators from the given result
|
||||
objects.
|
||||
|
||||
The new result will use the metadata from this result object.
|
||||
The subsequent result objects must be against an identical
|
||||
set of result / cursor metadata, otherwise the behavior is
|
||||
undefined.
|
||||
|
||||
"""
|
||||
return MergedResult(self._metadata, (self,) + others)
|
||||
|
||||
def scalars(self, index=0):
|
||||
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
|
||||
will return single elements rather than :class:`_row.Row` objects.
|
||||
|
||||
Refer to :meth:`_result.Result.scalars` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
:param index: integer or row key indicating the column to be fetched
|
||||
from each row, defaults to ``0`` indicating the first column.
|
||||
|
||||
:return: a new :class:`_asyncio.AsyncScalarResult` filtering object
|
||||
referring to this :class:`_asyncio.AsyncResult` object.
|
||||
|
||||
"""
|
||||
return AsyncScalarResult(self._real_result, index)
|
||||
|
||||
def mappings(self):
|
||||
"""Apply a mappings filter to returned rows, returning an instance of
|
||||
:class:`_asyncio.AsyncMappingResult`.
|
||||
|
||||
When this filter is applied, fetching rows will return
|
||||
:class:`.RowMapping` objects instead of :class:`.Row` objects.
|
||||
|
||||
Refer to :meth:`_result.Result.mappings` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
:return: a new :class:`_asyncio.AsyncMappingResult` filtering object
|
||||
referring to the underlying :class:`_result.Result` object.
|
||||
|
||||
"""
|
||||
|
||||
return AsyncMappingResult(self._real_result)
|
||||
|
||||
|
||||
class AsyncScalarResult(AsyncCommon):
|
||||
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
|
||||
rather than :class:`_row.Row` values.
|
||||
|
||||
The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
|
||||
:meth:`_asyncio.AsyncResult.scalars` method.
|
||||
|
||||
Refer to the :class:`_result.ScalarResult` object in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
_generate_rows = False
|
||||
|
||||
def __init__(self, real_result, index):
|
||||
self._real_result = real_result
|
||||
|
||||
if real_result._source_supports_scalars:
|
||||
self._metadata = real_result._metadata
|
||||
self._post_creational_filter = None
|
||||
else:
|
||||
self._metadata = real_result._metadata._reduce([index])
|
||||
self._post_creational_filter = operator.itemgetter(0)
|
||||
|
||||
self._unique_filter_state = real_result._unique_filter_state
|
||||
|
||||
def unique(self, strategy=None):
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncScalarResult`.
|
||||
|
||||
See :meth:`_asyncio.AsyncResult.unique` for usage details.
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
return self
|
||||
|
||||
async def partitions(self, size=None):
|
||||
"""Iterate through sub-lists of elements of the size given.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchall(self):
|
||||
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
async def fetchmany(self, size=None):
|
||||
"""Fetch many objects.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self):
|
||||
"""Return all scalar values in a list.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self):
|
||||
"""Fetch the first object or None if no object is present.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self):
|
||||
"""Return at most one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
async def one(self):
|
||||
"""Return exactly one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
|
||||
scalar values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
|
||||
class AsyncMappingResult(AsyncCommon):
|
||||
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
|
||||
values rather than :class:`_engine.Row` values.
|
||||
|
||||
The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
|
||||
:meth:`_asyncio.AsyncResult.mappings` method.
|
||||
|
||||
Refer to the :class:`_result.MappingResult` object in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
_generate_rows = True
|
||||
|
||||
_post_creational_filter = operator.attrgetter("_mapping")
|
||||
|
||||
def __init__(self, result):
|
||||
self._real_result = result
|
||||
self._unique_filter_state = result._unique_filter_state
|
||||
self._metadata = result._metadata
|
||||
if result._source_supports_scalars:
|
||||
self._metadata = self._metadata._reduce([0])
|
||||
|
||||
def keys(self):
|
||||
"""Return an iterable view which yields the string keys that would
|
||||
be represented by each :class:`.Row`.
|
||||
|
||||
The view also can be tested for key containment using the Python
|
||||
``in`` operator, which will test both for the string keys represented
|
||||
in the view, as well as for alternate keys such as column objects.
|
||||
|
||||
.. versionchanged:: 1.4 a key view object is returned rather than a
|
||||
plain list.
|
||||
|
||||
|
||||
"""
|
||||
return self._metadata.keys
|
||||
|
||||
def unique(self, strategy=None):
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncMappingResult`.
|
||||
|
||||
See :meth:`_asyncio.AsyncResult.unique` for usage details.
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
return self
|
||||
|
||||
def columns(self, *col_expressions):
|
||||
r"""Establish the columns that should be returned in each row."""
|
||||
return self._column_slices(col_expressions)
|
||||
|
||||
async def partitions(self, size=None):
|
||||
"""Iterate through sub-lists of elements of the size given.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchall(self):
|
||||
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
async def fetchone(self):
|
||||
"""Fetch one object.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
return None
|
||||
else:
|
||||
return row
|
||||
|
||||
async def fetchmany(self, size=None):
|
||||
"""Fetch many objects.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self):
|
||||
"""Return all scalar values in a list.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self):
|
||||
"""Fetch the first object or None if no object is present.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self):
|
||||
"""Return at most one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
async def one(self):
|
||||
"""Return exactly one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
|
||||
mapping values, rather than :class:`_result.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
|
||||
async def _ensure_sync_result(result, calling_method):
|
||||
if not result._is_cursor:
|
||||
cursor_result = getattr(result, "raw", None)
|
||||
else:
|
||||
cursor_result = result
|
||||
if cursor_result and cursor_result.context._is_server_side:
|
||||
await greenlet_spawn(cursor_result.close)
|
||||
raise async_exc.AsyncMethodRequired(
|
||||
"Can't use the %s.%s() method with a "
|
||||
"server-side cursor. "
|
||||
"Use the %s.stream() method for an async "
|
||||
"streaming result set."
|
||||
% (
|
||||
calling_method.__self__.__class__.__name__,
|
||||
calling_method.__name__,
|
||||
calling_method.__self__.__class__.__name__,
|
||||
)
|
||||
)
|
||||
return result
|
||||
107
lib/sqlalchemy/ext/asyncio/scoping.py
Normal file
107
lib/sqlalchemy/ext/asyncio/scoping.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# ext/asyncio/scoping.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .session import AsyncSession
|
||||
from ...orm.scoping import ScopedSessionMixin
|
||||
from ...util import create_proxy_methods
|
||||
from ...util import ScopedRegistry
|
||||
|
||||
|
||||
@create_proxy_methods(
|
||||
AsyncSession,
|
||||
":class:`_asyncio.AsyncSession`",
|
||||
":class:`_asyncio.scoping.async_scoped_session`",
|
||||
classmethods=["close_all", "object_session", "identity_key"],
|
||||
methods=[
|
||||
"__contains__",
|
||||
"__iter__",
|
||||
"add",
|
||||
"add_all",
|
||||
"begin",
|
||||
"begin_nested",
|
||||
"close",
|
||||
"commit",
|
||||
"connection",
|
||||
"delete",
|
||||
"execute",
|
||||
"expire",
|
||||
"expire_all",
|
||||
"expunge",
|
||||
"expunge_all",
|
||||
"flush",
|
||||
"get",
|
||||
"get_bind",
|
||||
"is_modified",
|
||||
"invalidate",
|
||||
"merge",
|
||||
"refresh",
|
||||
"rollback",
|
||||
"scalar",
|
||||
"scalars",
|
||||
"stream",
|
||||
"stream_scalars",
|
||||
],
|
||||
attributes=[
|
||||
"bind",
|
||||
"dirty",
|
||||
"deleted",
|
||||
"new",
|
||||
"identity_map",
|
||||
"is_active",
|
||||
"autoflush",
|
||||
"no_autoflush",
|
||||
"info",
|
||||
],
|
||||
)
|
||||
class async_scoped_session(ScopedSessionMixin):
|
||||
"""Provides scoped management of :class:`.AsyncSession` objects.
|
||||
|
||||
See the section :ref:`asyncio_scoped_session` for usage details.
|
||||
|
||||
.. versionadded:: 1.4.19
|
||||
|
||||
|
||||
"""
|
||||
|
||||
_support_async = True
|
||||
|
||||
def __init__(self, session_factory, scopefunc):
|
||||
"""Construct a new :class:`_asyncio.async_scoped_session`.
|
||||
|
||||
:param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
|
||||
instances. This is usually, but not necessarily, an instance
|
||||
of :class:`_orm.sessionmaker` which itself was passed the
|
||||
:class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
|
||||
parameter::
|
||||
|
||||
async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
|
||||
AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
|
||||
|
||||
:param scopefunc: function which defines
|
||||
the current scope. A function such as ``asyncio.current_task``
|
||||
may be useful here.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
self.session_factory = session_factory
|
||||
self.registry = ScopedRegistry(session_factory, scopefunc)
|
||||
|
||||
@property
|
||||
def _proxied(self):
|
||||
return self.registry()
|
||||
|
||||
async def remove(self):
|
||||
"""Dispose of the current :class:`.AsyncSession`, if present.
|
||||
|
||||
Different from scoped_session's remove method, this method would use
|
||||
await to wait for the close method of AsyncSession.
|
||||
|
||||
"""
|
||||
|
||||
if self.registry.has():
|
||||
await self.registry().close()
|
||||
self.registry.clear()
|
||||
759
lib/sqlalchemy/ext/asyncio/session.py
Normal file
759
lib/sqlalchemy/ext/asyncio/session.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# ext/asyncio/session.py
|
||||
# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import asyncio
|
||||
|
||||
from . import engine
|
||||
from . import result as _result
|
||||
from .base import ReversibleProxy
|
||||
from .base import StartableContext
|
||||
from .result import _ensure_sync_result
|
||||
from ... import util
|
||||
from ...orm import object_session
|
||||
from ...orm import Session
|
||||
from ...orm import state as _instance_state
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
|
||||
_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
|
||||
_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
Session,
|
||||
":class:`_orm.Session`",
|
||||
":class:`_asyncio.AsyncSession`",
|
||||
classmethods=["object_session", "identity_key"],
|
||||
methods=[
|
||||
"__contains__",
|
||||
"__iter__",
|
||||
"add",
|
||||
"add_all",
|
||||
"expire",
|
||||
"expire_all",
|
||||
"expunge",
|
||||
"expunge_all",
|
||||
"is_modified",
|
||||
"in_transaction",
|
||||
"in_nested_transaction",
|
||||
],
|
||||
attributes=[
|
||||
"dirty",
|
||||
"deleted",
|
||||
"new",
|
||||
"identity_map",
|
||||
"is_active",
|
||||
"autoflush",
|
||||
"no_autoflush",
|
||||
"info",
|
||||
],
|
||||
)
|
||||
class AsyncSession(ReversibleProxy):
|
||||
"""Asyncio version of :class:`_orm.Session`.
|
||||
|
||||
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
|
||||
:class:`_orm.Session` instance.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
|
||||
implementations, see the
|
||||
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
_is_asyncio = True
|
||||
|
||||
dispatch = None
|
||||
|
||||
def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
|
||||
r"""Construct a new :class:`_asyncio.AsyncSession`.
|
||||
|
||||
All parameters other than ``sync_session_class`` are passed to the
|
||||
``sync_session_class`` callable directly to instantiate a new
|
||||
:class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
|
||||
parameter documentation.
|
||||
|
||||
:param sync_session_class:
|
||||
A :class:`_orm.Session` subclass or other callable which will be used
|
||||
to construct the :class:`_orm.Session` which will be proxied. This
|
||||
parameter may be used to provide custom :class:`_orm.Session`
|
||||
subclasses. Defaults to the
|
||||
:attr:`_asyncio.AsyncSession.sync_session_class` class-level
|
||||
attribute.
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
"""
|
||||
kw["future"] = True
|
||||
if bind:
|
||||
self.bind = bind
|
||||
bind = engine._get_sync_engine_or_connection(bind)
|
||||
|
||||
if binds:
|
||||
self.binds = binds
|
||||
binds = {
|
||||
key: engine._get_sync_engine_or_connection(b)
|
||||
for key, b in binds.items()
|
||||
}
|
||||
|
||||
if sync_session_class:
|
||||
self.sync_session_class = sync_session_class
|
||||
|
||||
self.sync_session = self._proxied = self._assign_proxied(
|
||||
self.sync_session_class(bind=bind, binds=binds, **kw)
|
||||
)
|
||||
|
||||
sync_session_class = Session
|
||||
"""The class or callable that provides the
|
||||
underlying :class:`_orm.Session` instance for a particular
|
||||
:class:`_asyncio.AsyncSession`.
|
||||
|
||||
At the class level, this attribute is the default value for the
|
||||
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
|
||||
subclasses of :class:`_asyncio.AsyncSession` can override this.
|
||||
|
||||
At the instance level, this attribute indicates the current class or
|
||||
callable that was used to provide the :class:`_orm.Session` instance for
|
||||
this :class:`_asyncio.AsyncSession` instance.
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
"""
|
||||
|
||||
sync_session: Session
|
||||
"""Reference to the underlying :class:`_orm.Session` this
|
||||
:class:`_asyncio.AsyncSession` proxies requests towards.
|
||||
|
||||
This instance can be used as an event target.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`asyncio_events`
|
||||
|
||||
"""
|
||||
|
||||
async def refresh(
|
||||
self, instance, attribute_names=None, with_for_update=None
|
||||
):
|
||||
"""Expire and refresh the attributes on the given instance.
|
||||
|
||||
A query will be issued to the database and all attributes will be
|
||||
refreshed with their current database value.
|
||||
|
||||
This is the async version of the :meth:`_orm.Session.refresh` method.
|
||||
See that method for a complete description of all options.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.refresh` - main documentation for refresh
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(
|
||||
self.sync_session.refresh,
|
||||
instance,
|
||||
attribute_names=attribute_names,
|
||||
with_for_update=with_for_update,
|
||||
)
|
||||
|
||||
async def run_sync(self, fn, *arg, **kw):
|
||||
"""Invoke the given sync callable passing sync self as the first
|
||||
argument.
|
||||
|
||||
This method maintains the asyncio event loop all the way through
|
||||
to the database connection by running the given callable in a
|
||||
specially instrumented greenlet.
|
||||
|
||||
E.g.::
|
||||
|
||||
with AsyncSession(async_engine) as session:
|
||||
await session.run_sync(some_business_method)
|
||||
|
||||
.. note::
|
||||
|
||||
The provided callable is invoked inline within the asyncio event
|
||||
loop, and will block on traditional IO calls. IO within this
|
||||
callable should only call into SQLAlchemy's asyncio database
|
||||
APIs which will be properly adapted to the greenlet context.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`session_run_sync`
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
statement,
|
||||
params=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
bind_arguments=None,
|
||||
**kw
|
||||
):
|
||||
"""Execute a statement and return a buffered
|
||||
:class:`_engine.Result` object.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.execute` - main documentation for execute
|
||||
|
||||
"""
|
||||
|
||||
if execution_options:
|
||||
execution_options = util.immutabledict(execution_options).union(
|
||||
_EXECUTE_OPTIONS
|
||||
)
|
||||
else:
|
||||
execution_options = _EXECUTE_OPTIONS
|
||||
|
||||
result = await greenlet_spawn(
|
||||
self.sync_session.execute,
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw
|
||||
)
|
||||
return await _ensure_sync_result(result, self.execute)
|
||||
|
||||
async def scalar(
|
||||
self,
|
||||
statement,
|
||||
params=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
bind_arguments=None,
|
||||
**kw
|
||||
):
|
||||
"""Execute a statement and return a scalar result.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.scalar` - main documentation for scalar
|
||||
|
||||
"""
|
||||
|
||||
result = await self.execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
async def scalars(
|
||||
self,
|
||||
statement,
|
||||
params=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
bind_arguments=None,
|
||||
**kw
|
||||
):
|
||||
"""Execute a statement and return scalar results.
|
||||
|
||||
:return: a :class:`_result.ScalarResult` object
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.scalars` - main documentation for scalars
|
||||
|
||||
:meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
|
||||
|
||||
"""
|
||||
|
||||
result = await self.execute(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw
|
||||
)
|
||||
return result.scalars()
|
||||
|
||||
async def get(
|
||||
self,
|
||||
entity,
|
||||
ident,
|
||||
options=None,
|
||||
populate_existing=False,
|
||||
with_for_update=None,
|
||||
identity_token=None,
|
||||
):
|
||||
"""Return an instance based on the given primary key identifier,
|
||||
or ``None`` if not found.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.get` - main documentation for get
|
||||
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(
|
||||
self.sync_session.get,
|
||||
entity,
|
||||
ident,
|
||||
options=options,
|
||||
populate_existing=populate_existing,
|
||||
with_for_update=with_for_update,
|
||||
identity_token=identity_token,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
statement,
|
||||
params=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
bind_arguments=None,
|
||||
**kw
|
||||
):
|
||||
"""Execute a statement and return a streaming
|
||||
:class:`_asyncio.AsyncResult` object.
|
||||
|
||||
"""
|
||||
|
||||
if execution_options:
|
||||
execution_options = util.immutabledict(execution_options).union(
|
||||
_STREAM_OPTIONS
|
||||
)
|
||||
else:
|
||||
execution_options = _STREAM_OPTIONS
|
||||
|
||||
result = await greenlet_spawn(
|
||||
self.sync_session.execute,
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw
|
||||
)
|
||||
return _result.AsyncResult(result)
|
||||
|
||||
async def stream_scalars(
|
||||
self,
|
||||
statement,
|
||||
params=None,
|
||||
execution_options=util.EMPTY_DICT,
|
||||
bind_arguments=None,
|
||||
**kw
|
||||
):
|
||||
"""Execute a statement and return a stream of scalar results.
|
||||
|
||||
:return: an :class:`_asyncio.AsyncScalarResult` object
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.scalars` - main documentation for scalars
|
||||
|
||||
:meth:`_asyncio.AsyncSession.scalars` - non streaming version
|
||||
|
||||
"""
|
||||
|
||||
result = await self.stream(
|
||||
statement,
|
||||
params=params,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=bind_arguments,
|
||||
**kw
|
||||
)
|
||||
return result.scalars()
|
||||
|
||||
async def delete(self, instance):
|
||||
"""Mark an instance as deleted.
|
||||
|
||||
The database delete operation occurs upon ``flush()``.
|
||||
|
||||
As this operation may need to cascade along unloaded relationships,
|
||||
it is awaitable to allow for those queries to take place.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.delete` - main documentation for delete
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self.sync_session.delete, instance)
|
||||
|
||||
async def merge(self, instance, load=True, options=None):
|
||||
"""Copy the state of a given instance into a corresponding instance
|
||||
within this :class:`_asyncio.AsyncSession`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.merge` - main documentation for merge
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(
|
||||
self.sync_session.merge, instance, load=load, options=options
|
||||
)
|
||||
|
||||
async def flush(self, objects=None):
|
||||
"""Flush all the object changes to the database.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.flush` - main documentation for flush
|
||||
|
||||
"""
|
||||
await greenlet_spawn(self.sync_session.flush, objects=objects)
|
||||
|
||||
def get_transaction(self):
|
||||
"""Return the current root transaction in progress, if any.
|
||||
|
||||
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
|
||||
``None``.
|
||||
|
||||
.. versionadded:: 1.4.18
|
||||
|
||||
"""
|
||||
trans = self.sync_session.get_transaction()
|
||||
if trans is not None:
|
||||
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_nested_transaction(self):
|
||||
"""Return the current nested transaction in progress, if any.
|
||||
|
||||
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
|
||||
``None``.
|
||||
|
||||
.. versionadded:: 1.4.18
|
||||
|
||||
"""
|
||||
|
||||
trans = self.sync_session.get_nested_transaction()
|
||||
if trans is not None:
|
||||
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_bind(self, mapper=None, clause=None, bind=None, **kw):
|
||||
"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
|
||||
is bound.
|
||||
|
||||
Unlike the :meth:`_orm.Session.get_bind` method, this method is
|
||||
currently **not** used by this :class:`.AsyncSession` in any way
|
||||
in order to resolve engines for requests.
|
||||
|
||||
.. note::
|
||||
|
||||
This method proxies directly to the :meth:`_orm.Session.get_bind`
|
||||
method, however is currently **not** useful as an override target,
|
||||
in contrast to that of the :meth:`_orm.Session.get_bind` method.
|
||||
The example below illustrates how to implement custom
|
||||
:meth:`_orm.Session.get_bind` schemes that work with
|
||||
:class:`.AsyncSession` and :class:`.AsyncEngine`.
|
||||
|
||||
The pattern introduced at :ref:`session_custom_partitioning`
|
||||
illustrates how to apply a custom bind-lookup scheme to a
|
||||
:class:`_orm.Session` given a set of :class:`_engine.Engine` objects.
|
||||
To apply a corresponding :meth:`_orm.Session.get_bind` implementation
|
||||
for use with a :class:`.AsyncSession` and :class:`.AsyncEngine`
|
||||
objects, continue to subclass :class:`_orm.Session` and apply it to
|
||||
:class:`.AsyncSession` using
|
||||
:paramref:`.AsyncSession.sync_session_class`. The inner method must
|
||||
continue to return :class:`_engine.Engine` instances, which can be
|
||||
acquired from a :class:`_asyncio.AsyncEngine` using the
|
||||
:attr:`_asyncio.AsyncEngine.sync_engine` attribute::
|
||||
|
||||
# using example from "Custom Vertical Partitioning"
|
||||
|
||||
|
||||
import random
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
# construct async engines w/ async drivers
|
||||
engines = {
|
||||
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
|
||||
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
|
||||
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
|
||||
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
|
||||
}
|
||||
|
||||
class RoutingSession(Session):
|
||||
def get_bind(self, mapper=None, clause=None, **kw):
|
||||
# within get_bind(), return sync engines
|
||||
if mapper and issubclass(mapper.class_, MyOtherClass):
|
||||
return engines['other'].sync_engine
|
||||
elif self._flushing or isinstance(clause, (Update, Delete)):
|
||||
return engines['leader'].sync_engine
|
||||
else:
|
||||
return engines[
|
||||
random.choice(['follower1','follower2'])
|
||||
].sync_engine
|
||||
|
||||
# apply to AsyncSession using sync_session_class
|
||||
AsyncSessionMaker = sessionmaker(
|
||||
class_=AsyncSession,
|
||||
sync_session_class=RoutingSession
|
||||
)
|
||||
|
||||
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
|
||||
implicitly non-blocking context in the same manner as ORM event hooks
|
||||
and functions that are invoked via :meth:`.AsyncSession.run_sync`, so
|
||||
routines that wish to run SQL commands inside of
|
||||
:meth:`_orm.Session.get_bind` can continue to do so using
|
||||
blocking-style code, which will be translated to implicitly async calls
|
||||
at the point of invoking IO on the database drivers.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
return self.sync_session.get_bind(
|
||||
mapper=mapper, clause=clause, bind=bind, **kw
|
||||
)
|
||||
|
||||
async def connection(self, **kw):
|
||||
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
|
||||
this :class:`.Session` object's transactional state.
|
||||
|
||||
This method may also be used to establish execution options for the
|
||||
database connection used by the current transaction.
|
||||
|
||||
.. versionadded:: 1.4.24 Added \**kw arguments which are passed
|
||||
through to the underlying :meth:`_orm.Session.connection` method.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.connection` - main documentation for
|
||||
"connection"
|
||||
|
||||
"""
|
||||
|
||||
sync_connection = await greenlet_spawn(
|
||||
self.sync_session.connection, **kw
|
||||
)
|
||||
return engine.AsyncConnection._retrieve_proxy_for_target(
|
||||
sync_connection
|
||||
)
|
||||
|
||||
def begin(self, **kw):
|
||||
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
|
||||
|
||||
The underlying :class:`_orm.Session` will perform the
|
||||
"begin" action when the :class:`_asyncio.AsyncSessionTransaction`
|
||||
object is entered::
|
||||
|
||||
async with async_session.begin():
|
||||
# .. ORM transaction is begun
|
||||
|
||||
Note that database IO will not normally occur when the session-level
|
||||
transaction is begun, as database transactions begin on an
|
||||
on-demand basis. However, the begin block is async to accommodate
|
||||
for a :meth:`_orm.SessionEvents.after_transaction_create`
|
||||
event hook that may perform IO.
|
||||
|
||||
For a general description of ORM begin, see
|
||||
:meth:`_orm.Session.begin`.
|
||||
|
||||
"""
|
||||
|
||||
return AsyncSessionTransaction(self)
|
||||
|
||||
def begin_nested(self, **kw):
|
||||
"""Return an :class:`_asyncio.AsyncSessionTransaction` object
|
||||
which will begin a "nested" transaction, e.g. SAVEPOINT.
|
||||
|
||||
Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
|
||||
|
||||
For a general description of ORM begin nested, see
|
||||
:meth:`_orm.Session.begin_nested`.
|
||||
|
||||
"""
|
||||
|
||||
return AsyncSessionTransaction(self, nested=True)
|
||||
|
||||
async def rollback(self):
|
||||
"""Rollback the current transaction in progress."""
|
||||
return await greenlet_spawn(self.sync_session.rollback)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the current transaction in progress."""
|
||||
return await greenlet_spawn(self.sync_session.commit)
|
||||
|
||||
async def close(self):
|
||||
"""Close out the transactional resources and ORM objects used by this
|
||||
:class:`_asyncio.AsyncSession`.
|
||||
|
||||
This expunges all ORM objects associated with this
|
||||
:class:`_asyncio.AsyncSession`, ends any transaction in progress and
|
||||
:term:`releases` any :class:`_asyncio.AsyncConnection` objects which
|
||||
this :class:`_asyncio.AsyncSession` itself has checked out from
|
||||
associated :class:`_asyncio.AsyncEngine` objects. The operation then
|
||||
leaves the :class:`_asyncio.AsyncSession` in a state which it may be
|
||||
used again.
|
||||
|
||||
.. tip::
|
||||
|
||||
The :meth:`_asyncio.AsyncSession.close` method **does not prevent
|
||||
the Session from being used again**. The
|
||||
:class:`_asyncio.AsyncSession` itself does not actually have a
|
||||
distinct "closed" state; it merely means the
|
||||
:class:`_asyncio.AsyncSession` will release all database
|
||||
connections and ORM objects.
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`session_closing` - detail on the semantics of
|
||||
:meth:`_asyncio.AsyncSession.close`
|
||||
|
||||
"""
|
||||
await greenlet_spawn(self.sync_session.close)
|
||||
|
||||
async def invalidate(self):
|
||||
"""Close this Session, using connection invalidation.
|
||||
|
||||
For a complete description, see :meth:`_orm.Session.invalidate`.
|
||||
"""
|
||||
return await greenlet_spawn(self.sync_session.invalidate)
|
||||
|
||||
@classmethod
|
||||
async def close_all(self):
|
||||
"""Close all :class:`_asyncio.AsyncSession` sessions."""
|
||||
return await greenlet_spawn(self.sync_session.close_all)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
await asyncio.shield(self.close())
|
||||
|
||||
def _maker_context_manager(self):
|
||||
# no @contextlib.asynccontextmanager until python3.7, gr
|
||||
return _AsyncSessionContextManager(self)
|
||||
|
||||
|
||||
class _AsyncSessionContextManager:
|
||||
def __init__(self, async_session):
|
||||
self.async_session = async_session
|
||||
|
||||
async def __aenter__(self):
|
||||
self.trans = self.async_session.begin()
|
||||
await self.trans.__aenter__()
|
||||
return self.async_session
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
async def go():
|
||||
await self.trans.__aexit__(type_, value, traceback)
|
||||
await self.async_session.__aexit__(type_, value, traceback)
|
||||
|
||||
await asyncio.shield(go())
|
||||
|
||||
|
||||
class AsyncSessionTransaction(ReversibleProxy, StartableContext):
|
||||
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
|
||||
|
||||
This object is provided so that a transaction-holding object
|
||||
for the :meth:`_asyncio.AsyncSession.begin` may be returned.
|
||||
|
||||
The object supports both explicit calls to
|
||||
:meth:`_asyncio.AsyncSessionTransaction.commit` and
|
||||
:meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
|
||||
async context manager.
|
||||
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("session", "sync_transaction", "nested")
|
||||
|
||||
def __init__(self, session, nested=False):
|
||||
self.session = session
|
||||
self.nested = nested
|
||||
self.sync_transaction = None
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return (
|
||||
self._sync_transaction() is not None
|
||||
and self._sync_transaction().is_active
|
||||
)
|
||||
|
||||
def _sync_transaction(self):
|
||||
if not self.sync_transaction:
|
||||
self._raise_for_not_started()
|
||||
return self.sync_transaction
|
||||
|
||||
async def rollback(self):
|
||||
"""Roll back this :class:`_asyncio.AsyncTransaction`."""
|
||||
await greenlet_spawn(self._sync_transaction().rollback)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit this :class:`_asyncio.AsyncTransaction`."""
|
||||
|
||||
await greenlet_spawn(self._sync_transaction().commit)
|
||||
|
||||
async def start(self, is_ctxmanager=False):
|
||||
self.sync_transaction = self._assign_proxied(
|
||||
await greenlet_spawn(
|
||||
self.session.sync_session.begin_nested
|
||||
if self.nested
|
||||
else self.session.sync_session.begin
|
||||
)
|
||||
)
|
||||
if is_ctxmanager:
|
||||
self.sync_transaction.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, type_, value, traceback):
|
||||
await greenlet_spawn(
|
||||
self._sync_transaction().__exit__, type_, value, traceback
|
||||
)
|
||||
|
||||
|
||||
def async_object_session(instance):
|
||||
"""Return the :class:`_asyncio.AsyncSession` to which the given instance
|
||||
belongs.
|
||||
|
||||
This function makes use of the sync-API function
|
||||
:class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
|
||||
refers to the given instance, and from there links it to the original
|
||||
:class:`_asyncio.AsyncSession`.
|
||||
|
||||
If the :class:`_asyncio.AsyncSession` has been garbage collected, the
|
||||
return value is ``None``.
|
||||
|
||||
This functionality is also available from the
|
||||
:attr:`_orm.InstanceState.async_session` accessor.
|
||||
|
||||
:param instance: an ORM mapped instance
|
||||
:return: an :class:`_asyncio.AsyncSession` object, or ``None``.
|
||||
|
||||
.. versionadded:: 1.4.18
|
||||
|
||||
"""
|
||||
|
||||
session = object_session(instance)
|
||||
if session is not None:
|
||||
return async_session(session)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def async_session(session):
|
||||
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
|
||||
:class:`_orm.Session` object, if any.
|
||||
|
||||
:param session: a :class:`_orm.Session` instance.
|
||||
:return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
|
||||
|
||||
.. versionadded:: 1.4.18
|
||||
|
||||
"""
|
||||
return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
|
||||
|
||||
|
||||
_instance_state._async_provider = async_session
|
||||
1234
lib/sqlalchemy/ext/automap.py
Normal file
1234
lib/sqlalchemy/ext/automap.py
Normal file
File diff suppressed because it is too large
Load Diff
648
lib/sqlalchemy/ext/baked.py
Normal file
648
lib/sqlalchemy/ext/baked.py
Normal file
@@ -0,0 +1,648 @@
|
||||
# sqlalchemy/ext/baked.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
"""Baked query extension.
|
||||
|
||||
Provides a creational pattern for the :class:`.query.Query` object which
|
||||
allows the fully constructed object, Core select statement, and string
|
||||
compiled result to be fully cached.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .. import exc as sa_exc
|
||||
from .. import util
|
||||
from ..orm import exc as orm_exc
|
||||
from ..orm import strategy_options
|
||||
from ..orm.query import Query
|
||||
from ..orm.session import Session
|
||||
from ..sql import func
|
||||
from ..sql import literal_column
|
||||
from ..sql import util as sql_util
|
||||
from ..util import collections_abc
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bakery(object):
|
||||
"""Callable which returns a :class:`.BakedQuery`.
|
||||
|
||||
This object is returned by the class method
|
||||
:meth:`.BakedQuery.bakery`. It exists as an object
|
||||
so that the "cache" can be easily inspected.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = "cls", "cache"
|
||||
|
||||
def __init__(self, cls_, cache):
|
||||
self.cls = cls_
|
||||
self.cache = cache
|
||||
|
||||
def __call__(self, initial_fn, *args):
|
||||
return self.cls(self.cache, initial_fn, args)
|
||||
|
||||
|
||||
class BakedQuery(object):
|
||||
"""A builder object for :class:`.query.Query` objects."""
|
||||
|
||||
__slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
|
||||
|
||||
def __init__(self, bakery, initial_fn, args=()):
|
||||
self._cache_key = ()
|
||||
self._update_cache_key(initial_fn, args)
|
||||
self.steps = [initial_fn]
|
||||
self._spoiled = False
|
||||
self._bakery = bakery
|
||||
|
||||
@classmethod
|
||||
def bakery(cls, size=200, _size_alert=None):
|
||||
"""Construct a new bakery.
|
||||
|
||||
:return: an instance of :class:`.Bakery`
|
||||
|
||||
"""
|
||||
|
||||
return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
|
||||
|
||||
def _clone(self):
|
||||
b1 = BakedQuery.__new__(BakedQuery)
|
||||
b1._cache_key = self._cache_key
|
||||
b1.steps = list(self.steps)
|
||||
b1._bakery = self._bakery
|
||||
b1._spoiled = self._spoiled
|
||||
return b1
|
||||
|
||||
def _update_cache_key(self, fn, args=()):
|
||||
self._cache_key += (fn.__code__,) + args
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
self.add_criteria(*other)
|
||||
else:
|
||||
self.add_criteria(other)
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return self.with_criteria(*other)
|
||||
else:
|
||||
return self.with_criteria(other)
|
||||
|
||||
def add_criteria(self, fn, *args):
|
||||
"""Add a criteria function to this :class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to using the ``+=`` operator to
|
||||
modify a :class:`.BakedQuery` in-place.
|
||||
|
||||
"""
|
||||
self._update_cache_key(fn, args)
|
||||
self.steps.append(fn)
|
||||
return self
|
||||
|
||||
def with_criteria(self, fn, *args):
|
||||
"""Add a criteria function to a :class:`.BakedQuery` cloned from this
|
||||
one.
|
||||
|
||||
This is equivalent to using the ``+`` operator to
|
||||
produce a new :class:`.BakedQuery` with modifications.
|
||||
|
||||
"""
|
||||
return self._clone().add_criteria(fn, *args)
|
||||
|
||||
def for_session(self, session):
|
||||
"""Return a :class:`_baked.Result` object for this
|
||||
:class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to calling the :class:`.BakedQuery` as a
|
||||
Python callable, e.g. ``result = my_baked_query(session)``.
|
||||
|
||||
"""
|
||||
return Result(self, session)
|
||||
|
||||
def __call__(self, session):
|
||||
return self.for_session(session)
|
||||
|
||||
def spoil(self, full=False):
|
||||
"""Cancel any query caching that will occur on this BakedQuery object.
|
||||
|
||||
The BakedQuery can continue to be used normally, however additional
|
||||
creational functions will not be cached; they will be called
|
||||
on every invocation.
|
||||
|
||||
This is to support the case where a particular step in constructing
|
||||
a baked query disqualifies the query from being cacheable, such
|
||||
as a variant that relies upon some uncacheable value.
|
||||
|
||||
:param full: if False, only functions added to this
|
||||
:class:`.BakedQuery` object subsequent to the spoil step will be
|
||||
non-cached; the state of the :class:`.BakedQuery` up until
|
||||
this point will be pulled from the cache. If True, then the
|
||||
entire :class:`_query.Query` object is built from scratch each
|
||||
time, with all creational functions being called on each
|
||||
invocation.
|
||||
|
||||
"""
|
||||
if not full and not self._spoiled:
|
||||
_spoil_point = self._clone()
|
||||
_spoil_point._cache_key += ("_query_only",)
|
||||
self.steps = [_spoil_point._retrieve_baked_query]
|
||||
self._spoiled = True
|
||||
return self
|
||||
|
||||
def _effective_key(self, session):
|
||||
"""Return the key that actually goes into the cache dictionary for
|
||||
this :class:`.BakedQuery`, taking into account the given
|
||||
:class:`.Session`.
|
||||
|
||||
This basically means we also will include the session's query_class,
|
||||
as the actual :class:`_query.Query` object is part of what's cached
|
||||
and needs to match the type of :class:`_query.Query` that a later
|
||||
session will want to use.
|
||||
|
||||
"""
|
||||
return self._cache_key + (session._query_cls,)
|
||||
|
||||
def _with_lazyload_options(self, options, effective_path, cache_path=None):
|
||||
"""Cloning version of _add_lazyload_options."""
|
||||
q = self._clone()
|
||||
q._add_lazyload_options(options, effective_path, cache_path=cache_path)
|
||||
return q
|
||||
|
||||
def _add_lazyload_options(self, options, effective_path, cache_path=None):
|
||||
"""Used by per-state lazy loaders to add options to the
|
||||
"lazy load" query from a parent query.
|
||||
|
||||
Creates a cache key based on given load path and query options;
|
||||
if a repeatable cache key cannot be generated, the query is
|
||||
"spoiled" so that it won't use caching.
|
||||
|
||||
"""
|
||||
|
||||
key = ()
|
||||
|
||||
if not cache_path:
|
||||
cache_path = effective_path
|
||||
|
||||
for opt in options:
|
||||
if opt._is_legacy_option or opt._is_compile_state:
|
||||
ck = opt._generate_cache_key()
|
||||
if ck is None:
|
||||
self.spoil(full=True)
|
||||
else:
|
||||
assert not ck[1], (
|
||||
"loader options with variable bound parameters "
|
||||
"not supported with baked queries. Please "
|
||||
"use new-style select() statements for cached "
|
||||
"ORM queries."
|
||||
)
|
||||
key += ck[0]
|
||||
|
||||
self.add_criteria(
|
||||
lambda q: q._with_current_path(effective_path).options(*options),
|
||||
cache_path.path,
|
||||
key,
|
||||
)
|
||||
|
||||
def _retrieve_baked_query(self, session):
|
||||
query = self._bakery.get(self._effective_key(session), None)
|
||||
if query is None:
|
||||
query = self._as_query(session)
|
||||
self._bakery[self._effective_key(session)] = query.with_session(
|
||||
None
|
||||
)
|
||||
return query.with_session(session)
|
||||
|
||||
def _bake(self, session):
|
||||
query = self._as_query(session)
|
||||
query.session = None
|
||||
|
||||
# in 1.4, this is where before_compile() event is
|
||||
# invoked
|
||||
statement = query._statement_20()
|
||||
|
||||
# if the query is not safe to cache, we still do everything as though
|
||||
# we did cache it, since the receiver of _bake() assumes subqueryload
|
||||
# context was set up, etc.
|
||||
#
|
||||
# note also we want to cache the statement itself because this
|
||||
# allows the statement itself to hold onto its cache key that is
|
||||
# used by the Connection, which in itself is more expensive to
|
||||
# generate than what BakedQuery was able to provide in 1.3 and prior
|
||||
|
||||
if statement._compile_options._bake_ok:
|
||||
self._bakery[self._effective_key(session)] = (
|
||||
query,
|
||||
statement,
|
||||
)
|
||||
|
||||
return query, statement
|
||||
|
||||
def to_query(self, query_or_session):
|
||||
"""Return the :class:`_query.Query` object for use as a subquery.
|
||||
|
||||
This method should be used within the lambda callable being used
|
||||
to generate a step of an enclosing :class:`.BakedQuery`. The
|
||||
parameter should normally be the :class:`_query.Query` object that
|
||||
is passed to the lambda::
|
||||
|
||||
sub_bq = self.bakery(lambda s: s.query(User.name))
|
||||
sub_bq += lambda q: q.filter(
|
||||
User.id == Address.user_id).correlate(Address)
|
||||
|
||||
main_bq = self.bakery(lambda s: s.query(Address))
|
||||
main_bq += lambda q: q.filter(
|
||||
sub_bq.to_query(q).exists())
|
||||
|
||||
In the case where the subquery is used in the first callable against
|
||||
a :class:`.Session`, the :class:`.Session` is also accepted::
|
||||
|
||||
sub_bq = self.bakery(lambda s: s.query(User.name))
|
||||
sub_bq += lambda q: q.filter(
|
||||
User.id == Address.user_id).correlate(Address)
|
||||
|
||||
main_bq = self.bakery(
|
||||
lambda s: s.query(
|
||||
Address.id, sub_bq.to_query(q).scalar_subquery())
|
||||
)
|
||||
|
||||
:param query_or_session: a :class:`_query.Query` object or a class
|
||||
:class:`.Session` object, that is assumed to be within the context
|
||||
of an enclosing :class:`.BakedQuery` callable.
|
||||
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(query_or_session, Session):
|
||||
session = query_or_session
|
||||
elif isinstance(query_or_session, Query):
|
||||
session = query_or_session.session
|
||||
if session is None:
|
||||
raise sa_exc.ArgumentError(
|
||||
"Given Query needs to be associated with a Session"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Query or Session object expected, got %r."
|
||||
% type(query_or_session)
|
||||
)
|
||||
return self._as_query(session)
|
||||
|
||||
def _as_query(self, session):
|
||||
query = self.steps[0](session)
|
||||
|
||||
for step in self.steps[1:]:
|
||||
query = step(query)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class Result(object):
|
||||
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
|
||||
|
||||
The :class:`_baked.Result` object is where the actual :class:`.query.Query`
|
||||
object gets created, or retrieved from the cache,
|
||||
against a target :class:`.Session`, and is then invoked for results.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = "bq", "session", "_params", "_post_criteria"
|
||||
|
||||
def __init__(self, bq, session):
|
||||
self.bq = bq
|
||||
self.session = session
|
||||
self._params = {}
|
||||
self._post_criteria = []
|
||||
|
||||
def params(self, *args, **kw):
|
||||
"""Specify parameters to be replaced into the string SQL statement."""
|
||||
|
||||
if len(args) == 1:
|
||||
kw.update(args[0])
|
||||
elif len(args) > 0:
|
||||
raise sa_exc.ArgumentError(
|
||||
"params() takes zero or one positional argument, "
|
||||
"which is a dictionary."
|
||||
)
|
||||
self._params.update(kw)
|
||||
return self
|
||||
|
||||
def _using_post_criteria(self, fns):
|
||||
if fns:
|
||||
self._post_criteria.extend(fns)
|
||||
return self
|
||||
|
||||
def with_post_criteria(self, fn):
|
||||
"""Add a criteria function that will be applied post-cache.
|
||||
|
||||
This adds a function that will be run against the
|
||||
:class:`_query.Query` object after it is retrieved from the
|
||||
cache. This currently includes **only** the
|
||||
:meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
|
||||
methods.
|
||||
|
||||
.. warning:: :meth:`_baked.Result.with_post_criteria`
|
||||
functions are applied
|
||||
to the :class:`_query.Query`
|
||||
object **after** the query's SQL statement
|
||||
object has been retrieved from the cache. Only
|
||||
:meth:`_query.Query.params` and
|
||||
:meth:`_query.Query.execution_options`
|
||||
methods should be used.
|
||||
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
|
||||
"""
|
||||
return self._using_post_criteria([fn])
|
||||
|
||||
def _as_query(self):
|
||||
q = self.bq._as_query(self.session).params(self._params)
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
return q
|
||||
|
||||
def __str__(self):
|
||||
return str(self._as_query())
|
||||
|
||||
def __iter__(self):
|
||||
return self._iter().__iter__()
|
||||
|
||||
def _iter(self):
|
||||
bq = self.bq
|
||||
|
||||
if not self.session.enable_baked_queries or bq._spoiled:
|
||||
return self._as_query()._iter()
|
||||
|
||||
query, statement = bq._bakery.get(
|
||||
bq._effective_key(self.session), (None, None)
|
||||
)
|
||||
if query is None:
|
||||
query, statement = bq._bake(self.session)
|
||||
|
||||
if self._params:
|
||||
q = query.params(self._params)
|
||||
else:
|
||||
q = query
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
|
||||
params = q._params
|
||||
execution_options = dict(q._execution_options)
|
||||
execution_options.update(
|
||||
{
|
||||
"_sa_orm_load_options": q.load_options,
|
||||
"compiled_cache": bq._bakery,
|
||||
}
|
||||
)
|
||||
|
||||
result = self.session.execute(
|
||||
statement, params, execution_options=execution_options
|
||||
)
|
||||
if result._attributes.get("is_single_entity", False):
|
||||
result = result.scalars()
|
||||
|
||||
if result._attributes.get("filtered", False):
|
||||
result = result.unique()
|
||||
|
||||
return result
|
||||
|
||||
def count(self):
|
||||
"""return the 'count'.
|
||||
|
||||
Equivalent to :meth:`_query.Query.count`.
|
||||
|
||||
Note this uses a subquery to ensure an accurate count regardless
|
||||
of the structure of the original statement.
|
||||
|
||||
.. versionadded:: 1.1.6
|
||||
|
||||
"""
|
||||
|
||||
col = func.count(literal_column("*"))
|
||||
bq = self.bq.with_criteria(lambda q: q._from_self(col))
|
||||
return bq.for_session(self.session).params(self._params).scalar()
|
||||
|
||||
def scalar(self):
|
||||
"""Return the first element of the first result or None
|
||||
if no rows present. If multiple rows are returned,
|
||||
raises MultipleResultsFound.
|
||||
|
||||
Equivalent to :meth:`_query.Query.scalar`.
|
||||
|
||||
.. versionadded:: 1.1.6
|
||||
|
||||
"""
|
||||
try:
|
||||
ret = self.one()
|
||||
if not isinstance(ret, collections_abc.Sequence):
|
||||
return ret
|
||||
return ret[0]
|
||||
except orm_exc.NoResultFound:
|
||||
return None
|
||||
|
||||
def first(self):
|
||||
"""Return the first row.
|
||||
|
||||
Equivalent to :meth:`_query.Query.first`.
|
||||
|
||||
"""
|
||||
|
||||
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
|
||||
return (
|
||||
bq.for_session(self.session)
|
||||
.params(self._params)
|
||||
._using_post_criteria(self._post_criteria)
|
||||
._iter()
|
||||
.first()
|
||||
)
|
||||
|
||||
def one(self):
|
||||
"""Return exactly one result or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_query.Query.one`.
|
||||
|
||||
"""
|
||||
return self._iter().one()
|
||||
|
||||
def one_or_none(self):
|
||||
"""Return one or zero results, or raise an exception for multiple
|
||||
rows.
|
||||
|
||||
Equivalent to :meth:`_query.Query.one_or_none`.
|
||||
|
||||
.. versionadded:: 1.0.9
|
||||
|
||||
"""
|
||||
return self._iter().one_or_none()
|
||||
|
||||
def all(self):
|
||||
"""Return all rows.
|
||||
|
||||
Equivalent to :meth:`_query.Query.all`.
|
||||
|
||||
"""
|
||||
return self._iter().all()
|
||||
|
||||
def get(self, ident):
|
||||
"""Retrieve an object based on identity.
|
||||
|
||||
Equivalent to :meth:`_query.Query.get`.
|
||||
|
||||
"""
|
||||
|
||||
query = self.bq.steps[0](self.session)
|
||||
return query._get_impl(ident, self._load_on_pk_identity)
|
||||
|
||||
def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
|
||||
"""Load the given primary key identity from the database."""
|
||||
|
||||
mapper = query._raw_columns[0]._annotations["parententity"]
|
||||
|
||||
_get_clause, _get_params = mapper._get_clause
|
||||
|
||||
def setup(query):
|
||||
_lcl_get_clause = _get_clause
|
||||
q = query._clone()
|
||||
q._get_condition()
|
||||
q._order_by = None
|
||||
|
||||
# None present in ident - turn those comparisons
|
||||
# into "IS NULL"
|
||||
if None in primary_key_identity:
|
||||
nones = set(
|
||||
[
|
||||
_get_params[col].key
|
||||
for col, value in zip(
|
||||
mapper.primary_key, primary_key_identity
|
||||
)
|
||||
if value is None
|
||||
]
|
||||
)
|
||||
_lcl_get_clause = sql_util.adapt_criterion_to_null(
|
||||
_lcl_get_clause, nones
|
||||
)
|
||||
|
||||
# TODO: can mapper._get_clause be pre-adapted?
|
||||
q._where_criteria = (
|
||||
sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
|
||||
)
|
||||
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
return q
|
||||
|
||||
# cache the query against a key that includes
|
||||
# which positions in the primary key are NULL
|
||||
# (remember, we can map to an OUTER JOIN)
|
||||
bq = self.bq
|
||||
|
||||
# add the clause we got from mapper._get_clause to the cache
|
||||
# key so that if a race causes multiple calls to _get_clause,
|
||||
# we've cached on ours
|
||||
bq = bq._clone()
|
||||
bq._cache_key += (_get_clause,)
|
||||
|
||||
bq = bq.with_criteria(
|
||||
setup, tuple(elem is None for elem in primary_key_identity)
|
||||
)
|
||||
|
||||
params = dict(
|
||||
[
|
||||
(_get_params[primary_key].key, id_val)
|
||||
for id_val, primary_key in zip(
|
||||
primary_key_identity, mapper.primary_key
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = list(bq.for_session(self.session).params(**params))
|
||||
l = len(result)
|
||||
if l > 1:
|
||||
raise orm_exc.MultipleResultsFound()
|
||||
elif l:
|
||||
return result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@util.deprecated(
|
||||
"1.2", "Baked lazy loading is now the default implementation."
|
||||
)
|
||||
def bake_lazy_loaders():
|
||||
"""Enable the use of baked queries for all lazyloaders systemwide.
|
||||
|
||||
The "baked" implementation of lazy loading is now the sole implementation
|
||||
for the base lazy loader; this method has no effect except for a warning.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@util.deprecated(
|
||||
"1.2", "Baked lazy loading is now the default implementation."
|
||||
)
|
||||
def unbake_lazy_loaders():
|
||||
"""Disable the use of baked queries for all lazyloaders systemwide.
|
||||
|
||||
This method now raises NotImplementedError() as the "baked" implementation
|
||||
is the only lazy load implementation. The
|
||||
:paramref:`_orm.relationship.bake_queries` flag may be used to disable
|
||||
the caching of queries on a per-relationship basis.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Baked lazy loading is now the default implementation"
|
||||
)
|
||||
|
||||
|
||||
@strategy_options.loader_option()
|
||||
def baked_lazyload(loadopt, attr):
|
||||
"""Indicate that the given attribute should be loaded using "lazy"
|
||||
loading with a "baked" query used in the load.
|
||||
|
||||
"""
|
||||
return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
|
||||
|
||||
|
||||
@baked_lazyload._add_unbound_fn
|
||||
@util.deprecated(
|
||||
"1.2",
|
||||
"Baked lazy loading is now the default "
|
||||
"implementation for lazy loading.",
|
||||
)
|
||||
def baked_lazyload(*keys):
|
||||
return strategy_options._UnboundLoad._from_keys(
|
||||
strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
|
||||
)
|
||||
|
||||
|
||||
@baked_lazyload._add_unbound_all_fn
|
||||
@util.deprecated(
|
||||
"1.2",
|
||||
"Baked lazy loading is now the default "
|
||||
"implementation for lazy loading.",
|
||||
)
|
||||
def baked_lazyload_all(*keys):
|
||||
return strategy_options._UnboundLoad._from_keys(
|
||||
strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
|
||||
)
|
||||
|
||||
|
||||
baked_lazyload = baked_lazyload._unbound_fn
|
||||
baked_lazyload_all = baked_lazyload_all._unbound_all_fn
|
||||
|
||||
bakery = BakedQuery.bakery
|
||||
613
lib/sqlalchemy/ext/compiler.py
Normal file
613
lib/sqlalchemy/ext/compiler.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# ext/compiler.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
r"""Provides an API for creation of custom ClauseElements and compilers.
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Usage involves the creation of one or more
|
||||
:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
|
||||
more callables defining its compilation::
|
||||
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
class MyColumn(ColumnClause):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(MyColumn)
|
||||
def compile_mycolumn(element, compiler, **kw):
|
||||
return "[%s]" % element.name
|
||||
|
||||
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
|
||||
the base expression element for named column objects. The ``compiles``
|
||||
decorator registers itself with the ``MyColumn`` class so that it is invoked
|
||||
when the object is compiled to a string::
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
s = select(MyColumn('x'), MyColumn('y'))
|
||||
print(str(s))
|
||||
|
||||
Produces::
|
||||
|
||||
SELECT [x], [y]
|
||||
|
||||
Dialect-specific compilation rules
|
||||
==================================
|
||||
|
||||
Compilers can also be made dialect-specific. The appropriate compiler will be
|
||||
invoked for the dialect in use::
|
||||
|
||||
from sqlalchemy.schema import DDLElement
|
||||
|
||||
class AlterColumn(DDLElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, column, cmd):
|
||||
self.column = column
|
||||
self.cmd = cmd
|
||||
|
||||
@compiles(AlterColumn)
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER COLUMN %s ..." % element.column.name
|
||||
|
||||
@compiles(AlterColumn, 'postgresql')
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
|
||||
element.column.name)
|
||||
|
||||
The second ``visit_alter_table`` will be invoked when any ``postgresql``
|
||||
dialect is used.
|
||||
|
||||
.. _compilerext_compiling_subelements:
|
||||
|
||||
Compiling sub-elements of a custom expression construct
|
||||
=======================================================
|
||||
|
||||
The ``compiler`` argument is the
|
||||
:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
|
||||
can be inspected for any information about the in-progress compilation,
|
||||
including ``compiler.dialect``, ``compiler.statement`` etc. The
|
||||
:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
|
||||
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
|
||||
method which can be used for compilation of embedded attributes::
|
||||
|
||||
from sqlalchemy.sql.expression import Executable, ClauseElement
|
||||
|
||||
class InsertFromSelect(Executable, ClauseElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, table, select):
|
||||
self.table = table
|
||||
self.select = select
|
||||
|
||||
@compiles(InsertFromSelect)
|
||||
def visit_insert_from_select(element, compiler, **kw):
|
||||
return "INSERT INTO %s (%s)" % (
|
||||
compiler.process(element.table, asfrom=True, **kw),
|
||||
compiler.process(element.select, **kw)
|
||||
)
|
||||
|
||||
insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
|
||||
print(insert)
|
||||
|
||||
Produces::
|
||||
|
||||
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
|
||||
FROM mytable WHERE mytable.x > :x_1)"
|
||||
|
||||
.. note::
|
||||
|
||||
The above ``InsertFromSelect`` construct is only an example, this actual
|
||||
functionality is already available using the
|
||||
:meth:`_expression.Insert.from_select` method.
|
||||
|
||||
.. note::
|
||||
|
||||
The above ``InsertFromSelect`` construct probably wants to have "autocommit"
|
||||
enabled. See :ref:`enabling_compiled_autocommit` for this step.
|
||||
|
||||
Cross Compiling between SQL and DDL compilers
|
||||
---------------------------------------------
|
||||
|
||||
SQL and DDL constructs are each compiled using different base compilers -
|
||||
``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
|
||||
compilation rules of SQL expressions from within a DDL expression. The
|
||||
``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
|
||||
below where we generate a CHECK constraint that embeds a SQL expression::
|
||||
|
||||
@compiles(MyConstraint)
|
||||
def compile_my_constraint(constraint, ddlcompiler, **kw):
|
||||
kw['literal_binds'] = True
|
||||
return "CONSTRAINT %s CHECK (%s)" % (
|
||||
constraint.name,
|
||||
ddlcompiler.sql_compiler.process(
|
||||
constraint.expression, **kw)
|
||||
)
|
||||
|
||||
Above, we add an additional flag to the process step as called by
|
||||
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
|
||||
indicates that any SQL expression which refers to a :class:`.BindParameter`
|
||||
object or other "literal" object such as those which refer to strings or
|
||||
integers should be rendered **in-place**, rather than being referred to as
|
||||
a bound parameter; when emitting DDL, bound parameters are typically not
|
||||
supported.
|
||||
|
||||
|
||||
.. _enabling_compiled_autocommit:
|
||||
|
||||
Enabling Autocommit on a Construct
|
||||
==================================
|
||||
|
||||
Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`,
|
||||
when
|
||||
asked to execute a construct in the absence of a user-defined transaction,
|
||||
detects if the given construct represents DML or DDL, that is, a data
|
||||
modification or data definition statement, which requires (or may require,
|
||||
in the case of DDL) that the transaction generated by the DBAPI be committed
|
||||
(recall that DBAPI always has a transaction going on regardless of what
|
||||
SQLAlchemy does). Checking for this is actually accomplished by checking for
|
||||
the "autocommit" execution option on the construct. When building a
|
||||
construct like an INSERT derivation, a new DDL type, or perhaps a stored
|
||||
procedure that alters data, the "autocommit" option needs to be set in order
|
||||
for the statement to function with "connectionless" execution
|
||||
(as described in :ref:`dbengine_implicit`).
|
||||
|
||||
Currently a quick way to do this is to subclass :class:`.Executable`, then
|
||||
add the "autocommit" flag to the ``_execution_options`` dictionary (note this
|
||||
is a "frozen" dictionary which supplies a generative ``union()`` method)::
|
||||
|
||||
from sqlalchemy.sql.expression import Executable, ClauseElement
|
||||
|
||||
class MyInsertThing(Executable, ClauseElement):
|
||||
_execution_options = \
|
||||
Executable._execution_options.union({'autocommit': True})
|
||||
|
||||
More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
|
||||
DELETE, :class:`.UpdateBase` can be used, which already is a subclass
|
||||
of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the
|
||||
``autocommit`` flag::
|
||||
|
||||
from sqlalchemy.sql.expression import UpdateBase
|
||||
|
||||
class MyInsertThing(UpdateBase):
|
||||
def __init__(self, ...):
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
DDL elements that subclass :class:`.DDLElement` already have the
|
||||
"autocommit" flag turned on.
|
||||
|
||||
|
||||
|
||||
|
||||
Changing the default compilation of existing constructs
|
||||
=======================================================
|
||||
|
||||
The compiler extension applies just as well to the existing constructs. When
|
||||
overriding the compilation of a built in SQL construct, the @compiles
|
||||
decorator is invoked upon the appropriate class (be sure to use the class,
|
||||
i.e. ``Insert`` or ``Select``, instead of the creation function such
|
||||
as ``insert()`` or ``select()``).
|
||||
|
||||
Within the new compilation function, to get at the "original" compilation
|
||||
routine, use the appropriate visit_XXX method - this
|
||||
because compiler.process() will call upon the overriding routine and cause
|
||||
an endless loop. Such as, to add "prefix" to all insert statements::
|
||||
|
||||
from sqlalchemy.sql.expression import Insert
|
||||
|
||||
@compiles(Insert)
|
||||
def prefix_inserts(insert, compiler, **kw):
|
||||
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
|
||||
|
||||
The above compiler will prefix all INSERT statements with "some prefix" when
|
||||
compiled.
|
||||
|
||||
.. _type_compilation_extension:
|
||||
|
||||
Changing Compilation of Types
|
||||
=============================
|
||||
|
||||
``compiler`` works for types, too, such as below where we implement the
|
||||
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
|
||||
|
||||
@compiles(String, 'mssql')
|
||||
@compiles(VARCHAR, 'mssql')
|
||||
def compile_varchar(element, compiler, **kw):
|
||||
if element.length == 'max':
|
||||
return "VARCHAR('max')"
|
||||
else:
|
||||
return compiler.visit_VARCHAR(element, **kw)
|
||||
|
||||
foo = Table('foo', metadata,
|
||||
Column('data', VARCHAR('max'))
|
||||
)
|
||||
|
||||
Subclassing Guidelines
|
||||
======================
|
||||
|
||||
A big part of using the compiler extension is subclassing SQLAlchemy
|
||||
expression constructs. To make this easier, the expression and
|
||||
schema packages feature a set of "bases" intended for common tasks.
|
||||
A synopsis is as follows:
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
|
||||
expression class. Any SQL expression can be derived from this base, and is
|
||||
probably the best choice for longer constructs such as specialized INSERT
|
||||
statements.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
|
||||
"column-like" elements. Anything that you'd place in the "columns" clause of
|
||||
a SELECT statement (as well as order by and group by) can derive from this -
|
||||
the object will automatically have Python "comparison" behavior.
|
||||
|
||||
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
|
||||
``type`` member which is expression's return type. This can be established
|
||||
at the instance level in the constructor, or at the class level if its
|
||||
generally constant::
|
||||
|
||||
class timestamp(ColumnElement):
|
||||
type = TIMESTAMP()
|
||||
inherit_cache = True
|
||||
|
||||
* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
|
||||
``ColumnElement`` and a "from clause" like object, and represents a SQL
|
||||
function or stored procedure type of call. Since most databases support
|
||||
statements along the line of "SELECT FROM <some function>"
|
||||
``FunctionElement`` adds in the ability to be used in the FROM clause of a
|
||||
``select()`` construct::
|
||||
|
||||
from sqlalchemy.sql.expression import FunctionElement
|
||||
|
||||
class coalesce(FunctionElement):
|
||||
name = 'coalesce'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(coalesce)
|
||||
def compile(element, compiler, **kw):
|
||||
return "coalesce(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
@compiles(coalesce, 'oracle')
|
||||
def compile(element, compiler, **kw):
|
||||
if len(element.clauses) > 2:
|
||||
raise TypeError("coalesce only supports two arguments on Oracle")
|
||||
return "nvl(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
* :class:`.DDLElement` - The root of all DDL expressions,
|
||||
like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement`
|
||||
subclasses is issued by a :class:`.DDLCompiler` instead of a
|
||||
:class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook
|
||||
in conjunction with event hooks like :meth:`.DDLEvents.before_create` and
|
||||
:meth:`.DDLEvents.after_create`, allowing the construct to be invoked
|
||||
automatically during CREATE TABLE and DROP TABLE sequences.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`metadata_ddl_toplevel` - contains examples of associating
|
||||
:class:`.DDL` objects (which are themselves :class:`.DDLElement`
|
||||
instances) with :class:`.DDLEvents` event hooks.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
|
||||
should be used with any expression class that represents a "standalone"
|
||||
SQL statement that can be passed directly to an ``execute()`` method. It
|
||||
is already implicit within ``DDLElement`` and ``FunctionElement``.
|
||||
|
||||
Most of the above constructs also respond to SQL statement caching. A
|
||||
subclassed construct will want to define the caching behavior for the object,
|
||||
which usually means setting the flag ``inherit_cache`` to the value of
|
||||
``False`` or ``True``. See the next section :ref:`compilerext_caching`
|
||||
for background.
|
||||
|
||||
|
||||
.. _compilerext_caching:
|
||||
|
||||
Enabling Caching Support for Custom Constructs
|
||||
==============================================
|
||||
|
||||
SQLAlchemy as of version 1.4 includes a
|
||||
:ref:`SQL compilation caching facility <sql_caching>` which will allow
|
||||
equivalent SQL constructs to cache their stringified form, along with other
|
||||
structural information used to fetch results from the statement.
|
||||
|
||||
For reasons discussed at :ref:`caching_caveats`, the implementation of this
|
||||
caching system takes a conservative approach towards including custom SQL
|
||||
constructs and/or subclasses within the caching system. This includes that
|
||||
any user-defined SQL constructs, including all the examples for this
|
||||
extension, will not participate in caching by default unless they positively
|
||||
assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
|
||||
attribute when set to ``True`` at the class level of a specific subclass
|
||||
will indicate that instances of this class may be safely cached, using the
|
||||
cache key generation scheme of the immediate superclass. This applies
|
||||
for example to the "synopsis" example indicated previously::
|
||||
|
||||
class MyColumn(ColumnClause):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(MyColumn)
|
||||
def compile_mycolumn(element, compiler, **kw):
|
||||
return "[%s]" % element.name
|
||||
|
||||
Above, the ``MyColumn`` class does not include any new state that
|
||||
affects its SQL compilation; the cache key of ``MyColumn`` instances will
|
||||
make use of that of the ``ColumnClause`` superclass, meaning it will take
|
||||
into account the class of the object (``MyColumn``), the string name and
|
||||
datatype of the object::
|
||||
|
||||
>>> MyColumn("some_name", String())._generate_cache_key()
|
||||
CacheKey(
|
||||
key=('0', <class '__main__.MyColumn'>,
|
||||
'name', 'some_name',
|
||||
'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
|
||||
('length', None), ('collation', None))
|
||||
), bindparams=[])
|
||||
|
||||
For objects that are likely to be **used liberally as components within many
|
||||
larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
|
||||
datatypes, it's important that **caching be enabled as much as possible**, as
|
||||
this may otherwise negatively affect performance.
|
||||
|
||||
An example of an object that **does** contain state which affects its SQL
|
||||
compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
|
||||
this is an "INSERT FROM SELECT" construct that combines together a
|
||||
:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
|
||||
which independently affect the SQL string generation of the construct. For
|
||||
this class, the example illustrates that it simply does not participate in
|
||||
caching::
|
||||
|
||||
class InsertFromSelect(Executable, ClauseElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, table, select):
|
||||
self.table = table
|
||||
self.select = select
|
||||
|
||||
@compiles(InsertFromSelect)
|
||||
def visit_insert_from_select(element, compiler, **kw):
|
||||
return "INSERT INTO %s (%s)" % (
|
||||
compiler.process(element.table, asfrom=True, **kw),
|
||||
compiler.process(element.select, **kw)
|
||||
)
|
||||
|
||||
While it is also possible that the above ``InsertFromSelect`` could be made to
|
||||
produce a cache key that is composed of that of the :class:`_schema.Table` and
|
||||
:class:`_sql.Select` components together, the API for this is not at the moment
|
||||
fully public. However, for an "INSERT FROM SELECT" construct, which is only
|
||||
used by itself for specific operations, caching is not as critical as in the
|
||||
previous example.
|
||||
|
||||
For objects that are **used in relative isolation and are generally
|
||||
standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
|
||||
SELECT", **caching is generally less critical** as the lack of caching for such
|
||||
a construct will have only localized implications for that specific operation.
|
||||
|
||||
|
||||
Further Examples
|
||||
================
|
||||
|
||||
"UTC timestamp" function
|
||||
-------------------------
|
||||
|
||||
A function that works like "CURRENT_TIMESTAMP" except applies the
|
||||
appropriate conversions so that the time is in UTC time. Timestamps are best
|
||||
stored in relational databases as UTC, without time zones. UTC so that your
|
||||
database doesn't think time has gone backwards in the hour when daylight
|
||||
savings ends, without timezones because timezones are like character
|
||||
encodings - they're best applied only at the endpoints of an application
|
||||
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
|
||||
|
||||
For PostgreSQL and Microsoft SQL Server::
|
||||
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.types import DateTime
|
||||
|
||||
class utcnow(expression.FunctionElement):
|
||||
type = DateTime()
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(utcnow, 'postgresql')
|
||||
def pg_utcnow(element, compiler, **kw):
|
||||
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
|
||||
|
||||
@compiles(utcnow, 'mssql')
|
||||
def ms_utcnow(element, compiler, **kw):
|
||||
return "GETUTCDATE()"
|
||||
|
||||
Example usage::
|
||||
|
||||
from sqlalchemy import (
|
||||
Table, Column, Integer, String, DateTime, MetaData
|
||||
)
|
||||
metadata = MetaData()
|
||||
event = Table("event", metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("description", String(50), nullable=False),
|
||||
Column("timestamp", DateTime, server_default=utcnow())
|
||||
)
|
||||
|
||||
"GREATEST" function
|
||||
-------------------
|
||||
|
||||
The "GREATEST" function is given any number of arguments and returns the one
|
||||
that is of the highest value - its equivalent to Python's ``max``
|
||||
function. A SQL standard version versus a CASE based version which only
|
||||
accommodates two arguments::
|
||||
|
||||
from sqlalchemy.sql import expression, case
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.types import Numeric
|
||||
|
||||
class greatest(expression.FunctionElement):
|
||||
type = Numeric()
|
||||
name = 'greatest'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(greatest)
|
||||
def default_greatest(element, compiler, **kw):
|
||||
return compiler.visit_function(element)
|
||||
|
||||
@compiles(greatest, 'sqlite')
|
||||
@compiles(greatest, 'mssql')
|
||||
@compiles(greatest, 'oracle')
|
||||
def case_greatest(element, compiler, **kw):
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw)
|
||||
|
||||
Example usage::
|
||||
|
||||
Session.query(Account).\
|
||||
filter(
|
||||
greatest(
|
||||
Account.checking_balance,
|
||||
Account.savings_balance) > 10000
|
||||
)
|
||||
|
||||
"false" expression
|
||||
------------------
|
||||
|
||||
Render a "false" constant expression, rendering as "0" on platforms that
|
||||
don't have a "false" constant::
|
||||
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
class sql_false(expression.ColumnElement):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(sql_false)
|
||||
def default_false(element, compiler, **kw):
|
||||
return "false"
|
||||
|
||||
@compiles(sql_false, 'mssql')
|
||||
@compiles(sql_false, 'mysql')
|
||||
@compiles(sql_false, 'oracle')
|
||||
def int_false(element, compiler, **kw):
|
||||
return "0"
|
||||
|
||||
Example usage::
|
||||
|
||||
from sqlalchemy import select, union_all
|
||||
|
||||
exp = union_all(
|
||||
select(users.c.name, sql_false().label("enrolled")),
|
||||
select(customers.c.name, customers.c.enrolled)
|
||||
)
|
||||
|
||||
"""
|
||||
from .. import exc
|
||||
from .. import util
|
||||
from ..sql import sqltypes
|
||||
|
||||
|
||||
def compiles(class_, *specs):
|
||||
"""Register a function as a compiler for a
|
||||
given :class:`_expression.ClauseElement` type."""
|
||||
|
||||
def decorate(fn):
|
||||
# get an existing @compiles handler
|
||||
existing = class_.__dict__.get("_compiler_dispatcher", None)
|
||||
|
||||
# get the original handler. All ClauseElement classes have one
|
||||
# of these, but some TypeEngine classes will not.
|
||||
existing_dispatch = getattr(class_, "_compiler_dispatch", None)
|
||||
|
||||
if not existing:
|
||||
existing = _dispatcher()
|
||||
|
||||
if existing_dispatch:
|
||||
|
||||
def _wrap_existing_dispatch(element, compiler, **kw):
|
||||
try:
|
||||
return existing_dispatch(element, compiler, **kw)
|
||||
except exc.UnsupportedCompilationError as uce:
|
||||
util.raise_(
|
||||
exc.UnsupportedCompilationError(
|
||||
compiler,
|
||||
type(element),
|
||||
message="%s construct has no default "
|
||||
"compilation handler." % type(element),
|
||||
),
|
||||
from_=uce,
|
||||
)
|
||||
|
||||
existing.specs["default"] = _wrap_existing_dispatch
|
||||
|
||||
# TODO: why is the lambda needed ?
|
||||
setattr(
|
||||
class_,
|
||||
"_compiler_dispatch",
|
||||
lambda *arg, **kw: existing(*arg, **kw),
|
||||
)
|
||||
setattr(class_, "_compiler_dispatcher", existing)
|
||||
|
||||
if specs:
|
||||
for s in specs:
|
||||
existing.specs[s] = fn
|
||||
|
||||
else:
|
||||
existing.specs["default"] = fn
|
||||
return fn
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def deregister(class_):
|
||||
"""Remove all custom compilers associated with a given
|
||||
:class:`_expression.ClauseElement` type.
|
||||
|
||||
"""
|
||||
|
||||
if hasattr(class_, "_compiler_dispatcher"):
|
||||
class_._compiler_dispatch = class_._original_compiler_dispatch
|
||||
del class_._compiler_dispatcher
|
||||
|
||||
|
||||
class _dispatcher(object):
|
||||
def __init__(self):
|
||||
self.specs = {}
|
||||
|
||||
def __call__(self, element, compiler, **kw):
|
||||
# TODO: yes, this could also switch off of DBAPI in use.
|
||||
fn = self.specs.get(compiler.dialect.name, None)
|
||||
if not fn:
|
||||
try:
|
||||
fn = self.specs["default"]
|
||||
except KeyError as ke:
|
||||
util.raise_(
|
||||
exc.UnsupportedCompilationError(
|
||||
compiler,
|
||||
type(element),
|
||||
message="%s construct has no default "
|
||||
"compilation handler." % type(element),
|
||||
),
|
||||
replace_context=ke,
|
||||
)
|
||||
|
||||
# if compilation includes add_to_result_map, collect add_to_result_map
|
||||
# arguments from the user-defined callable, which are probably none
|
||||
# because this is not public API. if it wasn't called, then call it
|
||||
# ourselves.
|
||||
arm = kw.get("add_to_result_map", None)
|
||||
if arm:
|
||||
arm_collection = []
|
||||
kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
|
||||
|
||||
expr = fn(element, compiler, **kw)
|
||||
|
||||
if arm:
|
||||
if not arm_collection:
|
||||
arm_collection.append(
|
||||
(None, None, (element,), sqltypes.NULLTYPE)
|
||||
)
|
||||
for tup in arm_collection:
|
||||
arm(*tup)
|
||||
return expr
|
||||
64
lib/sqlalchemy/ext/declarative/__init__.py
Normal file
64
lib/sqlalchemy/ext/declarative/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# ext/declarative/__init__.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .extensions import AbstractConcreteBase
|
||||
from .extensions import ConcreteBase
|
||||
from .extensions import DeferredReflection
|
||||
from .extensions import instrument_declarative
|
||||
from ... import util
|
||||
from ...orm.decl_api import as_declarative as _as_declarative
|
||||
from ...orm.decl_api import declarative_base as _declarative_base
|
||||
from ...orm.decl_api import DeclarativeMeta
|
||||
from ...orm.decl_api import declared_attr
|
||||
from ...orm.decl_api import has_inherited_table as _has_inherited_table
|
||||
from ...orm.decl_api import synonym_for as _synonym_for
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``declarative_base()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.declarative_base`."
|
||||
)
|
||||
def declarative_base(*arg, **kw):
|
||||
return _declarative_base(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``as_declarative()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.as_declarative`"
|
||||
)
|
||||
def as_declarative(*arg, **kw):
|
||||
return _as_declarative(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``has_inherited_table()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.has_inherited_table`."
|
||||
)
|
||||
def has_inherited_table(*arg, **kw):
|
||||
return _has_inherited_table(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``synonym_for()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.synonym_for`"
|
||||
)
|
||||
def synonym_for(*arg, **kw):
|
||||
return _synonym_for(*arg, **kw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"declarative_base",
|
||||
"synonym_for",
|
||||
"has_inherited_table",
|
||||
"instrument_declarative",
|
||||
"declared_attr",
|
||||
"as_declarative",
|
||||
"ConcreteBase",
|
||||
"AbstractConcreteBase",
|
||||
"DeclarativeMeta",
|
||||
"DeferredReflection",
|
||||
]
|
||||
463
lib/sqlalchemy/ext/declarative/extensions.py
Normal file
463
lib/sqlalchemy/ext/declarative/extensions.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# ext/declarative/extensions.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
"""Public API functions and helpers for declarative."""
|
||||
|
||||
|
||||
from ... import inspection
|
||||
from ... import util
|
||||
from ...orm import exc as orm_exc
|
||||
from ...orm import registry
|
||||
from ...orm import relationships
|
||||
from ...orm.base import _mapper_or_none
|
||||
from ...orm.clsregistry import _resolver
|
||||
from ...orm.decl_base import _DeferredMapperConfig
|
||||
from ...orm.util import polymorphic_union
|
||||
from ...schema import Table
|
||||
from ...util import OrderedDict
|
||||
|
||||
|
||||
@util.deprecated(
|
||||
"2.0",
|
||||
"the instrument_declarative function is deprecated "
|
||||
"and will be removed in SQLAlhcemy 2.0. Please use "
|
||||
":meth:`_orm.registry.map_declaratively",
|
||||
)
|
||||
def instrument_declarative(cls, cls_registry, metadata):
|
||||
"""Given a class, configure the class declaratively,
|
||||
using the given registry, which can be any dictionary, and
|
||||
MetaData object.
|
||||
|
||||
"""
|
||||
registry(metadata=metadata, class_registry=cls_registry).map_declaratively(
|
||||
cls
|
||||
)
|
||||
|
||||
|
||||
class ConcreteBase(object):
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_last__()`` function, which is essentially
|
||||
a hook for the :meth:`.after_configured` event.
|
||||
|
||||
:class:`.ConcreteBase` produces a mapped
|
||||
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
|
||||
which does not.
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.ext.declarative import ConcreteBase
|
||||
|
||||
class Employee(ConcreteBase, Base):
|
||||
__tablename__ = 'employee'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'employee',
|
||||
'concrete':True}
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
|
||||
The name of the discriminator column used by :func:`.polymorphic_union`
|
||||
defaults to the name ``type``. To suit the use case of a mapping where an
|
||||
actual column in a mapped table is already named ``type``, the
|
||||
discriminator name can be configured by setting the
|
||||
``_concrete_discriminator_name`` attribute::
|
||||
|
||||
class Employee(ConcreteBase, Base):
|
||||
_concrete_discriminator_name = '_concrete_discriminator'
|
||||
|
||||
.. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
|
||||
attribute to :class:`_declarative.ConcreteBase` so that the
|
||||
virtual discriminator column name can be customized.
|
||||
|
||||
.. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
|
||||
need only be placed on the basemost class to take correct effect for
|
||||
all subclasses. An explicit error message is now raised if the
|
||||
mapped column names conflict with the discriminator name, whereas
|
||||
in the 1.3.x series there would be some warnings and then a non-useful
|
||||
query would be generated.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.AbstractConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _create_polymorphic_union(cls, mappers, discriminator_name):
|
||||
return polymorphic_union(
|
||||
OrderedDict(
|
||||
(mp.polymorphic_identity, mp.local_table) for mp in mappers
|
||||
),
|
||||
discriminator_name,
|
||||
"pjoin",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
m = cls.__mapper__
|
||||
if m.with_polymorphic:
|
||||
return
|
||||
|
||||
discriminator_name = (
|
||||
getattr(cls, "_concrete_discriminator_name", None) or "type"
|
||||
)
|
||||
|
||||
mappers = list(m.self_and_descendants)
|
||||
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
|
||||
m._set_with_polymorphic(("*", pjoin))
|
||||
m._set_polymorphic_on(pjoin.c[discriminator_name])
|
||||
|
||||
|
||||
class AbstractConcreteBase(ConcreteBase):
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_last__()`` function, which is essentially
|
||||
a hook for the :meth:`.after_configured` event.
|
||||
|
||||
:class:`.AbstractConcreteBase` does produce a mapped class
|
||||
for the base class, however it is not persisted to any table; it
|
||||
is instead mapped directly to the "polymorphic" selectable directly
|
||||
and is only used for selecting. Compare to :class:`.ConcreteBase`,
|
||||
which does create a persisted table for the base class.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.AbstractConcreteBase` class does not intend to set up the
|
||||
mapping for the base class until all the subclasses have been defined,
|
||||
as it needs to create a mapping against a selectable that will include
|
||||
all subclass tables. In order to achieve this, it waits for the
|
||||
**mapper configuration event** to occur, at which point it scans
|
||||
through all the configured subclasses and sets up a mapping that will
|
||||
query against all subclasses at once.
|
||||
|
||||
While this event is normally invoked automatically, in the case of
|
||||
:class:`.AbstractConcreteBase`, it may be necessary to invoke it
|
||||
explicitly after **all** subclass mappings are defined, if the first
|
||||
operation is to be a query against this base class. To do so, invoke
|
||||
:func:`.configure_mappers` once all the desired classes have been
|
||||
configured::
|
||||
|
||||
from sqlalchemy.orm import configure_mappers
|
||||
|
||||
configure_mappers()
|
||||
|
||||
.. seealso::
|
||||
|
||||
:func:`_orm.configure_mappers`
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
pass
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
configure_mappers()
|
||||
|
||||
The abstract base class is handled by declarative in a special way;
|
||||
at class configuration time, it behaves like a declarative mixin
|
||||
or an ``__abstract__`` base class. Once classes are configured
|
||||
and mappings are produced, it then gets mapped itself, but
|
||||
after all of its descendants. This is a very unique system of mapping
|
||||
not found in any other SQLAlchemy system.
|
||||
|
||||
Using this approach, we can specify columns and properties
|
||||
that will take place on mapped subclasses, in the way that
|
||||
we normally do as in :ref:`declarative_mixins`::
|
||||
|
||||
class Company(Base):
|
||||
__tablename__ = 'company'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
|
||||
@declared_attr
|
||||
def company_id(cls):
|
||||
return Column(ForeignKey('company.id'))
|
||||
|
||||
@declared_attr
|
||||
def company(cls):
|
||||
return relationship("Company")
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
configure_mappers()
|
||||
|
||||
When we make use of our mappings however, both ``Manager`` and
|
||||
``Employee`` will have an independently usable ``.company`` attribute::
|
||||
|
||||
session.query(Employee).filter(Employee.company.has(id=5))
|
||||
|
||||
.. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
|
||||
have been reworked to support relationships established directly
|
||||
on the abstract base, without any special configurational steps.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.ConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
"""
|
||||
|
||||
__no_table__ = True
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
cls._sa_decl_prepare_nocascade()
|
||||
|
||||
@classmethod
|
||||
def _sa_decl_prepare_nocascade(cls):
|
||||
if getattr(cls, "__mapper__", None):
|
||||
return
|
||||
|
||||
to_map = _DeferredMapperConfig.config_for_cls(cls)
|
||||
|
||||
# can't rely on 'self_and_descendants' here
|
||||
# since technically an immediate subclass
|
||||
# might not be mapped, but a subclass
|
||||
# may be.
|
||||
mappers = []
|
||||
stack = list(cls.__subclasses__())
|
||||
while stack:
|
||||
klass = stack.pop()
|
||||
stack.extend(klass.__subclasses__())
|
||||
mn = _mapper_or_none(klass)
|
||||
if mn is not None:
|
||||
mappers.append(mn)
|
||||
|
||||
discriminator_name = (
|
||||
getattr(cls, "_concrete_discriminator_name", None) or "type"
|
||||
)
|
||||
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
|
||||
|
||||
# For columns that were declared on the class, these
|
||||
# are normally ignored with the "__no_table__" mapping,
|
||||
# unless they have a different attribute key vs. col name
|
||||
# and are in the properties argument.
|
||||
# In that case, ensure we update the properties entry
|
||||
# to the correct column from the pjoin target table.
|
||||
declared_cols = set(to_map.declared_columns)
|
||||
for k, v in list(to_map.properties.items()):
|
||||
if v in declared_cols:
|
||||
to_map.properties[k] = pjoin.c[v.key]
|
||||
|
||||
to_map.local_table = pjoin
|
||||
|
||||
m_args = to_map.mapper_args_fn or dict
|
||||
|
||||
def mapper_args():
|
||||
args = m_args()
|
||||
args["polymorphic_on"] = pjoin.c[discriminator_name]
|
||||
return args
|
||||
|
||||
to_map.mapper_args_fn = mapper_args
|
||||
|
||||
m = to_map.map()
|
||||
|
||||
for scls in cls.__subclasses__():
|
||||
sm = _mapper_or_none(scls)
|
||||
if sm and sm.concrete and cls in scls.__bases__:
|
||||
sm._set_concrete_base(m)
|
||||
|
||||
@classmethod
|
||||
def _sa_raise_deferred_config(cls):
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls,
|
||||
msg="Class %s is a subclass of AbstractConcreteBase and "
|
||||
"has a mapping pending until all subclasses are defined. "
|
||||
"Call the sqlalchemy.orm.configure_mappers() function after "
|
||||
"all subclasses have been defined to "
|
||||
"complete the mapping of this class."
|
||||
% orm_exc._safe_cls_name(cls),
|
||||
)
|
||||
|
||||
|
||||
class DeferredReflection(object):
|
||||
"""A helper class for construction of mappings based on
|
||||
a deferred reflection step.
|
||||
|
||||
Normally, declarative can be used with reflection by
|
||||
setting a :class:`_schema.Table` object using autoload_with=engine
|
||||
as the ``__table__`` attribute on a declarative class.
|
||||
The caveat is that the :class:`_schema.Table` must be fully
|
||||
reflected, or at the very least have a primary key column,
|
||||
at the point at which a normal declarative mapping is
|
||||
constructed, meaning the :class:`_engine.Engine` must be available
|
||||
at class declaration time.
|
||||
|
||||
The :class:`.DeferredReflection` mixin moves the construction
|
||||
of mappers to be at a later point, after a specific
|
||||
method is called which first reflects all :class:`_schema.Table`
|
||||
objects created so far. Classes can define it as such::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import DeferredReflection
|
||||
Base = declarative_base()
|
||||
|
||||
class MyClass(DeferredReflection, Base):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
Above, ``MyClass`` is not yet mapped. After a series of
|
||||
classes have been defined in the above fashion, all tables
|
||||
can be reflected and mappings created using
|
||||
:meth:`.prepare`::
|
||||
|
||||
engine = create_engine("someengine://...")
|
||||
DeferredReflection.prepare(engine)
|
||||
|
||||
The :class:`.DeferredReflection` mixin can be applied to individual
|
||||
classes, used as the base for the declarative base itself,
|
||||
or used in a custom abstract class. Using an abstract base
|
||||
allows that only a subset of classes to be prepared for a
|
||||
particular prepare step, which is necessary for applications
|
||||
that use more than one engine. For example, if an application
|
||||
has two engines, you might use two bases, and prepare each
|
||||
separately, e.g.::
|
||||
|
||||
class ReflectedOne(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class ReflectedTwo(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class MyClass(ReflectedOne):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
class MyOtherClass(ReflectedOne):
|
||||
__tablename__ = 'myothertable'
|
||||
|
||||
class YetAnotherClass(ReflectedTwo):
|
||||
__tablename__ = 'yetanothertable'
|
||||
|
||||
# ... etc.
|
||||
|
||||
Above, the class hierarchies for ``ReflectedOne`` and
|
||||
``ReflectedTwo`` can be configured separately::
|
||||
|
||||
ReflectedOne.prepare(engine_one)
|
||||
ReflectedTwo.prepare(engine_two)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`orm_declarative_reflected_deferred_reflection` - in the
|
||||
:ref:`orm_declarative_table_config_toplevel` section.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def prepare(cls, engine):
|
||||
"""Reflect all :class:`_schema.Table` objects for all current
|
||||
:class:`.DeferredReflection` subclasses"""
|
||||
|
||||
to_map = _DeferredMapperConfig.classes_for_base(cls)
|
||||
|
||||
with inspection.inspect(engine)._inspection_context() as insp:
|
||||
for thingy in to_map:
|
||||
cls._sa_decl_prepare(thingy.local_table, insp)
|
||||
thingy.map()
|
||||
mapper = thingy.cls.__mapper__
|
||||
metadata = mapper.class_.metadata
|
||||
for rel in mapper._props.values():
|
||||
if (
|
||||
isinstance(rel, relationships.RelationshipProperty)
|
||||
and rel.secondary is not None
|
||||
):
|
||||
if isinstance(rel.secondary, Table):
|
||||
cls._reflect_table(rel.secondary, insp)
|
||||
elif isinstance(rel.secondary, str):
|
||||
|
||||
_, resolve_arg = _resolver(rel.parent.class_, rel)
|
||||
|
||||
rel.secondary = resolve_arg(rel.secondary)
|
||||
rel.secondary._resolvers += (
|
||||
cls._sa_deferred_table_resolver(
|
||||
insp, metadata
|
||||
),
|
||||
)
|
||||
|
||||
# controversy! do we resolve it here? or leave
|
||||
# it deferred? I think doing it here is necessary
|
||||
# so the connection does not leak.
|
||||
rel.secondary = rel.secondary()
|
||||
|
||||
@classmethod
|
||||
def _sa_deferred_table_resolver(cls, inspector, metadata):
|
||||
def _resolve(key):
|
||||
t1 = Table(key, metadata)
|
||||
cls._reflect_table(t1, inspector)
|
||||
return t1
|
||||
|
||||
return _resolve
|
||||
|
||||
@classmethod
|
||||
def _sa_decl_prepare(cls, local_table, inspector):
|
||||
# autoload Table, which is already
|
||||
# present in the metadata. This
|
||||
# will fill in db-loaded columns
|
||||
# into the existing Table object.
|
||||
if local_table is not None:
|
||||
cls._reflect_table(local_table, inspector)
|
||||
|
||||
@classmethod
|
||||
def _sa_raise_deferred_config(cls):
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls,
|
||||
msg="Class %s is a subclass of DeferredReflection. "
|
||||
"Mappings are not produced until the .prepare() "
|
||||
"method is called on the class hierarchy."
|
||||
% orm_exc._safe_cls_name(cls),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _reflect_table(cls, table, inspector):
|
||||
Table(
|
||||
table.name,
|
||||
table.metadata,
|
||||
extend_existing=True,
|
||||
autoload_replace=False,
|
||||
autoload_with=inspector,
|
||||
schema=table.schema,
|
||||
)
|
||||
256
lib/sqlalchemy/ext/horizontal_shard.py
Normal file
256
lib/sqlalchemy/ext/horizontal_shard.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# ext/horizontal_shard.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Horizontal sharding support.
|
||||
|
||||
Defines a rudimental 'horizontal sharding' system which allows a Session to
|
||||
distribute queries and persistence operations across multiple databases.
|
||||
|
||||
For a usage example, see the :ref:`examples_sharding` example included in
|
||||
the source distribution.
|
||||
|
||||
"""
|
||||
|
||||
from .. import event
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from .. import util
|
||||
from ..orm.query import Query
|
||||
from ..orm.session import Session
|
||||
|
||||
__all__ = ["ShardedSession", "ShardedQuery"]
|
||||
|
||||
|
||||
class ShardedQuery(Query):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ShardedQuery, self).__init__(*args, **kwargs)
|
||||
self.id_chooser = self.session.id_chooser
|
||||
self.query_chooser = self.session.query_chooser
|
||||
self.execute_chooser = self.session.execute_chooser
|
||||
self._shard_id = None
|
||||
|
||||
def set_shard(self, shard_id):
|
||||
"""Return a new query, limited to a single shard ID.
|
||||
|
||||
All subsequent operations with the returned query will
|
||||
be against the single shard regardless of other state.
|
||||
|
||||
The shard_id can be passed for a 2.0 style execution to the
|
||||
bind_arguments dictionary of :meth:`.Session.execute`::
|
||||
|
||||
results = session.execute(
|
||||
stmt,
|
||||
bind_arguments={"shard_id": "my_shard"}
|
||||
)
|
||||
|
||||
"""
|
||||
return self.execution_options(_sa_shard_id=shard_id)
|
||||
|
||||
|
||||
class ShardedSession(Session):
|
||||
def __init__(
|
||||
self,
|
||||
shard_chooser,
|
||||
id_chooser,
|
||||
execute_chooser=None,
|
||||
shards=None,
|
||||
query_cls=ShardedQuery,
|
||||
**kwargs
|
||||
):
|
||||
"""Construct a ShardedSession.
|
||||
|
||||
:param shard_chooser: A callable which, passed a Mapper, a mapped
|
||||
instance, and possibly a SQL clause, returns a shard ID. This id
|
||||
may be based off of the attributes present within the object, or on
|
||||
some round-robin scheme. If the scheme is based on a selection, it
|
||||
should set whatever state on the instance to mark it in the future as
|
||||
participating in that shard.
|
||||
|
||||
:param id_chooser: A callable, passed a query and a tuple of identity
|
||||
values, which should return a list of shard ids where the ID might
|
||||
reside. The databases will be queried in the order of this listing.
|
||||
|
||||
:param execute_chooser: For a given :class:`.ORMExecuteState`,
|
||||
returns the list of shard_ids
|
||||
where the query should be issued. Results from all shards returned
|
||||
will be combined together into a single listing.
|
||||
|
||||
.. versionchanged:: 1.4 The ``execute_chooser`` parameter
|
||||
supersedes the ``query_chooser`` parameter.
|
||||
|
||||
:param shards: A dictionary of string shard names
|
||||
to :class:`~sqlalchemy.engine.Engine` objects.
|
||||
|
||||
"""
|
||||
query_chooser = kwargs.pop("query_chooser", None)
|
||||
super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
|
||||
|
||||
event.listen(
|
||||
self, "do_orm_execute", execute_and_instances, retval=True
|
||||
)
|
||||
self.shard_chooser = shard_chooser
|
||||
self.id_chooser = id_chooser
|
||||
|
||||
if query_chooser:
|
||||
util.warn_deprecated(
|
||||
"The ``query_choser`` parameter is deprecated; "
|
||||
"please use ``execute_chooser``.",
|
||||
"1.4",
|
||||
)
|
||||
if execute_chooser:
|
||||
raise exc.ArgumentError(
|
||||
"Can't pass query_chooser and execute_chooser "
|
||||
"at the same time."
|
||||
)
|
||||
|
||||
def execute_chooser(orm_context):
|
||||
return query_chooser(orm_context.statement)
|
||||
|
||||
self.execute_chooser = execute_chooser
|
||||
else:
|
||||
self.execute_chooser = execute_chooser
|
||||
self.query_chooser = query_chooser
|
||||
self.__binds = {}
|
||||
if shards is not None:
|
||||
for k in shards:
|
||||
self.bind_shard(k, shards[k])
|
||||
|
||||
def _identity_lookup(
|
||||
self,
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
identity_token=None,
|
||||
lazy_loaded_from=None,
|
||||
**kw
|
||||
):
|
||||
"""override the default :meth:`.Session._identity_lookup` method so
|
||||
that we search for a given non-token primary key identity across all
|
||||
possible identity tokens (e.g. shard ids).
|
||||
|
||||
.. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
|
||||
the :class:`_query.Query` object to the :class:`.Session`.
|
||||
|
||||
"""
|
||||
|
||||
if identity_token is not None:
|
||||
return super(ShardedSession, self)._identity_lookup(
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
identity_token=identity_token,
|
||||
**kw
|
||||
)
|
||||
else:
|
||||
q = self.query(mapper)
|
||||
if lazy_loaded_from:
|
||||
q = q._set_lazyload_from(lazy_loaded_from)
|
||||
for shard_id in self.id_chooser(q, primary_key_identity):
|
||||
obj = super(ShardedSession, self)._identity_lookup(
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
identity_token=shard_id,
|
||||
lazy_loaded_from=lazy_loaded_from,
|
||||
**kw
|
||||
)
|
||||
if obj is not None:
|
||||
return obj
|
||||
|
||||
return None
|
||||
|
||||
def _choose_shard_and_assign(self, mapper, instance, **kw):
|
||||
if instance is not None:
|
||||
state = inspect(instance)
|
||||
if state.key:
|
||||
token = state.key[2]
|
||||
assert token is not None
|
||||
return token
|
||||
elif state.identity_token:
|
||||
return state.identity_token
|
||||
|
||||
shard_id = self.shard_chooser(mapper, instance, **kw)
|
||||
if instance is not None:
|
||||
state.identity_token = shard_id
|
||||
return shard_id
|
||||
|
||||
def connection_callable(
|
||||
self, mapper=None, instance=None, shard_id=None, **kwargs
|
||||
):
|
||||
"""Provide a :class:`_engine.Connection` to use in the unit of work
|
||||
flush process.
|
||||
|
||||
"""
|
||||
|
||||
if shard_id is None:
|
||||
shard_id = self._choose_shard_and_assign(mapper, instance)
|
||||
|
||||
if self.in_transaction():
|
||||
return self.get_transaction().connection(mapper, shard_id=shard_id)
|
||||
else:
|
||||
return self.get_bind(
|
||||
mapper, shard_id=shard_id, instance=instance
|
||||
).connect(**kwargs)
|
||||
|
||||
def get_bind(
|
||||
self, mapper=None, shard_id=None, instance=None, clause=None, **kw
|
||||
):
|
||||
if shard_id is None:
|
||||
shard_id = self._choose_shard_and_assign(
|
||||
mapper, instance, clause=clause
|
||||
)
|
||||
return self.__binds[shard_id]
|
||||
|
||||
def bind_shard(self, shard_id, bind):
|
||||
self.__binds[shard_id] = bind
|
||||
|
||||
|
||||
def execute_and_instances(orm_context):
|
||||
if orm_context.is_select:
|
||||
load_options = active_options = orm_context.load_options
|
||||
update_options = None
|
||||
|
||||
elif orm_context.is_update or orm_context.is_delete:
|
||||
load_options = None
|
||||
update_options = active_options = orm_context.update_delete_options
|
||||
else:
|
||||
load_options = update_options = active_options = None
|
||||
|
||||
session = orm_context.session
|
||||
|
||||
def iter_for_shard(shard_id, load_options, update_options):
|
||||
execution_options = dict(orm_context.local_execution_options)
|
||||
|
||||
bind_arguments = dict(orm_context.bind_arguments)
|
||||
bind_arguments["shard_id"] = shard_id
|
||||
|
||||
if orm_context.is_select:
|
||||
load_options += {"_refresh_identity_token": shard_id}
|
||||
execution_options["_sa_orm_load_options"] = load_options
|
||||
elif orm_context.is_update or orm_context.is_delete:
|
||||
update_options += {"_refresh_identity_token": shard_id}
|
||||
execution_options["_sa_orm_update_options"] = update_options
|
||||
|
||||
return orm_context.invoke_statement(
|
||||
bind_arguments=bind_arguments, execution_options=execution_options
|
||||
)
|
||||
|
||||
if active_options and active_options._refresh_identity_token is not None:
|
||||
shard_id = active_options._refresh_identity_token
|
||||
elif "_sa_shard_id" in orm_context.execution_options:
|
||||
shard_id = orm_context.execution_options["_sa_shard_id"]
|
||||
elif "shard_id" in orm_context.bind_arguments:
|
||||
shard_id = orm_context.bind_arguments["shard_id"]
|
||||
else:
|
||||
shard_id = None
|
||||
|
||||
if shard_id is not None:
|
||||
return iter_for_shard(shard_id, load_options, update_options)
|
||||
else:
|
||||
partial = []
|
||||
for shard_id in session.execute_chooser(orm_context):
|
||||
result_ = iter_for_shard(shard_id, load_options, update_options)
|
||||
partial.append(result_)
|
||||
|
||||
return partial[0].merge(*partial[1:])
|
||||
1206
lib/sqlalchemy/ext/hybrid.py
Normal file
1206
lib/sqlalchemy/ext/hybrid.py
Normal file
File diff suppressed because it is too large
Load Diff
352
lib/sqlalchemy/ext/indexable.py
Normal file
352
lib/sqlalchemy/ext/indexable.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# ext/index.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Define attributes on ORM-mapped classes that have "index" attributes for
|
||||
columns with :class:`_types.Indexable` types.
|
||||
|
||||
"index" means the attribute is associated with an element of an
|
||||
:class:`_types.Indexable` column with the predefined index to access it.
|
||||
The :class:`_types.Indexable` types include types such as
|
||||
:class:`_types.ARRAY`, :class:`_types.JSON` and
|
||||
:class:`_postgresql.HSTORE`.
|
||||
|
||||
|
||||
|
||||
The :mod:`~sqlalchemy.ext.indexable` extension provides
|
||||
:class:`_schema.Column`-like interface for any element of an
|
||||
:class:`_types.Indexable` typed column. In simple cases, it can be
|
||||
treated as a :class:`_schema.Column` - mapped attribute.
|
||||
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Given ``Person`` as a model with a primary key and JSON data field.
|
||||
While this field may have any number of elements encoded within it,
|
||||
we would like to refer to the element called ``name`` individually
|
||||
as a dedicated attribute which behaves like a standalone column::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
name = index_property('data', 'name')
|
||||
|
||||
|
||||
Above, the ``name`` attribute now behaves like a mapped column. We
|
||||
can compose a new ``Person`` and set the value of ``name``::
|
||||
|
||||
>>> person = Person(name='Alchemist')
|
||||
|
||||
The value is now accessible::
|
||||
|
||||
>>> person.name
|
||||
'Alchemist'
|
||||
|
||||
Behind the scenes, the JSON field was initialized to a new blank dictionary
|
||||
and the field was set::
|
||||
|
||||
>>> person.data
|
||||
{"name": "Alchemist'}
|
||||
|
||||
The field is mutable in place::
|
||||
|
||||
>>> person.name = 'Renamed'
|
||||
>>> person.name
|
||||
'Renamed'
|
||||
>>> person.data
|
||||
{'name': 'Renamed'}
|
||||
|
||||
When using :class:`.index_property`, the change that we make to the indexable
|
||||
structure is also automatically tracked as history; we no longer need
|
||||
to use :class:`~.mutable.MutableDict` in order to track this change
|
||||
for the unit of work.
|
||||
|
||||
Deletions work normally as well::
|
||||
|
||||
>>> del person.name
|
||||
>>> person.data
|
||||
{}
|
||||
|
||||
Above, deletion of ``person.name`` deletes the value from the dictionary,
|
||||
but not the dictionary itself.
|
||||
|
||||
A missing key will produce ``AttributeError``::
|
||||
|
||||
>>> person = Person()
|
||||
>>> person.name
|
||||
...
|
||||
AttributeError: 'name'
|
||||
|
||||
Unless you set a default value::
|
||||
|
||||
>>> class Person(Base):
|
||||
>>> __tablename__ = 'person'
|
||||
>>>
|
||||
>>> id = Column(Integer, primary_key=True)
|
||||
>>> data = Column(JSON)
|
||||
>>>
|
||||
>>> name = index_property('data', 'name', default=None) # See default
|
||||
|
||||
>>> person = Person()
|
||||
>>> print(person.name)
|
||||
None
|
||||
|
||||
|
||||
The attributes are also accessible at the class level.
|
||||
Below, we illustrate ``Person.name`` used to generate
|
||||
an indexed SQL criteria::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> session = Session()
|
||||
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
|
||||
|
||||
The above query is equivalent to::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
|
||||
|
||||
Multiple :class:`.index_property` objects can be chained to produce
|
||||
multiple levels of indexing::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
birthday = index_property('data', 'birthday')
|
||||
year = index_property('birthday', 'year')
|
||||
month = index_property('birthday', 'month')
|
||||
day = index_property('birthday', 'day')
|
||||
|
||||
Above, a query such as::
|
||||
|
||||
q = session.query(Person).filter(Person.year == '1980')
|
||||
|
||||
On a PostgreSQL backend, the above query will render as::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
|
||||
|
||||
Default Values
|
||||
==============
|
||||
|
||||
:class:`.index_property` includes special behaviors for when the indexed
|
||||
data structure does not exist, and a set operation is called:
|
||||
|
||||
* For an :class:`.index_property` that is given an integer index value,
|
||||
the default data structure will be a Python list of ``None`` values,
|
||||
at least as long as the index value; the value is then set at its
|
||||
place in the list. This means for an index value of zero, the list
|
||||
will be initialized to ``[None]`` before setting the given value,
|
||||
and for an index value of five, the list will be initialized to
|
||||
``[None, None, None, None, None]`` before setting the fifth element
|
||||
to the given value. Note that an existing list is **not** extended
|
||||
in place to receive a value.
|
||||
|
||||
* for an :class:`.index_property` that is given any other kind of index
|
||||
value (e.g. strings usually), a Python dictionary is used as the
|
||||
default data structure.
|
||||
|
||||
* The default data structure can be set to any Python callable using the
|
||||
:paramref:`.index_property.datatype` parameter, overriding the previous
|
||||
rules.
|
||||
|
||||
|
||||
Subclassing
|
||||
===========
|
||||
|
||||
:class:`.index_property` can be subclassed, in particular for the common
|
||||
use case of providing coercion of values or SQL expressions as they are
|
||||
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
|
||||
where we want to also include automatic casting plus ``astext()``::
|
||||
|
||||
class pg_json_property(index_property):
|
||||
def __init__(self, attr_name, index, cast_type):
|
||||
super(pg_json_property, self).__init__(attr_name, index)
|
||||
self.cast_type = cast_type
|
||||
|
||||
def expr(self, model):
|
||||
expr = super(pg_json_property, self).expr(model)
|
||||
return expr.astext.cast(self.cast_type)
|
||||
|
||||
The above subclass can be used with the PostgreSQL-specific
|
||||
version of :class:`_postgresql.JSON`::
|
||||
|
||||
from sqlalchemy import Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
age = pg_json_property('data', 'age', Integer)
|
||||
|
||||
The ``age`` attribute at the instance level works as before; however
|
||||
when rendering SQL, PostgreSQL's ``->>`` operator will be used
|
||||
for indexed access, instead of the usual index operator of ``->``::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.age < 20)
|
||||
|
||||
The above query will render::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
|
||||
|
||||
""" # noqa
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .. import inspect
|
||||
from .. import util
|
||||
from ..ext.hybrid import hybrid_property
|
||||
from ..orm.attributes import flag_modified
|
||||
|
||||
|
||||
__all__ = ["index_property"]
|
||||
|
||||
|
||||
class index_property(hybrid_property): # noqa
|
||||
"""A property generator. The generated property describes an object
|
||||
attribute that corresponds to an :class:`_types.Indexable`
|
||||
column.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:mod:`sqlalchemy.ext.indexable`
|
||||
|
||||
"""
|
||||
|
||||
_NO_DEFAULT_ARGUMENT = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attr_name,
|
||||
index,
|
||||
default=_NO_DEFAULT_ARGUMENT,
|
||||
datatype=None,
|
||||
mutable=True,
|
||||
onebased=True,
|
||||
):
|
||||
"""Create a new :class:`.index_property`.
|
||||
|
||||
:param attr_name:
|
||||
An attribute name of an `Indexable` typed column, or other
|
||||
attribute that returns an indexable structure.
|
||||
:param index:
|
||||
The index to be used for getting and setting this value. This
|
||||
should be the Python-side index value for integers.
|
||||
:param default:
|
||||
A value which will be returned instead of `AttributeError`
|
||||
when there is not a value at given index.
|
||||
:param datatype: default datatype to use when the field is empty.
|
||||
By default, this is derived from the type of index used; a
|
||||
Python list for an integer index, or a Python dictionary for
|
||||
any other style of index. For a list, the list will be
|
||||
initialized to a list of None values that is at least
|
||||
``index`` elements long.
|
||||
:param mutable: if False, writes and deletes to the attribute will
|
||||
be disallowed.
|
||||
:param onebased: assume the SQL representation of this value is
|
||||
one-based; that is, the first index in SQL is 1, not zero.
|
||||
"""
|
||||
|
||||
if mutable:
|
||||
super(index_property, self).__init__(
|
||||
self.fget, self.fset, self.fdel, self.expr
|
||||
)
|
||||
else:
|
||||
super(index_property, self).__init__(
|
||||
self.fget, None, None, self.expr
|
||||
)
|
||||
self.attr_name = attr_name
|
||||
self.index = index
|
||||
self.default = default
|
||||
is_numeric = isinstance(index, int)
|
||||
onebased = is_numeric and onebased
|
||||
|
||||
if datatype is not None:
|
||||
self.datatype = datatype
|
||||
else:
|
||||
if is_numeric:
|
||||
self.datatype = lambda: [None for x in range(index + 1)]
|
||||
else:
|
||||
self.datatype = dict
|
||||
self.onebased = onebased
|
||||
|
||||
def _fget_default(self, err=None):
|
||||
if self.default == self._NO_DEFAULT_ARGUMENT:
|
||||
util.raise_(AttributeError(self.attr_name), replace_context=err)
|
||||
else:
|
||||
return self.default
|
||||
|
||||
def fget(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
return self._fget_default()
|
||||
try:
|
||||
value = column_value[self.index]
|
||||
except (KeyError, IndexError) as err:
|
||||
return self._fget_default(err)
|
||||
else:
|
||||
return value
|
||||
|
||||
def fset(self, instance, value):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name, None)
|
||||
if column_value is None:
|
||||
column_value = self.datatype()
|
||||
setattr(instance, attr_name, column_value)
|
||||
column_value[self.index] = value
|
||||
setattr(instance, attr_name, column_value)
|
||||
if attr_name in inspect(instance).mapper.attrs:
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def fdel(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
raise AttributeError(self.attr_name)
|
||||
try:
|
||||
del column_value[self.index]
|
||||
except KeyError as err:
|
||||
util.raise_(AttributeError(self.attr_name), replace_context=err)
|
||||
else:
|
||||
setattr(instance, attr_name, column_value)
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def expr(self, model):
|
||||
column = getattr(model, self.attr_name)
|
||||
index = self.index
|
||||
if self.onebased:
|
||||
index += 1
|
||||
return column[index]
|
||||
416
lib/sqlalchemy/ext/instrumentation.py
Normal file
416
lib/sqlalchemy/ext/instrumentation.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""Extensible class instrumentation.
|
||||
|
||||
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
|
||||
systems of class instrumentation within the ORM. Class instrumentation
|
||||
refers to how the ORM places attributes on the class which maintain
|
||||
data and track changes to that data, as well as event hooks installed
|
||||
on the class.
|
||||
|
||||
.. note::
|
||||
The extension package is provided for the benefit of integration
|
||||
with other object management packages, which already perform
|
||||
their own instrumentation. It is not intended for general use.
|
||||
|
||||
For examples of how the instrumentation extension is used,
|
||||
see the example :ref:`examples_instrumentation`.
|
||||
|
||||
"""
|
||||
import weakref
|
||||
|
||||
from .. import util
|
||||
from ..orm import attributes
|
||||
from ..orm import base as orm_base
|
||||
from ..orm import collections
|
||||
from ..orm import exc as orm_exc
|
||||
from ..orm import instrumentation as orm_instrumentation
|
||||
from ..orm.instrumentation import _default_dict_getter
|
||||
from ..orm.instrumentation import _default_manager_getter
|
||||
from ..orm.instrumentation import _default_state_getter
|
||||
from ..orm.instrumentation import ClassManager
|
||||
from ..orm.instrumentation import InstrumentationFactory
|
||||
|
||||
|
||||
INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
|
||||
"""Attribute, elects custom instrumentation when present on a mapped class.
|
||||
|
||||
Allows a class to specify a slightly or wildly different technique for
|
||||
tracking changes made to mapped attributes and collections.
|
||||
|
||||
Only one instrumentation implementation is allowed in a given object
|
||||
inheritance hierarchy.
|
||||
|
||||
The value of this attribute must be a callable and will be passed a class
|
||||
object. The callable must return one of:
|
||||
|
||||
- An instance of an :class:`.InstrumentationManager` or subclass
|
||||
- An object implementing all or some of InstrumentationManager (TODO)
|
||||
- A dictionary of callables, implementing all or some of the above (TODO)
|
||||
- An instance of a :class:`.ClassManager` or subclass
|
||||
|
||||
This attribute is consulted by SQLAlchemy instrumentation
|
||||
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
|
||||
has been imported. If custom finders are installed in the global
|
||||
instrumentation_finders list, they may or may not choose to honor this
|
||||
attribute.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def find_native_user_instrumentation_hook(cls):
|
||||
"""Find user-specified instrumentation management for a class."""
|
||||
return getattr(cls, INSTRUMENTATION_MANAGER, None)
|
||||
|
||||
|
||||
instrumentation_finders = [find_native_user_instrumentation_hook]
|
||||
"""An extensible sequence of callables which return instrumentation
|
||||
implementations
|
||||
|
||||
When a class is registered, each callable will be passed a class object.
|
||||
If None is returned, the
|
||||
next finder in the sequence is consulted. Otherwise the return must be an
|
||||
instrumentation factory that follows the same guidelines as
|
||||
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
|
||||
|
||||
By default, the only finder is find_native_user_instrumentation_hook, which
|
||||
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
|
||||
ClassManager instrumentation is used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ExtendedInstrumentationRegistry(InstrumentationFactory):
|
||||
"""Extends :class:`.InstrumentationFactory` with additional
|
||||
bookkeeping, to accommodate multiple types of
|
||||
class managers.
|
||||
|
||||
"""
|
||||
|
||||
_manager_finders = weakref.WeakKeyDictionary()
|
||||
_state_finders = weakref.WeakKeyDictionary()
|
||||
_dict_finders = weakref.WeakKeyDictionary()
|
||||
_extended = False
|
||||
|
||||
def _locate_extended_factory(self, class_):
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(class_)
|
||||
if factory is not None:
|
||||
manager = self._extended_class_manager(class_, factory)
|
||||
return manager, factory
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def _check_conflicts(self, class_, factory):
|
||||
existing_factories = self._collect_management_factories_for(
|
||||
class_
|
||||
).difference([factory])
|
||||
if existing_factories:
|
||||
raise TypeError(
|
||||
"multiple instrumentation implementations specified "
|
||||
"in %s inheritance hierarchy: %r"
|
||||
% (class_.__name__, list(existing_factories))
|
||||
)
|
||||
|
||||
def _extended_class_manager(self, class_, factory):
|
||||
manager = factory(class_)
|
||||
if not isinstance(manager, ClassManager):
|
||||
manager = _ClassInstrumentationAdapter(class_, manager)
|
||||
|
||||
if factory != ClassManager and not self._extended:
|
||||
# somebody invoked a custom ClassManager.
|
||||
# reinstall global "getter" functions with the more
|
||||
# expensive ones.
|
||||
self._extended = True
|
||||
_install_instrumented_lookups()
|
||||
|
||||
self._manager_finders[class_] = manager.manager_getter()
|
||||
self._state_finders[class_] = manager.state_getter()
|
||||
self._dict_finders[class_] = manager.dict_getter()
|
||||
return manager
|
||||
|
||||
def _collect_management_factories_for(self, cls):
|
||||
"""Return a collection of factories in play or specified for a
|
||||
hierarchy.
|
||||
|
||||
Traverses the entire inheritance graph of a cls and returns a
|
||||
collection of instrumentation factories for those classes. Factories
|
||||
are extracted from active ClassManagers, if available, otherwise
|
||||
instrumentation_finders is consulted.
|
||||
|
||||
"""
|
||||
hierarchy = util.class_hierarchy(cls)
|
||||
factories = set()
|
||||
for member in hierarchy:
|
||||
manager = self.manager_of_class(member)
|
||||
if manager is not None:
|
||||
factories.add(manager.factory)
|
||||
else:
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(member)
|
||||
if factory is not None:
|
||||
break
|
||||
else:
|
||||
factory = None
|
||||
factories.add(factory)
|
||||
factories.discard(None)
|
||||
return factories
|
||||
|
||||
def unregister(self, class_):
|
||||
super(ExtendedInstrumentationRegistry, self).unregister(class_)
|
||||
if class_ in self._manager_finders:
|
||||
del self._manager_finders[class_]
|
||||
del self._state_finders[class_]
|
||||
del self._dict_finders[class_]
|
||||
|
||||
def manager_of_class(self, cls):
|
||||
if cls is None:
|
||||
return None
|
||||
try:
|
||||
finder = self._manager_finders.get(cls, _default_manager_getter)
|
||||
except TypeError:
|
||||
# due to weakref lookup on invalid object
|
||||
return None
|
||||
else:
|
||||
return finder(cls)
|
||||
|
||||
def state_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._state_finders.get(
|
||||
instance.__class__, _default_state_getter
|
||||
)(instance)
|
||||
|
||||
def dict_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._dict_finders.get(
|
||||
instance.__class__, _default_dict_getter
|
||||
)(instance)
|
||||
|
||||
|
||||
orm_instrumentation._instrumentation_factory = (
|
||||
_instrumentation_factory
|
||||
) = ExtendedInstrumentationRegistry()
|
||||
orm_instrumentation.instrumentation_finders = instrumentation_finders
|
||||
|
||||
|
||||
class InstrumentationManager(object):
|
||||
"""User-defined class instrumentation extension.
|
||||
|
||||
:class:`.InstrumentationManager` can be subclassed in order
|
||||
to change
|
||||
how class instrumentation proceeds. This class exists for
|
||||
the purposes of integration with other object management
|
||||
frameworks which would like to entirely modify the
|
||||
instrumentation methodology of the ORM, and is not intended
|
||||
for regular usage. For interception of class instrumentation
|
||||
events, see :class:`.InstrumentationEvents`.
|
||||
|
||||
The API for this class should be considered as semi-stable,
|
||||
and may change slightly with new releases.
|
||||
|
||||
"""
|
||||
|
||||
# r4361 added a mandatory (cls) constructor to this interface.
|
||||
# given that, perhaps class_ should be dropped from all of these
|
||||
# signatures.
|
||||
|
||||
def __init__(self, class_):
|
||||
pass
|
||||
|
||||
def manage(self, class_, manager):
|
||||
setattr(class_, "_default_class_manager", manager)
|
||||
|
||||
def unregister(self, class_, manager):
|
||||
delattr(class_, "_default_class_manager")
|
||||
|
||||
def manager_getter(self, class_):
|
||||
def get(cls):
|
||||
return cls._default_class_manager
|
||||
|
||||
return get
|
||||
|
||||
def instrument_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def post_configure_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def install_descriptor(self, class_, key, inst):
|
||||
setattr(class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def install_member(self, class_, key, implementation):
|
||||
setattr(class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def instrument_collection_class(self, class_, key, collection_class):
|
||||
return collections.prepare_instrumentation(collection_class)
|
||||
|
||||
def get_instance_dict(self, class_, instance):
|
||||
return instance.__dict__
|
||||
|
||||
def initialize_instance_dict(self, class_, instance):
|
||||
pass
|
||||
|
||||
def install_state(self, class_, instance, state):
|
||||
setattr(instance, "_default_state", state)
|
||||
|
||||
def remove_state(self, class_, instance):
|
||||
delattr(instance, "_default_state")
|
||||
|
||||
def state_getter(self, class_):
|
||||
return lambda instance: getattr(instance, "_default_state")
|
||||
|
||||
def dict_getter(self, class_):
|
||||
return lambda inst: self.get_instance_dict(class_, inst)
|
||||
|
||||
|
||||
class _ClassInstrumentationAdapter(ClassManager):
|
||||
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
|
||||
|
||||
def __init__(self, class_, override):
|
||||
self._adapted = override
|
||||
self._get_state = self._adapted.state_getter(class_)
|
||||
self._get_dict = self._adapted.dict_getter(class_)
|
||||
|
||||
ClassManager.__init__(self, class_)
|
||||
|
||||
def manage(self):
|
||||
self._adapted.manage(self.class_, self)
|
||||
|
||||
def unregister(self):
|
||||
self._adapted.unregister(self.class_, self)
|
||||
|
||||
def manager_getter(self):
|
||||
return self._adapted.manager_getter(self.class_)
|
||||
|
||||
def instrument_attribute(self, key, inst, propagated=False):
|
||||
ClassManager.instrument_attribute(self, key, inst, propagated)
|
||||
if not propagated:
|
||||
self._adapted.instrument_attribute(self.class_, key, inst)
|
||||
|
||||
def post_configure_attribute(self, key):
|
||||
super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
|
||||
self._adapted.post_configure_attribute(self.class_, key, self[key])
|
||||
|
||||
def install_descriptor(self, key, inst):
|
||||
self._adapted.install_descriptor(self.class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, key):
|
||||
self._adapted.uninstall_descriptor(self.class_, key)
|
||||
|
||||
def install_member(self, key, implementation):
|
||||
self._adapted.install_member(self.class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, key):
|
||||
self._adapted.uninstall_member(self.class_, key)
|
||||
|
||||
def instrument_collection_class(self, key, collection_class):
|
||||
return self._adapted.instrument_collection_class(
|
||||
self.class_, key, collection_class
|
||||
)
|
||||
|
||||
def initialize_collection(self, key, state, factory):
|
||||
delegate = getattr(self._adapted, "initialize_collection", None)
|
||||
if delegate:
|
||||
return delegate(key, state, factory)
|
||||
else:
|
||||
return ClassManager.initialize_collection(
|
||||
self, key, state, factory
|
||||
)
|
||||
|
||||
def new_instance(self, state=None):
|
||||
instance = self.class_.__new__(self.class_)
|
||||
self.setup_instance(instance, state)
|
||||
return instance
|
||||
|
||||
def _new_state_if_none(self, instance):
|
||||
"""Install a default InstanceState if none is present.
|
||||
|
||||
A private convenience method used by the __init__ decorator.
|
||||
"""
|
||||
if self.has_state(instance):
|
||||
return False
|
||||
else:
|
||||
return self.setup_instance(instance)
|
||||
|
||||
def setup_instance(self, instance, state=None):
|
||||
self._adapted.initialize_instance_dict(self.class_, instance)
|
||||
|
||||
if state is None:
|
||||
state = self._state_constructor(instance, self)
|
||||
|
||||
# the given instance is assumed to have no state
|
||||
self._adapted.install_state(self.class_, instance, state)
|
||||
return state
|
||||
|
||||
def teardown_instance(self, instance):
|
||||
self._adapted.remove_state(self.class_, instance)
|
||||
|
||||
def has_state(self, instance):
|
||||
try:
|
||||
self._get_state(instance)
|
||||
except orm_exc.NO_STATE:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def state_getter(self):
|
||||
return self._get_state
|
||||
|
||||
def dict_getter(self):
|
||||
return self._get_dict
|
||||
|
||||
|
||||
def _install_instrumented_lookups():
|
||||
"""Replace global class/object management functions
|
||||
with ExtendedInstrumentationRegistry implementations, which
|
||||
allow multiple types of class managers to be present,
|
||||
at the cost of performance.
|
||||
|
||||
This function is called only by ExtendedInstrumentationRegistry
|
||||
and unit tests specific to this behavior.
|
||||
|
||||
The _reinstall_default_lookups() function can be called
|
||||
after this one to re-establish the default functions.
|
||||
|
||||
"""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_instrumentation_factory.state_of,
|
||||
instance_dict=_instrumentation_factory.dict_of,
|
||||
manager_of_class=_instrumentation_factory.manager_of_class,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _reinstall_default_lookups():
|
||||
"""Restore simplified lookups."""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_default_state_getter,
|
||||
instance_dict=_default_dict_getter,
|
||||
manager_of_class=_default_manager_getter,
|
||||
)
|
||||
)
|
||||
_instrumentation_factory._extended = False
|
||||
|
||||
|
||||
def _install_lookups(lookups):
|
||||
global instance_state, instance_dict, manager_of_class
|
||||
instance_state = lookups["instance_state"]
|
||||
instance_dict = lookups["instance_dict"]
|
||||
manager_of_class = lookups["manager_of_class"]
|
||||
orm_base.instance_state = (
|
||||
attributes.instance_state
|
||||
) = orm_instrumentation.instance_state = instance_state
|
||||
orm_base.instance_dict = (
|
||||
attributes.instance_dict
|
||||
) = orm_instrumentation.instance_dict = instance_dict
|
||||
orm_base.manager_of_class = (
|
||||
attributes.manager_of_class
|
||||
) = orm_instrumentation.manager_of_class = manager_of_class
|
||||
958
lib/sqlalchemy/ext/mutable.py
Normal file
958
lib/sqlalchemy/ext/mutable.py
Normal file
@@ -0,0 +1,958 @@
|
||||
# ext/mutable.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
r"""Provide support for tracking of in-place changes to scalar values,
|
||||
which are propagated into ORM change events on owning parent objects.
|
||||
|
||||
.. _mutable_scalars:
|
||||
|
||||
Establishing Mutability on Scalar Column Values
|
||||
===============================================
|
||||
|
||||
A typical example of a "mutable" structure is a Python dictionary.
|
||||
Following the example introduced in :ref:`types_toplevel`, we
|
||||
begin with a custom type that marshals Python dictionaries into
|
||||
JSON strings before being persisted::
|
||||
|
||||
from sqlalchemy.types import TypeDecorator, VARCHAR
|
||||
import json
|
||||
|
||||
class JSONEncodedDict(TypeDecorator):
|
||||
"Represents an immutable structure as a json-encoded string."
|
||||
|
||||
impl = VARCHAR
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
value = json.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
value = json.loads(value)
|
||||
return value
|
||||
|
||||
The usage of ``json`` is only for the purposes of example. The
|
||||
:mod:`sqlalchemy.ext.mutable` extension can be used
|
||||
with any type whose target Python type may be mutable, including
|
||||
:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc.
|
||||
|
||||
When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
|
||||
tracks all parents which reference it. Below, we illustrate a simple
|
||||
version of the :class:`.MutableDict` dictionary object, which applies
|
||||
the :class:`.Mutable` mixin to a plain Python dictionary::
|
||||
|
||||
from sqlalchemy.ext.mutable import Mutable
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"Convert plain dictionaries to MutableDict."
|
||||
|
||||
if not isinstance(value, MutableDict):
|
||||
if isinstance(value, dict):
|
||||
return MutableDict(value)
|
||||
|
||||
# this call will raise ValueError
|
||||
return Mutable.coerce(key, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"Detect dictionary set events and emit change events."
|
||||
|
||||
dict.__setitem__(self, key, value)
|
||||
self.changed()
|
||||
|
||||
def __delitem__(self, key):
|
||||
"Detect dictionary del events and emit change events."
|
||||
|
||||
dict.__delitem__(self, key)
|
||||
self.changed()
|
||||
|
||||
The above dictionary class takes the approach of subclassing the Python
|
||||
built-in ``dict`` to produce a dict
|
||||
subclass which routes all mutation events through ``__setitem__``. There are
|
||||
variants on this approach, such as subclassing ``UserDict.UserDict`` or
|
||||
``collections.MutableMapping``; the part that's important to this example is
|
||||
that the :meth:`.Mutable.changed` method is called whenever an in-place
|
||||
change to the datastructure takes place.
|
||||
|
||||
We also redefine the :meth:`.Mutable.coerce` method which will be used to
|
||||
convert any values that are not instances of ``MutableDict``, such
|
||||
as the plain dictionaries returned by the ``json`` module, into the
|
||||
appropriate type. Defining this method is optional; we could just as well
|
||||
created our ``JSONEncodedDict`` such that it always returns an instance
|
||||
of ``MutableDict``, and additionally ensured that all calling code
|
||||
uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
|
||||
overridden, any values applied to a parent object which are not instances
|
||||
of the mutable type will raise a ``ValueError``.
|
||||
|
||||
Our new ``MutableDict`` type offers a class method
|
||||
:meth:`~.Mutable.as_mutable` which we can use within column metadata
|
||||
to associate with types. This method grabs the given type object or
|
||||
class and associates a listener that will detect all future mappings
|
||||
of this type, applying event listening instrumentation to the mapped
|
||||
attribute. Such as, with classical table metadata::
|
||||
|
||||
from sqlalchemy import Table, Column, Integer
|
||||
|
||||
my_data = Table('my_data', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', MutableDict.as_mutable(JSONEncodedDict))
|
||||
)
|
||||
|
||||
Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
|
||||
(if the type object was not an instance already), which will intercept any
|
||||
attributes which are mapped against this type. Below we establish a simple
|
||||
mapping against the ``my_data`` table::
|
||||
|
||||
from sqlalchemy import mapper
|
||||
|
||||
class MyDataClass(object):
|
||||
pass
|
||||
|
||||
# associates mutation listeners with MyDataClass.data
|
||||
mapper(MyDataClass, my_data)
|
||||
|
||||
The ``MyDataClass.data`` member will now be notified of in place changes
|
||||
to its value.
|
||||
|
||||
There's no difference in usage when using declarative::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class MyDataClass(Base):
|
||||
__tablename__ = 'my_data'
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(JSONEncodedDict))
|
||||
|
||||
Any in-place changes to the ``MyDataClass.data`` member
|
||||
will flag the attribute as "dirty" on the parent object::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
|
||||
>>> sess = Session()
|
||||
>>> m1 = MyDataClass(data={'value1':'foo'})
|
||||
>>> sess.add(m1)
|
||||
>>> sess.commit()
|
||||
|
||||
>>> m1.data['value1'] = 'bar'
|
||||
>>> assert m1 in sess.dirty
|
||||
True
|
||||
|
||||
The ``MutableDict`` can be associated with all future instances
|
||||
of ``JSONEncodedDict`` in one step, using
|
||||
:meth:`~.Mutable.associate_with`. This is similar to
|
||||
:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
|
||||
of ``MutableDict`` in all mappings unconditionally, without
|
||||
the need to declare it individually::
|
||||
|
||||
MutableDict.associate_with(JSONEncodedDict)
|
||||
|
||||
class MyDataClass(Base):
|
||||
__tablename__ = 'my_data'
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSONEncodedDict)
|
||||
|
||||
|
||||
Supporting Pickling
|
||||
--------------------
|
||||
|
||||
The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
|
||||
placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
|
||||
stores a mapping of parent mapped objects keyed to the attribute name under
|
||||
which they are associated with this value. ``WeakKeyDictionary`` objects are
|
||||
not picklable, due to the fact that they contain weakrefs and function
|
||||
callbacks. In our case, this is a good thing, since if this dictionary were
|
||||
picklable, it could lead to an excessively large pickle size for our value
|
||||
objects that are pickled by themselves outside of the context of the parent.
|
||||
The developer responsibility here is only to provide a ``__getstate__`` method
|
||||
that excludes the :meth:`~MutableBase._parents` collection from the pickle
|
||||
stream::
|
||||
|
||||
class MyMutableType(Mutable):
|
||||
def __getstate__(self):
|
||||
d = self.__dict__.copy()
|
||||
d.pop('_parents', None)
|
||||
return d
|
||||
|
||||
With our dictionary example, we need to return the contents of the dict itself
|
||||
(and also restore them on __setstate__)::
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
# ....
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
In the case that our mutable value object is pickled as it is attached to one
|
||||
or more parent objects that are also part of the pickle, the :class:`.Mutable`
|
||||
mixin will re-establish the :attr:`.Mutable._parents` collection on each value
|
||||
object as the owning parents themselves are unpickled.
|
||||
|
||||
Receiving Events
|
||||
----------------
|
||||
|
||||
The :meth:`.AttributeEvents.modified` event handler may be used to receive
|
||||
an event when a mutable scalar emits a change event. This event handler
|
||||
is called when the :func:`.attributes.flag_modified` function is called
|
||||
from within the mutable extension::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import event
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class MyDataClass(Base):
|
||||
__tablename__ = 'my_data'
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(JSONEncodedDict))
|
||||
|
||||
@event.listens_for(MyDataClass.data, "modified")
|
||||
def modified_json(instance):
|
||||
print("json value modified:", instance.data)
|
||||
|
||||
.. _mutable_composites:
|
||||
|
||||
Establishing Mutability on Composites
|
||||
=====================================
|
||||
|
||||
Composites are a special ORM feature which allow a single scalar attribute to
|
||||
be assigned an object value which represents information "composed" from one
|
||||
or more columns from the underlying mapped table. The usual example is that of
|
||||
a geometric "point", and is introduced in :ref:`mapper_composite`.
|
||||
|
||||
As is the case with :class:`.Mutable`, the user-defined composite class
|
||||
subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
|
||||
change events to its parents via the :meth:`.MutableComposite.changed` method.
|
||||
In the case of a composite class, the detection is usually via the usage of
|
||||
Python descriptors (i.e. ``@property``), or alternatively via the special
|
||||
Python method ``__setattr__()``. Below we expand upon the ``Point`` class
|
||||
introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
|
||||
and to also route attribute set events via ``__setattr__`` to the
|
||||
:meth:`.MutableComposite.changed` method::
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableComposite
|
||||
|
||||
class Point(MutableComposite):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"Intercept set events"
|
||||
|
||||
# set the attribute
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
# alert all parents to the change
|
||||
self.changed()
|
||||
|
||||
def __composite_values__(self):
|
||||
return self.x, self.y
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Point) and \
|
||||
other.x == self.x and \
|
||||
other.y == self.y
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
The :class:`.MutableComposite` class uses a Python metaclass to automatically
|
||||
establish listeners for any usage of :func:`_orm.composite` that specifies our
|
||||
``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
|
||||
listeners are established which will route change events from ``Point``
|
||||
objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
|
||||
|
||||
from sqlalchemy.orm import composite, mapper
|
||||
from sqlalchemy import Table, Column
|
||||
|
||||
vertices = Table('vertices', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x1', Integer),
|
||||
Column('y1', Integer),
|
||||
Column('x2', Integer),
|
||||
Column('y2', Integer),
|
||||
)
|
||||
|
||||
class Vertex(object):
|
||||
pass
|
||||
|
||||
mapper(Vertex, vertices, properties={
|
||||
'start': composite(Point, vertices.c.x1, vertices.c.y1),
|
||||
'end': composite(Point, vertices.c.x2, vertices.c.y2)
|
||||
})
|
||||
|
||||
Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
|
||||
will flag the attribute as "dirty" on the parent object::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
|
||||
>>> sess = Session()
|
||||
>>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
|
||||
>>> sess.add(v1)
|
||||
>>> sess.commit()
|
||||
|
||||
>>> v1.end.x = 8
|
||||
>>> assert v1 in sess.dirty
|
||||
True
|
||||
|
||||
Coercing Mutable Composites
|
||||
---------------------------
|
||||
|
||||
The :meth:`.MutableBase.coerce` method is also supported on composite types.
|
||||
In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
|
||||
method is only called for attribute set operations, not load operations.
|
||||
Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
|
||||
to using a :func:`.validates` validation routine for all attributes which
|
||||
make use of the custom composite type::
|
||||
|
||||
class Point(MutableComposite):
|
||||
# other Point methods
|
||||
# ...
|
||||
|
||||
def coerce(cls, key, value):
|
||||
if isinstance(value, tuple):
|
||||
value = Point(*value)
|
||||
elif not isinstance(value, Point):
|
||||
raise ValueError("tuple or Point expected")
|
||||
return value
|
||||
|
||||
Supporting Pickling
|
||||
--------------------
|
||||
|
||||
As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
|
||||
class uses a ``weakref.WeakKeyDictionary`` available via the
|
||||
:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
|
||||
pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
|
||||
to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
|
||||
Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
|
||||
the minimal form of our ``Point`` class::
|
||||
|
||||
class Point(MutableComposite):
|
||||
# ...
|
||||
|
||||
def __getstate__(self):
|
||||
return self.x, self.y
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.x, self.y = state
|
||||
|
||||
As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
|
||||
pickling process of the parent's object-relational state so that the
|
||||
:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
|
||||
|
||||
"""
|
||||
from collections import defaultdict
|
||||
import weakref
|
||||
|
||||
from .. import event
|
||||
from .. import inspect
|
||||
from .. import types
|
||||
from ..orm import Mapper
|
||||
from ..orm import mapper
|
||||
from ..orm.attributes import flag_modified
|
||||
from ..sql.base import SchemaEventTarget
|
||||
from ..util import memoized_property
|
||||
|
||||
|
||||
class MutableBase(object):
|
||||
"""Common base class to :class:`.Mutable`
|
||||
and :class:`.MutableComposite`.
|
||||
|
||||
"""
|
||||
|
||||
@memoized_property
|
||||
def _parents(self):
|
||||
"""Dictionary of parent object's :class:`.InstanceState`->attribute
|
||||
name on the parent.
|
||||
|
||||
This attribute is a so-called "memoized" property. It initializes
|
||||
itself with a new ``weakref.WeakKeyDictionary`` the first time
|
||||
it is accessed, returning the same object upon subsequent access.
|
||||
|
||||
.. versionchanged:: 1.4 the :class:`.InstanceState` is now used
|
||||
as the key in the weak dictionary rather than the instance
|
||||
itself.
|
||||
|
||||
"""
|
||||
|
||||
return weakref.WeakKeyDictionary()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"""Given a value, coerce it into the target type.
|
||||
|
||||
Can be overridden by custom subclasses to coerce incoming
|
||||
data into a particular type.
|
||||
|
||||
By default, raises ``ValueError``.
|
||||
|
||||
This method is called in different scenarios depending on if
|
||||
the parent class is of type :class:`.Mutable` or of type
|
||||
:class:`.MutableComposite`. In the case of the former, it is called
|
||||
for both attribute-set operations as well as during ORM loading
|
||||
operations. For the latter, it is only called during attribute-set
|
||||
operations; the mechanics of the :func:`.composite` construct
|
||||
handle coercion during load operations.
|
||||
|
||||
|
||||
:param key: string name of the ORM-mapped attribute being set.
|
||||
:param value: the incoming value.
|
||||
:return: the method should return the coerced value, or raise
|
||||
``ValueError`` if the coercion cannot be completed.
|
||||
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
msg = "Attribute '%s' does not accept objects of type %s"
|
||||
raise ValueError(msg % (key, type(value)))
|
||||
|
||||
@classmethod
|
||||
def _get_listen_keys(cls, attribute):
|
||||
"""Given a descriptor attribute, return a ``set()`` of the attribute
|
||||
keys which indicate a change in the state of this attribute.
|
||||
|
||||
This is normally just ``set([attribute.key])``, but can be overridden
|
||||
to provide for additional keys. E.g. a :class:`.MutableComposite`
|
||||
augments this set with the attribute keys associated with the columns
|
||||
that comprise the composite value.
|
||||
|
||||
This collection is consulted in the case of intercepting the
|
||||
:meth:`.InstanceEvents.refresh` and
|
||||
:meth:`.InstanceEvents.refresh_flush` events, which pass along a list
|
||||
of attribute names that have been refreshed; the list is compared
|
||||
against this set to determine if action needs to be taken.
|
||||
|
||||
.. versionadded:: 1.0.5
|
||||
|
||||
"""
|
||||
return {attribute.key}
|
||||
|
||||
@classmethod
|
||||
def _listen_on_attribute(cls, attribute, coerce, parent_cls):
|
||||
"""Establish this type as a mutation listener for the given
|
||||
mapped descriptor.
|
||||
|
||||
"""
|
||||
key = attribute.key
|
||||
if parent_cls is not attribute.class_:
|
||||
return
|
||||
|
||||
# rely on "propagate" here
|
||||
parent_cls = attribute.class_
|
||||
|
||||
listen_keys = cls._get_listen_keys(attribute)
|
||||
|
||||
def load(state, *args):
|
||||
"""Listen for objects loaded or refreshed.
|
||||
|
||||
Wrap the target data member's value with
|
||||
``Mutable``.
|
||||
|
||||
"""
|
||||
val = state.dict.get(key, None)
|
||||
if val is not None:
|
||||
if coerce:
|
||||
val = cls.coerce(key, val)
|
||||
state.dict[key] = val
|
||||
val._parents[state] = key
|
||||
|
||||
def load_attrs(state, ctx, attrs):
|
||||
if not attrs or listen_keys.intersection(attrs):
|
||||
load(state)
|
||||
|
||||
def set_(target, value, oldvalue, initiator):
|
||||
"""Listen for set/replace events on the target
|
||||
data member.
|
||||
|
||||
Establish a weak reference to the parent object
|
||||
on the incoming value, remove it for the one
|
||||
outgoing.
|
||||
|
||||
"""
|
||||
if value is oldvalue:
|
||||
return value
|
||||
|
||||
if not isinstance(value, cls):
|
||||
value = cls.coerce(key, value)
|
||||
if value is not None:
|
||||
value._parents[target] = key
|
||||
if isinstance(oldvalue, cls):
|
||||
oldvalue._parents.pop(inspect(target), None)
|
||||
return value
|
||||
|
||||
def pickle(state, state_dict):
|
||||
val = state.dict.get(key, None)
|
||||
if val is not None:
|
||||
if "ext.mutable.values" not in state_dict:
|
||||
state_dict["ext.mutable.values"] = defaultdict(list)
|
||||
state_dict["ext.mutable.values"][key].append(val)
|
||||
|
||||
def unpickle(state, state_dict):
|
||||
if "ext.mutable.values" in state_dict:
|
||||
collection = state_dict["ext.mutable.values"]
|
||||
if isinstance(collection, list):
|
||||
# legacy format
|
||||
for val in collection:
|
||||
val._parents[state] = key
|
||||
else:
|
||||
for val in state_dict["ext.mutable.values"][key]:
|
||||
val._parents[state] = key
|
||||
|
||||
event.listen(parent_cls, "load", load, raw=True, propagate=True)
|
||||
event.listen(
|
||||
parent_cls, "refresh", load_attrs, raw=True, propagate=True
|
||||
)
|
||||
event.listen(
|
||||
parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
|
||||
)
|
||||
event.listen(
|
||||
attribute, "set", set_, raw=True, retval=True, propagate=True
|
||||
)
|
||||
event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
|
||||
event.listen(
|
||||
parent_cls, "unpickle", unpickle, raw=True, propagate=True
|
||||
)
|
||||
|
||||
|
||||
class Mutable(MutableBase):
|
||||
"""Mixin that defines transparent propagation of change
|
||||
events to a parent object.
|
||||
|
||||
See the example in :ref:`mutable_scalars` for usage information.
|
||||
|
||||
"""
|
||||
|
||||
def changed(self):
|
||||
"""Subclasses should call this method whenever change events occur."""
|
||||
|
||||
for parent, key in self._parents.items():
|
||||
flag_modified(parent.obj(), key)
|
||||
|
||||
@classmethod
|
||||
def associate_with_attribute(cls, attribute):
|
||||
"""Establish this type as a mutation listener for the given
|
||||
mapped descriptor.
|
||||
|
||||
"""
|
||||
cls._listen_on_attribute(attribute, True, attribute.class_)
|
||||
|
||||
@classmethod
|
||||
def associate_with(cls, sqltype):
|
||||
"""Associate this wrapper with all future mapped columns
|
||||
of the given type.
|
||||
|
||||
This is a convenience method that calls
|
||||
``associate_with_attribute`` automatically.
|
||||
|
||||
.. warning::
|
||||
|
||||
The listeners established by this method are *global*
|
||||
to all mappers, and are *not* garbage collected. Only use
|
||||
:meth:`.associate_with` for types that are permanent to an
|
||||
application, not with ad-hoc types else this will cause unbounded
|
||||
growth in memory usage.
|
||||
|
||||
"""
|
||||
|
||||
def listen_for_type(mapper, class_):
|
||||
if mapper.non_primary:
|
||||
return
|
||||
for prop in mapper.column_attrs:
|
||||
if isinstance(prop.columns[0].type, sqltype):
|
||||
cls.associate_with_attribute(getattr(class_, prop.key))
|
||||
|
||||
event.listen(mapper, "mapper_configured", listen_for_type)
|
||||
|
||||
@classmethod
|
||||
def as_mutable(cls, sqltype):
|
||||
"""Associate a SQL type with this mutable Python type.
|
||||
|
||||
This establishes listeners that will detect ORM mappings against
|
||||
the given type, adding mutation event trackers to those mappings.
|
||||
|
||||
The type is returned, unconditionally as an instance, so that
|
||||
:meth:`.as_mutable` can be used inline::
|
||||
|
||||
Table('mytable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', MyMutableType.as_mutable(PickleType))
|
||||
)
|
||||
|
||||
Note that the returned type is always an instance, even if a class
|
||||
is given, and that only columns which are declared specifically with
|
||||
that type instance receive additional instrumentation.
|
||||
|
||||
To associate a particular mutable type with all occurrences of a
|
||||
particular type, use the :meth:`.Mutable.associate_with` classmethod
|
||||
of the particular :class:`.Mutable` subclass to establish a global
|
||||
association.
|
||||
|
||||
.. warning::
|
||||
|
||||
The listeners established by this method are *global*
|
||||
to all mappers, and are *not* garbage collected. Only use
|
||||
:meth:`.as_mutable` for types that are permanent to an application,
|
||||
not with ad-hoc types else this will cause unbounded growth
|
||||
in memory usage.
|
||||
|
||||
"""
|
||||
sqltype = types.to_instance(sqltype)
|
||||
|
||||
# a SchemaType will be copied when the Column is copied,
|
||||
# and we'll lose our ability to link that type back to the original.
|
||||
# so track our original type w/ columns
|
||||
if isinstance(sqltype, SchemaEventTarget):
|
||||
|
||||
@event.listens_for(sqltype, "before_parent_attach")
|
||||
def _add_column_memo(sqltyp, parent):
|
||||
parent.info["_ext_mutable_orig_type"] = sqltyp
|
||||
|
||||
schema_event_check = True
|
||||
else:
|
||||
schema_event_check = False
|
||||
|
||||
def listen_for_type(mapper, class_):
|
||||
if mapper.non_primary:
|
||||
return
|
||||
for prop in mapper.column_attrs:
|
||||
if (
|
||||
schema_event_check
|
||||
and hasattr(prop.expression, "info")
|
||||
and prop.expression.info.get("_ext_mutable_orig_type")
|
||||
is sqltype
|
||||
) or (prop.columns[0].type is sqltype):
|
||||
cls.associate_with_attribute(getattr(class_, prop.key))
|
||||
|
||||
event.listen(mapper, "mapper_configured", listen_for_type)
|
||||
|
||||
return sqltype
|
||||
|
||||
|
||||
class MutableComposite(MutableBase):
|
||||
"""Mixin that defines transparent propagation of change
|
||||
events on a SQLAlchemy "composite" object to its
|
||||
owning parent or parents.
|
||||
|
||||
See the example in :ref:`mutable_composites` for usage information.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_listen_keys(cls, attribute):
|
||||
return {attribute.key}.union(attribute.property._attribute_keys)
|
||||
|
||||
def changed(self):
|
||||
"""Subclasses should call this method whenever change events occur."""
|
||||
|
||||
for parent, key in self._parents.items():
|
||||
|
||||
prop = parent.mapper.get_property(key)
|
||||
for value, attr_name in zip(
|
||||
self.__composite_values__(), prop._attribute_keys
|
||||
):
|
||||
setattr(parent.obj(), attr_name, value)
|
||||
|
||||
|
||||
def _setup_composite_listener():
|
||||
def _listen_for_type(mapper, class_):
|
||||
for prop in mapper.iterate_properties:
|
||||
if (
|
||||
hasattr(prop, "composite_class")
|
||||
and isinstance(prop.composite_class, type)
|
||||
and issubclass(prop.composite_class, MutableComposite)
|
||||
):
|
||||
prop.composite_class._listen_on_attribute(
|
||||
getattr(class_, prop.key), False, class_
|
||||
)
|
||||
|
||||
if not event.contains(Mapper, "mapper_configured", _listen_for_type):
|
||||
event.listen(Mapper, "mapper_configured", _listen_for_type)
|
||||
|
||||
|
||||
_setup_composite_listener()
|
||||
|
||||
|
||||
class MutableDict(Mutable, dict):
|
||||
"""A dictionary type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableDict` object implements a dictionary that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the dictionary are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableDict` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the dictionary. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
dictionary structure, such as a JSON structure. To support this use case,
|
||||
build a subclass of :class:`.MutableDict` that provides appropriate
|
||||
coercion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableList`
|
||||
|
||||
:class:`.MutableSet`
|
||||
|
||||
"""
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Detect dictionary set events and emit change events."""
|
||||
dict.__setitem__(self, key, value)
|
||||
self.changed()
|
||||
|
||||
def setdefault(self, key, value):
|
||||
result = dict.setdefault(self, key, value)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Detect dictionary del events and emit change events."""
|
||||
dict.__delitem__(self, key)
|
||||
self.changed()
|
||||
|
||||
def update(self, *a, **kw):
|
||||
dict.update(self, *a, **kw)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = dict.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def popitem(self):
|
||||
result = dict.popitem(self)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
dict.clear(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
"""Convert plain dictionary to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, dict):
|
||||
return cls(value)
|
||||
return Mutable.coerce(key, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
|
||||
class MutableList(Mutable, list):
|
||||
"""A list type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableList` object implements a list that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the list are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableList` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the list. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
mutable structure, such as a JSON structure. To support this use case,
|
||||
build a subclass of :class:`.MutableList` that provides appropriate
|
||||
coercion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableDict`
|
||||
|
||||
:class:`.MutableSet`
|
||||
|
||||
"""
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
return (self.__class__, (list(self),))
|
||||
|
||||
# needed for backwards compatibility with
|
||||
# older pickles
|
||||
def __setstate__(self, state):
|
||||
self[:] = state
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
"""Detect list set events and emit change events."""
|
||||
list.__setitem__(self, index, value)
|
||||
self.changed()
|
||||
|
||||
def __setslice__(self, start, end, value):
|
||||
"""Detect list set events and emit change events."""
|
||||
list.__setslice__(self, start, end, value)
|
||||
self.changed()
|
||||
|
||||
def __delitem__(self, index):
|
||||
"""Detect list del events and emit change events."""
|
||||
list.__delitem__(self, index)
|
||||
self.changed()
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
"""Detect list del events and emit change events."""
|
||||
list.__delslice__(self, start, end)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = list.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def append(self, x):
|
||||
list.append(self, x)
|
||||
self.changed()
|
||||
|
||||
def extend(self, x):
|
||||
list.extend(self, x)
|
||||
self.changed()
|
||||
|
||||
def __iadd__(self, x):
|
||||
self.extend(x)
|
||||
return self
|
||||
|
||||
def insert(self, i, x):
|
||||
list.insert(self, i, x)
|
||||
self.changed()
|
||||
|
||||
def remove(self, i):
|
||||
list.remove(self, i)
|
||||
self.changed()
|
||||
|
||||
def clear(self):
|
||||
list.clear(self)
|
||||
self.changed()
|
||||
|
||||
def sort(self, **kw):
|
||||
list.sort(self, **kw)
|
||||
self.changed()
|
||||
|
||||
def reverse(self):
|
||||
list.reverse(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, index, value):
|
||||
"""Convert plain list to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, list):
|
||||
return cls(value)
|
||||
return Mutable.coerce(index, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class MutableSet(Mutable, set):
|
||||
"""A set type that implements :class:`.Mutable`.
|
||||
|
||||
The :class:`.MutableSet` object implements a set that will
|
||||
emit change events to the underlying mapping when the contents of
|
||||
the set are altered, including when values are added or removed.
|
||||
|
||||
Note that :class:`.MutableSet` does **not** apply mutable tracking to the
|
||||
*values themselves* inside the set. Therefore it is not a sufficient
|
||||
solution for the use case of tracking deep changes to a *recursive*
|
||||
mutable structure. To support this use case,
|
||||
build a subclass of :class:`.MutableSet` that provides appropriate
|
||||
coercion to the values placed in the dictionary so that they too are
|
||||
"mutable", and emit events up to their parent structure.
|
||||
|
||||
.. versionadded:: 1.1
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.MutableDict`
|
||||
|
||||
:class:`.MutableList`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def update(self, *arg):
|
||||
set.update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def intersection_update(self, *arg):
|
||||
set.intersection_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def difference_update(self, *arg):
|
||||
set.difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def symmetric_difference_update(self, *arg):
|
||||
set.symmetric_difference_update(self, *arg)
|
||||
self.changed()
|
||||
|
||||
def __ior__(self, other):
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
def __iand__(self, other):
|
||||
self.intersection_update(other)
|
||||
return self
|
||||
|
||||
def __ixor__(self, other):
|
||||
self.symmetric_difference_update(other)
|
||||
return self
|
||||
|
||||
def __isub__(self, other):
|
||||
self.difference_update(other)
|
||||
return self
|
||||
|
||||
def add(self, elem):
|
||||
set.add(self, elem)
|
||||
self.changed()
|
||||
|
||||
def remove(self, elem):
|
||||
set.remove(self, elem)
|
||||
self.changed()
|
||||
|
||||
def discard(self, elem):
|
||||
set.discard(self, elem)
|
||||
self.changed()
|
||||
|
||||
def pop(self, *arg):
|
||||
result = set.pop(self, *arg)
|
||||
self.changed()
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
set.clear(self)
|
||||
self.changed()
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, index, value):
|
||||
"""Convert plain set to instance of this class."""
|
||||
if not isinstance(value, cls):
|
||||
if isinstance(value, set):
|
||||
return cls(value)
|
||||
return Mutable.coerce(index, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def __getstate__(self):
|
||||
return set(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.update(state)
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
return (self.__class__, (list(self),))
|
||||
0
lib/sqlalchemy/ext/mypy/__init__.py
Normal file
0
lib/sqlalchemy/ext/mypy/__init__.py
Normal file
299
lib/sqlalchemy/ext/mypy/apply.py
Normal file
299
lib/sqlalchemy/ext/mypy/apply.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# ext/mypy/apply.py
|
||||
# Copyright (C) 2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import ARG_NAMED_OPT
|
||||
from mypy.nodes import Argument
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import MDEF
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TempNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.plugins.common import add_method_to_class
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneTyp
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import infer
|
||||
from . import util
|
||||
from .names import NAMED_TYPE_SQLA_MAPPED
|
||||
|
||||
|
||||
def apply_mypy_mapped_attr(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
item: Union[NameExpr, StrExpr],
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
if isinstance(item, NameExpr):
|
||||
name = item.name
|
||||
elif isinstance(item, StrExpr):
|
||||
name = item.value
|
||||
else:
|
||||
return None
|
||||
|
||||
for stmt in cls.defs.body:
|
||||
if (
|
||||
isinstance(stmt, AssignmentStmt)
|
||||
and isinstance(stmt.lvalues[0], NameExpr)
|
||||
and stmt.lvalues[0].name == name
|
||||
):
|
||||
break
|
||||
else:
|
||||
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
|
||||
return None
|
||||
|
||||
if stmt.type is None:
|
||||
util.fail(
|
||||
api,
|
||||
"Statement linked from _mypy_mapped_attrs has no "
|
||||
"typing information",
|
||||
stmt,
|
||||
)
|
||||
return None
|
||||
|
||||
left_hand_explicit_type = get_proper_type(stmt.type)
|
||||
assert isinstance(
|
||||
left_hand_explicit_type, (Instance, UnionType, UnboundType)
|
||||
)
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=name,
|
||||
line=item.line,
|
||||
column=item.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
apply_type_to_mapped_statement(
|
||||
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
|
||||
)
|
||||
|
||||
|
||||
def re_apply_declarative_assignments(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""For multiple class passes, re-apply our left-hand side types as mypy
|
||||
seems to reset them in place.
|
||||
|
||||
"""
|
||||
mapped_attr_lookup = {attr.name: attr for attr in attributes}
|
||||
update_cls_metadata = False
|
||||
|
||||
for stmt in cls.defs.body:
|
||||
# for a re-apply, all of our statements are AssignmentStmt;
|
||||
# @declared_attr calls will have been converted and this
|
||||
# currently seems to be preserved by mypy (but who knows if this
|
||||
# will change).
|
||||
if (
|
||||
isinstance(stmt, AssignmentStmt)
|
||||
and isinstance(stmt.lvalues[0], NameExpr)
|
||||
and stmt.lvalues[0].name in mapped_attr_lookup
|
||||
and isinstance(stmt.lvalues[0].node, Var)
|
||||
):
|
||||
|
||||
left_node = stmt.lvalues[0].node
|
||||
python_type_for_type = mapped_attr_lookup[
|
||||
stmt.lvalues[0].name
|
||||
].type
|
||||
|
||||
left_node_proper_type = get_proper_type(left_node.type)
|
||||
|
||||
# if we have scanned an UnboundType and now there's a more
|
||||
# specific type than UnboundType, call the re-scan so we
|
||||
# can get that set up correctly
|
||||
if (
|
||||
isinstance(python_type_for_type, UnboundType)
|
||||
and not isinstance(left_node_proper_type, UnboundType)
|
||||
and (
|
||||
isinstance(stmt.rvalue, CallExpr)
|
||||
and isinstance(stmt.rvalue.callee, MemberExpr)
|
||||
and isinstance(stmt.rvalue.callee.expr, NameExpr)
|
||||
and stmt.rvalue.callee.expr.node is not None
|
||||
and stmt.rvalue.callee.expr.node.fullname
|
||||
== NAMED_TYPE_SQLA_MAPPED
|
||||
and stmt.rvalue.callee.name == "_empty_constructor"
|
||||
and isinstance(stmt.rvalue.args[0], CallExpr)
|
||||
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
|
||||
)
|
||||
):
|
||||
|
||||
python_type_for_type = (
|
||||
infer.infer_type_from_right_hand_nameexpr(
|
||||
api,
|
||||
stmt,
|
||||
left_node,
|
||||
left_node_proper_type,
|
||||
stmt.rvalue.args[0].callee,
|
||||
)
|
||||
)
|
||||
|
||||
if python_type_for_type is None or isinstance(
|
||||
python_type_for_type, UnboundType
|
||||
):
|
||||
continue
|
||||
|
||||
# update the SQLAlchemyAttribute with the better information
|
||||
mapped_attr_lookup[
|
||||
stmt.lvalues[0].name
|
||||
].type = python_type_for_type
|
||||
|
||||
update_cls_metadata = True
|
||||
|
||||
if python_type_for_type is not None:
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
|
||||
)
|
||||
|
||||
if update_cls_metadata:
|
||||
util.set_mapped_attributes(cls.info, attributes)
|
||||
|
||||
|
||||
def apply_type_to_mapped_statement(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
lvalue: NameExpr,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
python_type_for_type: Optional[ProperType],
|
||||
) -> None:
|
||||
"""Apply the Mapped[<type>] annotation and right hand object to a
|
||||
declarative assignment statement.
|
||||
|
||||
This converts a Python declarative class statement such as::
|
||||
|
||||
class User(Base):
|
||||
# ...
|
||||
|
||||
attrname = Column(Integer)
|
||||
|
||||
To one that describes the final Python behavior to Mypy::
|
||||
|
||||
class User(Base):
|
||||
# ...
|
||||
|
||||
attrname : Mapped[Optional[int]] = <meaningless temp node>
|
||||
|
||||
"""
|
||||
left_node = lvalue.node
|
||||
assert isinstance(left_node, Var)
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
|
||||
)
|
||||
else:
|
||||
lvalue.is_inferred_def = False
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED,
|
||||
[] if python_type_for_type is None else [python_type_for_type],
|
||||
)
|
||||
|
||||
# so to have it skip the right side totally, we can do this:
|
||||
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
|
||||
|
||||
# however, if we instead manufacture a new node that uses the old
|
||||
# one, then we can still get type checking for the call itself,
|
||||
# e.g. the Column, relationship() call, etc.
|
||||
|
||||
# rewrite the node as:
|
||||
# <attr> : Mapped[<typ>] =
|
||||
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
|
||||
# the original right-hand side is maintained so it gets type checked
|
||||
# internally
|
||||
stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
|
||||
|
||||
|
||||
def add_additional_orm_attributes(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Apply __init__, __table__ and other attributes to the mapped class."""
|
||||
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
return
|
||||
|
||||
is_base = util.get_is_base(info)
|
||||
|
||||
if "__init__" not in info.names and not is_base:
|
||||
mapped_attr_names = {attr.name: attr.type for attr in attributes}
|
||||
|
||||
for base in info.mro[1:-1]:
|
||||
if "sqlalchemy" not in info.metadata:
|
||||
continue
|
||||
|
||||
base_cls_attributes = util.get_mapped_attributes(base, api)
|
||||
if base_cls_attributes is None:
|
||||
continue
|
||||
|
||||
for attr in base_cls_attributes:
|
||||
mapped_attr_names.setdefault(attr.name, attr.type)
|
||||
|
||||
arguments = []
|
||||
for name, typ in mapped_attr_names.items():
|
||||
if typ is None:
|
||||
typ = AnyType(TypeOfAny.special_form)
|
||||
arguments.append(
|
||||
Argument(
|
||||
variable=Var(name, typ),
|
||||
type_annotation=typ,
|
||||
initializer=TempNode(typ),
|
||||
kind=ARG_NAMED_OPT,
|
||||
)
|
||||
)
|
||||
|
||||
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
|
||||
|
||||
if "__table__" not in info.names and util.get_has_table(info):
|
||||
_apply_placeholder_attr_to_class(
|
||||
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
|
||||
)
|
||||
if not is_base:
|
||||
_apply_placeholder_attr_to_class(
|
||||
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
|
||||
)
|
||||
|
||||
|
||||
def _apply_placeholder_attr_to_class(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
cls: ClassDef,
|
||||
qualified_name: str,
|
||||
attrname: str,
|
||||
) -> None:
|
||||
sym = api.lookup_fully_qualified_or_none(qualified_name)
|
||||
if sym:
|
||||
assert isinstance(sym.node, TypeInfo)
|
||||
type_: ProperType = Instance(sym.node, [])
|
||||
else:
|
||||
type_ = AnyType(TypeOfAny.special_form)
|
||||
var = Var(attrname)
|
||||
var._fullname = cls.fullname + "." + attrname
|
||||
var.info = cls.info
|
||||
var.type = type_
|
||||
cls.info.names[attrname] = SymbolTableNode(MDEF, var)
|
||||
516
lib/sqlalchemy/ext/mypy/decl_class.py
Normal file
516
lib/sqlalchemy/ext/mypy/decl_class.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# ext/mypy/decl_class.py
|
||||
# Copyright (C) 2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import Decorator
|
||||
from mypy.nodes import LambdaExpr
|
||||
from mypy.nodes import ListExpr
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import PlaceholderNode
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import SymbolNode
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TempNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import Type
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import apply
|
||||
from . import infer
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
|
||||
def scan_declarative_assignments_and_apply_types(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
is_mixin_scan: bool = False,
|
||||
) -> Optional[List[util.SQLAlchemyAttribute]]:
|
||||
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
# this can occur during cached passes
|
||||
return None
|
||||
elif cls.fullname.startswith("builtins"):
|
||||
return None
|
||||
|
||||
mapped_attributes: Optional[
|
||||
List[util.SQLAlchemyAttribute]
|
||||
] = util.get_mapped_attributes(info, api)
|
||||
|
||||
# used by assign.add_additional_orm_attributes among others
|
||||
util.establish_as_sqlalchemy(info)
|
||||
|
||||
if mapped_attributes is not None:
|
||||
# ensure that a class that's mapped is always picked up by
|
||||
# its mapped() decorator or declarative metaclass before
|
||||
# it would be detected as an unmapped mixin class
|
||||
|
||||
if not is_mixin_scan:
|
||||
# mypy can call us more than once. it then *may* have reset the
|
||||
# left hand side of everything, but not the right that we removed,
|
||||
# removing our ability to re-scan. but we have the types
|
||||
# here, so lets re-apply them, or if we have an UnboundType,
|
||||
# we can re-scan
|
||||
|
||||
apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
|
||||
|
||||
return mapped_attributes
|
||||
|
||||
mapped_attributes = []
|
||||
|
||||
if not cls.defs.body:
|
||||
# when we get a mixin class from another file, the body is
|
||||
# empty (!) but the names are in the symbol table. so use that.
|
||||
|
||||
for sym_name, sym in info.names.items():
|
||||
_scan_symbol_table_entry(
|
||||
cls, api, sym_name, sym, mapped_attributes
|
||||
)
|
||||
else:
|
||||
for stmt in util.flatten_typechecking(cls.defs.body):
|
||||
if isinstance(stmt, AssignmentStmt):
|
||||
_scan_declarative_assignment_stmt(
|
||||
cls, api, stmt, mapped_attributes
|
||||
)
|
||||
elif isinstance(stmt, Decorator):
|
||||
_scan_declarative_decorator_stmt(
|
||||
cls, api, stmt, mapped_attributes
|
||||
)
|
||||
_scan_for_mapped_bases(cls, api)
|
||||
|
||||
if not is_mixin_scan:
|
||||
apply.add_additional_orm_attributes(cls, api, mapped_attributes)
|
||||
|
||||
util.set_mapped_attributes(info, mapped_attributes)
|
||||
|
||||
return mapped_attributes
|
||||
|
||||
|
||||
def _scan_symbol_table_entry(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
name: str,
|
||||
value: SymbolTableNode,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from a SymbolTableNode that's in the
|
||||
type.names dictionary.
|
||||
|
||||
"""
|
||||
value_type = get_proper_type(value.type)
|
||||
if not isinstance(value_type, Instance):
|
||||
return
|
||||
|
||||
left_hand_explicit_type = None
|
||||
type_id = names.type_id_for_named_node(value_type.type)
|
||||
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
|
||||
|
||||
err = False
|
||||
|
||||
# TODO: this is nearly the same logic as that of
|
||||
# _scan_declarative_decorator_stmt, likely can be merged
|
||||
if type_id in {
|
||||
names.MAPPED,
|
||||
names.RELATIONSHIP,
|
||||
names.COMPOSITE_PROPERTY,
|
||||
names.MAPPER_PROPERTY,
|
||||
names.SYNONYM_PROPERTY,
|
||||
names.COLUMN_PROPERTY,
|
||||
}:
|
||||
if value_type.args:
|
||||
left_hand_explicit_type = get_proper_type(value_type.args[0])
|
||||
else:
|
||||
err = True
|
||||
elif type_id is names.COLUMN:
|
||||
if not value_type.args:
|
||||
err = True
|
||||
else:
|
||||
typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
|
||||
value_type.args[0]
|
||||
)
|
||||
if isinstance(typeengine_arg, Instance):
|
||||
typeengine_arg = typeengine_arg.type
|
||||
|
||||
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
|
||||
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
if names.has_base_type_id(sym.node, names.TYPEENGINE):
|
||||
|
||||
left_hand_explicit_type = UnionType(
|
||||
[
|
||||
infer.extract_python_type_from_typeengine(
|
||||
api, sym.node, []
|
||||
),
|
||||
NoneType(),
|
||||
]
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Column type should be a TypeEngine "
|
||||
"subclass not '{}'".format(sym.node.fullname),
|
||||
value_type,
|
||||
)
|
||||
|
||||
if err:
|
||||
msg = (
|
||||
"Can't infer type from attribute {} on class {}. "
|
||||
"please specify a return type from this function that is "
|
||||
"one of: Mapped[<python type>], relationship[<target class>], "
|
||||
"Column[<TypeEngine>], MapperProperty[<python type>]"
|
||||
)
|
||||
util.fail(api, msg.format(name, cls.name), cls)
|
||||
|
||||
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
assert value.node is not None
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=name,
|
||||
line=value.node.line,
|
||||
column=value.node.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _scan_declarative_decorator_stmt(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: Decorator,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from a @declared_attr in a declarative
|
||||
class.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
@declared_attr
|
||||
def updated_at(cls) -> Column[DateTime]:
|
||||
return Column(DateTime)
|
||||
|
||||
Will resolve in mypy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
updated_at: Mapped[Optional[datetime.datetime]]
|
||||
|
||||
"""
|
||||
for dec in stmt.decorators:
|
||||
if (
|
||||
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
|
||||
and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
|
||||
):
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
dec_index = cls.defs.body.index(stmt)
|
||||
|
||||
left_hand_explicit_type: Optional[ProperType] = None
|
||||
|
||||
if util.name_is_dunder(stmt.name):
|
||||
# for dunder names like __table_args__, __tablename__,
|
||||
# __mapper_args__ etc., rewrite these as simple assignment
|
||||
# statements; otherwise mypy doesn't like if the decorated
|
||||
# function has an annotation like ``cls: Type[Foo]`` because
|
||||
# it isn't @classmethod
|
||||
any_ = AnyType(TypeOfAny.special_form)
|
||||
left_node = NameExpr(stmt.var.name)
|
||||
left_node.node = stmt.var
|
||||
new_stmt = AssignmentStmt([left_node], TempNode(any_))
|
||||
new_stmt.type = left_node.node.type
|
||||
cls.defs.body[dec_index] = new_stmt
|
||||
return
|
||||
elif isinstance(stmt.func.type, CallableType):
|
||||
func_type = stmt.func.type.ret_type
|
||||
if isinstance(func_type, UnboundType):
|
||||
type_id = names.type_id_for_unbound_type(func_type, cls, api)
|
||||
else:
|
||||
# this does not seem to occur unless the type argument is
|
||||
# incorrect
|
||||
return
|
||||
|
||||
if (
|
||||
type_id
|
||||
in {
|
||||
names.MAPPED,
|
||||
names.RELATIONSHIP,
|
||||
names.COMPOSITE_PROPERTY,
|
||||
names.MAPPER_PROPERTY,
|
||||
names.SYNONYM_PROPERTY,
|
||||
names.COLUMN_PROPERTY,
|
||||
}
|
||||
and func_type.args
|
||||
):
|
||||
left_hand_explicit_type = get_proper_type(func_type.args[0])
|
||||
elif type_id is names.COLUMN and func_type.args:
|
||||
typeengine_arg = func_type.args[0]
|
||||
if isinstance(typeengine_arg, UnboundType):
|
||||
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
if names.has_base_type_id(sym.node, names.TYPEENGINE):
|
||||
left_hand_explicit_type = UnionType(
|
||||
[
|
||||
infer.extract_python_type_from_typeengine(
|
||||
api, sym.node, []
|
||||
),
|
||||
NoneType(),
|
||||
]
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Column type should be a TypeEngine "
|
||||
"subclass not '{}'".format(sym.node.fullname),
|
||||
func_type,
|
||||
)
|
||||
|
||||
if left_hand_explicit_type is None:
|
||||
# no type on the decorated function. our option here is to
|
||||
# dig into the function body and get the return type, but they
|
||||
# should just have an annotation.
|
||||
msg = (
|
||||
"Can't infer type from @declared_attr on function '{}'; "
|
||||
"please specify a return type from this function that is "
|
||||
"one of: Mapped[<python type>], relationship[<target class>], "
|
||||
"Column[<TypeEngine>], MapperProperty[<python type>]"
|
||||
)
|
||||
util.fail(api, msg.format(stmt.var.name), stmt)
|
||||
|
||||
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
|
||||
|
||||
left_node = NameExpr(stmt.var.name)
|
||||
left_node.node = stmt.var
|
||||
|
||||
# totally feeling around in the dark here as I don't totally understand
|
||||
# the significance of UnboundType. It seems to be something that is
|
||||
# not going to do what's expected when it is applied as the type of
|
||||
# an AssignmentStatement. So do a feeling-around-in-the-dark version
|
||||
# of converting it to the regular Instance/TypeInfo/UnionType structures
|
||||
# we see everywhere else.
|
||||
if isinstance(left_hand_explicit_type, UnboundType):
|
||||
left_hand_explicit_type = get_proper_type(
|
||||
util.unbound_to_instance(api, left_hand_explicit_type)
|
||||
)
|
||||
|
||||
left_node.node.type = api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
|
||||
)
|
||||
|
||||
# this will ignore the rvalue entirely
|
||||
# rvalue = TempNode(AnyType(TypeOfAny.special_form))
|
||||
|
||||
# rewrite the node as:
|
||||
# <attr> : Mapped[<typ>] =
|
||||
# _sa_Mapped._empty_constructor(lambda: <function body>)
|
||||
# the function body is maintained so it gets type checked internally
|
||||
rvalue = util.expr_to_mapped_constructor(
|
||||
LambdaExpr(stmt.func.arguments, stmt.func.body)
|
||||
)
|
||||
|
||||
new_stmt = AssignmentStmt([left_node], rvalue)
|
||||
new_stmt.type = left_node.node.type
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=left_node.name,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
cls.defs.body[dec_index] = new_stmt
|
||||
|
||||
|
||||
def _scan_declarative_assignment_stmt(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from an assignment statement in a
|
||||
declarative class.
|
||||
|
||||
"""
|
||||
lvalue = stmt.lvalues[0]
|
||||
if not isinstance(lvalue, NameExpr):
|
||||
return
|
||||
|
||||
sym = cls.info.names.get(lvalue.name)
|
||||
|
||||
# this establishes that semantic analysis has taken place, which
|
||||
# means the nodes are populated and we are called from an appropriate
|
||||
# hook.
|
||||
assert sym is not None
|
||||
node = sym.node
|
||||
|
||||
if isinstance(node, PlaceholderNode):
|
||||
return
|
||||
|
||||
assert node is lvalue.node
|
||||
assert isinstance(node, Var)
|
||||
|
||||
if node.name == "__abstract__":
|
||||
if api.parse_bool(stmt.rvalue) is True:
|
||||
util.set_is_base(cls.info)
|
||||
return
|
||||
elif node.name == "__tablename__":
|
||||
util.set_has_table(cls.info)
|
||||
elif node.name.startswith("__"):
|
||||
return
|
||||
elif node.name == "_mypy_mapped_attrs":
|
||||
if not isinstance(stmt.rvalue, ListExpr):
|
||||
util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
|
||||
else:
|
||||
for item in stmt.rvalue.items:
|
||||
if isinstance(item, (NameExpr, StrExpr)):
|
||||
apply.apply_mypy_mapped_attr(cls, api, item, attributes)
|
||||
|
||||
left_hand_mapped_type: Optional[Type] = None
|
||||
left_hand_explicit_type: Optional[ProperType] = None
|
||||
|
||||
if node.is_inferred or node.type is None:
|
||||
if isinstance(stmt.type, UnboundType):
|
||||
# look for an explicit Mapped[] type annotation on the left
|
||||
# side with nothing on the right
|
||||
|
||||
# print(stmt.type)
|
||||
# Mapped?[Optional?[A?]]
|
||||
|
||||
left_hand_explicit_type = stmt.type
|
||||
|
||||
if stmt.type.name == "Mapped":
|
||||
mapped_sym = api.lookup_qualified("Mapped", cls)
|
||||
if (
|
||||
mapped_sym is not None
|
||||
and mapped_sym.node is not None
|
||||
and names.type_id_for_named_node(mapped_sym.node)
|
||||
is names.MAPPED
|
||||
):
|
||||
left_hand_explicit_type = get_proper_type(
|
||||
stmt.type.args[0]
|
||||
)
|
||||
left_hand_mapped_type = stmt.type
|
||||
|
||||
# TODO: do we need to convert from unbound for this case?
|
||||
# left_hand_explicit_type = util._unbound_to_instance(
|
||||
# api, left_hand_explicit_type
|
||||
# )
|
||||
else:
|
||||
node_type = get_proper_type(node.type)
|
||||
if (
|
||||
isinstance(node_type, Instance)
|
||||
and names.type_id_for_named_node(node_type.type) is names.MAPPED
|
||||
):
|
||||
# print(node.type)
|
||||
# sqlalchemy.orm.attributes.Mapped[<python type>]
|
||||
left_hand_explicit_type = get_proper_type(node_type.args[0])
|
||||
left_hand_mapped_type = node_type
|
||||
else:
|
||||
# print(node.type)
|
||||
# <python type>
|
||||
left_hand_explicit_type = node_type
|
||||
left_hand_mapped_type = None
|
||||
|
||||
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
|
||||
# annotation without assignment and Mapped is present
|
||||
# as type annotation
|
||||
# equivalent to using _infer_type_from_left_hand_type_only.
|
||||
|
||||
python_type_for_type = left_hand_explicit_type
|
||||
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
|
||||
stmt.rvalue.callee, RefExpr
|
||||
):
|
||||
|
||||
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
|
||||
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
|
||||
)
|
||||
|
||||
if python_type_for_type is None:
|
||||
return
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
assert python_type_for_type is not None
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=node.name,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
typ=python_type_for_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
apply.apply_type_to_mapped_statement(
|
||||
api,
|
||||
stmt,
|
||||
lvalue,
|
||||
left_hand_explicit_type,
|
||||
python_type_for_type,
|
||||
)
|
||||
|
||||
|
||||
def _scan_for_mapped_bases(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
) -> None:
|
||||
"""Given a class, iterate through its superclass hierarchy to find
|
||||
all other classes that are considered as ORM-significant.
|
||||
|
||||
Locates non-mapped mixins and scans them for mapped attributes to be
|
||||
applied to subclasses.
|
||||
|
||||
"""
|
||||
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
return
|
||||
|
||||
for base_info in info.mro[1:-1]:
|
||||
if base_info.fullname.startswith("builtins"):
|
||||
continue
|
||||
|
||||
# scan each base for mapped attributes. if they are not already
|
||||
# scanned (but have all their type info), that means they are unmapped
|
||||
# mixins
|
||||
scan_declarative_assignments_and_apply_types(
|
||||
base_info.defn, api, is_mixin_scan=True
|
||||
)
|
||||
556
lib/sqlalchemy/ext/mypy/infer.py
Normal file
556
lib/sqlalchemy/ext/mypy/infer.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# ext/mypy/infer.py
|
||||
# Copyright (C) 2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
|
||||
from mypy.maptype import map_instance_to_supertype
|
||||
from mypy.messages import format_type
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import FuncDef
|
||||
from mypy.nodes import LambdaExpr
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
|
||||
def infer_type_from_right_hand_nameexpr(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
infer_from_right_side: RefExpr,
|
||||
) -> Optional[ProperType]:
|
||||
|
||||
type_id = names.type_id_for_callee(infer_from_right_side)
|
||||
|
||||
if type_id is None:
|
||||
return None
|
||||
elif type_id is names.COLUMN:
|
||||
python_type_for_type = _infer_type_from_decl_column(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.RELATIONSHIP:
|
||||
python_type_for_type = _infer_type_from_relationship(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.COLUMN_PROPERTY:
|
||||
python_type_for_type = _infer_type_from_decl_column_property(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.SYNONYM_PROPERTY:
|
||||
python_type_for_type = infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.COMPOSITE_PROPERTY:
|
||||
python_type_for_type = _infer_type_from_decl_composite_property(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_relationship(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a relationship.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
addresses = relationship(Address, uselist=True)
|
||||
|
||||
order: Mapped["Order"] = relationship("Order")
|
||||
|
||||
Will resolve in mypy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
addresses: Mapped[List[Address]]
|
||||
|
||||
order: Mapped["Order"]
|
||||
|
||||
"""
|
||||
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
target_cls_arg = stmt.rvalue.args[0]
|
||||
python_type_for_type: Optional[ProperType] = None
|
||||
|
||||
if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
target_cls_arg.node, TypeInfo
|
||||
):
|
||||
# type
|
||||
related_object_type = target_cls_arg.node
|
||||
python_type_for_type = Instance(related_object_type, [])
|
||||
|
||||
# other cases not covered - an error message directs the user
|
||||
# to set an explicit type annotation
|
||||
#
|
||||
# node.type == str, it's a string
|
||||
# if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
# target_cls_arg.node, Var
|
||||
# )
|
||||
# points to a type
|
||||
# isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
# target_cls_arg.node, TypeAlias
|
||||
# )
|
||||
# string expression
|
||||
# isinstance(target_cls_arg, StrExpr)
|
||||
|
||||
uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
|
||||
collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
|
||||
stmt.rvalue, "collection_class"
|
||||
)
|
||||
type_is_a_collection = False
|
||||
|
||||
# this can be used to determine Optional for a many-to-one
|
||||
# in the same way nullable=False could be used, if we start supporting
|
||||
# that.
|
||||
# innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
|
||||
|
||||
if (
|
||||
uselist_arg is not None
|
||||
and api.parse_bool(uselist_arg) is True
|
||||
and collection_cls_arg is None
|
||||
):
|
||||
type_is_a_collection = True
|
||||
if python_type_for_type is not None:
|
||||
python_type_for_type = api.named_type(
|
||||
names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
|
||||
)
|
||||
elif (
|
||||
uselist_arg is None or api.parse_bool(uselist_arg) is True
|
||||
) and collection_cls_arg is not None:
|
||||
type_is_a_collection = True
|
||||
if isinstance(collection_cls_arg, CallExpr):
|
||||
collection_cls_arg = collection_cls_arg.callee
|
||||
|
||||
if isinstance(collection_cls_arg, NameExpr) and isinstance(
|
||||
collection_cls_arg.node, TypeInfo
|
||||
):
|
||||
if python_type_for_type is not None:
|
||||
# this can still be overridden by the left hand side
|
||||
# within _infer_Type_from_left_and_inferred_right
|
||||
python_type_for_type = Instance(
|
||||
collection_cls_arg.node, [python_type_for_type]
|
||||
)
|
||||
elif (
|
||||
isinstance(collection_cls_arg, NameExpr)
|
||||
and isinstance(collection_cls_arg.node, FuncDef)
|
||||
and collection_cls_arg.node.type is not None
|
||||
):
|
||||
if python_type_for_type is not None:
|
||||
# this can still be overridden by the left hand side
|
||||
# within _infer_Type_from_left_and_inferred_right
|
||||
|
||||
# TODO: handle mypy.types.Overloaded
|
||||
if isinstance(collection_cls_arg.node.type, CallableType):
|
||||
rt = get_proper_type(collection_cls_arg.node.type.ret_type)
|
||||
|
||||
if isinstance(rt, CallableType):
|
||||
callable_ret_type = get_proper_type(rt.ret_type)
|
||||
if isinstance(callable_ret_type, Instance):
|
||||
python_type_for_type = Instance(
|
||||
callable_ret_type.type,
|
||||
[python_type_for_type],
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Expected Python collection type for "
|
||||
"collection_class parameter",
|
||||
stmt.rvalue,
|
||||
)
|
||||
python_type_for_type = None
|
||||
elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
|
||||
if collection_cls_arg is not None:
|
||||
util.fail(
|
||||
api,
|
||||
"Sending uselist=False and collection_class at the same time "
|
||||
"does not make sense",
|
||||
stmt.rvalue,
|
||||
)
|
||||
if python_type_for_type is not None:
|
||||
python_type_for_type = UnionType(
|
||||
[python_type_for_type, NoneType()]
|
||||
)
|
||||
|
||||
else:
|
||||
if left_hand_explicit_type is None:
|
||||
msg = (
|
||||
"Can't infer scalar or collection for ORM mapped expression "
|
||||
"assigned to attribute '{}' if both 'uselist' and "
|
||||
"'collection_class' arguments are absent from the "
|
||||
"relationship(); please specify a "
|
||||
"type annotation on the left hand side."
|
||||
)
|
||||
util.fail(api, msg.format(node.name), node)
|
||||
|
||||
if python_type_for_type is None:
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif left_hand_explicit_type is not None:
|
||||
if type_is_a_collection:
|
||||
assert isinstance(left_hand_explicit_type, Instance)
|
||||
assert isinstance(python_type_for_type, Instance)
|
||||
return _infer_collection_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
else:
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
python_type_for_type,
|
||||
)
|
||||
else:
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_decl_composite_property(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a CompositeProperty."""
|
||||
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
target_cls_arg = stmt.rvalue.args[0]
|
||||
python_type_for_type = None
|
||||
|
||||
if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
target_cls_arg.node, TypeInfo
|
||||
):
|
||||
related_object_type = target_cls_arg.node
|
||||
python_type_for_type = Instance(related_object_type, [])
|
||||
else:
|
||||
python_type_for_type = None
|
||||
|
||||
if python_type_for_type is None:
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif left_hand_explicit_type is not None:
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
else:
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_decl_column_property(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a ColumnProperty.
|
||||
|
||||
This includes mappings against ``column_property()`` as well as the
|
||||
``deferred()`` function.
|
||||
|
||||
"""
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
|
||||
if stmt.rvalue.args:
|
||||
first_prop_arg = stmt.rvalue.args[0]
|
||||
|
||||
if isinstance(first_prop_arg, CallExpr):
|
||||
type_id = names.type_id_for_callee(first_prop_arg.callee)
|
||||
|
||||
# look for column_property() / deferred() etc with Column as first
|
||||
# argument
|
||||
if type_id is names.COLUMN:
|
||||
return _infer_type_from_decl_column(
|
||||
api,
|
||||
stmt,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
right_hand_expression=first_prop_arg,
|
||||
)
|
||||
|
||||
if isinstance(stmt.rvalue, CallExpr):
|
||||
type_id = names.type_id_for_callee(stmt.rvalue.callee)
|
||||
# this is probably not strictly necessary as we have to use the left
|
||||
# hand type for query expression in any case. any other no-arg
|
||||
# column prop objects would go here also
|
||||
if type_id is names.QUERY_EXPRESSION:
|
||||
return _infer_type_from_decl_column(
|
||||
api,
|
||||
stmt,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
)
|
||||
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
|
||||
|
||||
def _infer_type_from_decl_column(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
right_hand_expression: Optional[CallExpr] = None,
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a Column.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
a = Column(Integer)
|
||||
|
||||
b = Column("b", String)
|
||||
|
||||
c: Mapped[int] = Column(Integer)
|
||||
|
||||
d: bool = Column(Boolean)
|
||||
|
||||
Will resolve in MyPy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
a : Mapped[int]
|
||||
|
||||
b : Mapped[str]
|
||||
|
||||
c: Mapped[int]
|
||||
|
||||
d: Mapped[bool]
|
||||
|
||||
"""
|
||||
assert isinstance(node, Var)
|
||||
|
||||
callee = None
|
||||
|
||||
if right_hand_expression is None:
|
||||
if not isinstance(stmt.rvalue, CallExpr):
|
||||
return None
|
||||
|
||||
right_hand_expression = stmt.rvalue
|
||||
|
||||
for column_arg in right_hand_expression.args[0:2]:
|
||||
if isinstance(column_arg, CallExpr):
|
||||
if isinstance(column_arg.callee, RefExpr):
|
||||
# x = Column(String(50))
|
||||
callee = column_arg.callee
|
||||
type_args: Sequence[Expression] = column_arg.args
|
||||
break
|
||||
elif isinstance(column_arg, (NameExpr, MemberExpr)):
|
||||
if isinstance(column_arg.node, TypeInfo):
|
||||
# x = Column(String)
|
||||
callee = column_arg
|
||||
type_args = ()
|
||||
break
|
||||
else:
|
||||
# x = Column(some_name, String), go to next argument
|
||||
continue
|
||||
elif isinstance(column_arg, (StrExpr,)):
|
||||
# x = Column("name", String), go to next argument
|
||||
continue
|
||||
elif isinstance(column_arg, (LambdaExpr,)):
|
||||
# x = Column("name", String, default=lambda: uuid.uuid4())
|
||||
# go to next argument
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
|
||||
if callee is None:
|
||||
return None
|
||||
|
||||
if isinstance(callee.node, TypeInfo) and names.mro_has_id(
|
||||
callee.node.mro, names.TYPEENGINE
|
||||
):
|
||||
python_type_for_type = extract_python_type_from_typeengine(
|
||||
api, callee.node, type_args
|
||||
)
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
|
||||
else:
|
||||
return UnionType([python_type_for_type, NoneType()])
|
||||
else:
|
||||
# it's not TypeEngine, it's typically implicitly typed
|
||||
# like ForeignKey. we can't infer from the right side.
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
|
||||
|
||||
def _infer_type_from_left_and_inferred_right(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: ProperType,
|
||||
python_type_for_type: ProperType,
|
||||
orig_left_hand_type: Optional[ProperType] = None,
|
||||
orig_python_type_for_type: Optional[ProperType] = None,
|
||||
) -> Optional[ProperType]:
|
||||
"""Validate type when a left hand annotation is present and we also
|
||||
could infer the right hand side::
|
||||
|
||||
attrname: SomeType = Column(SomeDBType)
|
||||
|
||||
"""
|
||||
|
||||
if orig_left_hand_type is None:
|
||||
orig_left_hand_type = left_hand_explicit_type
|
||||
if orig_python_type_for_type is None:
|
||||
orig_python_type_for_type = python_type_for_type
|
||||
|
||||
if not is_subtype(left_hand_explicit_type, python_type_for_type):
|
||||
effective_type = api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
|
||||
)
|
||||
|
||||
msg = (
|
||||
"Left hand assignment '{}: {}' not compatible "
|
||||
"with ORM mapped expression of type {}"
|
||||
)
|
||||
util.fail(
|
||||
api,
|
||||
msg.format(
|
||||
node.name,
|
||||
format_type(orig_left_hand_type),
|
||||
format_type(effective_type),
|
||||
),
|
||||
node,
|
||||
)
|
||||
|
||||
return orig_left_hand_type
|
||||
|
||||
|
||||
def _infer_collection_type_from_left_and_inferred_right(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Instance,
|
||||
python_type_for_type: Instance,
|
||||
) -> Optional[ProperType]:
|
||||
orig_left_hand_type = left_hand_explicit_type
|
||||
orig_python_type_for_type = python_type_for_type
|
||||
|
||||
if left_hand_explicit_type.args:
|
||||
left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
|
||||
python_type_arg = get_proper_type(python_type_for_type.args[0])
|
||||
else:
|
||||
left_hand_arg = left_hand_explicit_type
|
||||
python_type_arg = python_type_for_type
|
||||
|
||||
assert isinstance(left_hand_arg, (Instance, UnionType))
|
||||
assert isinstance(python_type_arg, (Instance, UnionType))
|
||||
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api,
|
||||
node,
|
||||
left_hand_arg,
|
||||
python_type_arg,
|
||||
orig_left_hand_type=orig_left_hand_type,
|
||||
orig_python_type_for_type=orig_python_type_for_type,
|
||||
)
|
||||
|
||||
|
||||
def infer_type_from_left_hand_type_only(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Determine the type based on explicit annotation only.
|
||||
|
||||
if no annotation were present, note that we need one there to know
|
||||
the type.
|
||||
|
||||
"""
|
||||
if left_hand_explicit_type is None:
|
||||
msg = (
|
||||
"Can't infer type from ORM mapped expression "
|
||||
"assigned to attribute '{}'; please specify a "
|
||||
"Python type or "
|
||||
"Mapped[<python type>] on the left hand side."
|
||||
)
|
||||
util.fail(api, msg.format(node.name), node)
|
||||
|
||||
return api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
|
||||
)
|
||||
|
||||
else:
|
||||
# use type from the left hand side
|
||||
return left_hand_explicit_type
|
||||
|
||||
|
||||
def extract_python_type_from_typeengine(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: TypeInfo,
|
||||
type_args: Sequence[Expression],
|
||||
) -> ProperType:
|
||||
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
|
||||
first_arg = type_args[0]
|
||||
if isinstance(first_arg, RefExpr) and isinstance(
|
||||
first_arg.node, TypeInfo
|
||||
):
|
||||
for base_ in first_arg.node.mro:
|
||||
if base_.fullname == "enum.Enum":
|
||||
return Instance(first_arg.node, [])
|
||||
# TODO: support other pep-435 types here
|
||||
else:
|
||||
return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
|
||||
|
||||
assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
|
||||
"could not extract Python type from node: %s" % node
|
||||
)
|
||||
|
||||
type_engine_sym = api.lookup_fully_qualified_or_none(
|
||||
"sqlalchemy.sql.type_api.TypeEngine"
|
||||
)
|
||||
|
||||
assert type_engine_sym is not None and isinstance(
|
||||
type_engine_sym.node, TypeInfo
|
||||
)
|
||||
type_engine = map_instance_to_supertype(
|
||||
Instance(node, []),
|
||||
type_engine_sym.node,
|
||||
)
|
||||
return get_proper_type(type_engine.args[-1])
|
||||
253
lib/sqlalchemy/ext/mypy/names.py
Normal file
253
lib/sqlalchemy/ext/mypy/names.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# ext/mypy/names.py
|
||||
# Copyright (C) 2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import FuncDef
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import SymbolNode
|
||||
from mypy.nodes import TypeAlias
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import UnboundType
|
||||
|
||||
from ... import util
|
||||
|
||||
COLUMN: int = util.symbol("COLUMN") # type: ignore
|
||||
RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
|
||||
REGISTRY: int = util.symbol("REGISTRY") # type: ignore
|
||||
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
|
||||
TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
|
||||
MAPPED: int = util.symbol("MAPPED") # type: ignore
|
||||
DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
|
||||
DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
|
||||
MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
|
||||
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
|
||||
SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
|
||||
COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
|
||||
DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
|
||||
MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
|
||||
AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
|
||||
AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
|
||||
DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
|
||||
QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
|
||||
|
||||
# names that must succeed with mypy.api.named_type
|
||||
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
|
||||
NAMED_TYPE_BUILTINS_STR = "builtins.str"
|
||||
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
|
||||
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
|
||||
|
||||
_lookup: Dict[str, Tuple[int, Set[str]]] = {
|
||||
"Column": (
|
||||
COLUMN,
|
||||
{
|
||||
"sqlalchemy.sql.schema.Column",
|
||||
"sqlalchemy.sql.Column",
|
||||
},
|
||||
),
|
||||
"RelationshipProperty": (
|
||||
RELATIONSHIP,
|
||||
{
|
||||
"sqlalchemy.orm.relationships.RelationshipProperty",
|
||||
"sqlalchemy.orm.RelationshipProperty",
|
||||
},
|
||||
),
|
||||
"registry": (
|
||||
REGISTRY,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry",
|
||||
"sqlalchemy.orm.registry",
|
||||
},
|
||||
),
|
||||
"ColumnProperty": (
|
||||
COLUMN_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.properties.ColumnProperty",
|
||||
"sqlalchemy.orm.ColumnProperty",
|
||||
},
|
||||
),
|
||||
"SynonymProperty": (
|
||||
SYNONYM_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.SynonymProperty",
|
||||
"sqlalchemy.orm.SynonymProperty",
|
||||
},
|
||||
),
|
||||
"CompositeProperty": (
|
||||
COMPOSITE_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.CompositeProperty",
|
||||
"sqlalchemy.orm.CompositeProperty",
|
||||
},
|
||||
),
|
||||
"MapperProperty": (
|
||||
MAPPER_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.interfaces.MapperProperty",
|
||||
"sqlalchemy.orm.MapperProperty",
|
||||
},
|
||||
),
|
||||
"TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
|
||||
"Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
|
||||
"declarative_base": (
|
||||
DECLARATIVE_BASE,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.declarative_base",
|
||||
"sqlalchemy.orm.declarative_base",
|
||||
"sqlalchemy.orm.decl_api.declarative_base",
|
||||
},
|
||||
),
|
||||
"DeclarativeMeta": (
|
||||
DECLARATIVE_META,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.DeclarativeMeta",
|
||||
"sqlalchemy.orm.DeclarativeMeta",
|
||||
"sqlalchemy.orm.decl_api.DeclarativeMeta",
|
||||
},
|
||||
),
|
||||
"mapped": (
|
||||
MAPPED_DECORATOR,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry.mapped",
|
||||
"sqlalchemy.orm.registry.mapped",
|
||||
},
|
||||
),
|
||||
"as_declarative": (
|
||||
AS_DECLARATIVE,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.as_declarative",
|
||||
"sqlalchemy.orm.decl_api.as_declarative",
|
||||
"sqlalchemy.orm.as_declarative",
|
||||
},
|
||||
),
|
||||
"as_declarative_base": (
|
||||
AS_DECLARATIVE_BASE,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry.as_declarative_base",
|
||||
"sqlalchemy.orm.registry.as_declarative_base",
|
||||
},
|
||||
),
|
||||
"declared_attr": (
|
||||
DECLARED_ATTR,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.declared_attr",
|
||||
"sqlalchemy.orm.declared_attr",
|
||||
},
|
||||
),
|
||||
"declarative_mixin": (
|
||||
DECLARATIVE_MIXIN,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.declarative_mixin",
|
||||
"sqlalchemy.orm.declarative_mixin",
|
||||
},
|
||||
),
|
||||
"query_expression": (
|
||||
QUERY_EXPRESSION,
|
||||
{"sqlalchemy.orm.query_expression"},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
|
||||
for mr in info.mro:
|
||||
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
|
||||
if check_type_id == type_id:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
if fullnames is None:
|
||||
return False
|
||||
|
||||
return mr.fullname in fullnames
|
||||
|
||||
|
||||
def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
|
||||
for mr in mro:
|
||||
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
|
||||
if check_type_id == type_id:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
if fullnames is None:
|
||||
return False
|
||||
|
||||
return mr.fullname in fullnames
|
||||
|
||||
|
||||
def type_id_for_unbound_type(
|
||||
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[int]:
|
||||
sym = api.lookup_qualified(type_.name, type_)
|
||||
if sym is not None:
|
||||
if isinstance(sym.node, TypeAlias):
|
||||
target_type = get_proper_type(sym.node.target)
|
||||
if isinstance(target_type, Instance):
|
||||
return type_id_for_named_node(target_type.type)
|
||||
elif isinstance(sym.node, TypeInfo):
|
||||
return type_id_for_named_node(sym.node)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_callee(callee: Expression) -> Optional[int]:
|
||||
if isinstance(callee, (MemberExpr, NameExpr)):
|
||||
if isinstance(callee.node, FuncDef):
|
||||
if callee.node.type and isinstance(callee.node.type, CallableType):
|
||||
ret_type = get_proper_type(callee.node.type.ret_type)
|
||||
|
||||
if isinstance(ret_type, Instance):
|
||||
return type_id_for_fullname(ret_type.type.fullname)
|
||||
|
||||
return None
|
||||
elif isinstance(callee.node, TypeAlias):
|
||||
target_type = get_proper_type(callee.node.target)
|
||||
if isinstance(target_type, Instance):
|
||||
return type_id_for_fullname(target_type.type.fullname)
|
||||
elif isinstance(callee.node, TypeInfo):
|
||||
return type_id_for_named_node(callee)
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_named_node(
|
||||
node: Union[NameExpr, MemberExpr, SymbolNode]
|
||||
) -> Optional[int]:
|
||||
type_id, fullnames = _lookup.get(node.name, (None, None))
|
||||
|
||||
if type_id is None or fullnames is None:
|
||||
return None
|
||||
elif node.fullname in fullnames:
|
||||
return type_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_fullname(fullname: str) -> Optional[int]:
|
||||
tokens = fullname.split(".")
|
||||
immediate = tokens[-1]
|
||||
|
||||
type_id, fullnames = _lookup.get(immediate, (None, None))
|
||||
|
||||
if type_id is None or fullnames is None:
|
||||
return None
|
||||
elif fullname in fullnames:
|
||||
return type_id
|
||||
else:
|
||||
return None
|
||||
284
lib/sqlalchemy/ext/mypy/plugin.py
Normal file
284
lib/sqlalchemy/ext/mypy/plugin.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# ext/mypy/plugin.py
|
||||
# Copyright (C) 2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Mypy plugin for SQLAlchemy ORM.
|
||||
|
||||
"""
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Type as TypingType
|
||||
from typing import Union
|
||||
|
||||
from mypy import nodes
|
||||
from mypy.mro import calculate_mro
|
||||
from mypy.mro import MroError
|
||||
from mypy.nodes import Block
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import GDEF
|
||||
from mypy.nodes import MypyFile
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import SymbolTable
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import AttributeContext
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugin import DynamicClassDefContext
|
||||
from mypy.plugin import Plugin
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import Type
|
||||
|
||||
from . import decl_class
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
|
||||
class SQLAlchemyPlugin(Plugin):
|
||||
def get_dynamic_class_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[DynamicClassDefContext], None]]:
|
||||
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
|
||||
return _dynamic_class_hook
|
||||
return None
|
||||
|
||||
def get_customize_class_mro_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
return _fill_in_decorators
|
||||
|
||||
def get_class_decorator_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
|
||||
if sym is not None and sym.node is not None:
|
||||
type_id = names.type_id_for_named_node(sym.node)
|
||||
if type_id is names.MAPPED_DECORATOR:
|
||||
return _cls_decorator_hook
|
||||
elif type_id in (
|
||||
names.AS_DECLARATIVE,
|
||||
names.AS_DECLARATIVE_BASE,
|
||||
):
|
||||
return _base_cls_decorator_hook
|
||||
elif type_id is names.DECLARATIVE_MIXIN:
|
||||
return _declarative_mixin_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_metaclass_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
|
||||
# Set any classes that explicitly have metaclass=DeclarativeMeta
|
||||
# as declarative so the check in `get_base_class_hook()` works
|
||||
return _metaclass_cls_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_base_class_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
|
||||
if (
|
||||
sym
|
||||
and isinstance(sym.node, TypeInfo)
|
||||
and util.has_declarative_base(sym.node)
|
||||
):
|
||||
return _base_cls_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_attribute_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
||||
if fullname.startswith(
|
||||
"sqlalchemy.orm.attributes.QueryableAttribute."
|
||||
):
|
||||
return _queryable_getattr_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_additional_deps(
|
||||
self, file: MypyFile
|
||||
) -> List[Tuple[int, str, int]]:
|
||||
return [
|
||||
(10, "sqlalchemy.orm.attributes", -1),
|
||||
(10, "sqlalchemy.orm.decl_api", -1),
|
||||
]
|
||||
|
||||
|
||||
def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
|
||||
return SQLAlchemyPlugin
|
||||
|
||||
|
||||
def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
|
||||
"""Generate a declarative Base class when the declarative_base() function
|
||||
is encountered."""
|
||||
|
||||
_add_globals(ctx)
|
||||
|
||||
cls = ClassDef(ctx.name, Block([]))
|
||||
cls.fullname = ctx.api.qualified_name(ctx.name)
|
||||
|
||||
info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
|
||||
cls.info = info
|
||||
_set_declarative_metaclass(ctx.api, cls)
|
||||
|
||||
cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
|
||||
if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
|
||||
util.set_is_base(cls_arg.node)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
cls_arg.node.defn, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
info.bases = [Instance(cls_arg.node, [])]
|
||||
else:
|
||||
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
|
||||
|
||||
info.bases = [obj]
|
||||
|
||||
try:
|
||||
calculate_mro(info)
|
||||
except MroError:
|
||||
util.fail(
|
||||
ctx.api, "Not able to calculate MRO for declarative base", ctx.call
|
||||
)
|
||||
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
|
||||
info.bases = [obj]
|
||||
info.fallback_to_any = True
|
||||
|
||||
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
|
||||
util.set_is_base(info)
|
||||
|
||||
|
||||
def _fill_in_decorators(ctx: ClassDefContext) -> None:
|
||||
for decorator in ctx.cls.decorators:
|
||||
# set the ".fullname" attribute of a class decorator
|
||||
# that is a MemberExpr. This causes the logic in
|
||||
# semanal.py->apply_class_plugin_hooks to invoke the
|
||||
# get_class_decorator_hook for our "registry.map_class()"
|
||||
# and "registry.as_declarative_base()" methods.
|
||||
# this seems like a bug in mypy that these decorators are otherwise
|
||||
# skipped.
|
||||
|
||||
if (
|
||||
isinstance(decorator, nodes.CallExpr)
|
||||
and isinstance(decorator.callee, nodes.MemberExpr)
|
||||
and decorator.callee.name == "as_declarative_base"
|
||||
):
|
||||
target = decorator.callee
|
||||
elif (
|
||||
isinstance(decorator, nodes.MemberExpr)
|
||||
and decorator.name == "mapped"
|
||||
):
|
||||
target = decorator
|
||||
else:
|
||||
continue
|
||||
|
||||
assert isinstance(target.expr, NameExpr)
|
||||
sym = ctx.api.lookup_qualified(
|
||||
target.expr.name, target, suppress_errors=True
|
||||
)
|
||||
if sym and sym.node:
|
||||
sym_type = get_proper_type(sym.type)
|
||||
if isinstance(sym_type, Instance):
|
||||
target.fullname = f"{sym_type.type.fullname}.{target.name}"
|
||||
else:
|
||||
# if the registry is in the same file as where the
|
||||
# decorator is used, it might not have semantic
|
||||
# symbols applied and we can't get a fully qualified
|
||||
# name or an inferred type, so we are actually going to
|
||||
# flag an error in this case that they need to annotate
|
||||
# it. The "registry" is declared just
|
||||
# once (or few times), so they have to just not use
|
||||
# type inference for its assignment in this one case.
|
||||
util.fail(
|
||||
ctx.api,
|
||||
"Class decorator called %s(), but we can't "
|
||||
"tell if it's from an ORM registry. Please "
|
||||
"annotate the registry assignment, e.g. "
|
||||
"my_registry: registry = registry()" % target.name,
|
||||
sym.node,
|
||||
)
|
||||
|
||||
|
||||
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
assert isinstance(ctx.reason, nodes.MemberExpr)
|
||||
expr = ctx.reason.expr
|
||||
|
||||
assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
|
||||
|
||||
node_type = get_proper_type(expr.node.type)
|
||||
|
||||
assert (
|
||||
isinstance(node_type, Instance)
|
||||
and names.type_id_for_named_node(node_type.type) is names.REGISTRY
|
||||
)
|
||||
|
||||
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
|
||||
|
||||
|
||||
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
|
||||
cls = ctx.cls
|
||||
|
||||
_set_declarative_metaclass(ctx.api, cls)
|
||||
|
||||
util.set_is_base(ctx.cls.info)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
cls, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
|
||||
|
||||
def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
util.set_is_base(ctx.cls.info)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
ctx.cls, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
|
||||
|
||||
def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
|
||||
util.set_is_base(ctx.cls.info)
|
||||
|
||||
|
||||
def _base_cls_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
|
||||
|
||||
|
||||
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
|
||||
# how do I....tell it it has no attribute of a certain name?
|
||||
# can't find any Type that seems to match that
|
||||
return ctx.default_attr_type
|
||||
|
||||
|
||||
def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
|
||||
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
|
||||
for all class defs
|
||||
|
||||
"""
|
||||
|
||||
util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
|
||||
|
||||
|
||||
def _set_declarative_metaclass(
|
||||
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
|
||||
) -> None:
|
||||
info = target_cls.info
|
||||
sym = api.lookup_fully_qualified_or_none(
|
||||
"sqlalchemy.orm.decl_api.DeclarativeMeta"
|
||||
)
|
||||
assert sym is not None and isinstance(sym.node, TypeInfo)
|
||||
info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
|
||||
305
lib/sqlalchemy/ext/mypy/util.py
Normal file
305
lib/sqlalchemy/ext/mypy/util.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Tuple
|
||||
from typing import Type as TypingType
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import ARG_POS
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import CLASSDEF_NO_INFO
|
||||
from mypy.nodes import Context
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import IfStmt
|
||||
from mypy.nodes import JsonDict
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import Statement
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugin import DynamicClassDefContext
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.plugins.common import deserialize_and_fixup_type
|
||||
from mypy.typeops import map_type_from_supertype
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import Type
|
||||
from mypy.types import TypeVarType
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
|
||||
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
|
||||
|
||||
|
||||
class SQLAlchemyAttribute:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
line: int,
|
||||
column: int,
|
||||
typ: Optional[Type],
|
||||
info: TypeInfo,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.line = line
|
||||
self.column = column
|
||||
self.type = typ
|
||||
self.info = info
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
assert self.type
|
||||
return {
|
||||
"name": self.name,
|
||||
"line": self.line,
|
||||
"column": self.column,
|
||||
"type": self.type.serialize(),
|
||||
}
|
||||
|
||||
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
|
||||
"""Expands type vars in the context of a subtype when an attribute is
|
||||
inherited from a generic super type.
|
||||
"""
|
||||
if not isinstance(self.type, TypeVarType):
|
||||
return
|
||||
|
||||
self.type = map_type_from_supertype(self.type, sub_type, self.info)
|
||||
|
||||
@classmethod
|
||||
def deserialize(
|
||||
cls,
|
||||
info: TypeInfo,
|
||||
data: JsonDict,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
) -> "SQLAlchemyAttribute":
|
||||
data = data.copy()
|
||||
typ = deserialize_and_fixup_type(data.pop("type"), api)
|
||||
return cls(typ=typ, info=info, **data)
|
||||
|
||||
|
||||
def name_is_dunder(name):
|
||||
return bool(re.match(r"^__.+?__$", name))
|
||||
|
||||
|
||||
def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
|
||||
info.metadata.setdefault("sqlalchemy", {})[key] = data
|
||||
|
||||
|
||||
def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
|
||||
return info.metadata.get("sqlalchemy", {}).get(key, None)
|
||||
|
||||
|
||||
def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
|
||||
if info.mro:
|
||||
for base in info.mro:
|
||||
metadata = _get_info_metadata(base, key)
|
||||
if metadata is not None:
|
||||
return metadata
|
||||
return None
|
||||
|
||||
|
||||
def establish_as_sqlalchemy(info: TypeInfo) -> None:
|
||||
info.metadata.setdefault("sqlalchemy", {})
|
||||
|
||||
|
||||
def set_is_base(info: TypeInfo) -> None:
|
||||
_set_info_metadata(info, "is_base", True)
|
||||
|
||||
|
||||
def get_is_base(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_metadata(info, "is_base")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def has_declarative_base(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_mro_metadata(info, "is_base")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def set_has_table(info: TypeInfo) -> None:
|
||||
_set_info_metadata(info, "has_table", True)
|
||||
|
||||
|
||||
def get_has_table(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_metadata(info, "has_table")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def get_mapped_attributes(
|
||||
info: TypeInfo, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[List[SQLAlchemyAttribute]]:
|
||||
mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
|
||||
info, "mapped_attributes"
|
||||
)
|
||||
if mapped_attributes is None:
|
||||
return None
|
||||
|
||||
attributes: List[SQLAlchemyAttribute] = []
|
||||
|
||||
for data in mapped_attributes:
|
||||
attr = SQLAlchemyAttribute.deserialize(info, data, api)
|
||||
attr.expand_typevar_from_subtype(info)
|
||||
attributes.append(attr)
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
def set_mapped_attributes(
|
||||
info: TypeInfo, attributes: List[SQLAlchemyAttribute]
|
||||
) -> None:
|
||||
_set_info_metadata(
|
||||
info,
|
||||
"mapped_attributes",
|
||||
[attribute.serialize() for attribute in attributes],
|
||||
)
|
||||
|
||||
|
||||
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
|
||||
msg = "[SQLAlchemy Mypy plugin] %s" % msg
|
||||
return api.fail(msg, ctx)
|
||||
|
||||
|
||||
def add_global(
|
||||
ctx: Union[ClassDefContext, DynamicClassDefContext],
|
||||
module: str,
|
||||
symbol_name: str,
|
||||
asname: str,
|
||||
) -> None:
|
||||
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
|
||||
|
||||
if asname not in module_globals:
|
||||
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
|
||||
symbol_name
|
||||
]
|
||||
|
||||
module_globals[asname] = lookup_sym
|
||||
|
||||
|
||||
@overload
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr, name: str, *, expr_types: None = ...
|
||||
) -> Optional[Union[CallExpr, NameExpr]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr,
|
||||
name: str,
|
||||
*,
|
||||
expr_types: Tuple[TypingType[_TArgType], ...]
|
||||
) -> Optional[_TArgType]:
|
||||
...
|
||||
|
||||
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr,
|
||||
name: str,
|
||||
*,
|
||||
expr_types: Optional[Tuple[TypingType[Any], ...]] = None
|
||||
) -> Optional[Any]:
|
||||
try:
|
||||
arg_idx = callexpr.arg_names.index(name)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
kwarg = callexpr.args[arg_idx]
|
||||
if isinstance(
|
||||
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
|
||||
):
|
||||
return kwarg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
|
||||
for stmt in stmts:
|
||||
if (
|
||||
isinstance(stmt, IfStmt)
|
||||
and isinstance(stmt.expr[0], NameExpr)
|
||||
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
|
||||
):
|
||||
for substmt in stmt.body[0].body:
|
||||
yield substmt
|
||||
else:
|
||||
yield stmt
|
||||
|
||||
|
||||
def unbound_to_instance(
|
||||
api: SemanticAnalyzerPluginInterface, typ: Type
|
||||
) -> Type:
|
||||
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
|
||||
and convert it into an Instance/TypeInfo kind of structure that seems
|
||||
to work as the left-hand type of an AssignmentStatement.
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(typ, UnboundType):
|
||||
return typ
|
||||
|
||||
# TODO: figure out a more robust way to check this. The node is some
|
||||
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
|
||||
# but I cant figure out how to get them to match up
|
||||
if typ.name == "Optional":
|
||||
# convert from "Optional?" to the more familiar
|
||||
# UnionType[..., NoneType()]
|
||||
return unbound_to_instance(
|
||||
api,
|
||||
UnionType(
|
||||
[unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
|
||||
+ [NoneType()]
|
||||
),
|
||||
)
|
||||
|
||||
node = api.lookup_qualified(typ.name, typ)
|
||||
|
||||
if (
|
||||
node is not None
|
||||
and isinstance(node, SymbolTableNode)
|
||||
and isinstance(node.node, TypeInfo)
|
||||
):
|
||||
bound_type = node.node
|
||||
|
||||
return Instance(
|
||||
bound_type,
|
||||
[
|
||||
unbound_to_instance(api, arg)
|
||||
if isinstance(arg, UnboundType)
|
||||
else arg
|
||||
for arg in typ.args
|
||||
],
|
||||
)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
def info_for_cls(
|
||||
cls: ClassDef, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[TypeInfo]:
|
||||
if cls.info is CLASSDEF_NO_INFO:
|
||||
sym = api.lookup_qualified(cls.name, cls)
|
||||
if sym is None:
|
||||
return None
|
||||
assert sym and isinstance(sym.node, TypeInfo)
|
||||
return sym.node
|
||||
|
||||
return cls.info
|
||||
|
||||
|
||||
def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
|
||||
column_descriptor = NameExpr("__sa_Mapped")
|
||||
column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
|
||||
member_expr = MemberExpr(column_descriptor, "_empty_constructor")
|
||||
return CallExpr(
|
||||
member_expr,
|
||||
[expr],
|
||||
[ARG_POS],
|
||||
["arg1"],
|
||||
)
|
||||
388
lib/sqlalchemy/ext/orderinglist.py
Normal file
388
lib/sqlalchemy/ext/orderinglist.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# ext/orderinglist.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""A custom list that manages index/position information for contained
|
||||
elements.
|
||||
|
||||
:author: Jason Kirtland
|
||||
|
||||
``orderinglist`` is a helper for mutable ordered relationships. It will
|
||||
intercept list operations performed on a :func:`_orm.relationship`-managed
|
||||
collection and
|
||||
automatically synchronize changes in list position onto a target scalar
|
||||
attribute.
|
||||
|
||||
Example: A ``slide`` table, where each row refers to zero or more entries
|
||||
in a related ``bullet`` table. The bullets within a slide are
|
||||
displayed in order based on the value of the ``position`` column in the
|
||||
``bullet`` table. As entries are reordered in memory, the value of the
|
||||
``position`` attribute should be updated to reflect the new sort order::
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position")
|
||||
|
||||
class Bullet(Base):
|
||||
__tablename__ = 'bullet'
|
||||
id = Column(Integer, primary_key=True)
|
||||
slide_id = Column(Integer, ForeignKey('slide.id'))
|
||||
position = Column(Integer)
|
||||
text = Column(String)
|
||||
|
||||
The standard relationship mapping will produce a list-like attribute on each
|
||||
``Slide`` containing all related ``Bullet`` objects,
|
||||
but coping with changes in ordering is not handled automatically.
|
||||
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
|
||||
attribute will remain unset until manually assigned. When the ``Bullet``
|
||||
is inserted into the middle of the list, the following ``Bullet`` objects
|
||||
will also need to be renumbered.
|
||||
|
||||
The :class:`.OrderingList` object automates this task, managing the
|
||||
``position`` attribute on all ``Bullet`` objects in the collection. It is
|
||||
constructed using the :func:`.ordering_list` factory::
|
||||
|
||||
from sqlalchemy.ext.orderinglist import ordering_list
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position",
|
||||
collection_class=ordering_list('position'))
|
||||
|
||||
class Bullet(Base):
|
||||
__tablename__ = 'bullet'
|
||||
id = Column(Integer, primary_key=True)
|
||||
slide_id = Column(Integer, ForeignKey('slide.id'))
|
||||
position = Column(Integer)
|
||||
text = Column(String)
|
||||
|
||||
With the above mapping the ``Bullet.position`` attribute is managed::
|
||||
|
||||
s = Slide()
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets[1].position
|
||||
>>> 1
|
||||
s.bullets.insert(1, Bullet())
|
||||
s.bullets[2].position
|
||||
>>> 2
|
||||
|
||||
The :class:`.OrderingList` construct only works with **changes** to a
|
||||
collection, and not the initial load from the database, and requires that the
|
||||
list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
|
||||
:func:`_orm.relationship` against the target ordering attribute, so that the
|
||||
ordering is correct when first loaded.
|
||||
|
||||
.. warning::
|
||||
|
||||
:class:`.OrderingList` only provides limited functionality when a primary
|
||||
key column or unique column is the target of the sort. Operations
|
||||
that are unsupported or are problematic include:
|
||||
|
||||
* two entries must trade values. This is not supported directly in the
|
||||
case of a primary key or unique constraint because it means at least
|
||||
one row would need to be temporarily removed first, or changed to
|
||||
a third, neutral value while the switch occurs.
|
||||
|
||||
* an entry must be deleted in order to make room for a new entry.
|
||||
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
|
||||
single flush. In the case of a primary key, it will trade
|
||||
an INSERT/DELETE of the same primary key for an UPDATE statement in order
|
||||
to lessen the impact of this limitation, however this does not take place
|
||||
for a UNIQUE column.
|
||||
A future feature will allow the "DELETE before INSERT" behavior to be
|
||||
possible, alleviating this limitation, though this feature will require
|
||||
explicit configuration at the mapper level for sets of columns that
|
||||
are to be handled in this way.
|
||||
|
||||
:func:`.ordering_list` takes the name of the related object's ordering
|
||||
attribute as an argument. By default, the zero-based integer index of the
|
||||
object's position in the :func:`.ordering_list` is synchronized with the
|
||||
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
|
||||
start numbering at 1 or some other integer, provide ``count_from=1``.
|
||||
|
||||
|
||||
"""
|
||||
from ..orm.collections import collection
|
||||
from ..orm.collections import collection_adapter
|
||||
|
||||
|
||||
__all__ = ["ordering_list"]
|
||||
|
||||
|
||||
def ordering_list(attr, count_from=None, **kw):
|
||||
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
|
||||
|
||||
Returns an object suitable for use as an argument to a Mapper
|
||||
relationship's ``collection_class`` option. e.g.::
|
||||
|
||||
from sqlalchemy.ext.orderinglist import ordering_list
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position",
|
||||
collection_class=ordering_list('position'))
|
||||
|
||||
:param attr:
|
||||
Name of the mapped attribute to use for storage and retrieval of
|
||||
ordering information
|
||||
|
||||
:param count_from:
|
||||
Set up an integer-based ordering, starting at ``count_from``. For
|
||||
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
|
||||
list in SQL, storing the value in the 'pos' column. Ignored if
|
||||
``ordering_func`` is supplied.
|
||||
|
||||
Additional arguments are passed to the :class:`.OrderingList` constructor.
|
||||
|
||||
"""
|
||||
|
||||
kw = _unsugar_count_from(count_from=count_from, **kw)
|
||||
return lambda: OrderingList(attr, **kw)
|
||||
|
||||
|
||||
# Ordering utility functions
|
||||
|
||||
|
||||
def count_from_0(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 0."""
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def count_from_1(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 1."""
|
||||
|
||||
return index + 1
|
||||
|
||||
|
||||
def count_from_n_factory(start):
|
||||
"""Numbering function: consecutive integers starting at arbitrary start."""
|
||||
|
||||
def f(index, collection):
|
||||
return index + start
|
||||
|
||||
try:
|
||||
f.__name__ = "count_from_%i" % start
|
||||
except TypeError:
|
||||
pass
|
||||
return f
|
||||
|
||||
|
||||
def _unsugar_count_from(**kw):
|
||||
"""Builds counting functions from keyword arguments.
|
||||
|
||||
Keyword argument filter, prepares a simple ``ordering_func`` from a
|
||||
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
|
||||
"""
|
||||
|
||||
count_from = kw.pop("count_from", None)
|
||||
if kw.get("ordering_func", None) is None and count_from is not None:
|
||||
if count_from == 0:
|
||||
kw["ordering_func"] = count_from_0
|
||||
elif count_from == 1:
|
||||
kw["ordering_func"] = count_from_1
|
||||
else:
|
||||
kw["ordering_func"] = count_from_n_factory(count_from)
|
||||
return kw
|
||||
|
||||
|
||||
class OrderingList(list):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
The :class:`.OrderingList` object is normally set up using the
|
||||
:func:`.ordering_list` factory function, used in conjunction with
|
||||
the :func:`_orm.relationship` function.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, ordering_attr=None, ordering_func=None, reorder_on_append=False
|
||||
):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
``OrderingList`` is a ``collection_class`` list implementation that
|
||||
syncs position in a Python list with a position attribute on the
|
||||
mapped objects.
|
||||
|
||||
This implementation relies on the list starting in the proper order,
|
||||
so be **sure** to put an ``order_by`` on your relationship.
|
||||
|
||||
:param ordering_attr:
|
||||
Name of the attribute that stores the object's order in the
|
||||
relationship.
|
||||
|
||||
:param ordering_func: Optional. A function that maps the position in
|
||||
the Python list to a value to store in the
|
||||
``ordering_attr``. Values returned are usually (but need not be!)
|
||||
integers.
|
||||
|
||||
An ``ordering_func`` is called with two positional parameters: the
|
||||
index of the element in the list, and the list itself.
|
||||
|
||||
If omitted, Python list indexes are used for the attribute values.
|
||||
Two basic pre-built numbering functions are provided in this module:
|
||||
``count_from_0`` and ``count_from_1``. For more exotic examples
|
||||
like stepped numbering, alphabetical and Fibonacci numbering, see
|
||||
the unit tests.
|
||||
|
||||
:param reorder_on_append:
|
||||
Default False. When appending an object with an existing (non-None)
|
||||
ordering value, that value will be left untouched unless
|
||||
``reorder_on_append`` is true. This is an optimization to avoid a
|
||||
variety of dangerous unexpected database writes.
|
||||
|
||||
SQLAlchemy will add instances to the list via append() when your
|
||||
object loads. If for some reason the result set from the database
|
||||
skips a step in the ordering (say, row '1' is missing but you get
|
||||
'2', '3', and '4'), reorder_on_append=True would immediately
|
||||
renumber the items to '1', '2', '3'. If you have multiple sessions
|
||||
making changes, any of whom happen to load this collection even in
|
||||
passing, all of the sessions would try to "clean up" the numbering
|
||||
in their commits, possibly causing all but one to fail with a
|
||||
concurrent modification error.
|
||||
|
||||
Recommend leaving this with the default of False, and just call
|
||||
``reorder()`` if you're doing ``append()`` operations with
|
||||
previously ordered instances or when doing some housekeeping after
|
||||
manual sql operations.
|
||||
|
||||
"""
|
||||
self.ordering_attr = ordering_attr
|
||||
if ordering_func is None:
|
||||
ordering_func = count_from_0
|
||||
self.ordering_func = ordering_func
|
||||
self.reorder_on_append = reorder_on_append
|
||||
|
||||
# More complex serialization schemes (multi column, e.g.) are possible by
|
||||
# subclassing and reimplementing these two methods.
|
||||
def _get_order_value(self, entity):
|
||||
return getattr(entity, self.ordering_attr)
|
||||
|
||||
def _set_order_value(self, entity, value):
|
||||
setattr(entity, self.ordering_attr, value)
|
||||
|
||||
def reorder(self):
|
||||
"""Synchronize ordering for the entire collection.
|
||||
|
||||
Sweeps through the list and ensures that each object has accurate
|
||||
ordering information set.
|
||||
|
||||
"""
|
||||
for index, entity in enumerate(self):
|
||||
self._order_entity(index, entity, True)
|
||||
|
||||
# As of 0.5, _reorder is no longer semi-private
|
||||
_reorder = reorder
|
||||
|
||||
def _order_entity(self, index, entity, reorder=True):
|
||||
have = self._get_order_value(entity)
|
||||
|
||||
# Don't disturb existing ordering if reorder is False
|
||||
if have is not None and not reorder:
|
||||
return
|
||||
|
||||
should_be = self.ordering_func(index, self)
|
||||
if have != should_be:
|
||||
self._set_order_value(entity, should_be)
|
||||
|
||||
def append(self, entity):
|
||||
super(OrderingList, self).append(entity)
|
||||
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
|
||||
|
||||
def _raw_append(self, entity):
|
||||
"""Append without any ordering behavior."""
|
||||
|
||||
super(OrderingList, self).append(entity)
|
||||
|
||||
_raw_append = collection.adds(1)(_raw_append)
|
||||
|
||||
def insert(self, index, entity):
|
||||
super(OrderingList, self).insert(index, entity)
|
||||
self._reorder()
|
||||
|
||||
def remove(self, entity):
|
||||
super(OrderingList, self).remove(entity)
|
||||
|
||||
adapter = collection_adapter(self)
|
||||
if adapter and adapter._referenced_by_owner:
|
||||
self._reorder()
|
||||
|
||||
def pop(self, index=-1):
|
||||
entity = super(OrderingList, self).pop(index)
|
||||
self._reorder()
|
||||
return entity
|
||||
|
||||
def __setitem__(self, index, entity):
|
||||
if isinstance(index, slice):
|
||||
step = index.step or 1
|
||||
start = index.start or 0
|
||||
if start < 0:
|
||||
start += len(self)
|
||||
stop = index.stop or len(self)
|
||||
if stop < 0:
|
||||
stop += len(self)
|
||||
|
||||
for i in range(start, stop, step):
|
||||
self.__setitem__(i, entity[i])
|
||||
else:
|
||||
self._order_entity(index, entity, True)
|
||||
super(OrderingList, self).__setitem__(index, entity)
|
||||
|
||||
def __delitem__(self, index):
|
||||
super(OrderingList, self).__delitem__(index)
|
||||
self._reorder()
|
||||
|
||||
def __setslice__(self, start, end, values):
|
||||
super(OrderingList, self).__setslice__(start, end, values)
|
||||
self._reorder()
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
super(OrderingList, self).__delslice__(start, end)
|
||||
self._reorder()
|
||||
|
||||
def __reduce__(self):
|
||||
return _reconstitute, (self.__class__, self.__dict__, list(self))
|
||||
|
||||
for func_name, func in list(locals().items()):
|
||||
if (
|
||||
callable(func)
|
||||
and func.__name__ == func_name
|
||||
and not func.__doc__
|
||||
and hasattr(list, func_name)
|
||||
):
|
||||
func.__doc__ = getattr(list, func_name).__doc__
|
||||
del func_name, func
|
||||
|
||||
|
||||
def _reconstitute(cls, dict_, items):
|
||||
"""Reconstitute an :class:`.OrderingList`.
|
||||
|
||||
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
|
||||
unpickling :class:`.OrderingList` objects.
|
||||
|
||||
"""
|
||||
obj = cls.__new__(cls)
|
||||
obj.__dict__.update(dict_)
|
||||
list.extend(obj, items)
|
||||
return obj
|
||||
177
lib/sqlalchemy/ext/serializer.py
Normal file
177
lib/sqlalchemy/ext/serializer.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# ext/serializer.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
|
||||
allowing "contextual" deserialization.
|
||||
|
||||
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
|
||||
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
|
||||
etc. which are referenced by the structure are not persisted in serialized
|
||||
form, but are instead re-associated with the query structure
|
||||
when it is deserialized.
|
||||
|
||||
Usage is nearly the same as that of the standard Python pickle module::
|
||||
|
||||
from sqlalchemy.ext.serializer import loads, dumps
|
||||
metadata = MetaData(bind=some_engine)
|
||||
Session = scoped_session(sessionmaker())
|
||||
|
||||
# ... define mappers
|
||||
|
||||
query = Session.query(MyClass).
|
||||
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
|
||||
|
||||
# pickle the query
|
||||
serialized = dumps(query)
|
||||
|
||||
# unpickle. Pass in metadata + scoped_session
|
||||
query2 = loads(serialized, metadata, Session)
|
||||
|
||||
print query2.all()
|
||||
|
||||
Similar restrictions as when using raw pickle apply; mapped classes must be
|
||||
themselves be pickleable, meaning they are importable from a module-level
|
||||
namespace.
|
||||
|
||||
The serializer module is only appropriate for query structures. It is not
|
||||
needed for:
|
||||
|
||||
* instances of user-defined classes. These contain no references to engines,
|
||||
sessions or expression constructs in the typical case and can be serialized
|
||||
directly.
|
||||
|
||||
* Table metadata that is to be loaded entirely from the serialized structure
|
||||
(i.e. is not already declared in the application). Regular
|
||||
pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
|
||||
typically one which was reflected from an existing database at some previous
|
||||
point in time. The serializer module is specifically for the opposite case,
|
||||
where the Table metadata is already present in memory.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .. import Column
|
||||
from .. import Table
|
||||
from ..engine import Engine
|
||||
from ..orm import class_mapper
|
||||
from ..orm.interfaces import MapperProperty
|
||||
from ..orm.mapper import Mapper
|
||||
from ..orm.session import Session
|
||||
from ..util import b64decode
|
||||
from ..util import b64encode
|
||||
from ..util import byte_buffer
|
||||
from ..util import pickle
|
||||
from ..util import text_type
|
||||
|
||||
|
||||
__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
|
||||
|
||||
|
||||
def Serializer(*args, **kw):
|
||||
pickler = pickle.Pickler(*args, **kw)
|
||||
|
||||
def persistent_id(obj):
|
||||
# print "serializing:", repr(obj)
|
||||
if isinstance(obj, Mapper) and not obj.non_primary:
|
||||
id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
|
||||
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
|
||||
id_ = (
|
||||
"mapperprop:"
|
||||
+ b64encode(pickle.dumps(obj.parent.class_))
|
||||
+ ":"
|
||||
+ obj.key
|
||||
)
|
||||
elif isinstance(obj, Table):
|
||||
if "parententity" in obj._annotations:
|
||||
id_ = "mapper_selectable:" + b64encode(
|
||||
pickle.dumps(obj._annotations["parententity"].class_)
|
||||
)
|
||||
else:
|
||||
id_ = "table:" + text_type(obj.key)
|
||||
elif isinstance(obj, Column) and isinstance(obj.table, Table):
|
||||
id_ = (
|
||||
"column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
|
||||
)
|
||||
elif isinstance(obj, Session):
|
||||
id_ = "session:"
|
||||
elif isinstance(obj, Engine):
|
||||
id_ = "engine:"
|
||||
else:
|
||||
return None
|
||||
return id_
|
||||
|
||||
pickler.persistent_id = persistent_id
|
||||
return pickler
|
||||
|
||||
|
||||
our_ids = re.compile(
|
||||
r"(mapperprop|mapper|mapper_selectable|table|column|"
|
||||
r"session|attribute|engine):(.*)"
|
||||
)
|
||||
|
||||
|
||||
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
|
||||
unpickler = pickle.Unpickler(file)
|
||||
|
||||
def get_engine():
|
||||
if engine:
|
||||
return engine
|
||||
elif scoped_session and scoped_session().bind:
|
||||
return scoped_session().bind
|
||||
elif metadata and metadata.bind:
|
||||
return metadata.bind
|
||||
else:
|
||||
return None
|
||||
|
||||
def persistent_load(id_):
|
||||
m = our_ids.match(text_type(id_))
|
||||
if not m:
|
||||
return None
|
||||
else:
|
||||
type_, args = m.group(1, 2)
|
||||
if type_ == "attribute":
|
||||
key, clsarg = args.split(":")
|
||||
cls = pickle.loads(b64decode(clsarg))
|
||||
return getattr(cls, key)
|
||||
elif type_ == "mapper":
|
||||
cls = pickle.loads(b64decode(args))
|
||||
return class_mapper(cls)
|
||||
elif type_ == "mapper_selectable":
|
||||
cls = pickle.loads(b64decode(args))
|
||||
return class_mapper(cls).__clause_element__()
|
||||
elif type_ == "mapperprop":
|
||||
mapper, keyname = args.split(":")
|
||||
cls = pickle.loads(b64decode(mapper))
|
||||
return class_mapper(cls).attrs[keyname]
|
||||
elif type_ == "table":
|
||||
return metadata.tables[args]
|
||||
elif type_ == "column":
|
||||
table, colname = args.split(":")
|
||||
return metadata.tables[table].c[colname]
|
||||
elif type_ == "session":
|
||||
return scoped_session()
|
||||
elif type_ == "engine":
|
||||
return get_engine()
|
||||
else:
|
||||
raise Exception("Unknown token: %s" % type_)
|
||||
|
||||
unpickler.persistent_load = persistent_load
|
||||
return unpickler
|
||||
|
||||
|
||||
def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
|
||||
buf = byte_buffer()
|
||||
pickler = Serializer(buf, protocol)
|
||||
pickler.dump(obj)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def loads(data, metadata=None, scoped_session=None, engine=None):
|
||||
buf = byte_buffer(data)
|
||||
unpickler = Deserializer(buf, metadata, scoped_session, engine)
|
||||
return unpickler.load()
|
||||
Reference in New Issue
Block a user