diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index f652c1d346..d43163a40b 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -10,7 +10,7 @@ from django.db.models import ( TextField, Value, ) -from django.db.models.expressions import CombinedExpression +from django.db.models.expressions import CombinedExpression, register_combinable_fields from django.db.models.functions import Cast, Coalesce @@ -79,6 +79,11 @@ class SearchVectorCombinable: return CombinedSearchVector(self, connector, other, self.config) +register_combinable_fields( + SearchVectorField, SearchVectorCombinable.ADD, SearchVectorField, SearchVectorField +) + + class SearchVector(SearchVectorCombinable, Func): function = "to_tsvector" arg_joiner = " || ' ' || " diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 32982777ef..add7b20ba3 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -310,18 +310,18 @@ class BaseExpression: def _resolve_output_field(self): """ - Attempt to infer the output type of the expression. If the output - fields of all source fields match then, simply infer the same type - here. This isn't always correct, but it makes sense most of the time. + Attempt to infer the output type of the expression. - Consider the difference between `2 + 2` and `2 / 3`. Inferring - the type here is a convenience for the common case. The user should - supply their own output_field with more complex computations. + As a guess, if the output fields of all source fields match then simply + infer the same type here. If a source's output field resolves to None, exclude it from this check. If all sources are None, then an error is raised higher up the stack in the output_field property. """ + # This guess is mostly a bad idea, but there is quite a lot of code + # (especially 3rd party Func subclasses) that depend on it, we'd need a + # deprecation path to fix it. sources_iter = ( source for source in self.get_source_fields() if source is not None ) @@ -467,6 +467,13 @@ class Expression(BaseExpression, Combinable): # Type inference for CombinedExpression.output_field. +# Missing items will result in FieldError, by design. +# +# The current approach for NULL is based on lowest common denominator behavior +# i.e. if one of the supported databases is raising an error (rather than +# return NULL) for `val NULL`, then Django raises FieldError. +NoneType = type(None) + _connector_combinations = [ # Numeric operations - operands of same type. { @@ -482,6 +489,8 @@ _connector_combinations = [ # Behavior for DIV with integer arguments follows Postgres/SQLite, # not MySQL/Oracle. Combinable.DIV, + Combinable.MOD, + Combinable.POW, ) }, # Numeric operations - operands of different type. @@ -499,6 +508,66 @@ _connector_combinations = [ Combinable.DIV, ) }, + # Bitwise operators. + { + connector: [ + (fields.IntegerField, fields.IntegerField, fields.IntegerField), + ] + for connector in ( + Combinable.BITAND, + Combinable.BITOR, + Combinable.BITLEFTSHIFT, + Combinable.BITRIGHTSHIFT, + Combinable.BITXOR, + ) + }, + # Numeric with NULL. + { + connector: [ + (field_type, NoneType, field_type), + (NoneType, field_type, field_type), + ] + for connector in ( + Combinable.ADD, + Combinable.SUB, + Combinable.MUL, + Combinable.DIV, + Combinable.MOD, + Combinable.POW, + ) + for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField) + }, + # Date/DateTimeField/DurationField/TimeField. + { + Combinable.ADD: [ + # Date/DateTimeField. + (fields.DateField, fields.DurationField, fields.DateTimeField), + (fields.DateTimeField, fields.DurationField, fields.DateTimeField), + (fields.DurationField, fields.DateField, fields.DateTimeField), + (fields.DurationField, fields.DateTimeField, fields.DateTimeField), + # DurationField. + (fields.DurationField, fields.DurationField, fields.DurationField), + # TimeField. + (fields.TimeField, fields.DurationField, fields.TimeField), + (fields.DurationField, fields.TimeField, fields.TimeField), + ], + }, + { + Combinable.SUB: [ + # Date/DateTimeField. + (fields.DateField, fields.DurationField, fields.DateTimeField), + (fields.DateTimeField, fields.DurationField, fields.DateTimeField), + (fields.DateField, fields.DateField, fields.DurationField), + (fields.DateField, fields.DateTimeField, fields.DurationField), + (fields.DateTimeField, fields.DateField, fields.DurationField), + (fields.DateTimeField, fields.DateTimeField, fields.DurationField), + # DurationField. + (fields.DurationField, fields.DurationField, fields.DurationField), + # TimeField. + (fields.TimeField, fields.DurationField, fields.TimeField), + (fields.TimeField, fields.TimeField, fields.DurationField), + ], + }, ] _connector_combinators = defaultdict(list) @@ -552,17 +621,21 @@ class CombinedExpression(SQLiteNumericMixin, Expression): self.lhs, self.rhs = exprs def _resolve_output_field(self): - try: - return super()._resolve_output_field() - except FieldError: - combined_type = _resolve_combined_type( - self.connector, - type(self.lhs.output_field), - type(self.rhs.output_field), + # We avoid using super() here for reasons given in + # Expression._resolve_output_field() + combined_type = _resolve_combined_type( + self.connector, + type(self.lhs._output_field_or_none), + type(self.rhs._output_field_or_none), + ) + if combined_type is None: + raise FieldError( + f"Cannot infer type of {self.connector!r} expression involving these " + f"types: {self.lhs.output_field.__class__.__name__}, " + f"{self.rhs.output_field.__class__.__name__}. You must set " + f"output_field." ) - if combined_type is None: - raise - return combined_type() + return combined_type() def as_sql(self, compiler, connection): expressions = [] diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index b779fe551c..3922478bf3 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1126,8 +1126,8 @@ class AggregateTestCase(TestCase): def test_combine_different_types(self): msg = ( - "Expression contains mixed types: FloatField, DecimalField. " - "You must set output_field." + "Cannot infer type of '+' expression involving these types: FloatField, " + "DecimalField. You must set output_field." ) qs = Book.objects.annotate(sums=Sum("rating") + Sum("pages") + Sum("price")) with self.assertRaisesMessage(FieldError, msg): diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index a3d8d9b187..72e6020fa0 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1736,6 +1736,17 @@ class FTimeDeltaTests(TestCase): ], ) + def test_datetime_and_duration_field_addition_with_annotate_and_no_output_field( + self, + ): + test_set = Experiment.objects.annotate( + estimated_end=F("start") + F("estimated_time") + ) + self.assertEqual( + [e.estimated_end for e in test_set], + [e.start + e.estimated_time for e in test_set], + ) + @skipUnlessDBFeature("supports_temporal_subtraction") def test_datetime_subtraction_with_annotate_and_no_output_field(self): test_set = Experiment.objects.annotate( @@ -2438,8 +2449,10 @@ class CombinedExpressionTests(SimpleTestCase): (null, Combinable.ADD, DateTimeField), (DateField, Combinable.SUB, null), ] - msg = "Expression contains mixed types: " for lhs, connector, rhs in tests: + msg = ( + f"Cannot infer type of {connector!r} expression involving these types: " + ) with self.subTest(lhs=lhs, connector=connector, rhs=rhs): expr = CombinedExpression( Expression(lhs()), @@ -2452,16 +2465,34 @@ class CombinedExpressionTests(SimpleTestCase): def test_resolve_output_field_dates(self): tests = [ # Add - same type. + (DateField, Combinable.ADD, DateField, FieldError), + (DateTimeField, Combinable.ADD, DateTimeField, FieldError), + (TimeField, Combinable.ADD, TimeField, FieldError), (DurationField, Combinable.ADD, DurationField, DurationField), + # Add - different type. + (DateField, Combinable.ADD, DurationField, DateTimeField), + (DateTimeField, Combinable.ADD, DurationField, DateTimeField), + (TimeField, Combinable.ADD, DurationField, TimeField), + (DurationField, Combinable.ADD, DateField, DateTimeField), + (DurationField, Combinable.ADD, DateTimeField, DateTimeField), + (DurationField, Combinable.ADD, TimeField, TimeField), # Subtract - same type. + (DateField, Combinable.SUB, DateField, DurationField), + (DateTimeField, Combinable.SUB, DateTimeField, DurationField), + (TimeField, Combinable.SUB, TimeField, DurationField), (DurationField, Combinable.SUB, DurationField, DurationField), # Subtract - different type. + (DateField, Combinable.SUB, DurationField, DateTimeField), + (DateTimeField, Combinable.SUB, DurationField, DateTimeField), + (TimeField, Combinable.SUB, DurationField, TimeField), (DurationField, Combinable.SUB, DateField, FieldError), (DurationField, Combinable.SUB, DateTimeField, FieldError), (DurationField, Combinable.SUB, DateTimeField, FieldError), ] - msg = "Expression contains mixed types: " for lhs, connector, rhs, combined in tests: + msg = ( + f"Cannot infer type of {connector!r} expression involving these types: " + ) with self.subTest(lhs=lhs, connector=connector, rhs=rhs, combined=combined): expr = CombinedExpression( Expression(lhs()), @@ -2477,8 +2508,8 @@ class CombinedExpressionTests(SimpleTestCase): def test_mixed_char_date_with_annotate(self): queryset = Experiment.objects.annotate(nonsense=F("name") + F("assigned")) msg = ( - "Expression contains mixed types: CharField, DateField. You must set " - "output_field." + "Cannot infer type of '+' expression involving these types: CharField, " + "DateField. You must set output_field." ) with self.assertRaisesMessage(FieldError, msg): list(queryset)