mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-24 09:14:46 +01:00
a5375dfa0d
* Send updated admin emails to billing * add tests * sync when a user is removed from an org * fix the admin email filter * sync admin emails on level change * add more tests
413 lines
16 KiB
Python
413 lines
16 KiB
Python
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
from enum import Enum
|
|
from typing import Any, Optional, Union, cast
|
|
|
|
import jwt
|
|
import requests
|
|
import structlog
|
|
from django.utils import timezone
|
|
from requests import JSONDecodeError # type: ignore[attr-defined]
|
|
from rest_framework.exceptions import NotAuthenticated
|
|
from sentry_sdk import capture_exception
|
|
|
|
from ee.billing.billing_types import BillingStatus, Tier
|
|
from ee.billing.quota_limiting import set_org_usage_summary, sync_org_quota_limits
|
|
from ee.models import License
|
|
from ee.settings import BILLING_SERVICE_URL
|
|
from posthog.cloud_utils import get_cached_instance_license
|
|
from posthog.models import Organization
|
|
from posthog.models.organization import OrganizationMembership, OrganizationUsageInfo
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class BillingAPIErrorCodes(Enum):
|
|
OPEN_INVOICES_ERROR = "open_invoices_error"
|
|
|
|
|
|
def build_billing_token(license: License, organization: Organization):
|
|
if not organization or not license:
|
|
raise NotAuthenticated()
|
|
|
|
license_id = license.key.split("::")[0]
|
|
license_secret = license.key.split("::")[1]
|
|
|
|
encoded_jwt = jwt.encode(
|
|
{
|
|
"exp": datetime.now(tz=timezone.utc) + timedelta(minutes=15),
|
|
"id": license_id,
|
|
"organization_id": str(organization.id),
|
|
"organization_name": organization.name,
|
|
"aud": "posthog:license-key",
|
|
},
|
|
license_secret,
|
|
algorithm="HS256",
|
|
)
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
def handle_billing_service_error(res: requests.Response, valid_codes=(200, 404, 401)) -> None:
|
|
if res.status_code not in valid_codes:
|
|
logger.error(f"Billing service returned bad status code: {res.status_code}, body: {res.text}")
|
|
try:
|
|
response = res.json()
|
|
raise Exception(f"Billing service returned bad status code: {res.status_code}", f"body:", response)
|
|
except JSONDecodeError:
|
|
raise Exception(f"Billing service returned bad status code: {res.status_code}", f"body:", res.text)
|
|
|
|
|
|
def compute_usage_per_tier(current_usage: int, projected_usage: int, tiers):
|
|
remaining_usage = current_usage
|
|
remaining_projected_usage = projected_usage or 0
|
|
previous_tier: Optional[dict[str, Any]] = None
|
|
tier_max_usage: Union[int, float] = 0
|
|
|
|
result: list[Tier] = []
|
|
for tier in tiers:
|
|
if previous_tier and previous_tier.get("up_to"):
|
|
previous_tier_up_to = previous_tier["up_to"]
|
|
else:
|
|
previous_tier_up_to = 0
|
|
|
|
if tier.get("up_to"):
|
|
tier_max_usage = tier["up_to"] - previous_tier_up_to
|
|
else:
|
|
tier_max_usage = float("inf")
|
|
|
|
flat_amount_usd = Decimal(tier.get("flat_amount_usd") or 0)
|
|
unit_amount_usd = Decimal(tier.get("unit_amount_usd") or 0)
|
|
usage_this_tier = int(min(remaining_usage, tier_max_usage))
|
|
remaining_usage -= usage_this_tier
|
|
current_amount_usd = Decimal(unit_amount_usd * usage_this_tier + flat_amount_usd).quantize(Decimal("0.01"))
|
|
previous_tier = tier
|
|
if projected_usage:
|
|
projected_usage_this_tier = int(min(remaining_projected_usage, tier_max_usage))
|
|
remaining_projected_usage -= projected_usage_this_tier
|
|
projected_amount_usd = Decimal(unit_amount_usd * projected_usage_this_tier + flat_amount_usd).quantize(
|
|
Decimal("0.01")
|
|
)
|
|
else:
|
|
projected_usage_this_tier = None
|
|
projected_amount_usd = None
|
|
|
|
result.append(
|
|
Tier(
|
|
flat_amount_usd=str(flat_amount_usd),
|
|
unit_amount_usd=str(unit_amount_usd),
|
|
up_to=tier.get("up_to", None),
|
|
current_amount_usd=str(current_amount_usd),
|
|
current_usage=usage_this_tier,
|
|
projected_usage=projected_usage_this_tier,
|
|
projected_amount_usd=str(projected_amount_usd),
|
|
)
|
|
)
|
|
return result
|
|
|
|
|
|
def sum_total_across_tiers(tiers):
|
|
total = Decimal(0)
|
|
for tier in tiers:
|
|
total += Decimal(tier["current_amount_usd"])
|
|
return total
|
|
|
|
|
|
class BillingManager:
|
|
license: Optional[License]
|
|
|
|
def __init__(self, license):
|
|
self.license = license or get_cached_instance_license()
|
|
|
|
def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> dict[str, Any]:
|
|
if organization and self.license and self.license.is_v2_license:
|
|
billing_service_response = self._get_billing(organization)
|
|
|
|
# Ensure the license and org are updated with the latest info
|
|
if billing_service_response.get("license"):
|
|
self.update_license_details(billing_service_response)
|
|
|
|
if organization and billing_service_response:
|
|
self.update_org_details(organization, billing_service_response)
|
|
|
|
response: dict[str, Any] = {"available_features": []}
|
|
|
|
response["license"] = {"plan": self.license.plan}
|
|
|
|
if organization and billing_service_response.get("customer"):
|
|
response.update(billing_service_response["customer"])
|
|
|
|
if not billing_service_response["customer"].get("products"):
|
|
products = self.get_default_products(organization)
|
|
response["products"] = products["products"]
|
|
|
|
stripe_portal_url = self._get_stripe_portal_url(organization)
|
|
response["stripe_portal_url"] = stripe_portal_url
|
|
|
|
# Extend the products with accurate usage_limit info
|
|
for product in response["products"]:
|
|
usage_key = product.get("usage_key", None)
|
|
if not usage_key:
|
|
continue
|
|
usage = response.get("usage_summary", {}).get(usage_key, {})
|
|
usage_limit = usage.get("limit")
|
|
current_usage = usage.get("usage") or 0
|
|
|
|
if (
|
|
organization
|
|
and organization.usage
|
|
and organization.usage.get(usage_key, {}).get("todays_usage", None)
|
|
):
|
|
todays_usage = organization.usage[usage_key]["todays_usage"]
|
|
current_usage = current_usage + todays_usage
|
|
|
|
product["current_usage"] = current_usage
|
|
product["percentage_usage"] = current_usage / usage_limit if usage_limit else 0
|
|
|
|
# Also update the tiers
|
|
if product.get("tiers"):
|
|
product["tiers"] = compute_usage_per_tier(
|
|
current_usage, product["projected_usage"], product["tiers"]
|
|
)
|
|
product["current_amount_usd"] = sum_total_across_tiers(product["tiers"])
|
|
|
|
# Update the add on tiers
|
|
# TODO: enhanced_persons: make sure this updates properly for addons with different usage keys
|
|
for addon in product.get("addons"):
|
|
if not addon.get("subscribed"):
|
|
continue
|
|
addon_usage_key = addon.get("usage_key")
|
|
if not usage_key:
|
|
continue
|
|
if addon_usage_key != usage_key:
|
|
usage = response.get("usage_summary", {}).get(addon_usage_key, {})
|
|
usage_limit = usage.get("limit")
|
|
current_usage = usage.get("usage") or 0
|
|
if (
|
|
organization
|
|
and organization.usage
|
|
and organization.usage.get(usage_key, {}).get("todays_usage", None)
|
|
):
|
|
todays_usage = organization.usage[usage_key]["todays_usage"]
|
|
current_usage = current_usage + todays_usage
|
|
addon["current_usage"] = current_usage
|
|
addon["tiers"] = compute_usage_per_tier(current_usage, addon["projected_usage"], addon["tiers"])
|
|
addon["current_amount_usd"] = sum_total_across_tiers(addon["tiers"])
|
|
else:
|
|
products = self.get_default_products(organization)
|
|
response = {
|
|
"available_features": [],
|
|
"products": products["products"],
|
|
}
|
|
|
|
return response
|
|
|
|
def update_billing(self, organization: Organization, data: dict[str, Any]) -> None:
|
|
res = requests.patch(
|
|
f"{BILLING_SERVICE_URL}/api/billing/",
|
|
headers=self.get_auth_headers(organization),
|
|
json=data,
|
|
)
|
|
|
|
handle_billing_service_error(res)
|
|
|
|
def update_billing_distinct_ids(self, organization: Organization) -> None:
|
|
distinct_ids = list(organization.members.values_list("distinct_id", flat=True))
|
|
self.update_billing(organization, {"distinct_ids": distinct_ids})
|
|
|
|
def update_billing_customer_email(self, organization: Organization) -> None:
|
|
try:
|
|
owner_membership = OrganizationMembership.objects.get(organization=organization, level=15)
|
|
user = owner_membership.user
|
|
self.update_billing(organization, {"org_customer_email": user.email})
|
|
except Exception as e:
|
|
capture_exception(e)
|
|
|
|
def update_billing_admin_emails(self, organization: Organization) -> None:
|
|
try:
|
|
admin_emails = list(
|
|
organization.members.filter(
|
|
organization_membership__level__gte=OrganizationMembership.Level.ADMIN
|
|
).values_list("email", flat=True)
|
|
)
|
|
self.update_billing(organization, {"org_admin_emails": admin_emails})
|
|
except Exception as e:
|
|
capture_exception(e)
|
|
|
|
def deactivate_products(self, organization: Organization, products: str) -> None:
|
|
res = requests.get(
|
|
f"{BILLING_SERVICE_URL}/api/billing/deactivate?products={products}",
|
|
headers=self.get_auth_headers(organization),
|
|
)
|
|
|
|
handle_billing_service_error(res)
|
|
|
|
def get_default_products(self, organization: Optional[Organization]) -> dict:
|
|
response = {}
|
|
# If we don't have products from the billing service then get the default ones with our local usage calculation
|
|
products = self._get_products(organization)
|
|
response["products"] = products
|
|
|
|
return response
|
|
|
|
def update_license_details(self, billing_status: BillingStatus) -> License:
|
|
"""
|
|
Ensure the license details are up-to-date locally
|
|
"""
|
|
if not self.license: # mypy
|
|
raise Exception("No license found")
|
|
|
|
license_modified = False
|
|
|
|
data = billing_status["license"]
|
|
|
|
if not self.license.valid_until or self.license.valid_until < timezone.now() + timedelta(days=29):
|
|
# NOTE: License validity is a legacy concept. For now we always extend the license validity by 30 days.
|
|
self.license.valid_until = timezone.now() + timedelta(days=30)
|
|
license_modified = True
|
|
|
|
if self.license.plan != data["type"]:
|
|
self.license.plan = data["type"]
|
|
license_modified = True
|
|
|
|
if license_modified:
|
|
self.license.save()
|
|
|
|
return self.license
|
|
|
|
def _get_billing(self, organization: Organization) -> BillingStatus:
|
|
"""
|
|
Retrieves billing info and updates local models if necessary
|
|
"""
|
|
if not self.license: # mypy
|
|
raise Exception("No license found")
|
|
|
|
res = requests.get(
|
|
f"{BILLING_SERVICE_URL}/api/billing",
|
|
headers=self.get_auth_headers(organization),
|
|
)
|
|
handle_billing_service_error(res)
|
|
|
|
data = res.json()
|
|
|
|
return data
|
|
|
|
def _get_stripe_portal_url(self, organization: Organization) -> BillingStatus:
|
|
"""
|
|
Retrieves stripe protal url
|
|
"""
|
|
if not self.license: # mypy
|
|
raise Exception("No license found")
|
|
|
|
res = requests.get(
|
|
f"{BILLING_SERVICE_URL}/api/billing/portal",
|
|
headers=self.get_auth_headers(organization),
|
|
)
|
|
|
|
handle_billing_service_error(res)
|
|
|
|
data = res.json()
|
|
|
|
return data["url"]
|
|
|
|
def _get_products(self, organization: Optional[Organization]):
|
|
headers = {}
|
|
params = {"plan": "standard"}
|
|
|
|
if self.license and organization:
|
|
headers = self.get_auth_headers(organization)
|
|
|
|
res = requests.get(
|
|
f"{BILLING_SERVICE_URL}/api/products-v2",
|
|
params=params,
|
|
headers=headers,
|
|
)
|
|
|
|
handle_billing_service_error(res)
|
|
|
|
return res.json().get("products", [])
|
|
|
|
def update_org_details(self, organization: Organization, billing_status: BillingStatus) -> Organization:
|
|
"""
|
|
Ensure the relevant organization details are up-to-date locally
|
|
"""
|
|
org_modified = False
|
|
|
|
data = billing_status["customer"]
|
|
|
|
if data.get("customer_id") and organization.customer_id != data["customer_id"]:
|
|
organization.customer_id = data["customer_id"]
|
|
org_modified = True
|
|
|
|
usage_summary = cast(dict, data.get("usage_summary"))
|
|
if usage_summary:
|
|
usage_info = OrganizationUsageInfo(
|
|
events=usage_summary["events"],
|
|
recordings=usage_summary["recordings"],
|
|
rows_synced=usage_summary.get("rows_synced", None),
|
|
period=[
|
|
data["billing_period"]["current_period_start"],
|
|
data["billing_period"]["current_period_end"],
|
|
],
|
|
)
|
|
|
|
if set_org_usage_summary(organization, new_usage=usage_info):
|
|
org_modified = True
|
|
sync_org_quota_limits(organization)
|
|
|
|
available_features = data.get("available_features", None)
|
|
if available_features and available_features != organization.available_features:
|
|
organization.available_features = data["available_features"]
|
|
org_modified = True
|
|
|
|
available_product_features = data.get("available_product_features", None)
|
|
if available_product_features and available_product_features != organization.available_product_features:
|
|
organization.available_product_features = data["available_product_features"]
|
|
org_modified = True
|
|
|
|
never_drop_data = data.get("never_drop_data", None)
|
|
if never_drop_data != organization.never_drop_data:
|
|
organization.never_drop_data = never_drop_data
|
|
org_modified = True
|
|
|
|
customer_trust_scores = data.get("customer_trust_scores", {})
|
|
|
|
product_key_to_usage_key = {
|
|
product["type"]: product["usage_key"]
|
|
for product in (
|
|
billing_status["customer"].get("products") or self.get_default_products(organization)["products"]
|
|
)
|
|
}
|
|
org_customer_trust_scores = {}
|
|
for product_key in customer_trust_scores:
|
|
if product_key in product_key_to_usage_key:
|
|
org_customer_trust_scores[product_key_to_usage_key[product_key]] = customer_trust_scores[product_key]
|
|
|
|
if org_customer_trust_scores != organization.customer_trust_scores:
|
|
organization.customer_trust_scores = customer_trust_scores
|
|
org_modified = True
|
|
|
|
if org_modified:
|
|
organization.save()
|
|
|
|
return organization
|
|
|
|
def get_auth_headers(self, organization: Organization):
|
|
if not self.license: # mypy
|
|
raise Exception("No license found")
|
|
billing_service_token = build_billing_token(self.license, organization)
|
|
return {"Authorization": f"Bearer {billing_service_token}"}
|
|
|
|
def get_invoices(self, organization: Organization, status: Optional[str]):
|
|
res = requests.get(
|
|
f"{BILLING_SERVICE_URL}/api/billing/get_invoices",
|
|
params={"status": status},
|
|
headers=self.get_auth_headers(organization),
|
|
)
|
|
|
|
handle_billing_service_error(res)
|
|
|
|
data = res.json()
|
|
|
|
return data
|