0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-28 09:16:49 +01:00

Part 1: Make everything taggable Backend (starting with Actions) (#8528)

This commit is contained in:
Alex Gyujin Kim 2022-02-13 22:19:53 -08:00 committed by GitHub
parent 4f630a31f4
commit 459d304e95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 374 additions and 29 deletions

View File

@ -0,0 +1,71 @@
from typing import cast
import pytest
from django.utils import timezone
from rest_framework import status
from posthog.models import Action, Tag
from posthog.test.base import APIBaseTest
# Testing enterprise properties of actions here (i.e., tagging).
@pytest.mark.ee
class TestActionApi(APIBaseTest):
def test_create_action_update_delete_tags(self):
from ee.models.license import License, LicenseManager
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
)
response = self.client.post(f"/api/projects/{self.team.id}/actions/", data={"name": "user signed up",},)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.json()["tags"], [])
response = self.client.patch(
f"/api/projects/{self.team.id}/actions/{response.json()['id']}",
data={"name": "user signed up", "tags": ["hello", "random"]},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(set(response.json()["tags"]), {"hello", "random"})
response = self.client.patch(
f"/api/projects/{self.team.id}/actions/{response.json()['id']}", data={"name": "user signed up", "tags": []}
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["tags"], [])
def test_create_action_with_tags(self):
from ee.models.license import License, LicenseManager
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
)
response = self.client.post(
f"/api/projects/{self.team.id}/actions/",
data={"name": "user signed up", "tags": ["nightly", "is", "a", "good", "girl"]},
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(set(response.json()["tags"]), {"nightly", "is", "a", "good", "girl"})
def test_actions_does_not_nplus1(self):
from ee.models.license import License, LicenseManager
super(LicenseManager, cast(LicenseManager, License.objects)).create(
key="key_123", plan="enterprise", valid_until=timezone.datetime(2038, 1, 19, 3, 14, 7), max_users=3,
)
tag = Tag.objects.create(name="tag", team=self.team)
for i in range(20):
action = Action.objects.create(team=self.team, name=f"action_{i}")
action.tagged_items.create(tag=tag)
# django_session + user + team + organizationmembership + organization + action + taggeditem + actionstep
with self.assertNumQueries(8):
response = self.client.get(f"/api/projects/{self.team.id}/actions")
self.assertEqual(response.json()["results"][0]["tags"][0], "tag")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()["results"]), 20)

View File

@ -64,6 +64,7 @@ class License(models.Model):
AvailableFeature.GROUP_ANALYTICS,
AvailableFeature.MULTIVARIATE_FLAGS,
AvailableFeature.EXPERIMENTATION,
AvailableFeature.TAGGING,
]
ENTERPRISE_PLAN = "enterprise"

View File

@ -38,6 +38,7 @@ export enum AvailableFeature {
GROUP_ANALYTICS = 'group_analytics',
MULTIVARIATE_FLAGS = 'multivariate_flags',
EXPERIMENTATION = 'experimentation',
TAGGING = 'tagging',
}
export enum Realm {

View File

@ -4,7 +4,7 @@ axes: 0006_remove_accesslog_trusted
contenttypes: 0002_remove_content_type_name
database: 0002_auto_20190129_2304
ee: 0007_dashboard_permissions
posthog: 0205_auto_20220204_1748
posthog: 0206_global_tags_setup
rest_hooks: 0002_swappable_hook_model
sessions: 0001_initial
social_django: 0010_uid_db_index

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Dict, List, Optional, Union, cast
from dateutil.relativedelta import relativedelta
from django.core.cache import cache
@ -24,16 +24,9 @@ from posthog.api.shared import UserBasicSerializer
from posthog.api.utils import get_target_entity
from posthog.auth import PersonalAPIKeyAuthentication, TemporaryTokenAuthentication
from posthog.celery import update_cache_item_task
from posthog.constants import (
INSIGHT_STICKINESS,
TREND_FILTER_TYPE_ACTIONS,
TREND_FILTER_TYPE_EVENTS,
TRENDS_STICKINESS,
AvailableFeature,
)
from posthog.constants import INSIGHT_STICKINESS, TREND_FILTER_TYPE_ACTIONS, TREND_FILTER_TYPE_EVENTS, TRENDS_STICKINESS
from posthog.decorators import CacheType, cached_function
from posthog.event_usage import report_user_action
from posthog.filters import term_search_filter_sql
from posthog.models import (
Action,
ActionStep,
@ -54,6 +47,7 @@ from posthog.queries import base, retention, stickiness, trends
from posthog.utils import generate_cache_key, get_safe_cache, should_refresh
from .person import get_person_name
from .tagged_item import TaggedItemSerializerMixin, TaggedItemViewSetMixin
class ActionStepSerializer(serializers.HyperlinkedModelSerializer):
@ -82,7 +76,7 @@ class ActionStepSerializer(serializers.HyperlinkedModelSerializer):
}
class ActionSerializer(serializers.HyperlinkedModelSerializer):
class ActionSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedModelSerializer):
steps = ActionStepSerializer(many=True, required=False)
created_by = UserBasicSerializer(read_only=True)
is_calculating = serializers.SerializerMethodField()
@ -93,6 +87,7 @@ class ActionSerializer(serializers.HyperlinkedModelSerializer):
"id",
"name",
"description",
"tags",
"post_to_slack",
"slack_message_format",
"steps",
@ -179,7 +174,7 @@ class ActionSerializer(serializers.HyperlinkedModelSerializer):
return instance
class ActionViewSet(StructuredViewSetMixin, viewsets.ModelViewSet):
class ActionViewSet(TaggedItemViewSetMixin, StructuredViewSetMixin, viewsets.ModelViewSet):
renderer_classes = tuple(api_settings.DEFAULT_RENDERER_CLASSES) + (csvrenderers.PaginatedCSVRenderer,)
queryset = Action.objects.all()
serializer_class = ActionSerializer

