From 084b4f10c28cb5e1c57ce7ff9e5934154a611204 Mon Sep 17 00:00:00 2001 From: John Parton Date: Sun, 18 Aug 2024 23:12:14 -0500 Subject: [PATCH] Fixed #35690 -- Add support for calling QuerySet.in_bulk() after QuerySet.values() --- django/db/models/query.py | 25 ++++++++++++++++++++++++- tests/lookup/tests.py | 17 +++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index a4277d05fc..a0a0572728 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1119,6 +1119,29 @@ class QuerySet(AltersData): "in_bulk()'s field_name must be a unique field but %r isn't." % field_name ) + + if issubclass(self._iterable_class, ModelIterable): + + def _get_key(row): + # Will raise an unhelpful AttributeError if field_name is deferred + # We could catch this earlier and raise a TypeError + return getattr(row, field_name) + + elif issubclass(self._iterable_class, ValuesIterable): + # If field_name wasn't explicitly included, include it anyways + if field_name not in self.query.values_select: + return self.values(field_name, *self.query.values_select).in_bulk( + id_list=id_list, field_name=field_name + ) + + def _get_key(row): + return row[field_name] + + else: + # NamedValuesListIterable should work in theory, + # but in practice it's very clunky. + raise TypeError("in_bulk() cannot be used with %r." % self._iterable_class) + if id_list is not None: if not id_list: return {} @@ -1136,7 +1159,7 @@ class QuerySet(AltersData): qs = self.filter(**{filter_key: id_list}) else: qs = self._chain() - return {getattr(obj, field_name): obj for obj in qs} + return {_get_key(obj): obj for obj in qs} async def ain_bulk(self, id_list=None, *, field_name="pk"): return await sync_to_async(self.in_bulk)( diff --git a/tests/lookup/tests.py b/tests/lookup/tests.py index df96546d04..8f04537ec1 100644 --- a/tests/lookup/tests.py +++ b/tests/lookup/tests.py @@ -205,6 +205,23 @@ class LookupTests(TestCase): with self.assertRaises(TypeError): Article.objects.in_bulk(headline__startswith="Blah") + def test_in_bulk_values(self): + bulk = Article.objects.values().in_bulk([self.a1.id]) + self.assertIsInstance(bulk[self.a1.id], dict) + + def test_in_bulk_values_without_pk(self): + bulk = Article.objects.values("headline").in_bulk([self.a1.id]) + self.assertIn("pk", bulk[self.a1.id]) + + def test_in_bulk_values_list(self): + with self.assertRaises(TypeError): + Article.objects.values_list().in_bulk([self.a1.id]) + + def test_in_bulk_values_list_named(self) -> None: + # At some point this could return tuples + with self.assertRaises(TypeError): + Article.objects.values_list(named=True).in_bulk([self.a1.id]) + def test_in_bulk_lots_of_ids(self): test_range = 2000 max_query_params = connection.features.max_query_params