0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-25 11:17:50 +01:00
posthog/ee/clickhouse/views/experiment_holdouts.py
2024-10-24 13:11:41 +00:00

111 lines
4.0 KiB
Python

from typing import Any
from rest_framework import serializers, viewsets
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from django.db import transaction
from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models.experiment import ExperimentHoldout
class ExperimentHoldoutSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
class Meta:
model = ExperimentHoldout
fields = [
"id",
"name",
"description",
"filters",
"created_by",
"created_at",
"updated_at",
]
read_only_fields = [
"id",
"created_by",
"created_at",
"updated_at",
]
def _get_filters_with_holdout_id(self, id: int, filters: list) -> list:
variant_key = f"holdout-{id}"
updated_filters = []
for filter in filters:
updated_filters.append(
{
**filter,
"variant": variant_key,
}
)
return updated_filters
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExperimentHoldout:
request = self.context["request"]
validated_data["created_by"] = request.user
validated_data["team_id"] = self.context["team_id"]
if not validated_data.get("filters"):
raise ValidationError("Filters are required to create an holdout group")
instance = super().create(validated_data)
instance.filters = self._get_filters_with_holdout_id(instance.id, instance.filters)
instance.save()
return instance
def update(self, instance: ExperimentHoldout, validated_data):
filters = validated_data.get("filters")
if filters and instance.filters != filters:
# update flags on all experiments in this holdout group
new_filters = self._get_filters_with_holdout_id(instance.id, filters)
validated_data["filters"] = new_filters
with transaction.atomic():
for experiment in instance.experiment_set.all():
flag = experiment.feature_flag
existing_flag_serializer = FeatureFlagSerializer(
flag,
data={
"filters": {**flag.filters, "holdout_groups": validated_data["filters"]},
},
partial=True,
context=self.context,
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()
return super().update(instance, validated_data)
class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "experiment"
queryset = ExperimentHoldout.objects.prefetch_related("created_by").all()
serializer_class = ExperimentHoldoutSerializer
ordering = "-created_at"
def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response:
instance = self.get_object()
with transaction.atomic():
for experiment in instance.experiment_set.all():
flag = experiment.feature_flag
existing_flag_serializer = FeatureFlagSerializer(
flag,
data={
"filters": {
**flag.filters,
"holdout_groups": None,
}
},
partial=True,
context={"request": request, "team": self.team, "team_id": self.team_id},
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()
return super().destroy(request, *args, **kwargs)