mirror of
https://github.com/wagtail/wagtail.git
synced 2024-11-30 01:46:24 +01:00
Refactor common copy method mixin from Page and TaskState
This commit is contained in:
parent
0d5c2d6a30
commit
02a706581e
@ -25,7 +25,7 @@ from django.utils.functional import cached_property
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.text import capfirst, slugify
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from modelcluster.fields import ParentalKey
|
||||
from modelcluster.fields import ParentalKey, ParentalManyToManyField
|
||||
from modelcluster.models import (
|
||||
ClusterableModel, get_all_child_m2m_relations, get_all_child_relations)
|
||||
from treebeard.mp_tree import MP_Node
|
||||
@ -44,6 +44,120 @@ logger = logging.getLogger('wagtail.core')
|
||||
PAGE_TEMPLATE_VAR = 'page'
|
||||
|
||||
|
||||
class MultiTableCopyMixin:
|
||||
default_exclude_fields_in_copy = ['id']
|
||||
|
||||
def _get_field_dictionaries(self, exclude_fields=None, **kwargs):
|
||||
specific_self = self.specific
|
||||
exclude_fields = exclude_fields or []
|
||||
specific_dict = {}
|
||||
specific_m2m_dict = {}
|
||||
|
||||
for field in specific_self._meta.get_fields():
|
||||
# Ignore explicitly excluded fields
|
||||
if field.name in exclude_fields:
|
||||
continue
|
||||
|
||||
# Ignore reverse relations
|
||||
if field.auto_created:
|
||||
continue
|
||||
|
||||
# Copy child m2m relations
|
||||
# Otherwise add them to the m2m dict
|
||||
if field.many_to_many:
|
||||
if isinstance(field, ParentalManyToManyField):
|
||||
parental_field = getattr(specific_self, field.name)
|
||||
if hasattr(parental_field, 'all'):
|
||||
values = parental_field.all()
|
||||
if values:
|
||||
specific_dict[field.name] = values
|
||||
else:
|
||||
specific_m2m_dict[field.name] = getattr(specific_self, field.name).all()
|
||||
continue
|
||||
|
||||
# Ignore parent links (page_ptr)
|
||||
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
|
||||
continue
|
||||
|
||||
specific_dict[field.name] = getattr(specific_self, field.name)
|
||||
|
||||
return specific_dict, specific_m2m_dict
|
||||
|
||||
def _get_copy_instance(self, specific_dict, specific_m2m_dict, update_attrs=None, **kwargs):
|
||||
if not update_attrs:
|
||||
update_attrs = {}
|
||||
|
||||
specific_class = self.specific.__class__
|
||||
|
||||
copy_instance = specific_class(**specific_dict)
|
||||
|
||||
if update_attrs:
|
||||
for field, value in update_attrs.items():
|
||||
if field in specific_m2m_dict:
|
||||
continue
|
||||
setattr(copy_instance, field, value)
|
||||
|
||||
return copy_instance
|
||||
|
||||
def _save_copy_instance(self, instance, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _set_m2m_relations(self, instance, specific_m2m_dict, update_attrs=None, **kwargs):
|
||||
if not update_attrs:
|
||||
update_attrs = {}
|
||||
for field_name, value in specific_m2m_dict.items():
|
||||
value = update_attrs.get(field_name, value)
|
||||
getattr(instance, field_name).set(value)
|
||||
|
||||
return self._save_copy_instance(instance)
|
||||
|
||||
def _copy_child_objects_to_instance(self, instance, exclude_fields=None, process_child_object=None, **kwargs):
|
||||
# A dict that maps child objects to their new ids
|
||||
# Used to remap child object ids in revisions
|
||||
child_object_id_map = defaultdict(dict)
|
||||
exclude_fields = exclude_fields or []
|
||||
specific_self = self.specific
|
||||
for child_relation in get_all_child_relations(specific_self):
|
||||
accessor_name = child_relation.get_accessor_name()
|
||||
|
||||
if accessor_name in exclude_fields:
|
||||
continue
|
||||
|
||||
parental_key_name = child_relation.field.attname
|
||||
child_objects = getattr(specific_self, accessor_name, None)
|
||||
|
||||
if child_objects:
|
||||
for child_object in child_objects.all():
|
||||
old_pk = child_object.pk
|
||||
child_object.pk = None
|
||||
setattr(child_object, parental_key_name, instance.id)
|
||||
|
||||
if process_child_object is not None:
|
||||
process_child_object(specific_self, instance, child_relation, child_object)
|
||||
|
||||
child_object.save()
|
||||
|
||||
# Add mapping to new primary key (so we can apply this change to revisions)
|
||||
child_object_id_map[accessor_name][old_pk] = child_object.pk
|
||||
|
||||
return child_object_id_map
|
||||
|
||||
def _copy(self, exclude_fields=None, update_attrs=None, process_child_object=None, **kwargs):
|
||||
exclude_fields = self.default_exclude_fields_in_copy + self.specific.exclude_fields_in_copy + (exclude_fields or [])
|
||||
|
||||
specific_dict, specific_m2m_dict = self._get_field_dictionaries(exclude_fields=exclude_fields, **kwargs)
|
||||
|
||||
copy_instance = self._get_copy_instance(specific_dict, specific_m2m_dict, update_attrs=update_attrs, **kwargs)
|
||||
|
||||
copy_instance = self._save_copy_instance(copy_instance, **kwargs)
|
||||
|
||||
copy_instance = self._set_m2m_relations(copy_instance, specific_m2m_dict, update_attrs, **kwargs)
|
||||
|
||||
child_object_id_map = self._copy_child_objects_to_instance(copy_instance, exclude_fields=exclude_fields, process_child_object=process_child_object, **kwargs)
|
||||
|
||||
return copy_instance, child_object_id_map
|
||||
|
||||
|
||||
class SiteManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
return super(SiteManager, self).get_queryset().order_by(Lower("hostname"))
|
||||
@ -254,7 +368,7 @@ class AbstractPage(MP_Node):
|
||||
abstract = True
|
||||
|
||||
|
||||
class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
|
||||
class Page(MultiTableCopyMixin, AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
|
||||
title = models.CharField(
|
||||
verbose_name=_('title'),
|
||||
max_length=255,
|
||||
@ -387,6 +501,7 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
|
||||
|
||||
# An array of additional field names that will not be included when a Page is copied.
|
||||
exclude_fields_in_copy = []
|
||||
default_exclude_fields_in_copy = ['id', 'path', 'depth', 'numchild', 'url_path', 'path', 'index_entries']
|
||||
|
||||
# Define these attributes early to avoid masking errors. (Issue #3078)
|
||||
# The canonical definition is in wagtailadmin.edit_handlers.
|
||||
@ -1198,99 +1313,26 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
|
||||
|
||||
def copy(self, recursive=False, to=None, update_attrs=None, copy_revisions=True, keep_live=True, user=None,
|
||||
process_child_object=None, exclude_fields=None):
|
||||
# Fill dict with self.specific values
|
||||
|
||||
specific_self = self.specific
|
||||
default_exclude_fields = ['id', 'path', 'depth', 'numchild', 'url_path', 'path', 'index_entries']
|
||||
exclude_fields = default_exclude_fields + specific_self.exclude_fields_in_copy + (exclude_fields or [])
|
||||
specific_dict = {}
|
||||
|
||||
for field in specific_self._meta.get_fields():
|
||||
# Ignore explicitly excluded fields
|
||||
if field.name in exclude_fields:
|
||||
continue
|
||||
|
||||
# Ignore reverse relations
|
||||
if field.auto_created:
|
||||
continue
|
||||
|
||||
# Ignore m2m relations - they will be copied as child objects
|
||||
# if modelcluster supports them at all (as it does for tags)
|
||||
if field.many_to_many:
|
||||
continue
|
||||
|
||||
# Ignore parent links (page_ptr)
|
||||
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
|
||||
continue
|
||||
|
||||
specific_dict[field.name] = getattr(specific_self, field.name)
|
||||
|
||||
# copy child m2m relations
|
||||
for related_field in get_all_child_m2m_relations(specific_self):
|
||||
# Ignore explicitly excluded fields
|
||||
if related_field.name in exclude_fields:
|
||||
continue
|
||||
|
||||
field = getattr(specific_self, related_field.name)
|
||||
if field and hasattr(field, 'all'):
|
||||
values = field.all()
|
||||
if values:
|
||||
specific_dict[related_field.name] = values
|
||||
|
||||
# New instance from prepared dict values, in case the instance class implements multiple levels inheritance
|
||||
page_copy = self.specific_class(**specific_dict)
|
||||
|
||||
if not keep_live:
|
||||
page_copy.live = False
|
||||
page_copy.has_unpublished_changes = True
|
||||
page_copy.live_revision = None
|
||||
page_copy.first_published_at = None
|
||||
page_copy.last_published_at = None
|
||||
if keep_live:
|
||||
base_update_attrs = {}
|
||||
else:
|
||||
base_update_attrs = {
|
||||
'live': False,
|
||||
'has_unpublished_changes': True,
|
||||
'live_revision': None,
|
||||
'first_published_at': None,
|
||||
'last_published_at': None
|
||||
}
|
||||
|
||||
if user:
|
||||
page_copy.owner = user
|
||||
base_update_attrs['owner'] = user
|
||||
|
||||
if update_attrs:
|
||||
for field, value in update_attrs.items():
|
||||
setattr(page_copy, field, value)
|
||||
base_update_attrs.update(update_attrs)
|
||||
|
||||
if to:
|
||||
if recursive and (to == self or to.is_descendant_of(self)):
|
||||
raise Exception("You cannot copy a tree branch recursively into itself")
|
||||
page_copy = to.add_child(instance=page_copy)
|
||||
else:
|
||||
page_copy = self.add_sibling(instance=page_copy)
|
||||
|
||||
# A dict that maps child objects to their new ids
|
||||
# Used to remap child object ids in revisions
|
||||
child_object_id_map = defaultdict(dict)
|
||||
|
||||
# Copy child objects
|
||||
specific_self = self.specific
|
||||
for child_relation in get_all_child_relations(specific_self):
|
||||
accessor_name = child_relation.get_accessor_name()
|
||||
|
||||
if accessor_name in exclude_fields:
|
||||
continue
|
||||
|
||||
parental_key_name = child_relation.field.attname
|
||||
child_objects = getattr(specific_self, accessor_name, None)
|
||||
# Ignore explicitly excluded fields
|
||||
if accessor_name in exclude_fields:
|
||||
continue
|
||||
|
||||
if child_objects:
|
||||
for child_object in child_objects.all():
|
||||
old_pk = child_object.pk
|
||||
child_object.pk = None
|
||||
setattr(child_object, parental_key_name, page_copy.id)
|
||||
|
||||
if process_child_object is not None:
|
||||
process_child_object(specific_self, page_copy, child_relation, child_object)
|
||||
|
||||
child_object.save()
|
||||
|
||||
# Add mapping to new primary key (so we can apply this change to revisions)
|
||||
child_object_id_map[accessor_name][old_pk] = child_object.pk
|
||||
page_copy, child_object_id_map = self._copy(exclude_fields=exclude_fields, update_attrs=base_update_attrs, to=to, recursive=recursive, process_child_object=process_child_object)
|
||||
|
||||
# Copy revisions
|
||||
if copy_revisions:
|
||||
@ -1361,6 +1403,18 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase):
|
||||
|
||||
return page_copy
|
||||
|
||||
def _save_copy_instance(self, instance, to=None, recursive=False, **kwargs):
|
||||
if not instance.id:
|
||||
if to:
|
||||
if recursive and (to == self or to.is_descendant_of(self)):
|
||||
raise Exception("You cannot copy a tree branch recursively into itself")
|
||||
instance = to.add_child(instance=instance)
|
||||
else:
|
||||
instance = self.add_sibling(instance=instance)
|
||||
else:
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
copy.alters_data = True
|
||||
|
||||
def permissions_for_user(self, user):
|
||||
@ -2785,7 +2839,7 @@ class WorkflowState(models.Model):
|
||||
]
|
||||
|
||||
|
||||
class TaskState(models.Model):
|
||||
class TaskState(MultiTableCopyMixin, models.Model):
|
||||
"""Tracks the status of a given Task for a particular page revision."""
|
||||
STATUS_IN_PROGRESS = 'in_progress'
|
||||
STATUS_APPROVED = 'approved'
|
||||
@ -2898,54 +2952,12 @@ class TaskState(models.Model):
|
||||
return self
|
||||
|
||||
def copy(self, update_attrs=None, exclude_fields=None):
|
||||
if not update_attrs:
|
||||
update_attrs = {}
|
||||
# Fill dict with self.specific values
|
||||
specific_self = self.specific
|
||||
default_exclude_fields = ['id']
|
||||
exclude_fields = default_exclude_fields + specific_self.exclude_fields_in_copy + (exclude_fields or [])
|
||||
specific_dict = {}
|
||||
specific_many_to_many_dict = {}
|
||||
copy_instance, _ = self._copy(exclude_fields, update_attrs)
|
||||
return copy_instance
|
||||
|
||||
for field in specific_self._meta.get_fields():
|
||||
# Ignore explicitly excluded fields
|
||||
if field.name in exclude_fields:
|
||||
continue
|
||||
|
||||
# Ignore reverse relations
|
||||
if field.auto_created:
|
||||
continue
|
||||
|
||||
if field.many_to_many:
|
||||
specific_many_to_many_dict[field.name] = getattr(specific_self, field.name).all()
|
||||
continue
|
||||
|
||||
# Ignore parent links
|
||||
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
|
||||
continue
|
||||
|
||||
specific_dict[field.name] = getattr(specific_self, field.name)
|
||||
|
||||
# New instance from prepared dict values, in case the instance class implements multiple levels inheritance
|
||||
task_state_copy = specific_self.__class__(**specific_dict)
|
||||
|
||||
if update_attrs:
|
||||
for field, value in update_attrs.items():
|
||||
if field in specific_many_to_many_dict:
|
||||
continue
|
||||
setattr(task_state_copy, field, value)
|
||||
|
||||
task_state_copy.save()
|
||||
|
||||
# Set many to many fields
|
||||
|
||||
for field_name, value in specific_many_to_many_dict.items():
|
||||
value = update_attrs.get(field_name, value)
|
||||
getattr(task_state_copy, field_name).set(value)
|
||||
|
||||
task_state_copy.save()
|
||||
|
||||
return task_state_copy
|
||||
def _save_copy_instance(self, instance, **kwargs):
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Task state')
|
||||
|
Loading…
Reference in New Issue
Block a user