From 509763c79952cde02d9f5b584af4278bdbed77b2 Mon Sep 17 00:00:00 2001 From: David Sanders Date: Mon, 5 Aug 2024 08:22:29 +0200 Subject: [PATCH] Fixed #35638 -- Updated validate_constraints to consider db_default. --- django/db/models/expressions.py | 34 ++++++++++- django/db/models/fields/__init__.py | 12 ++-- docs/releases/5.0.8.txt | 4 ++ tests/constraints/models.py | 7 +++ tests/constraints/tests.py | 57 ++++++++++++++++++- .../migrations/0002_create_test_models.py | 2 +- tests/postgres_tests/models.py | 2 +- tests/postgres_tests/test_constraints.py | 9 +++ tests/validation/models.py | 2 +- tests/validation/test_unique.py | 14 +++++ 10 files changed, 130 insertions(+), 13 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 4a242012ee..dc09e43fda 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1250,9 +1250,41 @@ class Star(Expression): class DatabaseDefault(Expression): - """Placeholder expression for the database default in an insert query.""" + """ + Expression to use DEFAULT keyword during insert otherwise the underlying expression. + """ + + def __init__(self, expression, output_field=None): + super().__init__(output_field) + self.expression = expression + + def get_source_expressions(self): + return [self.expression] + + def set_source_expressions(self, exprs): + (self.expression,) = exprs + + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + resolved_expression = self.expression.resolve_expression( + query=query, + allow_joins=allow_joins, + reuse=reuse, + summarize=summarize, + for_save=for_save, + ) + # Defaults used outside an INSERT context should resolve to their + # underlying expression. + if not for_save: + return resolved_expression + return DatabaseDefault( + resolved_expression, output_field=self._output_field_or_none + ) def as_sql(self, compiler, connection): + if not connection.features.supports_default_keyword_in_insert: + return compiler.compile(self.expression) return "DEFAULT", [] diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index 796c4d23c4..d1f31f0211 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -983,13 +983,7 @@ class Field(RegisterLookupMixin): def pre_save(self, model_instance, add): """Return field's value just before saving.""" - value = getattr(model_instance, self.attname) - if not connection.features.supports_default_keyword_in_insert: - from django.db.models.expressions import DatabaseDefault - - if isinstance(value, DatabaseDefault): - return self._db_default_expression - return value + return getattr(model_instance, self.attname) def get_prep_value(self, value): """Perform preliminary non-db specific value checks and conversions.""" @@ -1031,7 +1025,9 @@ class Field(RegisterLookupMixin): if self.db_default is not NOT_PROVIDED: from django.db.models.expressions import DatabaseDefault - return DatabaseDefault + return lambda: DatabaseDefault( + self._db_default_expression, output_field=self + ) if ( not self.empty_strings_allowed diff --git a/docs/releases/5.0.8.txt b/docs/releases/5.0.8.txt index 31de9985c4..5cc3faec98 100644 --- a/docs/releases/5.0.8.txt +++ b/docs/releases/5.0.8.txt @@ -28,3 +28,7 @@ Bugfixes * Fixed a bug in Django 5.0 that caused a system check crash when ``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an ``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`). + +* Fixed a bug in Django 5.0 which caused constraint validation to either crash + or incorrectly raise validation errors for constraints referring to fields + using ``Field.db_default`` (:ticket:`35638`). diff --git a/tests/constraints/models.py b/tests/constraints/models.py index 87b97b2a85..983d550502 100644 --- a/tests/constraints/models.py +++ b/tests/constraints/models.py @@ -128,3 +128,10 @@ class JSONFieldModel(models.Model): class Meta: required_db_features = {"supports_json_field"} + + +class ModelWithDatabaseDefault(models.Model): + field = models.CharField(max_length=255) + field_with_db_default = models.CharField( + max_length=255, db_default=models.Value("field_with_db_default") + ) diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 00c3d958e3..350f05f2b8 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -4,7 +4,7 @@ from django.core.exceptions import ValidationError from django.db import IntegrityError, connection, models from django.db.models import F from django.db.models.constraints import BaseConstraint, UniqueConstraint -from django.db.models.functions import Abs, Lower +from django.db.models.functions import Abs, Lower, Upper from django.db.transaction import atomic from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature from django.test.utils import ignore_warnings @@ -14,6 +14,7 @@ from .models import ( ChildModel, ChildUniqueConstraintProduct, JSONFieldModel, + ModelWithDatabaseDefault, Product, UniqueConstraintConditionProduct, UniqueConstraintDeferrable, @@ -396,6 +397,33 @@ class CheckConstraintTests(TestCase): with self.assertWarnsRegex(RemovedInDjango60Warning, msg): self.assertIs(constraint.check, other_condition) + def test_database_default(self): + models.CheckConstraint( + condition=models.Q(field_with_db_default="field_with_db_default"), + name="check_field_with_db_default", + ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) + + # Ensure that a check also does not silently pass with either + # FieldError or DatabaseError when checking with a db_default. + with self.assertRaises(ValidationError): + models.CheckConstraint( + condition=models.Q( + field_with_db_default="field_with_db_default", field="field" + ), + name="check_field_with_db_default_2", + ).validate( + ModelWithDatabaseDefault, ModelWithDatabaseDefault(field="not-field") + ) + + with self.assertRaises(ValidationError): + models.CheckConstraint( + condition=models.Q(field_with_db_default="field_with_db_default"), + name="check_field_with_db_default", + ).validate( + ModelWithDatabaseDefault, + ModelWithDatabaseDefault(field_with_db_default="other value"), + ) + class UniqueConstraintTests(TestCase): @classmethod @@ -1265,3 +1293,30 @@ class UniqueConstraintTests(TestCase): msg = "A unique constraint must be named." with self.assertRaisesMessage(ValueError, msg): models.UniqueConstraint(fields=["field"]) + + def test_database_default(self): + models.UniqueConstraint( + fields=["field_with_db_default"], name="unique_field_with_db_default" + ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) + models.UniqueConstraint( + Upper("field_with_db_default"), + name="unique_field_with_db_default_expression", + ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) + + ModelWithDatabaseDefault.objects.create() + + msg = ( + "Model with database default with this Field with db default already " + "exists." + ) + with self.assertRaisesMessage(ValidationError, msg): + models.UniqueConstraint( + fields=["field_with_db_default"], name="unique_field_with_db_default" + ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) + + msg = "Constraint “unique_field_with_db_default_expression” is violated." + with self.assertRaisesMessage(ValidationError, msg): + models.UniqueConstraint( + Upper("field_with_db_default"), + name="unique_field_with_db_default_expression", + ).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault()) diff --git a/tests/postgres_tests/migrations/0002_create_test_models.py b/tests/postgres_tests/migrations/0002_create_test_models.py index 5538b436ad..188f79607d 100644 --- a/tests/postgres_tests/migrations/0002_create_test_models.py +++ b/tests/postgres_tests/migrations/0002_create_test_models.py @@ -434,7 +434,7 @@ class Migration(migrations.Migration): primary_key=True, ), ), - ("ints", IntegerRangeField(null=True, blank=True)), + ("ints", IntegerRangeField(null=True, blank=True, db_default=(5, 10))), ("bigints", BigIntegerRangeField(null=True, blank=True)), ("decimals", DecimalRangeField(null=True, blank=True)), ("timestamps", DateTimeRangeField(null=True, blank=True)), diff --git a/tests/postgres_tests/models.py b/tests/postgres_tests/models.py index a97894e327..e3118bc590 100644 --- a/tests/postgres_tests/models.py +++ b/tests/postgres_tests/models.py @@ -130,7 +130,7 @@ class LineSavedSearch(PostgreSQLModel): class RangesModel(PostgreSQLModel): - ints = IntegerRangeField(blank=True, null=True) + ints = IntegerRangeField(blank=True, null=True, db_default=(5, 10)) bigints = BigIntegerRangeField(blank=True, null=True) decimals = DecimalRangeField(blank=True, null=True) timestamps = DateTimeRangeField(blank=True, null=True) diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index 770d4b1702..f571a96f35 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -1213,3 +1213,12 @@ class ExclusionConstraintTests(PostgreSQLTestCase): constraint_name, self.get_constraints(ModelWithExclusionConstraint._meta.db_table), ) + + def test_database_default(self): + constraint = ExclusionConstraint( + name="ints_equal", expressions=[("ints", RangeOperators.EQUAL)] + ) + RangesModel.objects.create() + msg = "Constraint “ints_equal” is violated." + with self.assertRaisesMessage(ValidationError, msg): + constraint.validate(RangesModel, RangesModel()) diff --git a/tests/validation/models.py b/tests/validation/models.py index 653be4a239..ed88750364 100644 --- a/tests/validation/models.py +++ b/tests/validation/models.py @@ -48,7 +48,7 @@ class ModelToValidate(models.Model): class UniqueFieldsModel(models.Model): unique_charfield = models.CharField(max_length=100, unique=True) - unique_integerfield = models.IntegerField(unique=True) + unique_integerfield = models.IntegerField(unique=True, db_default=42) non_unique_field = models.IntegerField() diff --git a/tests/validation/test_unique.py b/tests/validation/test_unique.py index 4a8b3894f0..36ee6e9da0 100644 --- a/tests/validation/test_unique.py +++ b/tests/validation/test_unique.py @@ -146,6 +146,20 @@ class PerformUniqueChecksTest(TestCase): mtv = ModelToValidate(number=10, name="Some Name") mtv.full_clean() + def test_unique_db_default(self): + UniqueFieldsModel.objects.create(unique_charfield="foo", non_unique_field=42) + um = UniqueFieldsModel(unique_charfield="bar", non_unique_field=42) + with self.assertRaises(ValidationError) as cm: + um.full_clean() + self.assertEqual( + cm.exception.message_dict, + { + "unique_integerfield": [ + "Unique fields model with this Unique integerfield already exists." + ] + }, + ) + def test_unique_for_date(self): Post.objects.create( title="Django 1.0 is released",