From 519c0c332dd8adc2654b0566c926a9439231247a Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Mon, 14 Sep 2020 20:50:44 +0100 Subject: [PATCH] Simplify Page.copy() (#6277) * Use Django modelcluster's copy_all_child_relations method * page.specific.__class__ => page.specific_class * Use child_object_map as returned by modelcluster for revision rewriting * Use modelcluster to commit child relations * Use a callback instead of a method for _save_copy_instance * Make CopyMixin work on non-MTI models * Make gathering exclude_fields the job of the callee ._copy() no longer depends on any custom attributes in the base class! * Converted CopyMixin into some utility methods (and renamed some stuff) * Don't commit the new page in _copy * Refactor _copy_m2m_relations to be more standalone * Merge _make_copy into _copy Not really useful outside _copy * Give unused variable a name * Version-bump django-modelcluster to 5.1 * Address review feedback Co-authored-by: Matt Westcott --- setup.py | 2 +- .../migrations/0047_add_workflow_models.py | 2 - wagtail/core/models.py | 207 +++++++----------- 3 files changed, 83 insertions(+), 128 deletions(-) diff --git a/setup.py b/setup.py index 74b84d5b0e..d0c89fe78a 100755 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ except ImportError: install_requires = [ "Django>=2.2,<3.2", - "django-modelcluster>=5.0.2,<6.0", + "django-modelcluster>=5.1,<6.0", "django-taggit>=1.0,<2.0", "django-treebeard>=4.2.0,<5.0", "djangorestframework>=3.11.1,<4.0", diff --git a/wagtail/core/migrations/0047_add_workflow_models.py b/wagtail/core/migrations/0047_add_workflow_models.py index 1048d65ead..eda6b6c81f 100755 --- a/wagtail/core/migrations/0047_add_workflow_models.py +++ b/wagtail/core/migrations/0047_add_workflow_models.py @@ -4,7 +4,6 @@ from django.conf import settings from django.db import migrations, models import django.db.models.deletion import modelcluster.fields -import wagtail.core.models class Migration(migrations.Migration): @@ -45,7 +44,6 @@ class Migration(migrations.Migration): 'verbose_name': 'Task state', 'verbose_name_plural': 'Task states', }, - bases=(wagtail.core.models.MultiTableCopyMixin, models.Model), ), migrations.CreateModel( name='Workflow', diff --git a/wagtail/core/models.py b/wagtail/core/models.py index c97bb90cca..a02b65adcd 100644 --- a/wagtail/core/models.py +++ b/wagtail/core/models.py @@ -1,6 +1,5 @@ import json import logging -from collections import defaultdict from io import StringIO from urllib.parse import urlparse @@ -51,131 +50,82 @@ logger = logging.getLogger('wagtail.core') PAGE_TEMPLATE_VAR = 'page' -class MultiTableCopyMixin: - default_exclude_fields_in_copy = ['id'] +def _extract_field_data(source, exclude_fields=None): + """ + Get dictionaries representing the model's field data. - def _get_field_dictionaries(self, exclude_fields=None): - """Get dictionaries representing the model: one with all non m2m fields, and one containing the m2m fields""" - specific_self = self.specific - exclude_fields = exclude_fields or [] - specific_dict = {} - specific_m2m_dict = {} + This excludes many to many fields (which are handled by _copy_m2m_relations)' + """ + exclude_fields = exclude_fields or [] + data_dict = {} - for field in specific_self._meta.get_fields(): - # Ignore explicitly excluded fields - if field.name in exclude_fields: - continue + for field in source._meta.get_fields(): + # Ignore explicitly excluded fields + if field.name in exclude_fields: + continue - # Ignore reverse relations - if field.auto_created: - continue + # Ignore reverse relations + if field.auto_created: + continue - # Copy parental m2m relations - # Otherwise add them to the m2m dict to be set after saving - if field.many_to_many: - if isinstance(field, ParentalManyToManyField): - parental_field = getattr(specific_self, field.name) - if hasattr(parental_field, 'all'): - values = parental_field.all() - if values: - specific_dict[field.name] = values - else: - try: - # Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects - through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == specific_self.__class__ or field.related_model in specific_self._meta.parents)] - if through_model_parental_links: - continue - except AttributeError: - pass - specific_m2m_dict[field.name] = getattr(specific_self, field.name).all() - continue + # Copy parental m2m relations + if field.many_to_many: + if isinstance(field, ParentalManyToManyField): + parental_field = getattr(source, field.name) + if hasattr(parental_field, 'all'): + values = parental_field.all() + if values: + data_dict[field.name] = values + continue - # Ignore parent links (page_ptr) - if isinstance(field, models.OneToOneField) and field.remote_field.parent_link: - continue + # Ignore parent links (page_ptr) + if isinstance(field, models.OneToOneField) and field.remote_field.parent_link: + continue - specific_dict[field.name] = getattr(specific_self, field.name) + data_dict[field.name] = getattr(source, field.name) - return specific_dict, specific_m2m_dict + return data_dict - def _get_copy_instance(self, specific_dict, specific_m2m_dict, update_attrs=None): - """Create a copy instance (without saving) from dictionaries of the model's fields, and update any attributes in update_attrs""" - if not update_attrs: - update_attrs = {} +def _copy_m2m_relations(source, target, exclude_fields=None, update_attrs=None): + """ + Copies non-ParentalManyToMany m2m relations + """ + update_attrs = update_attrs or {} + exclude_fields = exclude_fields or [] - specific_class = self.specific.__class__ - - copy_instance = specific_class(**specific_dict) - - if update_attrs: - for field, value in update_attrs.items(): - if field in specific_m2m_dict: + for field in source._meta.get_fields(): + # Copy m2m relations. Ignore explicitly excluded fields, reverse relations, and Parental m2m fields. + if field.many_to_many and field.name not in exclude_fields and not field.auto_created and not isinstance(field, ParentalManyToManyField): + try: + # Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects + through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == source.__class__ or field.related_model in source._meta.parents)] + if through_model_parental_links: continue - setattr(copy_instance, field, value) + except AttributeError: + pass - return copy_instance + if field.name in update_attrs: + value = update_attrs[field.name] - def _save_copy_instance(self, instance, **kwargs): - raise NotImplementedError + else: + value = getattr(source, field.name).all() - def _set_m2m_relations(self, instance, specific_m2m_dict, update_attrs=None): - """Set non-ParentalManyToMany m2m relations""" - if not update_attrs: - update_attrs = {} - for field_name, value in specific_m2m_dict.items(): - value = update_attrs.get(field_name, value) - getattr(instance, field_name).set(value) + getattr(target, field.name).set(value) - return instance - def _copy_child_objects_to_instance(self, instance, exclude_fields=None, process_child_object=None): - """Copy objects linked to the model by a ParentalKey, and set this to the new revision""" +def _copy(source, exclude_fields=None, update_attrs=None): + data_dict = _extract_field_data(source, exclude_fields=exclude_fields) + target = source.__class__(**data_dict) - # A dict that maps child objects to their new ids - # Used to remap child object ids in revisions - child_object_id_map = defaultdict(dict) - exclude_fields = exclude_fields or [] - specific_self = self.specific - for child_relation in get_all_child_relations(specific_self): - accessor_name = child_relation.get_accessor_name() - - if accessor_name in exclude_fields: + if update_attrs: + for field, value in update_attrs.items(): + if field not in data_dict: continue + setattr(target, field, value) - parental_key_name = child_relation.field.attname - child_objects = getattr(specific_self, accessor_name, None) - - if child_objects: - for child_object in child_objects.all(): - old_pk = child_object.pk - child_object.pk = None - setattr(child_object, parental_key_name, instance.id) - - if process_child_object is not None: - process_child_object(specific_self, instance, child_relation, child_object) - - child_object.save() - - # Add mapping to new primary key (so we can apply this change to revisions) - child_object_id_map[accessor_name][old_pk] = child_object.pk - - return child_object_id_map - - def _copy(self, exclude_fields=None, update_attrs=None, process_child_object=None, **kwargs): - exclude_fields = self.default_exclude_fields_in_copy + self.specific.exclude_fields_in_copy + (exclude_fields or []) - - specific_dict, specific_m2m_dict = self._get_field_dictionaries(exclude_fields=exclude_fields) - - copy_instance = self._get_copy_instance(specific_dict, specific_m2m_dict, update_attrs=update_attrs) - - copy_instance = self._save_copy_instance(copy_instance, **kwargs) - - copy_instance = self._set_m2m_relations(copy_instance, specific_m2m_dict, update_attrs) - - child_object_id_map = self._copy_child_objects_to_instance(copy_instance, exclude_fields=exclude_fields, process_child_object=process_child_object) - - return copy_instance, child_object_id_map + child_object_map = source.copy_all_child_relations(target, exclude=exclude_fields) + return target, child_object_map class SiteManager(models.Manager): @@ -388,7 +338,7 @@ class AbstractPage(TreebeardPathFixMixin, MP_Node): abstract = True -class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase): +class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase): title = models.CharField( verbose_name=_('title'), max_length=255, @@ -1428,7 +1378,7 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m :param log_action flag for logging the action. Pass None to skip logging. Can be passed an action string. Defaults to 'wagtail.copy' """ - + exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or []) specific_self = self.specific if keep_live: base_update_attrs = {} @@ -1447,7 +1397,22 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m if update_attrs: base_update_attrs.update(update_attrs) - page_copy, child_object_id_map = self._copy(exclude_fields=exclude_fields, update_attrs=base_update_attrs, to=to, recursive=recursive, process_child_object=process_child_object) + page_copy, child_object_map = _copy(specific_self, exclude_fields=exclude_fields, update_attrs=base_update_attrs) + + # Save copied child objects and run process_child_object on them if we need to + for (child_relation, old_pk), child_object in child_object_map.items(): + if process_child_object: + process_child_object(specific_self, page_copy, child_relation, child_object) + + # Save the new page + if to: + if recursive and (to == self or to.is_descendant_of(self)): + raise Exception("You cannot copy a tree branch recursively into itself") + page_copy = to.add_child(instance=page_copy) + else: + page_copy = self.add_sibling(instance=page_copy) + + _copy_m2m_relations(specific_self, page_copy, exclude_fields=exclude_fields, update_attrs=base_update_attrs) # Copy revisions if copy_revisions: @@ -1476,7 +1441,8 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m # Remap primary key to copied versions # If the primary key is not recognised (eg, the child object has been deleted from the database) # set the primary key to None - child_object['pk'] = child_object_id_map[accessor_name].get(child_object['pk'], None) + copied_child_object = child_object_map.get((child_relation, child_object['pk'])) + child_object['pk'] = copied_child_object.pk if copied_child_object else None revision.content_json = json.dumps(revision_content) @@ -1542,15 +1508,6 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m return page_copy - def _save_copy_instance(self, instance, to=None, recursive=False, **kwargs): - if to: - if recursive and (to == self or to.is_descendant_of(self)): - raise Exception("You cannot copy a tree branch recursively into itself") - instance = to.add_child(instance=instance) - else: - instance = self.add_sibling(instance=instance) - return instance - copy.alters_data = True def permissions_for_user(self, user): @@ -3489,7 +3446,7 @@ class TaskStateManager(models.Manager): return states -class TaskState(MultiTableCopyMixin, models.Model): +class TaskState(models.Model): """Tracks the status of a given Task for a particular page revision.""" STATUS_IN_PROGRESS = 'in_progress' STATUS_APPROVED = 'approved' @@ -3526,6 +3483,7 @@ class TaskState(MultiTableCopyMixin, models.Model): on_delete=models.CASCADE ) exclude_fields_in_copy = [] + default_exclude_fields_in_copy = ['id'] objects = TaskStateManager() @@ -3630,11 +3588,10 @@ class TaskState(MultiTableCopyMixin, models.Model): def copy(self, update_attrs=None, exclude_fields=None): """Copy this task state, excluding the attributes in the ``exclude_fields`` list and updating any attributes to values specified in the ``update_attrs`` dictionary of ``attribute``: ``new value`` pairs""" - copy_instance, _ = self._copy(exclude_fields, update_attrs) - return copy_instance - - def _save_copy_instance(self, instance, **kwargs): + exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or []) + instance, child_object_map = _copy(self.specific, exclude_fields, update_attrs) instance.save() + _copy_m2m_relations(self, instance, exclude_fields=exclude_fields) return instance def get_comment(self):