From 89c7454dbdae3e0df6d96aa6132205d05e4a9b3d Mon Sep 17 00:00:00 2001 From: Thomas Chaumeny Date: Fri, 7 Jul 2023 13:08:17 +0200 Subject: [PATCH] Fixed #34698 -- Made QuerySet.bulk_create() retrieve primary keys when updating conflicts. --- django/db/models/query.py | 7 ++++++- docs/ref/models/querysets.txt | 10 +++++++--- docs/releases/5.0.txt | 4 ++++ tests/bulk_create/tests.py | 37 ++++++++++++++++++++++++++++------- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/django/db/models/query.py b/django/db/models/query.py index 5ac2407ea3..395ba6e404 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1837,12 +1837,17 @@ class QuerySet(AltersData): inserted_rows = [] bulk_return = connection.features.can_return_rows_from_bulk_insert for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]: - if bulk_return and on_conflict is None: + if bulk_return and ( + on_conflict is None or on_conflict == OnConflict.UPDATE + ): inserted_rows.extend( self._insert( item, fields=fields, using=self.db, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, returning_fields=self.model._meta.db_returning_fields, ) ) diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index a754953264..fd6bb39ff8 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2411,9 +2411,13 @@ On databases that support it (all except Oracle and SQLite < 3.24), setting the SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may be in conflict must be provided. -Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable -setting the primary key on each model instance (if the database normally -support it). +Enabling the ``ignore_conflicts`` parameter disables setting the primary key on +each model instance (if the database normally supports it). + +.. versionchanged:: 5.0 + + In older versions, enabling the ``update_conflicts`` parameter prevented + setting the primary key on each model instance. .. warning:: diff --git a/docs/releases/5.0.txt b/docs/releases/5.0.txt index 0ad0835c29..e4c1eac1d9 100644 --- a/docs/releases/5.0.txt +++ b/docs/releases/5.0.txt @@ -357,6 +357,10 @@ Models :meth:`.Model.save` now allows specifying a tuple of parent classes that must be forced to be inserted. +* :meth:`.QuerySet.bulk_create` and :meth:`.QuerySet.abulk_create` methods now + set the primary key on each model instance when the ``update_conflicts`` + parameter is enabled (if the database supports it). + Pagination ~~~~~~~~~~ diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index aee0cd9996..7b86a2def5 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -582,12 +582,16 @@ class BulkCreateTests(TestCase): TwoFields(f1=1, f2=1, name="c"), TwoFields(f1=2, f2=2, name="d"), ] - TwoFields.objects.bulk_create( + results = TwoFields.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=unique_fields, update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(TwoFields.objects.count(), 2) self.assertCountEqual( TwoFields.objects.values("f1", "f2", "name"), @@ -619,7 +623,6 @@ class BulkCreateTests(TestCase): TwoFields(f1=2, f2=2, name="b"), ] ) - self.assertEqual(TwoFields.objects.count(), 2) obj1 = TwoFields.objects.get(f1=1) obj2 = TwoFields.objects.get(f1=2) @@ -627,12 +630,16 @@ class BulkCreateTests(TestCase): TwoFields(pk=obj1.pk, f1=3, f2=3, name="c"), TwoFields(pk=obj2.pk, f1=4, f2=4, name="d"), ] - TwoFields.objects.bulk_create( + results = TwoFields.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=["pk"], update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(TwoFields.objects.count(), 2) self.assertCountEqual( TwoFields.objects.values("f1", "f2", "name"), @@ -680,12 +687,16 @@ class BulkCreateTests(TestCase): description=("Japan is an island country in East Asia."), ), ] - Country.objects.bulk_create( + results = Country.objects.bulk_create( new_data, update_conflicts=True, update_fields=["description"], unique_fields=unique_fields, ) + self.assertEqual(len(results), len(new_data)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(Country.objects.count(), 6) self.assertCountEqual( Country.objects.values("iso_two_letter", "description"), @@ -743,12 +754,16 @@ class BulkCreateTests(TestCase): UpsertConflict(number=2, rank=2, name="Olivia"), UpsertConflict(number=3, rank=1, name="Hannah"), ] - UpsertConflict.objects.bulk_create( + results = UpsertConflict.objects.bulk_create( conflicting_objects, update_conflicts=True, update_fields=["name", "rank"], unique_fields=unique_fields, ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(UpsertConflict.objects.count(), 3) self.assertCountEqual( UpsertConflict.objects.values("number", "rank", "name"), @@ -759,12 +774,16 @@ class BulkCreateTests(TestCase): ], ) - UpsertConflict.objects.bulk_create( + results = UpsertConflict.objects.bulk_create( conflicting_objects + [UpsertConflict(number=4, rank=4, name="Mark")], update_conflicts=True, update_fields=["name", "rank"], unique_fields=unique_fields, ) + self.assertEqual(len(results), 4) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(UpsertConflict.objects.count(), 4) self.assertCountEqual( UpsertConflict.objects.values("number", "rank", "name"), @@ -803,12 +822,16 @@ class BulkCreateTests(TestCase): FieldsWithDbColumns(rank=1, name="c"), FieldsWithDbColumns(rank=2, name="d"), ] - FieldsWithDbColumns.objects.bulk_create( + results = FieldsWithDbColumns.objects.bulk_create( conflicting_objects, update_conflicts=True, unique_fields=["rank"], update_fields=["name"], ) + self.assertEqual(len(results), len(conflicting_objects)) + if connection.features.can_return_rows_from_bulk_insert: + for instance in results: + self.assertIsNotNone(instance.pk) self.assertEqual(FieldsWithDbColumns.objects.count(), 2) self.assertCountEqual( FieldsWithDbColumns.objects.values("rank", "name"),