mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-25 11:17:50 +01:00
111 lines
4.0 KiB
Python
111 lines
4.0 KiB
Python
from typing import Any
|
|
from rest_framework import serializers, viewsets
|
|
from rest_framework.exceptions import ValidationError
|
|
from rest_framework.request import Request
|
|
from rest_framework.response import Response
|
|
from django.db import transaction
|
|
|
|
|
|
from posthog.api.feature_flag import FeatureFlagSerializer
|
|
from posthog.api.routing import TeamAndOrgViewSetMixin
|
|
from posthog.api.shared import UserBasicSerializer
|
|
from posthog.models.experiment import ExperimentHoldout
|
|
|
|
|
|
class ExperimentHoldoutSerializer(serializers.ModelSerializer):
|
|
created_by = UserBasicSerializer(read_only=True)
|
|
|
|
class Meta:
|
|
model = ExperimentHoldout
|
|
fields = [
|
|
"id",
|
|
"name",
|
|
"description",
|
|
"filters",
|
|
"created_by",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
read_only_fields = [
|
|
"id",
|
|
"created_by",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
|
|
def _get_filters_with_holdout_id(self, id: int, filters: list) -> list:
|
|
variant_key = f"holdout-{id}"
|
|
updated_filters = []
|
|
for filter in filters:
|
|
updated_filters.append(
|
|
{
|
|
**filter,
|
|
"variant": variant_key,
|
|
}
|
|
)
|
|
return updated_filters
|
|
|
|
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExperimentHoldout:
|
|
request = self.context["request"]
|
|
validated_data["created_by"] = request.user
|
|
validated_data["team_id"] = self.context["team_id"]
|
|
|
|
if not validated_data.get("filters"):
|
|
raise ValidationError("Filters are required to create an holdout group")
|
|
|
|
instance = super().create(validated_data)
|
|
instance.filters = self._get_filters_with_holdout_id(instance.id, instance.filters)
|
|
instance.save()
|
|
return instance
|
|
|
|
def update(self, instance: ExperimentHoldout, validated_data):
|
|
filters = validated_data.get("filters")
|
|
if filters and instance.filters != filters:
|
|
# update flags on all experiments in this holdout group
|
|
new_filters = self._get_filters_with_holdout_id(instance.id, filters)
|
|
validated_data["filters"] = new_filters
|
|
with transaction.atomic():
|
|
for experiment in instance.experiment_set.all():
|
|
flag = experiment.feature_flag
|
|
existing_flag_serializer = FeatureFlagSerializer(
|
|
flag,
|
|
data={
|
|
"filters": {**flag.filters, "holdout_groups": validated_data["filters"]},
|
|
},
|
|
partial=True,
|
|
context=self.context,
|
|
)
|
|
existing_flag_serializer.is_valid(raise_exception=True)
|
|
existing_flag_serializer.save()
|
|
|
|
return super().update(instance, validated_data)
|
|
|
|
|
|
class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
|
|
scope_object = "experiment"
|
|
queryset = ExperimentHoldout.objects.prefetch_related("created_by").all()
|
|
serializer_class = ExperimentHoldoutSerializer
|
|
ordering = "-created_at"
|
|
|
|
def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response:
|
|
instance = self.get_object()
|
|
|
|
with transaction.atomic():
|
|
for experiment in instance.experiment_set.all():
|
|
flag = experiment.feature_flag
|
|
existing_flag_serializer = FeatureFlagSerializer(
|
|
flag,
|
|
data={
|
|
"filters": {
|
|
**flag.filters,
|
|
"holdout_groups": None,
|
|
}
|
|
},
|
|
partial=True,
|
|
context={"request": request, "team": self.team, "team_id": self.team_id},
|
|
)
|
|
existing_flag_serializer.is_valid(raise_exception=True)
|
|
existing_flag_serializer.save()
|
|
|
|
return super().destroy(request, *args, **kwargs)
|