0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-28 00:46:45 +01:00
posthog/ee/api/billing.py
2024-10-10 12:31:55 -04:00

322 lines
12 KiB
Python

from typing import Any, Optional
import posthoganalytics
import requests
import structlog
from django.conf import settings
from django.http import HttpResponse
from django.shortcuts import redirect
from rest_framework import serializers, status, viewsets
from posthog.api.utils import action
from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from django.contrib.auth.models import AbstractUser
from ee.billing.billing_manager import BillingManager, build_billing_token
from ee.models import License
from ee.settings import BILLING_SERVICE_URL
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.cloud_utils import get_cached_instance_license
from posthog.event_usage import groups
from posthog.models import Organization
logger = structlog.get_logger(__name__)
BILLING_SERVICE_JWT_AUD = "posthog:license-key"
class BillingSerializer(serializers.Serializer):
plan = serializers.CharField(max_length=100)
billing_limit = serializers.IntegerField()
class LicenseKeySerializer(serializers.Serializer):
license = serializers.CharField()
class BillingViewset(TeamAndOrgViewSetMixin, viewsets.GenericViewSet):
serializer_class = BillingSerializer
param_derived_from_user_current_team = "team_id"
scope_object = "INTERNAL"
def get_billing_manager(self) -> BillingManager:
license = get_cached_instance_license()
user = (
self.request.user if isinstance(self.request.user, AbstractUser) and self.request.user.distinct_id else None
)
return BillingManager(license, user)
def list(self, request: Request, *args: Any, **kwargs: Any) -> Response:
license = get_cached_instance_license()
if license and not license.is_v2_license:
raise NotFound("Billing is not supported for this license type")
org = self._get_org()
# If on Cloud and we have the property billing - return 404 as we always use legacy billing it it exists
if hasattr(org, "billing"):
if org.billing.stripe_subscription_id: # type: ignore
raise NotFound("Billing V1 is active for this organization")
plan_keys = request.query_params.get("plan_keys", None)
billing_manager = self.get_billing_manager()
response = billing_manager.get_billing(org, plan_keys)
return Response(response)
@action(methods=["PATCH"], detail=False, url_path="/")
def patch(self, request: Request, *args: Any, **kwargs: Any) -> Response:
distinct_id = None if self.request.user.is_anonymous else self.request.user.distinct_id
license = get_cached_instance_license()
if not license:
raise Exception("There is no license configured for this instance yet.")
org = self._get_org_required()
if license and org: # for mypy
custom_limits_usd = request.data.get("custom_limits_usd")
if custom_limits_usd:
billing_manager = self.get_billing_manager()
billing_manager.update_billing(org, {"custom_limits_usd": custom_limits_usd})
if distinct_id:
posthoganalytics.capture(
distinct_id,
"billing limits updated",
properties={**custom_limits_usd},
groups=(
groups(org, self.request.user.team) if hasattr(self.request.user, "team") else groups(org)
),
)
posthoganalytics.group_identify(
"organization",
str(org.id),
properties={f"billing_limits_{key}": value for key, value in custom_limits_usd.items()},
)
return self.list(request, *args, **kwargs)
class ActivateSerializer(serializers.Serializer):
plan = serializers.CharField(required=False)
products = serializers.CharField(
required=False
) # This is required but in order to support an error for the legacy 'plan' param we need to set required=False
redirect_path = serializers.CharField(required=False)
intent_product = serializers.CharField(required=False)
def validate(self, data):
plan = data.get("plan")
products = data.get("products")
if plan and not products:
raise ValidationError(
{
"plan": "The 'plan' parameter is no longer supported. Please use the 'products' parameter instead."
}
)
if not products:
raise ValidationError({"products": "The 'products' parameter is required."})
return data
# This is deprecated and should be removed in the future in favor of 'activate'
@action(methods=["GET"], detail=False)
def activation(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
return self.handle_activate(request, *args, **kwargs)
@action(methods=["GET"], detail=False)
def activate(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
return self.handle_activate(request, *args, **kwargs)
# A viewset action cannot call another action directly so this is in place until
# the 'activation' endpoint is removed. Once removed, this method can move to the 'activate' action
def handle_activate(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
organization = self._get_org_required()
serializer = self.ActivateSerializer(data=request.GET)
serializer.is_valid(raise_exception=True)
redirect_path = serializer.validated_data.get("redirect_path", "organization/billing")
if redirect_path.startswith("/"):
redirect_path = redirect_path[1:]
redirect_uri = f"{settings.SITE_URL or request.headers.get('Host')}/{redirect_path}"
url = f"{BILLING_SERVICE_URL}/activate?redirect_uri={redirect_uri}&organization_name={organization.name}"
products = serializer.validated_data.get("products")
url = f"{url}&products={products}"
intent_product = serializer.validated_data.get("intent_product")
if intent_product:
url = f"{url}&intent_product={intent_product}"
if license:
billing_service_token = build_billing_token(license, organization)
url = f"{url}&token={billing_service_token}"
return redirect(url)
class DeactivateSerializer(serializers.Serializer):
products = serializers.CharField()
@action(methods=["GET"], detail=False)
def deactivate(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
organization = self._get_org_required()
serializer = self.DeactivateSerializer(data=request.GET)
serializer.is_valid(raise_exception=True)
products = serializer.validated_data.get("products")
try:
billing_manager = self.get_billing_manager()
billing_manager.deactivate_products(organization, products)
except Exception as e:
if len(e.args) > 2:
detail_object = e.args[2]
return Response(
{
"statusText": e.args[0],
"detail": detail_object.get("error_message", detail_object),
"link": detail_object.get("link", None),
"code": detail_object.get("code"),
},
status=status.HTTP_400_BAD_REQUEST,
)
else:
raise
return self.list(request, *args, **kwargs)
@action(methods=["GET"], detail=False)
def portal(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
if not license:
return Response(
{"success": True},
status=status.HTTP_200_OK,
)
organization = self._get_org_required()
billing_manager = self.get_billing_manager()
res = billing_manager._get_stripe_portal_url(organization)
return redirect(res)
@action(methods=["GET"], detail=False)
def get_invoices(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
if not license:
return Response(
{"success": True},
status=status.HTTP_200_OK,
)
organization = self._get_org_required()
invoice_status = request.GET.get("status")
try:
billing_manager = self.get_billing_manager()
res = billing_manager.get_invoices(organization, status=invoice_status)
except Exception as e:
if len(e.args) > 2:
detail_object = e.args[2]
if not isinstance(detail_object, dict):
raise
return Response(
{
"statusText": e.args[0],
"detail": detail_object.get("error_message", detail_object),
"code": detail_object.get("code"),
},
status=status.HTTP_400_BAD_REQUEST,
)
else:
raise
return Response(
{
"link": res.get("portal_url"),
"count": res.get("count"),
},
status=status.HTTP_200_OK,
)
@action(methods=["GET"], detail=False, url_path="credits/overview")
def credits_overview(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
if not license:
return Response(
{"success": True},
status=status.HTTP_200_OK,
)
organization = self._get_org_required()
billing_manager = self.get_billing_manager()
res = billing_manager.credits_overview(organization)
return Response(res, status=status.HTTP_200_OK)
@action(methods=["POST"], detail=False, url_path="credits/purchase")
def purchase_credits(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
if not license:
return Response(
{"success": True},
status=status.HTTP_200_OK,
)
organization = self._get_org_required()
billing_manager = self.get_billing_manager()
res = billing_manager.purchase_credits(organization, request.data)
return Response(res, status=status.HTTP_200_OK)
@action(methods=["PATCH"], detail=False)
def license(self, request: Request, *args: Any, **kwargs: Any) -> HttpResponse:
license = get_cached_instance_license()
if license:
raise PermissionDenied(
"A valid license key already exists. This must be removed before a new one can be added."
)
organization = self._get_org_required()
serializer = LicenseKeySerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
license = License(key=serializer.validated_data["license"])
res = requests.get(
f"{BILLING_SERVICE_URL}/api/billing",
headers=BillingManager(license).get_auth_headers(organization),
)
if res.status_code != 200:
raise ValidationError(
{
"license": f"License could not be activated. Please contact support. (BillingService status {res.status_code})",
}
)
data = res.json()
BillingManager(license).update_license_details(data)
return Response({"success": True})
def _get_org(self) -> Optional[Organization]:
org = None if self.request.user.is_anonymous else self.request.user.organization
return org
def _get_org_required(self) -> Organization:
org = self._get_org()
if not org:
raise Exception("You cannot interact with the billing service without an organization configured.")
return org