From 488aa11da48981753e9e78d4e4c943f0b6be63fd Mon Sep 17 00:00:00 2001 From: sharonwoo <29156885+sharonwoo@users.noreply.github.com> Date: Fri, 27 Sep 2024 09:17:07 +0800 Subject: [PATCH] Fixed #35235 -- ArrayAgg() not returning default when filter contains __in=[] --- django/db/models/expressions.py | 17 ++++++++++------- tests/expressions/tests.py | 19 +++++++++++++++++++ tests/postgres_tests/test_aggregates.py | 15 +++++++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 667e9f93c6..c4a6d9bdd6 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -167,13 +167,16 @@ class Combinable: return NegatedExpression(self) +class OutputFieldIsNoneError(FieldError): + pass + + class BaseExpression: """Base class for all query expressions.""" empty_result_set_value = NotImplemented # aggregate specific fields is_summary = False - _output_field_resolved_to_none = False # Can the expression be used in a WHERE clause? filterable = True # Can the expression be used as a source expression in Window? @@ -315,11 +318,12 @@ class BaseExpression: """Return the output type of this expressions.""" output_field = self._resolve_output_field() if output_field is None: - self._output_field_resolved_to_none = True - raise FieldError("Cannot resolve expression type, unknown output_field") + raise OutputFieldIsNoneError( + "Cannot resolve expression type, unknown output_field" + ) return output_field - @cached_property + @property def _output_field_or_none(self): """ Return the output field of this expression, or None if @@ -327,9 +331,8 @@ class BaseExpression: """ try: return self.output_field - except FieldError: - if not self._output_field_resolved_to_none: - raise + except OutputFieldIsNoneError: + return def _resolve_output_field(self): """ diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index af4cf01fca..145c93f875 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -51,6 +51,7 @@ from django.db.models.expressions import ( Combinable, CombinedExpression, NegatedExpression, + OutputFieldIsNoneError, RawSQL, Ref, ) @@ -2329,6 +2330,24 @@ class ValueTests(TestCase): time = Time.objects.annotate(one=Value(1, output_field=DecimalField())).first() self.assertEqual(time.one, 1) + def test_output_field_is_none_error(self): + """Tests the new OutputFieldIsNoneError error introduced in #35235""" + with self.assertRaises(OutputFieldIsNoneError): + Employee.objects.annotate(custom_expression=Value(None)).values_list( + "custom_expression", flat=True + ).first() + + def test_output_field_or_none_property(self): + """ + The _output_field_or_none property was changed from @cached_property + to @property in #35235. Rather than write subtests, this test explicitly + checks if the property is re-evaluated when the output field is reset + """ + expression = Value(None, output_field=None) + self.assertIsNone(expression._output_field_or_none) + expression.output_field = BooleanField() + self.assertIsInstance(expression._output_field_or_none, BooleanField) + def test_resolve_output_field(self): value_types = [ ("str", CharField), diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index b72310bdf1..b8ad9f0eff 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -309,6 +309,21 @@ class TestGeneralAggregate(PostgreSQLTestCase): ) self.assertCountEqual(qs, [[], [5]]) + def test_array_agg_with_empty_filter_and_default_values(self): + for filter_value in ([-1], []): + for default_value in ([], Value([])): + with self.subTest(filter=filter_value, default=default_value): + queryset = AggregateTestModel.objects.annotate( + test_array_agg=ArrayAgg( + "stattestmodel__int1", + filter=Q(pk__in=filter_value), + default=default_value, + ) + ) + for obj in queryset: + with self.subTest(object_id=obj.pk): + self.assertSequenceEqual(obj.test_array_agg, []) + def test_bit_and_general(self): values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate( bitand=BitAnd("integer_field")