0
0
mirror of https://github.com/wagtail/wagtail.git synced 2024-11-24 01:57:32 +01:00

Simplify Page.copy() (#6277)

* Use Django modelcluster's copy_all_child_relations method

* page.specific.__class__ => page.specific_class

* Use child_object_map as returned by modelcluster for revision rewriting

* Use modelcluster to commit child relations

* Use a callback instead of a method for _save_copy_instance

* Make CopyMixin work on non-MTI models

* Make gathering exclude_fields the job of the callee

._copy() no longer depends on any custom attributes in the base class!

* Converted CopyMixin into some utility methods (and renamed some stuff)

* Don't commit the new page in _copy

* Refactor _copy_m2m_relations to be more standalone

* Merge _make_copy into _copy

Not really useful outside _copy

* Give unused variable a name

* Version-bump django-modelcluster to 5.1

* Address review feedback

Co-authored-by: Matt Westcott <matt@west.co.tt>
This commit is contained in:
Karl Hobley 2020-09-14 20:50:44 +01:00 committed by GitHub
parent 5e6a674686
commit 519c0c332d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 128 deletions

View File

@ -22,7 +22,7 @@ except ImportError:
install_requires = [
"Django>=2.2,<3.2",
"django-modelcluster>=5.0.2,<6.0",
"django-modelcluster>=5.1,<6.0",
"django-taggit>=1.0,<2.0",
"django-treebeard>=4.2.0,<5.0",
"djangorestframework>=3.11.1,<4.0",

View File

@ -4,7 +4,6 @@ from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import modelcluster.fields
import wagtail.core.models
class Migration(migrations.Migration):
@ -45,7 +44,6 @@ class Migration(migrations.Migration):
'verbose_name': 'Task state',
'verbose_name_plural': 'Task states',
},
bases=(wagtail.core.models.MultiTableCopyMixin, models.Model),
),
migrations.CreateModel(
name='Workflow',

View File

@ -1,6 +1,5 @@
import json
import logging
from collections import defaultdict
from io import StringIO
from urllib.parse import urlparse
@ -51,131 +50,82 @@ logger = logging.getLogger('wagtail.core')
PAGE_TEMPLATE_VAR = 'page'
class MultiTableCopyMixin:
default_exclude_fields_in_copy = ['id']
def _extract_field_data(source, exclude_fields=None):
"""
Get dictionaries representing the model's field data.
def _get_field_dictionaries(self, exclude_fields=None):
"""Get dictionaries representing the model: one with all non m2m fields, and one containing the m2m fields"""
specific_self = self.specific
exclude_fields = exclude_fields or []
specific_dict = {}
specific_m2m_dict = {}
This excludes many to many fields (which are handled by _copy_m2m_relations)'
"""
exclude_fields = exclude_fields or []
data_dict = {}
for field in specific_self._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
for field in source._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Copy parental m2m relations
# Otherwise add them to the m2m dict to be set after saving
if field.many_to_many:
if isinstance(field, ParentalManyToManyField):
parental_field = getattr(specific_self, field.name)
if hasattr(parental_field, 'all'):
values = parental_field.all()
if values:
specific_dict[field.name] = values
else:
try:
# Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects
through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == specific_self.__class__ or field.related_model in specific_self._meta.parents)]
if through_model_parental_links:
continue
except AttributeError:
pass
specific_m2m_dict[field.name] = getattr(specific_self, field.name).all()
continue
# Copy parental m2m relations
if field.many_to_many:
if isinstance(field, ParentalManyToManyField):
parental_field = getattr(source, field.name)
if hasattr(parental_field, 'all'):
values = parental_field.all()
if values:
data_dict[field.name] = values
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
specific_dict[field.name] = getattr(specific_self, field.name)
data_dict[field.name] = getattr(source, field.name)
return specific_dict, specific_m2m_dict
return data_dict
def _get_copy_instance(self, specific_dict, specific_m2m_dict, update_attrs=None):
"""Create a copy instance (without saving) from dictionaries of the model's fields, and update any attributes in update_attrs"""
if not update_attrs:
update_attrs = {}
def _copy_m2m_relations(source, target, exclude_fields=None, update_attrs=None):
"""
Copies non-ParentalManyToMany m2m relations
"""
update_attrs = update_attrs or {}
exclude_fields = exclude_fields or []
specific_class = self.specific.__class__
copy_instance = specific_class(**specific_dict)
if update_attrs:
for field, value in update_attrs.items():
if field in specific_m2m_dict:
for field in source._meta.get_fields():
# Copy m2m relations. Ignore explicitly excluded fields, reverse relations, and Parental m2m fields.
if field.many_to_many and field.name not in exclude_fields and not field.auto_created and not isinstance(field, ParentalManyToManyField):
try:
# Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects
through_model_parental_links = [field for field in field.through._meta.get_fields() if isinstance(field, ParentalKey) and (field.related_model == source.__class__ or field.related_model in source._meta.parents)]
if through_model_parental_links:
continue
setattr(copy_instance, field, value)
except AttributeError:
pass
return copy_instance
if field.name in update_attrs:
value = update_attrs[field.name]
def _save_copy_instance(self, instance, **kwargs):
raise NotImplementedError
else:
value = getattr(source, field.name).all()
def _set_m2m_relations(self, instance, specific_m2m_dict, update_attrs=None):
"""Set non-ParentalManyToMany m2m relations"""
if not update_attrs:
update_attrs = {}
for field_name, value in specific_m2m_dict.items():
value = update_attrs.get(field_name, value)
getattr(instance, field_name).set(value)
getattr(target, field.name).set(value)
return instance
def _copy_child_objects_to_instance(self, instance, exclude_fields=None, process_child_object=None):
"""Copy objects linked to the model by a ParentalKey, and set this to the new revision"""
def _copy(source, exclude_fields=None, update_attrs=None):
data_dict = _extract_field_data(source, exclude_fields=exclude_fields)
target = source.__class__(**data_dict)
# A dict that maps child objects to their new ids
# Used to remap child object ids in revisions
child_object_id_map = defaultdict(dict)
exclude_fields = exclude_fields or []
specific_self = self.specific
for child_relation in get_all_child_relations(specific_self):
accessor_name = child_relation.get_accessor_name()
if accessor_name in exclude_fields:
if update_attrs:
for field, value in update_attrs.items():
if field not in data_dict:
continue
setattr(target, field, value)
parental_key_name = child_relation.field.attname
child_objects = getattr(specific_self, accessor_name, None)
if child_objects:
for child_object in child_objects.all():
old_pk = child_object.pk
child_object.pk = None
setattr(child_object, parental_key_name, instance.id)
if process_child_object is not None:
process_child_object(specific_self, instance, child_relation, child_object)
child_object.save()
# Add mapping to new primary key (so we can apply this change to revisions)
child_object_id_map[accessor_name][old_pk] = child_object.pk
return child_object_id_map
def _copy(self, exclude_fields=None, update_attrs=None, process_child_object=None, **kwargs):
exclude_fields = self.default_exclude_fields_in_copy + self.specific.exclude_fields_in_copy + (exclude_fields or [])
specific_dict, specific_m2m_dict = self._get_field_dictionaries(exclude_fields=exclude_fields)
copy_instance = self._get_copy_instance(specific_dict, specific_m2m_dict, update_attrs=update_attrs)
copy_instance = self._save_copy_instance(copy_instance, **kwargs)
copy_instance = self._set_m2m_relations(copy_instance, specific_m2m_dict, update_attrs)
child_object_id_map = self._copy_child_objects_to_instance(copy_instance, exclude_fields=exclude_fields, process_child_object=process_child_object)
return copy_instance, child_object_id_map
child_object_map = source.copy_all_child_relations(target, exclude=exclude_fields)
return target, child_object_map
class SiteManager(models.Manager):
@ -388,7 +338,7 @@ class AbstractPage(TreebeardPathFixMixin, MP_Node):
abstract = True
class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
title = models.CharField(
verbose_name=_('title'),
max_length=255,
@ -1428,7 +1378,7 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
:param log_action flag for logging the action. Pass None to skip logging.
Can be passed an action string. Defaults to 'wagtail.copy'
"""
exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or [])
specific_self = self.specific
if keep_live:
base_update_attrs = {}
@ -1447,7 +1397,22 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
if update_attrs:
base_update_attrs.update(update_attrs)
page_copy, child_object_id_map = self._copy(exclude_fields=exclude_fields, update_attrs=base_update_attrs, to=to, recursive=recursive, process_child_object=process_child_object)
page_copy, child_object_map = _copy(specific_self, exclude_fields=exclude_fields, update_attrs=base_update_attrs)
# Save copied child objects and run process_child_object on them if we need to
for (child_relation, old_pk), child_object in child_object_map.items():
if process_child_object:
process_child_object(specific_self, page_copy, child_relation, child_object)
# Save the new page
if to:
if recursive and (to == self or to.is_descendant_of(self)):
raise Exception("You cannot copy a tree branch recursively into itself")
page_copy = to.add_child(instance=page_copy)
else:
page_copy = self.add_sibling(instance=page_copy)
_copy_m2m_relations(specific_self, page_copy, exclude_fields=exclude_fields, update_attrs=base_update_attrs)
# Copy revisions
if copy_revisions:
@ -1476,7 +1441,8 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
# Remap primary key to copied versions
# If the primary key is not recognised (eg, the child object has been deleted from the database)
# set the primary key to None
child_object['pk'] = child_object_id_map[accessor_name].get(child_object['pk'], None)
copied_child_object = child_object_map.get((child_relation, child_object['pk']))
child_object['pk'] = copied_child_object.pk if copied_child_object else None
revision.content_json = json.dumps(revision_content)
@ -1542,15 +1508,6 @@ class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, m
return page_copy
def _save_copy_instance(self, instance, to=None, recursive=False, **kwargs):
if to:
if recursive and (to == self or to.is_descendant_of(self)):
raise Exception("You cannot copy a tree branch recursively into itself")
instance = to.add_child(instance=instance)
else:
instance = self.add_sibling(instance=instance)
return instance
copy.alters_data = True
def permissions_for_user(self, user):
@ -3489,7 +3446,7 @@ class TaskStateManager(models.Manager):
return states
class TaskState(MultiTableCopyMixin, models.Model):
class TaskState(models.Model):
"""Tracks the status of a given Task for a particular page revision."""
STATUS_IN_PROGRESS = 'in_progress'
STATUS_APPROVED = 'approved'
@ -3526,6 +3483,7 @@ class TaskState(MultiTableCopyMixin, models.Model):
on_delete=models.CASCADE
)
exclude_fields_in_copy = []
default_exclude_fields_in_copy = ['id']
objects = TaskStateManager()
@ -3630,11 +3588,10 @@ class TaskState(MultiTableCopyMixin, models.Model):
def copy(self, update_attrs=None, exclude_fields=None):
"""Copy this task state, excluding the attributes in the ``exclude_fields`` list and updating any attributes to values
specified in the ``update_attrs`` dictionary of ``attribute``: ``new value`` pairs"""
copy_instance, _ = self._copy(exclude_fields, update_attrs)
return copy_instance
def _save_copy_instance(self, instance, **kwargs):
exclude_fields = self.default_exclude_fields_in_copy + self.exclude_fields_in_copy + (exclude_fields or [])
instance, child_object_map = _copy(self.specific, exclude_fields, update_attrs)
instance.save()
_copy_m2m_relations(self, instance, exclude_fields=exclude_fields)
return instance
def get_comment(self):