diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 676625df6f..7541817ba0 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1253,21 +1253,21 @@ class SQLCompiler: if restricted: related_fields = [ - (o.field, o.related_model) + (o, o.field, o.related_model) for o in opts.related_objects if o.field.unique and not o.many_to_many ] - for related_field, model in related_fields: - related_select_mask = select_mask.get(related_field) or {} + for related_object, related_field, model in related_fields: if not select_related_descend( related_field, restricted, requested, - related_select_mask, + select_mask, reverse=True, ): continue + related_select_mask = select_mask.get(related_object) or {} related_field_name = related_field.related_query_name() fields_found.add(related_field_name) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b3f130c0b4..c5c58b1788 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -815,19 +815,17 @@ class Query(BaseExpression): if filtered_relation := self._filtered_relations.get(field_name): relation = opts.get_field(filtered_relation.relation_name) field_select_mask = select_mask.setdefault((field_name, relation), {}) - field = relation.field else: - reverse_rel = opts.get_field(field_name) + relation = opts.get_field(field_name) # While virtual fields such as many-to-many and generic foreign # keys cannot be effectively deferred we've historically # allowed them to be passed to QuerySet.defer(). Ignore such # field references until a layer of validation at mask # alteration time will be implemented eventually. - if not hasattr(reverse_rel, "field"): + if not hasattr(relation, "field"): continue - field = reverse_rel.field - field_select_mask = select_mask.setdefault(field, {}) - related_model = field.model._meta.concrete_model + field_select_mask = select_mask.setdefault(relation, {}) + related_model = relation.related_model._meta.concrete_model self._get_defer_select_mask( related_model._meta, field_mask, field_select_mask ) @@ -840,13 +838,7 @@ class Query(BaseExpression): # Only include fields mentioned in the mask. for field_name, field_mask in mask.items(): field = opts.get_field(field_name) - # Retrieve the actual field associated with reverse relationships - # as that's what is expected in the select mask. - if field in opts.related_objects: - field_key = field.field - else: - field_key = field - field_select_mask = select_mask.setdefault(field_key, {}) + field_select_mask = select_mask.setdefault(field, {}) if field_mask: if not field.is_relation: raise FieldError(next(iter(field_mask))) diff --git a/tests/defer_regress/models.py b/tests/defer_regress/models.py index dd492993b7..38ba4a622f 100644 --- a/tests/defer_regress/models.py +++ b/tests/defer_regress/models.py @@ -10,6 +10,12 @@ class Item(models.Model): text = models.TextField(default="xyzzy") value = models.IntegerField() other_value = models.IntegerField(default=0) + source = models.OneToOneField( + "self", + related_name="destination", + on_delete=models.CASCADE, + null=True, + ) class RelatedItem(models.Model): diff --git a/tests/defer_regress/tests.py b/tests/defer_regress/tests.py index 10100e348d..1209325f21 100644 --- a/tests/defer_regress/tests.py +++ b/tests/defer_regress/tests.py @@ -309,6 +309,27 @@ class DeferRegressionTest(TestCase): with self.assertNumQueries(1): self.assertEqual(Item.objects.only("request").get(), item) + def test_self_referential_one_to_one(self): + first = Item.objects.create(name="first", value=1) + second = Item.objects.create(name="second", value=2, source=first) + with self.assertNumQueries(1): + deferred_first, deferred_second = ( + Item.objects.select_related("source", "destination") + .only("name", "source__name", "destination__value") + .order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(deferred_first.name, first.name) + self.assertEqual(deferred_second.name, second.name) + self.assertEqual(deferred_second.source.name, first.name) + self.assertEqual(deferred_first.destination.value, second.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_first.value, first.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_second.source.value, first.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_first.destination.name, second.name) + class DeferDeletionSignalsTests(TestCase): senders = [Item, Proxy]