View File

@ -0,0 +1,86 @@
from django.db.models import Prefetch, Q
from rest_framework import serializers, viewsets
from posthog.constants import AvailableFeature
from posthog.exceptions import EnterpriseFeatureException
from posthog.models import Tag, TaggedItem
class TaggedItemSerializerMixin(serializers.Serializer):
"""
Serializer mixin that resolves appropriate response for tags depending on license.
"""
tags = serializers.ListField(required=False)
def _is_licensed(self):
return (
"request" in self.context
and not self.context["request"].user.is_anonymous
and self.context["request"].user.organization.is_feature_available(AvailableFeature.TAGGING)
)
def _attempt_set_tags(self, tags, obj):
if not self._is_licensed() and tags is not None:
raise EnterpriseFeatureException()
if not obj or tags is None:
# If the object hasn't been created yet, this method will be called again on the create method.
return
# Normalize and dedupe tags
deduped_tags = list(set([t.strip().lower() for t in tags]))
tagged_item_objects = []
# Create tags
for tag in deduped_tags:
tag_instance, _ = Tag.objects.get_or_create(name=tag, team_id=obj.team_id)
tagged_item_instance, _ = obj.tagged_items.get_or_create(tag_id=tag_instance.id)
tagged_item_objects.append(tagged_item_instance)
# Delete tags that are missing
obj.tagged_items.exclude(tag__name__in=deduped_tags).delete()
# Cleanup tags that aren't used by team
Tag.objects.filter(Q(team_id=obj.team_id) & Q(tagged_items__isnull=True)).delete()
obj.prefetched_tags = tagged_item_objects
def to_representation(self, obj):
ret = super(TaggedItemSerializerMixin, self).to_representation(obj)
ret["tags"] = []
if self._is_licensed():
if hasattr(obj, "prefetched_tags"):
ret["tags"] = [p.tag.name for p in obj.prefetched_tags]
else:
ret["tags"] = list(obj.tagged_items.values_list("tag__name", flat=True)) if obj.tagged_items else []
return ret
def create(self, validated_data):
validated_data.pop("tags", None)
instance = super(TaggedItemSerializerMixin, self).create(validated_data)
self._attempt_set_tags(self.initial_data.get("tags"), instance)
return instance
def update(self, instance, validated_data):
instance = super(TaggedItemSerializerMixin, self).update(instance, validated_data)
self._attempt_set_tags(self.initial_data.get("tags"), instance)
return instance
class TaggedItemViewSetMixin(viewsets.GenericViewSet):
def is_licensed(self):
return (
not self.request.user.is_anonymous
# The below triggers an extra query to resolve user's organization.
and self.request.user.organization.is_feature_available(AvailableFeature.TAGGING) # type: ignore
)
def get_queryset(self):
queryset = super(TaggedItemViewSetMixin, self).get_queryset()
if self.is_licensed():
return queryset.prefetch_related(
Prefetch("tagged_items", queryset=TaggedItem.objects.select_related("tag"), to_attr="prefetched_tags")
)
return queryset

