diff --git a/wagtail/models/__init__.py b/wagtail/models/__init__.py index 42a3905f47..99412ef6ea 100644 --- a/wagtail/models/__init__.py +++ b/wagtail/models/__init__.py @@ -905,6 +905,21 @@ class WorkflowMixin: """Returns the active workflow assigned to the object.""" return self.get_default_workflow() + @property + def workflow_states(self): + """ + Returns workflow states that belong to the object. + + To allow filtering ``WorkflowState`` queries by the object, + subclasses should define a + :class:`~django.contrib.contenttypes.fields.GenericRelation` to + :class:`~wagtail.models.WorkflowState` with the desired + ``related_query_name``. This property can be replaced with the + ``GenericRelation`` or overridden to allow custom logic, which can be + useful if the model has inheritance. + """ + return WorkflowState.objects.for_instance(self) + @property def workflow_in_progress(self): """Returns True if a workflow is in progress on the current object, otherwise False.""" @@ -919,11 +934,9 @@ class WorkflowMixin: return True return False - return ( - WorkflowState.objects.for_instance(self) - .filter(status=WorkflowState.STATUS_IN_PROGRESS) - .exists() - ) + return self.workflow_states.filter( + status=WorkflowState.STATUS_IN_PROGRESS + ).exists() @property def current_workflow_state(self): @@ -940,8 +953,7 @@ class WorkflowMixin: return return ( - WorkflowState.objects.for_instance(self) - .active() + self.workflow_states.active() .select_related("current_task_state__task") .first() ) @@ -1074,7 +1086,11 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase): _revisions = GenericRelation("wagtailcore.Revision", related_query_name="page") - workflow_states = GenericRelation( + # Add GenericRelation to allow WorkflowState.objects.filter(page=...) queries. + # There is no need to override the workflow_states property, as the default + # implementation in WorkflowMixin already ensures that the queryset uses the + # base Page content type. + _workflow_states = GenericRelation( "wagtailcore.WorkflowState", content_type_field="base_content_type", object_id_field="object_id", diff --git a/wagtail/query.py b/wagtail/query.py index d3ecff938c..bcce4aa2f0 100644 --- a/wagtail/query.py +++ b/wagtail/query.py @@ -455,7 +455,7 @@ class PageQuerySet(SearchableQuerySetMixin, TreeQuerySet): return self.prefetch_related( Prefetch( - "workflow_states", + "_workflow_states", queryset=workflow_states, to_attr="_current_workflow_states", )