0
0
mirror of https://github.com/wagtail/wagtail.git synced 2024-11-30 01:46:24 +01:00

Refactor common copy method mixin from Page and TaskState

This commit is contained in:
jacobtm 2020-02-10 09:18:07 +00:00 committed by Matt Westcott
parent 0d5c2d6a30
commit 02a706581e

View File

@ -25,7 +25,7 @@ from django.utils.functional import cached_property
from django.utils.module_loading import import_string
from django.utils.text import capfirst, slugify
from django.utils.translation import gettext_lazy as _
from modelcluster.fields import ParentalKey
from modelcluster.fields import ParentalKey, ParentalManyToManyField
from modelcluster.models import (
ClusterableModel, get_all_child_m2m_relations, get_all_child_relations)
from treebeard.mp_tree import MP_Node
@ -44,6 +44,120 @@ logger = logging.getLogger('wagtail.core')
PAGE_TEMPLATE_VAR = 'page'
class MultiTableCopyMixin:
default_exclude_fields_in_copy = ['id']
def _get_field_dictionaries(self, exclude_fields=None, **kwargs):
specific_self = self.specific
exclude_fields = exclude_fields or []
specific_dict = {}
specific_m2m_dict = {}
for field in specific_self._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Copy child m2m relations
# Otherwise add them to the m2m dict
if field.many_to_many:
if isinstance(field, ParentalManyToManyField):
parental_field = getattr(specific_self, field.name)
if hasattr(parental_field, 'all'):
values = parental_field.all()
if values:
specific_dict[field.name] = values
else:
specific_m2m_dict[field.name] = getattr(specific_self, field.name).all()
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
specific_dict[field.name] = getattr(specific_self, field.name)
return specific_dict, specific_m2m_dict
def _get_copy_instance(self, specific_dict, specific_m2m_dict, update_attrs=None, **kwargs):
if not update_attrs:
update_attrs = {}
specific_class = self.specific.__class__
copy_instance = specific_class(**specific_dict)
if update_attrs:
for field, value in update_attrs.items():
if field in specific_m2m_dict:
continue
setattr(copy_instance, field, value)
return copy_instance
def _save_copy_instance(self, instance, **kwargs):
raise NotImplementedError
def _set_m2m_relations(self, instance, specific_m2m_dict, update_attrs=None, **kwargs):
if not update_attrs:
update_attrs = {}
for field_name, value in specific_m2m_dict.items():
value = update_attrs.get(field_name, value)
getattr(instance, field_name).set(value)
return self._save_copy_instance(instance)
def _copy_child_objects_to_instance(self, instance, exclude_fields=None, process_child_object=None, **kwargs):
# A dict that maps child objects to their new ids
# Used to remap child object ids in revisions
child_object_id_map = defaultdict(dict)
exclude_fields = exclude_fields or []
specific_self = self.specific
for child_relation in get_all_child_relations(specific_self):
accessor_name = child_relation.get_accessor_name()
if accessor_name in exclude_fields:
continue
parental_key_name = child_relation.field.attname
child_objects = getattr(specific_self, accessor_name, None)
if child_objects:
for child_object in child_objects.all():
old_pk = child_object.pk
child_object.pk = None
setattr(child_object, parental_key_name, instance.id)
if process_child_object is not None:
process_child_object(specific_self, instance, child_relation, child_object)
child_object.save()
# Add mapping to new primary key (so we can apply this change to revisions)
child_object_id_map[accessor_name][old_pk] = child_object.pk
return child_object_id_map
def _copy(self, exclude_fields=None, update_attrs=None, process_child_object=None, **kwargs):
exclude_fields = self.default_exclude_fields_in_copy + self.specific.exclude_fields_in_copy + (exclude_fields or [])
specific_dict, specific_m2m_dict = self._get_field_dictionaries(exclude_fields=exclude_fields, **kwargs)
copy_instance = self._get_copy_instance(specific_dict, specific_m2m_dict, update_attrs=update_attrs, **kwargs)
copy_instance = self._save_copy_instance(copy_instance, **kwargs)
copy_instance = self._set_m2m_relations(copy_instance, specific_m2m_dict, update_attrs, **kwargs)
child_object_id_map = self._copy_child_objects_to_instance(copy_instance, exclude_fields=exclude_fields, process_child_object=process_child_object, **kwargs)
return copy_instance, child_object_id_map
class SiteManager(models.Manager):
def get_queryset(self):
return super(SiteManager, self).get_queryset().order_by(Lower("hostname"))
@ -254,7 +368,7 @@ class AbstractPage(MP_Node):
abstract = True
class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
title = models.CharField(
verbose_name=_('title'),
max_length=255,
@ -387,6 +501,7 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
# An array of additional field names that will not be included when a Page is copied.
exclude_fields_in_copy = []
default_exclude_fields_in_copy = ['id', 'path', 'depth', 'numchild', 'url_path', 'path', 'index_entries']
# Define these attributes early to avoid masking errors. (Issue #3078)
# The canonical definition is in wagtailadmin.edit_handlers.
@ -1198,99 +1313,26 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
def copy(self, recursive=False, to=None, update_attrs=None, copy_revisions=True, keep_live=True, user=None,
process_child_object=None, exclude_fields=None):
# Fill dict with self.specific values
specific_self = self.specific
default_exclude_fields = ['id', 'path', 'depth', 'numchild', 'url_path', 'path', 'index_entries']
exclude_fields = default_exclude_fields + specific_self.exclude_fields_in_copy + (exclude_fields or [])
specific_dict = {}
for field in specific_self._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Ignore m2m relations - they will be copied as child objects
# if modelcluster supports them at all (as it does for tags)
if field.many_to_many:
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
specific_dict[field.name] = getattr(specific_self, field.name)
# copy child m2m relations
for related_field in get_all_child_m2m_relations(specific_self):
# Ignore explicitly excluded fields
if related_field.name in exclude_fields:
continue
field = getattr(specific_self, related_field.name)
if field and hasattr(field, 'all'):
values = field.all()
if values:
specific_dict[related_field.name] = values
# New instance from prepared dict values, in case the instance class implements multiple levels inheritance
page_copy = self.specific_class(**specific_dict)
if not keep_live:
page_copy.live = False
page_copy.has_unpublished_changes = True
page_copy.live_revision = None
page_copy.first_published_at = None
page_copy.last_published_at = None
if keep_live:
base_update_attrs = {}
else:
base_update_attrs = {
'live': False,
'has_unpublished_changes': True,
'live_revision': None,
'first_published_at': None,
'last_published_at': None
}
if user:
page_copy.owner = user
base_update_attrs['owner'] = user
if update_attrs:
for field, value in update_attrs.items():
setattr(page_copy, field, value)
base_update_attrs.update(update_attrs)
if to:
if recursive and (to == self or to.is_descendant_of(self)):
raise Exception("You cannot copy a tree branch recursively into itself")
page_copy = to.add_child(instance=page_copy)
else:
page_copy = self.add_sibling(instance=page_copy)
# A dict that maps child objects to their new ids
# Used to remap child object ids in revisions
child_object_id_map = defaultdict(dict)
# Copy child objects
specific_self = self.specific
for child_relation in get_all_child_relations(specific_self):
accessor_name = child_relation.get_accessor_name()
if accessor_name in exclude_fields:
continue
parental_key_name = child_relation.field.attname
child_objects = getattr(specific_self, accessor_name, None)
# Ignore explicitly excluded fields
if accessor_name in exclude_fields:
continue
if child_objects:
for child_object in child_objects.all():
old_pk = child_object.pk
child_object.pk = None
setattr(child_object, parental_key_name, page_copy.id)
if process_child_object is not None:
process_child_object(specific_self, page_copy, child_relation, child_object)
child_object.save()
# Add mapping to new primary key (so we can apply this change to revisions)
child_object_id_map[accessor_name][old_pk] = child_object.pk
page_copy, child_object_id_map = self._copy(exclude_fields=exclude_fields, update_attrs=base_update_attrs, to=to, recursive=recursive, process_child_object=process_child_object)
# Copy revisions
if copy_revisions:
@ -1361,6 +1403,18 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
return page_copy
def _save_copy_instance(self, instance, to=None, recursive=False, **kwargs):
if not instance.id:
if to:
if recursive and (to == self or to.is_descendant_of(self)):
raise Exception("You cannot copy a tree branch recursively into itself")
instance = to.add_child(instance=instance)
else:
instance = self.add_sibling(instance=instance)
else:
instance.save()
return instance
copy.alters_data = True
def permissions_for_user(self, user):
@ -2785,7 +2839,7 @@ class WorkflowState(models.Model):
]
class TaskState(models.Model):
class TaskState(MultiTableCopyMixin, models.Model):
"""Tracks the status of a given Task for a particular page revision."""
STATUS_IN_PROGRESS = 'in_progress'
STATUS_APPROVED = 'approved'
@ -2898,54 +2952,12 @@ class TaskState(models.Model):
return self
def copy(self, update_attrs=None, exclude_fields=None):
if not update_attrs:
update_attrs = {}
# Fill dict with self.specific values
specific_self = self.specific
default_exclude_fields = ['id']
exclude_fields = default_exclude_fields + specific_self.exclude_fields_in_copy + (exclude_fields or [])
specific_dict = {}
specific_many_to_many_dict = {}
copy_instance, _ = self._copy(exclude_fields, update_attrs)
return copy_instance
for field in specific_self._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
if field.many_to_many:
specific_many_to_many_dict[field.name] = getattr(specific_self, field.name).all()
continue
# Ignore parent links
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
specific_dict[field.name] = getattr(specific_self, field.name)
# New instance from prepared dict values, in case the instance class implements multiple levels inheritance
task_state_copy = specific_self.__class__(**specific_dict)
if update_attrs:
for field, value in update_attrs.items():
if field in specific_many_to_many_dict:
continue
setattr(task_state_copy, field, value)
task_state_copy.save()
# Set many to many fields
for field_name, value in specific_many_to_many_dict.items():
value = update_attrs.get(field_name, value)
getattr(task_state_copy, field_name).set(value)
task_state_copy.save()
return task_state_copy
def _save_copy_instance(self, instance, **kwargs):
instance.save()
return instance
class Meta:
verbose_name = _('Task state')