mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-28 09:16:49 +01:00
Part 1: Make everything taggable Backend (starting with Actions) (#8528)
This commit is contained in:
parent
4f630a31f4
commit
459d304e95
71
ee/api/test/test_action.py
Normal file
71
ee/api/test/test_action.py
Normal file
@ -0,0 +1,71 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from django.utils import timezone
|
||||
from rest_framework import status
|
||||
|
||||
from posthog.models import Action, Tag
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
# Testing enterprise properties of actions here (i.e., tagging).
|
||||
|
||||
|
||||
@pytest.mark.ee
|
||||
class TestActionApi(APIBaseTest):
|
||||
def test_create_action_update_delete_tags(self):
|
||||
from ee.models.license import License, LicenseManager
|
||||
|
||||
super(LicenseManager, cast(LicenseManager, License.objects)).create(
|
||||
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
|
||||
)
|
||||
|
||||
response = self.client.post(f"/api/projects/{self.team.id}/actions/", data={"name": "user signed up",},)
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(response.json()["tags"], [])
|
||||
|
||||
response = self.client.patch(
|
||||
f"/api/projects/{self.team.id}/actions/{response.json()['id']}",
|
||||
data={"name": "user signed up", "tags": ["hello", "random"]},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(set(response.json()["tags"]), {"hello", "random"})
|
||||
|
||||
response = self.client.patch(
|
||||
f"/api/projects/{self.team.id}/actions/{response.json()['id']}", data={"name": "user signed up", "tags": []}
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()["tags"], [])
|
||||
|
||||
def test_create_action_with_tags(self):
|
||||
from ee.models.license import License, LicenseManager
|
||||
|
||||
super(LicenseManager, cast(LicenseManager, License.objects)).create(
|
||||
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
f"/api/projects/{self.team.id}/actions/",
|
||||
data={"name": "user signed up", "tags": ["nightly", "is", "a", "good", "girl"]},
|
||||
)
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertEqual(set(response.json()["tags"]), {"nightly", "is", "a", "good", "girl"})
|
||||
|
||||
def test_actions_does_not_nplus1(self):
|
||||
from ee.models.license import License, LicenseManager
|
||||
|
||||
super(LicenseManager, cast(LicenseManager, License.objects)).create(
|
||||
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
|
||||
)
|
||||
|
||||
tag = Tag.objects.create(name="tag", team=self.team)
|
||||
for i in range(20):
|
||||
action = Action.objects.create(team=self.team, name=f"action_{i}")
|
||||
action.tagged_items.create(tag=tag)
|
||||
|
||||
# django_session + user + team + organizationmembership + organization + action + taggeditem + actionstep
|
||||
with self.assertNumQueries(8):
|
||||
response = self.client.get(f"/api/projects/{self.team.id}/actions")
|
||||
self.assertEqual(response.json()["results"][0]["tags"][0], "tag")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.json()["results"]), 20)
|
@ -64,6 +64,7 @@ class License(models.Model):
|
||||
AvailableFeature.GROUP_ANALYTICS,
|
||||
AvailableFeature.MULTIVARIATE_FLAGS,
|
||||
AvailableFeature.EXPERIMENTATION,
|
||||
AvailableFeature.TAGGING,
|
||||
]
|
||||
|
||||
ENTERPRISE_PLAN = "enterprise"
|
||||
|
@ -38,6 +38,7 @@ export enum AvailableFeature {
|
||||
GROUP_ANALYTICS = 'group_analytics',
|
||||
MULTIVARIATE_FLAGS = 'multivariate_flags',
|
||||
EXPERIMENTATION = 'experimentation',
|
||||
TAGGING = 'tagging',
|
||||
}
|
||||
|
||||
export enum Realm {
|
||||
|
@ -4,7 +4,7 @@ axes: 0006_remove_accesslog_trusted
|
||||
contenttypes: 0002_remove_content_type_name
|
||||
database: 0002_auto_20190129_2304
|
||||
ee: 0007_dashboard_permissions
|
||||
posthog: 0205_auto_20220204_1748
|
||||
posthog: 0206_global_tags_setup
|
||||
rest_hooks: 0002_swappable_hook_model
|
||||
sessions: 0001_initial
|
||||
social_django: 0010_uid_db_index
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.core.cache import cache
|
||||
@ -24,16 +24,9 @@ from posthog.api.shared import UserBasicSerializer
|
||||
from posthog.api.utils import get_target_entity
|
||||
from posthog.auth import PersonalAPIKeyAuthentication, TemporaryTokenAuthentication
|
||||
from posthog.celery import update_cache_item_task
|
||||
from posthog.constants import (
|
||||
INSIGHT_STICKINESS,
|
||||
TREND_FILTER_TYPE_ACTIONS,
|
||||
TREND_FILTER_TYPE_EVENTS,
|
||||
TRENDS_STICKINESS,
|
||||
AvailableFeature,
|
||||
)
|
||||
from posthog.constants import INSIGHT_STICKINESS, TREND_FILTER_TYPE_ACTIONS, TREND_FILTER_TYPE_EVENTS, TRENDS_STICKINESS
|
||||
from posthog.decorators import CacheType, cached_function
|
||||
from posthog.event_usage import report_user_action
|
||||
from posthog.filters import term_search_filter_sql
|
||||
from posthog.models import (
|
||||
Action,
|
||||
ActionStep,
|
||||
@ -54,6 +47,7 @@ from posthog.queries import base, retention, stickiness, trends
|
||||
from posthog.utils import generate_cache_key, get_safe_cache, should_refresh
|
||||
|
||||
from .person import get_person_name
|
||||
from .tagged_item import TaggedItemSerializerMixin, TaggedItemViewSetMixin
|
||||
|
||||
|
||||
class ActionStepSerializer(serializers.HyperlinkedModelSerializer):
|
||||
@ -82,7 +76,7 @@ class ActionStepSerializer(serializers.HyperlinkedModelSerializer):
|
||||
}
|
||||
|
||||
|
||||
class ActionSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class ActionSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedModelSerializer):
|
||||
steps = ActionStepSerializer(many=True, required=False)
|
||||
created_by = UserBasicSerializer(read_only=True)
|
||||
is_calculating = serializers.SerializerMethodField()
|
||||
@ -93,6 +87,7 @@ class ActionSerializer(serializers.HyperlinkedModelSerializer):
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"tags",
|
||||
"post_to_slack",
|
||||
"slack_message_format",
|
||||
"steps",
|
||||
@ -179,7 +174,7 @@ class ActionSerializer(serializers.HyperlinkedModelSerializer):
|
||||
return instance
|
||||
|
||||
|
||||
class ActionViewSet(StructuredViewSetMixin, viewsets.ModelViewSet):
|
||||
class ActionViewSet(TaggedItemViewSetMixin, StructuredViewSetMixin, viewsets.ModelViewSet):
|
||||
renderer_classes = tuple(api_settings.DEFAULT_RENDERER_CLASSES) + (csvrenderers.PaginatedCSVRenderer,)
|
||||
queryset = Action.objects.all()
|
||||
serializer_class = ActionSerializer
|
||||
|
86
posthog/api/tagged_item.py
Normal file
86
posthog/api/tagged_item.py
Normal file
@ -0,0 +1,86 @@
|
||||
from django.db.models import Prefetch, Q
|
||||
from rest_framework import serializers, viewsets
|
||||
|
||||
from posthog.constants import AvailableFeature
|
||||
from posthog.exceptions import EnterpriseFeatureException
|
||||
from posthog.models import Tag, TaggedItem
|
||||
|
||||
|
||||
class TaggedItemSerializerMixin(serializers.Serializer):
|
||||
"""
|
||||
Serializer mixin that resolves appropriate response for tags depending on license.
|
||||
"""
|
||||
|
||||
tags = serializers.ListField(required=False)
|
||||
|
||||
def _is_licensed(self):
|
||||
return (
|
||||
"request" in self.context
|
||||
and not self.context["request"].user.is_anonymous
|
||||
and self.context["request"].user.organization.is_feature_available(AvailableFeature.TAGGING)
|
||||
)
|
||||
|
||||
def _attempt_set_tags(self, tags, obj):
|
||||
|
||||
if not self._is_licensed() and tags is not None:
|
||||
raise EnterpriseFeatureException()
|
||||
|
||||
if not obj or tags is None:
|
||||
# If the object hasn't been created yet, this method will be called again on the create method.
|
||||
return
|
||||
|
||||
# Normalize and dedupe tags
|
||||
deduped_tags = list(set([t.strip().lower() for t in tags]))
|
||||
tagged_item_objects = []
|
||||
|
||||
# Create tags
|
||||
for tag in deduped_tags:
|
||||
tag_instance, _ = Tag.objects.get_or_create(name=tag, team_id=obj.team_id)
|
||||
tagged_item_instance, _ = obj.tagged_items.get_or_create(tag_id=tag_instance.id)
|
||||
tagged_item_objects.append(tagged_item_instance)
|
||||
|
||||
# Delete tags that are missing
|
||||
obj.tagged_items.exclude(tag__name__in=deduped_tags).delete()
|
||||
|
||||
# Cleanup tags that aren't used by team
|
||||
Tag.objects.filter(Q(team_id=obj.team_id) & Q(tagged_items__isnull=True)).delete()
|
||||
|
||||
obj.prefetched_tags = tagged_item_objects
|
||||
|
||||
def to_representation(self, obj):
|
||||
ret = super(TaggedItemSerializerMixin, self).to_representation(obj)
|
||||
ret["tags"] = []
|
||||
if self._is_licensed():
|
||||
if hasattr(obj, "prefetched_tags"):
|
||||
ret["tags"] = [p.tag.name for p in obj.prefetched_tags]
|
||||
else:
|
||||
ret["tags"] = list(obj.tagged_items.values_list("tag__name", flat=True)) if obj.tagged_items else []
|
||||
return ret
|
||||
|
||||
def create(self, validated_data):
|
||||
validated_data.pop("tags", None)
|
||||
instance = super(TaggedItemSerializerMixin, self).create(validated_data)
|
||||
self._attempt_set_tags(self.initial_data.get("tags"), instance)
|
||||
return instance
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
instance = super(TaggedItemSerializerMixin, self).update(instance, validated_data)
|
||||
self._attempt_set_tags(self.initial_data.get("tags"), instance)
|
||||
return instance
|
||||
|
||||
|
||||
class TaggedItemViewSetMixin(viewsets.GenericViewSet):
|
||||
def is_licensed(self):
|
||||
return (
|
||||
not self.request.user.is_anonymous
|
||||
# The below triggers an extra query to resolve user's organization.
|
||||
and self.request.user.organization.is_feature_available(AvailableFeature.TAGGING) # type: ignore
|
||||
)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super(TaggedItemViewSetMixin, self).get_queryset()
|
||||
if self.is_licensed():
|
||||
return queryset.prefetch_related(
|
||||
Prefetch("tagged_items", queryset=TaggedItem.objects.select_related("tag"), to_attr="prefetched_tags")
|
||||
)
|
||||
return queryset
|
@ -6,7 +6,7 @@ from rest_framework import status
|
||||
|
||||
from ee.clickhouse.models.event import create_event
|
||||
from ee.clickhouse.util import ClickhouseTestMixin
|
||||
from posthog.models import Action, ActionStep, Element, Event, Organization
|
||||
from posthog.models import Action, ActionStep, Element, Event, Organization, Tag
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
|
||||
@ -166,8 +166,8 @@ class TestActionApi(ClickhouseTestMixin, APIBaseTest):
|
||||
)
|
||||
|
||||
# test queries
|
||||
with self.assertNumQueries(6):
|
||||
# Django session, PostHog user, PostHog team, PostHog org membership,
|
||||
with self.assertNumQueries(7):
|
||||
# Django session, PostHog user, PostHog team, PostHog org membership, PostHog org
|
||||
# PostHog action, PostHog action step
|
||||
self.client.get(f"/api/projects/{self.team.id}/actions/")
|
||||
|
||||
@ -280,3 +280,63 @@ class TestActionApi(ClickhouseTestMixin, APIBaseTest):
|
||||
action.calculate_events()
|
||||
response = self.client.get(f"/api/projects/{self.team.id}/actions/{action.id}/count").json()
|
||||
self.assertEqual(response, {"count": 1})
|
||||
|
||||
def test_get_tags_on_non_ee_returns_empty_list(self):
|
||||
action = Action.objects.create(team=self.team, name="bla")
|
||||
tag = Tag.objects.create(name="random", team_id=self.team.id)
|
||||
action.tagged_items.create(tag_id=tag.id)
|
||||
|
||||
response = self.client.get(f"/api/projects/{self.team.id}/actions/{action.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()["tags"], [])
|
||||
self.assertEqual(Action.objects.all().count(), 1)
|
||||
|
||||
def test_create_tags_on_non_ee_not_allowed(self):
|
||||
response = self.client.post(
|
||||
f"/api/projects/{self.team.id}/actions/", {"name": "Default", "tags": ["random", "hello"]},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
|
||||
self.assertEqual(Tag.objects.all().count(), 0)
|
||||
|
||||
def test_update_tags_on_non_ee_not_allowed(self):
|
||||
action = Action.objects.create(team_id=self.team.id, name="private dashboard")
|
||||
tag = Tag.objects.create(name="random", team_id=self.team.id)
|
||||
action.tagged_items.create(tag_id=tag.id)
|
||||
|
||||
response = self.client.patch(
|
||||
f"/api/projects/{self.team.id}/actions/{action.id}",
|
||||
{"name": "action new name", "tags": ["random", "hello"], "description": "Internal system metrics.",},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
|
||||
|
||||
def test_undefined_tags_allows_other_props_to_update(self):
|
||||
action = Action.objects.create(team_id=self.team.id, name="private action")
|
||||
tag = Tag.objects.create(name="random", team_id=self.team.id)
|
||||
action.tagged_items.create(tag_id=tag.id)
|
||||
|
||||
response = self.client.patch(
|
||||
f"/api/projects/{self.team.id}/actions/{action.id}",
|
||||
{"name": "action new name", "description": "Internal system metrics.",},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.json()["name"], "action new name")
|
||||
self.assertEqual(response.json()["description"], "Internal system metrics.")
|
||||
|
||||
def test_empty_tags_does_not_delete_tags(self):
|
||||
action = Action.objects.create(team_id=self.team.id, name="private dashboard")
|
||||
tag = Tag.objects.create(name="random", team_id=self.team.id)
|
||||
action.tagged_items.create(tag_id=tag.id)
|
||||
|
||||
self.assertEqual(Action.objects.all().count(), 1)
|
||||
|
||||
response = self.client.patch(
|
||||
f"/api/projects/{self.team.id}/actions/{action.id}",
|
||||
{"name": "action new name", "description": "Internal system metrics.", "tags": []},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
|
||||
self.assertEqual(Action.objects.all().count(), 1)
|
||||
|
@ -1,27 +1,14 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
|
||||
from rest_framework import request, status
|
||||
from sentry_sdk import capture_exception
|
||||
from statshog.defaults.django import statsd
|
||||
|
||||
from posthog.constants import ENTITY_ID, ENTITY_MATH, ENTITY_TYPE
|
||||
from posthog.exceptions import RequestParsingError, generate_exception_response
|
||||
from posthog.models import Entity
|
||||
from posthog.models.action import Action
|
||||
from posthog.models.entity import MATH_TYPE
|
||||
from posthog.models.event import Event
|
||||
from posthog.models.filters.filter import Filter
|
||||
from posthog.models.filters.stickiness_filter import StickinessFilter
|
||||
from posthog.models.team import Team
|
||||
|
@ -18,6 +18,7 @@ class AvailableFeature(str, Enum):
|
||||
GROUP_ANALYTICS = "group_analytics"
|
||||
MULTIVARIATE_FLAGS = "multivariate_flags"
|
||||
EXPERIMENTATION = "experimentation"
|
||||
TAGGING = "tagging"
|
||||
|
||||
|
||||
TREND_FILTER_TYPE_ACTIONS = "actions"
|
||||
|
64
posthog/migrations/0206_global_tags_setup.py
Normal file
64
posthog/migrations/0206_global_tags_setup.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Generated by Django 3.2.5 on 2022-02-11 23:56
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import posthog.models.utils
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("posthog", "0205_auto_20220204_1748"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Tag",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=255)),
|
||||
("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="TaggedItem",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
(
|
||||
"action",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="tagged_items",
|
||||
to="posthog.action",
|
||||
),
|
||||
),
|
||||
(
|
||||
"tag",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, related_name="tagged_items", to="posthog.tag"
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.AddConstraint(
|
||||
model_name="taggeditem",
|
||||
constraint=models.CheckConstraint(
|
||||
check=models.Q(models.Q(("action__isnull", False)), _connector="OR"), name="exactly_one_related_object"
|
||||
),
|
||||
),
|
||||
migrations.AlterUniqueTogether(name="taggeditem", unique_together={("tag", "action")},),
|
||||
migrations.AlterUniqueTogether(name="tag", unique_together={("name", "team")},),
|
||||
]
|
@ -23,6 +23,8 @@ from .plugin import Plugin, PluginAttachment, PluginConfig, PluginLogEntry
|
||||
from .property import Property
|
||||
from .property_definition import PropertyDefinition
|
||||
from .session_recording_event import SessionRecordingEvent
|
||||
from .tag import Tag
|
||||
from .tagged_item import TaggedItem
|
||||
from .team import Team
|
||||
from .user import User, UserManager
|
||||
|
||||
@ -59,6 +61,8 @@ __all__ = [
|
||||
"Property",
|
||||
"PropertyDefinition",
|
||||
"SessionRecordingEvent",
|
||||
"Tag",
|
||||
"TaggedItem",
|
||||
"Team",
|
||||
"User",
|
||||
"UserManager",
|
||||
|
14
posthog/models/tag.py
Normal file
14
posthog/models/tag.py
Normal file
@ -0,0 +1,14 @@
|
||||
from django.db import models
|
||||
|
||||
from posthog.models.utils import UUIDModel
|
||||
|
||||
|
||||
class Tag(UUIDModel):
|
||||
name: models.CharField = models.CharField(max_length=255)
|
||||
team: models.ForeignKey = models.ForeignKey("Team", on_delete=models.CASCADE)
|
||||
|
||||
class Meta:
|
||||
unique_together = ("name", "team")
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
61
posthog/models/tagged_item.py
Normal file
61
posthog/models/tagged_item.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import List, Union
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
|
||||
from posthog.models.utils import UUIDModel
|
||||
|
||||
RELATED_OBJECTS = ("action",)
|
||||
|
||||
|
||||
def build_check():
|
||||
built_check_list: List[Union[Q, Q]] = []
|
||||
# Only one object field can be populated
|
||||
for o_field in RELATED_OBJECTS:
|
||||
built_check_list.append(
|
||||
Q(*[(f"{_o_field}__isnull", _o_field != o_field) for _o_field in RELATED_OBJECTS], _connector="AND")
|
||||
)
|
||||
return Q(*built_check_list, _connector="OR")
|
||||
|
||||
|
||||
class TaggedItem(UUIDModel):
|
||||
"""
|
||||
Taggable describes global tag-object relationships.
|
||||
Note: This is an EE only feature, however the model exists in posthog so that it is backwards accessible from all
|
||||
models. Whether we should be able to interact with this table is determined in the `TaggedItemSerializer` which
|
||||
imports `EnterpriseTaggedItemSerializer` if the feature is available.
|
||||
|
||||
Today, tags exist at the model-level making it impossible to aggregate, filter, and query objects appwide by tags.
|
||||
We want to deprecate model-specific tags and refactor tag relationships into a separate table that keeps track of
|
||||
tag-object relationships.
|
||||
|
||||
Models that are taggable throughout the app are listed as separate fields below.
|
||||
https://docs.djangoproject.com/en/4.0/ref/contrib/contenttypes/#generic-relations
|
||||
"""
|
||||
|
||||
tag: models.ForeignKey = models.ForeignKey("Tag", on_delete=models.CASCADE, related_name="tagged_items")
|
||||
|
||||
# When adding a new taggeditem-model relationship, make sure to add the foreign key field and append field name to
|
||||
# the `RELATED_OBJECTS` tuple above.
|
||||
action: models.ForeignKey = models.ForeignKey(
|
||||
"Action", on_delete=models.CASCADE, null=True, blank=True, related_name="tagged_items"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
# Make sure to add new key to uniqueness constraint when extending tag functionality to new model
|
||||
unique_together = ("tag",) + RELATED_OBJECTS
|
||||
constraints = [models.CheckConstraint(check=build_check(), name="exactly_one_related_object",)]
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
"""Ensure that exactly one of object columns can be set."""
|
||||
if sum(map(bool, [getattr(self, o_field) for o_field in RELATED_OBJECTS])) != 1:
|
||||
raise ValidationError("Exactly one object field must be set.")
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.full_clean()
|
||||
return super(TaggedItem, self).save(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return self.tag
|
Loading…
Reference in New Issue
Block a user