0
0
mirror of https://github.com/django/django.git synced 2024-11-24 02:47:35 +01:00

UnresolvedLookup WIP.

This commit is contained in:
Simon Charette 2024-02-29 01:37:15 -05:00
parent 9afbb91c35
commit 2b463a5844
No known key found for this signature in database
2 changed files with 134 additions and 87 deletions

View File

@ -1,10 +1,21 @@
import itertools
import math
import warnings
from collections.abc import Iterator
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import (
BaseExpression,
Case,
ColPairs,
Expression,
F,
Func,
Value,
When,
)
from django.db.models.fields import (
BooleanField,
CharField,
@ -20,6 +31,87 @@ from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
class UnresolvedLookup(BaseExpression):
output_field = BooleanField()
@classmethod
def from_lookup(cls, lookup, value):
name, *parts = lookup.split(LOOKUP_SEP)
return cls(F(name), parts, value)
def __init__(self, origin, parts, value):
self.origin = origin
self.parts = parts
self.value = value
super().__init__()
def get_source_expressions(self):
return [self.origin, self.value]
def set_source_expressions(self, exprs):
self.origin, self.value = exprs
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
lookup = LOOKUP_SEP.join([self.origin.name, *self.parts])
if not lookup:
raise FieldError("Cannot parse keyword query %r" % self.field.name)
lookups, parts, reffed_expression = query.solve_lookup_type(lookup, summarize)
if not allow_joins and len(parts) > 1:
raise FieldError("Joined field references are not permitted in this query")
pre_joins = query.alias_refcount.copy()
value = query.resolve_lookup_value(self.value, reuse, allow_joins, summarize)
used_joins = {
k for k, v in query.alias_refcount.items() if v > pre_joins.get(k, 0)
}
if reffed_expression:
return query.build_lookup(lookups, reffed_expression, value)
opts = query.get_meta()
alias = query.get_initial_alias()
join_info = query.setup_joins(
parts,
opts,
alias,
can_reuse=reuse,
# XXX: allow_many is problematic and should probably be passed to resolve_expression.
allow_many=True,
)
# Update used_joins before trimming since they are reused to determine
# which joins could be later promoted to INNER.
used_joins.update(join_info.joins)
targets, alias, join_list = query.trim_joins(
join_info.targets, join_info.joins, join_info.path
)
if reuse is not None:
reuse.update(join_list)
if join_info.final_field.is_relation:
if len(targets) == 1:
col = query._get_col(targets[0], join_info.final_field, alias)
else:
col = ColPairs(
alias, targets, join_info.targets, join_info.final_field
)
else:
col = query._get_col(targets[0], join_info.final_field, alias)
# Prevent iterator from being consumed by check_related_objects()
if isinstance(value, Iterator):
value = list(value)
query.check_related_objects(join_info.final_field, value, join_info.opts)
lookup = query.build_lookup(lookups, col, value)
lookup._used_joins = used_joins
lookup._lookup_joins = join_info.joins
return lookup
class Lookup(Expression):
lookup_name = None
prepare_rhs = True

View File

@ -33,7 +33,7 @@ from django.db.models.expressions import (
Value,
)
from django.db.models.fields import Field
from django.db.models.lookups import Lookup
from django.db.models.lookups import Lookup, UnresolvedLookup
from django.db.models.query_utils import (
Q,
check_rel_lookup_compatibility,
@ -1498,95 +1498,50 @@ class Query(BaseExpression):
summarize=summarize,
update_join_types=update_join_types,
)
if hasattr(filter_expr, "resolve_expression"):
if not getattr(filter_expr, "conditional", False):
raise TypeError("Cannot filter against a non-conditional expression.")
pre_joins = self.alias_refcount.copy()
condition = filter_expr.resolve_expression(
if isinstance(filter_expr, tuple):
return self.build_filter(
UnresolvedLookup.from_lookup(*filter_expr),
branch_negated=branch_negated,
current_negated=current_negated,
can_reuse=can_reuse,
allow_joins=allow_joins,
split_subq=split_subq,
check_filterable=check_filterable,
summarize=summarize,
update_join_types=update_join_types,
)
if (
resolve_expression := getattr(filter_expr, "resolve_expression", None)
) is None:
raise TypeError(f"Cannot filter against {filter_expr!r}")
if not getattr(filter_expr, "conditional", False):
raise TypeError("Cannot filter against a non-conditional expression.")
pre_joins = self.alias_refcount.copy()
try:
resolved = resolve_expression(
self, allow_joins=allow_joins, reuse=can_reuse, summarize=summarize
)
# XXX: Hack to avoid modifying the resolve_expression return signature.
# See UnresolvedLookup.
used_joins = getattr(resolved, "_used_joins", None)
lookup_joins = getattr(resolved, "_lookup_joins", set())
# split_exclude() needs to know which joins were generated for the
# lookup parts
self._lookup_joins = lookup_joins
except MultiJoin as e:
return self.split_exclude(filter_expr, can_reuse, e.names_with_path)
if check_filterable:
self.check_filterable(resolved)
if used_joins is None:
used_joins = {
k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
}
if not isinstance(condition, Lookup):
condition = self.build_lookup(["exact"], condition, True)
clause, require_outer = self._build_lookup_clause(
condition, current_negated, {}
)
return clause, used_joins if not require_outer else ()
arg, value = filter_expr
if not arg:
raise FieldError("Cannot parse keyword query %r" % arg)
lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
if check_filterable:
self.check_filterable(reffed_expression)
if not allow_joins and len(parts) > 1:
raise FieldError("Joined field references are not permitted in this query")
pre_joins = self.alias_refcount.copy()
value = self.resolve_lookup_value(value, can_reuse, allow_joins, summarize)
used_joins = {
k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
}
if check_filterable:
self.check_filterable(value)
if reffed_expression:
condition = self.build_lookup(lookups, reffed_expression, value)
return WhereNode([condition], connector=AND), []
opts = self.get_meta()
alias = self.get_initial_alias()
allow_many = not branch_negated or not split_subq
try:
join_info = self.setup_joins(
parts,
opts,
alias,
can_reuse=can_reuse,
allow_many=allow_many,
)
# Prevent iterator from being consumed by check_related_objects()
if isinstance(value, Iterator):
value = list(value)
self.check_related_objects(join_info.final_field, value, join_info.opts)
# split_exclude() needs to know which joins were generated for the
# lookup parts
self._lookup_joins = join_info.joins
except MultiJoin as e:
return self.split_exclude(filter_expr, can_reuse, e.names_with_path)
# Update used_joins before trimming since they are reused to determine
# which joins could be later promoted to INNER.
used_joins.update(join_info.joins)
targets, alias, join_list = self.trim_joins(
join_info.targets, join_info.joins, join_info.path
)
if can_reuse is not None:
can_reuse.update(join_list)
if join_info.final_field.is_relation:
if len(targets) == 1:
col = self._get_col(targets[0], join_info.final_field, alias)
else:
col = ColPairs(alias, targets, join_info.targets, join_info.final_field)
else:
col = self._get_col(targets[0], join_info.final_field, alias)
lookup = self.build_lookup(lookups, col, value)
if self.alias_map[join_list[-1]].join_type == LOUTER:
nullable_aliases = {alias}
else:
nullable_aliases = {}
clause, require_outer = self._build_lookup_clause(
lookup, current_negated, nullable_aliases
)
# XXX: Not clear if this is needed given this is really just a fallback.
used_joins.update(lookup_joins)
if not isinstance(resolved, Lookup):
resolved = self.build_lookup(["exact"], resolved, True)
clause, require_outer = self._build_lookup_clause(resolved, current_negated, {})
return clause, used_joins if not require_outer else ()
def add_filter(self, filter_lhs, filter_rhs):