View File

@ -6,7 +6,7 @@ from rest_framework import status
from ee.clickhouse.models.event import create_event
from ee.clickhouse.util import ClickhouseTestMixin
from posthog.models import Action, ActionStep, Element, Event, Organization
from posthog.models import Action, ActionStep, Element, Event, Organization, Tag
from posthog.test.base import APIBaseTest
@ -166,8 +166,8 @@ class TestActionApi(ClickhouseTestMixin, APIBaseTest):
)
# test queries
with self.assertNumQueries(6):
# Django session, PostHog user, PostHog team, PostHog org membership,
with self.assertNumQueries(7):
# Django session, PostHog user, PostHog team, PostHog org membership, PostHog org
# PostHog action, PostHog action step
self.client.get(f"/api/projects/{self.team.id}/actions/")
@ -280,3 +280,63 @@ class TestActionApi(ClickhouseTestMixin, APIBaseTest):
action.calculate_events()
response = self.client.get(f"/api/projects/{self.team.id}/actions/{action.id}/count").json()
self.assertEqual(response, {"count": 1})
def test_get_tags_on_non_ee_returns_empty_list(self):
action = Action.objects.create(team=self.team, name="bla")
tag = Tag.objects.create(name="random", team_id=self.team.id)
action.tagged_items.create(tag_id=tag.id)
response = self.client.get(f"/api/projects/{self.team.id}/actions/{action.id}")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["tags"], [])
self.assertEqual(Action.objects.all().count(), 1)
def test_create_tags_on_non_ee_not_allowed(self):
response = self.client.post(
f"/api/projects/{self.team.id}/actions/", {"name": "Default", "tags": ["random", "hello"]},
)
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
self.assertEqual(Tag.objects.all().count(), 0)
def test_update_tags_on_non_ee_not_allowed(self):
action = Action.objects.create(team_id=self.team.id, name="private dashboard")
tag = Tag.objects.create(name="random", team_id=self.team.id)
action.tagged_items.create(tag_id=tag.id)
response = self.client.patch(
f"/api/projects/{self.team.id}/actions/{action.id}",
{"name": "action new name", "tags": ["random", "hello"], "description": "Internal system metrics.",},
)
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
def test_undefined_tags_allows_other_props_to_update(self):
action = Action.objects.create(team_id=self.team.id, name="private action")
tag = Tag.objects.create(name="random", team_id=self.team.id)
action.tagged_items.create(tag_id=tag.id)
response = self.client.patch(
f"/api/projects/{self.team.id}/actions/{action.id}",
{"name": "action new name", "description": "Internal system metrics.",},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["name"], "action new name")
self.assertEqual(response.json()["description"], "Internal system metrics.")
def test_empty_tags_does_not_delete_tags(self):
action = Action.objects.create(team_id=self.team.id, name="private dashboard")
tag = Tag.objects.create(name="random", team_id=self.team.id)
action.tagged_items.create(tag_id=tag.id)
self.assertEqual(Action.objects.all().count(), 1)
response = self.client.patch(
f"/api/projects/{self.team.id}/actions/{action.id}",
{"name": "action new name", "description": "Internal system metrics.", "tags": []},
)
self.assertEqual(response.status_code, status.HTTP_402_PAYMENT_REQUIRED)
self.assertEqual(Action.objects.all().count(), 1)

View File

@ -1,27 +1,14 @@
import json
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
)
from typing import Any, List, Optional, Tuple, Union, cast
from rest_framework import request, status
from sentry_sdk import capture_exception
from statshog.defaults.django import statsd
from posthog.constants import ENTITY_ID, ENTITY_MATH, ENTITY_TYPE
from posthog.exceptions import RequestParsingError, generate_exception_response
from posthog.models import Entity
from posthog.models.action import Action
from posthog.models.entity import MATH_TYPE
from posthog.models.event import Event
from posthog.models.filters.filter import Filter
from posthog.models.filters.stickiness_filter import StickinessFilter
from posthog.models.team import Team

View File

@ -18,6 +18,7 @@ class AvailableFeature(str, Enum):
GROUP_ANALYTICS = "group_analytics"
MULTIVARIATE_FLAGS = "multivariate_flags"
EXPERIMENTATION = "experimentation"
TAGGING = "tagging"
TREND_FILTER_TYPE_ACTIONS = "actions"

View File

