0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-24 00:47:50 +01:00

chore: Add Pyupgrade rules (#21714)

* Add Pyupgrade rules
* Set correct Python version
This commit is contained in:
Julian Bez 2024-04-25 08:22:28 +01:00 committed by GitHub
parent db0aef207c
commit 9576fab1e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
523 changed files with 3100 additions and 3141 deletions

View File

@ -21,7 +21,6 @@
import argparse
import sys
from typing import List
from kafka import KafkaAdminClient, KafkaConsumer, KafkaProducer
from kafka.errors import KafkaError
@ -192,7 +191,7 @@ def handle(**options):
print("Polling for messages") # noqa: T201
messages_by_topic = consumer.poll(timeout_ms=timeout_ms)
futures: List[FutureRecordMetadata] = []
futures: list[FutureRecordMetadata] = []
if not messages_by_topic:
break

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any, Union
from django.core.exceptions import ValidationError as DjangoValidationError
from django.http.response import HttpResponse
@ -91,8 +91,8 @@ class MultitenantSAMLAuth(SAMLAuth):
def _get_attr(
self,
response_attributes: Dict[str, Any],
attribute_names: List[str],
response_attributes: dict[str, Any],
attribute_names: list[str],
optional: bool = False,
) -> str:
"""

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, cast
from typing import Any, cast
from django.db import IntegrityError
from rest_framework import exceptions, mixins, serializers, viewsets
@ -45,7 +45,7 @@ class DashboardCollaboratorSerializer(serializers.ModelSerializer, UserPermissio
]
read_only_fields = ["id", "dashboard_id", "user", "user"]
def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
dashboard: Dashboard = self.context["dashboard"]
dashboard_permissions = self.user_permissions.dashboard(dashboard)
if dashboard_permissions.effective_restriction_level <= Dashboard.RestrictionLevel.EVERYONE_IN_PROJECT_CAN_EDIT:
@ -96,7 +96,7 @@ class DashboardCollaboratorViewSet(
serializer_class = DashboardCollaboratorSerializer
filter_rewrite_rules = {"team_id": "dashboard__team_id"}
def get_serializer_context(self) -> Dict[str, Any]:
def get_serializer_context(self) -> dict[str, Any]:
context = super().get_serializer_context()
try:
context["dashboard"] = Dashboard.objects.get(id=context["dashboard_id"])

View File

@ -1,4 +1,4 @@
from typing import List, cast
from typing import cast
from django.db import IntegrityError
from rest_framework import mixins, serializers, viewsets
@ -76,7 +76,7 @@ class RoleSerializer(serializers.ModelSerializer):
return RoleMembershipSerializer(members, many=True).data
def get_associated_flags(self, role: Role):
associated_flags: List[dict] = []
associated_flags: list[dict] = []
role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id")
flags = FeatureFlag.objects.filter(id__in=role_access_objects)

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import requests
from django.http import HttpRequest, JsonResponse
@ -9,8 +9,8 @@ from rest_framework.exceptions import ValidationError
from posthog.models.instance_setting import get_instance_settings
def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]:
sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"])
def get_sentry_stats(start_time: str, end_time: str) -> tuple[dict, int]:
sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"])
org_slug = sentry_config.get("SENTRY_ORGANIZATION")
token = sentry_config.get("SENTRY_AUTH_TOKEN")
@ -41,9 +41,9 @@ def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]:
def get_tagged_issues_stats(
start_time: str, end_time: str, tags: Dict[str, str], target_issues: List[str]
) -> Dict[str, Any]:
sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"])
start_time: str, end_time: str, tags: dict[str, str], target_issues: list[str]
) -> dict[str, Any]:
sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"])
org_slug = sentry_config.get("SENTRY_ORGANIZATION")
token = sentry_config.get("SENTRY_AUTH_TOKEN")
@ -58,7 +58,7 @@ def get_tagged_issues_stats(
for tag, value in tags.items():
query += f" {tag}:{value}"
params: Dict[str, Union[list, str]] = {
params: dict[str, Union[list, str]] = {
"start": start_time,
"end": end_time,
"sort": "freq",
@ -89,8 +89,8 @@ def get_stats_for_timerange(
base_end_time: str,
target_start_time: str,
target_end_time: str,
tags: Optional[Dict[str, str]] = None,
) -> Tuple[int, int]:
tags: Optional[dict[str, str]] = None,
) -> tuple[int, int]:
base_counts, base_total_count = get_sentry_stats(base_start_time, base_end_time)
target_counts, target_total_count = get_sentry_stats(target_start_time, target_end_time)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
import jwt
from django.db.models import QuerySet
@ -67,7 +67,7 @@ class SubscriptionSerializer(serializers.ModelSerializer):
return attrs
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Subscription:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Subscription:
request = self.context["request"]
validated_data["team_id"] = self.context["team_id"]
validated_data["created_by"] = request.user

View File

@ -1,5 +1,5 @@
import datetime
from typing import Dict, Optional, cast
from typing import Optional, cast
from zoneinfo import ZoneInfo
@ -20,7 +20,7 @@ class LicensedTestMixin:
def license_required_response(
self,
message: str = "This feature is part of the premium PostHog offering. Self-hosted licenses are no longer available for purchase. Please contact sales@posthog.com to discuss options.",
) -> Dict[str, Optional[str]]:
) -> dict[str, Optional[str]]:
return {
"type": "server_error",
"code": "payment_required",

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, List
from typing import Any
AVAILABLE_PRODUCT_FEATURES: List[Dict[str, Any]] = [
AVAILABLE_PRODUCT_FEATURES: list[dict[str, Any]] = [
{
"description": "Create playlists of certain session recordings to easily find and watch them again in the future.",
"key": "recordings_playlists",

View File

@ -364,7 +364,6 @@ class TestEESAMLAuthenticationAPI(APILicensedTest):
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -407,7 +406,6 @@ class TestEESAMLAuthenticationAPI(APILicensedTest):
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_alt_attribute_names"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -474,7 +472,6 @@ YotAcSbU3p5bzd11wpyebYHB"""
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -514,7 +511,6 @@ YotAcSbU3p5bzd11wpyebYHB"""
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -552,7 +548,6 @@ YotAcSbU3p5bzd11wpyebYHB"""
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_no_first_name"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -594,7 +589,6 @@ YotAcSbU3p5bzd11wpyebYHB"""
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()
@ -683,7 +677,6 @@ YotAcSbU3p5bzd11wpyebYHB"""
with open(
os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"),
"r",
encoding="utf_8",
) as f:
saml_response = f.read()

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List
from typing import Any
from unittest.mock import MagicMock, patch
from uuid import uuid4
from zoneinfo import ZoneInfo
@ -22,7 +22,7 @@ from posthog.models.team import Team
from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events
def create_billing_response(**kwargs) -> Dict[str, Any]:
def create_billing_response(**kwargs) -> dict[str, Any]:
data: Any = {"license": {"type": "cloud"}}
data.update(kwargs)
return data
@ -106,7 +106,7 @@ def create_billing_customer(**kwargs) -> CustomerInfo:
return data
def create_billing_products_response(**kwargs) -> Dict[str, List[CustomerProduct]]:
def create_billing_products_response(**kwargs) -> dict[str, list[CustomerProduct]]:
data: Any = {
"products": [
CustomerProduct(

View File

@ -68,26 +68,26 @@ class TestCaptureAPI(APIBaseTest):
self.assertEqual(event2_data["properties"]["distinct_id"], "id2")
# Make sure we're producing data correctly in the way the plugin server expects
self.assertEquals(type(kafka_produce_call1["data"]["distinct_id"]), str)
self.assertEquals(type(kafka_produce_call2["data"]["distinct_id"]), str)
self.assertEqual(type(kafka_produce_call1["data"]["distinct_id"]), str)
self.assertEqual(type(kafka_produce_call2["data"]["distinct_id"]), str)
self.assertIn(type(kafka_produce_call1["data"]["ip"]), [str, type(None)])
self.assertIn(type(kafka_produce_call2["data"]["ip"]), [str, type(None)])
self.assertEquals(type(kafka_produce_call1["data"]["site_url"]), str)
self.assertEquals(type(kafka_produce_call2["data"]["site_url"]), str)
self.assertEqual(type(kafka_produce_call1["data"]["site_url"]), str)
self.assertEqual(type(kafka_produce_call2["data"]["site_url"]), str)
self.assertEquals(type(kafka_produce_call1["data"]["token"]), str)
self.assertEquals(type(kafka_produce_call2["data"]["token"]), str)
self.assertEqual(type(kafka_produce_call1["data"]["token"]), str)
self.assertEqual(type(kafka_produce_call2["data"]["token"]), str)
self.assertEquals(type(kafka_produce_call1["data"]["sent_at"]), str)
self.assertEquals(type(kafka_produce_call2["data"]["sent_at"]), str)
self.assertEqual(type(kafka_produce_call1["data"]["sent_at"]), str)
self.assertEqual(type(kafka_produce_call2["data"]["sent_at"]), str)
self.assertEquals(type(event1_data["properties"]), dict)
self.assertEquals(type(event2_data["properties"]), dict)
self.assertEqual(type(event1_data["properties"]), dict)
self.assertEqual(type(event2_data["properties"]), dict)
self.assertEquals(type(kafka_produce_call1["data"]["uuid"]), str)
self.assertEquals(type(kafka_produce_call2["data"]["uuid"]), str)
self.assertEqual(type(kafka_produce_call1["data"]["uuid"]), str)
self.assertEqual(type(kafka_produce_call2["data"]["uuid"]), str)
@patch("posthog.kafka_client.client._KafkaProducer.produce")
def test_capture_event_with_uuid_in_payload(self, kafka_produce):

View File

@ -106,7 +106,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest):
response_data = response.json()
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEquals(
self.assertEqual(
response_data,
self.permission_denied_response(
"Only the dashboard owner and project admins have the restriction rights required to change the dashboard's restriction level."
@ -178,7 +178,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest):
response_data = response.json()
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEquals(
self.assertEqual(
response_data,
self.permission_denied_response("You don't have edit permissions for this dashboard."),
)
@ -262,7 +262,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest):
response_data = response.json()
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEquals(
self.assertEqual(
response_data,
self.permission_denied_response("You don't have edit permissions for this dashboard."),
)

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import cast, Optional, List, Dict, Any
from typing import cast, Optional, Any
import dateutil.parser
from django.utils import timezone
@ -26,7 +26,7 @@ class TestEventDefinitionEnterpriseAPI(APIBaseTest):
Ignoring the verified field we'd expect ordering purchase, watched_movie, entered_free_trial, $pageview
With it we expect watched_movie, entered_free_trial, purchase, $pageview
"""
EXPECTED_EVENT_DEFINITIONS: List[Dict[str, Any]] = [
EXPECTED_EVENT_DEFINITIONS: list[dict[str, Any]] = [
{"name": "purchase", "verified": None},
{"name": "entered_free_trial", "verified": True},
{"name": "watched_movie", "verified": True},

View File

@ -1,6 +1,6 @@
import json
from datetime import timedelta
from typing import cast, Optional, List, Dict
from typing import cast, Optional
from django.test import override_settings
from django.utils import timezone
from freezegun import freeze_time
@ -305,7 +305,7 @@ class TestInsightEnterpriseAPI(APILicensedTest):
dashboard.refresh_from_db()
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEquals(
self.assertEqual(
response_data,
self.permission_denied_response(
"This insight is on a dashboard that can only be edited by its owner, team members invited to editing the dashboard, and project admins."
@ -547,7 +547,7 @@ class TestInsightEnterpriseAPI(APILicensedTest):
@override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False)
@snapshot_postgres_queries
def test_listing_insights_does_not_nplus1(self) -> None:
query_counts: List[int] = []
query_counts: list[int] = []
queries = []
for i in range(5):
@ -587,10 +587,10 @@ class TestInsightEnterpriseAPI(APILicensedTest):
f"received query counts\n\n{query_counts}",
)
def assert_insight_activity(self, insight_id: Optional[int], expected: List[Dict]):
def assert_insight_activity(self, insight_id: Optional[int], expected: list[dict]):
activity_response = self.dashboard_api.get_insight_activity(insight_id)
activity: List[Dict] = activity_response["results"]
activity: list[dict] = activity_response["results"]
self.maxDiff = None
assert activity == expected

View File

@ -25,7 +25,7 @@ class TestIntegration(APILicensedTest):
signature = (
"v0="
+ hmac.new(
"not-so-secret".encode("utf-8"),
b"not-so-secret",
sig_basestring.encode("utf-8"),
digestmod=hashlib.sha256,
).hexdigest()

View File

@ -1,4 +1,4 @@
from typing import cast, Optional, List, Dict
from typing import cast, Optional
from freezegun import freeze_time
import pytest
from django.db.utils import IntegrityError
@ -450,7 +450,7 @@ class TestPropertyDefinitionEnterpriseAPI(APIBaseTest):
plan="enterprise", valid_until=timezone.datetime(2500, 1, 19, 3, 14, 7)
)
properties: List[Dict] = [
properties: list[dict] = [
{"name": "1_when_verified", "verified": True},
{"name": "2_when_verified", "verified": True},
{"name": "3_when_verified", "verified": True},

View File

@ -1,6 +1,6 @@
import json
from dataclasses import asdict, dataclass, field
from typing import Any, List
from typing import Any
from unittest import mock
import pytest
@ -64,7 +64,7 @@ class TestTimeToSeeDataApi(APIBaseTest):
)
response = self.client.post("/api/time_to_see_data/sessions").json()
self.assertEquals(
self.assertEqual(
response,
[
{
@ -209,18 +209,18 @@ class QueryLogRow:
query_time_range_days: int = 1
has_joins: int = 0
has_json_operations: int = 0
filter_by_type: List[str] = field(default_factory=list)
breakdown_by: List[str] = field(default_factory=list)
entity_math: List[str] = field(default_factory=list)
filter_by_type: list[str] = field(default_factory=list)
breakdown_by: list[str] = field(default_factory=list)
entity_math: list[str] = field(default_factory=list)
filter: str = ""
ProfileEvents: dict = field(default_factory=dict)
tables: List[str] = field(default_factory=list)
columns: List[str] = field(default_factory=list)
tables: list[str] = field(default_factory=list)
columns: list[str] = field(default_factory=list)
query: str = ""
log_comment = ""
def insert(table: str, rows: List):
def insert(table: str, rows: list):
columns = asdict(rows[0]).keys()
all_values, params = [], {}

View File

@ -2,7 +2,6 @@
# Needs to be first to set up django environment
from .helpers import benchmark_clickhouse, no_materialized_columns, now
from datetime import timedelta
from typing import List, Tuple
from ee.clickhouse.materialized_columns.analyze import (
backfill_materialized_columns,
get_materialized_columns,
@ -29,7 +28,7 @@ from posthog.models.filters.filter import Filter
from posthog.models.property import PropertyName, TableWithProperties
from posthog.constants import FunnelCorrelationType
MATERIALIZED_PROPERTIES: List[Tuple[TableWithProperties, PropertyName]] = [
MATERIALIZED_PROPERTIES: list[tuple[TableWithProperties, PropertyName]] = [
("events", "$host"),
("events", "$current_url"),
("events", "$event_type"),

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, cast
from typing import Any, Optional, cast
import jwt
import requests
@ -53,7 +53,7 @@ class BillingManager:
def __init__(self, license):
self.license = license or get_cached_instance_license()
def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> Dict[str, Any]:
def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> dict[str, Any]:
if organization and self.license and self.license.is_v2_license:
billing_service_response = self._get_billing(organization)
@ -63,7 +63,7 @@ class BillingManager:
if organization and billing_service_response:
self.update_org_details(organization, billing_service_response)
response: Dict[str, Any] = {"available_features": []}
response: dict[str, Any] = {"available_features": []}
response["license"] = {"plan": self.license.plan}
@ -102,7 +102,7 @@ class BillingManager:
return response
def update_billing(self, organization: Organization, data: Dict[str, Any]) -> None:
def update_billing(self, organization: Organization, data: dict[str, Any]) -> None:
res = requests.patch(
f"{BILLING_SERVICE_URL}/api/billing/",
headers=self.get_auth_headers(organization),

View File

@ -1,5 +1,5 @@
from decimal import Decimal
from typing import Dict, List, Optional, TypedDict
from typing import Optional, TypedDict
from posthog.constants import AvailableFeature
@ -18,7 +18,7 @@ class CustomerProduct(TypedDict):
image_url: Optional[str]
type: str
free_allocation: int
tiers: List[Tier]
tiers: list[Tier]
tiered: bool
unit_amount_usd: Optional[Decimal]
current_amount_usd: Decimal
@ -51,16 +51,16 @@ class CustomerInfo(TypedDict):
deactivated: bool
has_active_subscription: bool
billing_period: BillingPeriod
available_features: List[AvailableFeature]
available_features: list[AvailableFeature]
current_total_amount_usd: Optional[str]
current_total_amount_usd_after_discount: Optional[str]
products: Optional[List[CustomerProduct]]
custom_limits_usd: Optional[Dict[str, str]]
usage_summary: Optional[Dict[str, Dict[str, Optional[int]]]]
products: Optional[list[CustomerProduct]]
custom_limits_usd: Optional[dict[str, str]]
usage_summary: Optional[dict[str, dict[str, Optional[int]]]]
free_trial_until: Optional[str]
discount_percent: Optional[int]
discount_amount_usd: Optional[str]
customer_trust_scores: Dict[str, int]
customer_trust_scores: dict[str, int]
class BillingStatus(TypedDict):

View File

@ -1,7 +1,8 @@
import copy
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict, cast
from typing import Optional, TypedDict, cast
from collections.abc import Mapping, Sequence
import dateutil.parser
import posthoganalytics
@ -66,13 +67,13 @@ def add_limited_team_tokens(resource: QuotaResource, tokens: Mapping[str, int],
redis_client.zadd(f"{cache_key}{resource.value}", tokens) # type: ignore # (zadd takes a Mapping[str, int] but the derived Union type is wrong)
def remove_limited_team_tokens(resource: QuotaResource, tokens: List[str], cache_key: QuotaLimitingCaches) -> None:
def remove_limited_team_tokens(resource: QuotaResource, tokens: list[str], cache_key: QuotaLimitingCaches) -> None:
redis_client = get_client()
redis_client.zrem(f"{cache_key}{resource.value}", *tokens)
@cache_for(timedelta(seconds=30), background_refresh=True)
def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> List[str]:
def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> list[str]:
now = timezone.now()
redis_client = get_client()
results = redis_client.zrangebyscore(f"{cache_key}{resource.value}", min=now.timestamp(), max="+inf")
@ -86,7 +87,7 @@ class UsageCounters(TypedDict):
def org_quota_limited_until(
organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: List[str]
organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: list[str]
) -> Optional[OrgQuotaLimitingInformation]:
if not organization.usage:
return None
@ -265,7 +266,7 @@ def sync_org_quota_limits(organization: Organization):
def get_team_attribute_by_quota_resource(organization: Organization, resource: QuotaResource):
if resource in [QuotaResource.EVENTS, QuotaResource.RECORDINGS]:
team_tokens: List[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x]
team_tokens: list[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x]
if not team_tokens:
capture_exception(Exception(f"quota_limiting: No team tokens found for organization: {organization.id}"))
@ -274,7 +275,7 @@ def get_team_attribute_by_quota_resource(organization: Organization, resource: Q
return team_tokens
if resource == QuotaResource.ROWS_SYNCED:
team_ids: List[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x]
team_ids: list[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x]
if not team_ids:
capture_exception(Exception(f"quota_limiting: No team ids found for organization: {organization.id}"))
@ -322,7 +323,7 @@ def set_org_usage_summary(
def update_all_org_billing_quotas(
dry_run: bool = False,
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
) -> tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]:
period = get_current_day()
period_start, period_end = period
@ -352,8 +353,8 @@ def update_all_org_billing_quotas(
)
)
todays_usage_report: Dict[str, UsageCounters] = {}
orgs_by_id: Dict[str, Organization] = {}
todays_usage_report: dict[str, UsageCounters] = {}
orgs_by_id: dict[str, Organization] = {}
# we iterate through all teams, and add their usage to the organization they belong to
for team in teams:
@ -373,12 +374,12 @@ def update_all_org_billing_quotas(
for field in team_report:
org_report[field] += team_report[field] # type: ignore
quota_limited_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limiting_suspended_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limited_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limiting_suspended_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource}
# Get the current quota limits so we can track to poshog if it changes
orgs_with_changes = set()
previously_quota_limited_team_tokens: Dict[str, List[str]] = {x.value: [] for x in QuotaResource}
previously_quota_limited_team_tokens: dict[str, list[str]] = {x.value: [] for x in QuotaResource}
for field in quota_limited_orgs:
previously_quota_limited_team_tokens[field] = list_limited_team_attributes(
@ -405,8 +406,8 @@ def update_all_org_billing_quotas(
elif quota_limited_until:
quota_limited_orgs[field][org_id] = quota_limited_until
quota_limited_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limiting_suspended_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limited_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource}
quota_limiting_suspended_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource}
# Convert the org ids to team tokens
for team in teams:

View File

@ -1,6 +1,7 @@
import re
from datetime import timedelta
from typing import Dict, Generator, List, Optional, Set, Tuple
from typing import Optional
from collections.abc import Generator
import structlog
@ -27,18 +28,18 @@ from posthog.models.property import PropertyName, TableColumn, TableWithProperti
from posthog.models.property_definition import PropertyDefinition
from posthog.models.team import Team
Suggestion = Tuple[TableWithProperties, TableColumn, PropertyName]
Suggestion = tuple[TableWithProperties, TableColumn, PropertyName]
logger = structlog.get_logger(__name__)
class TeamManager:
@instance_memoize
def person_properties(self, team_id: str) -> Set[str]:
def person_properties(self, team_id: str) -> set[str]:
return self._get_properties(GET_PERSON_PROPERTIES_COUNT, team_id)
@instance_memoize
def event_properties(self, team_id: str) -> Set[str]:
def event_properties(self, team_id: str) -> set[str]:
return set(
PropertyDefinition.objects.filter(team_id=team_id, type=PropertyDefinition.Type.EVENT).values_list(
"name", flat=True
@ -46,17 +47,17 @@ class TeamManager:
)
@instance_memoize
def person_on_events_properties(self, team_id: str) -> Set[str]:
def person_on_events_properties(self, team_id: str) -> set[str]:
return self._get_properties(GET_EVENT_PROPERTIES_COUNT.format(column_name="person_properties"), team_id)
@instance_memoize
def group_on_events_properties(self, group_type_index: int, team_id: str) -> Set[str]:
def group_on_events_properties(self, group_type_index: int, team_id: str) -> set[str]:
return self._get_properties(
GET_EVENT_PROPERTIES_COUNT.format(column_name=f"group{group_type_index}_properties"),
team_id,
)
def _get_properties(self, query, team_id) -> Set[str]:
def _get_properties(self, query, team_id) -> set[str]:
rows = sync_execute(query, {"team_id": team_id})
return {name for name, _ in rows}
@ -86,12 +87,12 @@ class Query:
return matches[0] if matches else None
@cached_property
def _all_properties(self) -> List[Tuple[str, PropertyName]]:
def _all_properties(self) -> list[tuple[str, PropertyName]]:
return re.findall(r"JSONExtract\w+\((\S+), '([^']+)'\)", self.query_string)
def properties(
self, team_manager: TeamManager
) -> Generator[Tuple[TableWithProperties, TableColumn, PropertyName], None, None]:
) -> Generator[tuple[TableWithProperties, TableColumn, PropertyName], None, None]:
# Reverse-engineer whether a property is an "event" or "person" property by getting their event definitions.
# :KLUDGE: Note that the same property will be found on both tables if both are used.
# We try to hone in on the right column by looking at the column from which the property is extracted.
@ -124,7 +125,7 @@ class Query:
yield "events", "group4_properties", property
def _analyze(since_hours_ago: int, min_query_time: int) -> List[Suggestion]:
def _analyze(since_hours_ago: int, min_query_time: int) -> list[Suggestion]:
"Finds columns that should be materialized"
raw_queries = sync_execute(
@ -179,7 +180,7 @@ LIMIT 100 -- Make sure we don't add 100s of columns in one run
def materialize_properties_task(
columns_to_materialize: Optional[List[Suggestion]] = None,
columns_to_materialize: Optional[list[Suggestion]] = None,
time_to_analyze_hours: int = MATERIALIZE_COLUMNS_ANALYSIS_PERIOD_HOURS,
maximum: int = MATERIALIZE_COLUMNS_MAX_AT_ONCE,
min_query_time: int = MATERIALIZE_COLUMNS_MINIMUM_QUERY_TIME,
@ -203,7 +204,7 @@ def materialize_properties_task(
else:
logger.info("Found no columns to materialize.")
properties: Dict[TableWithProperties, List[Tuple[PropertyName, TableColumn]]] = {
properties: dict[TableWithProperties, list[tuple[PropertyName, TableColumn]]] = {
"events": [],
"person": [],
}

View File

@ -1,6 +1,6 @@
import re
from datetime import timedelta
from typing import Dict, List, Literal, Tuple, Union, cast
from typing import Literal, Union, cast
from clickhouse_driver.errors import ServerException
from django.utils.timezone import now
@ -36,7 +36,7 @@ SHORT_TABLE_COLUMN_NAME = {
@cache_for(timedelta(minutes=15))
def get_materialized_columns(
table: TablesWithMaterializedColumns,
) -> Dict[Tuple[PropertyName, TableColumn], ColumnName]:
) -> dict[tuple[PropertyName, TableColumn], ColumnName]:
rows = sync_execute(
"""
SELECT comment, name
@ -141,7 +141,7 @@ def add_minmax_index(table: TablesWithMaterializedColumns, column_name: str):
def backfill_materialized_columns(
table: TableWithProperties,
properties: List[Tuple[PropertyName, TableColumn]],
properties: list[tuple[PropertyName, TableColumn]],
backfill_period: timedelta,
test_settings=None,
) -> None:
@ -215,7 +215,7 @@ def _materialized_column_name(
return f"{prefix}{property_str}{suffix}"
def _extract_property(comment: str) -> Tuple[PropertyName, TableColumn]:
def _extract_property(comment: str) -> tuple[PropertyName, TableColumn]:
# Old style comments have the format "column_materializer::property", dealing with the default table column.
# Otherwise, it's "column_materializer::table_column::property"
split_column = comment.split("::", 2)

View File

@ -1,5 +1,4 @@
import dataclasses
from typing import List
from posthog.client import sync_execute
from posthog.hogql.hogql import HogQLContext
@ -22,7 +21,7 @@ class MockEvent:
distinct_id: str
def _get_events_for_action(action: Action) -> List[MockEvent]:
def _get_events_for_action(action: Action) -> list[MockEvent]:
hogql_context = HogQLContext(team_id=action.team_id)
formatted_query, params = format_action_filter(
team_id=action.team_id, action=action, prepend="", hogql_context=hogql_context

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Literal, Union, cast
from typing import Literal, Union, cast
from uuid import UUID
import pytest
@ -43,7 +43,7 @@ from posthog.test.base import (
class TestPropFormat(ClickhouseTestMixin, BaseTest):
CLASS_DATA_LEVEL_SETUP = False
def _run_query(self, filter: Filter, **kwargs) -> List:
def _run_query(self, filter: Filter, **kwargs) -> list:
query, params = parse_prop_grouped_clauses(
property_group=filter.property_groups,
allow_denormalized_props=True,
@ -776,7 +776,7 @@ class TestPropFormat(ClickhouseTestMixin, BaseTest):
class TestPropDenormalized(ClickhouseTestMixin, BaseTest):
CLASS_DATA_LEVEL_SETUP = False
def _run_query(self, filter: Filter, join_person_tables=False) -> List:
def _run_query(self, filter: Filter, join_person_tables=False) -> list:
outer_properties = PropertyOptimizer().parse_property_groups(filter.property_groups).outer
query, params = parse_prop_grouped_clauses(
team_id=self.team.pk,
@ -1232,7 +1232,7 @@ TEST_BREAKDOWN_PROCESSING = [
@pytest.mark.parametrize("breakdown, table, query_alias, column, expected", TEST_BREAKDOWN_PROCESSING)
def test_breakdown_query_expression(
clean_up_materialised_columns,
breakdown: Union[str, List[str]],
breakdown: Union[str, list[str]],
table: TableWithProperties,
query_alias: Literal["prop", "value"],
column: str,
@ -1281,7 +1281,7 @@ TEST_BREAKDOWN_PROCESSING_MATERIALIZED = [
)
def test_breakdown_query_expression_materialised(
clean_up_materialised_columns,
breakdown: Union[str, List[str]],
breakdown: Union[str, list[str]],
table: TableWithProperties,
query_alias: Literal["prop", "value"],
column: str,
@ -1317,7 +1317,7 @@ def test_breakdown_query_expression_materialised(
@pytest.fixture
def test_events(db, team) -> List[UUID]:
def test_events(db, team) -> list[UUID]:
return [
_create_event(
event="$pageview",
@ -1958,7 +1958,7 @@ def test_combine_group_properties():
],
}
combined_group = PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])).combine_properties(
combined_group = PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])).combine_properties(
PropertyOperatorType.OR, [propertyC, propertyD]
)
assert combined_group.to_dict() == {

View File

@ -1,5 +1,5 @@
from typing import Counter as TCounter
from typing import Set, cast
from collections import Counter as TCounter
from typing import cast
from posthog.clickhouse.materialized_columns.column import ColumnName
from posthog.constants import TREND_FILTER_TYPE_ACTIONS, FunnelCorrelationType
@ -20,16 +20,16 @@ from posthog.queries.trends.util import is_series_group_based
class EnterpriseColumnOptimizer(FOSSColumnOptimizer):
@cached_property
def group_types_to_query(self) -> Set[GroupTypeIndex]:
def group_types_to_query(self) -> set[GroupTypeIndex]:
used_properties = self.used_properties_with_type("group")
return {cast(GroupTypeIndex, group_type_index) for _, _, group_type_index in used_properties}
@cached_property
def group_on_event_columns_to_query(self) -> Set[ColumnName]:
def group_on_event_columns_to_query(self) -> set[ColumnName]:
"Returns a list of event table group columns containing materialized properties that this query needs"
used_properties = self.used_properties_with_type("group")
columns_to_query: Set[ColumnName] = set()
columns_to_query: set[ColumnName] = set()
for group_type_index in range(5):
columns_to_query = columns_to_query.union(
@ -120,7 +120,7 @@ class EnterpriseColumnOptimizer(FOSSColumnOptimizer):
counter += get_action_tables_and_properties(entity.get_action())
if (
not isinstance(self.filter, (StickinessFilter, PropertiesTimelineFilter))
not isinstance(self.filter, StickinessFilter | PropertiesTimelineFilter)
and self.filter.correlation_type == FunnelCorrelationType.PROPERTIES
and self.filter.correlation_property_names
):

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple, cast
from typing import Any, cast
from posthog.constants import PropertyOperatorType
from posthog.models.cohort.util import get_count_operator
@ -15,18 +15,18 @@ from posthog.queries.util import PersonPropertiesMode
from posthog.schema import PersonsOnEventsMode
def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]:
def check_negation_clause(prop: PropertyGroup) -> tuple[bool, bool]:
has_negation_clause = False
has_primary_clase = False
if len(prop.values):
if isinstance(prop.values[0], PropertyGroup):
for p in cast(List[PropertyGroup], prop.values):
for p in cast(list[PropertyGroup], prop.values):
has_neg, has_primary = check_negation_clause(p)
has_negation_clause = has_negation_clause or has_neg
has_primary_clase = has_primary_clase or has_primary
else:
for property in cast(List[Property], prop.values):
for property in cast(list[Property], prop.values):
if property.negation:
has_negation_clause = True
else:
@ -42,7 +42,7 @@ def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]:
class EnterpriseCohortQuery(FOSSCohortQuery):
def get_query(self) -> Tuple[str, Dict[str, Any]]:
def get_query(self) -> tuple[str, dict[str, Any]]:
if not self._outer_property_groups:
# everything is pushed down, no behavioral stuff to do
# thus, use personQuery directly
@ -87,9 +87,9 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
return final_query, self.params
def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
res: str = ""
params: Dict[str, Any] = {}
params: dict[str, Any] = {}
if prop.type == "behavioral":
if prop.value == "performed_event":
@ -117,7 +117,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
return res, params
def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
event = (prop.event_type, prop.key)
column_name = f"stopped_event_condition_{prepend}_{idx}"
@ -152,7 +152,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
},
)
def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
event = (prop.event_type, prop.key)
column_name = f"restarted_event_condition_{prepend}_{idx}"
@ -191,7 +191,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
},
)
def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
event = (prop.event_type, prop.key)
entity_query, entity_params = self._get_entity(event, prepend, idx)
@ -212,7 +212,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
{f"{date_param}": date_value, **entity_params},
)
def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
event = (prop.event_type, prop.key)
entity_query, entity_params = self._get_entity(event, prepend, idx)
@ -266,7 +266,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
)
@cached_property
def sequence_filters_to_query(self) -> List[Property]:
def sequence_filters_to_query(self) -> list[Property]:
props = []
for prop in self._filter.property_groups.flat:
if prop.value == "performed_event_sequence":
@ -274,13 +274,13 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
return props
@cached_property
def sequence_filters_lookup(self) -> Dict[str, str]:
def sequence_filters_lookup(self) -> dict[str, str]:
lookup = {}
for idx, prop in enumerate(self.sequence_filters_to_query):
lookup[str(prop.to_dict())] = f"{idx}"
return lookup
def _get_sequence_query(self) -> Tuple[str, Dict[str, Any], str]:
def _get_sequence_query(self) -> tuple[str, dict[str, Any], str]:
params = {}
materialized_columns = list(self._column_optimizer.event_columns_to_query)
@ -356,7 +356,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
self.FUNNEL_QUERY_ALIAS,
)
def _get_sequence_filter(self, prop: Property, idx: int) -> Tuple[List[str], List[str], List[str], Dict[str, Any]]:
def _get_sequence_filter(self, prop: Property, idx: int) -> tuple[list[str], list[str], list[str], dict[str, Any]]:
event = validate_entity((prop.event_type, prop.key))
entity_query, entity_params = self._get_entity(event, f"event_sequence_{self._cohort_pk}", idx)
seq_event = validate_entity((prop.seq_event_type, prop.seq_event))
@ -405,7 +405,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery):
},
)
def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]:
def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]:
return (
f"{self.SEQUENCE_FIELD_ALIAS}_{self.sequence_filters_lookup[str(prop.to_dict())]}",
{},

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
from ee.clickhouse.materialized_columns.columns import ColumnName
from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer
@ -33,9 +33,9 @@ class EnterpriseEventQuery(EventQuery):
should_join_distinct_ids=False,
should_join_persons=False,
# Extra events/person table columns to fetch since parent query needs them
extra_fields: Optional[List[ColumnName]] = None,
extra_event_properties: Optional[List[PropertyName]] = None,
extra_person_fields: Optional[List[ColumnName]] = None,
extra_fields: Optional[list[ColumnName]] = None,
extra_event_properties: Optional[list[PropertyName]] = None,
extra_person_fields: Optional[list[ColumnName]] = None,
override_aggregate_users_by_distinct_id: Optional[bool] = None,
person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled,
**kwargs,
@ -62,7 +62,7 @@ class EnterpriseEventQuery(EventQuery):
self._column_optimizer = EnterpriseColumnOptimizer(self._filter, self._team_id)
def _get_groups_query(self) -> Tuple[str, Dict]:
def _get_groups_query(self) -> tuple[str, dict]:
if isinstance(self._filter, PropertiesTimelineFilter):
raise Exception("Properties Timeline never needs groups query")
return GroupsJoinQuery(

View File

@ -1,7 +1,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime
import json
from typing import List, Optional, Tuple, Type
from typing import Optional
from zoneinfo import ZoneInfo
from numpy.random import default_rng
@ -56,7 +56,7 @@ class ClickhouseFunnelExperimentResult:
feature_flag: FeatureFlag,
experiment_start_date: datetime,
experiment_end_date: Optional[datetime] = None,
funnel_class: Type[ClickhouseFunnel] = ClickhouseFunnel,
funnel_class: type[ClickhouseFunnel] = ClickhouseFunnel,
):
breakdown_key = f"$feature/{feature_flag.key}"
self.variants = [variant["key"] for variant in feature_flag.variants]
@ -148,9 +148,9 @@ class ClickhouseFunnelExperimentResult:
@staticmethod
def calculate_results(
control_variant: Variant,
test_variants: List[Variant],
priors: Tuple[int, int] = (1, 1),
) -> List[Probability]:
test_variants: list[Variant],
priors: tuple[int, int] = (1, 1),
) -> list[Probability]:
"""
Calculates probability that A is better than B. First variant is control, rest are test variants.
@ -186,9 +186,9 @@ class ClickhouseFunnelExperimentResult:
@staticmethod
def are_results_significant(
control_variant: Variant,
test_variants: List[Variant],
probabilities: List[Probability],
) -> Tuple[ExperimentSignificanceCode, Probability]:
test_variants: list[Variant],
probabilities: list[Probability],
) -> tuple[ExperimentSignificanceCode, Probability]:
def get_conversion_rate(variant: Variant):
return variant.success_count / (variant.success_count + variant.failure_count)
@ -226,7 +226,7 @@ class ClickhouseFunnelExperimentResult:
return ExperimentSignificanceCode.SIGNIFICANT, expected_loss
def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) -> float:
def calculate_expected_loss(target_variant: Variant, variants: list[Variant]) -> float:
"""
Calculates expected loss in conversion rate for a given variant.
Loss calculation comes from VWO's SmartStats technical paper:
@ -268,7 +268,7 @@ def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) ->
return loss / simulations_count
def simulate_winning_variant_for_conversion(target_variant: Variant, variants: List[Variant]) -> Probability:
def simulate_winning_variant_for_conversion(target_variant: Variant, variants: list[Variant]) -> Probability:
random_sampler = default_rng()
prior_success = 1
prior_failure = 1
@ -300,7 +300,7 @@ def simulate_winning_variant_for_conversion(target_variant: Variant, variants: L
return winnings / simulations_count
def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]:
def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]:
"""
Calculates the probability of winning for each variant.
"""

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, Optional
from typing import Optional
from rest_framework.exceptions import ValidationError
from ee.clickhouse.queries.experiments.funnel_experiment_result import ClickhouseFunnelExperimentResult
@ -55,7 +55,7 @@ class ClickhouseSecondaryExperimentResult:
return {"result": variants, **significance_results}
def get_funnel_conversion_rate_for_variants(self, insight_results) -> Dict[str, float]:
def get_funnel_conversion_rate_for_variants(self, insight_results) -> dict[str, float]:
variants = {}
for result in insight_results:
total = result[0]["count"]
@ -67,7 +67,7 @@ class ClickhouseSecondaryExperimentResult:
return variants
def get_trend_count_data_for_variants(self, insight_results) -> Dict[str, float]:
def get_trend_count_data_for_variants(self, insight_results) -> dict[str, float]:
# this assumes the Trend insight is Cumulative, unless using count per user
variants = {}

View File

@ -1,7 +1,6 @@
import unittest
from functools import lru_cache
from math import exp, lgamma, log
from typing import List
from flaky import flaky
@ -31,7 +30,7 @@ def logbeta(x: int, y: int) -> float:
# calculation: https://www.evanmiller.org/bayesian-ab-testing.html#binary_ab
def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: List[Variant]) -> Probability:
def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: list[Variant]) -> Probability:
"""
Calculates the probability of winning for target variant.
"""
@ -455,7 +454,7 @@ class TestFunnelExperimentCalculator(unittest.TestCase):
# calculation: https://www.evanmiller.org/bayesian-ab-testing.html#count_ab
def calculate_probability_of_winning_for_target_count_data(
target_variant: CountVariant, other_variants: List[CountVariant]
target_variant: CountVariant, other_variants: list[CountVariant]
) -> Probability:
"""
Calculates the probability of winning for target variant.

View File

@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass
from datetime import datetime
from functools import lru_cache
from math import exp, lgamma, log
from typing import List, Optional, Tuple, Type
from typing import Optional
from zoneinfo import ZoneInfo
from numpy.random import default_rng
@ -78,7 +78,7 @@ class ClickhouseTrendExperimentResult:
feature_flag: FeatureFlag,
experiment_start_date: datetime,
experiment_end_date: Optional[datetime] = None,
trend_class: Type[Trends] = Trends,
trend_class: type[Trends] = Trends,
custom_exposure_filter: Optional[Filter] = None,
):
breakdown_key = f"$feature/{feature_flag.key}"
@ -316,7 +316,7 @@ class ClickhouseTrendExperimentResult:
return control_variant, test_variants
@staticmethod
def calculate_results(control_variant: Variant, test_variants: List[Variant]) -> List[Probability]:
def calculate_results(control_variant: Variant, test_variants: list[Variant]) -> list[Probability]:
"""
Calculates probability that A is better than B. First variant is control, rest are test variants.
@ -346,9 +346,9 @@ class ClickhouseTrendExperimentResult:
@staticmethod
def are_results_significant(
control_variant: Variant,
test_variants: List[Variant],
probabilities: List[Probability],
) -> Tuple[ExperimentSignificanceCode, Probability]:
test_variants: list[Variant],
probabilities: list[Probability],
) -> tuple[ExperimentSignificanceCode, Probability]:
# TODO: Experiment with Expected Loss calculations for trend experiments
for variant in test_variants:
@ -375,7 +375,7 @@ class ClickhouseTrendExperimentResult:
return ExperimentSignificanceCode.SIGNIFICANT, p_value
def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: List[Variant]) -> float:
def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: list[Variant]) -> float:
random_sampler = default_rng()
simulations_count = 100_000
@ -399,7 +399,7 @@ def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants
return winnings / simulations_count
def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]:
def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]:
"""
Calculates the probability of winning for each variant.
"""
@ -458,7 +458,7 @@ def poisson_p_value(control_count, control_exposure, test_count, test_exposure):
return min(1, 2 * min(low_p_value, high_p_value))
def calculate_p_value(control_variant: Variant, test_variants: List[Variant]) -> Probability:
def calculate_p_value(control_variant: Variant, test_variants: list[Variant]) -> Probability:
best_test_variant = max(test_variants, key=lambda variant: variant.count)
return poisson_p_value(

View File

@ -1,4 +1,4 @@
from typing import Set, Union
from typing import Union
from posthog.client import sync_execute
from posthog.constants import TREND_FILTER_TYPE_ACTIONS
@ -20,7 +20,7 @@ def requires_flag_warning(filter: Filter, team: Team) -> bool:
{parsed_date_to}
"""
events: Set[Union[int, str]] = set()
events: set[Union[int, str]] = set()
entities_to_use = filter.entities
for entity in entities_to_use:

View File

@ -2,12 +2,8 @@ import dataclasses
import urllib.parse
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
TypedDict,
Union,
cast,
@ -40,7 +36,7 @@ from posthog.utils import generate_short_id
class EventDefinition(TypedDict):
event: str
properties: Dict[str, Any]
properties: dict[str, Any]
elements: list
@ -74,7 +70,7 @@ class FunnelCorrelationResponse(TypedDict):
queries, but we could use, for example, a dataclass
"""
events: List[EventOddsRatioSerialized]
events: list[EventOddsRatioSerialized]
skewed: bool
@ -153,7 +149,7 @@ class FunnelCorrelation:
)
@property
def properties_to_include(self) -> List[str]:
def properties_to_include(self) -> list[str]:
props_to_include = []
if (
self._team.person_on_events_mode != PersonsOnEventsMode.disabled
@ -203,7 +199,7 @@ class FunnelCorrelation:
return True
return False
def get_contingency_table_query(self) -> Tuple[str, Dict[str, Any]]:
def get_contingency_table_query(self) -> tuple[str, dict[str, Any]]:
"""
Returns a query string and params, which are used to generate the contingency table.
The query returns success and failure count for event / property values, along with total success and failure counts.
@ -216,7 +212,7 @@ class FunnelCorrelation:
return self.get_event_query()
def get_event_query(self) -> Tuple[str, Dict[str, Any]]:
def get_event_query(self) -> tuple[str, dict[str, Any]]:
funnel_persons_query, funnel_persons_params = self.get_funnel_actors_cte()
event_join_query = self._get_events_join_query()
@ -279,7 +275,7 @@ class FunnelCorrelation:
return query, params
def get_event_property_query(self) -> Tuple[str, Dict[str, Any]]:
def get_event_property_query(self) -> tuple[str, dict[str, Any]]:
if not self._filter.correlation_event_names:
raise ValidationError("Event Property Correlation expects atleast one event name to run correlation on")
@ -359,7 +355,7 @@ class FunnelCorrelation:
return query, params
def get_properties_query(self) -> Tuple[str, Dict[str, Any]]:
def get_properties_query(self) -> tuple[str, dict[str, Any]]:
if not self._filter.correlation_property_names:
raise ValidationError("Property Correlation expects atleast one Property to run correlation on")
@ -580,7 +576,7 @@ class FunnelCorrelation:
)
def _get_funnel_step_names(self):
events: Set[Union[int, str]] = set()
events: set[Union[int, str]] = set()
for entity in self._filter.entities:
if entity.type == TREND_FILTER_TYPE_ACTIONS:
action = entity.get_action()
@ -590,7 +586,7 @@ class FunnelCorrelation:
return sorted(events)
def _run(self) -> Tuple[List[EventOddsRatio], bool]:
def _run(self) -> tuple[list[EventOddsRatio], bool]:
"""
Run the diagnose query.
@ -834,7 +830,7 @@ class FunnelCorrelation:
).to_params()
return f"{self._base_uri}api/person/funnel/correlation?{urllib.parse.urlencode(params)}&cache_invalidation_key={cache_invalidation_key}"
def format_results(self, results: Tuple[List[EventOddsRatio], bool]) -> FunnelCorrelationResponse:
def format_results(self, results: tuple[list[EventOddsRatio], bool]) -> FunnelCorrelationResponse:
odds_ratios, skewed_totals = results
return {
"events": [self.serialize_event_odds_ratio(odds_ratio=odds_ratio) for odds_ratio in odds_ratios],
@ -847,7 +843,7 @@ class FunnelCorrelation:
return self.format_results(self._run())
def get_partial_event_contingency_tables(self) -> Tuple[List[EventContingencyTable], int, int]:
def get_partial_event_contingency_tables(self) -> tuple[list[EventContingencyTable], int, int]:
"""
For each event a person that started going through the funnel, gets stats
for how many of these users are sucessful and how many are unsuccessful.
@ -888,7 +884,7 @@ class FunnelCorrelation:
failure_total,
)
def get_funnel_actors_cte(self) -> Tuple[str, Dict[str, Any]]:
def get_funnel_actors_cte(self) -> tuple[str, dict[str, Any]]:
extra_fields = ["steps", "final_timestamp", "first_timestamp"]
for prop in self.properties_to_include:
@ -975,12 +971,12 @@ def get_entity_odds_ratio(event_contingency_table: EventContingencyTable, prior_
)
def build_selector(elements: List[Dict[str, Any]]) -> str:
def build_selector(elements: list[dict[str, Any]]) -> str:
# build a CSS select given an "elements_chain"
# NOTE: my source of what this should be doing is
# https://github.com/PostHog/posthog/blob/cc054930a47fb59940531e99a856add49a348ee5/frontend/src/scenes/events/createActionFromEvent.tsx#L36:L36
#
def element_to_selector(element: Dict[str, Any]) -> str:
def element_to_selector(element: dict[str, Any]) -> str:
if attr_id := element.get("attr_id"):
return f'[id="{attr_id}"]'

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
from django.db.models.query import QuerySet
from rest_framework.exceptions import ValidationError
@ -52,9 +52,9 @@ class FunnelCorrelationActors(ActorBaseQuery):
def get_actors(
self,
) -> Tuple[
) -> tuple[
Union[QuerySet[Person], QuerySet[Group]],
Union[List[SerializedGroup], List[SerializedPerson]],
Union[list[SerializedGroup], list[SerializedPerson]],
int,
]:
if self._filter.correlation_type == FunnelCorrelationType.PROPERTIES:
@ -167,7 +167,7 @@ class _FunnelPropertyCorrelationActors(ActorBaseQuery):
def actor_query(
self,
limit_actors: Optional[bool] = True,
extra_fields: Optional[List[str]] = None,
extra_fields: Optional[list[str]] = None,
):
if not self._filter.correlation_property_values:
raise ValidationError("Property Correlation expects atleast one Property to get persons for")

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List
from typing import Any
from posthog.constants import INSIGHT_FUNNELS
from posthog.models.filters import Filter
@ -51,8 +51,8 @@ def funnel_breakdown_group_test_factory(Funnel, FunnelPerson, _create_event, _cr
properties={"industry": "random"},
)
def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]):
def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]:
def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]):
def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]:
return {
"action_id": step.name if step.type == "events" else step.action_id,
"name": step.name,

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Union
from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer
from posthog.models import Filter
@ -35,7 +35,7 @@ class GroupsJoinQuery:
self._join_key = join_key
self._person_on_events_mode = person_on_events_mode
def get_join_query(self) -> Tuple[str, Dict]:
def get_join_query(self) -> tuple[str, dict]:
join_queries, params = [], {}
if self._person_on_events_mode != PersonsOnEventsMode.disabled and groups_on_events_querying_enabled():
@ -63,7 +63,7 @@ class GroupsJoinQuery:
return "\n".join(join_queries), params
def get_filter_query(self, group_type_index: GroupTypeIndex) -> Tuple[str, Dict]:
def get_filter_query(self, group_type_index: GroupTypeIndex) -> tuple[str, dict]:
var = f"group_index_{group_type_index}"
params = {
"team_id": self._team_id,

View File

@ -1,5 +1,5 @@
from re import escape
from typing import Dict, Literal, Optional, Tuple, Union, cast
from typing import Literal, Optional, Union, cast
from jsonschema import ValidationError
@ -34,8 +34,8 @@ class ClickhousePaths(Paths):
):
raise ValidationError("Max Edge weight can't be lower than min edge weight")
def get_edge_weight_clause(self) -> Tuple[str, Dict]:
params: Dict[str, int] = {}
def get_edge_weight_clause(self) -> tuple[str, dict]:
params: dict[str, int] = {}
conditions = []
@ -60,8 +60,8 @@ class ClickhousePaths(Paths):
else:
return ""
def get_target_clause(self) -> Tuple[str, Dict]:
params: Dict[str, Union[str, None]] = {
def get_target_clause(self) -> tuple[str, dict]:
params: dict[str, Union[str, None]] = {
"target_point": None,
"secondary_target_point": None,
}
@ -152,7 +152,7 @@ class ClickhousePaths(Paths):
else:
return "arraySlice"
def get_filtered_path_ordering(self) -> Tuple[str, ...]:
def get_filtered_path_ordering(self) -> tuple[str, ...]:
fields_to_include = ["filtered_path", "filtered_timings"] + [
f"filtered_{field}s" for field in self.extra_event_fields_and_properties
]

View File

@ -1,6 +1,6 @@
from datetime import timedelta
from functools import cached_property
from typing import List, Optional, Union
from typing import Optional, Union
from django.utils.timezone import now
@ -38,8 +38,8 @@ class RelatedActorsQuery:
self.group_type_index = validate_group_type_index("group_type_index", group_type_index)
self.id = id
def run(self) -> List[SerializedActor]:
results: List[SerializedActor] = []
def run(self) -> list[SerializedActor]:
results: list[SerializedActor] = []
results.extend(self._query_related_people())
for group_type_mapping in GroupTypeMapping.objects.filter(team_id=self.team.pk):
results.extend(self._query_related_groups(group_type_mapping.group_type_index))
@ -49,7 +49,7 @@ class RelatedActorsQuery:
def is_aggregating_by_groups(self) -> bool:
return self.group_type_index is not None
def _query_related_people(self) -> List[SerializedPerson]:
def _query_related_people(self) -> list[SerializedPerson]:
if not self.is_aggregating_by_groups:
return []
@ -72,7 +72,7 @@ class RelatedActorsQuery:
_, serialized_people = get_people(self.team, person_ids)
return serialized_people
def _query_related_groups(self, group_type_index: GroupTypeIndex) -> List[SerializedGroup]:
def _query_related_groups(self, group_type_index: GroupTypeIndex) -> list[SerializedGroup]:
if group_type_index == self.group_type_index:
return []
@ -102,7 +102,7 @@ class RelatedActorsQuery:
_, serialize_groups = get_groups(self.team.pk, group_type_index, group_ids)
return serialize_groups
def _take_first(self, rows: List) -> List:
def _take_first(self, rows: list) -> list:
return [row[0] for row in rows]
@property

View File

@ -1,5 +1,4 @@
from datetime import timedelta
from typing import Tuple
from unittest.mock import MagicMock
from uuid import UUID
@ -2905,7 +2904,7 @@ class TestClickhousePaths(ClickhouseTestMixin, APIBaseTest):
@snapshot_clickhouse_queries
def test_properties_queried_using_path_filter(self):
def should_query_list(filter) -> Tuple[bool, bool]:
def should_query_list(filter) -> tuple[bool, bool]:
path_query = PathEventQuery(filter, self.team)
return (path_query._should_query_url(), path_query._should_query_screen())

View File

@ -1,4 +1,5 @@
from typing import Any, Callable, Optional
from typing import Any, Optional
from collections.abc import Callable
from django.utils.timezone import now
from rest_framework import serializers, viewsets

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, cast
from typing import cast
from django.db.models import Q
from drf_spectacular.types import OpenApiTypes
@ -34,7 +34,7 @@ class ClickhouseGroupsTypesView(TeamAndOrgViewSetMixin, mixins.ListModelMixin, v
@action(detail=False, methods=["PATCH"], name="Update group types metadata")
def update_metadata(self, request: request.Request, *args, **kwargs):
for row in cast(List[Dict], request.data):
for row in cast(list[dict], request.data):
instance = GroupTypeMapping.objects.get(team=self.team, group_type_index=row["group_type_index"])
serializer = self.get_serializer(instance, data=row)
serializer.is_valid(raise_exception=True)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
from rest_framework.decorators import action
from rest_framework.permissions import SAFE_METHODS, BasePermission
@ -47,7 +47,7 @@ class ClickhouseInsightsViewSet(InsightViewSet):
return Response(result)
@cached_by_filters
def calculate_funnel_correlation(self, request: Request) -> Dict[str, Any]:
def calculate_funnel_correlation(self, request: Request) -> dict[str, Any]:
team = self.team
filter = Filter(request=request, team=team)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Optional
from rest_framework import request, response
from rest_framework.decorators import action
@ -28,7 +28,7 @@ class EnterprisePersonViewSet(PersonViewSet):
@cached_by_filters
def calculate_funnel_correlation_persons(
self, request: request.Request
) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]:
) -> dict[str, tuple[list, Optional[str], Optional[str], int]]:
filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team)
if not filter.correlation_person_limit:
filter = filter.shallow_clone({FUNNEL_CORRELATION_PERSON_LIMIT: 100})

View File

@ -552,15 +552,15 @@ class FunnelCorrelationTest(BaseTest):
),
)
(browser_correlation,) = [
(browser_correlation,) = (
correlation
for correlation in odds["result"]["events"]
if correlation["event"]["event"] == "$browser::1"
]
)
(notset_correlation,) = [
(notset_correlation,) = (
correlation for correlation in odds["result"]["events"] if correlation["event"]["event"] == "$browser::"
]
)
assert get_people_for_correlation_ok(client=self.client, correlation=browser_correlation) == {
"success": ["Person 2"],

View File

@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Dict, Literal, Optional, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
from django.test.client import Client
@ -12,7 +12,7 @@ class EventPattern(TypedDict, total=False):
id: str
type: Union[Literal["events"], Literal["actions"]]
order: int
properties: Dict[str, Any]
properties: dict[str, Any]
@dataclasses.dataclass
@ -46,7 +46,7 @@ def get_funnel(client: Client, team_id: int, request: FunnelRequest):
)
def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> Dict[str, Any]:
def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> dict[str, Any]:
response = get_funnel(client=client, team_id=team_id, request=request)
assert response.status_code == 200, response.content
@ -73,14 +73,14 @@ def get_funnel_correlation(client: Client, team_id: int, request: FunnelCorrelat
)
def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> Dict[str, Any]:
def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> dict[str, Any]:
response = get_funnel_correlation(client=client, team_id=team_id, request=request)
assert response.status_code == 200, response.content
return response.json()
def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> Dict[str, Any]:
def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> dict[str, Any]:
"""
Helper for getting people for a correlation. Note we keep checking to just
inclusion of name, to make the stable to changes in other people props.

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any
from flaky import flaky
@ -7,7 +7,7 @@ from posthog.models.signals import mute_selected_signals
from posthog.test.base import ClickhouseTestMixin, snapshot_clickhouse_queries
from posthog.test.test_journeys import journeys_for
DEFAULT_JOURNEYS_FOR_PAYLOAD: Dict[str, List[Dict[str, Any]]] = {
DEFAULT_JOURNEYS_FOR_PAYLOAD: dict[str, list[dict[str, Any]]] = {
# For a trend pageview metric
"person1": [
{

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import List, Literal, Optional, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
from django.test.client import Client
@ -719,10 +719,10 @@ class RetentionRequest:
period: Union[Literal["Hour"], Literal["Day"], Literal["Week"], Literal["Month"]]
retention_type: Literal["retention_first_time", "retention"] # probably not an exhaustive list
breakdowns: Optional[List[Breakdown]] = None
breakdowns: Optional[list[Breakdown]] = None
breakdown_type: Optional[Literal["person", "event"]] = None
properties: Optional[List[PropertyFilter]] = None
properties: Optional[list[PropertyFilter]] = None
filter_test_accounts: Optional[str] = None
limit: Optional[int] = None
@ -734,26 +734,26 @@ class Value(TypedDict):
class Cohort(TypedDict):
values: List[Value]
values: list[Value]
date: str
label: str
class RetentionResponse(TypedDict):
result: List[Cohort]
result: list[Cohort]
class Person(TypedDict):
distinct_ids: List[str]
distinct_ids: list[str]
class RetentionTableAppearance(TypedDict):
person: Person
appearances: List[int]
appearances: list[int]
class RetentionTablePeopleResponse(TypedDict):
result: List[RetentionTableAppearance]
result: list[RetentionTableAppearance]
def get_retention_ok(client: Client, team_id: int, request: RetentionRequest) -> RetentionResponse:

View File

@ -1,7 +1,7 @@
import json
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from unittest.case import skip
from unittest.mock import ANY
@ -420,20 +420,20 @@ class TrendsRequest:
insight: Optional[str] = None
display: Optional[str] = None
compare: Optional[bool] = None
events: List[Dict[str, Any]] = field(default_factory=list)
properties: List[Dict[str, Any]] = field(default_factory=list)
events: list[dict[str, Any]] = field(default_factory=list)
properties: list[dict[str, Any]] = field(default_factory=list)
smoothing_intervals: Optional[int] = 1
refresh: Optional[bool] = False
@dataclass
class TrendsRequestBreakdown(TrendsRequest):
breakdown: Optional[Union[List[int], str]] = None
breakdown: Optional[Union[list[int], str]] = None
breakdown_type: Optional[str] = None
def get_trends(client, request: Union[TrendsRequestBreakdown, TrendsRequest], team: Team):
data: Dict[str, Any] = {
data: dict[str, Any] = {
"date_from": request.date_from,
"date_to": request.date_to,
"interval": request.interval,
@ -471,7 +471,7 @@ class NormalizedTrendResult:
def get_trends_time_series_ok(
client: Client, request: TrendsRequest, team: Team, with_order: bool = False
) -> Dict[str, Dict[str, NormalizedTrendResult]]:
) -> dict[str, dict[str, NormalizedTrendResult]]:
data = get_trends_ok(client=client, request=request, team=team)
res = {}
for item in data["result"]:
@ -491,7 +491,7 @@ def get_trends_time_series_ok(
return res
def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> Dict[str, NormalizedTrendResult]:
def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> dict[str, NormalizedTrendResult]:
data = get_trends_ok(client=client, request=request, team=team)
res = {}
for item in data["result"]:

View File

@ -1,6 +1,5 @@
# Generated by Django 3.0.7 on 2020-08-07 09:15
from typing import List
from django.db import migrations, models
@ -8,7 +7,7 @@ from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies: List = []
dependencies: list = []
operations = [
migrations.CreateModel(

View File

@ -1,5 +1,5 @@
# Generated by Django 3.2.5 on 2022-03-02 22:44
from typing import Any, List, Tuple
from typing import Any
from django.core.paginator import Paginator
from django.db import migrations
@ -19,7 +19,7 @@ def forwards(apps, schema_editor):
EnterpriseEventDefinition = apps.get_model("ee", "EnterpriseEventDefinition")
EnterprisePropertyDefinition = apps.get_model("ee", "EnterprisePropertyDefinition")
createables: List[Tuple[Any, Any]] = []
createables: list[tuple[Any, Any]] = []
batch_size = 1_000
# Collect event definition tags and taggeditems

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from django.contrib.auth import get_user_model
from django.db import models
@ -85,7 +85,7 @@ class License(models.Model):
PLAN_TO_SORTING_VALUE = {SCALE_PLAN: 10, ENTERPRISE_PLAN: 20}
@property
def available_features(self) -> List[AvailableFeature]:
def available_features(self) -> list[AvailableFeature]:
return self.PLANS.get(self.plan, [])
@property

View File

@ -1,6 +1,5 @@
from django.conf import settings
from typing import List
from posthog.models import Team
from posthog.clickhouse.client import sync_execute
@ -9,7 +8,7 @@ BATCH_FLUSH_SIZE = settings.REPLAY_EMBEDDINGS_BATCH_SIZE
MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS
def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[str]:
def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> list[str]:
query = """
WITH embedded_sessions AS (
SELECT
@ -47,7 +46,7 @@ def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[s
)
def fetch_recordings_without_embeddings(team_id: int, offset=0) -> List[str]:
def fetch_recordings_without_embeddings(team_id: int, offset=0) -> list[str]:
team = Team.objects.get(id=team_id)
query = """

View File

@ -3,7 +3,7 @@ import tiktoken
import datetime
import pytz
from typing import Dict, Any, List, Tuple
from typing import Any
from abc import ABC, abstractmethod
from prometheus_client import Histogram, Counter
@ -88,7 +88,7 @@ class EmbeddingPreparation(ABC):
@staticmethod
@abstractmethod
def prepare(item, team) -> Tuple[str, str]:
def prepare(item, team) -> tuple[str, str]:
raise NotImplementedError()
@ -100,7 +100,7 @@ class SessionEmbeddingsRunner(ABC):
self.team = team
self.openai_client = OpenAI()
def run(self, items: List[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None:
def run(self, items: list[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None:
source_type = embeddings_preparation.source_type
try:
@ -196,7 +196,7 @@ class SessionEmbeddingsRunner(ABC):
"""Returns the number of tokens in a text string."""
return len(encoding.encode(string))
def _flush_embeddings_to_clickhouse(self, embeddings: List[Dict[str, Any]], source_type: str) -> None:
def _flush_embeddings_to_clickhouse(self, embeddings: list[dict[str, Any]], source_type: str) -> None:
try:
sync_execute(
"INSERT INTO session_replay_embeddings (session_id, team_id, embeddings, source_type, input) VALUES",
@ -213,7 +213,7 @@ class ErrorEmbeddingsPreparation(EmbeddingPreparation):
source_type = "error"
@staticmethod
def prepare(item: Tuple[str, str], _):
def prepare(item: tuple[str, str], _):
session_id = item[0]
error_message = item[1]
return session_id, error_message
@ -286,7 +286,7 @@ class SessionEventsEmbeddingsPreparation(EmbeddingPreparation):
return session_id, input
@staticmethod
def _compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str:
def _compact_result(event_name: str, current_url: int, elements_chain: dict[str, str] | str) -> str:
elements_string = (
elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain)
)

View File

@ -1,7 +1,7 @@
import dataclasses
from datetime import datetime
from typing import List, Dict, Any
from typing import Any
from posthog.models.element import chain_to_elements
from hashlib import shake_256
@ -12,11 +12,11 @@ class SessionSummaryPromptData:
# we may allow customisation of columns included in the future,
# and we alter the columns present as we process the data
# so want to stay as loose as possible here
columns: List[str] = dataclasses.field(default_factory=list)
results: List[List[Any]] = dataclasses.field(default_factory=list)
columns: list[str] = dataclasses.field(default_factory=list)
results: list[list[Any]] = dataclasses.field(default_factory=list)
# in order to reduce the number of tokens in the prompt
# we replace URLs with a placeholder and then pass this mapping of placeholder to URL into the prompt
url_mapping: Dict[str, str] = dataclasses.field(default_factory=dict)
url_mapping: dict[str, str] = dataclasses.field(default_factory=dict)
def is_empty(self) -> bool:
return not self.columns or not self.results
@ -63,7 +63,7 @@ def simplify_window_id(session_events: SessionSummaryPromptData) -> SessionSumma
# find window_id column index
window_id_index = session_events.column_index("$window_id")
window_id_mapping: Dict[str, int] = {}
window_id_mapping: dict[str, int] = {}
simplified_results = []
for result in session_events.results:
if window_id_index is None:
@ -128,7 +128,7 @@ def deduplicate_urls(session_events: SessionSummaryPromptData) -> SessionSummary
# find url column index
url_index = session_events.column_index("$current_url")
url_mapping: Dict[str, str] = {}
url_mapping: dict[str, str] = {}
deduplicated_results = []
for result in session_events.results:
if url_index is None:

View File

@ -1,5 +1,4 @@
from itertools import product
from typing import Dict
from unittest import mock
from uuid import uuid4
@ -131,7 +130,7 @@ class TestClickhouseSessionRecordingsListFromSessionReplay(ClickhouseTestMixin,
poe_v2: bool,
allow_denormalized_props: bool,
expected_poe_mode: PersonsOnEventsMode,
expected_query_params: Dict,
expected_query_params: dict,
unmaterialized_person_column_used: bool,
materialized_event_column_used: bool,
) -> None:

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import structlog
from django.db.models import Q, QuerySet
@ -49,7 +49,7 @@ def log_playlist_activity(
team_id: int,
user: User,
was_impersonated: bool,
changes: Optional[List[Change]] = None,
changes: Optional[list[Change]] = None,
) -> None:
"""
Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance
@ -101,7 +101,7 @@ class SessionRecordingPlaylistSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
last_modified_by = UserBasicSerializer(read_only=True)
def create(self, validated_data: Dict, *args, **kwargs) -> SessionRecordingPlaylist:
def create(self, validated_data: dict, *args, **kwargs) -> SessionRecordingPlaylist:
request = self.context["request"]
team = self.context["get_team"]()
@ -128,7 +128,7 @@ class SessionRecordingPlaylistSerializer(serializers.ModelSerializer):
return playlist
def update(self, instance: SessionRecordingPlaylist, validated_data: Dict, **kwargs) -> SessionRecordingPlaylist:
def update(self, instance: SessionRecordingPlaylist, validated_data: dict, **kwargs) -> SessionRecordingPlaylist:
try:
before_update = SessionRecordingPlaylist.objects.get(pk=instance.id)
except SessionRecordingPlaylist.DoesNotExist:

View File

@ -103,7 +103,7 @@ class TestSessionRecordingExtensions(ClickhouseTestMixin, APIBaseTest):
for file in ["a", "b", "c"]:
blob_path = f"{TEST_BUCKET}/team_id/{self.team.pk}/session_id/{session_id}/data"
file_name = f"{blob_path}/{file}"
write(file_name, f"my content-{file}".encode("utf-8"))
write(file_name, f"my content-{file}".encode())
recording: SessionRecording = SessionRecording.objects.create(team=self.team, session_id=session_id)
@ -164,7 +164,7 @@ class TestSessionRecordingExtensions(ClickhouseTestMixin, APIBaseTest):
mock_write.assert_called_with(
f"{expected_path}/12345000-12346000",
gzip.compress("the new content".encode("utf-8")),
gzip.compress(b"the new content"),
extras={
"ContentEncoding": "gzip",
"ContentType": "application/json",

View File

@ -3,14 +3,13 @@ Django settings for PostHog Enterprise Edition.
"""
import os
from typing import Dict, List
from posthog.settings import AUTHENTICATION_BACKENDS, DEMO, SITE_URL, DEBUG
from posthog.settings.utils import get_from_env
from posthog.utils import str_to_bool
# Zapier REST hooks
HOOK_EVENTS: Dict[str, str] = {
HOOK_EVENTS: dict[str, str] = {
# "event_name": "App.Model.Action" (created/updated/deleted)
"action_performed": "posthog.Action.performed",
}
@ -43,7 +42,7 @@ SOCIAL_AUTH_SAML_SUPPORT_CONTACT = SOCIAL_AUTH_SAML_TECHNICAL_CONTACT
SOCIAL_AUTH_GOOGLE_OAUTH2_KEY = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_KEY")
SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET")
if "SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS" in os.environ:
SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: List[str] = os.environ[
SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: list[str] = os.environ[
"SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS"
].split(",")
elif DEMO:

View File

@ -1,5 +1,4 @@
from datetime import datetime, timedelta
from typing import Dict
from zoneinfo import ZoneInfo
from celery import shared_task
@ -30,7 +29,7 @@ def check_feature_flag_rollback_conditions(feature_flag_id: int) -> None:
flag.save()
def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str) -> float:
def calculate_rolling_average(threshold_metric: dict, team: Team, timezone: str) -> float:
curr = datetime.now(tz=ZoneInfo(timezone))
rolling_average_days = 7
@ -54,7 +53,7 @@ def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str)
return sum(data) / rolling_average_days
def check_condition(rollback_condition: Dict, feature_flag: FeatureFlag) -> bool:
def check_condition(rollback_condition: dict, feature_flag: FeatureFlag) -> bool:
if rollback_condition["threshold_type"] == "sentry":
created_date = feature_flag.created_at
base_start_date = created_date.strftime("%Y-%m-%dT%H:%M:%S")

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any
import structlog
from celery import shared_task
@ -25,7 +25,7 @@ logger = structlog.get_logger(__name__)
# we currently are allowed 500 calls per minute, so let's rate limit each worker
# to much less than that
@shared_task(ignore_result=False, queue=CeleryQueue.SESSION_REPLAY_EMBEDDINGS.value, rate_limit="75/m")
def embed_batch_of_recordings_task(recordings: List[Any], team_id: int) -> None:
def embed_batch_of_recordings_task(recordings: list[Any], team_id: int) -> None:
try:
team = Team.objects.get(id=team_id)
runner = SessionEmbeddingsRunner(team=team)

View File

@ -1,5 +1,5 @@
import re
from typing import Any, Dict
from typing import Any
from urllib.parse import urlparse
import structlog
@ -16,7 +16,7 @@ logger = structlog.get_logger(__name__)
SHARED_LINK_REGEX = r"\/(?:shared_dashboard|shared|embedded)\/(.+)"
def _block_for_asset(asset: ExportedAsset) -> Dict:
def _block_for_asset(asset: ExportedAsset) -> dict:
image_url = asset.get_public_content_url()
alt_text = None
if asset.insight:

View File

@ -1,5 +1,5 @@
import uuid
from typing import List, Optional
from typing import Optional
import structlog
@ -15,7 +15,7 @@ logger = structlog.get_logger(__name__)
def send_email_subscription_report(
email: str,
subscription: Subscription,
assets: List[ExportedAsset],
assets: list[ExportedAsset],
invite_message: Optional[str] = None,
total_asset_count: Optional[int] = None,
) -> None:

View File

@ -1,5 +1,3 @@
from typing import Dict, List
import structlog
from django.conf import settings
@ -12,7 +10,7 @@ logger = structlog.get_logger(__name__)
UTM_TAGS_BASE = "utm_source=posthog&utm_campaign=subscription_report"
def _block_for_asset(asset: ExportedAsset) -> Dict:
def _block_for_asset(asset: ExportedAsset) -> dict:
image_url = asset.get_public_content_url()
alt_text = None
if asset.insight:
@ -26,7 +24,7 @@ def _block_for_asset(asset: ExportedAsset) -> Dict:
def send_slack_subscription_report(
subscription: Subscription,
assets: List[ExportedAsset],
assets: list[ExportedAsset],
total_asset_count: int,
is_new_subscription: bool = False,
) -> None:

View File

@ -1,5 +1,5 @@
import datetime
from typing import List, Tuple, Union
from typing import Union
from django.conf import settings
import structlog
from celery import chain
@ -28,7 +28,7 @@ SUBSCRIPTION_ASSET_GENERATION_TIMER = Histogram(
def generate_assets(
resource: Union[Subscription, SharingConfiguration],
max_asset_count: int = DEFAULT_MAX_ASSET_COUNT,
) -> Tuple[List[Insight], List[ExportedAsset]]:
) -> tuple[list[Insight], list[ExportedAsset]]:
with SUBSCRIPTION_ASSET_GENERATION_TIMER.time():
if resource.dashboard:
tiles = get_tiles_ordered_by_position(resource.dashboard)

View File

@ -1,5 +1,4 @@
from datetime import datetime
from typing import List
from unittest.mock import MagicMock, call, patch
from zoneinfo import ZoneInfo
@ -25,10 +24,10 @@ from posthog.test.base import APIBaseTest
@patch("ee.tasks.subscriptions.generate_assets")
@freeze_time("2022-02-02T08:55:00.000Z")
class TestSubscriptionsTasks(APIBaseTest):
subscriptions: List[Subscription] = None # type: ignore
subscriptions: list[Subscription] = None # type: ignore
dashboard: Dashboard
insight: Insight
tiles: List[DashboardTile] = None # type: ignore
tiles: list[DashboardTile] = None # type: ignore
asset: ExportedAsset
def setUp(self) -> None:

View File

@ -1,4 +1,3 @@
from typing import List
from unittest.mock import MagicMock, patch
import pytest
@ -21,7 +20,7 @@ class TestSubscriptionsTasksUtils(APIBaseTest):
dashboard: Dashboard
insight: Insight
asset: ExportedAsset
tiles: List[DashboardTile]
tiles: list[DashboardTile]
def setUp(self) -> None:
self.dashboard = Dashboard.objects.create(team=self.team, name="private dashboard", created_by=self.user)

View File

@ -1,4 +1,3 @@
from typing import List
from unittest.mock import MagicMock, patch
from freezegun import freeze_time
@ -14,7 +13,7 @@ from posthog.models.subscription import Subscription
from posthog.test.base import APIBaseTest
def create_mock_unfurl_event(team_id: str, links: List[str]):
def create_mock_unfurl_event(team_id: str, links: list[str]):
return {
"token": "XXYYZZ",
"team_id": team_id,

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any
from django.conf import settings
from django.contrib import admin
@ -92,7 +92,7 @@ admin_urlpatterns = (
)
urlpatterns: List[Any] = [
urlpatterns: list[Any] = [
path("api/saml/metadata/", authentication.saml_metadata_view),
path("api/sentry_stats/", sentry_stats.sentry_stats),
*admin_urlpatterns,

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import os

View File

@ -1,5 +1,5 @@
import re
from typing import List, Any, Dict
from typing import Any
from hogvm.python.operation import Operation, HOGQL_BYTECODE_IDENTIFIER
@ -33,7 +33,7 @@ def to_concat_arg(arg) -> str:
return str(arg)
def execute_bytecode(bytecode: List[Any], fields: Dict[str, Any]) -> Any:
def execute_bytecode(bytecode: list[Any], fields: dict[str, Any]) -> Any:
try:
stack = []
iterator = iter(bytecode)

View File

@ -53,7 +53,6 @@ import argparse
import json
import uuid
from sys import stderr, stdout
from typing import List
import numpy
from faker import Faker
@ -144,7 +143,7 @@ def get_parser():
def chunked(
data: str,
chunk_size: int,
) -> List[str]:
) -> list[str]:
return [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, cast
from typing import Any, cast
from django.db.models import Count, Prefetch
from rest_framework import request, serializers, viewsets
@ -123,7 +123,7 @@ class ActionSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedModelSe
return instance
def update(self, instance: Any, validated_data: Dict[str, Any]) -> Any:
def update(self, instance: Any, validated_data: dict[str, Any]) -> Any:
steps = validated_data.pop("steps", None)
# If there's no steps property at all we just ignore it
# If there is a step property but it's an empty array [], we'll delete all the steps
@ -182,7 +182,7 @@ class ActionViewSet(
def list(self, request: request.Request, *args: Any, **kwargs: Any) -> Response:
actions = self.get_queryset()
actions_list: List[Dict[Any, Any]] = self.serializer_class(
actions_list: list[dict[Any, Any]] = self.serializer_class(
actions, many=True, context={"request": request}
).data # type: ignore
return Response({"results": actions_list})

View File

@ -1,5 +1,5 @@
import time
from typing import Any, Optional, Dict
from typing import Any, Optional
from django.db.models import Q, QuerySet
@ -49,7 +49,7 @@ class ActivityLogPagination(pagination.CursorPagination):
# context manager for gathering a sequence of server timings
class ServerTimingsGathered:
# Class level dictionary to store timings
timings_dict: Dict[str, float] = {}
timings_dict: dict[str, float] = {}
def __call__(self, name):
self.name = name

View File

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any
from django.db.models import Q, QuerySet
from django.db.models.signals import post_save
@ -40,11 +40,11 @@ class AnnotationSerializer(serializers.ModelSerializer):
"updated_at",
]
def update(self, instance: Annotation, validated_data: Dict[str, Any]) -> Annotation:
def update(self, instance: Annotation, validated_data: dict[str, Any]) -> Annotation:
instance.team_id = self.context["team_id"]
return super().update(instance, validated_data)
def create(self, validated_data: Dict[str, Any], *args: Any, **kwargs: Any) -> Annotation:
def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> Annotation:
request = self.context["request"]
team = self.context["get_team"]()
annotation = Annotation.objects.create(

View File

@ -1,6 +1,6 @@
import datetime
import time
from typing import Any, Dict, Optional, cast
from typing import Any, Optional, cast
from uuid import uuid4
from django.conf import settings
@ -92,7 +92,7 @@ class LoginSerializer(serializers.Serializer):
email = serializers.EmailField()
password = serializers.CharField()
def to_representation(self, instance: Any) -> Dict[str, Any]:
def to_representation(self, instance: Any) -> dict[str, Any]:
return {"success": True}
def _check_if_2fa_required(self, user: User) -> bool:
@ -113,7 +113,7 @@ class LoginSerializer(serializers.Serializer):
pass
return True
def create(self, validated_data: Dict[str, str]) -> Any:
def create(self, validated_data: dict[str, str]) -> Any:
# Check SSO enforcement (which happens at the domain level)
sso_enforcement = OrganizationDomain.objects.get_sso_enforcement_for_email_address(validated_data["email"])
if sso_enforcement:
@ -159,10 +159,10 @@ class LoginSerializer(serializers.Serializer):
class LoginPrecheckSerializer(serializers.Serializer):
email = serializers.EmailField()
def to_representation(self, instance: Dict[str, str]) -> Dict[str, Any]:
def to_representation(self, instance: dict[str, str]) -> dict[str, Any]:
return instance
def create(self, validated_data: Dict[str, str]) -> Any:
def create(self, validated_data: dict[str, str]) -> Any:
email = validated_data.get("email", "")
# TODO: Refactor methods below to remove duplicate queries
return {

View File

@ -18,7 +18,8 @@ from sentry_sdk import configure_scope
from sentry_sdk.api import capture_exception, start_span
from statshog.defaults.django import statsd
from token_bucket import Limiter, MemoryStorage
from typing import Any, Dict, Iterator, List, Optional, Tuple, Set
from typing import Any, Optional
from collections.abc import Iterator
from ee.billing.quota_limiting import QuotaLimitingCaches
from posthog.api.utils import get_data, get_token, safe_clickhouse_string
@ -129,12 +130,12 @@ def build_kafka_event_data(
distinct_id: str,
ip: Optional[str],
site_url: str,
data: Dict,
data: dict,
now: datetime,
sent_at: Optional[datetime],
event_uuid: UUIDT,
token: str,
) -> Dict:
) -> dict:
logger.debug("build_kafka_event_data", token=token)
return {
"uuid": str(event_uuid),
@ -168,10 +169,10 @@ def _kafka_topic(event_name: str, historical: bool = False, overflowing: bool =
def log_event(
data: Dict,
data: dict,
event_name: str,
partition_key: Optional[str],
headers: Optional[List] = None,
headers: Optional[list] = None,
historical: bool = False,
overflowing: bool = False,
) -> FutureRecordMetadata:
@ -205,7 +206,7 @@ def _datetime_from_seconds_or_millis(timestamp: str) -> datetime:
return datetime.fromtimestamp(timestamp_number, timezone.utc)
def _get_sent_at(data, request) -> Tuple[Optional[datetime], Any]:
def _get_sent_at(data, request) -> tuple[Optional[datetime], Any]:
try:
if request.GET.get("_"): # posthog-js
sent_at = request.GET["_"]
@ -253,7 +254,7 @@ def _check_token_shape(token: Any) -> Optional[str]:
return None
def get_distinct_id(data: Dict[str, Any]) -> str:
def get_distinct_id(data: dict[str, Any]) -> str:
raw_value: Any = ""
try:
raw_value = data["$distinct_id"]
@ -274,12 +275,12 @@ def get_distinct_id(data: Dict[str, Any]) -> str:
return str(raw_value)[0:200]
def drop_performance_events(events: List[Any]) -> List[Any]:
def drop_performance_events(events: list[Any]) -> list[Any]:
cleaned_list = [event for event in events if event.get("event") != "$performance_event"]
return cleaned_list
def drop_events_over_quota(token: str, events: List[Any]) -> List[Any]:
def drop_events_over_quota(token: str, events: list[Any]) -> list[Any]:
if not settings.EE_AVAILABLE:
return events
@ -381,7 +382,7 @@ def get_event(request):
structlog.contextvars.bind_contextvars(token=token)
replay_events: List[Any] = []
replay_events: list[Any] = []
historical = token in settings.TOKENS_HISTORICAL_DATA
with start_span(op="request.process"):
@ -437,7 +438,7 @@ def get_event(request):
generate_exception_response("capture", f"Invalid payload: {e}", code="invalid_payload"),
)
futures: List[FutureRecordMetadata] = []
futures: list[FutureRecordMetadata] = []
with start_span(op="kafka.produce") as span:
span.set_tag("event.count", len(processed_events))
@ -536,7 +537,7 @@ def get_event(request):
return cors_response(request, JsonResponse({"status": 1}))
def preprocess_events(events: List[Dict[str, Any]]) -> Iterator[Tuple[Dict[str, Any], UUIDT, str]]:
def preprocess_events(events: list[dict[str, Any]]) -> Iterator[tuple[dict[str, Any], UUIDT, str]]:
for event in events:
event_uuid = UUIDT()
distinct_id = get_distinct_id(event)
@ -580,7 +581,7 @@ def capture_internal(
event_uuid=None,
token=None,
historical=False,
extra_headers: List[Tuple[str, str]] | None = None,
extra_headers: list[tuple[str, str]] | None = None,
):
if event_uuid is None:
event_uuid = UUIDT()
@ -680,7 +681,7 @@ def is_randomly_partitioned(candidate_partition_key: str) -> bool:
@cache_for(timedelta(seconds=30), background_refresh=True)
def _list_overflowing_keys(input_type: InputType) -> Set[str]:
def _list_overflowing_keys(input_type: InputType) -> set[str]:
"""Retrieve the active overflows from Redis with caching and pre-fetching
cache_for will keep the old value if Redis is temporarily unavailable.

View File

@ -18,7 +18,7 @@ import posthoganalytics
from posthog.metrics import LABEL_TEAM_ID
from posthog.renderers import SafeJSONRenderer
from datetime import datetime
from typing import Any, Dict, cast, Optional
from typing import Any, cast, Optional
from django.conf import settings
from django.db.models import QuerySet, Prefetch, prefetch_related_objects, OuterRef, Subquery
@ -133,7 +133,7 @@ class CohortSerializer(serializers.ModelSerializer):
"experiment_set",
]
def _handle_static(self, cohort: Cohort, context: Dict, validated_data: Dict) -> None:
def _handle_static(self, cohort: Cohort, context: dict, validated_data: dict) -> None:
request = self.context["request"]
if request.FILES.get("csv"):
self._calculate_static_by_csv(request.FILES["csv"], cohort)
@ -149,7 +149,7 @@ class CohortSerializer(serializers.ModelSerializer):
if filter_data:
insert_cohort_from_insight_filter.delay(cohort.pk, filter_data)
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort:
request = self.context["request"]
validated_data["created_by"] = request.user
@ -176,7 +176,7 @@ class CohortSerializer(serializers.ModelSerializer):
distinct_ids_and_emails = [row[0] for row in reader if len(row) > 0 and row]
calculate_cohort_from_list.delay(cohort.pk, distinct_ids_and_emails)
def validate_query(self, query: Optional[Dict]) -> Optional[Dict]:
def validate_query(self, query: Optional[dict]) -> Optional[dict]:
if not query:
return None
if not isinstance(query, dict):
@ -186,7 +186,7 @@ class CohortSerializer(serializers.ModelSerializer):
ActorsQuery.model_validate(query)
return query
def validate_filters(self, request_filters: Dict):
def validate_filters(self, request_filters: dict):
if isinstance(request_filters, dict) and "properties" in request_filters:
if self.context["request"].method == "PATCH":
parsed_filter = Filter(data=request_filters)
@ -225,7 +225,7 @@ class CohortSerializer(serializers.ModelSerializer):
else:
raise ValidationError("Filters must be a dictionary with a 'properties' key.")
def update(self, cohort: Cohort, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore
def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore
request = self.context["request"]
user = cast(User, request.user)
@ -498,7 +498,7 @@ def insert_cohort_query_actors_into_ch(cohort: Cohort):
insert_actors_into_cohort_by_query(cohort, query, {}, context)
def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict):
def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: dict):
from_existing_cohort_id = filter_data.get("from_cohort_id")
context: HogQLContext
@ -561,7 +561,7 @@ def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict):
insert_actors_into_cohort_by_query(cohort, query, params, context)
def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: Dict[str, Any], context: HogQLContext):
def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: dict[str, Any], context: HogQLContext):
try:
sync_execute(
INSERT_COHORT_ALL_PEOPLE_THROUGH_PERSON_ID.format(cohort_table=PERSON_STATIC_COHORT_TABLE, query=query),
@ -600,7 +600,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int,
cohort = Cohort.objects.get(pk=cohort_id, team_id=team_id)
matcher_cache = FlagsMatcherCache(team_id)
uuids_to_add_to_cohort = []
cohorts_cache: Dict[int, CohortOrEmpty] = {}
cohorts_cache: dict[int, CohortOrEmpty] = {}
if feature_flag.uses_cohorts:
# TODO: Consider disabling flags with cohorts for creating static cohorts
@ -709,7 +709,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int,
capture_exception(err)
def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortOrEmpty]):
def get_default_person_property(prop: Property, cohorts_cache: dict[int, CohortOrEmpty]):
default_person_properties = {}
if prop.operator not in ("is_set", "is_not_set") and prop.type == "person":
@ -725,7 +725,7 @@ def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortO
return default_person_properties
def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: Dict[int, CohortOrEmpty]) -> Dict[str, str]:
def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: dict[int, CohortOrEmpty]) -> dict[str, str]:
"""
Returns a dictionary of default person properties to use when evaluating a feature flag
"""

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, cast
from typing import Any, cast
from django.db import transaction
from django.db.models import QuerySet
@ -40,7 +40,7 @@ class CommentSerializer(serializers.ModelSerializer):
validated_data["team_id"] = self.context["team_id"]
return super().create(validated_data)
def update(self, instance: Comment, validated_data: Dict, **kwargs) -> Comment:
def update(self, instance: Comment, validated_data: dict, **kwargs) -> Comment:
request = self.context["request"]
with transaction.atomic():

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Optional, cast
import structlog
from django.db.models import Prefetch, QuerySet
@ -155,13 +155,13 @@ class DashboardSerializer(DashboardBasicSerializer):
]
read_only_fields = ["creation_mode", "effective_restriction_level", "is_shared"]
def validate_filters(self, value) -> Dict:
def validate_filters(self, value) -> dict:
if not isinstance(value, dict):
raise serializers.ValidationError("Filters must be a dictionary")
return value
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard:
request = self.context["request"]
validated_data["created_by"] = request.user
team_id = self.context["team_id"]
@ -260,7 +260,7 @@ class DashboardSerializer(DashboardBasicSerializer):
color=existing_tile.color,
)
def update(self, instance: Dashboard, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard:
def update(self, instance: Dashboard, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard:
can_user_restrict = self.user_permissions.dashboard(instance).can_restrict
if "restriction_level" in validated_data and not can_user_restrict:
raise exceptions.PermissionDenied(
@ -292,11 +292,11 @@ class DashboardSerializer(DashboardBasicSerializer):
return instance
@staticmethod
def _update_tiles(instance: Dashboard, tile_data: Dict, user: User) -> None:
def _update_tiles(instance: Dashboard, tile_data: dict, user: User) -> None:
tile_data.pop("is_cached", None) # read only field
if tile_data.get("text", None):
text_json: Dict = tile_data.get("text", {})
text_json: dict = tile_data.get("text", {})
created_by_json = text_json.get("created_by", None)
if created_by_json:
last_modified_by = user
@ -348,7 +348,7 @@ class DashboardSerializer(DashboardBasicSerializer):
insights_to_undelete.append(tile.insight)
Insight.objects.bulk_update(insights_to_undelete, ["deleted"])
def get_tiles(self, dashboard: Dashboard) -> Optional[List[ReturnDict]]:
def get_tiles(self, dashboard: Dashboard) -> Optional[list[ReturnDict]]:
if self.context["view"].action == "list":
return None
@ -401,7 +401,7 @@ class DashboardsViewSet(
queryset = Dashboard.objects_including_soft_deleted.order_by("name")
permission_classes = [CanEditDashboard]
def get_serializer_class(self) -> Type[BaseSerializer]:
def get_serializer_class(self) -> type[BaseSerializer]:
return DashboardBasicSerializer if self.action == "list" else DashboardSerializer
def get_queryset(self) -> QuerySet:
@ -512,7 +512,7 @@ class DashboardsViewSet(
class LegacyDashboardsViewSet(DashboardsViewSet):
derive_current_team_from_user_only = True
def get_parents_query_dict(self) -> Dict[str, Any]:
def get_parents_query_dict(self) -> dict[str, Any]:
if not self.request.user.is_authenticated or "share_token" in self.request.GET:
return {}
return {"team_id": self.team_id}

View File

@ -15,9 +15,7 @@ class DashboardTemplateCreationJSONSchemaParser(JSONParser):
The template is sent in the "template" key"""
def parse(self, stream, media_type=None, parser_context=None):
data = super(DashboardTemplateCreationJSONSchemaParser, self).parse(
stream, media_type or "application/json", parser_context
)
data = super().parse(stream, media_type or "application/json", parser_context)
try:
template = data["template"]
jsonschema.validate(template, dashboard_template_schema)

View File

@ -1,6 +1,5 @@
import json
from pathlib import Path
from typing import Dict
import structlog
from django.db.models import Q
@ -50,7 +49,7 @@ class DashboardTemplateSerializer(serializers.ModelSerializer):
"scope",
]
def create(self, validated_data: Dict, *args, **kwargs) -> DashboardTemplate:
def create(self, validated_data: dict, *args, **kwargs) -> DashboardTemplate:
if not validated_data["tiles"]:
raise ValidationError(detail="You need to provide tiles for the template.")
@ -61,7 +60,7 @@ class DashboardTemplateSerializer(serializers.ModelSerializer):
validated_data["team_id"] = self.context["team_id"]
return super().create(validated_data, *args, **kwargs)
def update(self, instance: DashboardTemplate, validated_data: Dict, *args, **kwargs) -> DashboardTemplate:
def update(self, instance: DashboardTemplate, validated_data: dict, *args, **kwargs) -> DashboardTemplate:
# if the original request was to make the template scope to team only, and the template is none then deny the request
if validated_data.get("scope") == "team" and instance.scope == "global" and not instance.team_id:
raise ValidationError(detail="The original templates cannot be made private as they would be lost.")

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, List
from typing import Optional
from rest_framework import status
@ -510,7 +510,7 @@ class TestDashboardTemplates(APIBaseTest):
assert flag_response.status_code == status.HTTP_200_OK
assert [(r["id"], r["scope"]) for r in flag_response.json()["results"]] == [(flag_template_id, "feature_flag")]
def create_template(self, overrides: Dict[str, str | List[str]], team_id: Optional[int] = None) -> str:
def create_template(self, overrides: dict[str, str | list[str]], team_id: Optional[int] = None) -> str:
template = {**variable_template, **overrides}
response = self.client.post(
f"/api/projects/{team_id or self.team.pk}/dashboard_templates",

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
from rest_framework import mixins, permissions, serializers, viewsets
@ -65,7 +65,7 @@ class DeadLetterQueueMetric:
key: str = ""
metric: str = ""
value: Union[str, bool, int, None] = None
subrows: Optional[List[Any]] = None
subrows: Optional[list[Any]] = None
def __init__(self, **kwargs):
for field in ("key", "metric", "value", "subrows"):
@ -138,7 +138,7 @@ def get_dead_letter_queue_events_last_24h() -> int:
)[0][0]
def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Union[str, int]]:
def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> list[Union[str, int]]:
return sync_execute(
f"""
SELECT error, count(*) AS c
@ -151,7 +151,7 @@ def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Un
)
def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List[Union[str, int]]:
def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> list[Union[str, int]]:
return sync_execute(
f"""
SELECT error_location, count(*) AS c
@ -164,7 +164,7 @@ def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List
)
def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Union[str, int]]:
def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> list[Union[str, int]]:
return sync_execute(
f"""
SELECT toDate(error_timestamp) as day, count(*) AS c
@ -177,7 +177,7 @@ def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Unio
)
def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> List[Union[str, int]]:
def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> list[Union[str, int]]:
return sync_execute(
f"""
SELECT arrayJoin(tags) as tag, count(*) as c from events_dead_letter_queue

View File

@ -1,6 +1,6 @@
import re
from random import random
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from urllib.parse import urlparse
import structlog
@ -56,7 +56,7 @@ def on_permitted_recording_domain(team: Team, request: HttpRequest) -> bool:
return is_authorized_web_client or is_authorized_mobile_client
def hostname_in_allowed_url_list(allowed_url_list: Optional[List[str]], hostname: Optional[str]) -> bool:
def hostname_in_allowed_url_list(allowed_url_list: Optional[list[str]], hostname: Optional[str]) -> bool:
if not hostname:
return False
@ -182,7 +182,7 @@ def get_decide(request: HttpRequest):
if geoip_enabled:
property_overrides = get_geoip_properties(get_ip_address(request))
all_property_overrides: Dict[str, Union[str, int]] = {
all_property_overrides: dict[str, Union[str, int]] = {
**property_overrides,
**(data.get("person_properties") or {}),
}
@ -296,8 +296,8 @@ def get_decide(request: HttpRequest):
return cors_response(request, JsonResponse(response))
def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | Dict:
session_recording_config_response: bool | Dict = False
def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | dict:
session_recording_config_response: bool | dict = False
try:
if team.session_recording_opt_in and (
@ -312,7 +312,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool
linked_flag = None
linked_flag_config = team.session_recording_linked_flag or None
if isinstance(linked_flag_config, Dict):
if isinstance(linked_flag_config, dict):
linked_flag_key = linked_flag_config.get("key", None)
linked_flag_variant = linked_flag_config.get("variant", None)
if linked_flag_variant is not None:
@ -330,7 +330,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool
"networkPayloadCapture": team.session_recording_network_payload_capture_config or None,
}
if isinstance(team.session_replay_config, Dict):
if isinstance(team.session_replay_config, dict):
record_canvas = team.session_replay_config.get("record_canvas", False)
session_recording_config_response.update(
{

View File

@ -1,5 +1,5 @@
import re
from typing import Dict, get_args
from typing import get_args
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
@ -215,7 +215,7 @@ def preprocess_exclude_path_format(endpoints, **kwargs):
def custom_postprocessing_hook(result, generator, request, public):
all_tags = []
paths: Dict[str, Dict] = {}
paths: dict[str, dict] = {}
for path, methods in result["paths"].items():
paths[path] = {}

View File

@ -1,5 +1,3 @@
from typing import Type
from django.http import JsonResponse
from rest_framework.response import Response
from posthog.api.feature_flag import FeatureFlagSerializer, MinimalFeatureFlagSerializer
@ -221,7 +219,7 @@ class EarlyAccessFeatureViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "early_access_feature"
queryset = EarlyAccessFeature.objects.select_related("feature_flag").all()
def get_serializer_class(self) -> Type[serializers.Serializer]:
def get_serializer_class(self) -> type[serializers.Serializer]:
if self.request.method == "POST":
return EarlyAccessFeatureSerializerCreateOnly
else:

View File

@ -1,4 +1,4 @@
from typing import Literal, Tuple
from typing import Literal
from rest_framework import request, response, serializers, viewsets
from rest_framework.decorators import action
@ -128,8 +128,8 @@ class ElementViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
else:
return response.Response(serialized_elements)
def _events_filter(self, request) -> Tuple[Literal["$autocapture", "$rageclick"], ...]:
event_to_filter: Tuple[Literal["$autocapture", "$rageclick"], ...] = ()
def _events_filter(self, request) -> tuple[Literal["$autocapture", "$rageclick"], ...]:
event_to_filter: tuple[Literal["$autocapture", "$rageclick"], ...] = ()
# when multiple includes are sent expects them as separate parameters
# e.g. ?include=a&include=b
events_to_include = request.query_params.getlist("include", [])

View File

@ -1,7 +1,7 @@
import json
import urllib
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, List, Optional, Union # noqa: UP035
from django.db.models.query import Prefetch
from drf_spectacular.types import OpenApiTypes
@ -94,7 +94,7 @@ class EventViewSet(
self,
request: request.Request,
last_event_timestamp: datetime,
order_by: List[str],
order_by: list[str],
) -> str:
params = request.GET.dict()
reverse = "-timestamp" in order_by
@ -175,7 +175,7 @@ class EventViewSet(
team = self.team
filter = Filter(request=request, team=self.team)
order_by: List[str] = (
order_by: list[str] = (
list(json.loads(request.GET["orderBy"])) if request.GET.get("orderBy") else ["-timestamp"]
)
@ -217,11 +217,11 @@ class EventViewSet(
capture_exception(ex)
raise ex
def _get_people(self, query_result: List[Dict], team: Team) -> Dict[str, Any]:
def _get_people(self, query_result: List[dict], team: Team) -> dict[str, Any]: # noqa: UP006
distinct_ids = [event["distinct_id"] for event in query_result]
persons = get_persons_by_distinct_ids(team.pk, distinct_ids)
persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache"))
distinct_to_person: Dict[str, Person] = {}
distinct_to_person: dict[str, Person] = {}
for person in persons:
for distinct_id in person.distinct_ids:
distinct_to_person[distinct_id] = person

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Tuple, Type, cast
from typing import Any, Literal, cast
from django.db.models import Manager, Prefetch
from rest_framework import (
@ -117,7 +117,7 @@ class EventDefinitionViewSet(
def _ordering_params_from_request(
self,
) -> Tuple[str, Literal["ASC", "DESC"]]:
) -> tuple[str, Literal["ASC", "DESC"]]:
order_direction: Literal["ASC", "DESC"]
ordering = self.request.GET.get("ordering")
@ -154,7 +154,7 @@ class EventDefinitionViewSet(
return EventDefinition.objects.get(id=id, team_id=self.team_id)
def get_serializer_class(self) -> Type[serializers.ModelSerializer]:
def get_serializer_class(self) -> type[serializers.ModelSerializer]:
serializer_class = self.serializer_class
if EE_AVAILABLE and self.request.user.organization.is_feature_available( # type: ignore
AvailableFeature.INGESTION_TAXONOMY

View File

@ -1,5 +1,5 @@
from datetime import timedelta
from typing import Any, Dict
from typing import Any
import structlog
from django.http import HttpResponse
@ -40,7 +40,7 @@ class ExportedAssetSerializer(serializers.ModelSerializer):
]
read_only_fields = ["id", "created_at", "has_content", "filename"]
def validate(self, data: Dict) -> Dict:
def validate(self, data: dict) -> dict:
if not data.get("export_format"):
raise ValidationError("Must provide export format")
@ -61,13 +61,13 @@ class ExportedAssetSerializer(serializers.ModelSerializer):
def synthetic_create(self, reason: str, *args: Any, **kwargs: Any) -> ExportedAsset:
return self._create_asset(self.validated_data, user=None, reason=reason)
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> ExportedAsset:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExportedAsset:
request = self.context["request"]
return self._create_asset(validated_data, user=request.user, reason=None)
def _create_asset(
self,
validated_data: Dict,
validated_data: dict,
user: User | None,
reason: str | None,
) -> ExportedAsset:

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional, cast
from typing import Any, Optional, cast
from datetime import datetime
from django.db.models import QuerySet, Q, deletion
@ -145,12 +145,12 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
and feature_flag.aggregation_group_type_index is None
)
def get_features(self, feature_flag: FeatureFlag) -> Dict:
def get_features(self, feature_flag: FeatureFlag) -> dict:
from posthog.api.early_access_feature import MinimalEarlyAccessFeatureSerializer
return MinimalEarlyAccessFeatureSerializer(feature_flag.features, many=True).data
def get_surveys(self, feature_flag: FeatureFlag) -> Dict:
def get_surveys(self, feature_flag: FeatureFlag) -> dict:
from posthog.api.survey import SurveyAPISerializer
return SurveyAPISerializer(feature_flag.surveys_linked_flag, many=True).data
@ -263,7 +263,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
return filters
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag:
request = self.context["request"]
validated_data["created_by"] = request.user
validated_data["team_id"] = self.context["team_id"]
@ -299,7 +299,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
return instance
def update(self, instance: FeatureFlag, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag:
def update(self, instance: FeatureFlag, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag:
if "deleted" in validated_data and validated_data["deleted"] is True and instance.features.count() > 0:
raise exceptions.ValidationError(
"Cannot delete a feature flag that is in use with early access features. Please delete the early access feature before deleting the flag."
@ -496,13 +496,11 @@ class FeatureFlagViewSet(
feature_flags, many=True, context=self.get_serializer_context()
).data
return Response(
(
{
"feature_flag": feature_flag,
"value": matches.get(feature_flag["key"], False),
}
for feature_flag in all_serialized_flags
)
{
"feature_flag": feature_flag,
"value": matches.get(feature_flag["key"], False),
}
for feature_flag in all_serialized_flags
)
@action(
@ -516,7 +514,7 @@ class FeatureFlagViewSet(
should_send_cohorts = "send_cohorts" in request.GET
cohorts = {}
seen_cohorts_cache: Dict[int, CohortOrEmpty] = {}
seen_cohorts_cache: dict[int, CohortOrEmpty] = {}
if should_send_cohorts:
seen_cohorts_cache = {

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
import structlog
from django.contrib.gis.geoip2 import GeoIP2
@ -27,7 +27,7 @@ VALID_GEOIP_PROPERTIES = [
]
def get_geoip_properties(ip_address: Optional[str]) -> Dict[str, str]:
def get_geoip_properties(ip_address: Optional[str]) -> dict[str, str]:
"""
Returns a dictionary of geoip properties for the given ip address.

View File

@ -1,6 +1,6 @@
import json
from functools import lru_cache
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Optional, Union, cast
import structlog
from django.db import transaction
@ -118,7 +118,7 @@ def log_insight_activity(
team_id: int,
user: User,
was_impersonated: bool,
changes: Optional[List[Change]] = None,
changes: Optional[list[Change]] = None,
) -> None:
"""
Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance
@ -148,7 +148,7 @@ class QuerySchemaParser(JSONParser):
"""
def parse(self, stream, media_type=None, parser_context=None):
data = super(QuerySchemaParser, self).parse(stream, media_type, parser_context)
data = super().parse(stream, media_type, parser_context)
try:
query = data.get("query", None)
if query:
@ -197,7 +197,7 @@ class InsightBasicSerializer(TaggedItemSerializerMixin, serializers.ModelSeriali
]
read_only_fields = ("short_id", "updated_at", "last_refresh", "refreshing")
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Any:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError()
def to_representation(self, instance):
@ -306,7 +306,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin):
"is_cached",
)
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Insight:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Insight:
request = self.context["request"]
tags = validated_data.pop("tags", None) # tags are created separately as global tag relationships
team_id = self.context["team_id"]
@ -345,8 +345,8 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin):
return insight
def update(self, instance: Insight, validated_data: Dict, **kwargs) -> Insight:
dashboards_before_change: List[Union[str, Dict]] = []
def update(self, instance: Insight, validated_data: dict, **kwargs) -> Insight:
dashboards_before_change: list[Union[str, dict]] = []
try:
# since it is possible to be undeleting a soft deleted insight
# the state captured before the update has to include soft deleted insights
@ -411,7 +411,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin):
changes=changes,
)
def _synthetic_dashboard_changes(self, dashboards_before_change: List[Dict]) -> List[Change]:
def _synthetic_dashboard_changes(self, dashboards_before_change: list[dict]) -> list[Change]:
artificial_dashboard_changes = self.context.get("after_dashboard_changes", [])
if artificial_dashboard_changes:
return [
@ -426,7 +426,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin):
return []
def _update_insight_dashboards(self, dashboards: List[Dashboard], instance: Insight) -> None:
def _update_insight_dashboards(self, dashboards: list[Dashboard], instance: Insight) -> None:
old_dashboard_ids = [tile.dashboard_id for tile in instance.dashboard_tiles.all()]
new_dashboard_ids = [d.id for d in dashboards if not d.deleted]
@ -598,14 +598,14 @@ class InsightViewSet(
parser_classes = (QuerySchemaParser,)
def get_serializer_class(self) -> Type[serializers.BaseSerializer]:
def get_serializer_class(self) -> type[serializers.BaseSerializer]:
if (self.action == "list" or self.action == "retrieve") and str_to_bool(
self.request.query_params.get("basic", "0")
):
return InsightBasicSerializer
return super().get_serializer_class()
def get_serializer_context(self) -> Dict[str, Any]:
def get_serializer_context(self) -> dict[str, Any]:
context = super().get_serializer_context()
context["is_shared"] = isinstance(self.request.successful_authenticator, SharingAccessTokenAuthentication)
return context
@ -867,7 +867,7 @@ Using the correct cache and enriching the response with dashboard specific confi
return Response({**result, "next": next})
@cached_by_filters
def calculate_trends(self, request: request.Request) -> Dict[str, Any]:
def calculate_trends(self, request: request.Request) -> dict[str, Any]:
team = self.team
filter = Filter(request=request, team=self.team)
@ -919,7 +919,7 @@ Using the correct cache and enriching the response with dashboard specific confi
return Response(funnel)
@cached_by_filters
def calculate_funnel(self, request: request.Request) -> Dict[str, Any]:
def calculate_funnel(self, request: request.Request) -> dict[str, Any]:
team = self.team
filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team)
@ -959,7 +959,7 @@ Using the correct cache and enriching the response with dashboard specific confi
return Response(result)
@cached_by_filters
def calculate_retention(self, request: request.Request) -> Dict[str, Any]:
def calculate_retention(self, request: request.Request) -> dict[str, Any]:
team = self.team
data = {}
if not request.GET.get("date_from") and not request.data.get("date_from"):
@ -989,7 +989,7 @@ Using the correct cache and enriching the response with dashboard specific confi
return Response(result)
@cached_by_filters
def calculate_path(self, request: request.Request) -> Dict[str, Any]:
def calculate_path(self, request: request.Request) -> dict[str, Any]:
team = self.team
filter = PathFilter(request=request, data={"insight": INSIGHT_PATHS}, team=self.team)

