diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index a94582db95..6342937cd6 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -12,6 +12,7 @@ from django.db.models.lookups import ( LessThan, LessThanOrEqual, ) +from django.db.models.sql import Query from django.db.models.sql.where import AND, OR, WhereNode @@ -211,9 +212,14 @@ class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual): class TupleIn(TupleLookupMixin, In): def get_prep_lookup(self): - self.check_rhs_is_tuple_or_list() - self.check_rhs_is_collection_of_tuples_or_lists() - self.check_rhs_elements_length_equals_lhs_length() + if self.rhs_is_direct_value(): + self.check_rhs_is_tuple_or_list() + self.check_rhs_is_collection_of_tuples_or_lists() + self.check_rhs_elements_length_equals_lhs_length() + else: + self.check_rhs_is_query() + self.check_rhs_select_length_equals_lhs_length() + return self.rhs # skip checks from mixin def check_rhs_is_collection_of_tuples_or_lists(self): @@ -233,6 +239,25 @@ class TupleIn(TupleLookupMixin, In): f"must have {len_lhs} elements each" ) + def check_rhs_is_query(self): + if not isinstance(self.rhs, Query): + lhs_str = self.get_lhs_str() + rhs_cls = self.rhs.__class__.__name__ + raise ValueError( + f"{self.lookup_name!r} subquery lookup of {lhs_str} " + f"must be a Query object (received {rhs_cls!r})" + ) + + def check_rhs_select_length_equals_lhs_length(self): + len_rhs = len(self.rhs.select) + len_lhs = len(self.lhs) + if len_rhs != len_lhs: + lhs_str = self.get_lhs_str() + raise ValueError( + f"{self.lookup_name!r} subquery lookup of {lhs_str} " + f"must have {len_lhs} fields (received {len_rhs})" + ) + def process_rhs(self, compiler, connection): rhs = self.rhs if not rhs: @@ -255,10 +280,17 @@ class TupleIn(TupleLookupMixin, In): return Tuple(*result).as_sql(compiler, connection) + def as_sql(self, compiler, connection): + if not self.rhs_is_direct_value(): + return self.as_subquery(compiler, connection) + return super().as_sql(compiler, connection) + def as_sqlite(self, compiler, connection): rhs = self.rhs if not rhs: raise EmptyResultSet + if not self.rhs_is_direct_value(): + return self.as_subquery(compiler, connection) # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL: # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2) @@ -271,6 +303,9 @@ class TupleIn(TupleLookupMixin, In): return root.as_sql(compiler, connection) + def as_subquery(self, compiler, connection): + return compiler.compile(In(self.lhs, self.rhs)) + tuple_lookups = { "exact": TupleExact, diff --git a/tests/foreign_object/test_tuple_lookups.py b/tests/foreign_object/test_tuple_lookups.py index 499329e7ca..797fea1c8a 100644 --- a/tests/foreign_object/test_tuple_lookups.py +++ b/tests/foreign_object/test_tuple_lookups.py @@ -11,6 +11,7 @@ from django.db.models.fields.tuple_lookups import ( TupleLessThan, TupleLessThanOrEqual, ) +from django.db.models.lookups import In from django.test import TestCase, skipUnlessDBFeature from .models import Contact, Customer @@ -126,6 +127,46 @@ class TupleLookupsTests(TestCase): (self.contact_1, self.contact_2, self.contact_5), ) + def test_tuple_in_subquery_must_be_query(self): + lhs = (F("customer_code"), F("company_code")) + # If rhs is any non-Query object with an as_sql() function. + rhs = In(F("customer_code"), [1, 2, 3]) + with self.assertRaisesMessage( + ValueError, + "'in' subquery lookup of ('customer_code', 'company_code') " + "must be a Query object (received 'In')", + ): + TupleIn(lhs, rhs) + + def test_tuple_in_subquery_must_have_2_fields(self): + lhs = (F("customer_code"), F("company_code")) + rhs = Customer.objects.values_list("customer_id").query + with self.assertRaisesMessage( + ValueError, + "'in' subquery lookup of ('customer_code', 'company_code') " + "must have 2 fields (received 1)", + ): + TupleIn(lhs, rhs) + + def test_tuple_in_subquery(self): + customers = Customer.objects.values_list("customer_id", "company") + test_cases = ( + (self.customer_1, (self.contact_1, self.contact_2, self.contact_5)), + (self.customer_2, (self.contact_3,)), + (self.customer_3, (self.contact_4,)), + (self.customer_4, ()), + (self.customer_5, (self.contact_6,)), + ) + + for customer, contacts in test_cases: + lhs = (F("customer_code"), F("company_code")) + rhs = customers.filter(id=customer.id).query + lookup = TupleIn(lhs, rhs) + qs = Contact.objects.filter(lookup).order_by("id") + + with self.subTest(customer=customer.id, query=str(qs.query)): + self.assertSequenceEqual(qs, contacts) + def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self): test_cases = ( (1, 2, 3),