diff options
author | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
---|---|---|
committer | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
commit | 1dac2263372df2b85db5d029a45721fa158a5c9d (patch) | |
tree | 0365f9c57df04178a726d7584ca6a6b955a7ce6a /lib/sqlalchemy/sql/util.py | |
parent | b494be364bb39e1de128ada7dc576a729d99907e (diff) | |
download | sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.gz sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.bz2 sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.zip |
first add files
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 1120 |
1 files changed, 1120 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py new file mode 100644 index 0000000..019b29e --- /dev/null +++ b/lib/sqlalchemy/sql/util.py @@ -0,0 +1,1120 @@ +# sql/util.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""High level utilities which build upon other modules here. + +""" + +from collections import deque +from itertools import chain + +from . import coercions +from . import operators +from . import roles +from . import visitors +from .annotation import _deep_annotate # noqa +from .annotation import _deep_deannotate # noqa +from .annotation import _shallow_annotate # noqa +from .base import _expand_cloned +from .base import _from_objects +from .base import ColumnSet +from .ddl import sort_tables # noqa +from .elements import _find_columns # noqa +from .elements import _label_reference +from .elements import _textual_label_reference +from .elements import BindParameter +from .elements import ColumnClause +from .elements import ColumnElement +from .elements import Grouping +from .elements import Label +from .elements import Null +from .elements import UnaryExpression +from .schema import Column +from .selectable import Alias +from .selectable import FromClause +from .selectable import FromGrouping +from .selectable import Join +from .selectable import ScalarSelect +from .selectable import SelectBase +from .selectable import TableClause +from .traversals import HasCacheKey # noqa +from .. import exc +from .. import util + + +join_condition = util.langhelpers.public_factory( + Join._join_condition, ".sql.util.join_condition" +) + + +def find_join_source(clauses, join_to): + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns + None, None if no match is found. + + e.g.:: + + clause1 = table1.join(table2) + clause2 = table4.join(table5) + + join_to = table2.join(table3) + + find_join_source([clause1, clause2], join_to) == clause1 + + """ + + selectables = list(_from_objects(join_to)) + idx = [] + for i, f in enumerate(clauses): + for s in selectables: + if f.is_derived_from(s): + idx.append(i) + return idx + + +def find_left_clause_that_matches_given(clauses, join_from): + """Given a list of FROM clauses and a selectable, + return the indexes from the list of + clauses which is derived from the selectable. + + """ + + selectables = list(_from_objects(join_from)) + liberal_idx = [] + for i, f in enumerate(clauses): + for s in selectables: + # basic check, if f is derived from s. + # this can be joins containing a table, or an aliased table + # or select statement matching to a table. This check + # will match a table to a selectable that is adapted from + # that table. With Query, this suits the case where a join + # is being made to an adapted entity + if f.is_derived_from(s): + liberal_idx.append(i) + break + + # in an extremely small set of use cases, a join is being made where + # there are multiple FROM clauses where our target table is represented + # in more than one, such as embedded or similar. in this case, do + # another pass where we try to get a more exact match where we aren't + # looking at adaption relationships. + if len(liberal_idx) > 1: + conservative_idx = [] + for idx in liberal_idx: + f = clauses[idx] + for s in selectables: + if set(surface_selectables(f)).intersection( + surface_selectables(s) + ): + conservative_idx.append(idx) + break + if conservative_idx: + return conservative_idx + + return liberal_idx + + +def find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presence of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif onclause is not None or Join._can_join(f, s): + idx.append(i) + break + + if len(idx) > 1: + # this is the same "hide froms" logic from + # Selectable._get_display_froms + toremove = set( + chain(*[_expand_cloned(f._hide_froms) for f in clauses]) + ) + idx = [i for i in idx if clauses[i] not in toremove] + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) + else: + return idx + + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a <eq> q + a <eq> e + a <eq> f + b <eq> q + b <eq> e + b <eq> f + j <eq> r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + + def visit(element): + if isinstance(element, ScalarSelect): + # we don't want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == "binary" and operators.is_comparison( + element.operator + ): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + + list(visit(expr)) + visit = None # remove gc cycles + + +def find_tables( + clause, + check_columns=False, + include_aliases=False, + include_joins=False, + include_selects=False, + include_crud=False, +): + """locate Table objects within the given expression.""" + + tables = [] + _visitors = {} + + if include_selects: + _visitors["select"] = _visitors["compound_select"] = tables.append + + if include_joins: + _visitors["join"] = tables.append + + if include_aliases: + _visitors["alias"] = _visitors["subquery"] = _visitors[ + "tablesample" + ] = _visitors["lateral"] = tables.append + + if include_crud: + _visitors["insert"] = _visitors["update"] = _visitors[ + "delete" + ] = lambda ent: tables.append(ent.table) + + if check_columns: + + def visit_column(column): + tables.append(column.table) + + _visitors["column"] = visit_column + + _visitors["table"] = tables.append + + visitors.traverse(clause, {}, _visitors) + return tables + + +def unwrap_order_by(clause): + """Break up an 'order by' expression into individual column-expressions, + without DESC/ASC/NULLS FIRST/NULLS LAST""" + + cols = util.column_set() + result = [] + stack = deque([clause]) + + # examples + # column -> ASC/DESC == column + # column -> ASC/DESC -> label == column + # column -> label -> ASC/DESC -> label == column + # scalar_select -> label -> ASC/DESC == scalar_select -> label + + while stack: + t = stack.popleft() + if isinstance(t, ColumnElement) and ( + not isinstance(t, UnaryExpression) + or not operators.is_ordering_modifier(t.modifier) + ): + if isinstance(t, Label) and not isinstance( + t.element, ScalarSelect + ): + t = t.element + + if isinstance(t, Grouping): + t = t.element + + stack.append(t) + continue + elif isinstance(t, _label_reference): + t = t.element + + stack.append(t) + continue + if isinstance(t, (_textual_label_reference)): + continue + if t not in cols: + cols.add(t) + result.append(t) + + else: + for c in t.get_children(): + stack.append(c) + return result + + +def unwrap_label_reference(element): + def replace(elem): + if isinstance(elem, (_label_reference, _textual_label_reference)): + return elem.element + + return visitors.replacement_traverse(element, {}, replace) + + +def expand_column_list_from_order_by(collist, order_by): + """Given the columns clause and ORDER BY of a selectable, + return a list of column expressions that can be added to the collist + corresponding to the ORDER BY, without repeating those already + in the collist. + + """ + cols_already_present = set( + [ + col.element if col._order_by_label_element is not None else col + for col in collist + ] + ) + + to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by])) + + return [col for col in to_look_for if col not in cols_already_present] + + +def clause_is_present(clause, search): + """Given a target clause and a second to search within, return True + if the target is plainly present in the search without any + subqueries or aliases involved. + + Basically descends through Joins. + + """ + + for elem in surface_selectables(search): + if clause == elem: # use == here so that Annotated's compare + return True + else: + return False + + +def tables_from_leftmost(clause): + if isinstance(clause, Join): + for t in tables_from_leftmost(clause.left): + yield t + for t in tables_from_leftmost(clause.right): + yield t + elif isinstance(clause, FromGrouping): + for t in tables_from_leftmost(clause.element): + yield t + else: + yield clause + + +def surface_selectables(clause): + stack = [clause] + while stack: + elem = stack.pop() + yield elem + if isinstance(elem, Join): + stack.extend((elem.left, elem.right)) + elif isinstance(elem, FromGrouping): + stack.append(elem.element) + + +def surface_selectables_only(clause): + stack = [clause] + while stack: + elem = stack.pop() + if isinstance(elem, (TableClause, Alias)): + yield elem + if isinstance(elem, Join): + stack.extend((elem.left, elem.right)) + elif isinstance(elem, FromGrouping): + stack.append(elem.element) + elif isinstance(elem, ColumnClause): + if elem.table is not None: + stack.append(elem.table) + else: + yield elem + elif elem is not None: + yield elem + + +def extract_first_column_annotation(column, annotation_name): + filter_ = (FromGrouping, SelectBase) + + stack = deque([column]) + while stack: + elem = stack.popleft() + if annotation_name in elem._annotations: + return elem._annotations[annotation_name] + for sub in elem.get_children(): + if isinstance(sub, filter_): + continue + stack.append(sub) + return None + + +def selectables_overlap(left, right): + """Return True if left/right have some overlapping selectable""" + + return bool( + set(surface_selectables(left)).intersection(surface_selectables(right)) + ) + + +def bind_values(clause): + """Return an ordered list of "bound" values in the given clause. + + E.g.:: + + >>> expr = and_( + ... table.c.foo==5, table.c.foo==7 + ... ) + >>> bind_values(expr) + [5, 7] + """ + + v = [] + + def visit_bindparam(bind): + v.append(bind.effective_value) + + visitors.traverse(clause, {}, {"bindparam": visit_bindparam}) + return v + + +def _quote_ddl_expr(element): + if isinstance(element, util.string_types): + element = element.replace("'", "''") + return "'%s'" % element + else: + return repr(element) + + +class _repr_base(object): + _LIST = 0 + _TUPLE = 1 + _DICT = 2 + + __slots__ = ("max_chars",) + + def trunc(self, value): + rep = repr(value) + lenrep = len(rep) + if lenrep > self.max_chars: + segment_length = self.max_chars // 2 + rep = ( + rep[0:segment_length] + + ( + " ... (%d characters truncated) ... " + % (lenrep - self.max_chars) + ) + + rep[-segment_length:] + ) + return rep + + +class _repr_row(_repr_base): + """Provide a string view of a row.""" + + __slots__ = ("row",) + + def __init__(self, row, max_chars=300): + self.row = row + self.max_chars = max_chars + + def __repr__(self): + trunc = self.trunc + return "(%s%s)" % ( + ", ".join(trunc(value) for value in self.row), + "," if len(self.row) == 1 else "", + ) + + +class _repr_params(_repr_base): + """Provide a string view of bound parameters. + + Truncates display to a given number of 'multi' parameter sets, + as well as long values to a given number of characters. + + """ + + __slots__ = "params", "batches", "ismulti" + + def __init__(self, params, batches, max_chars=300, ismulti=None): + self.params = params + self.ismulti = ismulti + self.batches = batches + self.max_chars = max_chars + + def __repr__(self): + if self.ismulti is None: + return self.trunc(self.params) + + if isinstance(self.params, list): + typ = self._LIST + + elif isinstance(self.params, tuple): + typ = self._TUPLE + elif isinstance(self.params, dict): + typ = self._DICT + else: + return self.trunc(self.params) + + if self.ismulti and len(self.params) > self.batches: + msg = " ... displaying %i of %i total bound parameter sets ... " + return " ".join( + ( + self._repr_multi(self.params[: self.batches - 2], typ)[ + 0:-1 + ], + msg % (self.batches, len(self.params)), + self._repr_multi(self.params[-2:], typ)[1:], + ) + ) + elif self.ismulti: + return self._repr_multi(self.params, typ) + else: + return self._repr_params(self.params, typ) + + def _repr_multi(self, multi_params, typ): + if multi_params: + if isinstance(multi_params[0], list): + elem_type = self._LIST + elif isinstance(multi_params[0], tuple): + elem_type = self._TUPLE + elif isinstance(multi_params[0], dict): + elem_type = self._DICT + else: + assert False, "Unknown parameter type %s" % ( + type(multi_params[0]) + ) + + elements = ", ".join( + self._repr_params(params, elem_type) for params in multi_params + ) + else: + elements = "" + + if typ == self._LIST: + return "[%s]" % elements + else: + return "(%s)" % elements + + def _repr_params(self, params, typ): + trunc = self.trunc + if typ is self._DICT: + return "{%s}" % ( + ", ".join( + "%r: %s" % (key, trunc(value)) + for key, value in params.items() + ) + ) + elif typ is self._TUPLE: + return "(%s%s)" % ( + ", ".join(trunc(value) for value in params), + "," if len(params) == 1 else "", + ) + else: + return "[%s]" % (", ".join(trunc(value) for value in params)) + + +def adapt_criterion_to_null(crit, nulls): + """given criterion containing bind params, convert selected elements + to IS NULL. + + """ + + def visit_binary(binary): + if ( + isinstance(binary.left, BindParameter) + and binary.left._identifying_key in nulls + ): + # reverse order if the NULL is on the left side + binary.left = binary.right + binary.right = Null() + binary.operator = operators.is_ + binary.negate = operators.is_not + elif ( + isinstance(binary.right, BindParameter) + and binary.right._identifying_key in nulls + ): + binary.right = Null() + binary.operator = operators.is_ + binary.negate = operators.is_not + + return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) + + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, Join) and right is not stop_on: + right = right._clone() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright is not None: + prevright.left = right + if ret is None: + ret = right + + return ret + + +def reduce_columns(columns, *clauses, **kw): + r"""given a list of columns, return a 'reduced' set based on natural + equivalents. + + the set is reduced to the smallest list of columns which have no natural + equivalent present in the list. A "natural equivalent" means that two + columns will ultimately represent the same value because they are related + by a foreign key. + + \*clauses is an optional list of join clauses which will be traversed + to further identify columns that are "equivalent". + + \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys + whose tables are not yet configured, or columns that aren't yet present. + + This function is primarily used to determine the most minimal "primary + key" from a selectable, by reducing the set of primary key columns present + in the selectable to just those that are not repeated. + + """ + ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False) + only_synonyms = kw.pop("only_synonyms", False) + + columns = util.ordered_column_set(columns) + + omit = util.column_set() + for col in columns: + for fk in chain(*[c.foreign_keys for c in col.proxy_set]): + for c in columns: + if c is col: + continue + try: + fk_col = fk.column + except exc.NoReferencedColumnError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest + if ignore_nonexistent_tables: + continue + else: + raise + except exc.NoReferencedTableError: + # TODO: add specific coverage here + # to test/sql/test_selectable ReduceTest + if ignore_nonexistent_tables: + continue + else: + raise + if fk_col.shares_lineage(c) and ( + not only_synonyms or c.name == col.name + ): + omit.add(col) + break + + if clauses: + + def visit_binary(binary): + if binary.operator == operators.eq: + cols = util.column_set( + chain(*[c.proxy_set for c in columns.difference(omit)]) + ) + if binary.left in cols and binary.right in cols: + for c in reversed(columns): + if c.shares_lineage(binary.right) and ( + not only_synonyms or c.name == binary.left.name + ): + omit.add(c) + break + + for clause in clauses: + if clause is not None: + visitors.traverse(clause, {}, {"binary": visit_binary}) + + return ColumnSet(columns.difference(omit)) + + +def criterion_as_pairs( + expression, + consider_as_foreign_keys=None, + consider_as_referenced_keys=None, + any_operator=False, +): + """traverse an expression and locate binary criterion pairs.""" + + if consider_as_foreign_keys and consider_as_referenced_keys: + raise exc.ArgumentError( + "Can only specify one of " + "'consider_as_foreign_keys' or " + "'consider_as_referenced_keys'" + ) + + def col_is(a, b): + # return a is b + return a.compare(b) + + def visit_binary(binary): + if not any_operator and binary.operator is not operators.eq: + return + if not isinstance(binary.left, ColumnElement) or not isinstance( + binary.right, ColumnElement + ): + return + + if consider_as_foreign_keys: + if binary.left in consider_as_foreign_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_foreign_keys + ): + pairs.append((binary.right, binary.left)) + elif binary.right in consider_as_foreign_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_foreign_keys + ): + pairs.append((binary.left, binary.right)) + elif consider_as_referenced_keys: + if binary.left in consider_as_referenced_keys and ( + col_is(binary.right, binary.left) + or binary.right not in consider_as_referenced_keys + ): + pairs.append((binary.left, binary.right)) + elif binary.right in consider_as_referenced_keys and ( + col_is(binary.left, binary.right) + or binary.left not in consider_as_referenced_keys + ): + pairs.append((binary.right, binary.left)) + else: + if isinstance(binary.left, Column) and isinstance( + binary.right, Column + ): + if binary.left.references(binary.right): + pairs.append((binary.right, binary.left)) + elif binary.right.references(binary.left): + pairs.append((binary.left, binary.right)) + + pairs = [] + visitors.traverse(expression, {}, {"binary": visit_binary}) + return pairs + + +class ClauseAdapter(visitors.ReplacingExternalTraversal): + """Clones and modifies clauses based on column correspondence. + + E.g.:: + + table1 = Table('sometable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + table2 = Table('someothertable', metadata, + Column('col1', Integer), + Column('col2', Integer) + ) + + condition = table1.c.col1 == table2.c.col1 + + make an alias of table1:: + + s = table1.alias('foo') + + calling ``ClauseAdapter(s).traverse(condition)`` converts + condition to read:: + + s.c.col1 == table2.c.col1 + + """ + + def __init__( + self, + selectable, + equivalents=None, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + anonymize_labels=False, + adapt_from_selectables=None, + ): + self.__traverse_options__ = { + "stop_on": [selectable], + "anonymize_labels": anonymize_labels, + } + self.selectable = selectable + self.include_fn = include_fn + self.exclude_fn = exclude_fn + self.equivalents = util.column_dict(equivalents or {}) + self.adapt_on_names = adapt_on_names + self.adapt_from_selectables = adapt_from_selectables + + def _corresponding_column( + self, col, require_embedded, _seen=util.EMPTY_SET + ): + + newcol = self.selectable.corresponding_column( + col, require_embedded=require_embedded + ) + if newcol is None and col in self.equivalents and col not in _seen: + for equiv in self.equivalents[col]: + newcol = self._corresponding_column( + equiv, + require_embedded=require_embedded, + _seen=_seen.union([col]), + ) + if newcol is not None: + return newcol + if self.adapt_on_names and newcol is None: + newcol = self.selectable.exported_columns.get(col.name) + return newcol + + @util.preload_module("sqlalchemy.sql.functions") + def replace(self, col, _include_singleton_constants=False): + functions = util.preloaded.sql_functions + + if isinstance(col, FromClause) and not isinstance( + col, functions.FunctionElement + ): + + if self.selectable.is_derived_from(col): + if self.adapt_from_selectables: + for adp in self.adapt_from_selectables: + if adp.is_derived_from(col): + break + else: + return None + return self.selectable + elif isinstance(col, Alias) and isinstance( + col.element, TableClause + ): + # we are a SELECT statement and not derived from an alias of a + # table (which nonetheless may be a table our SELECT derives + # from), so return the alias to prevent further traversal + # or + # we are an alias of a table and we are not derived from an + # alias of a table (which nonetheless may be the same table + # as ours) so, same thing + return col + else: + # other cases where we are a selectable and the element + # is another join or selectable that contains a table which our + # selectable derives from, that we want to process + return None + + elif not isinstance(col, ColumnElement): + return None + elif not _include_singleton_constants and col._is_singleton_constant: + # dont swap out NULL, TRUE, FALSE for a label name + # in a SQL statement that's being rewritten, + # leave them as the constant. This is first noted in #6259, + # however the logic to check this moved here as of #7154 so that + # it is made specific to SQL rewriting and not all column + # correspondence + return None + + if "adapt_column" in col._annotations: + col = col._annotations["adapt_column"] + + if self.adapt_from_selectables and col not in self.equivalents: + for adp in self.adapt_from_selectables: + if adp.c.corresponding_column(col, False) is not None: + break + else: + return None + + if self.include_fn and not self.include_fn(col): + return None + elif self.exclude_fn and self.exclude_fn(col): + return None + else: + return self._corresponding_column(col, True) + + +class ColumnAdapter(ClauseAdapter): + """Extends ClauseAdapter with extra utility functions. + + Key aspects of ColumnAdapter include: + + * Expressions that are adapted are stored in a persistent + .columns collection; so that an expression E adapted into + an expression E1, will return the same object E1 when adapted + a second time. This is important in particular for things like + Label objects that are anonymized, so that the ColumnAdapter can + be used to present a consistent "adapted" view of things. + + * Exclusion of items from the persistent collection based on + include/exclude rules, but also independent of hash identity. + This because "annotated" items all have the same hash identity as their + parent. + + * "wrapping" capability is added, so that the replacement of an expression + E can proceed through a series of adapters. This differs from the + visitor's "chaining" feature in that the resulting object is passed + through all replacing functions unconditionally, rather than stopping + at the first one that returns non-None. + + * An adapt_required option, used by eager loading to indicate that + We don't trust a result row column that is not translated. + This is to prevent a column from being interpreted as that + of the child row in a self-referential scenario, see + inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency + + """ + + def __init__( + self, + selectable, + equivalents=None, + adapt_required=False, + include_fn=None, + exclude_fn=None, + adapt_on_names=False, + allow_label_resolve=True, + anonymize_labels=False, + adapt_from_selectables=None, + ): + ClauseAdapter.__init__( + self, + selectable, + equivalents, + include_fn=include_fn, + exclude_fn=exclude_fn, + adapt_on_names=adapt_on_names, + anonymize_labels=anonymize_labels, + adapt_from_selectables=adapt_from_selectables, + ) + + self.columns = util.WeakPopulateDict(self._locate_col) + if self.include_fn or self.exclude_fn: + self.columns = self._IncludeExcludeMapping(self, self.columns) + self.adapt_required = adapt_required + self.allow_label_resolve = allow_label_resolve + self._wrap = None + + class _IncludeExcludeMapping(object): + def __init__(self, parent, columns): + self.parent = parent + self.columns = columns + + def __getitem__(self, key): + if ( + self.parent.include_fn and not self.parent.include_fn(key) + ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)): + if self.parent._wrap: + return self.parent._wrap.columns[key] + else: + return key + return self.columns[key] + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__.update(self.__dict__) + ac._wrap = adapter + ac.columns = util.WeakPopulateDict(ac._locate_col) + if ac.include_fn or ac.exclude_fn: + ac.columns = self._IncludeExcludeMapping(ac, ac.columns) + + return ac + + def traverse(self, obj): + return self.columns[obj] + + adapt_clause = traverse + adapt_list = ClauseAdapter.copy_and_process + + def adapt_check_present(self, col): + newcol = self.columns[col] + + if newcol is col and self._corresponding_column(col, True) is None: + return None + + return newcol + + def _locate_col(self, col): + # both replace and traverse() are overly complicated for what + # we are doing here and we would do better to have an inlined + # version that doesn't build up as much overhead. the issue is that + # sometimes the lookup does in fact have to adapt the insides of + # say a labeled scalar subquery. However, if the object is an + # Immutable, i.e. Column objects, we can skip the "clone" / + # "copy internals" part since those will be no-ops in any case. + # additionally we want to catch singleton objects null/true/false + # and make sure they are adapted as well here. + + if col._is_immutable: + for vis in self.visitor_iterator: + c = vis.replace(col, _include_singleton_constants=True) + if c is not None: + break + else: + c = col + else: + c = ClauseAdapter.traverse(self, col) + + if self._wrap: + c2 = self._wrap._locate_col(c) + if c2 is not None: + c = c2 + + if self.adapt_required and c is col: + return None + + c._allow_label_resolve = self.allow_label_resolve + + return c + + def __getstate__(self): + d = self.__dict__.copy() + del d["columns"] + return d + + def __setstate__(self, state): + self.__dict__.update(state) + self.columns = util.WeakPopulateDict(self._locate_col) + + +def _offset_or_limit_clause(element, name=None, type_=None): + """Convert the given value to an "offset or limit" clause. + + This handles incoming integers and converts to an expression; if + an expression is already given, it is passed through. + + """ + return coercions.expect( + roles.LimitOffsetRole, element, name=name, type_=type_ + ) + + +def _offset_or_limit_clause_asint_if_possible(clause): + """Return the offset or limit clause as a simple integer if possible, + else return the clause. + + """ + if clause is None: + return None + if hasattr(clause, "_limit_offset_value"): + value = clause._limit_offset_value + return util.asint(value) + else: + return clause + + +def _make_slice(limit_clause, offset_clause, start, stop): + """Compute LIMIT/OFFSET in terms of slice start/end""" + + # for calculated limit/offset, try to do the addition of + # values to offset in Python, however if a SQL clause is present + # then the addition has to be on the SQL side. + if start is not None and stop is not None: + offset_clause = _offset_or_limit_clause_asint_if_possible( + offset_clause + ) + if offset_clause is None: + offset_clause = 0 + + if start != 0: + offset_clause = offset_clause + start + + if offset_clause == 0: + offset_clause = None + else: + offset_clause = _offset_or_limit_clause(offset_clause) + + limit_clause = _offset_or_limit_clause(stop - start) + + elif start is None and stop is not None: + limit_clause = _offset_or_limit_clause(stop) + elif start is not None and stop is None: + offset_clause = _offset_or_limit_clause_asint_if_possible( + offset_clause + ) + if offset_clause is None: + offset_clause = 0 + + if start != 0: + offset_clause = offset_clause + start + + if offset_clause == 0: + offset_clause = None + else: + offset_clause = _offset_or_limit_clause(offset_clause) + + return limit_clause, offset_clause |