From 2b463a5844b4b9b57b304dbad3c908fe75906762 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Thu, 29 Feb 2024 01:37:15 -0500 Subject: [PATCH] UnresolvedLookup WIP. --- django/db/models/lookups.py | 96 +++++++++++++++++++++++++- django/db/models/sql/query.py | 125 +++++++++++----------------------- 2 files changed, 134 insertions(+), 87 deletions(-) diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index 72b087ed6d..bd3f4615a9 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -1,10 +1,21 @@ import itertools import math import warnings +from collections.abc import Iterator -from django.core.exceptions import EmptyResultSet, FullResultSet +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet from django.db.backends.base.operations import BaseDatabaseOperations -from django.db.models.expressions import Case, Expression, Func, Value, When +from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import ( + BaseExpression, + Case, + ColPairs, + Expression, + F, + Func, + Value, + When, +) from django.db.models.fields import ( BooleanField, CharField, @@ -20,6 +31,87 @@ from django.utils.functional import cached_property from django.utils.hashable import make_hashable +class UnresolvedLookup(BaseExpression): + output_field = BooleanField() + + @classmethod + def from_lookup(cls, lookup, value): + name, *parts = lookup.split(LOOKUP_SEP) + return cls(F(name), parts, value) + + def __init__(self, origin, parts, value): + self.origin = origin + self.parts = parts + self.value = value + super().__init__() + + def get_source_expressions(self): + return [self.origin, self.value] + + def set_source_expressions(self, exprs): + self.origin, self.value = exprs + + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + lookup = LOOKUP_SEP.join([self.origin.name, *self.parts]) + if not lookup: + raise FieldError("Cannot parse keyword query %r" % self.field.name) + + lookups, parts, reffed_expression = query.solve_lookup_type(lookup, summarize) + if not allow_joins and len(parts) > 1: + raise FieldError("Joined field references are not permitted in this query") + + pre_joins = query.alias_refcount.copy() + value = query.resolve_lookup_value(self.value, reuse, allow_joins, summarize) + used_joins = { + k for k, v in query.alias_refcount.items() if v > pre_joins.get(k, 0) + } + + if reffed_expression: + return query.build_lookup(lookups, reffed_expression, value) + + opts = query.get_meta() + alias = query.get_initial_alias() + join_info = query.setup_joins( + parts, + opts, + alias, + can_reuse=reuse, + # XXX: allow_many is problematic and should probably be passed to resolve_expression. + allow_many=True, + ) + + # Update used_joins before trimming since they are reused to determine + # which joins could be later promoted to INNER. + used_joins.update(join_info.joins) + targets, alias, join_list = query.trim_joins( + join_info.targets, join_info.joins, join_info.path + ) + if reuse is not None: + reuse.update(join_list) + + if join_info.final_field.is_relation: + if len(targets) == 1: + col = query._get_col(targets[0], join_info.final_field, alias) + else: + col = ColPairs( + alias, targets, join_info.targets, join_info.final_field + ) + else: + col = query._get_col(targets[0], join_info.final_field, alias) + + # Prevent iterator from being consumed by check_related_objects() + if isinstance(value, Iterator): + value = list(value) + query.check_related_objects(join_info.final_field, value, join_info.opts) + + lookup = query.build_lookup(lookups, col, value) + lookup._used_joins = used_joins + lookup._lookup_joins = join_info.joins + return lookup + + class Lookup(Expression): lookup_name = None prepare_rhs = True diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0bb3414b72..e3ac33a996 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -33,7 +33,7 @@ from django.db.models.expressions import ( Value, ) from django.db.models.fields import Field -from django.db.models.lookups import Lookup +from django.db.models.lookups import Lookup, UnresolvedLookup from django.db.models.query_utils import ( Q, check_rel_lookup_compatibility, @@ -1498,95 +1498,50 @@ class Query(BaseExpression): summarize=summarize, update_join_types=update_join_types, ) - if hasattr(filter_expr, "resolve_expression"): - if not getattr(filter_expr, "conditional", False): - raise TypeError("Cannot filter against a non-conditional expression.") - pre_joins = self.alias_refcount.copy() - condition = filter_expr.resolve_expression( + if isinstance(filter_expr, tuple): + return self.build_filter( + UnresolvedLookup.from_lookup(*filter_expr), + branch_negated=branch_negated, + current_negated=current_negated, + can_reuse=can_reuse, + allow_joins=allow_joins, + split_subq=split_subq, + check_filterable=check_filterable, + summarize=summarize, + update_join_types=update_join_types, + ) + if ( + resolve_expression := getattr(filter_expr, "resolve_expression", None) + ) is None: + raise TypeError(f"Cannot filter against {filter_expr!r}") + if not getattr(filter_expr, "conditional", False): + raise TypeError("Cannot filter against a non-conditional expression.") + + pre_joins = self.alias_refcount.copy() + try: + resolved = resolve_expression( self, allow_joins=allow_joins, reuse=can_reuse, summarize=summarize ) + # XXX: Hack to avoid modifying the resolve_expression return signature. + # See UnresolvedLookup. + used_joins = getattr(resolved, "_used_joins", None) + lookup_joins = getattr(resolved, "_lookup_joins", set()) + # split_exclude() needs to know which joins were generated for the + # lookup parts + self._lookup_joins = lookup_joins + except MultiJoin as e: + return self.split_exclude(filter_expr, can_reuse, e.names_with_path) + if check_filterable: + self.check_filterable(resolved) + if used_joins is None: used_joins = { k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0) } - if not isinstance(condition, Lookup): - condition = self.build_lookup(["exact"], condition, True) - clause, require_outer = self._build_lookup_clause( - condition, current_negated, {} - ) - return clause, used_joins if not require_outer else () - arg, value = filter_expr - if not arg: - raise FieldError("Cannot parse keyword query %r" % arg) - lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize) - - if check_filterable: - self.check_filterable(reffed_expression) - - if not allow_joins and len(parts) > 1: - raise FieldError("Joined field references are not permitted in this query") - - pre_joins = self.alias_refcount.copy() - value = self.resolve_lookup_value(value, can_reuse, allow_joins, summarize) - used_joins = { - k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0) - } - - if check_filterable: - self.check_filterable(value) - - if reffed_expression: - condition = self.build_lookup(lookups, reffed_expression, value) - return WhereNode([condition], connector=AND), [] - - opts = self.get_meta() - alias = self.get_initial_alias() - allow_many = not branch_negated or not split_subq - - try: - join_info = self.setup_joins( - parts, - opts, - alias, - can_reuse=can_reuse, - allow_many=allow_many, - ) - - # Prevent iterator from being consumed by check_related_objects() - if isinstance(value, Iterator): - value = list(value) - self.check_related_objects(join_info.final_field, value, join_info.opts) - - # split_exclude() needs to know which joins were generated for the - # lookup parts - self._lookup_joins = join_info.joins - except MultiJoin as e: - return self.split_exclude(filter_expr, can_reuse, e.names_with_path) - - # Update used_joins before trimming since they are reused to determine - # which joins could be later promoted to INNER. - used_joins.update(join_info.joins) - targets, alias, join_list = self.trim_joins( - join_info.targets, join_info.joins, join_info.path - ) - if can_reuse is not None: - can_reuse.update(join_list) - - if join_info.final_field.is_relation: - if len(targets) == 1: - col = self._get_col(targets[0], join_info.final_field, alias) - else: - col = ColPairs(alias, targets, join_info.targets, join_info.final_field) - else: - col = self._get_col(targets[0], join_info.final_field, alias) - - lookup = self.build_lookup(lookups, col, value) - if self.alias_map[join_list[-1]].join_type == LOUTER: - nullable_aliases = {alias} - else: - nullable_aliases = {} - clause, require_outer = self._build_lookup_clause( - lookup, current_negated, nullable_aliases - ) + # XXX: Not clear if this is needed given this is really just a fallback. + used_joins.update(lookup_joins) + if not isinstance(resolved, Lookup): + resolved = self.build_lookup(["exact"], resolved, True) + clause, require_outer = self._build_lookup_clause(resolved, current_negated, {}) return clause, used_joins if not require_outer else () def add_filter(self, filter_lhs, filter_rhs):