View File

@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Optional, Union
from rest_framework import exceptions, mixins, permissions, serializers, viewsets
@ -50,7 +50,7 @@ class InstanceSettingHelper:
setattr(self, field, kwargs.get(field, None))
def get_instance_setting(key: str, setting_config: Optional[Tuple] = None) -> InstanceSettingHelper:
def get_instance_setting(key: str, setting_config: Optional[tuple] = None) -> InstanceSettingHelper:
setting_config = setting_config or CONSTANCE_CONFIG[key]
is_secret = key in SECRET_SETTINGS
value = get_instance_setting_raw(key)
@ -73,7 +73,7 @@ class InstanceSettingsSerializer(serializers.Serializer):
editable = serializers.BooleanField(read_only=True)
is_secret = serializers.BooleanField(read_only=True)
def update(self, instance: InstanceSettingHelper, validated_data: Dict[str, Any]) -> InstanceSettingHelper:
def update(self, instance: InstanceSettingHelper, validated_data: dict[str, Any]) -> InstanceSettingHelper:
if instance.key not in SETTINGS_ALLOWING_API_OVERRIDE:
raise serializers.ValidationError("This setting cannot be updated from the API.", code="no_api_override")

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union
from typing import Any, Union
from django.conf import settings
from django.db import connection
@ -40,7 +40,7 @@ class InstanceStatusViewSet(viewsets.ViewSet):
redis_alive = is_redis_alive()
postgres_alive = is_postgres_alive()
metrics: List[Dict[str, Union[str, bool, int, float, Dict[str, Any]]]] = []
metrics: list[dict[str, Union[str, bool, int, float, dict[str, Any]]]] = []
metrics.append(
{"key": "posthog_git_sha", "metric": "PostHog Git SHA", "value": get_git_commit_short() or "unknown"}

View File

@ -1,4 +1,4 @@
from typing import TypeVar, Type
from typing import TypeVar
from pydantic import BaseModel, ValidationError
@ -9,7 +9,7 @@ T = TypeVar("T", bound=BaseModel)
class PydanticModelMixin:
def get_model(self, data: dict, model: Type[T]) -> T:
def get_model(self, data: dict, model: type[T]) -> T:
try:
return model.model_validate(data)
except ValidationError as exc:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Any, Type
from typing import Optional, Any
from django.db.models import Q
import structlog
from django.db import transaction
@ -58,7 +58,7 @@ def log_notebook_activity(
team_id: int,
user: User,
was_impersonated: bool,
changes: Optional[List[Change]] = None,
changes: Optional[list[Change]] = None,
) -> None:
short_id = str(notebook.short_id)
@ -118,7 +118,7 @@ class NotebookSerializer(NotebookMinimalSerializer):
"last_modified_by",
]
def create(self, validated_data: Dict, *args, **kwargs) -> Notebook:
def create(self, validated_data: dict, *args, **kwargs) -> Notebook:
request = self.context["request"]
team = self.context["get_team"]()
@ -141,7 +141,7 @@ class NotebookSerializer(NotebookMinimalSerializer):
return notebook
def update(self, instance: Notebook, validated_data: Dict, **kwargs) -> Notebook:
def update(self, instance: Notebook, validated_data: dict, **kwargs) -> Notebook:
try:
before_update = Notebook.objects.get(pk=instance.id)
except Notebook.DoesNotExist:
@ -240,7 +240,7 @@ class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.Model
filterset_fields = ["short_id"]
lookup_field = "short_id"
def get_serializer_class(self) -> Type[BaseSerializer]:
def get_serializer_class(self) -> type[BaseSerializer]:
return NotebookMinimalSerializer if self.action == "list" else NotebookSerializer
def get_queryset(self) -> QuerySet:
@ -298,8 +298,8 @@ class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.Model
if target:
# the JSONB query requires a specific structure
basic_structure = List[Dict[str, Any]]
nested_structure = basic_structure | List[Dict[str, basic_structure]]
basic_structure = list[dict[str, Any]]
nested_structure = basic_structure | list[dict[str, basic_structure]]
presence_match_structure: basic_structure | nested_structure = [{"type": f"ph-{target}"}]