@ -0,0 +1,64 @@
# Generated by Django 3.2.5 on 2022-02-11 23:56
import django.db.models.deletion
from django.db import migrations, models
import posthog.models.utils
class Migration(migrations.Migration):
dependencies = [
("posthog", "0205_auto_20220204_1748"),
]
operations = [
migrations.CreateModel(
name="Tag",
fields=[
(
"id",
models.UUIDField(
default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False
),
),
("name", models.CharField(max_length=255)),
("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")),
],
),
migrations.CreateModel(
name="TaggedItem",
fields=[
(
"id",
models.UUIDField(
default=posthog.models.utils.UUIDT, editable=False, primary_key=True, serialize=False
),
),
(
"action",
models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="tagged_items",
to="posthog.action",
),
),
(
"tag",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, related_name="tagged_items", to="posthog.tag"
),
),
],
),
migrations.AddConstraint(
model_name="taggeditem",
constraint=models.CheckConstraint(
check=models.Q(models.Q(("action__isnull", False)), _connector="OR"), name="exactly_one_related_object"
),
),
migrations.AlterUniqueTogether(name="taggeditem", unique_together={("tag", "action")},),
migrations.AlterUniqueTogether(name="tag", unique_together={("name", "team")},),
]

View File

@ -23,6 +23,8 @@ from .plugin import Plugin, PluginAttachment, PluginConfig, PluginLogEntry
from .property import Property
from .property_definition import PropertyDefinition
from .session_recording_event import SessionRecordingEvent
from .tag import Tag
from .tagged_item import TaggedItem
from .team import Team
from .user import User, UserManager
@ -59,6 +61,8 @@ __all__ = [
"Property",
"PropertyDefinition",
"SessionRecordingEvent",
"Tag",
"TaggedItem",
"Team",
"User",
"UserManager",

14
posthog/models/tag.py Normal file
View File

@ -0,0 +1,14 @@
from django.db import models
from posthog.models.utils import UUIDModel
class Tag(UUIDModel):
name: models.CharField = models.CharField(max_length=255)
team: models.ForeignKey = models.ForeignKey("Team", on_delete=models.CASCADE)
class Meta:
unique_together = ("name", "team")
def __str__(self):
return self.name

View File

@ -0,0 +1,61 @@
from typing import List, Union
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models import Q
from posthog.models.utils import UUIDModel
RELATED_OBJECTS = ("action",)
def build_check():
built_check_list: List[Union[Q, Q]] = []
# Only one object field can be populated
for o_field in RELATED_OBJECTS:
built_check_list.append(
Q(*[(f"{_o_field}__isnull", _o_field != o_field) for _o_field in RELATED_OBJECTS], _connector="AND")
)
return Q(*built_check_list, _connector="OR")
class TaggedItem(UUIDModel):
"""
Taggable describes global tag-object relationships.
Note: This is an EE only feature, however the model exists in posthog so that it is backwards accessible from all
models. Whether we should be able to interact with this table is determined in the `TaggedItemSerializer` which
imports `EnterpriseTaggedItemSerializer` if the feature is available.
Today, tags exist at the model-level making it impossible to aggregate, filter, and query objects appwide by tags.
We want to deprecate model-specific tags and refactor tag relationships into a separate table that keeps track of
tag-object relationships.
Models that are taggable throughout the app are listed as separate fields below.
https://docs.djangoproject.com/en/4.0/ref/contrib/contenttypes/#generic-relations
"""
tag: models.ForeignKey = models.ForeignKey("Tag", on_delete=models.CASCADE, related_name="tagged_items")
# When adding a new taggeditem-model relationship, make sure to add the foreign key field and append field name to
# the `RELATED_OBJECTS` tuple above.
action: models.ForeignKey = models.ForeignKey(
"Action", on_delete=models.CASCADE, null=True, blank=True, related_name="tagged_items"
)
class Meta:
# Make sure to add new key to uniqueness constraint when extending tag functionality to new model
unique_together = ("tag",) + RELATED_OBJECTS
constraints = [models.CheckConstraint(check=build_check(), name="exactly_one_related_object",)]
def clean(self):
super().clean()
"""Ensure that exactly one of object columns can be set."""
if sum(map(bool, [getattr(self, o_field) for o_field in RELATED_OBJECTS])) != 1:
raise ValidationError("Exactly one object field must be set.")
def save(self, *args, **kwargs):
self.full_clean()
return super(TaggedItem, self).save(*args, **kwargs)
def __str__(self):
return self.tag