View File

@ -1,5 +1,5 @@
from functools import cached_property
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Optional, Union, cast
from django.db.models import Model, QuerySet
from django.shortcuts import get_object_or_404
@ -108,7 +108,7 @@ class OrganizationSerializer(serializers.ModelSerializer, UserPermissionsSeriali
}, # slug is not required here as it's generated automatically for new organizations
}
def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Organization:
def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Organization:
serializers.raise_errors_on_nested_writes("create", self, validated_data)
user = self.context["request"].user
organization, _, _ = Organization.objects.bootstrap(user, **validated_data)
@ -119,11 +119,11 @@ class OrganizationSerializer(serializers.ModelSerializer, UserPermissionsSeriali
membership = self.user_permissions.organization_memberships.get(organization.pk)
return membership.level if membership is not None else None
def get_teams(self, instance: Organization) -> List[Dict[str, Any]]:
def get_teams(self, instance: Organization) -> list[dict[str, Any]]:
visible_teams = instance.teams.filter(id__in=self.user_permissions.team_ids_visible_for_user)
return TeamBasicSerializer(visible_teams, context=self.context, many=True).data # type: ignore
def get_metadata(self, instance: Organization) -> Dict[str, Union[str, int, object]]:
def get_metadata(self, instance: Organization) -> dict[str, Union[str, int, object]]:
return {
"instance_tag": settings.INSTANCE_TAG,
}
@ -210,7 +210,7 @@ class OrganizationViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
ignore_conflicts=True,
)
def get_serializer_context(self) -> Dict[str, Any]:
def get_serializer_context(self) -> dict[str, Any]:
return {
**super().get_serializer_context(),
"user_permissions": UserPermissions(cast(User, self.request.user)),

View File

@ -1,5 +1,5 @@
import re
from typing import Any, Dict, cast
from typing import Any, cast
from rest_framework import exceptions, request, response, serializers
from rest_framework.decorators import action
@ -38,7 +38,7 @@ class OrganizationDomainSerializer(serializers.ModelSerializer):
"has_saml": {"read_only": True},
}
def create(self, validated_data: Dict[str, Any]) -> OrganizationDomain:
def create(self, validated_data: dict[str, Any]) -> OrganizationDomain:
validated_data["organization"] = self.context["view"].organization
validated_data.pop(
"jit_provisioning_enabled", None
@ -56,7 +56,7 @@ class OrganizationDomainSerializer(serializers.ModelSerializer):
raise serializers.ValidationError("Please enter a valid domain or subdomain name.")
return domain
def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
instance = cast(OrganizationDomain, self.instance)
if instance and not instance.verified_at:

Some files were not shown because too many files have changed in this diff Show More