From 9576fab1e4720a643ba5e4feec9ba83975c9c6d3 Mon Sep 17 00:00:00 2001 From: Julian Bez Date: Thu, 25 Apr 2024 08:22:28 +0100 Subject: [PATCH] chore: Add Pyupgrade rules (#21714) * Add Pyupgrade rules * Set correct Python version --- bin/migrate_kafka_data.py | 3 +- ee/api/authentication.py | 6 +- ee/api/dashboard_collaborator.py | 6 +- ee/api/role.py | 4 +- ee/api/sentry_stats.py | 18 +- ee/api/subscription.py | 4 +- ee/api/test/base.py | 4 +- .../fixtures/available_product_features.py | 4 +- ee/api/test/test_authentication.py | 7 - ee/api/test/test_billing.py | 6 +- ee/api/test/test_capture.py | 24 +- ee/api/test/test_dashboard.py | 6 +- ee/api/test/test_event_definition.py | 4 +- ee/api/test/test_insight.py | 10 +- ee/api/test/test_integration.py | 2 +- ee/api/test/test_property_definition.py | 4 +- ee/api/test/test_time_to_see_data.py | 16 +- ee/benchmarks/benchmarks.py | 3 +- ee/billing/billing_manager.py | 8 +- ee/billing/billing_types.py | 14 +- ee/billing/quota_limiting.py | 29 +- ee/clickhouse/materialized_columns/analyze.py | 25 +- ee/clickhouse/materialized_columns/columns.py | 8 +- ee/clickhouse/models/test/test_action.py | 3 +- ee/clickhouse/models/test/test_property.py | 14 +- ee/clickhouse/queries/column_optimizer.py | 12 +- .../queries/enterprise_cohort_query.py | 32 +- ee/clickhouse/queries/event_query.py | 10 +- .../experiments/funnel_experiment_result.py | 22 +- .../secondary_experiment_result.py | 6 +- .../experiments/test_experiment_result.py | 5 +- .../experiments/trend_experiment_result.py | 18 +- ee/clickhouse/queries/experiments/utils.py | 4 +- .../queries/funnels/funnel_correlation.py | 32 +- .../funnels/funnel_correlation_persons.py | 8 +- .../queries/funnels/test/breakdown_cases.py | 6 +- ee/clickhouse/queries/groups_join_query.py | 6 +- ee/clickhouse/queries/paths/paths.py | 12 +- ee/clickhouse/queries/related_actors_query.py | 12 +- ee/clickhouse/queries/test/test_paths.py | 3 +- ee/clickhouse/views/experiments.py | 3 +- ee/clickhouse/views/groups.py | 4 +- ee/clickhouse/views/insights.py | 4 +- ee/clickhouse/views/person.py | 4 +- .../test_clickhouse_funnel_correlation.py | 8 +- ee/clickhouse/views/test/funnel/util.py | 10 +- ...clickhouse_experiment_secondary_results.py | 4 +- .../views/test/test_clickhouse_retention.py | 16 +- .../views/test/test_clickhouse_trends.py | 14 +- ee/migrations/0001_initial.py | 3 +- ee/migrations/0012_migrate_tags_v2.py | 4 +- ee/models/license.py | 4 +- .../ai/embeddings_queries.py | 5 +- ee/session_recordings/ai/embeddings_runner.py | 12 +- ee/session_recordings/ai/utils.py | 12 +- ...sion_recording_list_from_session_replay.py | 3 +- .../session_recording_playlist.py | 8 +- .../test/test_session_recording_extensions.py | 4 +- ee/settings.py | 5 +- ee/tasks/auto_rollback_feature_flag.py | 5 +- ee/tasks/replay.py | 4 +- ee/tasks/slack.py | 4 +- ee/tasks/subscriptions/email_subscriptions.py | 4 +- ee/tasks/subscriptions/slack_subscriptions.py | 6 +- ee/tasks/subscriptions/subscription_utils.py | 4 +- .../test/subscriptions/test_subscriptions.py | 5 +- .../subscriptions/test_subscriptions_utils.py | 3 +- ee/tasks/test/test_slack.py | 3 +- ee/urls.py | 4 +- gunicorn.config.py | 1 - hogvm/python/execute.py | 4 +- .../generate_session_recordings_messages.py | 3 +- posthog/api/action.py | 6 +- posthog/api/activity_log.py | 4 +- posthog/api/annotation.py | 6 +- posthog/api/authentication.py | 10 +- posthog/api/capture.py | 29 +- posthog/api/cohort.py | 22 +- posthog/api/comments.py | 4 +- posthog/api/dashboards/dashboard.py | 18 +- .../dashboard_template_json_schema_parser.py | 4 +- posthog/api/dashboards/dashboard_templates.py | 5 +- .../test/test_dashboard_templates.py | 4 +- posthog/api/dead_letter_queue.py | 12 +- posthog/api/decide.py | 14 +- posthog/api/documentation.py | 4 +- posthog/api/early_access_feature.py | 4 +- posthog/api/element.py | 6 +- posthog/api/event.py | 10 +- posthog/api/event_definition.py | 6 +- posthog/api/exports.py | 8 +- posthog/api/feature_flag.py | 24 +- posthog/api/geoip.py | 4 +- posthog/api/insight.py | 30 +- posthog/api/instance_settings.py | 6 +- posthog/api/instance_status.py | 4 +- posthog/api/mixins.py | 4 +- posthog/api/notebook.py | 14 +- posthog/api/organization.py | 10 +- posthog/api/organization_domain.py | 6 +- posthog/api/organization_feature_flag.py | 5 +- posthog/api/organization_invite.py | 4 +- posthog/api/person.py | 19 +- posthog/api/plugin.py | 32 +- posthog/api/property_definition.py | 10 +- posthog/api/routing.py | 8 +- posthog/api/scheduled_change.py | 4 +- posthog/api/sharing.py | 8 +- posthog/api/signup.py | 10 +- posthog/api/survey.py | 3 +- posthog/api/tagged_item.py | 8 +- posthog/api/team.py | 26 +- posthog/api/test/dashboards/__init__.py | 46 +- posthog/api/test/dashboards/test_dashboard.py | 9 +- .../dashboards/test_dashboard_duplication.py | 4 +- .../dashboards/test_dashboard_text_tiles.py | 10 +- posthog/api/test/notebooks/test_notebook.py | 9 +- .../test/notebooks/test_notebook_filtering.py | 8 +- posthog/api/test/openapi_validation.py | 4 +- posthog/api/test/test_activity_log.py | 10 +- posthog/api/test/test_capture.py | 20 +- posthog/api/test/test_cohort.py | 6 +- posthog/api/test/test_decide.py | 22 +- posthog/api/test/test_element.py | 5 +- posthog/api/test/test_event_definition.py | 10 +- posthog/api/test/test_exports.py | 8 +- posthog/api/test/test_feature_flag.py | 75 ++- posthog/api/test/test_feature_flag_utils.py | 3 +- posthog/api/test/test_ingestion_warnings.py | 3 +- posthog/api/test/test_insight.py | 12 +- posthog/api/test/test_insight_funnels.py | 4 +- posthog/api/test/test_insight_query.py | 6 +- posthog/api/test/test_kafka_inspector.py | 4 +- .../test/test_organization_feature_flag.py | 4 +- posthog/api/test/test_person.py | 6 +- posthog/api/test/test_plugin.py | 14 +- posthog/api/test/test_properties_timeline.py | 4 +- posthog/api/test/test_property_definition.py | 6 +- posthog/api/test/test_signup.py | 4 +- posthog/api/test/test_site_app.py | 4 +- posthog/api/test/test_stickiness.py | 12 +- posthog/api/test/test_team.py | 16 +- posthog/api/test/test_user.py | 4 +- posthog/api/uploaded_media.py | 4 +- posthog/api/utils.py | 12 +- posthog/async_migrations/definition.py | 21 +- .../migrations/0001_events_sample_by.py | 4 +- .../migrations/0002_events_sample_by.py | 3 +- .../0005_person_replacing_by_version.py | 5 +- ...6_persons_and_groups_on_events_backfill.py | 4 +- ...7_persons_and_groups_on_events_backfill.py | 4 +- ...minmax_indexes_for_materialized_columns.py | 4 +- posthog/async_migrations/runner.py | 8 +- posthog/async_migrations/setup.py | 8 +- ...7_persons_and_groups_on_events_backfill.py | 7 +- .../test/test_0010_move_old_partitions.py | 2 +- posthog/async_migrations/utils.py | 3 +- posthog/auth.py | 18 +- posthog/caching/calculate_results.py | 18 +- posthog/caching/fetch_from_cache.py | 4 +- posthog/caching/insight_cache.py | 6 +- posthog/caching/insight_caching_state.py | 6 +- posthog/caching/insights_api.py | 4 +- posthog/caching/test/test_insight_cache.py | 3 +- .../test/test_insight_caching_state.py | 8 +- posthog/caching/utils.py | 8 +- posthog/celery.py | 3 +- posthog/clickhouse/client/connection.py | 4 +- posthog/clickhouse/client/execute.py | 13 +- posthog/clickhouse/client/migration_tools.py | 3 +- .../clickhouse/materialized_columns/column.py | 6 +- ...ensure_kafa_session_replay_table_exists.py | 4 +- posthog/clickhouse/system_status.py | 10 +- posthog/conftest.py | 4 +- posthog/decorators.py | 5 +- posthog/demo/legacy/data_generator.py | 7 +- posthog/demo/legacy/web_data_generator.py | 10 +- posthog/demo/matrix/manager.py | 14 +- posthog/demo/matrix/matrix.py | 30 +- posthog/demo/matrix/models.py | 25 +- posthog/demo/matrix/randomization.py | 9 +- posthog/demo/matrix/taxonomy_inference.py | 12 +- posthog/demo/products/hedgebox/models.py | 20 +- posthog/email.py | 16 +- posthog/errors.py | 4 +- posthog/event_usage.py | 14 +- posthog/filters.py | 8 +- posthog/gzip_middleware.py | 3 +- posthog/health.py | 9 +- posthog/heatmaps/heatmaps_api.py | 10 +- posthog/heatmaps/test/test_heatmaps_api.py | 7 +- posthog/helpers/dashboard_templates.py | 15 +- posthog/helpers/multi_property_breakdown.py | 18 +- .../tests/test_multi_property_breakdown.py | 26 +- posthog/hogql/ai.py | 2 +- posthog/hogql/ast.py | 66 +-- posthog/hogql/autocomplete.py | 31 +- posthog/hogql/bytecode.py | 6 +- posthog/hogql/constants.py | 4 +- posthog/hogql/context.py | 10 +- posthog/hogql/database/argmax.py | 11 +- posthog/hogql/database/database.py | 26 +- posthog/hogql/database/models.py | 19 +- .../hogql/database/schema/cohort_people.py | 12 +- .../hogql/database/schema/event_sessions.py | 16 +- posthog/hogql/database/schema/events.py | 6 +- posthog/hogql/database/schema/groups.py | 12 +- posthog/hogql/database/schema/heatmaps.py | 4 +- posthog/hogql/database/schema/log_entries.py | 18 +- posthog/hogql/database/schema/numbers.py | 4 +- .../schema/person_distinct_id_overrides.py | 11 +- .../database/schema/person_distinct_ids.py | 11 +- .../hogql/database/schema/person_overrides.py | 14 +- posthog/hogql/database/schema/persons.py | 13 +- posthog/hogql/database/schema/persons_pdi.py | 9 +- .../database/schema/session_replay_events.py | 18 +- posthog/hogql/database/schema/sessions.py | 22 +- .../database/schema/static_cohort_people.py | 4 +- .../schema/test/test_event_sessions.py | 4 +- .../test_session_where_clause_extractor.py | 4 +- posthog/hogql/escape_sql.py | 6 +- posthog/hogql/filters.py | 6 +- posthog/hogql/functions/action.py | 4 +- posthog/hogql/functions/cohort.py | 4 +- posthog/hogql/functions/mapping.py | 18 +- posthog/hogql/functions/sparkline.py | 4 +- posthog/hogql/hogql.py | 4 +- posthog/hogql/parser.py | 19 +- posthog/hogql/placeholders.py | 10 +- posthog/hogql/printer.py | 24 +- posthog/hogql/property.py | 4 +- posthog/hogql/query.py | 6 +- posthog/hogql/resolver.py | 16 +- posthog/hogql/resolver_utils.py | 4 +- posthog/hogql/test/_test_parser.py | 6 +- posthog/hogql/test/test_filters.py | 6 +- posthog/hogql/test/test_mapping.py | 24 +- posthog/hogql/test/test_printer.py | 4 +- posthog/hogql/test/test_property.py | 6 +- posthog/hogql/test/test_query.py | 2 +- posthog/hogql/test/test_resolver.py | 4 +- posthog/hogql/test/test_timings.py | 28 +- posthog/hogql/timings.py | 9 +- posthog/hogql/transforms/in_cohort.py | 28 +- posthog/hogql/transforms/lazy_tables.py | 56 +-- posthog/hogql/transforms/property_types.py | 14 +- posthog/hogql_queries/actor_strategies.py | 30 +- posthog/hogql_queries/actors_query_runner.py | 15 +- posthog/hogql_queries/events_query_runner.py | 18 +- posthog/hogql_queries/hogql_query_runner.py | 5 +- .../hogql_queries/insights/funnels/base.py | 94 ++-- .../hogql_queries/insights/funnels/funnel.py | 20 +- .../funnel_correlation_query_runner.py | 16 +- .../insights/funnels/funnel_event_query.py | 18 +- .../insights/funnels/funnel_persons.py | 6 +- .../insights/funnels/funnel_query_context.py | 10 +- .../insights/funnels/funnel_strict.py | 16 +- .../insights/funnels/funnel_strict_persons.py | 6 +- .../insights/funnels/funnel_trends.py | 24 +- .../insights/funnels/funnel_trends_persons.py | 5 +- .../insights/funnels/funnel_unordered.py | 28 +- .../funnels/funnel_unordered_persons.py | 8 +- .../insights/funnels/funnels_query_runner.py | 4 +- .../insights/funnels/test/breakdown_cases.py | 19 +- .../test_funnel_breakdowns_by_current_url.py | 4 +- .../funnels/test/test_funnel_correlation.py | 6 +- .../test/test_funnel_correlations_persons.py | 4 +- .../funnels/test/test_funnel_persons.py | 8 +- .../hogql_queries/insights/funnels/utils.py | 3 +- .../insights/lifecycle_query_runner.py | 4 +- .../insights/paths_query_runner.py | 4 +- .../insights/retention_query_runner.py | 4 +- .../insights/stickiness_query_runner.py | 14 +- .../test/test_insight_actors_query_runner.py | 4 +- .../insights/test/test_paths_query_runner.py | 3 +- .../test/test_stickiness_query_runner.py | 36 +- .../insights/trends/aggregation_operations.py | 10 +- .../insights/trends/breakdown.py | 24 +- .../insights/trends/breakdown_values.py | 10 +- .../insights/trends/test/test_trends.py | 22 +- .../trends/test/test_trends_query_runner.py | 18 +- .../insights/trends/trends_query_builder.py | 6 +- .../insights/trends/trends_query_runner.py | 40 +- .../hogql_queries/insights/trends/utils.py | 4 +- .../insights/utils/properties.py | 8 +- posthog/hogql_queries/insights/utils/utils.py | 4 +- .../legacy_compatibility/filter_to_query.py | 32 +- posthog/hogql_queries/query_runner.py | 42 +- .../sessions_timeline_query_runner.py | 4 +- .../test/test_events_query_runner.py | 4 +- .../hogql_queries/test/test_query_runner.py | 4 +- posthog/hogql_queries/utils/formula_ast.py | 10 +- .../hogql_queries/utils/query_date_range.py | 4 +- .../utils/query_previous_period_date_range.py | 8 +- .../test/test_web_analytics_query_runner.py | 12 +- .../web_analytics_query_runner.py | 6 +- posthog/jwt.py | 4 +- posthog/kafka_client/client.py | 9 +- posthog/kafka_client/helper.py | 8 +- .../backfill_distinct_id_overrides.py | 2 +- .../create_channel_definitions_file.py | 6 +- .../fix_person_distinct_ids_after_delete.py | 4 +- posthog/management/commands/makemigrations.py | 2 +- posthog/management/commands/partition.py | 2 +- .../commands/run_async_migrations.py | 10 +- .../management/commands/sync_feature_flags.py | 6 +- .../commands/sync_replicated_schema.py | 11 +- .../commands/test_migrations_are_safe.py | 4 +- posthog/middleware.py | 7 +- .../migrations/0027_move_elements_to_group.py | 3 +- .../0132_team_test_account_filters.py | 2 +- posthog/migrations/0219_migrate_tags_v2.py | 4 +- .../0259_backfill_team_recording_domains.py | 3 +- posthog/models/action/action.py | 6 +- posthog/models/action/util.py | 14 +- .../models/activity_logging/activity_log.py | 30 +- posthog/models/async_deletion/delete.py | 11 +- .../models/async_deletion/delete_cohorts.py | 10 +- .../models/async_deletion/delete_events.py | 12 +- posthog/models/async_migration.py | 4 +- posthog/models/channel_type/sql.py | 4 +- posthog/models/cohort/cohort.py | 14 +- posthog/models/cohort/util.py | 44 +- posthog/models/dashboard.py | 4 +- posthog/models/dashboard_tile.py | 6 +- posthog/models/element/element.py | 5 +- posthog/models/element_group.py | 6 +- posthog/models/entity/entity.py | 10 +- posthog/models/entity/util.py | 13 +- posthog/models/event/event.py | 22 +- posthog/models/event/query_event_list.py | 14 +- posthog/models/event/util.py | 30 +- posthog/models/exported_asset.py | 4 +- posthog/models/feature_flag/feature_flag.py | 18 +- posthog/models/feature_flag/flag_analytics.py | 4 +- posthog/models/feature_flag/flag_matching.py | 70 +-- posthog/models/filters/base_filter.py | 16 +- posthog/models/filters/lifecycle_filter.py | 4 +- posthog/models/filters/mixins/base.py | 4 +- posthog/models/filters/mixins/common.py | 30 +- posthog/models/filters/mixins/funnel.py | 28 +- posthog/models/filters/mixins/paths.py | 12 +- posthog/models/filters/mixins/property.py | 14 +- posthog/models/filters/mixins/retention.py | 4 +- .../filters/mixins/session_recordings.py | 6 +- posthog/models/filters/mixins/simplify.py | 8 +- posthog/models/filters/mixins/stickiness.py | 3 +- posthog/models/filters/mixins/utils.py | 3 +- posthog/models/filters/path_filter.py | 4 +- posthog/models/filters/retention_filter.py | 6 +- posthog/models/filters/stickiness_filter.py | 5 +- posthog/models/filters/test/test_filter.py | 7 +- .../models/filters/test/test_path_filter.py | 4 +- posthog/models/group/util.py | 6 +- posthog/models/instance_setting.py | 4 +- posthog/models/integration.py | 8 +- posthog/models/organization.py | 12 +- posthog/models/organization_domain.py | 6 +- posthog/models/person/person.py | 12 +- posthog/models/person/util.py | 12 +- posthog/models/personal_api_key.py | 6 +- posthog/models/plugin.py | 30 +- posthog/models/project.py | 4 +- posthog/models/property/property.py | 23 +- posthog/models/property/util.py | 56 ++- posthog/models/sharing_configuration.py | 4 +- posthog/models/subscription.py | 6 +- posthog/models/tagged_item.py | 7 +- posthog/models/team/team.py | 8 +- posthog/models/team/util.py | 6 +- .../models/test/test_dashboard_tile_model.py | 5 +- posthog/models/uploaded_media.py | 4 +- posthog/models/user.py | 15 +- posthog/models/utils.py | 11 +- posthog/plugins/site.py | 10 +- posthog/plugins/utils.py | 24 +- posthog/queries/actor_base_query.py | 58 ++- .../queries/app_metrics/historical_exports.py | 4 +- .../app_metrics/test/test_app_metrics.py | 4 +- posthog/queries/base.py | 24 +- posthog/queries/breakdown_props.py | 24 +- .../column_optimizer/foss_column_optimizer.py | 23 +- posthog/queries/event_query/event_query.py | 28 +- posthog/queries/foss_cohort_query.py | 66 +-- posthog/queries/funnels/base.py | 54 +-- posthog/queries/funnels/funnel.py | 6 +- posthog/queries/funnels/funnel_event_query.py | 10 +- posthog/queries/funnels/funnel_persons.py | 4 +- posthog/queries/funnels/funnel_strict.py | 4 +- .../queries/funnels/funnel_strict_persons.py | 4 +- posthog/queries/funnels/funnel_trends.py | 6 +- posthog/queries/funnels/funnel_unordered.py | 10 +- .../funnels/funnel_unordered_persons.py | 4 +- .../queries/funnels/test/breakdown_cases.py | 14 +- .../test/test_breakdowns_by_current_url.py | 4 +- posthog/queries/funnels/utils.py | 4 +- .../groups_join_query/groups_join_query.py | 4 +- posthog/queries/paths/paths.py | 20 +- posthog/queries/paths/paths_actors.py | 4 +- posthog/queries/paths/paths_event_query.py | 10 +- posthog/queries/person_query.py | 34 +- .../properties_timeline.py | 12 +- .../properties_timeline_event_query.py | 10 +- posthog/queries/property_optimizer.py | 4 +- posthog/queries/property_values.py | 4 +- posthog/queries/query_date_range.py | 6 +- posthog/queries/retention/actors_query.py | 8 +- posthog/queries/retention/retention.py | 16 +- .../retention/retention_events_query.py | 4 +- posthog/queries/retention/types.py | 4 +- posthog/queries/stickiness/stickiness.py | 18 +- .../queries/stickiness/stickiness_actors.py | 4 +- .../stickiness/stickiness_event_query.py | 6 +- posthog/queries/test/test_paths.py | 3 +- posthog/queries/test/test_trends.py | 10 +- posthog/queries/time_to_see_data/hierarchy.py | 5 +- posthog/queries/time_to_see_data/sessions.py | 8 +- posthog/queries/trends/breakdown.py | 35 +- posthog/queries/trends/formula.py | 8 +- posthog/queries/trends/lifecycle.py | 13 +- posthog/queries/trends/lifecycle_actors.py | 4 +- .../queries/trends/test/test_breakdowns.py | 4 +- .../test/test_breakdowns_by_current_url.py | 4 +- posthog/queries/trends/test/test_formula.py | 4 +- .../trends/test/test_paging_breakdowns.py | 4 +- posthog/queries/trends/total_volume.py | 23 +- posthog/queries/trends/trends.py | 41 +- posthog/queries/trends/trends_actors.py | 8 +- posthog/queries/trends/trends_event_query.py | 4 +- .../queries/trends/trends_event_query_base.py | 10 +- posthog/queries/trends/util.py | 18 +- posthog/queries/util.py | 10 +- posthog/rate_limit.py | 4 +- posthog/renderers.py | 4 +- posthog/schema.py | 448 +++++++++--------- posthog/session_recordings/models/metadata.py | 12 +- .../models/session_recording.py | 8 +- .../queries/session_query.py | 4 +- ...sion_recording_list_from_replay_summary.py | 50 +- .../queries/session_recording_properties.py | 18 +- .../queries/session_replay_events.py | 8 +- .../queries/test/session_replay_sql.py | 4 +- ...sion_recording_list_from_session_replay.py | 3 +- .../session_recordings/realtime_snapshots.py | 4 +- .../session_recording_api.py | 20 +- .../session_recording_helpers.py | 25 +- .../snapshots/convert_legacy_snapshots.py | 3 +- .../test/test_lts_session_recordings.py | 9 +- .../test/test_session_recording_helpers.py | 6 +- .../test/test_session_recordings.py | 7 +- posthog/settings/__init__.py | 5 +- posthog/settings/data_stores.py | 3 +- posthog/settings/logs.py | 3 +- posthog/settings/session_replay.py | 4 +- posthog/settings/temporal.py | 3 +- posthog/settings/utils.py | 7 +- posthog/settings/web.py | 3 +- posthog/storage/object_storage.py | 26 +- posthog/storage/test/test_object_storage.py | 10 +- posthog/tasks/calculate_cohort.py | 6 +- posthog/tasks/email.py | 4 +- posthog/tasks/exports/csv_exporter.py | 11 +- posthog/tasks/exports/ordered_csv_renderer.py | 5 +- .../tasks/exports/test/test_csv_exporter.py | 6 +- ...ync_all_organization_available_features.py | 3 +- posthog/tasks/test/test_calculate_cohort.py | 2 +- posthog/tasks/test/test_email.py | 3 +- posthog/tasks/test/test_usage_report.py | 6 +- posthog/tasks/test/utils_email_tests.py | 4 +- posthog/tasks/usage_report.py | 78 +-- posthog/tasks/verify_persons_data_in_sync.py | 8 +- posthog/templatetags/posthog_assets.py | 3 +- .../temporal/batch_exports/batch_exports.py | 7 +- .../batch_exports/postgres_batch_export.py | 2 +- .../batch_exports/snowflake_batch_export.py | 2 +- posthog/temporal/batch_exports/utils.py | 3 +- posthog/temporal/common/clickhouse.py | 13 +- posthog/temporal/common/codec.py | 2 +- posthog/temporal/common/sentry.py | 4 +- posthog/temporal/common/utils.py | 2 +- .../data_imports/external_data_job.py | 3 +- .../pipelines/hubspot/__init__.py | 15 +- .../data_imports/pipelines/hubspot/auth.py | 3 +- .../data_imports/pipelines/hubspot/helpers.py | 19 +- .../data_imports/pipelines/pipeline.py | 6 +- .../pipelines/postgres/__init__.py | 7 +- .../pipelines/postgres/helpers.py | 7 +- .../data_imports/pipelines/stripe/helpers.py | 7 +- .../pipelines/zendesk/api_helpers.py | 6 +- .../pipelines/zendesk/credentials.py | 4 +- .../data_imports/pipelines/zendesk/helpers.py | 5 +- .../pipelines/zendesk/talk_api.py | 7 +- .../workflow_activities/create_job_model.py | 3 +- .../workflow_activities/import_data.py | 3 +- .../test_http_batch_export_workflow.py | 30 +- .../test_snowflake_batch_export_workflow.py | 22 +- .../external_data/test_external_data_job.py | 113 +++-- posthog/test/base.py | 29 +- posthog/test/db_context_capturing.py | 2 +- posthog/test/test_feature_flag.py | 15 +- posthog/test/test_feature_flag_analytics.py | 20 +- posthog/test/test_health.py | 12 +- posthog/test/test_journeys.py | 28 +- posthog/test/test_utils.py | 2 +- posthog/urls.py | 5 +- posthog/user_permissions.py | 24 +- posthog/utils.py | 48 +- posthog/version_requirement.py | 4 +- posthog/views.py | 4 +- posthog/warehouse/api/external_data_schema.py | 6 +- posthog/warehouse/api/external_data_source.py | 10 +- posthog/warehouse/api/saved_query.py | 4 +- posthog/warehouse/api/table.py | 8 +- .../warehouse/data_load/validate_schema.py | 9 +- .../warehouse/external_data_source/source.py | 4 +- .../models/datawarehouse_saved_query.py | 3 +- .../warehouse/models/external_data_schema.py | 4 +- .../models/external_table_definitions.py | 3 +- posthog/warehouse/models/join.py | 4 +- posthog/warehouse/models/table.py | 6 +- posthog/year_in_posthog/calculate_2023.py | 4 +- posthog/year_in_posthog/year_in_posthog.py | 6 +- pyproject.toml | 6 + 523 files changed, 3100 insertions(+), 3141 deletions(-) diff --git a/bin/migrate_kafka_data.py b/bin/migrate_kafka_data.py index 87eaf391657..3da55ed538c 100755 --- a/bin/migrate_kafka_data.py +++ b/bin/migrate_kafka_data.py @@ -21,7 +21,6 @@ import argparse import sys -from typing import List from kafka import KafkaAdminClient, KafkaConsumer, KafkaProducer from kafka.errors import KafkaError @@ -192,7 +191,7 @@ def handle(**options): print("Polling for messages") # noqa: T201 messages_by_topic = consumer.poll(timeout_ms=timeout_ms) - futures: List[FutureRecordMetadata] = [] + futures: list[FutureRecordMetadata] = [] if not messages_by_topic: break diff --git a/ee/api/authentication.py b/ee/api/authentication.py index 2dfb6c7b9f0..f2850bcfb5f 100644 --- a/ee/api/authentication.py +++ b/ee/api/authentication.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Union from django.core.exceptions import ValidationError as DjangoValidationError from django.http.response import HttpResponse @@ -91,8 +91,8 @@ class MultitenantSAMLAuth(SAMLAuth): def _get_attr( self, - response_attributes: Dict[str, Any], - attribute_names: List[str], + response_attributes: dict[str, Any], + attribute_names: list[str], optional: bool = False, ) -> str: """ diff --git a/ee/api/dashboard_collaborator.py b/ee/api/dashboard_collaborator.py index 998eeba8238..6a004215d96 100644 --- a/ee/api/dashboard_collaborator.py +++ b/ee/api/dashboard_collaborator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from django.db import IntegrityError from rest_framework import exceptions, mixins, serializers, viewsets @@ -45,7 +45,7 @@ class DashboardCollaboratorSerializer(serializers.ModelSerializer, UserPermissio ] read_only_fields = ["id", "dashboard_id", "user", "user"] - def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: dashboard: Dashboard = self.context["dashboard"] dashboard_permissions = self.user_permissions.dashboard(dashboard) if dashboard_permissions.effective_restriction_level <= Dashboard.RestrictionLevel.EVERYONE_IN_PROJECT_CAN_EDIT: @@ -96,7 +96,7 @@ class DashboardCollaboratorViewSet( serializer_class = DashboardCollaboratorSerializer filter_rewrite_rules = {"team_id": "dashboard__team_id"} - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() try: context["dashboard"] = Dashboard.objects.get(id=context["dashboard_id"]) diff --git a/ee/api/role.py b/ee/api/role.py index 44909f504ee..0c4894c2779 100644 --- a/ee/api/role.py +++ b/ee/api/role.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from django.db import IntegrityError from rest_framework import mixins, serializers, viewsets @@ -76,7 +76,7 @@ class RoleSerializer(serializers.ModelSerializer): return RoleMembershipSerializer(members, many=True).data def get_associated_flags(self, role: Role): - associated_flags: List[dict] = [] + associated_flags: list[dict] = [] role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id") flags = FeatureFlag.objects.filter(id__in=role_access_objects) diff --git a/ee/api/sentry_stats.py b/ee/api/sentry_stats.py index 52b16647c2c..06b4e53b1bd 100644 --- a/ee/api/sentry_stats.py +++ b/ee/api/sentry_stats.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import requests from django.http import HttpRequest, JsonResponse @@ -9,8 +9,8 @@ from rest_framework.exceptions import ValidationError from posthog.models.instance_setting import get_instance_settings -def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]: - sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) +def get_sentry_stats(start_time: str, end_time: str) -> tuple[dict, int]: + sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) org_slug = sentry_config.get("SENTRY_ORGANIZATION") token = sentry_config.get("SENTRY_AUTH_TOKEN") @@ -41,9 +41,9 @@ def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]: def get_tagged_issues_stats( - start_time: str, end_time: str, tags: Dict[str, str], target_issues: List[str] -) -> Dict[str, Any]: - sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) + start_time: str, end_time: str, tags: dict[str, str], target_issues: list[str] +) -> dict[str, Any]: + sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) org_slug = sentry_config.get("SENTRY_ORGANIZATION") token = sentry_config.get("SENTRY_AUTH_TOKEN") @@ -58,7 +58,7 @@ def get_tagged_issues_stats( for tag, value in tags.items(): query += f" {tag}:{value}" - params: Dict[str, Union[list, str]] = { + params: dict[str, Union[list, str]] = { "start": start_time, "end": end_time, "sort": "freq", @@ -89,8 +89,8 @@ def get_stats_for_timerange( base_end_time: str, target_start_time: str, target_end_time: str, - tags: Optional[Dict[str, str]] = None, -) -> Tuple[int, int]: + tags: Optional[dict[str, str]] = None, +) -> tuple[int, int]: base_counts, base_total_count = get_sentry_stats(base_start_time, base_end_time) target_counts, target_total_count = get_sentry_stats(target_start_time, target_end_time) diff --git a/ee/api/subscription.py b/ee/api/subscription.py index 412ddc5cfaf..9f8881026fb 100644 --- a/ee/api/subscription.py +++ b/ee/api/subscription.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import jwt from django.db.models import QuerySet @@ -67,7 +67,7 @@ class SubscriptionSerializer(serializers.ModelSerializer): return attrs - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Subscription: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Subscription: request = self.context["request"] validated_data["team_id"] = self.context["team_id"] validated_data["created_by"] = request.user diff --git a/ee/api/test/base.py b/ee/api/test/base.py index 55e7930bfad..066dcc373d6 100644 --- a/ee/api/test/base.py +++ b/ee/api/test/base.py @@ -1,5 +1,5 @@ import datetime -from typing import Dict, Optional, cast +from typing import Optional, cast from zoneinfo import ZoneInfo @@ -20,7 +20,7 @@ class LicensedTestMixin: def license_required_response( self, message: str = "This feature is part of the premium PostHog offering. Self-hosted licenses are no longer available for purchase. Please contact sales@posthog.com to discuss options.", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "server_error", "code": "payment_required", diff --git a/ee/api/test/fixtures/available_product_features.py b/ee/api/test/fixtures/available_product_features.py index 5be816a169b..8cc5413754d 100644 --- a/ee/api/test/fixtures/available_product_features.py +++ b/ee/api/test/fixtures/available_product_features.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List +from typing import Any -AVAILABLE_PRODUCT_FEATURES: List[Dict[str, Any]] = [ +AVAILABLE_PRODUCT_FEATURES: list[dict[str, Any]] = [ { "description": "Create playlists of certain session recordings to easily find and watch them again in the future.", "key": "recordings_playlists", diff --git a/ee/api/test/test_authentication.py b/ee/api/test/test_authentication.py index 00fca66b291..451efdd3d37 100644 --- a/ee/api/test/test_authentication.py +++ b/ee/api/test/test_authentication.py @@ -364,7 +364,6 @@ class TestEESAMLAuthenticationAPI(APILicensedTest): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -407,7 +406,6 @@ class TestEESAMLAuthenticationAPI(APILicensedTest): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_alt_attribute_names"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -474,7 +472,6 @@ YotAcSbU3p5bzd11wpyebYHB""" with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -514,7 +511,6 @@ YotAcSbU3p5bzd11wpyebYHB""" with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -552,7 +548,6 @@ YotAcSbU3p5bzd11wpyebYHB""" with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_no_first_name"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -594,7 +589,6 @@ YotAcSbU3p5bzd11wpyebYHB""" with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -683,7 +677,6 @@ YotAcSbU3p5bzd11wpyebYHB""" with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() diff --git a/ee/api/test/test_billing.py b/ee/api/test/test_billing.py index c1698bd1cae..94eed34d29d 100644 --- a/ee/api/test/test_billing.py +++ b/ee/api/test/test_billing.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from unittest.mock import MagicMock, patch from uuid import uuid4 from zoneinfo import ZoneInfo @@ -22,7 +22,7 @@ from posthog.models.team import Team from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events -def create_billing_response(**kwargs) -> Dict[str, Any]: +def create_billing_response(**kwargs) -> dict[str, Any]: data: Any = {"license": {"type": "cloud"}} data.update(kwargs) return data @@ -106,7 +106,7 @@ def create_billing_customer(**kwargs) -> CustomerInfo: return data -def create_billing_products_response(**kwargs) -> Dict[str, List[CustomerProduct]]: +def create_billing_products_response(**kwargs) -> dict[str, list[CustomerProduct]]: data: Any = { "products": [ CustomerProduct( diff --git a/ee/api/test/test_capture.py b/ee/api/test/test_capture.py index 891a9759a80..4f716d78509 100644 --- a/ee/api/test/test_capture.py +++ b/ee/api/test/test_capture.py @@ -68,26 +68,26 @@ class TestCaptureAPI(APIBaseTest): self.assertEqual(event2_data["properties"]["distinct_id"], "id2") # Make sure we're producing data correctly in the way the plugin server expects - self.assertEquals(type(kafka_produce_call1["data"]["distinct_id"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["distinct_id"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["distinct_id"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["distinct_id"]), str) self.assertIn(type(kafka_produce_call1["data"]["ip"]), [str, type(None)]) self.assertIn(type(kafka_produce_call2["data"]["ip"]), [str, type(None)]) - self.assertEquals(type(kafka_produce_call1["data"]["site_url"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["site_url"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["site_url"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["site_url"]), str) - self.assertEquals(type(kafka_produce_call1["data"]["token"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["token"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["token"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["token"]), str) - self.assertEquals(type(kafka_produce_call1["data"]["sent_at"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["sent_at"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["sent_at"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["sent_at"]), str) - self.assertEquals(type(event1_data["properties"]), dict) - self.assertEquals(type(event2_data["properties"]), dict) + self.assertEqual(type(event1_data["properties"]), dict) + self.assertEqual(type(event2_data["properties"]), dict) - self.assertEquals(type(kafka_produce_call1["data"]["uuid"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["uuid"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["uuid"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["uuid"]), str) @patch("posthog.kafka_client.client._KafkaProducer.produce") def test_capture_event_with_uuid_in_payload(self, kafka_produce): diff --git a/ee/api/test/test_dashboard.py b/ee/api/test/test_dashboard.py index 8c39a17135d..e494dfbce7a 100644 --- a/ee/api/test/test_dashboard.py +++ b/ee/api/test/test_dashboard.py @@ -106,7 +106,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest): response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response( "Only the dashboard owner and project admins have the restriction rights required to change the dashboard's restriction level." @@ -178,7 +178,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest): response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response("You don't have edit permissions for this dashboard."), ) @@ -262,7 +262,7 @@ class TestDashboardEnterpriseAPI(APILicensedTest): response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response("You don't have edit permissions for this dashboard."), ) diff --git a/ee/api/test/test_event_definition.py b/ee/api/test/test_event_definition.py index 6e3cbb8775f..2aa87e63e2e 100644 --- a/ee/api/test/test_event_definition.py +++ b/ee/api/test/test_event_definition.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import cast, Optional, List, Dict, Any +from typing import cast, Optional, Any import dateutil.parser from django.utils import timezone @@ -26,7 +26,7 @@ class TestEventDefinitionEnterpriseAPI(APIBaseTest): Ignoring the verified field we'd expect ordering purchase, watched_movie, entered_free_trial, $pageview With it we expect watched_movie, entered_free_trial, purchase, $pageview """ - EXPECTED_EVENT_DEFINITIONS: List[Dict[str, Any]] = [ + EXPECTED_EVENT_DEFINITIONS: list[dict[str, Any]] = [ {"name": "purchase", "verified": None}, {"name": "entered_free_trial", "verified": True}, {"name": "watched_movie", "verified": True}, diff --git a/ee/api/test/test_insight.py b/ee/api/test/test_insight.py index 00863551500..7db46bf79de 100644 --- a/ee/api/test/test_insight.py +++ b/ee/api/test/test_insight.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import cast, Optional, List, Dict +from typing import cast, Optional from django.test import override_settings from django.utils import timezone from freezegun import freeze_time @@ -305,7 +305,7 @@ class TestInsightEnterpriseAPI(APILicensedTest): dashboard.refresh_from_db() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response( "This insight is on a dashboard that can only be edited by its owner, team members invited to editing the dashboard, and project admins." @@ -547,7 +547,7 @@ class TestInsightEnterpriseAPI(APILicensedTest): @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @snapshot_postgres_queries def test_listing_insights_does_not_nplus1(self) -> None: - query_counts: List[int] = [] + query_counts: list[int] = [] queries = [] for i in range(5): @@ -587,10 +587,10 @@ class TestInsightEnterpriseAPI(APILicensedTest): f"received query counts\n\n{query_counts}", ) - def assert_insight_activity(self, insight_id: Optional[int], expected: List[Dict]): + def assert_insight_activity(self, insight_id: Optional[int], expected: list[dict]): activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected diff --git a/ee/api/test/test_integration.py b/ee/api/test/test_integration.py index d675415e4bd..7f30635b5af 100644 --- a/ee/api/test/test_integration.py +++ b/ee/api/test/test_integration.py @@ -25,7 +25,7 @@ class TestIntegration(APILicensedTest): signature = ( "v0=" + hmac.new( - "not-so-secret".encode("utf-8"), + b"not-so-secret", sig_basestring.encode("utf-8"), digestmod=hashlib.sha256, ).hexdigest() diff --git a/ee/api/test/test_property_definition.py b/ee/api/test/test_property_definition.py index ef8d4dd9285..effa43a9f4b 100644 --- a/ee/api/test/test_property_definition.py +++ b/ee/api/test/test_property_definition.py @@ -1,4 +1,4 @@ -from typing import cast, Optional, List, Dict +from typing import cast, Optional from freezegun import freeze_time import pytest from django.db.utils import IntegrityError @@ -450,7 +450,7 @@ class TestPropertyDefinitionEnterpriseAPI(APIBaseTest): plan="enterprise", valid_until=timezone.datetime(2500, 1, 19, 3, 14, 7) ) - properties: List[Dict] = [ + properties: list[dict] = [ {"name": "1_when_verified", "verified": True}, {"name": "2_when_verified", "verified": True}, {"name": "3_when_verified", "verified": True}, diff --git a/ee/api/test/test_time_to_see_data.py b/ee/api/test/test_time_to_see_data.py index 4c5a50d51e5..1ad6b4b0813 100644 --- a/ee/api/test/test_time_to_see_data.py +++ b/ee/api/test/test_time_to_see_data.py @@ -1,6 +1,6 @@ import json from dataclasses import asdict, dataclass, field -from typing import Any, List +from typing import Any from unittest import mock import pytest @@ -64,7 +64,7 @@ class TestTimeToSeeDataApi(APIBaseTest): ) response = self.client.post("/api/time_to_see_data/sessions").json() - self.assertEquals( + self.assertEqual( response, [ { @@ -209,18 +209,18 @@ class QueryLogRow: query_time_range_days: int = 1 has_joins: int = 0 has_json_operations: int = 0 - filter_by_type: List[str] = field(default_factory=list) - breakdown_by: List[str] = field(default_factory=list) - entity_math: List[str] = field(default_factory=list) + filter_by_type: list[str] = field(default_factory=list) + breakdown_by: list[str] = field(default_factory=list) + entity_math: list[str] = field(default_factory=list) filter: str = "" ProfileEvents: dict = field(default_factory=dict) - tables: List[str] = field(default_factory=list) - columns: List[str] = field(default_factory=list) + tables: list[str] = field(default_factory=list) + columns: list[str] = field(default_factory=list) query: str = "" log_comment = "" -def insert(table: str, rows: List): +def insert(table: str, rows: list): columns = asdict(rows[0]).keys() all_values, params = [], {} diff --git a/ee/benchmarks/benchmarks.py b/ee/benchmarks/benchmarks.py index 83e82df068f..d999467779a 100644 --- a/ee/benchmarks/benchmarks.py +++ b/ee/benchmarks/benchmarks.py @@ -2,7 +2,6 @@ # Needs to be first to set up django environment from .helpers import benchmark_clickhouse, no_materialized_columns, now from datetime import timedelta -from typing import List, Tuple from ee.clickhouse.materialized_columns.analyze import ( backfill_materialized_columns, get_materialized_columns, @@ -29,7 +28,7 @@ from posthog.models.filters.filter import Filter from posthog.models.property import PropertyName, TableWithProperties from posthog.constants import FunnelCorrelationType -MATERIALIZED_PROPERTIES: List[Tuple[TableWithProperties, PropertyName]] = [ +MATERIALIZED_PROPERTIES: list[tuple[TableWithProperties, PropertyName]] = [ ("events", "$host"), ("events", "$current_url"), ("events", "$event_type"), diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py index da95c0871f5..c301e80f8c2 100644 --- a/ee/billing/billing_manager.py +++ b/ee/billing/billing_manager.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast import jwt import requests @@ -53,7 +53,7 @@ class BillingManager: def __init__(self, license): self.license = license or get_cached_instance_license() - def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> Dict[str, Any]: + def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> dict[str, Any]: if organization and self.license and self.license.is_v2_license: billing_service_response = self._get_billing(organization) @@ -63,7 +63,7 @@ class BillingManager: if organization and billing_service_response: self.update_org_details(organization, billing_service_response) - response: Dict[str, Any] = {"available_features": []} + response: dict[str, Any] = {"available_features": []} response["license"] = {"plan": self.license.plan} @@ -102,7 +102,7 @@ class BillingManager: return response - def update_billing(self, organization: Organization, data: Dict[str, Any]) -> None: + def update_billing(self, organization: Organization, data: dict[str, Any]) -> None: res = requests.patch( f"{BILLING_SERVICE_URL}/api/billing/", headers=self.get_auth_headers(organization), diff --git a/ee/billing/billing_types.py b/ee/billing/billing_types.py index 6151ad32880..0761e02e807 100644 --- a/ee/billing/billing_types.py +++ b/ee/billing/billing_types.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Dict, List, Optional, TypedDict +from typing import Optional, TypedDict from posthog.constants import AvailableFeature @@ -18,7 +18,7 @@ class CustomerProduct(TypedDict): image_url: Optional[str] type: str free_allocation: int - tiers: List[Tier] + tiers: list[Tier] tiered: bool unit_amount_usd: Optional[Decimal] current_amount_usd: Decimal @@ -51,16 +51,16 @@ class CustomerInfo(TypedDict): deactivated: bool has_active_subscription: bool billing_period: BillingPeriod - available_features: List[AvailableFeature] + available_features: list[AvailableFeature] current_total_amount_usd: Optional[str] current_total_amount_usd_after_discount: Optional[str] - products: Optional[List[CustomerProduct]] - custom_limits_usd: Optional[Dict[str, str]] - usage_summary: Optional[Dict[str, Dict[str, Optional[int]]]] + products: Optional[list[CustomerProduct]] + custom_limits_usd: Optional[dict[str, str]] + usage_summary: Optional[dict[str, dict[str, Optional[int]]]] free_trial_until: Optional[str] discount_percent: Optional[int] discount_amount_usd: Optional[str] - customer_trust_scores: Dict[str, int] + customer_trust_scores: dict[str, int] class BillingStatus(TypedDict): diff --git a/ee/billing/quota_limiting.py b/ee/billing/quota_limiting.py index 1c50b69803a..8f5864c3ed5 100644 --- a/ee/billing/quota_limiting.py +++ b/ee/billing/quota_limiting.py @@ -1,7 +1,8 @@ import copy from datetime import datetime, timedelta from enum import Enum -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict, cast +from typing import Optional, TypedDict, cast +from collections.abc import Mapping, Sequence import dateutil.parser import posthoganalytics @@ -66,13 +67,13 @@ def add_limited_team_tokens(resource: QuotaResource, tokens: Mapping[str, int], redis_client.zadd(f"{cache_key}{resource.value}", tokens) # type: ignore # (zadd takes a Mapping[str, int] but the derived Union type is wrong) -def remove_limited_team_tokens(resource: QuotaResource, tokens: List[str], cache_key: QuotaLimitingCaches) -> None: +def remove_limited_team_tokens(resource: QuotaResource, tokens: list[str], cache_key: QuotaLimitingCaches) -> None: redis_client = get_client() redis_client.zrem(f"{cache_key}{resource.value}", *tokens) @cache_for(timedelta(seconds=30), background_refresh=True) -def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> List[str]: +def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> list[str]: now = timezone.now() redis_client = get_client() results = redis_client.zrangebyscore(f"{cache_key}{resource.value}", min=now.timestamp(), max="+inf") @@ -86,7 +87,7 @@ class UsageCounters(TypedDict): def org_quota_limited_until( - organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: List[str] + organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: list[str] ) -> Optional[OrgQuotaLimitingInformation]: if not organization.usage: return None @@ -265,7 +266,7 @@ def sync_org_quota_limits(organization: Organization): def get_team_attribute_by_quota_resource(organization: Organization, resource: QuotaResource): if resource in [QuotaResource.EVENTS, QuotaResource.RECORDINGS]: - team_tokens: List[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x] + team_tokens: list[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x] if not team_tokens: capture_exception(Exception(f"quota_limiting: No team tokens found for organization: {organization.id}")) @@ -274,7 +275,7 @@ def get_team_attribute_by_quota_resource(organization: Organization, resource: Q return team_tokens if resource == QuotaResource.ROWS_SYNCED: - team_ids: List[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x] + team_ids: list[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x] if not team_ids: capture_exception(Exception(f"quota_limiting: No team ids found for organization: {organization.id}")) @@ -322,7 +323,7 @@ def set_org_usage_summary( def update_all_org_billing_quotas( dry_run: bool = False, -) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: +) -> tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: period = get_current_day() period_start, period_end = period @@ -352,8 +353,8 @@ def update_all_org_billing_quotas( ) ) - todays_usage_report: Dict[str, UsageCounters] = {} - orgs_by_id: Dict[str, Organization] = {} + todays_usage_report: dict[str, UsageCounters] = {} + orgs_by_id: dict[str, Organization] = {} # we iterate through all teams, and add their usage to the organization they belong to for team in teams: @@ -373,12 +374,12 @@ def update_all_org_billing_quotas( for field in team_report: org_report[field] += team_report[field] # type: ignore - quota_limited_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} - quota_limiting_suspended_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limited_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limiting_suspended_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} # Get the current quota limits so we can track to poshog if it changes orgs_with_changes = set() - previously_quota_limited_team_tokens: Dict[str, List[str]] = {x.value: [] for x in QuotaResource} + previously_quota_limited_team_tokens: dict[str, list[str]] = {x.value: [] for x in QuotaResource} for field in quota_limited_orgs: previously_quota_limited_team_tokens[field] = list_limited_team_attributes( @@ -405,8 +406,8 @@ def update_all_org_billing_quotas( elif quota_limited_until: quota_limited_orgs[field][org_id] = quota_limited_until - quota_limited_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} - quota_limiting_suspended_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limited_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limiting_suspended_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} # Convert the org ids to team tokens for team in teams: diff --git a/ee/clickhouse/materialized_columns/analyze.py b/ee/clickhouse/materialized_columns/analyze.py index dac1aa6abc0..e8801fe17f6 100644 --- a/ee/clickhouse/materialized_columns/analyze.py +++ b/ee/clickhouse/materialized_columns/analyze.py @@ -1,6 +1,7 @@ import re from datetime import timedelta -from typing import Dict, Generator, List, Optional, Set, Tuple +from typing import Optional +from collections.abc import Generator import structlog @@ -27,18 +28,18 @@ from posthog.models.property import PropertyName, TableColumn, TableWithProperti from posthog.models.property_definition import PropertyDefinition from posthog.models.team import Team -Suggestion = Tuple[TableWithProperties, TableColumn, PropertyName] +Suggestion = tuple[TableWithProperties, TableColumn, PropertyName] logger = structlog.get_logger(__name__) class TeamManager: @instance_memoize - def person_properties(self, team_id: str) -> Set[str]: + def person_properties(self, team_id: str) -> set[str]: return self._get_properties(GET_PERSON_PROPERTIES_COUNT, team_id) @instance_memoize - def event_properties(self, team_id: str) -> Set[str]: + def event_properties(self, team_id: str) -> set[str]: return set( PropertyDefinition.objects.filter(team_id=team_id, type=PropertyDefinition.Type.EVENT).values_list( "name", flat=True @@ -46,17 +47,17 @@ class TeamManager: ) @instance_memoize - def person_on_events_properties(self, team_id: str) -> Set[str]: + def person_on_events_properties(self, team_id: str) -> set[str]: return self._get_properties(GET_EVENT_PROPERTIES_COUNT.format(column_name="person_properties"), team_id) @instance_memoize - def group_on_events_properties(self, group_type_index: int, team_id: str) -> Set[str]: + def group_on_events_properties(self, group_type_index: int, team_id: str) -> set[str]: return self._get_properties( GET_EVENT_PROPERTIES_COUNT.format(column_name=f"group{group_type_index}_properties"), team_id, ) - def _get_properties(self, query, team_id) -> Set[str]: + def _get_properties(self, query, team_id) -> set[str]: rows = sync_execute(query, {"team_id": team_id}) return {name for name, _ in rows} @@ -86,12 +87,12 @@ class Query: return matches[0] if matches else None @cached_property - def _all_properties(self) -> List[Tuple[str, PropertyName]]: + def _all_properties(self) -> list[tuple[str, PropertyName]]: return re.findall(r"JSONExtract\w+\((\S+), '([^']+)'\)", self.query_string) def properties( self, team_manager: TeamManager - ) -> Generator[Tuple[TableWithProperties, TableColumn, PropertyName], None, None]: + ) -> Generator[tuple[TableWithProperties, TableColumn, PropertyName], None, None]: # Reverse-engineer whether a property is an "event" or "person" property by getting their event definitions. # :KLUDGE: Note that the same property will be found on both tables if both are used. # We try to hone in on the right column by looking at the column from which the property is extracted. @@ -124,7 +125,7 @@ class Query: yield "events", "group4_properties", property -def _analyze(since_hours_ago: int, min_query_time: int) -> List[Suggestion]: +def _analyze(since_hours_ago: int, min_query_time: int) -> list[Suggestion]: "Finds columns that should be materialized" raw_queries = sync_execute( @@ -179,7 +180,7 @@ LIMIT 100 -- Make sure we don't add 100s of columns in one run def materialize_properties_task( - columns_to_materialize: Optional[List[Suggestion]] = None, + columns_to_materialize: Optional[list[Suggestion]] = None, time_to_analyze_hours: int = MATERIALIZE_COLUMNS_ANALYSIS_PERIOD_HOURS, maximum: int = MATERIALIZE_COLUMNS_MAX_AT_ONCE, min_query_time: int = MATERIALIZE_COLUMNS_MINIMUM_QUERY_TIME, @@ -203,7 +204,7 @@ def materialize_properties_task( else: logger.info("Found no columns to materialize.") - properties: Dict[TableWithProperties, List[Tuple[PropertyName, TableColumn]]] = { + properties: dict[TableWithProperties, list[tuple[PropertyName, TableColumn]]] = { "events": [], "person": [], } diff --git a/ee/clickhouse/materialized_columns/columns.py b/ee/clickhouse/materialized_columns/columns.py index 71bfd5adcc7..1340abde0a6 100644 --- a/ee/clickhouse/materialized_columns/columns.py +++ b/ee/clickhouse/materialized_columns/columns.py @@ -1,6 +1,6 @@ import re from datetime import timedelta -from typing import Dict, List, Literal, Tuple, Union, cast +from typing import Literal, Union, cast from clickhouse_driver.errors import ServerException from django.utils.timezone import now @@ -36,7 +36,7 @@ SHORT_TABLE_COLUMN_NAME = { @cache_for(timedelta(minutes=15)) def get_materialized_columns( table: TablesWithMaterializedColumns, -) -> Dict[Tuple[PropertyName, TableColumn], ColumnName]: +) -> dict[tuple[PropertyName, TableColumn], ColumnName]: rows = sync_execute( """ SELECT comment, name @@ -141,7 +141,7 @@ def add_minmax_index(table: TablesWithMaterializedColumns, column_name: str): def backfill_materialized_columns( table: TableWithProperties, - properties: List[Tuple[PropertyName, TableColumn]], + properties: list[tuple[PropertyName, TableColumn]], backfill_period: timedelta, test_settings=None, ) -> None: @@ -215,7 +215,7 @@ def _materialized_column_name( return f"{prefix}{property_str}{suffix}" -def _extract_property(comment: str) -> Tuple[PropertyName, TableColumn]: +def _extract_property(comment: str) -> tuple[PropertyName, TableColumn]: # Old style comments have the format "column_materializer::property", dealing with the default table column. # Otherwise, it's "column_materializer::table_column::property" split_column = comment.split("::", 2) diff --git a/ee/clickhouse/models/test/test_action.py b/ee/clickhouse/models/test/test_action.py index 692844e55c1..4f06b3e871a 100644 --- a/ee/clickhouse/models/test/test_action.py +++ b/ee/clickhouse/models/test/test_action.py @@ -1,5 +1,4 @@ import dataclasses -from typing import List from posthog.client import sync_execute from posthog.hogql.hogql import HogQLContext @@ -22,7 +21,7 @@ class MockEvent: distinct_id: str -def _get_events_for_action(action: Action) -> List[MockEvent]: +def _get_events_for_action(action: Action) -> list[MockEvent]: hogql_context = HogQLContext(team_id=action.team_id) formatted_query, params = format_action_filter( team_id=action.team_id, action=action, prepend="", hogql_context=hogql_context diff --git a/ee/clickhouse/models/test/test_property.py b/ee/clickhouse/models/test/test_property.py index 913058d4ae1..6348697d844 100644 --- a/ee/clickhouse/models/test/test_property.py +++ b/ee/clickhouse/models/test/test_property.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Literal, Union, cast +from typing import Literal, Union, cast from uuid import UUID import pytest @@ -43,7 +43,7 @@ from posthog.test.base import ( class TestPropFormat(ClickhouseTestMixin, BaseTest): CLASS_DATA_LEVEL_SETUP = False - def _run_query(self, filter: Filter, **kwargs) -> List: + def _run_query(self, filter: Filter, **kwargs) -> list: query, params = parse_prop_grouped_clauses( property_group=filter.property_groups, allow_denormalized_props=True, @@ -776,7 +776,7 @@ class TestPropFormat(ClickhouseTestMixin, BaseTest): class TestPropDenormalized(ClickhouseTestMixin, BaseTest): CLASS_DATA_LEVEL_SETUP = False - def _run_query(self, filter: Filter, join_person_tables=False) -> List: + def _run_query(self, filter: Filter, join_person_tables=False) -> list: outer_properties = PropertyOptimizer().parse_property_groups(filter.property_groups).outer query, params = parse_prop_grouped_clauses( team_id=self.team.pk, @@ -1232,7 +1232,7 @@ TEST_BREAKDOWN_PROCESSING = [ @pytest.mark.parametrize("breakdown, table, query_alias, column, expected", TEST_BREAKDOWN_PROCESSING) def test_breakdown_query_expression( clean_up_materialised_columns, - breakdown: Union[str, List[str]], + breakdown: Union[str, list[str]], table: TableWithProperties, query_alias: Literal["prop", "value"], column: str, @@ -1281,7 +1281,7 @@ TEST_BREAKDOWN_PROCESSING_MATERIALIZED = [ ) def test_breakdown_query_expression_materialised( clean_up_materialised_columns, - breakdown: Union[str, List[str]], + breakdown: Union[str, list[str]], table: TableWithProperties, query_alias: Literal["prop", "value"], column: str, @@ -1317,7 +1317,7 @@ def test_breakdown_query_expression_materialised( @pytest.fixture -def test_events(db, team) -> List[UUID]: +def test_events(db, team) -> list[UUID]: return [ _create_event( event="$pageview", @@ -1958,7 +1958,7 @@ def test_combine_group_properties(): ], } - combined_group = PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])).combine_properties( + combined_group = PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])).combine_properties( PropertyOperatorType.OR, [propertyC, propertyD] ) assert combined_group.to_dict() == { diff --git a/ee/clickhouse/queries/column_optimizer.py b/ee/clickhouse/queries/column_optimizer.py index dd62154dd20..b1bf142aa3d 100644 --- a/ee/clickhouse/queries/column_optimizer.py +++ b/ee/clickhouse/queries/column_optimizer.py @@ -1,5 +1,5 @@ -from typing import Counter as TCounter -from typing import Set, cast +from collections import Counter as TCounter +from typing import cast from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.constants import TREND_FILTER_TYPE_ACTIONS, FunnelCorrelationType @@ -20,16 +20,16 @@ from posthog.queries.trends.util import is_series_group_based class EnterpriseColumnOptimizer(FOSSColumnOptimizer): @cached_property - def group_types_to_query(self) -> Set[GroupTypeIndex]: + def group_types_to_query(self) -> set[GroupTypeIndex]: used_properties = self.used_properties_with_type("group") return {cast(GroupTypeIndex, group_type_index) for _, _, group_type_index in used_properties} @cached_property - def group_on_event_columns_to_query(self) -> Set[ColumnName]: + def group_on_event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table group columns containing materialized properties that this query needs" used_properties = self.used_properties_with_type("group") - columns_to_query: Set[ColumnName] = set() + columns_to_query: set[ColumnName] = set() for group_type_index in range(5): columns_to_query = columns_to_query.union( @@ -120,7 +120,7 @@ class EnterpriseColumnOptimizer(FOSSColumnOptimizer): counter += get_action_tables_and_properties(entity.get_action()) if ( - not isinstance(self.filter, (StickinessFilter, PropertiesTimelineFilter)) + not isinstance(self.filter, StickinessFilter | PropertiesTimelineFilter) and self.filter.correlation_type == FunnelCorrelationType.PROPERTIES and self.filter.correlation_property_names ): diff --git a/ee/clickhouse/queries/enterprise_cohort_query.py b/ee/clickhouse/queries/enterprise_cohort_query.py index a748a64adf0..814b61e9a8b 100644 --- a/ee/clickhouse/queries/enterprise_cohort_query.py +++ b/ee/clickhouse/queries/enterprise_cohort_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, cast +from typing import Any, cast from posthog.constants import PropertyOperatorType from posthog.models.cohort.util import get_count_operator @@ -15,18 +15,18 @@ from posthog.queries.util import PersonPropertiesMode from posthog.schema import PersonsOnEventsMode -def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]: +def check_negation_clause(prop: PropertyGroup) -> tuple[bool, bool]: has_negation_clause = False has_primary_clase = False if len(prop.values): if isinstance(prop.values[0], PropertyGroup): - for p in cast(List[PropertyGroup], prop.values): + for p in cast(list[PropertyGroup], prop.values): has_neg, has_primary = check_negation_clause(p) has_negation_clause = has_negation_clause or has_neg has_primary_clase = has_primary_clase or has_primary else: - for property in cast(List[Property], prop.values): + for property in cast(list[Property], prop.values): if property.negation: has_negation_clause = True else: @@ -42,7 +42,7 @@ def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]: class EnterpriseCohortQuery(FOSSCohortQuery): - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: if not self._outer_property_groups: # everything is pushed down, no behavioral stuff to do # thus, use personQuery directly @@ -87,9 +87,9 @@ class EnterpriseCohortQuery(FOSSCohortQuery): return final_query, self.params - def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if prop.type == "behavioral": if prop.value == "performed_event": @@ -117,7 +117,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): return res, params - def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"stopped_event_condition_{prepend}_{idx}" @@ -152,7 +152,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): }, ) - def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"restarted_event_condition_{prepend}_{idx}" @@ -191,7 +191,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): }, ) - def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) entity_query, entity_params = self._get_entity(event, prepend, idx) @@ -212,7 +212,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): {f"{date_param}": date_value, **entity_params}, ) - def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) entity_query, entity_params = self._get_entity(event, prepend, idx) @@ -266,7 +266,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): ) @cached_property - def sequence_filters_to_query(self) -> List[Property]: + def sequence_filters_to_query(self) -> list[Property]: props = [] for prop in self._filter.property_groups.flat: if prop.value == "performed_event_sequence": @@ -274,13 +274,13 @@ class EnterpriseCohortQuery(FOSSCohortQuery): return props @cached_property - def sequence_filters_lookup(self) -> Dict[str, str]: + def sequence_filters_lookup(self) -> dict[str, str]: lookup = {} for idx, prop in enumerate(self.sequence_filters_to_query): lookup[str(prop.to_dict())] = f"{idx}" return lookup - def _get_sequence_query(self) -> Tuple[str, Dict[str, Any], str]: + def _get_sequence_query(self) -> tuple[str, dict[str, Any], str]: params = {} materialized_columns = list(self._column_optimizer.event_columns_to_query) @@ -356,7 +356,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): self.FUNNEL_QUERY_ALIAS, ) - def _get_sequence_filter(self, prop: Property, idx: int) -> Tuple[List[str], List[str], List[str], Dict[str, Any]]: + def _get_sequence_filter(self, prop: Property, idx: int) -> tuple[list[str], list[str], list[str], dict[str, Any]]: event = validate_entity((prop.event_type, prop.key)) entity_query, entity_params = self._get_entity(event, f"event_sequence_{self._cohort_pk}", idx) seq_event = validate_entity((prop.seq_event_type, prop.seq_event)) @@ -405,7 +405,7 @@ class EnterpriseCohortQuery(FOSSCohortQuery): }, ) - def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: return ( f"{self.SEQUENCE_FIELD_ALIAS}_{self.sequence_filters_lookup[str(prop.to_dict())]}", {}, diff --git a/ee/clickhouse/queries/event_query.py b/ee/clickhouse/queries/event_query.py index b1b4dbb695e..0e16abc7800 100644 --- a/ee/clickhouse/queries/event_query.py +++ b/ee/clickhouse/queries/event_query.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from ee.clickhouse.materialized_columns.columns import ColumnName from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer @@ -33,9 +33,9 @@ class EnterpriseEventQuery(EventQuery): should_join_distinct_ids=False, should_join_persons=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, **kwargs, @@ -62,7 +62,7 @@ class EnterpriseEventQuery(EventQuery): self._column_optimizer = EnterpriseColumnOptimizer(self._filter, self._team_id) - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: if isinstance(self._filter, PropertiesTimelineFilter): raise Exception("Properties Timeline never needs groups query") return GroupsJoinQuery( diff --git a/ee/clickhouse/queries/experiments/funnel_experiment_result.py b/ee/clickhouse/queries/experiments/funnel_experiment_result.py index ab117b07c69..845cce75d50 100644 --- a/ee/clickhouse/queries/experiments/funnel_experiment_result.py +++ b/ee/clickhouse/queries/experiments/funnel_experiment_result.py @@ -1,7 +1,7 @@ from dataclasses import asdict, dataclass from datetime import datetime import json -from typing import List, Optional, Tuple, Type +from typing import Optional from zoneinfo import ZoneInfo from numpy.random import default_rng @@ -56,7 +56,7 @@ class ClickhouseFunnelExperimentResult: feature_flag: FeatureFlag, experiment_start_date: datetime, experiment_end_date: Optional[datetime] = None, - funnel_class: Type[ClickhouseFunnel] = ClickhouseFunnel, + funnel_class: type[ClickhouseFunnel] = ClickhouseFunnel, ): breakdown_key = f"$feature/{feature_flag.key}" self.variants = [variant["key"] for variant in feature_flag.variants] @@ -148,9 +148,9 @@ class ClickhouseFunnelExperimentResult: @staticmethod def calculate_results( control_variant: Variant, - test_variants: List[Variant], - priors: Tuple[int, int] = (1, 1), - ) -> List[Probability]: + test_variants: list[Variant], + priors: tuple[int, int] = (1, 1), + ) -> list[Probability]: """ Calculates probability that A is better than B. First variant is control, rest are test variants. @@ -186,9 +186,9 @@ class ClickhouseFunnelExperimentResult: @staticmethod def are_results_significant( control_variant: Variant, - test_variants: List[Variant], - probabilities: List[Probability], - ) -> Tuple[ExperimentSignificanceCode, Probability]: + test_variants: list[Variant], + probabilities: list[Probability], + ) -> tuple[ExperimentSignificanceCode, Probability]: def get_conversion_rate(variant: Variant): return variant.success_count / (variant.success_count + variant.failure_count) @@ -226,7 +226,7 @@ class ClickhouseFunnelExperimentResult: return ExperimentSignificanceCode.SIGNIFICANT, expected_loss -def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) -> float: +def calculate_expected_loss(target_variant: Variant, variants: list[Variant]) -> float: """ Calculates expected loss in conversion rate for a given variant. Loss calculation comes from VWO's SmartStats technical paper: @@ -268,7 +268,7 @@ def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) -> return loss / simulations_count -def simulate_winning_variant_for_conversion(target_variant: Variant, variants: List[Variant]) -> Probability: +def simulate_winning_variant_for_conversion(target_variant: Variant, variants: list[Variant]) -> Probability: random_sampler = default_rng() prior_success = 1 prior_failure = 1 @@ -300,7 +300,7 @@ def simulate_winning_variant_for_conversion(target_variant: Variant, variants: L return winnings / simulations_count -def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]: +def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]: """ Calculates the probability of winning for each variant. """ diff --git a/ee/clickhouse/queries/experiments/secondary_experiment_result.py b/ee/clickhouse/queries/experiments/secondary_experiment_result.py index 4926d2920af..bd485c43622 100644 --- a/ee/clickhouse/queries/experiments/secondary_experiment_result.py +++ b/ee/clickhouse/queries/experiments/secondary_experiment_result.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from rest_framework.exceptions import ValidationError from ee.clickhouse.queries.experiments.funnel_experiment_result import ClickhouseFunnelExperimentResult @@ -55,7 +55,7 @@ class ClickhouseSecondaryExperimentResult: return {"result": variants, **significance_results} - def get_funnel_conversion_rate_for_variants(self, insight_results) -> Dict[str, float]: + def get_funnel_conversion_rate_for_variants(self, insight_results) -> dict[str, float]: variants = {} for result in insight_results: total = result[0]["count"] @@ -67,7 +67,7 @@ class ClickhouseSecondaryExperimentResult: return variants - def get_trend_count_data_for_variants(self, insight_results) -> Dict[str, float]: + def get_trend_count_data_for_variants(self, insight_results) -> dict[str, float]: # this assumes the Trend insight is Cumulative, unless using count per user variants = {} diff --git a/ee/clickhouse/queries/experiments/test_experiment_result.py b/ee/clickhouse/queries/experiments/test_experiment_result.py index 20b737efa17..18eb673bf9a 100644 --- a/ee/clickhouse/queries/experiments/test_experiment_result.py +++ b/ee/clickhouse/queries/experiments/test_experiment_result.py @@ -1,7 +1,6 @@ import unittest from functools import lru_cache from math import exp, lgamma, log -from typing import List from flaky import flaky @@ -31,7 +30,7 @@ def logbeta(x: int, y: int) -> float: # calculation: https://www.evanmiller.org/bayesian-ab-testing.html#binary_ab -def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: List[Variant]) -> Probability: +def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: list[Variant]) -> Probability: """ Calculates the probability of winning for target variant. """ @@ -455,7 +454,7 @@ class TestFunnelExperimentCalculator(unittest.TestCase): # calculation: https://www.evanmiller.org/bayesian-ab-testing.html#count_ab def calculate_probability_of_winning_for_target_count_data( - target_variant: CountVariant, other_variants: List[CountVariant] + target_variant: CountVariant, other_variants: list[CountVariant] ) -> Probability: """ Calculates the probability of winning for target variant. diff --git a/ee/clickhouse/queries/experiments/trend_experiment_result.py b/ee/clickhouse/queries/experiments/trend_experiment_result.py index 02974d8bd82..0370e0a684a 100644 --- a/ee/clickhouse/queries/experiments/trend_experiment_result.py +++ b/ee/clickhouse/queries/experiments/trend_experiment_result.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass from datetime import datetime from functools import lru_cache from math import exp, lgamma, log -from typing import List, Optional, Tuple, Type +from typing import Optional from zoneinfo import ZoneInfo from numpy.random import default_rng @@ -78,7 +78,7 @@ class ClickhouseTrendExperimentResult: feature_flag: FeatureFlag, experiment_start_date: datetime, experiment_end_date: Optional[datetime] = None, - trend_class: Type[Trends] = Trends, + trend_class: type[Trends] = Trends, custom_exposure_filter: Optional[Filter] = None, ): breakdown_key = f"$feature/{feature_flag.key}" @@ -316,7 +316,7 @@ class ClickhouseTrendExperimentResult: return control_variant, test_variants @staticmethod - def calculate_results(control_variant: Variant, test_variants: List[Variant]) -> List[Probability]: + def calculate_results(control_variant: Variant, test_variants: list[Variant]) -> list[Probability]: """ Calculates probability that A is better than B. First variant is control, rest are test variants. @@ -346,9 +346,9 @@ class ClickhouseTrendExperimentResult: @staticmethod def are_results_significant( control_variant: Variant, - test_variants: List[Variant], - probabilities: List[Probability], - ) -> Tuple[ExperimentSignificanceCode, Probability]: + test_variants: list[Variant], + probabilities: list[Probability], + ) -> tuple[ExperimentSignificanceCode, Probability]: # TODO: Experiment with Expected Loss calculations for trend experiments for variant in test_variants: @@ -375,7 +375,7 @@ class ClickhouseTrendExperimentResult: return ExperimentSignificanceCode.SIGNIFICANT, p_value -def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: List[Variant]) -> float: +def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: list[Variant]) -> float: random_sampler = default_rng() simulations_count = 100_000 @@ -399,7 +399,7 @@ def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants return winnings / simulations_count -def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]: +def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]: """ Calculates the probability of winning for each variant. """ @@ -458,7 +458,7 @@ def poisson_p_value(control_count, control_exposure, test_count, test_exposure): return min(1, 2 * min(low_p_value, high_p_value)) -def calculate_p_value(control_variant: Variant, test_variants: List[Variant]) -> Probability: +def calculate_p_value(control_variant: Variant, test_variants: list[Variant]) -> Probability: best_test_variant = max(test_variants, key=lambda variant: variant.count) return poisson_p_value( diff --git a/ee/clickhouse/queries/experiments/utils.py b/ee/clickhouse/queries/experiments/utils.py index 88418e3e354..c0211e4c9de 100644 --- a/ee/clickhouse/queries/experiments/utils.py +++ b/ee/clickhouse/queries/experiments/utils.py @@ -1,4 +1,4 @@ -from typing import Set, Union +from typing import Union from posthog.client import sync_execute from posthog.constants import TREND_FILTER_TYPE_ACTIONS @@ -20,7 +20,7 @@ def requires_flag_warning(filter: Filter, team: Team) -> bool: {parsed_date_to} """ - events: Set[Union[int, str]] = set() + events: set[Union[int, str]] = set() entities_to_use = filter.entities for entity in entities_to_use: diff --git a/ee/clickhouse/queries/funnels/funnel_correlation.py b/ee/clickhouse/queries/funnels/funnel_correlation.py index ed3995968a0..c25763167f2 100644 --- a/ee/clickhouse/queries/funnels/funnel_correlation.py +++ b/ee/clickhouse/queries/funnels/funnel_correlation.py @@ -2,12 +2,8 @@ import dataclasses import urllib.parse from typing import ( Any, - Dict, - List, Literal, Optional, - Set, - Tuple, TypedDict, Union, cast, @@ -40,7 +36,7 @@ from posthog.utils import generate_short_id class EventDefinition(TypedDict): event: str - properties: Dict[str, Any] + properties: dict[str, Any] elements: list @@ -74,7 +70,7 @@ class FunnelCorrelationResponse(TypedDict): queries, but we could use, for example, a dataclass """ - events: List[EventOddsRatioSerialized] + events: list[EventOddsRatioSerialized] skewed: bool @@ -153,7 +149,7 @@ class FunnelCorrelation: ) @property - def properties_to_include(self) -> List[str]: + def properties_to_include(self) -> list[str]: props_to_include = [] if ( self._team.person_on_events_mode != PersonsOnEventsMode.disabled @@ -203,7 +199,7 @@ class FunnelCorrelation: return True return False - def get_contingency_table_query(self) -> Tuple[str, Dict[str, Any]]: + def get_contingency_table_query(self) -> tuple[str, dict[str, Any]]: """ Returns a query string and params, which are used to generate the contingency table. The query returns success and failure count for event / property values, along with total success and failure counts. @@ -216,7 +212,7 @@ class FunnelCorrelation: return self.get_event_query() - def get_event_query(self) -> Tuple[str, Dict[str, Any]]: + def get_event_query(self) -> tuple[str, dict[str, Any]]: funnel_persons_query, funnel_persons_params = self.get_funnel_actors_cte() event_join_query = self._get_events_join_query() @@ -279,7 +275,7 @@ class FunnelCorrelation: return query, params - def get_event_property_query(self) -> Tuple[str, Dict[str, Any]]: + def get_event_property_query(self) -> tuple[str, dict[str, Any]]: if not self._filter.correlation_event_names: raise ValidationError("Event Property Correlation expects atleast one event name to run correlation on") @@ -359,7 +355,7 @@ class FunnelCorrelation: return query, params - def get_properties_query(self) -> Tuple[str, Dict[str, Any]]: + def get_properties_query(self) -> tuple[str, dict[str, Any]]: if not self._filter.correlation_property_names: raise ValidationError("Property Correlation expects atleast one Property to run correlation on") @@ -580,7 +576,7 @@ class FunnelCorrelation: ) def _get_funnel_step_names(self): - events: Set[Union[int, str]] = set() + events: set[Union[int, str]] = set() for entity in self._filter.entities: if entity.type == TREND_FILTER_TYPE_ACTIONS: action = entity.get_action() @@ -590,7 +586,7 @@ class FunnelCorrelation: return sorted(events) - def _run(self) -> Tuple[List[EventOddsRatio], bool]: + def _run(self) -> tuple[list[EventOddsRatio], bool]: """ Run the diagnose query. @@ -834,7 +830,7 @@ class FunnelCorrelation: ).to_params() return f"{self._base_uri}api/person/funnel/correlation?{urllib.parse.urlencode(params)}&cache_invalidation_key={cache_invalidation_key}" - def format_results(self, results: Tuple[List[EventOddsRatio], bool]) -> FunnelCorrelationResponse: + def format_results(self, results: tuple[list[EventOddsRatio], bool]) -> FunnelCorrelationResponse: odds_ratios, skewed_totals = results return { "events": [self.serialize_event_odds_ratio(odds_ratio=odds_ratio) for odds_ratio in odds_ratios], @@ -847,7 +843,7 @@ class FunnelCorrelation: return self.format_results(self._run()) - def get_partial_event_contingency_tables(self) -> Tuple[List[EventContingencyTable], int, int]: + def get_partial_event_contingency_tables(self) -> tuple[list[EventContingencyTable], int, int]: """ For each event a person that started going through the funnel, gets stats for how many of these users are sucessful and how many are unsuccessful. @@ -888,7 +884,7 @@ class FunnelCorrelation: failure_total, ) - def get_funnel_actors_cte(self) -> Tuple[str, Dict[str, Any]]: + def get_funnel_actors_cte(self) -> tuple[str, dict[str, Any]]: extra_fields = ["steps", "final_timestamp", "first_timestamp"] for prop in self.properties_to_include: @@ -975,12 +971,12 @@ def get_entity_odds_ratio(event_contingency_table: EventContingencyTable, prior_ ) -def build_selector(elements: List[Dict[str, Any]]) -> str: +def build_selector(elements: list[dict[str, Any]]) -> str: # build a CSS select given an "elements_chain" # NOTE: my source of what this should be doing is # https://github.com/PostHog/posthog/blob/cc054930a47fb59940531e99a856add49a348ee5/frontend/src/scenes/events/createActionFromEvent.tsx#L36:L36 # - def element_to_selector(element: Dict[str, Any]) -> str: + def element_to_selector(element: dict[str, Any]) -> str: if attr_id := element.get("attr_id"): return f'[id="{attr_id}"]' diff --git a/ee/clickhouse/queries/funnels/funnel_correlation_persons.py b/ee/clickhouse/queries/funnels/funnel_correlation_persons.py index 6a0cfe36551..b02a8b8e9b6 100644 --- a/ee/clickhouse/queries/funnels/funnel_correlation_persons.py +++ b/ee/clickhouse/queries/funnels/funnel_correlation_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from django.db.models.query import QuerySet from rest_framework.exceptions import ValidationError @@ -52,9 +52,9 @@ class FunnelCorrelationActors(ActorBaseQuery): def get_actors( self, - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], int, ]: if self._filter.correlation_type == FunnelCorrelationType.PROPERTIES: @@ -167,7 +167,7 @@ class _FunnelPropertyCorrelationActors(ActorBaseQuery): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): if not self._filter.correlation_property_values: raise ValidationError("Property Correlation expects atleast one Property to get persons for") diff --git a/ee/clickhouse/queries/funnels/test/breakdown_cases.py b/ee/clickhouse/queries/funnels/test/breakdown_cases.py index f4fb2689d87..7a1b2076776 100644 --- a/ee/clickhouse/queries/funnels/test/breakdown_cases.py +++ b/ee/clickhouse/queries/funnels/test/breakdown_cases.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from posthog.constants import INSIGHT_FUNNELS from posthog.models.filters import Filter @@ -51,8 +51,8 @@ def funnel_breakdown_group_test_factory(Funnel, FunnelPerson, _create_event, _cr properties={"industry": "random"}, ) - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, diff --git a/ee/clickhouse/queries/groups_join_query.py b/ee/clickhouse/queries/groups_join_query.py index db1d12a3c6c..7a3dc46daf9 100644 --- a/ee/clickhouse/queries/groups_join_query.py +++ b/ee/clickhouse/queries/groups_join_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer from posthog.models import Filter @@ -35,7 +35,7 @@ class GroupsJoinQuery: self._join_key = join_key self._person_on_events_mode = person_on_events_mode - def get_join_query(self) -> Tuple[str, Dict]: + def get_join_query(self) -> tuple[str, dict]: join_queries, params = [], {} if self._person_on_events_mode != PersonsOnEventsMode.disabled and groups_on_events_querying_enabled(): @@ -63,7 +63,7 @@ class GroupsJoinQuery: return "\n".join(join_queries), params - def get_filter_query(self, group_type_index: GroupTypeIndex) -> Tuple[str, Dict]: + def get_filter_query(self, group_type_index: GroupTypeIndex) -> tuple[str, dict]: var = f"group_index_{group_type_index}" params = { "team_id": self._team_id, diff --git a/ee/clickhouse/queries/paths/paths.py b/ee/clickhouse/queries/paths/paths.py index a5b9968da58..f20744ee672 100644 --- a/ee/clickhouse/queries/paths/paths.py +++ b/ee/clickhouse/queries/paths/paths.py @@ -1,5 +1,5 @@ from re import escape -from typing import Dict, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from jsonschema import ValidationError @@ -34,8 +34,8 @@ class ClickhousePaths(Paths): ): raise ValidationError("Max Edge weight can't be lower than min edge weight") - def get_edge_weight_clause(self) -> Tuple[str, Dict]: - params: Dict[str, int] = {} + def get_edge_weight_clause(self) -> tuple[str, dict]: + params: dict[str, int] = {} conditions = [] @@ -60,8 +60,8 @@ class ClickhousePaths(Paths): else: return "" - def get_target_clause(self) -> Tuple[str, Dict]: - params: Dict[str, Union[str, None]] = { + def get_target_clause(self) -> tuple[str, dict]: + params: dict[str, Union[str, None]] = { "target_point": None, "secondary_target_point": None, } @@ -152,7 +152,7 @@ class ClickhousePaths(Paths): else: return "arraySlice" - def get_filtered_path_ordering(self) -> Tuple[str, ...]: + def get_filtered_path_ordering(self) -> tuple[str, ...]: fields_to_include = ["filtered_path", "filtered_timings"] + [ f"filtered_{field}s" for field in self.extra_event_fields_and_properties ] diff --git a/ee/clickhouse/queries/related_actors_query.py b/ee/clickhouse/queries/related_actors_query.py index 9c031a3b662..e4cd462ace4 100644 --- a/ee/clickhouse/queries/related_actors_query.py +++ b/ee/clickhouse/queries/related_actors_query.py @@ -1,6 +1,6 @@ from datetime import timedelta from functools import cached_property -from typing import List, Optional, Union +from typing import Optional, Union from django.utils.timezone import now @@ -38,8 +38,8 @@ class RelatedActorsQuery: self.group_type_index = validate_group_type_index("group_type_index", group_type_index) self.id = id - def run(self) -> List[SerializedActor]: - results: List[SerializedActor] = [] + def run(self) -> list[SerializedActor]: + results: list[SerializedActor] = [] results.extend(self._query_related_people()) for group_type_mapping in GroupTypeMapping.objects.filter(team_id=self.team.pk): results.extend(self._query_related_groups(group_type_mapping.group_type_index)) @@ -49,7 +49,7 @@ class RelatedActorsQuery: def is_aggregating_by_groups(self) -> bool: return self.group_type_index is not None - def _query_related_people(self) -> List[SerializedPerson]: + def _query_related_people(self) -> list[SerializedPerson]: if not self.is_aggregating_by_groups: return [] @@ -72,7 +72,7 @@ class RelatedActorsQuery: _, serialized_people = get_people(self.team, person_ids) return serialized_people - def _query_related_groups(self, group_type_index: GroupTypeIndex) -> List[SerializedGroup]: + def _query_related_groups(self, group_type_index: GroupTypeIndex) -> list[SerializedGroup]: if group_type_index == self.group_type_index: return [] @@ -102,7 +102,7 @@ class RelatedActorsQuery: _, serialize_groups = get_groups(self.team.pk, group_type_index, group_ids) return serialize_groups - def _take_first(self, rows: List) -> List: + def _take_first(self, rows: list) -> list: return [row[0] for row in rows] @property diff --git a/ee/clickhouse/queries/test/test_paths.py b/ee/clickhouse/queries/test/test_paths.py index fdaf25a043a..69f673e4489 100644 --- a/ee/clickhouse/queries/test/test_paths.py +++ b/ee/clickhouse/queries/test/test_paths.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import Tuple from unittest.mock import MagicMock from uuid import UUID @@ -2905,7 +2904,7 @@ class TestClickhousePaths(ClickhouseTestMixin, APIBaseTest): @snapshot_clickhouse_queries def test_properties_queried_using_path_filter(self): - def should_query_list(filter) -> Tuple[bool, bool]: + def should_query_list(filter) -> tuple[bool, bool]: path_query = PathEventQuery(filter, self.team) return (path_query._should_query_url(), path_query._should_query_screen()) diff --git a/ee/clickhouse/views/experiments.py b/ee/clickhouse/views/experiments.py index f50b9921c92..b37d4e4d765 100644 --- a/ee/clickhouse/views/experiments.py +++ b/ee/clickhouse/views/experiments.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from typing import Any, Optional +from collections.abc import Callable from django.utils.timezone import now from rest_framework import serializers, viewsets diff --git a/ee/clickhouse/views/groups.py b/ee/clickhouse/views/groups.py index e539de4673d..4c67072b11d 100644 --- a/ee/clickhouse/views/groups.py +++ b/ee/clickhouse/views/groups.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, cast +from typing import cast from django.db.models import Q from drf_spectacular.types import OpenApiTypes @@ -34,7 +34,7 @@ class ClickhouseGroupsTypesView(TeamAndOrgViewSetMixin, mixins.ListModelMixin, v @action(detail=False, methods=["PATCH"], name="Update group types metadata") def update_metadata(self, request: request.Request, *args, **kwargs): - for row in cast(List[Dict], request.data): + for row in cast(list[dict], request.data): instance = GroupTypeMapping.objects.get(team=self.team, group_type_index=row["group_type_index"]) serializer = self.get_serializer(instance, data=row) serializer.is_valid(raise_exception=True) diff --git a/ee/clickhouse/views/insights.py b/ee/clickhouse/views/insights.py index ff772b71aae..e6adf49e7ff 100644 --- a/ee/clickhouse/views/insights.py +++ b/ee/clickhouse/views/insights.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from rest_framework.decorators import action from rest_framework.permissions import SAFE_METHODS, BasePermission @@ -47,7 +47,7 @@ class ClickhouseInsightsViewSet(InsightViewSet): return Response(result) @cached_by_filters - def calculate_funnel_correlation(self, request: Request) -> Dict[str, Any]: + def calculate_funnel_correlation(self, request: Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, team=team) diff --git a/ee/clickhouse/views/person.py b/ee/clickhouse/views/person.py index d01dba65da9..f3f8432ad68 100644 --- a/ee/clickhouse/views/person.py +++ b/ee/clickhouse/views/person.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Optional from rest_framework import request, response from rest_framework.decorators import action @@ -28,7 +28,7 @@ class EnterprisePersonViewSet(PersonViewSet): @cached_by_filters def calculate_funnel_correlation_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[list, Optional[str], Optional[str], int]]: filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) if not filter.correlation_person_limit: filter = filter.shallow_clone({FUNNEL_CORRELATION_PERSON_LIMIT: 100}) diff --git a/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py b/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py index 829232d1bd9..f5ff3722008 100644 --- a/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py +++ b/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py @@ -552,15 +552,15 @@ class FunnelCorrelationTest(BaseTest): ), ) - (browser_correlation,) = [ + (browser_correlation,) = ( correlation for correlation in odds["result"]["events"] if correlation["event"]["event"] == "$browser::1" - ] + ) - (notset_correlation,) = [ + (notset_correlation,) = ( correlation for correlation in odds["result"]["events"] if correlation["event"]["event"] == "$browser::" - ] + ) assert get_people_for_correlation_ok(client=self.client, correlation=browser_correlation) == { "success": ["Person 2"], diff --git a/ee/clickhouse/views/test/funnel/util.py b/ee/clickhouse/views/test/funnel/util.py index 8d2c304cb8b..45984ee41ba 100644 --- a/ee/clickhouse/views/test/funnel/util.py +++ b/ee/clickhouse/views/test/funnel/util.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, Literal, Optional, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union from django.test.client import Client @@ -12,7 +12,7 @@ class EventPattern(TypedDict, total=False): id: str type: Union[Literal["events"], Literal["actions"]] order: int - properties: Dict[str, Any] + properties: dict[str, Any] @dataclasses.dataclass @@ -46,7 +46,7 @@ def get_funnel(client: Client, team_id: int, request: FunnelRequest): ) -def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> Dict[str, Any]: +def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> dict[str, Any]: response = get_funnel(client=client, team_id=team_id, request=request) assert response.status_code == 200, response.content @@ -73,14 +73,14 @@ def get_funnel_correlation(client: Client, team_id: int, request: FunnelCorrelat ) -def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> Dict[str, Any]: +def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> dict[str, Any]: response = get_funnel_correlation(client=client, team_id=team_id, request=request) assert response.status_code == 200, response.content return response.json() -def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> Dict[str, Any]: +def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> dict[str, Any]: """ Helper for getting people for a correlation. Note we keep checking to just inclusion of name, to make the stable to changes in other people props. diff --git a/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py b/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py index 232312ec644..e7f9ebf7e2c 100644 --- a/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py +++ b/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from flaky import flaky @@ -7,7 +7,7 @@ from posthog.models.signals import mute_selected_signals from posthog.test.base import ClickhouseTestMixin, snapshot_clickhouse_queries from posthog.test.test_journeys import journeys_for -DEFAULT_JOURNEYS_FOR_PAYLOAD: Dict[str, List[Dict[str, Any]]] = { +DEFAULT_JOURNEYS_FOR_PAYLOAD: dict[str, list[dict[str, Any]]] = { # For a trend pageview metric "person1": [ { diff --git a/ee/clickhouse/views/test/test_clickhouse_retention.py b/ee/clickhouse/views/test/test_clickhouse_retention.py index 0e5a8ad0faf..5deff716a26 100644 --- a/ee/clickhouse/views/test/test_clickhouse_retention.py +++ b/ee/clickhouse/views/test/test_clickhouse_retention.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import List, Literal, Optional, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union from django.test.client import Client @@ -719,10 +719,10 @@ class RetentionRequest: period: Union[Literal["Hour"], Literal["Day"], Literal["Week"], Literal["Month"]] retention_type: Literal["retention_first_time", "retention"] # probably not an exhaustive list - breakdowns: Optional[List[Breakdown]] = None + breakdowns: Optional[list[Breakdown]] = None breakdown_type: Optional[Literal["person", "event"]] = None - properties: Optional[List[PropertyFilter]] = None + properties: Optional[list[PropertyFilter]] = None filter_test_accounts: Optional[str] = None limit: Optional[int] = None @@ -734,26 +734,26 @@ class Value(TypedDict): class Cohort(TypedDict): - values: List[Value] + values: list[Value] date: str label: str class RetentionResponse(TypedDict): - result: List[Cohort] + result: list[Cohort] class Person(TypedDict): - distinct_ids: List[str] + distinct_ids: list[str] class RetentionTableAppearance(TypedDict): person: Person - appearances: List[int] + appearances: list[int] class RetentionTablePeopleResponse(TypedDict): - result: List[RetentionTableAppearance] + result: list[RetentionTableAppearance] def get_retention_ok(client: Client, team_id: int, request: RetentionRequest) -> RetentionResponse: diff --git a/ee/clickhouse/views/test/test_clickhouse_trends.py b/ee/clickhouse/views/test/test_clickhouse_trends.py index 8ce3809263a..4de1f00e534 100644 --- a/ee/clickhouse/views/test/test_clickhouse_trends.py +++ b/ee/clickhouse/views/test/test_clickhouse_trends.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from unittest.case import skip from unittest.mock import ANY @@ -420,20 +420,20 @@ class TrendsRequest: insight: Optional[str] = None display: Optional[str] = None compare: Optional[bool] = None - events: List[Dict[str, Any]] = field(default_factory=list) - properties: List[Dict[str, Any]] = field(default_factory=list) + events: list[dict[str, Any]] = field(default_factory=list) + properties: list[dict[str, Any]] = field(default_factory=list) smoothing_intervals: Optional[int] = 1 refresh: Optional[bool] = False @dataclass class TrendsRequestBreakdown(TrendsRequest): - breakdown: Optional[Union[List[int], str]] = None + breakdown: Optional[Union[list[int], str]] = None breakdown_type: Optional[str] = None def get_trends(client, request: Union[TrendsRequestBreakdown, TrendsRequest], team: Team): - data: Dict[str, Any] = { + data: dict[str, Any] = { "date_from": request.date_from, "date_to": request.date_to, "interval": request.interval, @@ -471,7 +471,7 @@ class NormalizedTrendResult: def get_trends_time_series_ok( client: Client, request: TrendsRequest, team: Team, with_order: bool = False -) -> Dict[str, Dict[str, NormalizedTrendResult]]: +) -> dict[str, dict[str, NormalizedTrendResult]]: data = get_trends_ok(client=client, request=request, team=team) res = {} for item in data["result"]: @@ -491,7 +491,7 @@ def get_trends_time_series_ok( return res -def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> Dict[str, NormalizedTrendResult]: +def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> dict[str, NormalizedTrendResult]: data = get_trends_ok(client=client, request=request, team=team) res = {} for item in data["result"]: diff --git a/ee/migrations/0001_initial.py b/ee/migrations/0001_initial.py index fd3cad38927..5b668bc772b 100644 --- a/ee/migrations/0001_initial.py +++ b/ee/migrations/0001_initial.py @@ -1,6 +1,5 @@ # Generated by Django 3.0.7 on 2020-08-07 09:15 -from typing import List from django.db import migrations, models @@ -8,7 +7,7 @@ from django.db import migrations, models class Migration(migrations.Migration): initial = True - dependencies: List = [] + dependencies: list = [] operations = [ migrations.CreateModel( diff --git a/ee/migrations/0012_migrate_tags_v2.py b/ee/migrations/0012_migrate_tags_v2.py index 9a2cf8e3d39..540cd281338 100644 --- a/ee/migrations/0012_migrate_tags_v2.py +++ b/ee/migrations/0012_migrate_tags_v2.py @@ -1,5 +1,5 @@ # Generated by Django 3.2.5 on 2022-03-02 22:44 -from typing import Any, List, Tuple +from typing import Any from django.core.paginator import Paginator from django.db import migrations @@ -19,7 +19,7 @@ def forwards(apps, schema_editor): EnterpriseEventDefinition = apps.get_model("ee", "EnterpriseEventDefinition") EnterprisePropertyDefinition = apps.get_model("ee", "EnterprisePropertyDefinition") - createables: List[Tuple[Any, Any]] = [] + createables: list[tuple[Any, Any]] = [] batch_size = 1_000 # Collect event definition tags and taggeditems diff --git a/ee/models/license.py b/ee/models/license.py index f0e12d3d2f4..35530b89687 100644 --- a/ee/models/license.py +++ b/ee/models/license.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from django.contrib.auth import get_user_model from django.db import models @@ -85,7 +85,7 @@ class License(models.Model): PLAN_TO_SORTING_VALUE = {SCALE_PLAN: 10, ENTERPRISE_PLAN: 20} @property - def available_features(self) -> List[AvailableFeature]: + def available_features(self) -> list[AvailableFeature]: return self.PLANS.get(self.plan, []) @property diff --git a/ee/session_recordings/ai/embeddings_queries.py b/ee/session_recordings/ai/embeddings_queries.py index 6d657d11109..2034a9f1901 100644 --- a/ee/session_recordings/ai/embeddings_queries.py +++ b/ee/session_recordings/ai/embeddings_queries.py @@ -1,6 +1,5 @@ from django.conf import settings -from typing import List from posthog.models import Team from posthog.clickhouse.client import sync_execute @@ -9,7 +8,7 @@ BATCH_FLUSH_SIZE = settings.REPLAY_EMBEDDINGS_BATCH_SIZE MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS -def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[str]: +def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> list[str]: query = """ WITH embedded_sessions AS ( SELECT @@ -47,7 +46,7 @@ def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[s ) -def fetch_recordings_without_embeddings(team_id: int, offset=0) -> List[str]: +def fetch_recordings_without_embeddings(team_id: int, offset=0) -> list[str]: team = Team.objects.get(id=team_id) query = """ diff --git a/ee/session_recordings/ai/embeddings_runner.py b/ee/session_recordings/ai/embeddings_runner.py index 101c7175acb..413e9f45368 100644 --- a/ee/session_recordings/ai/embeddings_runner.py +++ b/ee/session_recordings/ai/embeddings_runner.py @@ -3,7 +3,7 @@ import tiktoken import datetime import pytz -from typing import Dict, Any, List, Tuple +from typing import Any from abc import ABC, abstractmethod from prometheus_client import Histogram, Counter @@ -88,7 +88,7 @@ class EmbeddingPreparation(ABC): @staticmethod @abstractmethod - def prepare(item, team) -> Tuple[str, str]: + def prepare(item, team) -> tuple[str, str]: raise NotImplementedError() @@ -100,7 +100,7 @@ class SessionEmbeddingsRunner(ABC): self.team = team self.openai_client = OpenAI() - def run(self, items: List[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None: + def run(self, items: list[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None: source_type = embeddings_preparation.source_type try: @@ -196,7 +196,7 @@ class SessionEmbeddingsRunner(ABC): """Returns the number of tokens in a text string.""" return len(encoding.encode(string)) - def _flush_embeddings_to_clickhouse(self, embeddings: List[Dict[str, Any]], source_type: str) -> None: + def _flush_embeddings_to_clickhouse(self, embeddings: list[dict[str, Any]], source_type: str) -> None: try: sync_execute( "INSERT INTO session_replay_embeddings (session_id, team_id, embeddings, source_type, input) VALUES", @@ -213,7 +213,7 @@ class ErrorEmbeddingsPreparation(EmbeddingPreparation): source_type = "error" @staticmethod - def prepare(item: Tuple[str, str], _): + def prepare(item: tuple[str, str], _): session_id = item[0] error_message = item[1] return session_id, error_message @@ -286,7 +286,7 @@ class SessionEventsEmbeddingsPreparation(EmbeddingPreparation): return session_id, input @staticmethod - def _compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str: + def _compact_result(event_name: str, current_url: int, elements_chain: dict[str, str] | str) -> str: elements_string = ( elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain) ) diff --git a/ee/session_recordings/ai/utils.py b/ee/session_recordings/ai/utils.py index a1d5f31460d..1b7770a1361 100644 --- a/ee/session_recordings/ai/utils.py +++ b/ee/session_recordings/ai/utils.py @@ -1,7 +1,7 @@ import dataclasses from datetime import datetime -from typing import List, Dict, Any +from typing import Any from posthog.models.element import chain_to_elements from hashlib import shake_256 @@ -12,11 +12,11 @@ class SessionSummaryPromptData: # we may allow customisation of columns included in the future, # and we alter the columns present as we process the data # so want to stay as loose as possible here - columns: List[str] = dataclasses.field(default_factory=list) - results: List[List[Any]] = dataclasses.field(default_factory=list) + columns: list[str] = dataclasses.field(default_factory=list) + results: list[list[Any]] = dataclasses.field(default_factory=list) # in order to reduce the number of tokens in the prompt # we replace URLs with a placeholder and then pass this mapping of placeholder to URL into the prompt - url_mapping: Dict[str, str] = dataclasses.field(default_factory=dict) + url_mapping: dict[str, str] = dataclasses.field(default_factory=dict) def is_empty(self) -> bool: return not self.columns or not self.results @@ -63,7 +63,7 @@ def simplify_window_id(session_events: SessionSummaryPromptData) -> SessionSumma # find window_id column index window_id_index = session_events.column_index("$window_id") - window_id_mapping: Dict[str, int] = {} + window_id_mapping: dict[str, int] = {} simplified_results = [] for result in session_events.results: if window_id_index is None: @@ -128,7 +128,7 @@ def deduplicate_urls(session_events: SessionSummaryPromptData) -> SessionSummary # find url column index url_index = session_events.column_index("$current_url") - url_mapping: Dict[str, str] = {} + url_mapping: dict[str, str] = {} deduplicated_results = [] for result in session_events.results: if url_index is None: diff --git a/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py b/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py index 71196ec0eca..797ac453e69 100644 --- a/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py +++ b/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py @@ -1,5 +1,4 @@ from itertools import product -from typing import Dict from unittest import mock from uuid import uuid4 @@ -131,7 +130,7 @@ class TestClickhouseSessionRecordingsListFromSessionReplay(ClickhouseTestMixin, poe_v2: bool, allow_denormalized_props: bool, expected_poe_mode: PersonsOnEventsMode, - expected_query_params: Dict, + expected_query_params: dict, unmaterialized_person_column_used: bool, materialized_event_column_used: bool, ) -> None: diff --git a/ee/session_recordings/session_recording_playlist.py b/ee/session_recordings/session_recording_playlist.py index a54f8e38a6b..7d2b9fe0b0c 100644 --- a/ee/session_recordings/session_recording_playlist.py +++ b/ee/session_recordings/session_recording_playlist.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional +from typing import Any, Optional import structlog from django.db.models import Q, QuerySet @@ -49,7 +49,7 @@ def log_playlist_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: """ Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance @@ -101,7 +101,7 @@ class SessionRecordingPlaylistSerializer(serializers.ModelSerializer): created_by = UserBasicSerializer(read_only=True) last_modified_by = UserBasicSerializer(read_only=True) - def create(self, validated_data: Dict, *args, **kwargs) -> SessionRecordingPlaylist: + def create(self, validated_data: dict, *args, **kwargs) -> SessionRecordingPlaylist: request = self.context["request"] team = self.context["get_team"]() @@ -128,7 +128,7 @@ class SessionRecordingPlaylistSerializer(serializers.ModelSerializer): return playlist - def update(self, instance: SessionRecordingPlaylist, validated_data: Dict, **kwargs) -> SessionRecordingPlaylist: + def update(self, instance: SessionRecordingPlaylist, validated_data: dict, **kwargs) -> SessionRecordingPlaylist: try: before_update = SessionRecordingPlaylist.objects.get(pk=instance.id) except SessionRecordingPlaylist.DoesNotExist: diff --git a/ee/session_recordings/test/test_session_recording_extensions.py b/ee/session_recordings/test/test_session_recording_extensions.py index 35fd5d2bc8b..ad545e5cec3 100644 --- a/ee/session_recordings/test/test_session_recording_extensions.py +++ b/ee/session_recordings/test/test_session_recording_extensions.py @@ -103,7 +103,7 @@ class TestSessionRecordingExtensions(ClickhouseTestMixin, APIBaseTest): for file in ["a", "b", "c"]: blob_path = f"{TEST_BUCKET}/team_id/{self.team.pk}/session_id/{session_id}/data" file_name = f"{blob_path}/{file}" - write(file_name, f"my content-{file}".encode("utf-8")) + write(file_name, f"my content-{file}".encode()) recording: SessionRecording = SessionRecording.objects.create(team=self.team, session_id=session_id) @@ -164,7 +164,7 @@ class TestSessionRecordingExtensions(ClickhouseTestMixin, APIBaseTest): mock_write.assert_called_with( f"{expected_path}/12345000-12346000", - gzip.compress("the new content".encode("utf-8")), + gzip.compress(b"the new content"), extras={ "ContentEncoding": "gzip", "ContentType": "application/json", diff --git a/ee/settings.py b/ee/settings.py index 7342bdf98f9..d9a863c3f81 100644 --- a/ee/settings.py +++ b/ee/settings.py @@ -3,14 +3,13 @@ Django settings for PostHog Enterprise Edition. """ import os -from typing import Dict, List from posthog.settings import AUTHENTICATION_BACKENDS, DEMO, SITE_URL, DEBUG from posthog.settings.utils import get_from_env from posthog.utils import str_to_bool # Zapier REST hooks -HOOK_EVENTS: Dict[str, str] = { +HOOK_EVENTS: dict[str, str] = { # "event_name": "App.Model.Action" (created/updated/deleted) "action_performed": "posthog.Action.performed", } @@ -43,7 +42,7 @@ SOCIAL_AUTH_SAML_SUPPORT_CONTACT = SOCIAL_AUTH_SAML_TECHNICAL_CONTACT SOCIAL_AUTH_GOOGLE_OAUTH2_KEY = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_KEY") SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET") if "SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS" in os.environ: - SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: List[str] = os.environ[ + SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: list[str] = os.environ[ "SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS" ].split(",") elif DEMO: diff --git a/ee/tasks/auto_rollback_feature_flag.py b/ee/tasks/auto_rollback_feature_flag.py index d1b7e606976..f676f91d0c4 100644 --- a/ee/tasks/auto_rollback_feature_flag.py +++ b/ee/tasks/auto_rollback_feature_flag.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from typing import Dict from zoneinfo import ZoneInfo from celery import shared_task @@ -30,7 +29,7 @@ def check_feature_flag_rollback_conditions(feature_flag_id: int) -> None: flag.save() -def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str) -> float: +def calculate_rolling_average(threshold_metric: dict, team: Team, timezone: str) -> float: curr = datetime.now(tz=ZoneInfo(timezone)) rolling_average_days = 7 @@ -54,7 +53,7 @@ def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str) return sum(data) / rolling_average_days -def check_condition(rollback_condition: Dict, feature_flag: FeatureFlag) -> bool: +def check_condition(rollback_condition: dict, feature_flag: FeatureFlag) -> bool: if rollback_condition["threshold_type"] == "sentry": created_date = feature_flag.created_at base_start_date = created_date.strftime("%Y-%m-%dT%H:%M:%S") diff --git a/ee/tasks/replay.py b/ee/tasks/replay.py index 036925b279a..fcf57196c2d 100644 --- a/ee/tasks/replay.py +++ b/ee/tasks/replay.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any import structlog from celery import shared_task @@ -25,7 +25,7 @@ logger = structlog.get_logger(__name__) # we currently are allowed 500 calls per minute, so let's rate limit each worker # to much less than that @shared_task(ignore_result=False, queue=CeleryQueue.SESSION_REPLAY_EMBEDDINGS.value, rate_limit="75/m") -def embed_batch_of_recordings_task(recordings: List[Any], team_id: int) -> None: +def embed_batch_of_recordings_task(recordings: list[Any], team_id: int) -> None: try: team = Team.objects.get(id=team_id) runner = SessionEmbeddingsRunner(team=team) diff --git a/ee/tasks/slack.py b/ee/tasks/slack.py index 0137089b08b..251e9fd2613 100644 --- a/ee/tasks/slack.py +++ b/ee/tasks/slack.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict +from typing import Any from urllib.parse import urlparse import structlog @@ -16,7 +16,7 @@ logger = structlog.get_logger(__name__) SHARED_LINK_REGEX = r"\/(?:shared_dashboard|shared|embedded)\/(.+)" -def _block_for_asset(asset: ExportedAsset) -> Dict: +def _block_for_asset(asset: ExportedAsset) -> dict: image_url = asset.get_public_content_url() alt_text = None if asset.insight: diff --git a/ee/tasks/subscriptions/email_subscriptions.py b/ee/tasks/subscriptions/email_subscriptions.py index aa62b7d83a4..39e342bcec1 100644 --- a/ee/tasks/subscriptions/email_subscriptions.py +++ b/ee/tasks/subscriptions/email_subscriptions.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Optional +from typing import Optional import structlog @@ -15,7 +15,7 @@ logger = structlog.get_logger(__name__) def send_email_subscription_report( email: str, subscription: Subscription, - assets: List[ExportedAsset], + assets: list[ExportedAsset], invite_message: Optional[str] = None, total_asset_count: Optional[int] = None, ) -> None: diff --git a/ee/tasks/subscriptions/slack_subscriptions.py b/ee/tasks/subscriptions/slack_subscriptions.py index 1d35259a6f3..73643c7a97b 100644 --- a/ee/tasks/subscriptions/slack_subscriptions.py +++ b/ee/tasks/subscriptions/slack_subscriptions.py @@ -1,5 +1,3 @@ -from typing import Dict, List - import structlog from django.conf import settings @@ -12,7 +10,7 @@ logger = structlog.get_logger(__name__) UTM_TAGS_BASE = "utm_source=posthog&utm_campaign=subscription_report" -def _block_for_asset(asset: ExportedAsset) -> Dict: +def _block_for_asset(asset: ExportedAsset) -> dict: image_url = asset.get_public_content_url() alt_text = None if asset.insight: @@ -26,7 +24,7 @@ def _block_for_asset(asset: ExportedAsset) -> Dict: def send_slack_subscription_report( subscription: Subscription, - assets: List[ExportedAsset], + assets: list[ExportedAsset], total_asset_count: int, is_new_subscription: bool = False, ) -> None: diff --git a/ee/tasks/subscriptions/subscription_utils.py b/ee/tasks/subscriptions/subscription_utils.py index d89d73d4a3b..6fa4b63960f 100644 --- a/ee/tasks/subscriptions/subscription_utils.py +++ b/ee/tasks/subscriptions/subscription_utils.py @@ -1,5 +1,5 @@ import datetime -from typing import List, Tuple, Union +from typing import Union from django.conf import settings import structlog from celery import chain @@ -28,7 +28,7 @@ SUBSCRIPTION_ASSET_GENERATION_TIMER = Histogram( def generate_assets( resource: Union[Subscription, SharingConfiguration], max_asset_count: int = DEFAULT_MAX_ASSET_COUNT, -) -> Tuple[List[Insight], List[ExportedAsset]]: +) -> tuple[list[Insight], list[ExportedAsset]]: with SUBSCRIPTION_ASSET_GENERATION_TIMER.time(): if resource.dashboard: tiles = get_tiles_ordered_by_position(resource.dashboard) diff --git a/ee/tasks/test/subscriptions/test_subscriptions.py b/ee/tasks/test/subscriptions/test_subscriptions.py index d6afe50b68f..c814b2a4ebc 100644 --- a/ee/tasks/test/subscriptions/test_subscriptions.py +++ b/ee/tasks/test/subscriptions/test_subscriptions.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List from unittest.mock import MagicMock, call, patch from zoneinfo import ZoneInfo @@ -25,10 +24,10 @@ from posthog.test.base import APIBaseTest @patch("ee.tasks.subscriptions.generate_assets") @freeze_time("2022-02-02T08:55:00.000Z") class TestSubscriptionsTasks(APIBaseTest): - subscriptions: List[Subscription] = None # type: ignore + subscriptions: list[Subscription] = None # type: ignore dashboard: Dashboard insight: Insight - tiles: List[DashboardTile] = None # type: ignore + tiles: list[DashboardTile] = None # type: ignore asset: ExportedAsset def setUp(self) -> None: diff --git a/ee/tasks/test/subscriptions/test_subscriptions_utils.py b/ee/tasks/test/subscriptions/test_subscriptions_utils.py index c8ff89adcea..edab23bbfb9 100644 --- a/ee/tasks/test/subscriptions/test_subscriptions_utils.py +++ b/ee/tasks/test/subscriptions/test_subscriptions_utils.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -21,7 +20,7 @@ class TestSubscriptionsTasksUtils(APIBaseTest): dashboard: Dashboard insight: Insight asset: ExportedAsset - tiles: List[DashboardTile] + tiles: list[DashboardTile] def setUp(self) -> None: self.dashboard = Dashboard.objects.create(team=self.team, name="private dashboard", created_by=self.user) diff --git a/ee/tasks/test/test_slack.py b/ee/tasks/test/test_slack.py index 03b28b8155c..64b227d7d1e 100644 --- a/ee/tasks/test/test_slack.py +++ b/ee/tasks/test/test_slack.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import MagicMock, patch from freezegun import freeze_time @@ -14,7 +13,7 @@ from posthog.models.subscription import Subscription from posthog.test.base import APIBaseTest -def create_mock_unfurl_event(team_id: str, links: List[str]): +def create_mock_unfurl_event(team_id: str, links: list[str]): return { "token": "XXYYZZ", "team_id": team_id, diff --git a/ee/urls.py b/ee/urls.py index a3851a28075..2ee3f7d3a8f 100644 --- a/ee/urls.py +++ b/ee/urls.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.conf import settings from django.contrib import admin @@ -92,7 +92,7 @@ admin_urlpatterns = ( ) -urlpatterns: List[Any] = [ +urlpatterns: list[Any] = [ path("api/saml/metadata/", authentication.saml_metadata_view), path("api/sentry_stats/", sentry_stats.sentry_stats), *admin_urlpatterns, diff --git a/gunicorn.config.py b/gunicorn.config.py index 1e561820260..acd7ba3f5f5 100644 --- a/gunicorn.config.py +++ b/gunicorn.config.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- import logging import os diff --git a/hogvm/python/execute.py b/hogvm/python/execute.py index 4e4a61a1af5..a1130c0d54c 100644 --- a/hogvm/python/execute.py +++ b/hogvm/python/execute.py @@ -1,5 +1,5 @@ import re -from typing import List, Any, Dict +from typing import Any from hogvm.python.operation import Operation, HOGQL_BYTECODE_IDENTIFIER @@ -33,7 +33,7 @@ def to_concat_arg(arg) -> str: return str(arg) -def execute_bytecode(bytecode: List[Any], fields: Dict[str, Any]) -> Any: +def execute_bytecode(bytecode: list[Any], fields: dict[str, Any]) -> Any: try: stack = [] iterator = iter(bytecode) diff --git a/plugin-server/bin/generate_session_recordings_messages.py b/plugin-server/bin/generate_session_recordings_messages.py index 4b5462bebd3..cfd3d034d19 100755 --- a/plugin-server/bin/generate_session_recordings_messages.py +++ b/plugin-server/bin/generate_session_recordings_messages.py @@ -53,7 +53,6 @@ import argparse import json import uuid from sys import stderr, stdout -from typing import List import numpy from faker import Faker @@ -144,7 +143,7 @@ def get_parser(): def chunked( data: str, chunk_size: int, -) -> List[str]: +) -> list[str]: return [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] diff --git a/posthog/api/action.py b/posthog/api/action.py index 437f0227c81..38eb33d1074 100644 --- a/posthog/api/action.py +++ b/posthog/api/action.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, cast +from typing import Any, cast from django.db.models import Count, Prefetch from rest_framework import request, serializers, viewsets @@ -123,7 +123,7 @@ class ActionSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedModelSe return instance - def update(self, instance: Any, validated_data: Dict[str, Any]) -> Any: + def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: steps = validated_data.pop("steps", None) # If there's no steps property at all we just ignore it # If there is a step property but it's an empty array [], we'll delete all the steps @@ -182,7 +182,7 @@ class ActionViewSet( def list(self, request: request.Request, *args: Any, **kwargs: Any) -> Response: actions = self.get_queryset() - actions_list: List[Dict[Any, Any]] = self.serializer_class( + actions_list: list[dict[Any, Any]] = self.serializer_class( actions, many=True, context={"request": request} ).data # type: ignore return Response({"results": actions_list}) diff --git a/posthog/api/activity_log.py b/posthog/api/activity_log.py index fefa2554d19..35ff30d5703 100644 --- a/posthog/api/activity_log.py +++ b/posthog/api/activity_log.py @@ -1,5 +1,5 @@ import time -from typing import Any, Optional, Dict +from typing import Any, Optional from django.db.models import Q, QuerySet @@ -49,7 +49,7 @@ class ActivityLogPagination(pagination.CursorPagination): # context manager for gathering a sequence of server timings class ServerTimingsGathered: # Class level dictionary to store timings - timings_dict: Dict[str, float] = {} + timings_dict: dict[str, float] = {} def __call__(self, name): self.name = name diff --git a/posthog/api/annotation.py b/posthog/api/annotation.py index 4806d5a632f..7216efe6cd6 100644 --- a/posthog/api/annotation.py +++ b/posthog/api/annotation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from django.db.models import Q, QuerySet from django.db.models.signals import post_save @@ -40,11 +40,11 @@ class AnnotationSerializer(serializers.ModelSerializer): "updated_at", ] - def update(self, instance: Annotation, validated_data: Dict[str, Any]) -> Annotation: + def update(self, instance: Annotation, validated_data: dict[str, Any]) -> Annotation: instance.team_id = self.context["team_id"] return super().update(instance, validated_data) - def create(self, validated_data: Dict[str, Any], *args: Any, **kwargs: Any) -> Annotation: + def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> Annotation: request = self.context["request"] team = self.context["get_team"]() annotation = Annotation.objects.create( diff --git a/posthog/api/authentication.py b/posthog/api/authentication.py index 069acac50c9..d7911059506 100644 --- a/posthog/api/authentication.py +++ b/posthog/api/authentication.py @@ -1,6 +1,6 @@ import datetime import time -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from uuid import uuid4 from django.conf import settings @@ -92,7 +92,7 @@ class LoginSerializer(serializers.Serializer): email = serializers.EmailField() password = serializers.CharField() - def to_representation(self, instance: Any) -> Dict[str, Any]: + def to_representation(self, instance: Any) -> dict[str, Any]: return {"success": True} def _check_if_2fa_required(self, user: User) -> bool: @@ -113,7 +113,7 @@ class LoginSerializer(serializers.Serializer): pass return True - def create(self, validated_data: Dict[str, str]) -> Any: + def create(self, validated_data: dict[str, str]) -> Any: # Check SSO enforcement (which happens at the domain level) sso_enforcement = OrganizationDomain.objects.get_sso_enforcement_for_email_address(validated_data["email"]) if sso_enforcement: @@ -159,10 +159,10 @@ class LoginSerializer(serializers.Serializer): class LoginPrecheckSerializer(serializers.Serializer): email = serializers.EmailField() - def to_representation(self, instance: Dict[str, str]) -> Dict[str, Any]: + def to_representation(self, instance: dict[str, str]) -> dict[str, Any]: return instance - def create(self, validated_data: Dict[str, str]) -> Any: + def create(self, validated_data: dict[str, str]) -> Any: email = validated_data.get("email", "") # TODO: Refactor methods below to remove duplicate queries return { diff --git a/posthog/api/capture.py b/posthog/api/capture.py index 31592e90e79..9c223f8264a 100644 --- a/posthog/api/capture.py +++ b/posthog/api/capture.py @@ -18,7 +18,8 @@ from sentry_sdk import configure_scope from sentry_sdk.api import capture_exception, start_span from statshog.defaults.django import statsd from token_bucket import Limiter, MemoryStorage -from typing import Any, Dict, Iterator, List, Optional, Tuple, Set +from typing import Any, Optional +from collections.abc import Iterator from ee.billing.quota_limiting import QuotaLimitingCaches from posthog.api.utils import get_data, get_token, safe_clickhouse_string @@ -129,12 +130,12 @@ def build_kafka_event_data( distinct_id: str, ip: Optional[str], site_url: str, - data: Dict, + data: dict, now: datetime, sent_at: Optional[datetime], event_uuid: UUIDT, token: str, -) -> Dict: +) -> dict: logger.debug("build_kafka_event_data", token=token) return { "uuid": str(event_uuid), @@ -168,10 +169,10 @@ def _kafka_topic(event_name: str, historical: bool = False, overflowing: bool = def log_event( - data: Dict, + data: dict, event_name: str, partition_key: Optional[str], - headers: Optional[List] = None, + headers: Optional[list] = None, historical: bool = False, overflowing: bool = False, ) -> FutureRecordMetadata: @@ -205,7 +206,7 @@ def _datetime_from_seconds_or_millis(timestamp: str) -> datetime: return datetime.fromtimestamp(timestamp_number, timezone.utc) -def _get_sent_at(data, request) -> Tuple[Optional[datetime], Any]: +def _get_sent_at(data, request) -> tuple[Optional[datetime], Any]: try: if request.GET.get("_"): # posthog-js sent_at = request.GET["_"] @@ -253,7 +254,7 @@ def _check_token_shape(token: Any) -> Optional[str]: return None -def get_distinct_id(data: Dict[str, Any]) -> str: +def get_distinct_id(data: dict[str, Any]) -> str: raw_value: Any = "" try: raw_value = data["$distinct_id"] @@ -274,12 +275,12 @@ def get_distinct_id(data: Dict[str, Any]) -> str: return str(raw_value)[0:200] -def drop_performance_events(events: List[Any]) -> List[Any]: +def drop_performance_events(events: list[Any]) -> list[Any]: cleaned_list = [event for event in events if event.get("event") != "$performance_event"] return cleaned_list -def drop_events_over_quota(token: str, events: List[Any]) -> List[Any]: +def drop_events_over_quota(token: str, events: list[Any]) -> list[Any]: if not settings.EE_AVAILABLE: return events @@ -381,7 +382,7 @@ def get_event(request): structlog.contextvars.bind_contextvars(token=token) - replay_events: List[Any] = [] + replay_events: list[Any] = [] historical = token in settings.TOKENS_HISTORICAL_DATA with start_span(op="request.process"): @@ -437,7 +438,7 @@ def get_event(request): generate_exception_response("capture", f"Invalid payload: {e}", code="invalid_payload"), ) - futures: List[FutureRecordMetadata] = [] + futures: list[FutureRecordMetadata] = [] with start_span(op="kafka.produce") as span: span.set_tag("event.count", len(processed_events)) @@ -536,7 +537,7 @@ def get_event(request): return cors_response(request, JsonResponse({"status": 1})) -def preprocess_events(events: List[Dict[str, Any]]) -> Iterator[Tuple[Dict[str, Any], UUIDT, str]]: +def preprocess_events(events: list[dict[str, Any]]) -> Iterator[tuple[dict[str, Any], UUIDT, str]]: for event in events: event_uuid = UUIDT() distinct_id = get_distinct_id(event) @@ -580,7 +581,7 @@ def capture_internal( event_uuid=None, token=None, historical=False, - extra_headers: List[Tuple[str, str]] | None = None, + extra_headers: list[tuple[str, str]] | None = None, ): if event_uuid is None: event_uuid = UUIDT() @@ -680,7 +681,7 @@ def is_randomly_partitioned(candidate_partition_key: str) -> bool: @cache_for(timedelta(seconds=30), background_refresh=True) -def _list_overflowing_keys(input_type: InputType) -> Set[str]: +def _list_overflowing_keys(input_type: InputType) -> set[str]: """Retrieve the active overflows from Redis with caching and pre-fetching cache_for will keep the old value if Redis is temporarily unavailable. diff --git a/posthog/api/cohort.py b/posthog/api/cohort.py index 64eb30db9b0..af85769fd0f 100644 --- a/posthog/api/cohort.py +++ b/posthog/api/cohort.py @@ -18,7 +18,7 @@ import posthoganalytics from posthog.metrics import LABEL_TEAM_ID from posthog.renderers import SafeJSONRenderer from datetime import datetime -from typing import Any, Dict, cast, Optional +from typing import Any, cast, Optional from django.conf import settings from django.db.models import QuerySet, Prefetch, prefetch_related_objects, OuterRef, Subquery @@ -133,7 +133,7 @@ class CohortSerializer(serializers.ModelSerializer): "experiment_set", ] - def _handle_static(self, cohort: Cohort, context: Dict, validated_data: Dict) -> None: + def _handle_static(self, cohort: Cohort, context: dict, validated_data: dict) -> None: request = self.context["request"] if request.FILES.get("csv"): self._calculate_static_by_csv(request.FILES["csv"], cohort) @@ -149,7 +149,7 @@ class CohortSerializer(serializers.ModelSerializer): if filter_data: insert_cohort_from_insight_filter.delay(cohort.pk, filter_data) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort: request = self.context["request"] validated_data["created_by"] = request.user @@ -176,7 +176,7 @@ class CohortSerializer(serializers.ModelSerializer): distinct_ids_and_emails = [row[0] for row in reader if len(row) > 0 and row] calculate_cohort_from_list.delay(cohort.pk, distinct_ids_and_emails) - def validate_query(self, query: Optional[Dict]) -> Optional[Dict]: + def validate_query(self, query: Optional[dict]) -> Optional[dict]: if not query: return None if not isinstance(query, dict): @@ -186,7 +186,7 @@ class CohortSerializer(serializers.ModelSerializer): ActorsQuery.model_validate(query) return query - def validate_filters(self, request_filters: Dict): + def validate_filters(self, request_filters: dict): if isinstance(request_filters, dict) and "properties" in request_filters: if self.context["request"].method == "PATCH": parsed_filter = Filter(data=request_filters) @@ -225,7 +225,7 @@ class CohortSerializer(serializers.ModelSerializer): else: raise ValidationError("Filters must be a dictionary with a 'properties' key.") - def update(self, cohort: Cohort, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore + def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore request = self.context["request"] user = cast(User, request.user) @@ -498,7 +498,7 @@ def insert_cohort_query_actors_into_ch(cohort: Cohort): insert_actors_into_cohort_by_query(cohort, query, {}, context) -def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict): +def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: dict): from_existing_cohort_id = filter_data.get("from_cohort_id") context: HogQLContext @@ -561,7 +561,7 @@ def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict): insert_actors_into_cohort_by_query(cohort, query, params, context) -def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: Dict[str, Any], context: HogQLContext): +def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: dict[str, Any], context: HogQLContext): try: sync_execute( INSERT_COHORT_ALL_PEOPLE_THROUGH_PERSON_ID.format(cohort_table=PERSON_STATIC_COHORT_TABLE, query=query), @@ -600,7 +600,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int, cohort = Cohort.objects.get(pk=cohort_id, team_id=team_id) matcher_cache = FlagsMatcherCache(team_id) uuids_to_add_to_cohort = [] - cohorts_cache: Dict[int, CohortOrEmpty] = {} + cohorts_cache: dict[int, CohortOrEmpty] = {} if feature_flag.uses_cohorts: # TODO: Consider disabling flags with cohorts for creating static cohorts @@ -709,7 +709,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int, capture_exception(err) -def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortOrEmpty]): +def get_default_person_property(prop: Property, cohorts_cache: dict[int, CohortOrEmpty]): default_person_properties = {} if prop.operator not in ("is_set", "is_not_set") and prop.type == "person": @@ -725,7 +725,7 @@ def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortO return default_person_properties -def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: Dict[int, CohortOrEmpty]) -> Dict[str, str]: +def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: dict[int, CohortOrEmpty]) -> dict[str, str]: """ Returns a dictionary of default person properties to use when evaluating a feature flag """ diff --git a/posthog/api/comments.py b/posthog/api/comments.py index 8b9a9174dda..63ef5d1d33a 100644 --- a/posthog/api/comments.py +++ b/posthog/api/comments.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from django.db import transaction from django.db.models import QuerySet @@ -40,7 +40,7 @@ class CommentSerializer(serializers.ModelSerializer): validated_data["team_id"] = self.context["team_id"] return super().create(validated_data) - def update(self, instance: Comment, validated_data: Dict, **kwargs) -> Comment: + def update(self, instance: Comment, validated_data: dict, **kwargs) -> Comment: request = self.context["request"] with transaction.atomic(): diff --git a/posthog/api/dashboards/dashboard.py b/posthog/api/dashboards/dashboard.py index a89d41814d6..850e29b52a4 100644 --- a/posthog/api/dashboards/dashboard.py +++ b/posthog/api/dashboards/dashboard.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast import structlog from django.db.models import Prefetch, QuerySet @@ -155,13 +155,13 @@ class DashboardSerializer(DashboardBasicSerializer): ] read_only_fields = ["creation_mode", "effective_restriction_level", "is_shared"] - def validate_filters(self, value) -> Dict: + def validate_filters(self, value) -> dict: if not isinstance(value, dict): raise serializers.ValidationError("Filters must be a dictionary") return value - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard: request = self.context["request"] validated_data["created_by"] = request.user team_id = self.context["team_id"] @@ -260,7 +260,7 @@ class DashboardSerializer(DashboardBasicSerializer): color=existing_tile.color, ) - def update(self, instance: Dashboard, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard: + def update(self, instance: Dashboard, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard: can_user_restrict = self.user_permissions.dashboard(instance).can_restrict if "restriction_level" in validated_data and not can_user_restrict: raise exceptions.PermissionDenied( @@ -292,11 +292,11 @@ class DashboardSerializer(DashboardBasicSerializer): return instance @staticmethod - def _update_tiles(instance: Dashboard, tile_data: Dict, user: User) -> None: + def _update_tiles(instance: Dashboard, tile_data: dict, user: User) -> None: tile_data.pop("is_cached", None) # read only field if tile_data.get("text", None): - text_json: Dict = tile_data.get("text", {}) + text_json: dict = tile_data.get("text", {}) created_by_json = text_json.get("created_by", None) if created_by_json: last_modified_by = user @@ -348,7 +348,7 @@ class DashboardSerializer(DashboardBasicSerializer): insights_to_undelete.append(tile.insight) Insight.objects.bulk_update(insights_to_undelete, ["deleted"]) - def get_tiles(self, dashboard: Dashboard) -> Optional[List[ReturnDict]]: + def get_tiles(self, dashboard: Dashboard) -> Optional[list[ReturnDict]]: if self.context["view"].action == "list": return None @@ -401,7 +401,7 @@ class DashboardsViewSet( queryset = Dashboard.objects_including_soft_deleted.order_by("name") permission_classes = [CanEditDashboard] - def get_serializer_class(self) -> Type[BaseSerializer]: + def get_serializer_class(self) -> type[BaseSerializer]: return DashboardBasicSerializer if self.action == "list" else DashboardSerializer def get_queryset(self) -> QuerySet: @@ -512,7 +512,7 @@ class DashboardsViewSet( class LegacyDashboardsViewSet(DashboardsViewSet): derive_current_team_from_user_only = True - def get_parents_query_dict(self) -> Dict[str, Any]: + def get_parents_query_dict(self) -> dict[str, Any]: if not self.request.user.is_authenticated or "share_token" in self.request.GET: return {} return {"team_id": self.team_id} diff --git a/posthog/api/dashboards/dashboard_template_json_schema_parser.py b/posthog/api/dashboards/dashboard_template_json_schema_parser.py index 3463601514e..8f9149cd84d 100644 --- a/posthog/api/dashboards/dashboard_template_json_schema_parser.py +++ b/posthog/api/dashboards/dashboard_template_json_schema_parser.py @@ -15,9 +15,7 @@ class DashboardTemplateCreationJSONSchemaParser(JSONParser): The template is sent in the "template" key""" def parse(self, stream, media_type=None, parser_context=None): - data = super(DashboardTemplateCreationJSONSchemaParser, self).parse( - stream, media_type or "application/json", parser_context - ) + data = super().parse(stream, media_type or "application/json", parser_context) try: template = data["template"] jsonschema.validate(template, dashboard_template_schema) diff --git a/posthog/api/dashboards/dashboard_templates.py b/posthog/api/dashboards/dashboard_templates.py index 6e8752e0cbd..03740b06ebd 100644 --- a/posthog/api/dashboards/dashboard_templates.py +++ b/posthog/api/dashboards/dashboard_templates.py @@ -1,6 +1,5 @@ import json from pathlib import Path -from typing import Dict import structlog from django.db.models import Q @@ -50,7 +49,7 @@ class DashboardTemplateSerializer(serializers.ModelSerializer): "scope", ] - def create(self, validated_data: Dict, *args, **kwargs) -> DashboardTemplate: + def create(self, validated_data: dict, *args, **kwargs) -> DashboardTemplate: if not validated_data["tiles"]: raise ValidationError(detail="You need to provide tiles for the template.") @@ -61,7 +60,7 @@ class DashboardTemplateSerializer(serializers.ModelSerializer): validated_data["team_id"] = self.context["team_id"] return super().create(validated_data, *args, **kwargs) - def update(self, instance: DashboardTemplate, validated_data: Dict, *args, **kwargs) -> DashboardTemplate: + def update(self, instance: DashboardTemplate, validated_data: dict, *args, **kwargs) -> DashboardTemplate: # if the original request was to make the template scope to team only, and the template is none then deny the request if validated_data.get("scope") == "team" and instance.scope == "global" and not instance.team_id: raise ValidationError(detail="The original templates cannot be made private as they would be lost.") diff --git a/posthog/api/dashboards/test/test_dashboard_templates.py b/posthog/api/dashboards/test/test_dashboard_templates.py index f07610ba903..e562b3798d8 100644 --- a/posthog/api/dashboards/test/test_dashboard_templates.py +++ b/posthog/api/dashboards/test/test_dashboard_templates.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, List +from typing import Optional from rest_framework import status @@ -510,7 +510,7 @@ class TestDashboardTemplates(APIBaseTest): assert flag_response.status_code == status.HTTP_200_OK assert [(r["id"], r["scope"]) for r in flag_response.json()["results"]] == [(flag_template_id, "feature_flag")] - def create_template(self, overrides: Dict[str, str | List[str]], team_id: Optional[int] = None) -> str: + def create_template(self, overrides: dict[str, str | list[str]], team_id: Optional[int] = None) -> str: template = {**variable_template, **overrides} response = self.client.post( f"/api/projects/{team_id or self.team.pk}/dashboard_templates", diff --git a/posthog/api/dead_letter_queue.py b/posthog/api/dead_letter_queue.py index 93e2b09370b..2bab6875435 100644 --- a/posthog/api/dead_letter_queue.py +++ b/posthog/api/dead_letter_queue.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from rest_framework import mixins, permissions, serializers, viewsets @@ -65,7 +65,7 @@ class DeadLetterQueueMetric: key: str = "" metric: str = "" value: Union[str, bool, int, None] = None - subrows: Optional[List[Any]] = None + subrows: Optional[list[Any]] = None def __init__(self, **kwargs): for field in ("key", "metric", "value", "subrows"): @@ -138,7 +138,7 @@ def get_dead_letter_queue_events_last_24h() -> int: )[0][0] -def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT error, count(*) AS c @@ -151,7 +151,7 @@ def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Un ) -def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT error_location, count(*) AS c @@ -164,7 +164,7 @@ def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List ) -def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT toDate(error_timestamp) as day, count(*) AS c @@ -177,7 +177,7 @@ def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Unio ) -def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT arrayJoin(tags) as tag, count(*) as c from events_dead_letter_queue diff --git a/posthog/api/decide.py b/posthog/api/decide.py index 3a6e08bc7a7..827194dea9c 100644 --- a/posthog/api/decide.py +++ b/posthog/api/decide.py @@ -1,6 +1,6 @@ import re from random import random -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from urllib.parse import urlparse import structlog @@ -56,7 +56,7 @@ def on_permitted_recording_domain(team: Team, request: HttpRequest) -> bool: return is_authorized_web_client or is_authorized_mobile_client -def hostname_in_allowed_url_list(allowed_url_list: Optional[List[str]], hostname: Optional[str]) -> bool: +def hostname_in_allowed_url_list(allowed_url_list: Optional[list[str]], hostname: Optional[str]) -> bool: if not hostname: return False @@ -182,7 +182,7 @@ def get_decide(request: HttpRequest): if geoip_enabled: property_overrides = get_geoip_properties(get_ip_address(request)) - all_property_overrides: Dict[str, Union[str, int]] = { + all_property_overrides: dict[str, Union[str, int]] = { **property_overrides, **(data.get("person_properties") or {}), } @@ -296,8 +296,8 @@ def get_decide(request: HttpRequest): return cors_response(request, JsonResponse(response)) -def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | Dict: - session_recording_config_response: bool | Dict = False +def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | dict: + session_recording_config_response: bool | dict = False try: if team.session_recording_opt_in and ( @@ -312,7 +312,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool linked_flag = None linked_flag_config = team.session_recording_linked_flag or None - if isinstance(linked_flag_config, Dict): + if isinstance(linked_flag_config, dict): linked_flag_key = linked_flag_config.get("key", None) linked_flag_variant = linked_flag_config.get("variant", None) if linked_flag_variant is not None: @@ -330,7 +330,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool "networkPayloadCapture": team.session_recording_network_payload_capture_config or None, } - if isinstance(team.session_replay_config, Dict): + if isinstance(team.session_replay_config, dict): record_canvas = team.session_replay_config.get("record_canvas", False) session_recording_config_response.update( { diff --git a/posthog/api/documentation.py b/posthog/api/documentation.py index 47820a9cb22..3cae48fcdb0 100644 --- a/posthog/api/documentation.py +++ b/posthog/api/documentation.py @@ -1,5 +1,5 @@ import re -from typing import Dict, get_args +from typing import get_args from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( @@ -215,7 +215,7 @@ def preprocess_exclude_path_format(endpoints, **kwargs): def custom_postprocessing_hook(result, generator, request, public): all_tags = [] - paths: Dict[str, Dict] = {} + paths: dict[str, dict] = {} for path, methods in result["paths"].items(): paths[path] = {} diff --git a/posthog/api/early_access_feature.py b/posthog/api/early_access_feature.py index 911c860a75a..57885666fde 100644 --- a/posthog/api/early_access_feature.py +++ b/posthog/api/early_access_feature.py @@ -1,5 +1,3 @@ -from typing import Type - from django.http import JsonResponse from rest_framework.response import Response from posthog.api.feature_flag import FeatureFlagSerializer, MinimalFeatureFlagSerializer @@ -221,7 +219,7 @@ class EarlyAccessFeatureViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): scope_object = "early_access_feature" queryset = EarlyAccessFeature.objects.select_related("feature_flag").all() - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if self.request.method == "POST": return EarlyAccessFeatureSerializerCreateOnly else: diff --git a/posthog/api/element.py b/posthog/api/element.py index d7b721dee81..b617ea8be28 100644 --- a/posthog/api/element.py +++ b/posthog/api/element.py @@ -1,4 +1,4 @@ -from typing import Literal, Tuple +from typing import Literal from rest_framework import request, response, serializers, viewsets from rest_framework.decorators import action @@ -128,8 +128,8 @@ class ElementViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): else: return response.Response(serialized_elements) - def _events_filter(self, request) -> Tuple[Literal["$autocapture", "$rageclick"], ...]: - event_to_filter: Tuple[Literal["$autocapture", "$rageclick"], ...] = () + def _events_filter(self, request) -> tuple[Literal["$autocapture", "$rageclick"], ...]: + event_to_filter: tuple[Literal["$autocapture", "$rageclick"], ...] = () # when multiple includes are sent expects them as separate parameters # e.g. ?include=a&include=b events_to_include = request.query_params.getlist("include", []) diff --git a/posthog/api/event.py b/posthog/api/event.py index 6366ee866f6..5c642a26129 100644 --- a/posthog/api/event.py +++ b/posthog/api/event.py @@ -1,7 +1,7 @@ import json import urllib from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union # noqa: UP035 from django.db.models.query import Prefetch from drf_spectacular.types import OpenApiTypes @@ -94,7 +94,7 @@ class EventViewSet( self, request: request.Request, last_event_timestamp: datetime, - order_by: List[str], + order_by: list[str], ) -> str: params = request.GET.dict() reverse = "-timestamp" in order_by @@ -175,7 +175,7 @@ class EventViewSet( team = self.team filter = Filter(request=request, team=self.team) - order_by: List[str] = ( + order_by: list[str] = ( list(json.loads(request.GET["orderBy"])) if request.GET.get("orderBy") else ["-timestamp"] ) @@ -217,11 +217,11 @@ class EventViewSet( capture_exception(ex) raise ex - def _get_people(self, query_result: List[Dict], team: Team) -> Dict[str, Any]: + def _get_people(self, query_result: List[dict], team: Team) -> dict[str, Any]: # noqa: UP006 distinct_ids = [event["distinct_id"] for event in query_result] persons = get_persons_by_distinct_ids(team.pk, distinct_ids) persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache")) - distinct_to_person: Dict[str, Person] = {} + distinct_to_person: dict[str, Person] = {} for person in persons: for distinct_id in person.distinct_ids: distinct_to_person[distinct_id] = person diff --git a/posthog/api/event_definition.py b/posthog/api/event_definition.py index 82a9c0617bd..76314578fb9 100644 --- a/posthog/api/event_definition.py +++ b/posthog/api/event_definition.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Tuple, Type, cast +from typing import Any, Literal, cast from django.db.models import Manager, Prefetch from rest_framework import ( @@ -117,7 +117,7 @@ class EventDefinitionViewSet( def _ordering_params_from_request( self, - ) -> Tuple[str, Literal["ASC", "DESC"]]: + ) -> tuple[str, Literal["ASC", "DESC"]]: order_direction: Literal["ASC", "DESC"] ordering = self.request.GET.get("ordering") @@ -154,7 +154,7 @@ class EventDefinitionViewSet( return EventDefinition.objects.get(id=id, team_id=self.team_id) - def get_serializer_class(self) -> Type[serializers.ModelSerializer]: + def get_serializer_class(self) -> type[serializers.ModelSerializer]: serializer_class = self.serializer_class if EE_AVAILABLE and self.request.user.organization.is_feature_available( # type: ignore AvailableFeature.INGESTION_TAXONOMY diff --git a/posthog/api/exports.py b/posthog/api/exports.py index 2099b2f169e..9fbaea35df3 100644 --- a/posthog/api/exports.py +++ b/posthog/api/exports.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict +from typing import Any import structlog from django.http import HttpResponse @@ -40,7 +40,7 @@ class ExportedAssetSerializer(serializers.ModelSerializer): ] read_only_fields = ["id", "created_at", "has_content", "filename"] - def validate(self, data: Dict) -> Dict: + def validate(self, data: dict) -> dict: if not data.get("export_format"): raise ValidationError("Must provide export format") @@ -61,13 +61,13 @@ class ExportedAssetSerializer(serializers.ModelSerializer): def synthetic_create(self, reason: str, *args: Any, **kwargs: Any) -> ExportedAsset: return self._create_asset(self.validated_data, user=None, reason=reason) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> ExportedAsset: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExportedAsset: request = self.context["request"] return self._create_asset(validated_data, user=request.user, reason=None) def _create_asset( self, - validated_data: Dict, + validated_data: dict, user: User | None, reason: str | None, ) -> ExportedAsset: diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py index 8bf1dbb5d3c..bd53f029552 100644 --- a/posthog/api/feature_flag.py +++ b/posthog/api/feature_flag.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from datetime import datetime from django.db.models import QuerySet, Q, deletion @@ -145,12 +145,12 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo and feature_flag.aggregation_group_type_index is None ) - def get_features(self, feature_flag: FeatureFlag) -> Dict: + def get_features(self, feature_flag: FeatureFlag) -> dict: from posthog.api.early_access_feature import MinimalEarlyAccessFeatureSerializer return MinimalEarlyAccessFeatureSerializer(feature_flag.features, many=True).data - def get_surveys(self, feature_flag: FeatureFlag) -> Dict: + def get_surveys(self, feature_flag: FeatureFlag) -> dict: from posthog.api.survey import SurveyAPISerializer return SurveyAPISerializer(feature_flag.surveys_linked_flag, many=True).data @@ -263,7 +263,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo return filters - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag: request = self.context["request"] validated_data["created_by"] = request.user validated_data["team_id"] = self.context["team_id"] @@ -299,7 +299,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo return instance - def update(self, instance: FeatureFlag, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag: + def update(self, instance: FeatureFlag, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag: if "deleted" in validated_data and validated_data["deleted"] is True and instance.features.count() > 0: raise exceptions.ValidationError( "Cannot delete a feature flag that is in use with early access features. Please delete the early access feature before deleting the flag." @@ -496,13 +496,11 @@ class FeatureFlagViewSet( feature_flags, many=True, context=self.get_serializer_context() ).data return Response( - ( - { - "feature_flag": feature_flag, - "value": matches.get(feature_flag["key"], False), - } - for feature_flag in all_serialized_flags - ) + { + "feature_flag": feature_flag, + "value": matches.get(feature_flag["key"], False), + } + for feature_flag in all_serialized_flags ) @action( @@ -516,7 +514,7 @@ class FeatureFlagViewSet( should_send_cohorts = "send_cohorts" in request.GET cohorts = {} - seen_cohorts_cache: Dict[int, CohortOrEmpty] = {} + seen_cohorts_cache: dict[int, CohortOrEmpty] = {} if should_send_cohorts: seen_cohorts_cache = { diff --git a/posthog/api/geoip.py b/posthog/api/geoip.py index d3d029cdd3f..7a749c0b294 100644 --- a/posthog/api/geoip.py +++ b/posthog/api/geoip.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import structlog from django.contrib.gis.geoip2 import GeoIP2 @@ -27,7 +27,7 @@ VALID_GEOIP_PROPERTIES = [ ] -def get_geoip_properties(ip_address: Optional[str]) -> Dict[str, str]: +def get_geoip_properties(ip_address: Optional[str]) -> dict[str, str]: """ Returns a dictionary of geoip properties for the given ip address. diff --git a/posthog/api/insight.py b/posthog/api/insight.py index 36495a5469b..a2fe0c53edc 100644 --- a/posthog/api/insight.py +++ b/posthog/api/insight.py @@ -1,6 +1,6 @@ import json from functools import lru_cache -from typing import Any, Dict, List, Optional, Type, Union, cast +from typing import Any, Optional, Union, cast import structlog from django.db import transaction @@ -118,7 +118,7 @@ def log_insight_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: """ Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance @@ -148,7 +148,7 @@ class QuerySchemaParser(JSONParser): """ def parse(self, stream, media_type=None, parser_context=None): - data = super(QuerySchemaParser, self).parse(stream, media_type, parser_context) + data = super().parse(stream, media_type, parser_context) try: query = data.get("query", None) if query: @@ -197,7 +197,7 @@ class InsightBasicSerializer(TaggedItemSerializerMixin, serializers.ModelSeriali ] read_only_fields = ("short_id", "updated_at", "last_refresh", "refreshing") - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Any: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError() def to_representation(self, instance): @@ -306,7 +306,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin): "is_cached", ) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Insight: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Insight: request = self.context["request"] tags = validated_data.pop("tags", None) # tags are created separately as global tag relationships team_id = self.context["team_id"] @@ -345,8 +345,8 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin): return insight - def update(self, instance: Insight, validated_data: Dict, **kwargs) -> Insight: - dashboards_before_change: List[Union[str, Dict]] = [] + def update(self, instance: Insight, validated_data: dict, **kwargs) -> Insight: + dashboards_before_change: list[Union[str, dict]] = [] try: # since it is possible to be undeleting a soft deleted insight # the state captured before the update has to include soft deleted insights @@ -411,7 +411,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin): changes=changes, ) - def _synthetic_dashboard_changes(self, dashboards_before_change: List[Dict]) -> List[Change]: + def _synthetic_dashboard_changes(self, dashboards_before_change: list[dict]) -> list[Change]: artificial_dashboard_changes = self.context.get("after_dashboard_changes", []) if artificial_dashboard_changes: return [ @@ -426,7 +426,7 @@ class InsightSerializer(InsightBasicSerializer, UserPermissionsSerializerMixin): return [] - def _update_insight_dashboards(self, dashboards: List[Dashboard], instance: Insight) -> None: + def _update_insight_dashboards(self, dashboards: list[Dashboard], instance: Insight) -> None: old_dashboard_ids = [tile.dashboard_id for tile in instance.dashboard_tiles.all()] new_dashboard_ids = [d.id for d in dashboards if not d.deleted] @@ -598,14 +598,14 @@ class InsightViewSet( parser_classes = (QuerySchemaParser,) - def get_serializer_class(self) -> Type[serializers.BaseSerializer]: + def get_serializer_class(self) -> type[serializers.BaseSerializer]: if (self.action == "list" or self.action == "retrieve") and str_to_bool( self.request.query_params.get("basic", "0") ): return InsightBasicSerializer return super().get_serializer_class() - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["is_shared"] = isinstance(self.request.successful_authenticator, SharingAccessTokenAuthentication) return context @@ -867,7 +867,7 @@ Using the correct cache and enriching the response with dashboard specific confi return Response({**result, "next": next}) @cached_by_filters - def calculate_trends(self, request: request.Request) -> Dict[str, Any]: + def calculate_trends(self, request: request.Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, team=self.team) @@ -919,7 +919,7 @@ Using the correct cache and enriching the response with dashboard specific confi return Response(funnel) @cached_by_filters - def calculate_funnel(self, request: request.Request) -> Dict[str, Any]: + def calculate_funnel(self, request: request.Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) @@ -959,7 +959,7 @@ Using the correct cache and enriching the response with dashboard specific confi return Response(result) @cached_by_filters - def calculate_retention(self, request: request.Request) -> Dict[str, Any]: + def calculate_retention(self, request: request.Request) -> dict[str, Any]: team = self.team data = {} if not request.GET.get("date_from") and not request.data.get("date_from"): @@ -989,7 +989,7 @@ Using the correct cache and enriching the response with dashboard specific confi return Response(result) @cached_by_filters - def calculate_path(self, request: request.Request) -> Dict[str, Any]: + def calculate_path(self, request: request.Request) -> dict[str, Any]: team = self.team filter = PathFilter(request=request, data={"insight": INSIGHT_PATHS}, team=self.team) diff --git a/posthog/api/instance_settings.py b/posthog/api/instance_settings.py index dc0b41e5cb1..13c1461ba56 100644 --- a/posthog/api/instance_settings.py +++ b/posthog/api/instance_settings.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from rest_framework import exceptions, mixins, permissions, serializers, viewsets @@ -50,7 +50,7 @@ class InstanceSettingHelper: setattr(self, field, kwargs.get(field, None)) -def get_instance_setting(key: str, setting_config: Optional[Tuple] = None) -> InstanceSettingHelper: +def get_instance_setting(key: str, setting_config: Optional[tuple] = None) -> InstanceSettingHelper: setting_config = setting_config or CONSTANCE_CONFIG[key] is_secret = key in SECRET_SETTINGS value = get_instance_setting_raw(key) @@ -73,7 +73,7 @@ class InstanceSettingsSerializer(serializers.Serializer): editable = serializers.BooleanField(read_only=True) is_secret = serializers.BooleanField(read_only=True) - def update(self, instance: InstanceSettingHelper, validated_data: Dict[str, Any]) -> InstanceSettingHelper: + def update(self, instance: InstanceSettingHelper, validated_data: dict[str, Any]) -> InstanceSettingHelper: if instance.key not in SETTINGS_ALLOWING_API_OVERRIDE: raise serializers.ValidationError("This setting cannot be updated from the API.", code="no_api_override") diff --git a/posthog/api/instance_status.py b/posthog/api/instance_status.py index c0dff3a3e4a..1e001b74703 100644 --- a/posthog/api/instance_status.py +++ b/posthog/api/instance_status.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Union from django.conf import settings from django.db import connection @@ -40,7 +40,7 @@ class InstanceStatusViewSet(viewsets.ViewSet): redis_alive = is_redis_alive() postgres_alive = is_postgres_alive() - metrics: List[Dict[str, Union[str, bool, int, float, Dict[str, Any]]]] = [] + metrics: list[dict[str, Union[str, bool, int, float, dict[str, Any]]]] = [] metrics.append( {"key": "posthog_git_sha", "metric": "PostHog Git SHA", "value": get_git_commit_short() or "unknown"} diff --git a/posthog/api/mixins.py b/posthog/api/mixins.py index 69b83d3469e..a326eb3d1d2 100644 --- a/posthog/api/mixins.py +++ b/posthog/api/mixins.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Type +from typing import TypeVar from pydantic import BaseModel, ValidationError @@ -9,7 +9,7 @@ T = TypeVar("T", bound=BaseModel) class PydanticModelMixin: - def get_model(self, data: dict, model: Type[T]) -> T: + def get_model(self, data: dict, model: type[T]) -> T: try: return model.model_validate(data) except ValidationError as exc: diff --git a/posthog/api/notebook.py b/posthog/api/notebook.py index 5910af4948c..4125b79dd65 100644 --- a/posthog/api/notebook.py +++ b/posthog/api/notebook.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any, Type +from typing import Optional, Any from django.db.models import Q import structlog from django.db import transaction @@ -58,7 +58,7 @@ def log_notebook_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: short_id = str(notebook.short_id) @@ -118,7 +118,7 @@ class NotebookSerializer(NotebookMinimalSerializer): "last_modified_by", ] - def create(self, validated_data: Dict, *args, **kwargs) -> Notebook: + def create(self, validated_data: dict, *args, **kwargs) -> Notebook: request = self.context["request"] team = self.context["get_team"]() @@ -141,7 +141,7 @@ class NotebookSerializer(NotebookMinimalSerializer): return notebook - def update(self, instance: Notebook, validated_data: Dict, **kwargs) -> Notebook: + def update(self, instance: Notebook, validated_data: dict, **kwargs) -> Notebook: try: before_update = Notebook.objects.get(pk=instance.id) except Notebook.DoesNotExist: @@ -240,7 +240,7 @@ class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.Model filterset_fields = ["short_id"] lookup_field = "short_id" - def get_serializer_class(self) -> Type[BaseSerializer]: + def get_serializer_class(self) -> type[BaseSerializer]: return NotebookMinimalSerializer if self.action == "list" else NotebookSerializer def get_queryset(self) -> QuerySet: @@ -298,8 +298,8 @@ class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.Model if target: # the JSONB query requires a specific structure - basic_structure = List[Dict[str, Any]] - nested_structure = basic_structure | List[Dict[str, basic_structure]] + basic_structure = list[dict[str, Any]] + nested_structure = basic_structure | list[dict[str, basic_structure]] presence_match_structure: basic_structure | nested_structure = [{"type": f"ph-{target}"}] diff --git a/posthog/api/organization.py b/posthog/api/organization.py index ea1a9f31615..f528d541319 100644 --- a/posthog/api/organization.py +++ b/posthog/api/organization.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from django.db.models import Model, QuerySet from django.shortcuts import get_object_or_404 @@ -108,7 +108,7 @@ class OrganizationSerializer(serializers.ModelSerializer, UserPermissionsSeriali }, # slug is not required here as it's generated automatically for new organizations } - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Organization: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Organization: serializers.raise_errors_on_nested_writes("create", self, validated_data) user = self.context["request"].user organization, _, _ = Organization.objects.bootstrap(user, **validated_data) @@ -119,11 +119,11 @@ class OrganizationSerializer(serializers.ModelSerializer, UserPermissionsSeriali membership = self.user_permissions.organization_memberships.get(organization.pk) return membership.level if membership is not None else None - def get_teams(self, instance: Organization) -> List[Dict[str, Any]]: + def get_teams(self, instance: Organization) -> list[dict[str, Any]]: visible_teams = instance.teams.filter(id__in=self.user_permissions.team_ids_visible_for_user) return TeamBasicSerializer(visible_teams, context=self.context, many=True).data # type: ignore - def get_metadata(self, instance: Organization) -> Dict[str, Union[str, int, object]]: + def get_metadata(self, instance: Organization) -> dict[str, Union[str, int, object]]: return { "instance_tag": settings.INSTANCE_TAG, } @@ -210,7 +210,7 @@ class OrganizationViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): ignore_conflicts=True, ) - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: return { **super().get_serializer_context(), "user_permissions": UserPermissions(cast(User, self.request.user)), diff --git a/posthog/api/organization_domain.py b/posthog/api/organization_domain.py index b3a4ada0b4e..81b8c8efad8 100644 --- a/posthog/api/organization_domain.py +++ b/posthog/api/organization_domain.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, cast +from typing import Any, cast from rest_framework import exceptions, request, response, serializers from rest_framework.decorators import action @@ -38,7 +38,7 @@ class OrganizationDomainSerializer(serializers.ModelSerializer): "has_saml": {"read_only": True}, } - def create(self, validated_data: Dict[str, Any]) -> OrganizationDomain: + def create(self, validated_data: dict[str, Any]) -> OrganizationDomain: validated_data["organization"] = self.context["view"].organization validated_data.pop( "jit_provisioning_enabled", None @@ -56,7 +56,7 @@ class OrganizationDomainSerializer(serializers.ModelSerializer): raise serializers.ValidationError("Please enter a valid domain or subdomain name.") return domain - def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: instance = cast(OrganizationDomain, self.instance) if instance and not instance.verified_at: diff --git a/posthog/api/organization_feature_flag.py b/posthog/api/organization_feature_flag.py index 0ed25ada28e..d2468cb07ce 100644 --- a/posthog/api/organization_feature_flag.py +++ b/posthog/api/organization_feature_flag.py @@ -1,4 +1,3 @@ -from typing import Dict from django.core.exceptions import ObjectDoesNotExist from rest_framework.response import Response from rest_framework.decorators import action @@ -95,13 +94,13 @@ class OrganizationFeatureFlagView( continue # get all linked cohorts, sorted by creation order - seen_cohorts_cache: Dict[int, CohortOrEmpty] = {} + seen_cohorts_cache: dict[int, CohortOrEmpty] = {} sorted_cohort_ids = flag_to_copy.get_cohort_ids( seen_cohorts_cache=seen_cohorts_cache, sort_by_topological_order=True ) # destination cohort id is different from original cohort id - create mapping - name_to_dest_cohort_id: Dict[str, int] = {} + name_to_dest_cohort_id: dict[str, int] = {} # create cohorts in the destination project if len(sorted_cohort_ids): for cohort_id in sorted_cohort_ids: diff --git a/posthog/api/organization_invite.py b/posthog/api/organization_invite.py index 6a8140479a9..961f2cddba2 100644 --- a/posthog/api/organization_invite.py +++ b/posthog/api/organization_invite.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from rest_framework import ( exceptions, @@ -49,7 +49,7 @@ class OrganizationInviteSerializer(serializers.ModelSerializer): local_part, domain = email.split("@") return f"{local_part}@{domain.lower()}" - def create(self, validated_data: Dict[str, Any], *args: Any, **kwargs: Any) -> OrganizationInvite: + def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> OrganizationInvite: if OrganizationMembership.objects.filter( organization_id=self.context["organization_id"], user__email=validated_data["target_email"], diff --git a/posthog/api/person.py b/posthog/api/person.py index 942f07e9a9e..e242c7ffffe 100644 --- a/posthog/api/person.py +++ b/posthog/api/person.py @@ -2,17 +2,14 @@ import json import posthoganalytics from posthog.renderers import SafeJSONRenderer from datetime import datetime -from typing import ( +from typing import ( # noqa: UP035 Any, - Callable, - Dict, List, Optional, - Tuple, - Type, TypeVar, cast, ) +from collections.abc import Callable from django.db.models import Prefetch from django.shortcuts import get_object_or_404 @@ -176,7 +173,7 @@ class PersonSerializer(serializers.HyperlinkedModelSerializer): team = self.context["get_team"]() return get_person_name(team, person) - def to_representation(self, instance: Person) -> Dict[str, Any]: + def to_representation(self, instance: Person) -> dict[str, Any]: representation = super().to_representation(instance) representation["distinct_ids"] = sorted(representation["distinct_ids"], key=is_anonymous_id) return representation @@ -192,7 +189,7 @@ class MinimalPersonSerializer(PersonSerializer): def get_funnel_actor_class(filter: Filter) -> Callable: - funnel_actor_class: Type[ActorBaseQuery] + funnel_actor_class: type[ActorBaseQuery] if filter.correlation_person_entity and EE_AVAILABLE: if EE_AVAILABLE: @@ -678,7 +675,7 @@ class PersonViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): ) # PRAGMA: Methods for getting Persons via clickhouse queries - def _respond_with_cached_results(self, results_package: Dict[str, Tuple[List, Optional[str], Optional[str], int]]): + def _respond_with_cached_results(self, results_package: dict[str, tuple[List, Optional[str], Optional[str], int]]): # noqa: UP006 if not results_package: return response.Response(data=[]) @@ -705,7 +702,7 @@ class PersonViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): @cached_by_filters def calculate_funnel_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) filter = prepare_actor_query_filter(filter) funnel_actor_class = get_funnel_actor_class(filter) @@ -734,7 +731,7 @@ class PersonViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): @cached_by_filters def calculate_path_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = PathFilter(request=request, data={"insight": INSIGHT_PATHS}, team=self.team) filter = prepare_actor_query_filter(filter) @@ -769,7 +766,7 @@ class PersonViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): @cached_by_filters def calculate_trends_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = Filter(request=request, team=self.team) filter = prepare_actor_query_filter(filter) entity = get_target_entity(filter) diff --git a/posthog/api/plugin.py b/posthog/api/plugin.py index 2a6e00f3254..7a4dea1a8d7 100644 --- a/posthog/api/plugin.py +++ b/posthog/api/plugin.py @@ -2,7 +2,7 @@ import json import os import re import subprocess -from typing import Any, Dict, List, Optional, Set, cast, Literal +from typing import Any, Optional, cast, Literal import requests from dateutil.relativedelta import relativedelta @@ -64,8 +64,8 @@ def _update_plugin_attachments(request: request.Request, plugin_config: PluginCo def get_plugin_config_changes( - old_config: Dict[str, Any], new_config: Dict[str, Any], secret_fields=None -) -> List[Change]: + old_config: dict[str, Any], new_config: dict[str, Any], secret_fields=None +) -> list[Change]: if secret_fields is None: secret_fields = [] config_changes = dict_changes_between("Plugin", old_config, new_config) @@ -103,8 +103,8 @@ def log_enabled_change_activity( def log_config_update_activity( new_plugin_config: PluginConfig, - old_config: Dict[str, Any], - secret_fields: Set[str], + old_config: dict[str, Any], + secret_fields: set[str], old_enabled: bool, user: User, was_impersonated: bool, @@ -280,7 +280,7 @@ class PluginSerializer(serializers.ModelSerializer): def get_organization_name(self, plugin: Plugin) -> str: return plugin.organization.name - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Plugin: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Plugin: validated_data["url"] = self.initial_data.get("url", None) validated_data["organization_id"] = self.context["organization_id"] validated_data["updated_at"] = now() @@ -291,7 +291,7 @@ class PluginSerializer(serializers.ModelSerializer): return plugin - def update(self, plugin: Plugin, validated_data: Dict, *args: Any, **kwargs: Any) -> Plugin: # type: ignore + def update(self, plugin: Plugin, validated_data: dict, *args: Any, **kwargs: Any) -> Plugin: # type: ignore context_organization = self.context["get_organization"]() if ( "is_global" in validated_data @@ -387,7 +387,7 @@ class PluginViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): @action(methods=["GET"], detail=True) def source(self, request: request.Request, **kwargs): plugin = self.get_plugin_with_permissions(reason="source editing") - response: Dict[str, str] = {} + response: dict[str, str] = {} for source in PluginSourceFile.objects.filter(plugin=plugin): response[source.filename] = source.source return Response(response) @@ -395,7 +395,7 @@ class PluginViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): @action(methods=["PATCH"], detail=True) def update_source(self, request: request.Request, **kwargs): plugin = self.get_plugin_with_permissions(reason="source editing") - sources: Dict[str, PluginSourceFile] = {} + sources: dict[str, PluginSourceFile] = {} performed_changes = False for plugin_source_file in PluginSourceFile.objects.filter(plugin=plugin): sources[plugin_source_file.filename] = plugin_source_file @@ -438,7 +438,7 @@ class PluginViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): sources[key].error = error sources[key].save() - response: Dict[str, str] = {} + response: dict[str, str] = {} for _, source in sources.items(): response[source.filename] = source.source @@ -476,7 +476,7 @@ class PluginViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): Plugin.PluginType.SOURCE, Plugin.PluginType.LOCAL, ): - validated_data: Dict[str, Any] = {} + validated_data: dict[str, Any] = {} plugin_json = update_validated_data_from_url(validated_data, plugin.url) with transaction.atomic(): serializer.update(plugin, validated_data) @@ -647,7 +647,7 @@ class PluginConfigSerializer(serializers.ModelSerializer): # error details instead. return None - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> PluginConfig: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> PluginConfig: if not can_configure_plugins(self.context["get_organization"]()): raise ValidationError("Plugin configuration is not available for the current organization!") validated_data["team_id"] = self.context["team_id"] @@ -682,7 +682,7 @@ class PluginConfigSerializer(serializers.ModelSerializer): def update( # type: ignore self, plugin_config: PluginConfig, - validated_data: Dict, + validated_data: dict, *args: Any, **kwargs: Any, ) -> PluginConfig: @@ -731,7 +731,7 @@ class PluginConfigViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): queryset = queryset.filter(deleted=False) return queryset.order_by("order", "plugin_id") - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() if context["view"].action in ("retrieve", "list"): context["delivery_rates_1d"] = TeamPluginsDeliveryRateQuery(self.team).run() @@ -856,7 +856,7 @@ class PluginConfigViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): content = plugin_source.transpiled or "" return HttpResponse(content, content_type="application/javascript; charset=UTF-8") - obj: Dict[str, Any] = {} + obj: dict[str, Any] = {} if not plugin_source: obj = {"no_frontend": True} elif plugin_source.status is None or plugin_source.status == PluginSourceFile.Status.LOCKED: @@ -868,7 +868,7 @@ class PluginConfigViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): return HttpResponse(content, content_type="application/javascript; charset=UTF-8") -def _get_secret_fields_for_plugin(plugin: Plugin) -> Set[str]: +def _get_secret_fields_for_plugin(plugin: Plugin) -> set[str]: # A set of keys for config fields that have secret = true secret_fields = {field["key"] for field in plugin.config_schema if isinstance(field, dict) and field.get("secret")} return secret_fields diff --git a/posthog/api/property_definition.py b/posthog/api/property_definition.py index 584644f902b..6a87fc6f348 100644 --- a/posthog/api/property_definition.py +++ b/posthog/api/property_definition.py @@ -1,6 +1,6 @@ import dataclasses import json -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast from django.db import connection from django.db.models import Prefetch @@ -125,7 +125,7 @@ class QueryContext: posthog_eventproperty_table_join_alias = "check_for_matching_event_property" - params: Dict = dataclasses.field(default_factory=dict) + params: dict = dataclasses.field(default_factory=dict) def with_properties_to_filter(self, properties_to_filter: Optional[str]) -> "QueryContext": if properties_to_filter: @@ -219,7 +219,7 @@ class QueryContext: params={**self.params, "event_names": list(map(str, event_names or []))}, ) - def with_search(self, search_query: str, search_kwargs: Dict) -> "QueryContext": + def with_search(self, search_query: str, search_kwargs: dict) -> "QueryContext": return dataclasses.replace( self, search_query=search_query, @@ -443,7 +443,7 @@ class NotCountingLimitOffsetPaginator(LimitOffsetPagination): return self.count - def paginate_queryset(self, queryset, request, view=None) -> Optional[List[Any]]: + def paginate_queryset(self, queryset, request, view=None) -> Optional[list[Any]]: """ Assumes the queryset has already had pagination applied """ @@ -570,7 +570,7 @@ class PropertyDefinitionViewSet( return queryset.raw(query_context.as_sql(order_by_verified), params=query_context.params) - def get_serializer_class(self) -> Type[serializers.ModelSerializer]: + def get_serializer_class(self) -> type[serializers.ModelSerializer]: serializer_class = self.serializer_class if self.request.user.organization.is_feature_available(AvailableFeature.INGESTION_TAXONOMY): try: diff --git a/posthog/api/routing.py b/posthog/api/routing.py index b768538c05d..02654051a3f 100644 --- a/posthog/api/routing.py +++ b/posthog/api/routing.py @@ -1,5 +1,5 @@ from functools import cached_property, lru_cache -from typing import TYPE_CHECKING, Any, Dict, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from rest_framework.exceptions import AuthenticationFailed, NotFound, ValidationError from rest_framework.permissions import IsAuthenticated @@ -50,7 +50,7 @@ class TeamAndOrgViewSetMixin(_GenericViewSet): # Rewrite filter queries, so that for example foreign keys can be accessed # Example: {"team_id": "foo__team_id"} will make the viewset filtered by obj.foo.team_id instead of obj.team_id - filter_rewrite_rules: Dict[str, str] = {} + filter_rewrite_rules: dict[str, str] = {} authentication_classes = [] permission_classes = [] @@ -170,7 +170,7 @@ class TeamAndOrgViewSetMixin(_GenericViewSet): return queryset @cached_property - def parents_query_dict(self) -> Dict[str, Any]: + def parents_query_dict(self) -> dict[str, Any]: # used to override the last visited project if there's a token in the request team_from_request = self._get_team_from_request() @@ -213,7 +213,7 @@ class TeamAndOrgViewSetMixin(_GenericViewSet): result[query_lookup] = query_value return result - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: serializer_context = super().get_serializer_context() if hasattr(super(), "get_serializer_context") else {} serializer_context.update(self.parents_query_dict) # The below are lambdas for lazy evaluation (i.e. we only query Postgres for team/org if actually needed) diff --git a/posthog/api/scheduled_change.py b/posthog/api/scheduled_change.py index 5d1878ebfe4..2100f6b7bdc 100644 --- a/posthog/api/scheduled_change.py +++ b/posthog/api/scheduled_change.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from rest_framework import ( serializers, viewsets, @@ -29,7 +29,7 @@ class ScheduledChangeSerializer(serializers.ModelSerializer): ] read_only_fields = ["id", "created_at", "created_by", "updated_at"] - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> ScheduledChange: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ScheduledChange: request = self.context["request"] validated_data["created_by"] = request.user validated_data["team_id"] = self.context["team_id"] diff --git a/posthog/api/sharing.py b/posthog/api/sharing.py index c7ab40fb0f8..3d4a2d69374 100644 --- a/posthog/api/sharing.py +++ b/posthog/api/sharing.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from urllib.parse import urlparse, urlunparse from django.core.serializers.json import DjangoJSONEncoder @@ -87,7 +87,7 @@ class SharingConfigurationViewSet(TeamAndOrgViewSetMixin, mixins.ListModelMixin, def get_serializer_context( self, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: context = super().get_serializer_context() dashboard_id = context.get("dashboard_id") @@ -113,7 +113,7 @@ class SharingConfigurationViewSet(TeamAndOrgViewSetMixin, mixins.ListModelMixin, return context - def _get_sharing_configuration(self, context: Dict[str, Any]): + def _get_sharing_configuration(self, context: dict[str, Any]): """ Gets but does not create a SharingConfiguration. Only once enabled do we actually store it """ @@ -247,7 +247,7 @@ class SharingViewerPageViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSe "user_permissions": UserPermissions(cast(User, request.user), resource.team), "is_shared": True, } - exported_data: Dict[str, Any] = {"type": "embed" if embedded else "scene"} + exported_data: dict[str, Any] = {"type": "embed" if embedded else "scene"} if isinstance(resource, SharingConfiguration) and request.path.endswith(f".png"): exported_data["accessToken"] = resource.access_token diff --git a/posthog/api/signup.py b/posthog/api/signup.py index c31f37b891e..8385dc77597 100644 --- a/posthog/api/signup.py +++ b/posthog/api/signup.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Optional, Union, cast from urllib.parse import urlencode import structlog @@ -71,7 +71,7 @@ class SignupSerializer(serializers.Serializer): super().__init__(*args, **kwargs) self.is_social_signup = False - def get_fields(self) -> Dict[str, serializers.Field]: + def get_fields(self) -> dict[str, serializers.Field]: fields = super().get_fields() if settings.DEMO: # There's no password in the demo env @@ -156,7 +156,7 @@ class SignupSerializer(serializers.Serializer): def create_team(self, organization: Organization, user: User) -> Team: return Team.objects.create_with_data(user=user, organization=organization) - def to_representation(self, instance) -> Dict: + def to_representation(self, instance) -> dict: data = UserBasicSerializer(instance=instance).data data["redirect_url"] = get_redirect_url(data["uuid"], data["is_email_verified"]) return data @@ -185,7 +185,7 @@ class InviteSignupSerializer(serializers.Serializer): data["redirect_url"] = get_redirect_url(data["uuid"], data["is_email_verified"]) return data - def validate(self, data: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, data: dict[str, Any]) -> dict[str, Any]: if "request" not in self.context or not self.context["request"].user.is_authenticated: # If there's no authenticated user and we're creating a new one, attributes are required. @@ -469,7 +469,7 @@ def social_create_user( return {"is_new": False} backend_processor = "social_create_user" - email = details["email"][0] if isinstance(details["email"], (list, tuple)) else details["email"] + email = details["email"][0] if isinstance(details["email"], list | tuple) else details["email"] full_name = ( details.get("fullname") or f"{details.get('first_name') or ''} {details.get('last_name') or ''}".strip() diff --git a/posthog/api/survey.py b/posthog/api/survey.py index cb991a5f95a..3ffce982b89 100644 --- a/posthog/api/survey.py +++ b/posthog/api/survey.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from typing import Type from django.http import JsonResponse @@ -271,7 +270,7 @@ class SurveyViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): scope_object = "survey" queryset = Survey.objects.select_related("linked_flag", "targeting_flag").all() - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if self.request.method == "POST" or self.request.method == "PATCH": return SurveySerializerCreateUpdateOnly else: diff --git a/posthog/api/tagged_item.py b/posthog/api/tagged_item.py index 85aa08323a0..d73275523b6 100644 --- a/posthog/api/tagged_item.py +++ b/posthog/api/tagged_item.py @@ -50,7 +50,7 @@ class TaggedItemSerializerMixin(serializers.Serializer): obj.prefetched_tags = tagged_item_objects def to_representation(self, obj): - ret = super(TaggedItemSerializerMixin, self).to_representation(obj) + ret = super().to_representation(obj) ret["tags"] = [] if self._is_licensed(): if hasattr(obj, "prefetched_tags"): @@ -61,12 +61,12 @@ class TaggedItemSerializerMixin(serializers.Serializer): def create(self, validated_data): validated_data.pop("tags", None) - instance = super(TaggedItemSerializerMixin, self).create(validated_data) + instance = super().create(validated_data) self._attempt_set_tags(self.initial_data.get("tags"), instance) return instance def update(self, instance, validated_data): - instance = super(TaggedItemSerializerMixin, self).update(instance, validated_data) + instance = super().update(instance, validated_data) self._attempt_set_tags(self.initial_data.get("tags"), instance) return instance @@ -96,7 +96,7 @@ class TaggedItemViewSetMixin(viewsets.GenericViewSet): return queryset def get_queryset(self): - queryset = super(TaggedItemViewSetMixin, self).get_queryset() + queryset = super().get_queryset() return self.prefetch_tagged_items_if_available(queryset) diff --git a/posthog/api/team.py b/posthog/api/team.py index c8b2513b679..39acc8c2a0a 100644 --- a/posthog/api/team.py +++ b/posthog/api/team.py @@ -1,6 +1,6 @@ import json from functools import cached_property -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast from django.core.cache import cache from django.shortcuts import get_object_or_404 @@ -190,11 +190,11 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin def get_groups_on_events_querying_enabled(self, team: Team) -> bool: return groups_on_events_querying_enabled() - def validate_session_recording_linked_flag(self, value) -> Dict | None: + def validate_session_recording_linked_flag(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") received_keys = value.keys() valid_keys = [ @@ -208,11 +208,11 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin return value - def validate_session_recording_network_payload_capture_config(self, value) -> Dict | None: + def validate_session_recording_network_payload_capture_config(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") if not all(key in ["recordHeaders", "recordBody"] for key in value.keys()): @@ -222,11 +222,11 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin return value - def validate_session_replay_config(self, value) -> Dict | None: + def validate_session_replay_config(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") known_keys = ["record_canvas", "ai_config"] @@ -240,9 +240,9 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin return value - def validate_session_replay_ai_summary_config(self, value: Dict | None) -> Dict | None: + def validate_session_replay_ai_summary_config(self, value: dict | None) -> dict | None: if value is not None: - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") allowed_keys = [ @@ -294,7 +294,7 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin ) return super().validate(attrs) - def create(self, validated_data: Dict[str, Any], **kwargs) -> Team: + def create(self, validated_data: dict[str, Any], **kwargs) -> Team: serializers.raise_errors_on_nested_writes("create", self, validated_data) request = self.context["request"] organization = self.context["view"].organization # Use the org we used to validate permissions @@ -337,7 +337,7 @@ class TeamSerializer(serializers.ModelSerializer, UserPermissionsSerializerMixin hashes = InsightCachingState.objects.filter(team=team).values_list("cache_key", flat=True) cache.delete_many(hashes) - def update(self, instance: Team, validated_data: Dict[str, Any]) -> Team: + def update(self, instance: Team, validated_data: dict[str, Any]) -> Team: before_update = instance.__dict__.copy() if "timezone" in validated_data and validated_data["timezone"] != instance.timezone: @@ -406,13 +406,13 @@ class TeamViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): visible_teams_ids = UserPermissions(cast(User, self.request.user)).team_ids_visible_for_user return super().get_queryset().filter(id__in=visible_teams_ids) - def get_serializer_class(self) -> Type[serializers.BaseSerializer]: + def get_serializer_class(self) -> type[serializers.BaseSerializer]: if self.action == "list": return TeamBasicSerializer return super().get_serializer_class() # NOTE: Team permissions are somewhat complex so we override the underlying viewset's get_permissions method - def get_permissions(self) -> List: + def get_permissions(self) -> list: """ Special permissions handling for create requests as the organization is inferred from the current user. """ diff --git a/posthog/api/test/dashboards/__init__.py b/posthog/api/test/dashboards/__init__.py index 79d1e435e64..ad6505b5a61 100644 --- a/posthog/api/test/dashboards/__init__.py +++ b/posthog/api/test/dashboards/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional from rest_framework import status @@ -15,7 +15,7 @@ class DashboardAPI: self, model_id: int, model_type: Literal["insights", "dashboards"], - extra_data: Optional[Dict] = None, + extra_data: Optional[dict] = None, expected_get_status: int = status.HTTP_404_NOT_FOUND, ) -> None: if extra_data is None: @@ -33,10 +33,10 @@ class DashboardAPI: def create_dashboard( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id response = self.client.post(f"/api/projects/{team_id}/dashboards/", data) @@ -49,10 +49,10 @@ class DashboardAPI: def update_dashboard( self, dashboard_id: int, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id response = self.client.patch(f"/api/projects/{team_id}/dashboards/{dashboard_id}", data) @@ -67,8 +67,8 @@ class DashboardAPI: dashboard_id: int, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + query_params: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: if team_id is None: team_id = self.team.id @@ -82,8 +82,8 @@ class DashboardAPI: self, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict] = None, - ) -> Dict: + query_params: Optional[dict] = None, + ) -> dict: if team_id is None: team_id = self.team.id @@ -100,8 +100,8 @@ class DashboardAPI: self, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict] = None, - ) -> Dict: + query_params: Optional[dict] = None, + ) -> dict: if team_id is None: team_id = self.team.id @@ -122,8 +122,8 @@ class DashboardAPI: insight_id: int, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + query_params: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: if team_id is None: team_id = self.team.id @@ -138,10 +138,10 @@ class DashboardAPI: def create_insight( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -160,10 +160,10 @@ class DashboardAPI: def update_insight( self, insight_id: int, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -177,10 +177,10 @@ class DashboardAPI: self, dashboard_id: int, text: str = "I AM TEXT!", - extra_data: Optional[Dict] = None, + extra_data: Optional[dict] = None, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -218,10 +218,10 @@ class DashboardAPI: def update_text_tile( self, dashboard_id: int, - tile: Dict, + tile: dict, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -271,7 +271,7 @@ class DashboardAPI: def add_insight_to_dashboard( self, - dashboard_ids: List[int], + dashboard_ids: list[int], insight_id: int, expected_status: int = status.HTTP_200_OK, ): diff --git a/posthog/api/test/dashboards/test_dashboard.py b/posthog/api/test/dashboards/test_dashboard.py index 234123dde16..7e8f3fafc87 100644 --- a/posthog/api/test/dashboards/test_dashboard.py +++ b/posthog/api/test/dashboards/test_dashboard.py @@ -1,5 +1,4 @@ import json -from typing import Dict from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -21,7 +20,7 @@ from posthog.models.signals import mute_selected_signals from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_queries, FuzzyInt from posthog.utils import generate_cache_key -valid_template: Dict = { +valid_template: dict = { "template_name": "Sign up conversion template with variables", "dashboard_description": "Use this template to see how many users sign up after visiting your pricing page.", "dashboard_filters": {}, @@ -1186,7 +1185,7 @@ class TestDashboard(APIBaseTest, QueryMatchingTest): ) def test_create_from_template_json_must_provide_at_least_one_tile(self) -> None: - template: Dict = {**valid_template, "tiles": []} + template: dict = {**valid_template, "tiles": []} response = self.client.post( f"/api/projects/{self.team.id}/dashboards/create_from_template_json", @@ -1195,7 +1194,7 @@ class TestDashboard(APIBaseTest, QueryMatchingTest): assert response.status_code == 400, response.json() def test_create_from_template_json_can_provide_text_tile(self) -> None: - template: Dict = { + template: dict = { **valid_template, "tiles": [{"type": "TEXT", "body": "hello world", "layouts": {}}], } @@ -1226,7 +1225,7 @@ class TestDashboard(APIBaseTest, QueryMatchingTest): ] def test_create_from_template_json_can_provide_query_tile(self) -> None: - template: Dict = { + template: dict = { **valid_template, # client provides an incorrect "empty" filter alongside a query "tiles": [ diff --git a/posthog/api/test/dashboards/test_dashboard_duplication.py b/posthog/api/test/dashboards/test_dashboard_duplication.py index dbfa572e9c0..f477f9f1e05 100644 --- a/posthog/api/test/dashboards/test_dashboard_duplication.py +++ b/posthog/api/test/dashboards/test_dashboard_duplication.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.api.test.dashboards import DashboardAPI from posthog.test.base import APIBaseTest, QueryMatchingTest @@ -85,7 +83,7 @@ class TestDashboardDuplication(APIBaseTest, QueryMatchingTest): ] @staticmethod - def _tile_child_ids_from(dashboard_json: Dict) -> List[int]: + def _tile_child_ids_from(dashboard_json: dict) -> list[int]: return [ (tile.get("insight", None) or {}).get("id", None) or (tile.get("text", None) or {}).get("id", None) for tile in dashboard_json["tiles"] diff --git a/posthog/api/test/dashboards/test_dashboard_text_tiles.py b/posthog/api/test/dashboards/test_dashboard_text_tiles.py index 34b9366da5a..d3f899d72d2 100644 --- a/posthog/api/test/dashboards/test_dashboard_text_tiles.py +++ b/posthog/api/test/dashboards/test_dashboard_text_tiles.py @@ -1,5 +1,5 @@ import datetime -from typing import Dict, Optional, Union +from typing import Optional, Union from unittest import mock from freezegun import freeze_time @@ -16,7 +16,7 @@ class TestDashboardTiles(APIBaseTest, QueryMatchingTest): self.dashboard_api = DashboardAPI(self.client, self.team, self.assertEqual) @staticmethod - def _serialised_user(user: Optional[User]) -> Optional[Dict[str, Optional[Union[int, str]]]]: + def _serialised_user(user: Optional[User]) -> Optional[dict[str, Optional[Union[int, str]]]]: if user is None: return None @@ -37,7 +37,7 @@ class TestDashboardTiles(APIBaseTest, QueryMatchingTest): last_modified_by: Optional[User] = None, text_id: Optional[int] = None, last_modified_at: str = "2022-04-01T12:45:00Z", - ) -> Dict: + ) -> dict: if not created_by: created_by = self.user @@ -62,7 +62,7 @@ class TestDashboardTiles(APIBaseTest, QueryMatchingTest): text_id: Optional[int] = None, color: Optional[str] = None, last_modified_at: str = "2022-04-01T12:45:00Z", - ) -> Dict: + ) -> dict: if not tile_id: tile_id = mock.ANY return { @@ -82,7 +82,7 @@ class TestDashboardTiles(APIBaseTest, QueryMatchingTest): } @staticmethod - def _tile_layout(lg: Optional[Dict] = None) -> Dict: + def _tile_layout(lg: Optional[dict] = None) -> dict: if lg is None: lg = {"x": "0", "y": "0", "w": "6", "h": "5"} diff --git a/posthog/api/test/notebooks/test_notebook.py b/posthog/api/test/notebooks/test_notebook.py index 2779f1a226c..f01d8fd6bc6 100644 --- a/posthog/api/test/notebooks/test_notebook.py +++ b/posthog/api/test/notebooks/test_notebook.py @@ -1,4 +1,3 @@ -from typing import List, Dict from unittest import mock from freezegun import freeze_time @@ -11,7 +10,7 @@ from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_ class TestNotebooks(APIBaseTest, QueryMatchingTest): - def created_activity(self, item_id: str, short_id: str) -> Dict: + def created_activity(self, item_id: str, short_id: str) -> dict: return { "activity": "created", "created_at": mock.ANY, @@ -30,11 +29,11 @@ class TestNotebooks(APIBaseTest, QueryMatchingTest): }, } - def assert_notebook_activity(self, expected: List[Dict]) -> None: + def assert_notebook_activity(self, expected: list[dict]) -> None: activity_response = self.client.get(f"/api/projects/{self.team.id}/notebooks/activity") assert activity_response.status_code == status.HTTP_200_OK - activity: List[Dict] = activity_response.json()["results"] + activity: list[dict] = activity_response.json()["results"] self.maxDiff = None assert activity == expected @@ -78,7 +77,7 @@ class TestNotebooks(APIBaseTest, QueryMatchingTest): ), ] ) - def test_create_a_notebook(self, _, content: Dict | None, text_content: str | None) -> None: + def test_create_a_notebook(self, _, content: dict | None, text_content: str | None) -> None: response = self.client.post( f"/api/projects/{self.team.id}/notebooks", data={"content": content, "text_content": text_content}, diff --git a/posthog/api/test/notebooks/test_notebook_filtering.py b/posthog/api/test/notebooks/test_notebook_filtering.py index bbe191892d8..06b543deca4 100644 --- a/posthog/api/test/notebooks/test_notebook_filtering.py +++ b/posthog/api/test/notebooks/test_notebook_filtering.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List +from typing import Any from parameterized import parameterized from rest_framework import status @@ -59,7 +59,7 @@ BASIC_TEXT = lambda text: { class TestNotebooksFiltering(APIBaseTest, QueryMatchingTest): - def _create_notebook_with_content(self, inner_content: List[Dict[str, Any]], title: str = "the title") -> str: + def _create_notebook_with_content(self, inner_content: list[dict[str, Any]], title: str = "the title") -> str: response = self.client.post( f"/api/projects/{self.team.id}/notebooks", data={ @@ -83,7 +83,7 @@ class TestNotebooksFiltering(APIBaseTest, QueryMatchingTest): ["random", []], ] ) - def test_filters_based_on_title(self, search_text: str, expected_match_indexes: List[int]) -> None: + def test_filters_based_on_title(self, search_text: str, expected_match_indexes: list[int]) -> None: notebook_ids = [ self._create_notebook_with_content([BASIC_TEXT("my important notes")], title="i ride around on a pony"), self._create_notebook_with_content([BASIC_TEXT("my important notes")], title="my hobby is to fish around"), @@ -108,7 +108,7 @@ class TestNotebooksFiltering(APIBaseTest, QueryMatchingTest): ["neither", []], ] ) - def test_filters_based_on_text_content(self, search_text: str, expected_match_indexes: List[int]) -> None: + def test_filters_based_on_text_content(self, search_text: str, expected_match_indexes: list[int]) -> None: notebook_ids = [ # will match both pony and ponies self._create_notebook_with_content([BASIC_TEXT("you may ride a pony")], title="never matches"), diff --git a/posthog/api/test/openapi_validation.py b/posthog/api/test/openapi_validation.py index e86bf5198bb..20d2fb1e1a6 100644 --- a/posthog/api/test/openapi_validation.py +++ b/posthog/api/test/openapi_validation.py @@ -1,7 +1,7 @@ import gzip import json from io import BytesIO -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from urllib.parse import parse_qs import lzstring @@ -11,7 +11,7 @@ from django.test.client import FakePayload from jsonschema import validate -def validate_response(openapi_spec: Dict[str, Any], response: Any, path_override: Optional[str] = None): +def validate_response(openapi_spec: dict[str, Any], response: Any, path_override: Optional[str] = None): # Validates are response against the OpenAPI spec. If `path_override` is # provided, the path in the response will be overridden with the provided # value. This is useful for validating responses from e.g. the /batch diff --git a/posthog/api/test/test_activity_log.py b/posthog/api/test/test_activity_log.py index a7573f10cab..c386d30de6c 100644 --- a/posthog/api/test/test_activity_log.py +++ b/posthog/api/test/test_activity_log.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from freezegun import freeze_time from freezegun.api import FrozenDateTimeFactory, StepTickTimeFactory @@ -9,7 +9,7 @@ from posthog.models import User from posthog.test.base import APIBaseTest, QueryMatchingTest -def _feature_flag_json_payload(key: str) -> Dict: +def _feature_flag_json_payload(key: str) -> dict: return { "key": key, "name": "", @@ -103,7 +103,7 @@ class TestActivityLog(APIBaseTest, QueryMatchingTest): def _edit_them_all( self, - created_insights: List[int], + created_insights: list[int], flag_one: str, flag_two: str, notebook_short_id: str, @@ -269,10 +269,10 @@ class TestActivityLog(APIBaseTest, QueryMatchingTest): def _create_insight( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id diff --git a/posthog/api/test/test_capture.py b/posthog/api/test/test_capture.py index f771aca99b3..1beb4e9724b 100644 --- a/posthog/api/test/test_capture.py +++ b/posthog/api/test/test_capture.py @@ -25,7 +25,7 @@ from parameterized import parameterized from prance import ResolvingParser from rest_framework import status from token_bucket import Limiter, MemoryStorage -from typing import Any, Dict, List, Union, cast +from typing import Any, Union, cast from unittest.mock import ANY, MagicMock, call, patch from urllib.parse import quote @@ -60,7 +60,7 @@ parser = ResolvingParser( url=str(pathlib.Path(__file__).parent / "../../../openapi/capture.yaml"), strict=True, ) -openapi_spec = cast(Dict[str, Any], parser.specification) +openapi_spec = cast(dict[str, Any], parser.specification) large_data_array = [ {"key": "".join(random.choice(string.ascii_letters) for _ in range(512 * 1024))} @@ -162,7 +162,7 @@ class TestCapture(BaseTest): # it is really important to know that /capture is CSRF exempt. Enforce checking in the client self.client = Client(enforce_csrf_checks=True) - def _to_json(self, data: Union[Dict, List]) -> str: + def _to_json(self, data: Union[dict, list]) -> str: return json.dumps(data) def _dict_to_b64(self, data: dict) -> str: @@ -188,7 +188,7 @@ class TestCapture(BaseTest): def _send_original_version_session_recording_event( self, number_of_events: int = 1, - event_data: Dict | None = None, + event_data: dict | None = None, snapshot_source=3, snapshot_type=1, session_id="abc123", @@ -229,7 +229,7 @@ class TestCapture(BaseTest): def _send_august_2023_version_session_recording_event( self, number_of_events: int = 1, - event_data: Dict | List[Dict] | None = None, + event_data: dict | list[dict] | None = None, session_id="abc123", window_id="def456", distinct_id="ghi789", @@ -241,7 +241,7 @@ class TestCapture(BaseTest): # event_data is an array of RRWeb events event_data = [{"type": 3, "data": {"source": 1}}, {"type": 3, "data": {"source": 2}}] - if isinstance(event_data, Dict): + if isinstance(event_data, dict): event_data = [event_data] event = { @@ -260,7 +260,7 @@ class TestCapture(BaseTest): "distinct_id": distinct_id, } - post_data: List[Dict[str, Any]] | Dict[str, Any] + post_data: list[dict[str, Any]] | dict[str, Any] if content_type == "application/json": post_data = [{**event, "api_key": self.team.api_token} for _ in range(number_of_events)] @@ -1254,7 +1254,7 @@ class TestCapture(BaseTest): } self.client.get( - "/e/?_=%s&data=%s" % (int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), + "/e/?_={}&data={}".format(int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), content_type="application/json", HTTP_ORIGIN="https://localhost", ) @@ -1283,7 +1283,7 @@ class TestCapture(BaseTest): } self.client.get( - "/e/?_=%s&data=%s" % (int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), + "/e/?_={}&data={}".format(int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), content_type="application/json", HTTP_ORIGIN="https://localhost", ) @@ -1526,7 +1526,7 @@ class TestCapture(BaseTest): ), ] ) - def test_cors_allows_tracing_headers(self, _: str, path: str, headers: List[str]) -> None: + def test_cors_allows_tracing_headers(self, _: str, path: str, headers: list[str]) -> None: expected_headers = ",".join(["X-Requested-With", "Content-Type", *headers]) presented_headers = ",".join([*headers, "someotherrandomheader"]) response = self.client.options( diff --git a/posthog/api/test/test_cohort.py b/posthog/api/test/test_cohort.py index 0b1971f8f2c..4e1a3da2d52 100644 --- a/posthog/api/test/test_cohort.py +++ b/posthog/api/test/test_cohort.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any from unittest.mock import patch from django.core.files.uploadedfile import SimpleUploadedFile @@ -1493,11 +1493,11 @@ email@example.org, self.assertEqual(async_deletion.delete_verified_at is not None, True) -def create_cohort(client: Client, team_id: int, name: str, groups: List[Dict[str, Any]]): +def create_cohort(client: Client, team_id: int, name: str, groups: list[dict[str, Any]]): return client.post(f"/api/projects/{team_id}/cohorts", {"name": name, "groups": json.dumps(groups)}) -def create_cohort_ok(client: Client, team_id: int, name: str, groups: List[Dict[str, Any]]): +def create_cohort_ok(client: Client, team_id: int, name: str, groups: list[dict[str, Any]]): response = create_cohort(client=client, team_id=team_id, name=name, groups=groups) assert response.status_code == 201, response.content return response.json() diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py index e89fb0b3c12..af9b2db88e3 100644 --- a/posthog/api/test/test_decide.py +++ b/posthog/api/test/test_decide.py @@ -3457,9 +3457,11 @@ class TestDatabaseCheckForDecide(BaseTest, QueryMatchingTest): # remove database check cache values postgres_healthcheck.cache_clear() - with connection.execute_wrapper(QueryTimeoutWrapper()), snapshot_postgres_queries_context( - self - ), self.assertNumQueries(1): + with ( + connection.execute_wrapper(QueryTimeoutWrapper()), + snapshot_postgres_queries_context(self), + self.assertNumQueries(1), + ): response = self._post_decide(api_version=3, origin="https://random.example.com").json() response = self._post_decide(api_version=3, origin="https://random.example.com").json() response = self._post_decide(api_version=3, origin="https://random.example.com").json() @@ -3607,8 +3609,10 @@ class TestDecideUsesReadReplica(TransactionTestCase): self.organization, self.team, self.user = org, team, user # this create fills up team cache^ - with freeze_time("2021-01-01T00:00:00Z"), self.assertNumQueries(1, using="replica"), self.assertNumQueries( - 1, using="default" + with ( + freeze_time("2021-01-01T00:00:00Z"), + self.assertNumQueries(1, using="replica"), + self.assertNumQueries(1, using="default"), ): response = self._post_decide() # Replica queries: @@ -4031,9 +4035,11 @@ class TestDecideUsesReadReplica(TransactionTestCase): # now main database is down, but does not affect replica - with connections["default"].execute_wrapper(QueryTimeoutWrapper()), self.assertNumQueries( - 13, using="replica" - ), self.assertNumQueries(0, using="default"): + with ( + connections["default"].execute_wrapper(QueryTimeoutWrapper()), + self.assertNumQueries(13, using="replica"), + self.assertNumQueries(0, using="default"), + ): # Replica queries: # E 1. SET LOCAL statement_timeout = 300 # E 2. WITH some CTEs, diff --git a/posthog/api/test/test_element.py b/posthog/api/test/test_element.py index 72a97ea2b9b..25cd01df353 100644 --- a/posthog/api/test/test_element.py +++ b/posthog/api/test/test_element.py @@ -1,6 +1,5 @@ import json from datetime import timedelta -from typing import Dict, List from django.test import override_settings from freezegun import freeze_time @@ -17,7 +16,7 @@ from posthog.test.base import ( snapshot_postgres_queries, ) -expected_autocapture_data_response_results: List[Dict] = [ +expected_autocapture_data_response_results: list[dict] = [ { "count": 3, "hash": None, @@ -78,7 +77,7 @@ expected_autocapture_data_response_results: List[Dict] = [ }, ] -expected_rage_click_data_response_results: List[Dict] = [ +expected_rage_click_data_response_results: list[dict] = [ { "count": 1, "hash": None, diff --git a/posthog/api/test/test_event_definition.py b/posthog/api/test/test_event_definition.py index aa2a2c05a24..c530708886b 100644 --- a/posthog/api/test/test_event_definition.py +++ b/posthog/api/test/test_event_definition.py @@ -1,6 +1,6 @@ import dataclasses from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest.mock import ANY, patch from uuid import uuid4 @@ -20,7 +20,7 @@ from posthog.test.base import APIBaseTest class TestEventDefinitionAPI(APIBaseTest): demo_team: Team = None # type: ignore - EXPECTED_EVENT_DEFINITIONS: List[Dict[str, Any]] = [ + EXPECTED_EVENT_DEFINITIONS: list[dict[str, Any]] = [ {"name": "installed_app"}, {"name": "rated_app"}, {"name": "purchase"}, @@ -54,7 +54,7 @@ class TestEventDefinitionAPI(APIBaseTest): self.assertEqual(len(response.json()["results"]), len(self.EXPECTED_EVENT_DEFINITIONS)) for item in self.EXPECTED_EVENT_DEFINITIONS: - response_item: Dict[str, Any] = next( + response_item: dict[str, Any] = next( (_i for _i in response.json()["results"] if _i["name"] == item["name"]), {}, ) @@ -199,7 +199,7 @@ class EventData: team_id: int distinct_id: str timestamp: datetime - properties: Dict[str, Any] + properties: dict[str, Any] def capture_event(event: EventData): @@ -222,7 +222,7 @@ def capture_event(event: EventData): ) -def create_event_definitions(event_definition: Dict, team_id: int) -> EventDefinition: +def create_event_definitions(event_definition: dict, team_id: int) -> EventDefinition: """ Create event definition for a team. """ diff --git a/posthog/api/test/test_exports.py b/posthog/api/test/test_exports.py index eead54e9055..5e804866206 100644 --- a/posthog/api/test/test_exports.py +++ b/posthog/api/test/test_exports.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import patch from datetime import datetime, timedelta import celery @@ -435,10 +435,10 @@ class TestExports(APIBaseTest): self.assertEqual(activity.status_code, expected_status) return activity.json() - def _assert_logs_the_activity(self, insight_id: int, expected: List[Dict]) -> None: + def _assert_logs_the_activity(self, insight_id: int, expected: list[dict]) -> None: activity_response = self._get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertEqual(activity, expected) @@ -463,7 +463,7 @@ class TestExports(APIBaseTest): class TestExportMixin(APIBaseTest): - def _get_export_output(self, path: str) -> List[str]: + def _get_export_output(self, path: str) -> list[str]: """ Use this function to test the CSV output of exports in other tests """ diff --git a/posthog/api/test/test_feature_flag.py b/posthog/api/test/test_feature_flag.py index 18236c8332f..4c353b98124 100644 --- a/posthog/api/test/test_feature_flag.py +++ b/posthog/api/test/test_feature_flag.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import call, patch from django.core.cache import cache @@ -3657,10 +3657,10 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): self.assertEqual(activity.status_code, expected_status) return activity.json() - def assert_feature_flag_activity(self, flag_id: Optional[int], expected: List[Dict]): + def assert_feature_flag_activity(self, flag_id: Optional[int], expected: list[dict]): activity_response = self._get_feature_flag_activity(flag_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected @@ -3898,7 +3898,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) # check deleting the dashboard doesn't delete flag, but deletes the relationship dashboard.delete() @@ -3928,7 +3928,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) def test_feature_flag_dashboard_already_exists(self): another_feature_flag = FeatureFlag.objects.create( @@ -3954,7 +3954,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) @freeze_time("2021-01-01") @snapshot_clickhouse_queries @@ -3988,8 +3988,11 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): ) flush_persons_and_events() - with snapshot_postgres_queries_context(self), self.settings( - CELERY_TASK_ALWAYS_EAGER=True, PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False + with ( + snapshot_postgres_queries_context(self), + self.settings( + CELERY_TASK_ALWAYS_EAGER=True, PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False + ), ): response = self.client.post( f"/api/projects/{self.team.id}/feature_flags/{flag.id}/create_static_cohort_for_flag", @@ -5328,9 +5331,13 @@ class TestResiliency(TransactionTestCase, QueryMatchingTest): self.assertFalse(errors) # now db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): mock_postgres_check.return_value = False all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5423,10 +5430,15 @@ class TestResiliency(TransactionTestCase, QueryMatchingTest): self.assertTrue(errors) # db is slow and times out, but shouldn't matter to us - with self.assertNumQueries(0), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, - ), self.settings(DECIDE_SKIP_POSTGRES_FLAGS=True): + with ( + self.assertNumQueries(0), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), + self.settings(DECIDE_SKIP_POSTGRES_FLAGS=True), + ): mock_postgres_check.return_value = False all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5536,10 +5548,15 @@ class TestResiliency(TransactionTestCase, QueryMatchingTest): self.assertFalse(errors) # now db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, - ), self.assertNumQueries(4): + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), + self.assertNumQueries(4), + ): # no extra queries to get person properties for the second flag after first one failed all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5627,9 +5644,13 @@ class TestResiliency(TransactionTestCase, QueryMatchingTest): self.assertFalse(errors) # now db is slow - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): with self.assertNumQueries(4): all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id", groups={"organization": "org:1"}) @@ -5737,9 +5758,13 @@ class TestResiliency(TransactionTestCase, QueryMatchingTest): self.assertFalse(errors) # db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id", hash_key_override="random") diff --git a/posthog/api/test/test_feature_flag_utils.py b/posthog/api/test/test_feature_flag_utils.py index 53369794dfe..c13bf04b670 100644 --- a/posthog/api/test/test_feature_flag_utils.py +++ b/posthog/api/test/test_feature_flag_utils.py @@ -1,4 +1,3 @@ -from typing import Set from posthog.models.cohort.cohort import CohortOrEmpty from posthog.test.base import ( APIBaseTest, @@ -68,7 +67,7 @@ class TestFeatureFlagUtils(APIBaseTest): self.assertEqual(topologically_sorted_cohort_ids, destination_creation_order) def test_empty_cohorts_set(self): - cohort_ids: Set[int] = set() + cohort_ids: set[int] = set() seen_cohorts_cache: dict[int, CohortOrEmpty] = {} topologically_sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, seen_cohorts_cache) self.assertEqual(topologically_sorted_cohort_ids, []) diff --git a/posthog/api/test/test_ingestion_warnings.py b/posthog/api/test/test_ingestion_warnings.py index bdf39969559..05e893babfa 100644 --- a/posthog/api/test/test_ingestion_warnings.py +++ b/posthog/api/test/test_ingestion_warnings.py @@ -1,5 +1,4 @@ import json -from typing import Dict from freezegun.api import freeze_time from rest_framework import status @@ -13,7 +12,7 @@ from posthog.test.base import APIBaseTest, ClickhouseTestMixin from posthog.utils import cast_timestamp_or_now -def create_ingestion_warning(team_id: int, type: str, details: Dict, timestamp: str, source=""): +def create_ingestion_warning(team_id: int, type: str, details: dict, timestamp: str, source=""): timestamp = cast_timestamp_or_now(timestamp) data = { "team_id": team_id, diff --git a/posthog/api/test/test_insight.py b/posthog/api/test/test_insight.py index b427ce13a12..2184261eccd 100644 --- a/posthog/api/test/test_insight.py +++ b/posthog/api/test/test_insight.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest import mock from unittest.case import skip from unittest.mock import patch @@ -343,7 +343,7 @@ class TestInsight(ClickhouseTestMixin, APIBaseTest, QueryMatchingTest): @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @snapshot_postgres_queries def test_listing_insights_does_not_nplus1(self) -> None: - query_counts: List[int] = [] + query_counts: list[int] = [] queries = [] for i in range(5): @@ -2059,7 +2059,7 @@ class TestInsight(ClickhouseTestMixin, APIBaseTest, QueryMatchingTest): ) self.assertEqual(response.status_code, status.HTTP_200_OK) - def _create_one_person_cohort(self, properties: List[Dict[str, Any]]) -> int: + def _create_one_person_cohort(self, properties: list[dict[str, Any]]) -> int: Person.objects.create(team=self.team, properties=properties) cohort_one_id = self.client.post( f"/api/projects/{self.team.id}/cohorts", @@ -2426,7 +2426,7 @@ class TestInsight(ClickhouseTestMixin, APIBaseTest, QueryMatchingTest): # assert that undeletes end up in the activity log activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] # we will have three logged activities (in reverse order) undelete, delete, create assert [a["activity"] for a in activity] == ["updated", "updated", "created"] undelete_change_log = activity[0]["detail"]["changes"][0] @@ -2478,10 +2478,10 @@ class TestInsight(ClickhouseTestMixin, APIBaseTest, QueryMatchingTest): query_params = f"?events={json.dumps([{'id': '$pageview', }])}&client_query_id={client_query_id}" self.client.get(f"/api/projects/{self.team.id}/insights/trend/{query_params}").json() - def assert_insight_activity(self, insight_id: Optional[int], expected: List[Dict]): + def assert_insight_activity(self, insight_id: Optional[int], expected: list[dict]): activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected diff --git a/posthog/api/test/test_insight_funnels.py b/posthog/api/test/test_insight_funnels.py index b02ebfec558..3b4c1403a58 100644 --- a/posthog/api/test/test_insight_funnels.py +++ b/posthog/api/test/test_insight_funnels.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Union from django.test.client import Client from rest_framework import status @@ -1004,7 +1004,7 @@ class ClickhouseTestFunnelTypes(ClickhouseTestMixin, APIBaseTest): self.assertEqual(["Chrome", "95"], result[1][1]["breakdown_value"]) @staticmethod - def as_result(breakdown_properties: Union[str, List[str]]) -> Dict[str, Any]: + def as_result(breakdown_properties: Union[str, list[str]]) -> dict[str, Any]: return { "action_id": "$pageview", "name": "$pageview", diff --git a/posthog/api/test/test_insight_query.py b/posthog/api/test/test_insight_query.py index 6279999bbef..19044cd937b 100644 --- a/posthog/api/test/test_insight_query.py +++ b/posthog/api/test/test_insight_query.py @@ -1,5 +1,3 @@ -from typing import List - from rest_framework import status from ee.api.test.base import LicensedTestMixin @@ -213,7 +211,7 @@ class TestInsight(ClickhouseTestMixin, LicensedTestMixin, APIBaseTest, QueryMatc }, ) - created_insights: List[Insight] = list(Insight.objects.all()) + created_insights: list[Insight] = list(Insight.objects.all()) assert len(created_insights) == 2 listed_insights = self.dashboard_api.list_insights(query_params={"include_query_insights": False}) @@ -236,7 +234,7 @@ class TestInsight(ClickhouseTestMixin, LicensedTestMixin, APIBaseTest, QueryMatc }, ) - created_insights: List[Insight] = list(Insight.objects.all()) + created_insights: list[Insight] = list(Insight.objects.all()) assert len(created_insights) == 2 listed_insights = self.dashboard_api.list_insights(query_params={"include_query_insights": True}) diff --git a/posthog/api/test/test_kafka_inspector.py b/posthog/api/test/test_kafka_inspector.py index 6a42741a47f..b9a02d0464e 100644 --- a/posthog/api/test/test_kafka_inspector.py +++ b/posthog/api/test/test_kafka_inspector.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Union +from typing import Union from unittest.mock import patch from rest_framework import status @@ -14,7 +14,7 @@ class TestKafkaInspector(APIBaseTest): self.user.is_staff = True self.user.save() - def _to_json(self, data: Union[Dict, List]) -> str: + def _to_json(self, data: Union[dict, list]) -> str: return json.dumps(data) @patch( diff --git a/posthog/api/test/test_organization_feature_flag.py b/posthog/api/test/test_organization_feature_flag.py index 41960032ca8..f1ad4ba26fb 100644 --- a/posthog/api/test/test_organization_feature_flag.py +++ b/posthog/api/test/test_organization_feature_flag.py @@ -11,7 +11,7 @@ from posthog.models.feedback.survey import Survey from posthog.models.early_access_feature import EarlyAccessFeature from posthog.api.dashboards.dashboard import Dashboard from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_queries -from typing import Any, Dict +from typing import Any class TestOrganizationFeatureFlagGet(APIBaseTest, QueryMatchingTest): @@ -382,7 +382,7 @@ class TestOrganizationFeatureFlagCopy(APIBaseTest, QueryMatchingTest): def test_copy_feature_flag_missing_fields(self): url = f"/api/organizations/{self.organization.id}/feature_flags/copy_flags" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} response = self.client.post(url, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/posthog/api/test/test_person.py b/posthog/api/test/test_person.py index 815f38c4729..a97e9d25de0 100644 --- a/posthog/api/test/test_person.py +++ b/posthog/api/test/test_person.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, cast +from typing import Optional, cast from unittest import mock from unittest.mock import patch, Mock @@ -982,10 +982,10 @@ class TestPerson(ClickhouseTestMixin, APIBaseTest): self.assertEqual(activity.status_code, expected_status) return activity.json() - def _assert_person_activity(self, person_id: Optional[str], expected: List[Dict]): + def _assert_person_activity(self, person_id: Optional[str], expected: list[dict]): activity_response = self._get_person_activity(person_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertCountEqual(activity, expected) diff --git a/posthog/api/test/test_plugin.py b/posthog/api/test/test_plugin.py index 16e0fc4c0d1..06642b46098 100644 --- a/posthog/api/test/test_plugin.py +++ b/posthog/api/test/test_plugin.py @@ -1,7 +1,7 @@ import base64 import json from datetime import datetime -from typing import Dict, List, cast +from typing import cast from unittest import mock from unittest.mock import ANY, patch @@ -52,10 +52,10 @@ class TestPluginAPI(APIBaseTest, QueryMatchingTest): self.assertEqual(activity.status_code, expected_status) return activity.json() - def assert_plugin_activity(self, expected: List[Dict]): + def assert_plugin_activity(self, expected: list[dict]): activity_response = self._get_plugin_activity() - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertEqual(activity, expected) @@ -586,7 +586,7 @@ class TestPluginAPI(APIBaseTest, QueryMatchingTest): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement "{FROZEN_POSTHOG_VERSION.next_minor()}".', ) @@ -608,7 +608,7 @@ class TestPluginAPI(APIBaseTest, QueryMatchingTest): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement ">= {FROZEN_POSTHOG_VERSION.next_major()}".', ) @@ -620,7 +620,7 @@ class TestPluginAPI(APIBaseTest, QueryMatchingTest): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement "< {FROZEN_POSTHOG_VERSION}".', ) @@ -642,7 +642,7 @@ class TestPluginAPI(APIBaseTest, QueryMatchingTest): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], 'Invalid PostHog semantic version requirement "< ..."!', ) diff --git a/posthog/api/test/test_properties_timeline.py b/posthog/api/test/test_properties_timeline.py index 5243151c27e..d8b8a11e909 100644 --- a/posthog/api/test/test_properties_timeline.py +++ b/posthog/api/test/test_properties_timeline.py @@ -1,7 +1,7 @@ import json import random import uuid -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from freezegun.api import freeze_time from rest_framework import status @@ -52,7 +52,7 @@ def properties_timeline_test_factory(actor_type: Literal["person", "group"]): return group.group_key def _create_event(self, event: str, timestamp: str, actor_properties: dict): - create_event_kwargs: Dict[str, Any] = {} + create_event_kwargs: dict[str, Any] = {} if actor_type == "person": create_event_kwargs["person_id"] = main_actor_id create_event_kwargs["person_properties"] = actor_properties diff --git a/posthog/api/test/test_property_definition.py b/posthog/api/test/test_property_definition.py index 77dca5e8330..378f66d7884 100644 --- a/posthog/api/test/test_property_definition.py +++ b/posthog/api/test/test_property_definition.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Union +from typing import Optional, Union from unittest.mock import ANY, patch from rest_framework import status @@ -17,7 +17,7 @@ from posthog.test.base import APIBaseTest, BaseTest class TestPropertyDefinitionAPI(APIBaseTest): - EXPECTED_PROPERTY_DEFINITIONS: List[Dict[str, Union[str, Optional[int], bool]]] = [ + EXPECTED_PROPERTY_DEFINITIONS: list[dict[str, Union[str, Optional[int], bool]]] = [ {"name": "$browser", "is_numerical": False}, {"name": "$current_url", "is_numerical": False}, {"name": "$lib", "is_numerical": False}, @@ -69,7 +69,7 @@ class TestPropertyDefinitionAPI(APIBaseTest): self.assertEqual(len(response.json()["results"]), len(self.EXPECTED_PROPERTY_DEFINITIONS)) for item in self.EXPECTED_PROPERTY_DEFINITIONS: - response_item: Dict = next( + response_item: dict = next( (_i for _i in response.json()["results"] if _i["name"] == item["name"]), {}, ) diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py index 1587c0b365e..532f7b945e7 100644 --- a/posthog/api/test/test_signup.py +++ b/posthog/api/test/test_signup.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Dict, Optional, cast +from typing import Optional, cast from unittest import mock from unittest.mock import ANY, patch from zoneinfo import ZoneInfo @@ -294,7 +294,7 @@ class TestSignupAPI(APIBaseTest): required_attributes = ["first_name", "email"] for attribute in required_attributes: - body: Dict[str, Optional[str]] = { + body: dict[str, Optional[str]] = { "first_name": "Jane", "email": "invalid@posthog.com", "password": "notsecure", diff --git a/posthog/api/test/test_site_app.py b/posthog/api/test/test_site_app.py index 82823ac4cf4..9a428774c6e 100644 --- a/posthog/api/test/test_site_app.py +++ b/posthog/api/test/test_site_app.py @@ -1,5 +1,3 @@ -from typing import List - from django.test.client import Client from rest_framework import status @@ -44,7 +42,7 @@ class TestSiteApp(BaseTest): ) def test_get_site_config_from_schema(self): - schema: List[dict] = [{"key": "in_site", "site": True}, {"key": "not_in_site"}] + schema: list[dict] = [{"key": "in_site", "site": True}, {"key": "not_in_site"}] config = {"in_site": "123", "not_in_site": "12345"} self.assertEqual(get_site_config_from_schema(schema, config), {"in_site": "123"}) self.assertEqual(get_site_config_from_schema(None, None), {}) diff --git a/posthog/api/test/test_stickiness.py b/posthog/api/test/test_stickiness.py index 56d610c205e..b3942414d54 100644 --- a/posthog/api/test/test_stickiness.py +++ b/posthog/api/test/test_stickiness.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from dateutil.relativedelta import relativedelta from django.test import override_settings @@ -20,26 +20,26 @@ from posthog.test.base import ( from posthog.utils import encode_get_request_params -def get_stickiness(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness(client: Client, team: Team, request: dict[str, Any]): return client.get(f"/api/projects/{team.pk}/insights/trend/", data=request) -def get_stickiness_ok(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness_ok(client: Client, team: Team, request: dict[str, Any]): response = get_stickiness(client=client, team=team, request=encode_get_request_params(data=request)) assert response.status_code == 200, response.content return response.json() -def get_stickiness_time_series_ok(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness_time_series_ok(client: Client, team: Team, request: dict[str, Any]): data = get_stickiness_ok(client=client, request=request, team=team) return get_time_series_ok(data) -def get_stickiness_people(client: Client, team_id: int, request: Dict[str, Any]): +def get_stickiness_people(client: Client, team_id: int, request: dict[str, Any]): return client.get("/api/person/stickiness/", data=request) -def get_stickiness_people_ok(client: Client, team_id: int, request: Dict[str, Any]): +def get_stickiness_people_ok(client: Client, team_id: int, request: dict[str, Any]): response = get_stickiness_people(client=client, team_id=team_id, request=encode_get_request_params(data=request)) assert response.status_code == 200 return response.json() diff --git a/posthog/api/test/test_team.py b/posthog/api/test/test_team.py index d23efe81cf7..0cae63e3b60 100644 --- a/posthog/api/test/test_team.py +++ b/posthog/api/test/test_team.py @@ -1,6 +1,6 @@ import json import uuid -from typing import List, cast, Dict, Optional, Any +from typing import cast, Optional, Any from unittest import mock from unittest.mock import MagicMock, call, patch, ANY @@ -27,7 +27,7 @@ from posthog.test.base import APIBaseTest class TestTeamAPI(APIBaseTest): - def _assert_activity_log(self, expected: List[Dict], team_id: Optional[int] = None) -> None: + def _assert_activity_log(self, expected: list[dict], team_id: Optional[int] = None) -> None: if not team_id: team_id = self.team.pk @@ -35,7 +35,7 @@ class TestTeamAPI(APIBaseTest): assert starting_log_response.status_code == 200 assert starting_log_response.json()["results"] == expected - def _assert_organization_activity_log(self, expected: List[Dict]) -> None: + def _assert_organization_activity_log(self, expected: list[dict]) -> None: starting_log_response = self.client.get(f"/api/organizations/{self.organization.pk}/activity") assert starting_log_response.status_code == 200 assert starting_log_response.json()["results"] == expected @@ -95,7 +95,7 @@ class TestTeamAPI(APIBaseTest): @patch("posthog.api.team.get_geoip_properties") def test_ip_location_is_used_for_new_project_week_day_start(self, get_geoip_properties_mock: MagicMock): - self.organization.available_features = cast(List[str], [AvailableFeature.ORGANIZATIONS_PROJECTS]) + self.organization.available_features = cast(list[str], [AvailableFeature.ORGANIZATIONS_PROJECTS]) self.organization.save() self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() @@ -1039,7 +1039,7 @@ class TestTeamAPI(APIBaseTest): # and the existing second level nesting is not preserved self._assert_replay_config_is({"ai_config": {"opt_in": None, "included_event_properties": ["and another"]}}) - def _assert_replay_config_is(self, expected: Dict[str, Any] | None) -> HttpResponse: + def _assert_replay_config_is(self, expected: dict[str, Any] | None) -> HttpResponse: get_response = self.client.get("/api/projects/@current/") assert get_response.status_code == status.HTTP_200_OK, get_response.json() assert get_response.json()["session_replay_config"] == expected @@ -1047,7 +1047,7 @@ class TestTeamAPI(APIBaseTest): return get_response def _patch_session_replay_config( - self, config: Dict[str, Any] | None, expected_status: int = status.HTTP_200_OK + self, config: dict[str, Any] | None, expected_status: int = status.HTTP_200_OK ) -> HttpResponse: patch_response = self.client.patch( "/api/projects/@current/", @@ -1057,13 +1057,13 @@ class TestTeamAPI(APIBaseTest): return patch_response - def _assert_linked_flag_config(self, expected_config: Dict | None) -> HttpResponse: + def _assert_linked_flag_config(self, expected_config: dict | None) -> HttpResponse: response = self.client.get("/api/projects/@current/") assert response.status_code == status.HTTP_200_OK assert response.json()["session_recording_linked_flag"] == expected_config return response - def _patch_linked_flag_config(self, config: Dict | None, expected_status: int = status.HTTP_200_OK) -> HttpResponse: + def _patch_linked_flag_config(self, config: dict | None, expected_status: int = status.HTTP_200_OK) -> HttpResponse: response = self.client.patch("/api/projects/@current/", {"session_recording_linked_flag": config}) assert response.status_code == expected_status, response.json() return response diff --git a/posthog/api/test/test_user.py b/posthog/api/test/test_user.py index 7113d50e5f7..4b682b4095e 100644 --- a/posthog/api/test/test_user.py +++ b/posthog/api/test/test_user.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Dict, List, cast +from typing import cast from unittest import mock from unittest.mock import ANY, Mock, patch from urllib.parse import quote @@ -326,7 +326,7 @@ class TestUserAPI(APIBaseTest): ) def _assert_set_scene_choice( - self, scene: str, dashboard: Dashboard, user: User, expected_choices: List[Dict] + self, scene: str, dashboard: Dashboard, user: User, expected_choices: list[dict] ) -> None: response = self.client.post( "/api/users/@me/scene_personalisation", diff --git a/posthog/api/uploaded_media.py b/posthog/api/uploaded_media.py index d4cea157c69..aba0384caf8 100644 --- a/posthog/api/uploaded_media.py +++ b/posthog/api/uploaded_media.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Dict, Optional +from typing import Optional import structlog from django.http import HttpResponse @@ -149,7 +149,7 @@ class MediaViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): detail="Object storage must be available to allow media uploads.", ) - def get_success_headers(self, location: str) -> Dict: + def get_success_headers(self, location: str) -> dict: try: return {"Location": location} except (TypeError, KeyError): diff --git a/posthog/api/utils.py b/posthog/api/utils.py index d34530cda14..ed1a571e6e4 100644 --- a/posthog/api/utils.py +++ b/posthog/api/utils.py @@ -4,7 +4,7 @@ import socket import urllib.parse from enum import Enum, auto from ipaddress import ip_address -from typing import List, Literal, Optional, Union, Tuple +from typing import Literal, Optional, Union from uuid import UUID import structlog @@ -64,7 +64,7 @@ def get_target_entity(filter: Union[Filter, StickinessFilter]) -> Entity: raise ValidationError("An entity must be provided for target entity to be determined") -def entity_from_order(order: Optional[str], entities: List[Entity]) -> Optional[Entity]: +def entity_from_order(order: Optional[str], entities: list[Entity]) -> Optional[Entity]: if not order: return None @@ -78,8 +78,8 @@ def retrieve_entity_from( entity_id: Optional[str], entity_type: Optional[str], entity_math: MathType, - events: List[Entity], - actions: List[Entity], + events: list[Entity], + actions: list[Entity], ) -> Optional[Entity]: """ Retrieves the entity from the events and actions. @@ -251,7 +251,7 @@ def create_event_definitions_sql( event_type: EventDefinitionType, is_enterprise: bool = False, conditions: str = "", - order_expressions: Optional[List[Tuple[str, Literal["ASC", "DESC"]]]] = None, + order_expressions: Optional[list[tuple[str, Literal["ASC", "DESC"]]]] = None, ) -> str: if order_expressions is None: order_expressions = [] @@ -305,7 +305,7 @@ def get_pk_or_uuid(queryset: QuerySet, key: Union[int, str]) -> QuerySet: return queryset.filter(pk=key) -def parse_bool(value: Union[str, List[str]]) -> bool: +def parse_bool(value: Union[str, list[str]]) -> bool: if value == "true": return True return False diff --git a/posthog/async_migrations/definition.py b/posthog/async_migrations/definition.py index 859b8af0881..52a53164bc7 100644 --- a/posthog/async_migrations/definition.py +++ b/posthog/async_migrations/definition.py @@ -1,13 +1,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, Optional, - Tuple, Union, ) +from collections.abc import Callable from posthog.constants import AnalyticsDBMS from posthog.models.utils import sane_repr @@ -36,9 +33,9 @@ class AsyncMigrationOperationSQL(AsyncMigrationOperation): self, *, sql: str, - sql_settings: Optional[Dict] = None, + sql_settings: Optional[dict] = None, rollback: Optional[str], - rollback_settings: Optional[Dict] = None, + rollback_settings: Optional[dict] = None, database: AnalyticsDBMS = AnalyticsDBMS.CLICKHOUSE, timeout_seconds: int = ASYNC_MIGRATIONS_DEFAULT_TIMEOUT_SECONDS, per_shard: bool = False, @@ -58,7 +55,7 @@ class AsyncMigrationOperationSQL(AsyncMigrationOperation): if self.rollback is not None: self._execute_op(query_id, self.rollback, self.rollback_settings) - def _execute_op(self, query_id: str, sql: str, settings: Optional[Dict]): + def _execute_op(self, query_id: str, sql: str, settings: Optional[dict]): from posthog.async_migrations.utils import ( execute_op_clickhouse, execute_op_postgres, @@ -91,16 +88,16 @@ class AsyncMigrationDefinition: description = "" # list of versions accepted for the services the migration relies on e.g. ClickHouse, Postgres - service_version_requirements: List[ServiceVersionRequirement] = [] + service_version_requirements: list[ServiceVersionRequirement] = [] # list of operations the migration will perform _in order_ - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] # name of async migration this migration depends on depends_on: Optional[str] = None # optional parameters for this async migration. Shown in the UI when starting the migration - parameters: Dict[str, Tuple[(Optional[Union[int, str]], str, Callable[[Any], Any])]] = {} + parameters: dict[str, tuple[(Optional[Union[int, str]], str, Callable[[Any], Any])]] = {} def __init__(self, name: str): self.name = name @@ -111,11 +108,11 @@ class AsyncMigrationDefinition: return True # run before starting the migration - def precheck(self) -> Tuple[bool, Optional[str]]: + def precheck(self) -> tuple[bool, Optional[str]]: return (True, None) # run at a regular interval while the migration is being executed - def healthcheck(self) -> Tuple[bool, Optional[str]]: + def healthcheck(self) -> tuple[bool, Optional[str]]: return (True, None) # return an int between 0-100 to specify how far along this migration is diff --git a/posthog/async_migrations/migrations/0001_events_sample_by.py b/posthog/async_migrations/migrations/0001_events_sample_by.py index 4098fd38f32..1d8fced273c 100644 --- a/posthog/async_migrations/migrations/0001_events_sample_by.py +++ b/posthog/async_migrations/migrations/0001_events_sample_by.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -17,7 +15,7 @@ class Migration(AsyncMigrationDefinition): posthog_max_version = "1.33.9" - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] def is_required(self): return False diff --git a/posthog/async_migrations/migrations/0002_events_sample_by.py b/posthog/async_migrations/migrations/0002_events_sample_by.py index 7038975b2af..2157c380f2d 100644 --- a/posthog/async_migrations/migrations/0002_events_sample_by.py +++ b/posthog/async_migrations/migrations/0002_events_sample_by.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import List from django.conf import settings @@ -76,7 +75,7 @@ class Migration(AsyncMigrationDefinition): # Note: This _should_ be impossible but hard to ensure. raise RuntimeError("Cannot run the migration as `events` table is already Distributed engine.") - create_table_op: List[AsyncMigrationOperation] = [ + create_table_op: list[AsyncMigrationOperation] = [ AsyncMigrationOperationSQL( database=AnalyticsDBMS.CLICKHOUSE, sql=f""" diff --git a/posthog/async_migrations/migrations/0005_person_replacing_by_version.py b/posthog/async_migrations/migrations/0005_person_replacing_by_version.py index 276d6c54abe..8740456c5e1 100644 --- a/posthog/async_migrations/migrations/0005_person_replacing_by_version.py +++ b/posthog/async_migrations/migrations/0005_person_replacing_by_version.py @@ -1,6 +1,5 @@ import json from functools import cached_property -from typing import Dict, List, Tuple import structlog from django.conf import settings @@ -238,9 +237,9 @@ class Migration(AsyncMigrationDefinition): ) return True - def _persons_insert_query(self, persons: List[Person]) -> Tuple[str, Dict]: + def _persons_insert_query(self, persons: list[Person]) -> tuple[str, dict]: values = [] - params: Dict = {} + params: dict = {} for i, person in enumerate(persons): created_at = person.created_at.strftime("%Y-%m-%d %H:%M:%S") # :TRICKY: We use a custom _timestamp to identify rows migrated during this migration diff --git a/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py index 62f539f3334..75c5510c9ef 100644 --- a/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -19,7 +17,7 @@ class Migration(AsyncMigrationDefinition): depends_on = "0005_person_replacing_by_version" - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] def is_required(self): return False diff --git a/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py index 99216ee936b..f51d171dfe8 100644 --- a/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Dict, Tuple, Union +from typing import Union import structlog from django.conf import settings @@ -289,7 +289,7 @@ class Migration(AsyncMigrationDefinition): self._check_person_data() self._check_groups_data() - def _where_clause(self) -> Tuple[str, Dict[str, Union[str, int]]]: + def _where_clause(self) -> tuple[str, dict[str, Union[str, int]]]: team_id = self.get_parameter("TEAM_ID") team_id_filter = f" AND team_id = %(team_id)s" if team_id else "" where_clause = f"WHERE timestamp > toDateTime(%(timestamp_lower_bound)s) AND timestamp < toDateTime(%(timestamp_upper_bound)s) {team_id_filter}" diff --git a/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py b/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py index 9b4c64c9af8..d679643b8a5 100644 --- a/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py +++ b/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -16,4 +14,4 @@ class Migration(AsyncMigrationDefinition): def is_required(self): return False - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] diff --git a/posthog/async_migrations/runner.py b/posthog/async_migrations/runner.py index 78f2afcf212..05946cfd3c9 100644 --- a/posthog/async_migrations/runner.py +++ b/posthog/async_migrations/runner.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import Optional import structlog from semantic_version.base import SimpleSpec @@ -281,7 +281,7 @@ def run_next_migration(candidate: str): trigger_migration(migration_instance) -def is_migration_dependency_fulfilled(migration_name: str) -> Tuple[bool, str]: +def is_migration_dependency_fulfilled(migration_name: str) -> tuple[bool, str]: dependency = get_async_migration_dependency(migration_name) dependency_ok: bool = ( @@ -292,8 +292,8 @@ def is_migration_dependency_fulfilled(migration_name: str) -> Tuple[bool, str]: def check_service_version_requirements( - service_version_requirements: List[ServiceVersionRequirement], -) -> Tuple[bool, str]: + service_version_requirements: list[ServiceVersionRequirement], +) -> tuple[bool, str]: for service_version_requirement in service_version_requirements: in_range, version = service_version_requirement.is_service_in_accepted_version() if not in_range: diff --git a/posthog/async_migrations/setup.py b/posthog/async_migrations/setup.py index 4493f137bd2..acc27b49543 100644 --- a/posthog/async_migrations/setup.py +++ b/posthog/async_migrations/setup.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from django.core.exceptions import ImproperlyConfigured from infi.clickhouse_orm.utils import import_submodules @@ -19,12 +19,12 @@ def reload_migration_definitions(): ALL_ASYNC_MIGRATIONS[name] = module.Migration(name) -ALL_ASYNC_MIGRATIONS: Dict[str, AsyncMigrationDefinition] = {} +ALL_ASYNC_MIGRATIONS: dict[str, AsyncMigrationDefinition] = {} -ASYNC_MIGRATION_TO_DEPENDENCY: Dict[str, Optional[str]] = {} +ASYNC_MIGRATION_TO_DEPENDENCY: dict[str, Optional[str]] = {} # inverted mapping of ASYNC_MIGRATION_TO_DEPENDENCY -DEPENDENCY_TO_ASYNC_MIGRATION: Dict[Optional[str], str] = {} +DEPENDENCY_TO_ASYNC_MIGRATION: dict[Optional[str], str] = {} ASYNC_MIGRATIONS_MODULE_PATH = "posthog.async_migrations.migrations" ASYNC_MIGRATIONS_EXAMPLE_MODULE_PATH = "posthog.async_migrations.examples" diff --git a/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py index 4e6588ad459..9a35ed05c82 100644 --- a/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py @@ -1,5 +1,4 @@ import json -from typing import Dict, List from uuid import uuid4 import pytest @@ -31,7 +30,7 @@ pytestmark = pytest.mark.async_migrations MIGRATION_NAME = "0007_persons_and_groups_on_events_backfill" -uuid1, uuid2, uuid3 = [UUIDT() for _ in range(3)] +uuid1, uuid2, uuid3 = (UUIDT() for _ in range(3)) # Clickhouse leaves behind blank/zero values for non-filled columns, these are checked against these constants ZERO_UUID = UUIDT(uuid_str="00000000-0000-0000-0000-000000000000") ZERO_DATE = "1970-01-01T00:00:00Z" @@ -44,7 +43,7 @@ def run_migration(): return start_async_migration(MIGRATION_NAME, ignore_posthog_version=True) -def query_events() -> List[Dict]: +def query_events() -> list[dict]: return query_with_columns( """ SELECT @@ -351,7 +350,7 @@ class Test0007PersonsAndGroupsOnEventsBackfill(AsyncMigrationBaseTest, Clickhous MIGRATION_DEFINITION.operations[-1].fn = old_fn def test_timestamp_boundaries(self): - _uuid1, _uuid2, _uuid3 = [UUIDT() for _ in range(3)] + _uuid1, _uuid2, _uuid3 = (UUIDT() for _ in range(3)) create_event( event_uuid=_uuid1, team=self.team, diff --git a/posthog/async_migrations/test/test_0010_move_old_partitions.py b/posthog/async_migrations/test/test_0010_move_old_partitions.py index d316f5f50e6..e249f17a434 100644 --- a/posthog/async_migrations/test/test_0010_move_old_partitions.py +++ b/posthog/async_migrations/test/test_0010_move_old_partitions.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.async_migrations MIGRATION_NAME = "0010_move_old_partitions" -uuid1, uuid2, uuid3 = [UUIDT() for _ in range(3)] +uuid1, uuid2, uuid3 = (UUIDT() for _ in range(3)) MIGRATION_DEFINITION = get_async_migration_definition(MIGRATION_NAME) diff --git a/posthog/async_migrations/utils.py b/posthog/async_migrations/utils.py index 20ad64cf7d7..ee7ecdbe4d2 100644 --- a/posthog/async_migrations/utils.py +++ b/posthog/async_migrations/utils.py @@ -1,6 +1,7 @@ import asyncio from datetime import datetime -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable import posthoganalytics import structlog diff --git a/posthog/auth.py b/posthog/auth.py index 6154ecb1ca0..f536ff30c20 100644 --- a/posthog/auth.py +++ b/posthog/auth.py @@ -1,7 +1,7 @@ import functools import re from datetime import timedelta -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from urllib.parse import urlsplit import jwt @@ -57,9 +57,9 @@ class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): def find_key_with_source( cls, request: Union[HttpRequest, Request], - request_data: Optional[Dict[str, Any]] = None, - extra_data: Optional[Dict[str, Any]] = None, - ) -> Optional[Tuple[str, str]]: + request_data: Optional[dict[str, Any]] = None, + extra_data: Optional[dict[str, Any]] = None, + ) -> Optional[tuple[str, str]]: """Try to find personal API key in request and return it along with where it was found.""" if "HTTP_AUTHORIZATION" in request.META: authorization_match = re.match(rf"^{cls.keyword}\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) @@ -80,8 +80,8 @@ class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): def find_key( cls, request: Union[HttpRequest, Request], - request_data: Optional[Dict[str, Any]] = None, - extra_data: Optional[Dict[str, Any]] = None, + request_data: Optional[dict[str, Any]] = None, + extra_data: Optional[dict[str, Any]] = None, ) -> Optional[str]: """Try to find personal API key in request and return it.""" key_with_source = cls.find_key_with_source(request, request_data, extra_data) @@ -121,7 +121,7 @@ class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): return personal_api_key_object - def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: + def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, None]]: personal_api_key_with_source = self.find_key_with_source(request) if not personal_api_key_with_source: return None @@ -190,7 +190,7 @@ class JwtAuthentication(authentication.BaseAuthentication): keyword = "Bearer" @classmethod - def authenticate(cls, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: + def authenticate(cls, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, None]]: if "HTTP_AUTHORIZATION" in request.META: authorization_match = re.match(rf"^Bearer\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) if authorization_match: @@ -222,7 +222,7 @@ class SharingAccessTokenAuthentication(authentication.BaseAuthentication): sharing_configuration: SharingConfiguration - def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, Any]]: + def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, Any]]: if sharing_access_token := request.GET.get("sharing_access_token"): if request.method not in ["GET", "HEAD"]: raise AuthenticationFailed(detail="Sharing access token can only be used for GET requests.") diff --git a/posthog/caching/calculate_results.py b/posthog/caching/calculate_results.py index 2fcf0ff04cc..4089323202e 100644 --- a/posthog/caching/calculate_results.py +++ b/posthog/caching/calculate_results.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union import structlog from sentry_sdk import capture_exception @@ -77,7 +77,7 @@ def get_cache_type_for_filter(cacheable: FilterType) -> CacheType: return CacheType.TRENDS -def get_cache_type_for_query(cacheable: Dict) -> CacheType: +def get_cache_type_for_query(cacheable: dict) -> CacheType: cache_type = None if cacheable.get("source"): @@ -92,7 +92,7 @@ def get_cache_type_for_query(cacheable: Dict) -> CacheType: return cache_type -def get_cache_type(cacheable: Optional[FilterType] | Optional[Dict]) -> CacheType: +def get_cache_type(cacheable: Optional[FilterType] | Optional[dict]) -> CacheType: if isinstance(cacheable, dict): return get_cache_type_for_query(cacheable) elif cacheable is not None: @@ -146,7 +146,7 @@ def calculate_for_query_based_insight( def calculate_for_filter_based_insight( insight: Insight, dashboard: Optional[Dashboard] -) -> Tuple[str, str, List | Dict]: +) -> tuple[str, str, list | dict]: filter = get_filter(data=insight.dashboard_filters(dashboard), team=insight.team) cache_key = generate_insight_cache_key(insight, dashboard) cache_type = get_cache_type(filter) @@ -161,7 +161,7 @@ def calculate_for_filter_based_insight( return cache_key, cache_type, calculate_result_by_cache_type(cache_type, filter, insight.team) -def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: Team) -> List[Dict[str, Any]]: +def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: Team) -> list[dict[str, Any]]: if cache_type == CacheType.FUNNEL: return _calculate_funnel(filter, team) else: @@ -169,7 +169,7 @@ def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: @timed("update_cache_item_timer.calculate_by_filter") -def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) -> List[Dict[str, Any]]: +def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) -> list[dict[str, Any]]: insight_class = CACHE_TYPE_TO_INSIGHT_CLASS[cache_type] if cache_type == CacheType.PATHS: @@ -180,7 +180,7 @@ def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) @timed("update_cache_item_timer.calculate_funnel") -def _calculate_funnel(filter: Filter, team: Team) -> List[Dict[str, Any]]: +def _calculate_funnel(filter: Filter, team: Team) -> list[dict[str, Any]]: if filter.funnel_viz_type == FunnelVizType.TRENDS: result = ClickhouseFunnelTrends(team=team, filter=filter).run() elif filter.funnel_viz_type == FunnelVizType.TIME_TO_CONVERT: @@ -193,7 +193,7 @@ def _calculate_funnel(filter: Filter, team: Team) -> List[Dict[str, Any]]: def cache_includes_latest_events( - payload: Dict, filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter] + payload: dict, filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter] ) -> bool: """ event_definition has last_seen_at timestamp @@ -218,7 +218,7 @@ def cache_includes_latest_events( return False -def _events_from_filter(filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter]) -> List[str]: +def _events_from_filter(filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter]) -> list[str]: """ If a filter only represents a set of events then we can use their last_seen_at to determine if the cache is up-to-date diff --git a/posthog/caching/fetch_from_cache.py b/posthog/caching/fetch_from_cache.py index fcbeb0b72e3..fe5d46ace3d 100644 --- a/posthog/caching/fetch_from_cache.py +++ b/posthog/caching/fetch_from_cache.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from django.utils.timezone import now from prometheus_client import Counter @@ -27,7 +27,7 @@ class InsightResult: is_cached: bool timezone: Optional[str] next_allowed_client_refresh: Optional[datetime] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None @dataclass(frozen=True) diff --git a/posthog/caching/insight_cache.py b/posthog/caching/insight_cache.py index d73486234df..97b5c691e46 100644 --- a/posthog/caching/insight_cache.py +++ b/posthog/caching/insight_cache.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from time import perf_counter -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Optional, cast from uuid import UUID import structlog @@ -49,7 +49,7 @@ def schedule_cache_updates(): logger.warn("No caches were found to be updated") -def fetch_states_in_need_of_updating(limit: int) -> List[Tuple[int, str, UUID]]: +def fetch_states_in_need_of_updating(limit: int) -> list[tuple[int, str, UUID]]: current_time = now() with connection.cursor() as cursor: cursor.execute( @@ -162,7 +162,7 @@ def update_cached_state( ) -def _extract_insight_dashboard(caching_state: InsightCachingState) -> Tuple[Insight, Optional[Dashboard]]: +def _extract_insight_dashboard(caching_state: InsightCachingState) -> tuple[Insight, Optional[Dashboard]]: if caching_state.dashboard_tile is not None: assert caching_state.dashboard_tile.insight is not None diff --git a/posthog/caching/insight_caching_state.py b/posthog/caching/insight_caching_state.py index a8ae36c14f0..ae3eb269425 100644 --- a/posthog/caching/insight_caching_state.py +++ b/posthog/caching/insight_caching_state.py @@ -1,7 +1,7 @@ from datetime import timedelta from enum import Enum from functools import cached_property -from typing import List, Optional, Union +from typing import Optional, Union import structlog from django.core.paginator import Paginator @@ -232,10 +232,10 @@ def _iterate_large_queryset(queryset, page_size): yield page.object_list -def _execute_insert(states: List[Optional[InsightCachingState]]): +def _execute_insert(states: list[Optional[InsightCachingState]]): from django.db import connection - models: List[InsightCachingState] = list(filter(None, states)) + models: list[InsightCachingState] = list(filter(None, states)) if len(models) == 0: return diff --git a/posthog/caching/insights_api.py b/posthog/caching/insights_api.py index 35a75cdf8a0..11760e2dc41 100644 --- a/posthog/caching/insights_api.py +++ b/posthog/caching/insights_api.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta from math import ceil from time import sleep -from typing import Optional, Tuple, Union +from typing import Optional, Union import zoneinfo from rest_framework import request @@ -37,7 +37,7 @@ def should_refresh_insight( *, request: request.Request, is_shared=False, -) -> Tuple[bool, timedelta]: +) -> tuple[bool, timedelta]: """Return whether the insight should be refreshed now, and what's the minimum wait time between refreshes. If a refresh already is being processed somewhere else, this function will wait for that to finish (or time out). diff --git a/posthog/caching/test/test_insight_cache.py b/posthog/caching/test/test_insight_cache.py index 9de2053f6c2..b86ac56a3de 100644 --- a/posthog/caching/test/test_insight_cache.py +++ b/posthog/caching/test/test_insight_cache.py @@ -1,5 +1,6 @@ from datetime import timedelta -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable from unittest.mock import call, patch import pytest diff --git a/posthog/caching/test/test_insight_caching_state.py b/posthog/caching/test/test_insight_caching_state.py index 03a36525552..47465786fb1 100644 --- a/posthog/caching/test/test_insight_caching_state.py +++ b/posthog/caching/test/test_insight_caching_state.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Optional, Union, cast from unittest.mock import patch import pytest @@ -42,7 +42,7 @@ def create_insight( is_shared=True, filters=filter_dict, deleted=False, - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> Insight: if mock_active_teams: mock_active_teams.return_value = {team.pk} if team_should_be_active else set() @@ -77,7 +77,7 @@ def create_tile( dashboard_tile_deleted=False, is_dashboard_shared=True, text_tile=False, - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> DashboardTile: if mock_active_teams: mock_active_teams.return_value = {team.pk} if team_should_be_active else set() @@ -295,7 +295,7 @@ def test_calculate_target_age( team: Team, user: User, create_item, - create_item_kw: Dict, + create_item_kw: dict, expected_target_age: TargetCacheAge, ): item = cast( diff --git a/posthog/caching/utils.py b/posthog/caching/utils.py index d0c6450cc7d..c56d0f33571 100644 --- a/posthog/caching/utils.py +++ b/posthog/caching/utils.py @@ -1,6 +1,6 @@ from datetime import datetime from dateutil.parser import isoparse -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import parser @@ -32,7 +32,7 @@ def ensure_is_date(candidate: Optional[Union[str, datetime]]) -> Optional[dateti return parser().parse(candidate) -def active_teams() -> Set[int]: +def active_teams() -> set[int]: """ Teams are stored in a sorted set. [{team_id: score}, {team_id: score}]. Their "score" is the number of seconds since last event. @@ -43,7 +43,7 @@ def active_teams() -> Set[int]: This assumes that the list of active teams is small enough to reasonably load in one go. """ redis = get_client() - all_teams: List[Tuple[bytes, float]] = redis.zrange(RECENTLY_ACCESSED_TEAMS_REDIS_KEY, 0, -1, withscores=True) + all_teams: list[tuple[bytes, float]] = redis.zrange(RECENTLY_ACCESSED_TEAMS_REDIS_KEY, 0, -1, withscores=True) if not all_teams: teams_by_recency = sync_execute( """ @@ -106,7 +106,7 @@ def is_stale(team: Team, date_to: datetime, interval: str, cached_result: Any) - return False last_refresh = ( - cached_result.get("last_refresh", None) if isinstance(cached_result, Dict) else cached_result.last_refresh + cached_result.get("last_refresh", None) if isinstance(cached_result, dict) else cached_result.last_refresh ) date_to = min([date_to, datetime.now(tz=ZoneInfo("UTC"))]) # can't be later than now diff --git a/posthog/celery.py b/posthog/celery.py index a78a7c94ad8..29c45c9b607 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -1,6 +1,5 @@ import os import time -from typing import Dict from celery import Celery from celery.signals import ( @@ -71,7 +70,7 @@ app.conf.broker_pool_limit = 0 app.steps["worker"].add(DjangoStructLogInitStep) -task_timings: Dict[str, float] = {} +task_timings: dict[str, float] = {} @setup_logging.connect diff --git a/posthog/clickhouse/client/connection.py b/posthog/clickhouse/client/connection.py index 31ae6cd291d..35c72a305fa 100644 --- a/posthog/clickhouse/client/connection.py +++ b/posthog/clickhouse/client/connection.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from enum import Enum -from functools import lru_cache +from functools import cache from clickhouse_driver import Client as SyncClient from clickhouse_pool import ChPool @@ -65,7 +65,7 @@ def default_client(): ) -@lru_cache(maxsize=None) +@cache def make_ch_pool(**overrides) -> ChPool: kwargs = { "host": settings.CLICKHOUSE_HOST, diff --git a/posthog/clickhouse/client/execute.py b/posthog/clickhouse/client/execute.py index b588badfc07..17af5683a6f 100644 --- a/posthog/clickhouse/client/execute.py +++ b/posthog/clickhouse/client/execute.py @@ -4,7 +4,8 @@ import types from contextlib import contextmanager from functools import lru_cache from time import perf_counter -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Optional, Union +from collections.abc import Sequence import sqlparse from clickhouse_driver import Client as SyncClient @@ -19,7 +20,7 @@ from posthog.settings import TEST from posthog.utils import generate_short_id, patchable InsertParams = Union[list, tuple, types.GeneratorType] -NonInsertParams = Dict[str, Any] +NonInsertParams = dict[str, Any] QueryArgs = Optional[Union[InsertParams, NonInsertParams]] thread_local_storage = threading.local() @@ -39,7 +40,7 @@ is_invalid_algorithm = lambda algo: algo not in CLICKHOUSE_SUPPORTED_JOIN_ALGORI @lru_cache(maxsize=1) -def default_settings() -> Dict: +def default_settings() -> dict: return { "join_algorithm": "direct,parallel_hash", "distributed_replica_max_ignored_errors": 1000, @@ -131,11 +132,11 @@ def query_with_columns( query: str, args: Optional[QueryArgs] = None, columns_to_remove: Optional[Sequence[str]] = None, - columns_to_rename: Optional[Dict[str, str]] = None, + columns_to_rename: Optional[dict[str, str]] = None, *, workload: Workload = Workload.DEFAULT, team_id: Optional[int] = None, -) -> List[Dict]: +) -> list[dict]: if columns_to_remove is None: columns_to_remove = [] if columns_to_rename is None: @@ -184,7 +185,7 @@ def _prepare_query( below predicate. """ prepared_args: Any = QueryArgs - if isinstance(args, (list, tuple, types.GeneratorType)): + if isinstance(args, list | tuple | types.GeneratorType): # If we get one of these it means we have an insert, let the clickhouse # client handle substitution here. rendered_sql = query diff --git a/posthog/clickhouse/client/migration_tools.py b/posthog/clickhouse/client/migration_tools.py index f71abd489fd..aa3100b548b 100644 --- a/posthog/clickhouse/client/migration_tools.py +++ b/posthog/clickhouse/client/migration_tools.py @@ -1,4 +1,5 @@ -from typing import Callable, Union +from typing import Union +from collections.abc import Callable from infi.clickhouse_orm import migrations diff --git a/posthog/clickhouse/materialized_columns/column.py b/posthog/clickhouse/materialized_columns/column.py index 70aca94511a..a206c051395 100644 --- a/posthog/clickhouse/materialized_columns/column.py +++ b/posthog/clickhouse/materialized_columns/column.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Dict, List, Literal, Tuple, Union +from typing import Literal, Union from posthog.cache_utils import cache_for from posthog.models.property import PropertyName, TableColumn, TableWithProperties @@ -12,7 +12,7 @@ TablesWithMaterializedColumns = Union[TableWithProperties, Literal["session_reco @cache_for(timedelta(minutes=15)) def get_materialized_columns( table: TablesWithMaterializedColumns, -) -> Dict[Tuple[PropertyName, TableColumn], ColumnName]: +) -> dict[tuple[PropertyName, TableColumn], ColumnName]: return {} @@ -28,7 +28,7 @@ def materialize( def backfill_materialized_columns( table: TableWithProperties, - properties: List[Tuple[PropertyName, TableColumn]], + properties: list[tuple[PropertyName, TableColumn]], backfill_period: timedelta, test_settings=None, ) -> None: diff --git a/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py b/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py index 877139c155e..85d6664e475 100644 --- a/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py +++ b/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py @@ -1,6 +1,4 @@ -from typing import List - -operations: List = [ +operations: list = [ # this migration has been amended to be entirely No-op # it has applied successfully in Prod US where it was a no-op # as all tables/columns it affected already existed diff --git a/posthog/clickhouse/system_status.py b/posthog/clickhouse/system_status.py index e04c6bf7597..eec283f3b5a 100644 --- a/posthog/clickhouse/system_status.py +++ b/posthog/clickhouse/system_status.py @@ -1,6 +1,6 @@ from datetime import timedelta from os.path import abspath, dirname, join -from typing import Dict, Generator, List, Tuple +from collections.abc import Generator from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -27,7 +27,7 @@ SLOW_AFTER = relativedelta(hours=6) CLICKHOUSE_FLAMEGRAPH_EXECUTABLE = abspath(join(dirname(__file__), "bin", "clickhouse-flamegraph")) FLAMEGRAPH_PL = abspath(join(dirname(__file__), "bin", "flamegraph.pl")) -SystemStatusRow = Dict +SystemStatusRow = dict def system_status() -> Generator[SystemStatusRow, None, None]: @@ -179,7 +179,7 @@ def is_alive() -> bool: return False -def dead_letter_queue_ratio() -> Tuple[bool, int]: +def dead_letter_queue_ratio() -> tuple[bool, int]: dead_letter_queue_events_last_day = get_dead_letter_queue_events_last_24h() total_events_ingested_last_day = sync_execute( @@ -199,14 +199,14 @@ def dead_letter_queue_ratio_ok_cached() -> bool: return dead_letter_queue_ratio()[0] -def get_clickhouse_running_queries() -> List[Dict]: +def get_clickhouse_running_queries() -> list[dict]: return query_with_columns( "SELECT elapsed as duration, query, * FROM system.processes ORDER BY duration DESC", columns_to_remove=["address", "initial_address", "elapsed"], ) -def get_clickhouse_slow_log() -> List[Dict]: +def get_clickhouse_slow_log() -> list[dict]: return query_with_columns( f""" SELECT query_duration_ms as duration, query, * diff --git a/posthog/conftest.py b/posthog/conftest.py index 7d2895eb380..8f3f233358c 100644 --- a/posthog/conftest.py +++ b/posthog/conftest.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any import pytest from django.conf import settings @@ -22,7 +22,7 @@ def create_clickhouse_tables(num_tables: int): ) # REMEMBER TO ADD ANY NEW CLICKHOUSE TABLES TO THIS ARRAY! - CREATE_TABLE_QUERIES: Tuple[Any, ...] = CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES + CREATE_TABLE_QUERIES: tuple[Any, ...] = CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES # Check if all the tables have already been created if num_tables == len(CREATE_TABLE_QUERIES): diff --git a/posthog/decorators.py b/posthog/decorators.py index 955bb9d0851..bb012701033 100644 --- a/posthog/decorators.py +++ b/posthog/decorators.py @@ -1,6 +1,7 @@ from enum import Enum from functools import wraps -from typing import Any, Callable, Dict, List, TypeVar, Union, cast +from typing import Any, TypeVar, Union, cast +from collections.abc import Callable from django.urls import resolve from django.utils.timezone import now @@ -25,7 +26,7 @@ class CacheType(str, Enum): PATHS = "Path" -ResultPackage = Union[Dict[str, Any], List[Dict[str, Any]]] +ResultPackage = Union[dict[str, Any], list[dict[str, Any]]] T = TypeVar("T", bound=ResultPackage) U = TypeVar("U", bound=GenericViewSet) diff --git a/posthog/demo/legacy/data_generator.py b/posthog/demo/legacy/data_generator.py index ccc9f163e6c..d507e65c31c 100644 --- a/posthog/demo/legacy/data_generator.py +++ b/posthog/demo/legacy/data_generator.py @@ -1,4 +1,3 @@ -from typing import Dict, List from uuid import uuid4 from posthog.models import Person, PersonDistinctId, Team @@ -13,9 +12,9 @@ class DataGenerator: self.team = team self.n_days = n_days self.n_people = n_people - self.events: List[Dict] = [] - self.snapshots: List[Dict] = [] - self.distinct_ids: List[str] = [] + self.events: list[dict] = [] + self.snapshots: list[dict] = [] + self.distinct_ids: list[str] = [] def create(self, dashboards=True): self.create_missing_events_and_properties() diff --git a/posthog/demo/legacy/web_data_generator.py b/posthog/demo/legacy/web_data_generator.py index aa0836d3db7..81127009225 100644 --- a/posthog/demo/legacy/web_data_generator.py +++ b/posthog/demo/legacy/web_data_generator.py @@ -1,7 +1,7 @@ import json import random from datetime import timedelta -from typing import Any, Dict, List +from typing import Any from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -199,11 +199,11 @@ class WebDataGenerator(DataGenerator): return super().make_person(index) @cached_property - def demo_data(self) -> List[Dict[str, Any]]: - with open(get_absolute_path("demo/legacy/demo_people.json"), "r") as demo_data_file: + def demo_data(self) -> list[dict[str, Any]]: + with open(get_absolute_path("demo/legacy/demo_people.json")) as demo_data_file: return json.load(demo_data_file) @cached_property - def demo_recording(self) -> Dict[str, Any]: - with open(get_absolute_path("demo/legacy/hogflix_session_recording.json"), "r") as demo_session_file: + def demo_recording(self) -> dict[str, Any]: + with open(get_absolute_path("demo/legacy/hogflix_session_recording.json")) as demo_session_file: return json.load(demo_session_file) diff --git a/posthog/demo/matrix/manager.py b/posthog/demo/matrix/manager.py index 507ea09581d..ce073a6126f 100644 --- a/posthog/demo/matrix/manager.py +++ b/posthog/demo/matrix/manager.py @@ -1,7 +1,7 @@ import datetime as dt import json from time import sleep -from typing import Any, Dict, List, Literal, Optional, Tuple, cast +from typing import Any, Literal, Optional, cast from django.conf import settings from django.core import exceptions @@ -55,13 +55,13 @@ class MatrixManager: password: Optional[str] = None, is_staff: bool = False, disallow_collision: bool = False, - ) -> Tuple[Organization, Team, User]: + ) -> tuple[Organization, Team, User]: """If there's an email collision in signup in the demo environment, we treat it as a login.""" existing_user: Optional[User] = User.objects.filter(email=email).first() if existing_user is None: if self.print_steps: print(f"Creating demo organization, project, and user...") - organization_kwargs: Dict[str, Any] = {"name": organization_name} + organization_kwargs: dict[str, Any] = {"name": organization_name} if settings.DEMO: organization_kwargs["plugins_access_level"] = Organization.PluginsAccessLevel.INSTALL with transaction.atomic(): @@ -241,7 +241,7 @@ class MatrixManager: ["team_id", "is_deleted", "_timestamp", "_offset", "_partition"], {"id": "uuid"}, ) - bulk_persons: Dict[str, Person] = {} + bulk_persons: dict[str, Person] = {} for row in clickhouse_persons: properties = json.loads(row.pop("properties", "{}")) bulk_persons[row["uuid"]] = Person(team_id=target_team_id, properties=properties, **row) @@ -317,7 +317,7 @@ class MatrixManager: self._save_future_sim_events(team, subject.future_events) @staticmethod - def _save_past_sim_events(team: Team, events: List[SimEvent]): + def _save_past_sim_events(team: Team, events: list[SimEvent]): """Past events are saved into ClickHouse right away (via Kafka of course).""" from posthog.models.event.util import create_event @@ -346,7 +346,7 @@ class MatrixManager: ) @staticmethod - def _save_future_sim_events(team: Team, events: List[SimEvent]): + def _save_future_sim_events(team: Team, events: list[SimEvent]): """Future events are not saved immediately, instead they're scheduled for ingestion via event buffer.""" # TODO: This used the plugin server's Graphile Worker-based event buffer, but the event buffer is no more @@ -356,7 +356,7 @@ class MatrixManager: team: Team, type_index: Literal[0, 1, 2, 3, 4], key: str, - properties: Dict[str, Any], + properties: dict[str, Any], timestamp: dt.datetime, ): from posthog.models.group.util import raw_create_group_ch diff --git a/posthog/demo/matrix/matrix.py b/posthog/demo/matrix/matrix.py index c2d3a5f2eb4..382e70d85b7 100644 --- a/posthog/demo/matrix/matrix.py +++ b/posthog/demo/matrix/matrix.py @@ -3,13 +3,7 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from typing import ( Any, - DefaultDict, - Deque, - Dict, - List, Optional, - Set, - Type, ) import mimesis @@ -38,7 +32,7 @@ class Cluster(ABC): end: timezone.datetime # End of the simulation (might be same as now or later) radius: int - people_matrix: List[List[SimPerson]] # Grid containing all people in the cluster + people_matrix: list[list[SimPerson]] # Grid containing all people in the cluster random: mimesis.random.Random properties_provider: PropertiesProvider @@ -52,7 +46,7 @@ class Cluster(ABC): _simulation_time: dt.datetime _reached_now: bool - _scheduled_effects: Deque[Effect] + _scheduled_effects: deque[Effect] def __init__(self, *, index: int, matrix: "Matrix") -> None: self.index = index @@ -98,7 +92,7 @@ class Cluster(ABC): """Return a value between 0 and 1 determining how far into the overall simulation should this cluster be initiated.""" return self.random.random() - def list_neighbors(self, person: SimPerson) -> List[SimPerson]: + def list_neighbors(self, person: SimPerson) -> list[SimPerson]: """Return a list of neighbors of a person at (x, y).""" x, y = person.x, person.y neighbors = [] @@ -141,7 +135,7 @@ class Cluster(ABC): while self._scheduled_effects and self._scheduled_effects[0].timestamp <= until: effect = self._scheduled_effects.popleft() self.simulation_time = effect.timestamp - resolved_targets: List[SimPerson] + resolved_targets: list[SimPerson] if effect.target == Effect.Target.SELF: resolved_targets = [effect.source] elif effect.target == Effect.Target.ALL_NEIGHBORS: @@ -155,7 +149,7 @@ class Cluster(ABC): effect.callback(target) @property - def people(self) -> Set[SimPerson]: + def people(self) -> set[SimPerson]: return {person for row in self.people_matrix for person in row} @property @@ -198,17 +192,17 @@ class Matrix(ABC): """ PRODUCT_NAME: str - CLUSTER_CLASS: Type[Cluster] - PERSON_CLASS: Type[SimPerson] + CLUSTER_CLASS: type[Cluster] + PERSON_CLASS: type[SimPerson] start: dt.datetime now: dt.datetime end: dt.datetime group_type_index_offset: int # A mapping of groups. The first key is the group type, the second key is the group key. - groups: DefaultDict[str, DefaultDict[str, Dict[str, Any]]] - distinct_id_to_person: Dict[str, SimPerson] - clusters: List[Cluster] + groups: defaultdict[str, defaultdict[str, dict[str, Any]]] + distinct_id_to_person: dict[str, SimPerson] + clusters: list[Cluster] is_complete: Optional[bool] server_client: SimServerClient @@ -257,7 +251,7 @@ class Matrix(ABC): self.is_complete = None @property - def people(self) -> List[SimPerson]: + def people(self) -> list[SimPerson]: return [person for cluster in self.clusters for person in cluster.people] @abstractmethod @@ -273,7 +267,7 @@ class Matrix(ABC): cluster.simulate() self.is_complete = True - def _update_group(self, group_type: str, group_key: str, set_properties: Dict[str, Any]): + def _update_group(self, group_type: str, group_key: str, set_properties: dict[str, Any]): if len(self.groups) == GROUP_TYPES_LIMIT and group_type not in self.groups: raise Exception(f"Cannot add group type {group_type} to simulation, limit of {GROUP_TYPES_LIMIT} reached!") self.groups[group_type][group_key].update(set_properties) diff --git a/posthog/demo/matrix/models.py b/posthog/demo/matrix/models.py index e1698d7dd7b..c09fcae8cbb 100644 --- a/posthog/demo/matrix/models.py +++ b/posthog/demo/matrix/models.py @@ -8,17 +8,12 @@ from itertools import chain from typing import ( TYPE_CHECKING, Any, - Callable, - DefaultDict, - Dict, Generic, - Iterable, - List, Literal, Optional, - Set, TypeVar, ) +from collections.abc import Callable, Iterable from urllib.parse import urlparse, parse_qs from uuid import UUID @@ -77,7 +72,7 @@ PROPERTIES_WITH_IMPLICIT_INITIAL_VALUE_TRACKING = { "$referrer", } -Properties = Dict[str, Any] +Properties = dict[str, Any] class SimSessionIntent(Enum): @@ -330,23 +325,23 @@ class SimPerson(ABC): timezone: str # Exposed state - present - past_events: List[SimEvent] - future_events: List[SimEvent] + past_events: list[SimEvent] + future_events: list[SimEvent] # Exposed state - at `now` - distinct_ids_at_now: Set[str] + distinct_ids_at_now: set[str] properties_at_now: Properties first_seen_at: Optional[dt.datetime] last_seen_at: Optional[dt.datetime] # Internal state active_client: SimBrowserClient # Client being used by person - all_time_pageview_counts: DefaultDict[str, int] # Pageview count per URL across all time - session_pageview_counts: DefaultDict[str, int] # Pageview count per URL across the ongoing session + all_time_pageview_counts: defaultdict[str, int] # Pageview count per URL across all time + session_pageview_counts: defaultdict[str, int] # Pageview count per URL across the ongoing session active_session_intent: Optional[SimSessionIntent] wake_up_by: dt.datetime - _groups: Dict[str, str] - _distinct_ids: Set[str] + _groups: dict[str, str] + _distinct_ids: set[str] _properties: Properties def __init__(self, *, kernel: bool, cluster: "Cluster", x: int, y: int): @@ -397,7 +392,7 @@ class SimPerson(ABC): # Abstract methods - def decide_feature_flags(self) -> Dict[str, Any]: + def decide_feature_flags(self) -> dict[str, Any]: """Determine feature flags in force at present.""" return {} diff --git a/posthog/demo/matrix/randomization.py b/posthog/demo/matrix/randomization.py index ca6bcfd5886..d017c295321 100644 --- a/posthog/demo/matrix/randomization.py +++ b/posthog/demo/matrix/randomization.py @@ -1,10 +1,9 @@ from enum import Enum -from typing import Dict, List, Tuple import mimesis import mimesis.random -WeightedPool = Tuple[List[str], List[int]] +WeightedPool = tuple[list[str], list[int]] class Industry(str, Enum): @@ -27,12 +26,12 @@ class PropertiesProvider(mimesis.BaseProvider): ["Desktop", "Mobile", "Tablet"], [8, 1, 1], ) - OS_WEIGHTED_POOLS: Dict[str, WeightedPool] = { + OS_WEIGHTED_POOLS: dict[str, WeightedPool] = { "Desktop": (["Windows", "Mac OS X", "Linux", "Chrome OS"], [18, 16, 7, 1]), "Mobile": (["iOS", "Android"], [1, 1]), "Tablet": (["iOS", "Android"], [1, 1]), } - BROWSER_WEIGHTED_POOLS: Dict[str, WeightedPool] = { + BROWSER_WEIGHTED_POOLS: dict[str, WeightedPool] = { "Windows": ( ["Chrome", "Firefox", "Opera", "Microsoft Edge", "Internet Explorer"], [12, 4, 2, 1, 1], @@ -65,7 +64,7 @@ class PropertiesProvider(mimesis.BaseProvider): random: mimesis.random.Random - def device_type_os_browser(self) -> Tuple[str, str, str]: + def device_type_os_browser(self) -> tuple[str, str, str]: device_type_pool, device_type_weights = self.DEVICE_TYPE_WEIGHTED_POOL device_type = self.random.choices(device_type_pool, device_type_weights)[0] os_pool, os_weights = self.OS_WEIGHTED_POOLS[device_type] diff --git a/posthog/demo/matrix/taxonomy_inference.py b/posthog/demo/matrix/taxonomy_inference.py index cc5686de96b..e05dc67f333 100644 --- a/posthog/demo/matrix/taxonomy_inference.py +++ b/posthog/demo/matrix/taxonomy_inference.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Tuple +from typing import Optional from django.utils import timezone @@ -9,7 +9,7 @@ from posthog.models.person.sql import PERSONS_TABLE from posthog.models.property_definition import PropertyType -def infer_taxonomy_for_team(team_id: int) -> Tuple[int, int, int]: +def infer_taxonomy_for_team(team_id: int) -> tuple[int, int, int]: """Infer event and property definitions based on ClickHouse data. In production, the plugin server is responsible for this - but in demo data we insert directly to ClickHouse. @@ -55,13 +55,13 @@ def infer_taxonomy_for_team(team_id: int) -> Tuple[int, int, int]: return len(event_definitions), len(property_definitions), len(event_properties) -def _get_events_last_seen_at(team_id: int) -> Dict[str, timezone.datetime]: +def _get_events_last_seen_at(team_id: int) -> dict[str, timezone.datetime]: from posthog.client import sync_execute return dict(sync_execute(_GET_EVENTS_LAST_SEEN_AT, {"team_id": team_id})) -def _get_property_types(team_id: int) -> Dict[str, Optional[PropertyType]]: +def _get_property_types(team_id: int) -> dict[str, Optional[PropertyType]]: """Determine property types based on ClickHouse data.""" from posthog.client import sync_execute @@ -87,14 +87,14 @@ def _infer_property_type(sample_json_value: str) -> Optional[PropertyType]: parsed_value = json.loads(sample_json_value) if isinstance(parsed_value, bool): return PropertyType.Boolean - if isinstance(parsed_value, (float, int)): + if isinstance(parsed_value, float | int): return PropertyType.Numeric if isinstance(parsed_value, str): return PropertyType.String return None -def _get_event_property_pairs(team_id: int) -> List[Tuple[str, str]]: +def _get_event_property_pairs(team_id: int) -> list[tuple[str, str]]: """Determine which properties have been since with which events based on ClickHouse data.""" from posthog.client import sync_execute diff --git a/posthog/demo/products/hedgebox/models.py b/posthog/demo/products/hedgebox/models.py index 324dc6b4737..dd694f64aac 100644 --- a/posthog/demo/products/hedgebox/models.py +++ b/posthog/demo/products/hedgebox/models.py @@ -5,11 +5,7 @@ from enum import Enum, auto from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Set, - Tuple, cast, ) from urllib.parse import urlencode, urlparse, urlunparse @@ -114,9 +110,9 @@ class HedgeboxFile: class HedgeboxAccount: id: str created_at: dt.datetime - team_members: Set["HedgeboxPerson"] + team_members: set["HedgeboxPerson"] plan: HedgeboxPlan - files: Set[HedgeboxFile] = field(default_factory=set) + files: set[HedgeboxFile] = field(default_factory=set) was_billing_scheduled: bool = field(default=False) @property @@ -247,7 +243,7 @@ class HedgeboxPerson(SimPerson): # Abstract methods - def decide_feature_flags(self) -> Dict[str, Any]: + def decide_feature_flags(self) -> dict[str, Any]: if ( self.cluster.simulation_time >= self.cluster.matrix.new_signup_page_experiment_start and self.cluster.simulation_time < self.cluster.matrix.new_signup_page_experiment_end @@ -292,7 +288,7 @@ class HedgeboxPerson(SimPerson): # Very low affinity users aren't interested # Non-kernel business users can't log in or sign up return None - possible_intents_with_weights: List[Tuple[HedgeboxSessionIntent, float]] = [] + possible_intents_with_weights: list[tuple[HedgeboxSessionIntent, float]] = [] if self.invite_to_use_id: possible_intents_with_weights.append((HedgeboxSessionIntent.JOIN_TEAM, 1)) elif self.file_to_view: @@ -342,8 +338,8 @@ class HedgeboxPerson(SimPerson): if possible_intents_with_weights: possible_intents, weights = zip(*possible_intents_with_weights) return self.cluster.random.choices( - cast(Tuple[HedgeboxSessionIntent], possible_intents), - cast(Tuple[float], weights), + cast(tuple[HedgeboxSessionIntent], possible_intents), + cast(tuple[float], weights), )[0] else: return None @@ -807,10 +803,10 @@ class HedgeboxPerson(SimPerson): self.advance_timer(self.cluster.random.uniform(0.1, 0.2)) @property - def invitable_neighbors(self) -> List["HedgeboxPerson"]: + def invitable_neighbors(self) -> list["HedgeboxPerson"]: return [ neighbor - for neighbor in cast(List[HedgeboxPerson], self.cluster.list_neighbors(self)) + for neighbor in cast(list[HedgeboxPerson], self.cluster.list_neighbors(self)) if neighbor.is_invitable ] diff --git a/posthog/email.py b/posthog/email.py index 61edb7ae593..3590723f708 100644 --- a/posthog/email.py +++ b/posthog/email.py @@ -1,5 +1,5 @@ import sys -from typing import Dict, List, Optional +from typing import Optional import lxml import toronado @@ -54,9 +54,9 @@ EMAIL_TASK_KWARGS = { @shared_task(**EMAIL_TASK_KWARGS) def _send_email( campaign_key: str, - to: List[Dict[str, str]], + to: list[dict[str, str]], subject: str, - headers: Dict, + headers: dict, txt_body: str = "", html_body: str = "", reply_to: Optional[str] = None, @@ -65,8 +65,8 @@ def _send_email( Sends built email message asynchronously. """ - messages: List = [] - records: List = [] + messages: list = [] + records: list = [] with transaction.atomic(): for dest in to: @@ -135,8 +135,8 @@ class EmailMessage: campaign_key: str, subject: str, template_name: str, - template_context: Optional[Dict] = None, - headers: Optional[Dict] = None, + template_context: Optional[dict] = None, + headers: Optional[dict] = None, reply_to: Optional[str] = None, ): if template_context is None: @@ -153,7 +153,7 @@ class EmailMessage: self.html_body = inline_css(template.render(template_context)) self.txt_body = "" self.headers = headers if headers else {} - self.to: List[Dict[str, str]] = [] + self.to: list[dict[str, str]] = [] self.reply_to = reply_to def add_recipient(self, email: str, name: Optional[str] = None) -> None: diff --git a/posthog/errors.py b/posthog/errors.py index d028522a599..70b3d46dd3c 100644 --- a/posthog/errors.py +++ b/posthog/errors.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import Dict, Optional +from typing import Optional from clickhouse_driver.errors import ServerException @@ -91,7 +91,7 @@ def look_up_error_code_meta(error: ServerException) -> ErrorCodeMeta: # # Remember to add back the `user_safe` args though! CLICKHOUSE_UNKNOWN_EXCEPTION = ErrorCodeMeta("UNKNOWN_EXCEPTION") -CLICKHOUSE_ERROR_CODE_LOOKUP: Dict[int, ErrorCodeMeta] = { +CLICKHOUSE_ERROR_CODE_LOOKUP: dict[int, ErrorCodeMeta] = { 0: ErrorCodeMeta("OK"), 1: ErrorCodeMeta("UNSUPPORTED_METHOD"), 2: ErrorCodeMeta("UNSUPPORTED_PARAMETER"), diff --git a/posthog/event_usage.py b/posthog/event_usage.py index ae8432c6b27..cf74b599363 100644 --- a/posthog/event_usage.py +++ b/posthog/event_usage.py @@ -2,7 +2,7 @@ Module to centralize event reporting on the server-side. """ -from typing import Dict, List, Optional +from typing import Optional import posthoganalytics @@ -107,7 +107,7 @@ def report_user_logged_in( ) -def report_user_updated(user: User, updated_attrs: List[str]) -> None: +def report_user_updated(user: User, updated_attrs: list[str]) -> None: """ Reports a user has been updated. This includes current_team, current_organization & password. """ @@ -217,7 +217,7 @@ def report_user_organization_membership_level_changed( ) -def report_user_action(user: User, event: str, properties: Optional[Dict] = None, team: Optional[Team] = None): +def report_user_action(user: User, event: str, properties: Optional[dict] = None, team: Optional[Team] = None): if properties is None: properties = {} posthoganalytics.capture( @@ -254,8 +254,8 @@ def groups(organization: Optional[Organization] = None, team: Optional[Team] = N def report_team_action( team: Team, event: str, - properties: Optional[Dict] = None, - group_properties: Optional[Dict] = None, + properties: Optional[dict] = None, + group_properties: Optional[dict] = None, ): """ For capturing events where it is unclear which user was the core actor we can use the team instead @@ -271,8 +271,8 @@ def report_team_action( def report_organization_action( organization: Organization, event: str, - properties: Optional[Dict] = None, - group_properties: Optional[Dict] = None, + properties: Optional[dict] = None, + group_properties: Optional[dict] = None, ): """ For capturing events where it is unclear which user was the core actor we can use the organization instead diff --git a/posthog/filters.py b/posthog/filters.py index ac098dea92c..911edcf4596 100644 --- a/posthog/filters.py +++ b/posthog/filters.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Optional, TypeVar, Union from django.db import models from django.db.models import Q @@ -19,7 +19,7 @@ class TermSearchFilterBackend(filters.BaseFilterBackend): # The URL query parameter used for the search. search_param = settings.api_settings.SEARCH_PARAM - def get_search_fields(self, view: APIView) -> Optional[List[str]]: + def get_search_fields(self, view: APIView) -> Optional[list[str]]: """ Search fields are obtained from the view. """ @@ -59,10 +59,10 @@ class TermSearchFilterBackend(filters.BaseFilterBackend): def term_search_filter_sql( - search_fields: List[str], + search_fields: list[str], search_terms: Optional[str] = "", search_extra: Optional[str] = "", -) -> Tuple[str, dict]: +) -> tuple[str, dict]: if not search_fields or not search_terms: return "", {} diff --git a/posthog/gzip_middleware.py b/posthog/gzip_middleware.py index 701f31b5dbe..cfd57eea005 100644 --- a/posthog/gzip_middleware.py +++ b/posthog/gzip_middleware.py @@ -1,5 +1,4 @@ import re -from typing import List from django.conf import settings from django.middleware.gzip import GZipMiddleware @@ -9,7 +8,7 @@ class InvalidGZipAllowList(Exception): pass -def allowed_path(path: str, allowed_paths: List) -> bool: +def allowed_path(path: str, allowed_paths: list) -> bool: return any(pattern.search(path) for pattern in allowed_paths) diff --git a/posthog/health.py b/posthog/health.py index 1ca35d6fe73..72012928feb 100644 --- a/posthog/health.py +++ b/posthog/health.py @@ -17,7 +17,8 @@ # changes to them are deliberate, as otherwise we could introduce unexpected # behaviour in deployments. -from typing import Callable, Dict, List, Literal, cast, get_args +from typing import Literal, cast, get_args +from collections.abc import Callable from django.core.cache import cache from django.db import DEFAULT_DB_ALIAS @@ -35,7 +36,7 @@ logger = get_logger(__name__) ServiceRole = Literal["events", "web", "worker", "decide"] -service_dependencies: Dict[ServiceRole, List[str]] = { +service_dependencies: dict[ServiceRole, list[str]] = { "events": ["http", "kafka_connected"], "web": [ "http", @@ -66,7 +67,7 @@ service_dependencies: Dict[ServiceRole, List[str]] = { # if atleast one of the checks is True, then the service is considered healthy # for the given role -service_conditional_dependencies: Dict[ServiceRole, List[str]] = { +service_conditional_dependencies: dict[ServiceRole, list[str]] = { "decide": ["cache", "postgres_flags"], } @@ -110,7 +111,7 @@ def readyz(request: HttpRequest): if role and role not in get_args(ServiceRole): return JsonResponse({"error": "InvalidRole"}, status=400) - available_checks: Dict[str, Callable] = { + available_checks: dict[str, Callable] = { "clickhouse": is_clickhouse_connected, "postgres": is_postgres_connected, "postgres_flags": lambda: is_postgres_connected(DATABASE_FOR_FLAG_MATCHING), diff --git a/posthog/heatmaps/heatmaps_api.py b/posthog/heatmaps/heatmaps_api.py index 35a424d4f15..f06899e3c41 100644 --- a/posthog/heatmaps/heatmaps_api.py +++ b/posthog/heatmaps/heatmaps_api.py @@ -1,5 +1,5 @@ from datetime import datetime, date -from typing import Any, Dict, List +from typing import Any, List # noqa: UP035 from rest_framework import viewsets, request, response, serializers, status @@ -80,7 +80,7 @@ class HeatmapsRequestSerializer(serializers.Serializer): except Exception: raise serializers.ValidationError("Error parsing provided date_from: {}".format(value)) - def validate(self, values) -> Dict: + def validate(self, values) -> dict: url_exact = values.get("url_exact", None) url_pattern = values.get("url_pattern", None) if isinstance(url_exact, str) and isinstance(url_pattern, str): @@ -154,10 +154,10 @@ class HeatmapViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): return aggregation_count @staticmethod - def _predicate_expressions(placeholders: Dict[str, Expr]) -> List[ast.Expr]: - predicate_expressions: List[ast.Expr] = [] + def _predicate_expressions(placeholders: dict[str, Expr]) -> List[ast.Expr]: # noqa: UP006 + predicate_expressions: list[ast.Expr] = [] - predicate_mapping: Dict[str, str] = { + predicate_mapping: dict[str, str] = { # should always have values "date_from": "timestamp >= {date_from}", "type": "`type` = {type}", diff --git a/posthog/heatmaps/test/test_heatmaps_api.py b/posthog/heatmaps/test/test_heatmaps_api.py index 4f0896c5ef1..e07343a2760 100644 --- a/posthog/heatmaps/test/test_heatmaps_api.py +++ b/posthog/heatmaps/test/test_heatmaps_api.py @@ -1,5 +1,4 @@ import math -from typing import Dict import freezegun from django.http import HttpResponse @@ -48,21 +47,21 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) CLASS_DATA_LEVEL_SETUP = False def _assert_heatmap_no_result_count( - self, params: Dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK + self, params: dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK ) -> None: response = self._get_heatmap(params, expected_status_code) if response.status_code == status.HTTP_200_OK: assert len(response.json()["results"]) == 0 def _assert_heatmap_single_result_count( - self, params: Dict[str, str | int | None] | None, expected_grouped_count: int + self, params: dict[str, str | int | None] | None, expected_grouped_count: int ) -> None: response = self._get_heatmap(params) assert len(response.json()["results"]) == 1 assert response.json()["results"][0]["count"] == expected_grouped_count def _get_heatmap( - self, params: Dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK + self, params: dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK ) -> HttpResponse: if params is None: params = {} diff --git a/posthog/helpers/dashboard_templates.py b/posthog/helpers/dashboard_templates.py index cfaa2bac5e1..0e3f8a81f95 100644 --- a/posthog/helpers/dashboard_templates.py +++ b/posthog/helpers/dashboard_templates.py @@ -1,4 +1,5 @@ -from typing import Callable, Dict, List, Optional +from typing import Optional +from collections.abc import Callable import structlog @@ -28,7 +29,7 @@ from posthog.models.dashboard_tile import DashboardTile, Text from posthog.models.insight import Insight from posthog.models.tag import Tag -DASHBOARD_COLORS: List[str] = ["white", "blue", "green", "purple", "black"] +DASHBOARD_COLORS: list[str] = ["white", "blue", "green", "purple", "black"] logger = structlog.get_logger(__name__) @@ -444,7 +445,7 @@ def _create_default_app_items(dashboard: Dashboard) -> None: create_from_template(dashboard, template) -DASHBOARD_TEMPLATES: Dict[str, Callable] = { +DASHBOARD_TEMPLATES: dict[str, Callable] = { "DEFAULT_APP": _create_default_app_items, "WEBSITE_TRAFFIC": _create_website_dashboard, } @@ -491,7 +492,7 @@ def create_from_template(dashboard: Dashboard, template: DashboardTemplate) -> N logger.error("dashboard_templates.creation.unknown_type", template=template) -def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: Dict, color: Optional[str]) -> None: +def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: dict, color: Optional[str]) -> None: text = Text.objects.create( team=dashboard.team, body=body, @@ -507,11 +508,11 @@ def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: Dict, color: def _create_tile_for_insight( dashboard: Dashboard, name: str, - filters: Dict, + filters: dict, description: str, - layouts: Dict, + layouts: dict, color: Optional[str], - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> None: filter_test_accounts = filters.get("filter_test_accounts", True) insight = Insight.objects.create( diff --git a/posthog/helpers/multi_property_breakdown.py b/posthog/helpers/multi_property_breakdown.py index edc5fe68f1b..94fc538b295 100644 --- a/posthog/helpers/multi_property_breakdown.py +++ b/posthog/helpers/multi_property_breakdown.py @@ -1,12 +1,12 @@ import copy -from typing import Any, Dict, List, Union +from typing import Any, Union -funnel_with_breakdown_type = List[List[Dict[str, Any]]] -possible_funnel_results_types = Union[funnel_with_breakdown_type, List[Dict[str, Any]], Dict[str, Any]] +funnel_with_breakdown_type = list[list[dict[str, Any]]] +possible_funnel_results_types = Union[funnel_with_breakdown_type, list[dict[str, Any]], dict[str, Any]] def protect_old_clients_from_multi_property_default( - request_filter: Dict[str, Any], result: possible_funnel_results_types + request_filter: dict[str, Any], result: possible_funnel_results_types ) -> possible_funnel_results_types: """ Implementing multi property breakdown will default breakdown to a list even if it is received as a string. @@ -25,7 +25,7 @@ def protect_old_clients_from_multi_property_default( :return: """ - if isinstance(result, Dict) or (len(result) > 1) and isinstance(result[0], Dict): + if isinstance(result, dict) or (len(result) > 1) and isinstance(result[0], dict): return result is_breakdown_request = ( @@ -34,7 +34,7 @@ def protect_old_clients_from_multi_property_default( and "breakdown_type" in request_filter and request_filter["breakdown_type"] in ["person", "event"] ) - is_breakdown_result = isinstance(result, List) and len(result) > 0 and isinstance(result[0], List) + is_breakdown_result = isinstance(result, list) and len(result) > 0 and isinstance(result[0], list) is_single_property_breakdown = ( is_breakdown_request @@ -49,14 +49,14 @@ def protect_old_clients_from_multi_property_default( for series_index in range(len(result)): copied_series = copied_result[series_index] - if isinstance(copied_series, List): + if isinstance(copied_series, list): for data_index in range(len(copied_series)): copied_item = copied_series[data_index] if is_single_property_breakdown: - if copied_item.get("breakdown") and isinstance(copied_item["breakdown"], List): + if copied_item.get("breakdown") and isinstance(copied_item["breakdown"], list): copied_item["breakdown"] = copied_item["breakdown"][0] - if copied_item.get("breakdown_value") and isinstance(copied_item["breakdown_value"], List): + if copied_item.get("breakdown_value") and isinstance(copied_item["breakdown_value"], list): copied_item["breakdown_value"] = copied_item["breakdown_value"][0] if is_multi_property_breakdown: diff --git a/posthog/helpers/tests/test_multi_property_breakdown.py b/posthog/helpers/tests/test_multi_property_breakdown.py index d22675adf84..417583ae009 100644 --- a/posthog/helpers/tests/test_multi_property_breakdown.py +++ b/posthog/helpers/tests/test_multi_property_breakdown.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from unittest import TestCase from posthog.helpers.multi_property_breakdown import ( @@ -8,8 +8,8 @@ from posthog.helpers.multi_property_breakdown import ( class TestMultiPropertyBreakdown(TestCase): def test_handles_empty_inputs(self): - data: Dict[str, Any] = {} - result: List = [] + data: dict[str, Any] = {} + result: list = [] try: protect_old_clients_from_multi_property_default(data, result) @@ -17,12 +17,12 @@ class TestMultiPropertyBreakdown(TestCase): raise AssertionError("should not raise any KeyError") def test_handles_empty_breakdowns_array(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdowns": [], "insight": "FUNNELS", "breakdown_type": "event", } - result: List = [] + result: list = [] try: protect_old_clients_from_multi_property_default(data, result) @@ -30,37 +30,37 @@ class TestMultiPropertyBreakdown(TestCase): raise AssertionError("should not raise any KeyError") def test_keeps_multi_property_breakdown_for_multi_property_requests(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdowns": ["a", "b"], "insight": "FUNNELS", "breakdown_type": "event", } - result: List[List[Dict[str, Any]]] = [[{"breakdown": ["a1", "b1"], "breakdown_value": ["a1", "b1"]}]] + result: list[list[dict[str, Any]]] = [[{"breakdown": ["a1", "b1"], "breakdown_value": ["a1", "b1"]}]] actual = protect_old_clients_from_multi_property_default(data, result) # to satisfy mypy - assert isinstance(actual, List) + assert isinstance(actual, list) series = actual[0] - assert isinstance(series, List) + assert isinstance(series, list) data = series[0] assert data["breakdowns"] == ["a1", "b1"] assert "breakdown" not in data def test_flattens_multi_property_breakdown_for_single_property_requests(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdown": "a", "insight": "FUNNELS", "breakdown_type": "event", } - result: List[List[Dict[str, Any]]] = [[{"breakdown": ["a1"], "breakdown_value": ["a1", "b1"]}]] + result: list[list[dict[str, Any]]] = [[{"breakdown": ["a1"], "breakdown_value": ["a1", "b1"]}]] actual = protect_old_clients_from_multi_property_default(data, result) # to satisfy mypy - assert isinstance(actual, List) + assert isinstance(actual, list) series = actual[0] - assert isinstance(series, List) + assert isinstance(series, list) data = series[0] assert data["breakdown"] == "a1" assert "breakdowns" not in data diff --git a/posthog/hogql/ai.py b/posthog/hogql/ai.py index 15b03e82e50..71a565ec777 100644 --- a/posthog/hogql/ai.py +++ b/posthog/hogql/ai.py @@ -63,7 +63,7 @@ def write_sql_from_prompt(prompt: str, *, current_query: Optional[str] = None, t schema_description = "\n\n".join( ( f"Table {table_name} with fields:\n" - + "\n".join((f'- {field["key"]} ({field["type"]})' for field in table_fields)) + + "\n".join(f'- {field["key"]} ({field["type"]})' for field in table_fields) for table_name, table_fields in serialized_database.items() ) ) diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index ccb3f9f3457..e3fa80b3f3e 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from dataclasses import dataclass, field from posthog.hogql.base import Type, Expr, CTE, ConstantType, UnknownType, AST @@ -143,14 +143,14 @@ class SelectQueryType(Type): """Type and new enclosed scope for a select query. Contains information about all tables and columns in the query.""" # all aliases a select query has access to in its scope - aliases: Dict[str, FieldAliasType] = field(default_factory=dict) + aliases: dict[str, FieldAliasType] = field(default_factory=dict) # all types a select query exports - columns: Dict[str, Type] = field(default_factory=dict) + columns: dict[str, Type] = field(default_factory=dict) # all from and join, tables and subqueries with aliases - tables: Dict[str, TableOrSelectType] = field(default_factory=dict) - ctes: Dict[str, CTE] = field(default_factory=dict) + tables: dict[str, TableOrSelectType] = field(default_factory=dict) + ctes: dict[str, CTE] = field(default_factory=dict) # all from and join subqueries without aliases - anonymous_tables: List[Union["SelectQueryType", "SelectUnionQueryType"]] = field(default_factory=list) + anonymous_tables: list[Union["SelectQueryType", "SelectUnionQueryType"]] = field(default_factory=list) # the parent select query, if this is a lambda parent: Optional[Union["SelectQueryType", "SelectUnionQueryType"]] = None @@ -173,7 +173,7 @@ class SelectQueryType(Type): @dataclass(kw_only=True) class SelectUnionQueryType(Type): - types: List[SelectQueryType] + types: list[SelectQueryType] def get_alias_for_table_type(self, table_type: TableOrSelectType) -> Optional[str]: return self.types[0].get_alias_for_table_type(table_type) @@ -313,7 +313,7 @@ class ArrayType(ConstantType): @dataclass(kw_only=True) class TupleType(ConstantType): data_type: ConstantDataType = field(default="tuple", init=False) - item_types: List[ConstantType] + item_types: list[ConstantType] def print_type(self) -> str: return "Tuple" @@ -322,8 +322,8 @@ class TupleType(ConstantType): @dataclass(kw_only=True) class CallType(Type): name: str - arg_types: List[ConstantType] - param_types: Optional[List[ConstantType]] = None + arg_types: list[ConstantType] + param_types: Optional[list[ConstantType]] = None return_type: ConstantType def resolve_constant_type(self, context: HogQLContext) -> ConstantType: @@ -337,7 +337,7 @@ class AsteriskType(Type): @dataclass(kw_only=True) class FieldTraverserType(Type): - chain: List[str | int] + chain: list[str | int] table_type: TableOrSelectType @@ -400,7 +400,7 @@ class FieldType(Type): @dataclass(kw_only=True) class PropertyType(Type): - chain: List[str | int] + chain: list[str | int] field_type: FieldType # The property has been moved into a field we query from a joined subquery @@ -449,12 +449,12 @@ class ArithmeticOperation(Expr): @dataclass(kw_only=True) class And(Expr): type: Optional[ConstantType] = None - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) class Or(Expr): - exprs: List[Expr] + exprs: list[Expr] type: Optional[ConstantType] = None @@ -509,7 +509,7 @@ class ArrayAccess(Expr): @dataclass(kw_only=True) class Array(Expr): - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) @@ -520,12 +520,12 @@ class TupleAccess(Expr): @dataclass(kw_only=True) class Tuple(Expr): - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) class Lambda(Expr): - args: List[str] + args: list[str] expr: Expr @@ -536,7 +536,7 @@ class Constant(Expr): @dataclass(kw_only=True) class Field(Expr): - chain: List[str | int] + chain: list[str | int] @dataclass(kw_only=True) @@ -548,8 +548,8 @@ class Placeholder(Expr): class Call(Expr): name: str """Function name""" - args: List[Expr] - params: Optional[List[Expr]] = None + args: list[Expr] + params: Optional[list[Expr]] = None """ Parameters apply to some aggregate functions, see ClickHouse docs: https://clickhouse.com/docs/en/sql-reference/aggregate-functions/parametric-functions @@ -569,7 +569,7 @@ class JoinExpr(Expr): join_type: Optional[str] = None table: Optional[Union["SelectQuery", "SelectUnionQuery", Field]] = None - table_args: Optional[List[Expr]] = None + table_args: Optional[list[Expr]] = None alias: Optional[str] = None table_final: Optional[bool] = None constraint: Optional["JoinConstraint"] = None @@ -585,8 +585,8 @@ class WindowFrameExpr(Expr): @dataclass(kw_only=True) class WindowExpr(Expr): - partition_by: Optional[List[Expr]] = None - order_by: Optional[List[OrderExpr]] = None + partition_by: Optional[list[Expr]] = None + order_by: Optional[list[OrderExpr]] = None frame_method: Optional[Literal["ROWS", "RANGE"]] = None frame_start: Optional[WindowFrameExpr] = None frame_end: Optional[WindowFrameExpr] = None @@ -595,7 +595,7 @@ class WindowExpr(Expr): @dataclass(kw_only=True) class WindowFunction(Expr): name: str - args: Optional[List[Expr]] = None + args: Optional[list[Expr]] = None over_expr: Optional[WindowExpr] = None over_identifier: Optional[str] = None @@ -604,20 +604,20 @@ class WindowFunction(Expr): class SelectQuery(Expr): # :TRICKY: When adding new fields, make sure they're handled in visitor.py and resolver.py type: Optional[SelectQueryType] = None - ctes: Optional[Dict[str, CTE]] = None - select: List[Expr] + ctes: Optional[dict[str, CTE]] = None + select: list[Expr] distinct: Optional[bool] = None select_from: Optional[JoinExpr] = None array_join_op: Optional[str] = None - array_join_list: Optional[List[Expr]] = None - window_exprs: Optional[Dict[str, WindowExpr]] = None + array_join_list: Optional[list[Expr]] = None + window_exprs: Optional[dict[str, WindowExpr]] = None where: Optional[Expr] = None prewhere: Optional[Expr] = None having: Optional[Expr] = None - group_by: Optional[List[Expr]] = None - order_by: Optional[List[OrderExpr]] = None + group_by: Optional[list[Expr]] = None + order_by: Optional[list[OrderExpr]] = None limit: Optional[Expr] = None - limit_by: Optional[List[Expr]] = None + limit_by: Optional[list[Expr]] = None limit_with_ties: Optional[bool] = None offset: Optional[Expr] = None settings: Optional[HogQLQuerySettings] = None @@ -627,7 +627,7 @@ class SelectQuery(Expr): @dataclass(kw_only=True) class SelectUnionQuery(Expr): type: Optional[SelectUnionQueryType] = None - select_queries: List[SelectQuery] + select_queries: list[SelectQuery] @dataclass(kw_only=True) @@ -652,7 +652,7 @@ class HogQLXAttribute(AST): @dataclass(kw_only=True) class HogQLXTag(AST): kind: str - attributes: List[HogQLXAttribute] + attributes: list[HogQLXAttribute] def to_dict(self): return { diff --git a/posthog/hogql/autocomplete.py b/posthog/hogql/autocomplete.py index b6d003c1ac8..c0d4cd8b84f 100644 --- a/posthog/hogql/autocomplete.py +++ b/posthog/hogql/autocomplete.py @@ -1,5 +1,6 @@ from copy import copy, deepcopy -from typing import Callable, Dict, List, Optional, cast +from typing import Optional, cast +from collections.abc import Callable from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import Database, create_hogql_database from posthog.hogql.database.models import ( @@ -38,7 +39,7 @@ from posthog.schema import ( class GetNodeAtPositionTraverser(TraversingVisitor): start: int end: int - selects: List[ast.SelectQuery] = [] + selects: list[ast.SelectQuery] = [] node: Optional[AST] = None parent_node: Optional[AST] = None last_node: Optional[AST] = None @@ -100,13 +101,13 @@ def convert_field_or_table_to_type_string(field_or_table: FieldOrTable) -> str | return "Object" if isinstance(field_or_table, ast.ExpressionField): return "Expression" - if isinstance(field_or_table, (ast.Table, ast.LazyJoin)): + if isinstance(field_or_table, ast.Table | ast.LazyJoin): return "Table" return None -def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[Dict[str, CTE]]) -> None | Table: +def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[dict[str, CTE]]) -> None | Table: assert context.database is not None def resolve_fields_on_table(table: Table | None, table_query: ast.SelectQuery) -> Table | None: @@ -120,7 +121,7 @@ def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[Dic return None selected_columns = node.type.columns - new_fields: Dict[str, FieldOrTable] = {} + new_fields: dict[str, FieldOrTable] = {} for name, field in selected_columns.items(): if isinstance(field, ast.FieldAliasType): underlying_field_name = field.alias @@ -145,7 +146,7 @@ def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[Dic # Return a new table with a reduced field set class AnonTable(Table): - fields: Dict[str, FieldOrTable] = new_fields + fields: dict[str, FieldOrTable] = new_fields def to_printed_hogql(self): # Use the base table name for resolving property definitions later @@ -184,8 +185,8 @@ def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[Dic return None -def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> Dict[str, ast.Table]: - tables: Dict[str, ast.Table] = {} +def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> dict[str, ast.Table]: + tables: dict[str, ast.Table] = {} if query.select_from is not None and query.select_from.alias is not None: table = get_table(context, query.select_from, query.ctes) @@ -207,7 +208,7 @@ def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> Dict[st # Replaces all ast.FieldTraverser with the underlying node def resolve_table_field_traversers(table: Table, context: HogQLContext) -> Table: new_table = deepcopy(table) - new_fields: Dict[str, FieldOrTable] = {} + new_fields: dict[str, FieldOrTable] = {} for key, field in list(new_table.fields.items()): if not isinstance(field, ast.FieldTraverser): new_fields[key] = field @@ -234,9 +235,9 @@ def resolve_table_field_traversers(table: Table, context: HogQLContext) -> Table return new_table -def append_table_field_to_response(table: Table, suggestions: List[AutocompleteCompletionItem]) -> None: - keys: List[str] = [] - details: List[str | None] = [] +def append_table_field_to_response(table: Table, suggestions: list[AutocompleteCompletionItem]) -> None: + keys: list[str] = [] + details: list[str | None] = [] table_fields = list(table.fields.items()) for field_name, field_or_table in table_fields: # Skip over hidden fields @@ -258,11 +259,11 @@ def append_table_field_to_response(table: Table, suggestions: List[AutocompleteC def extend_responses( - keys: List[str], - suggestions: List[AutocompleteCompletionItem], + keys: list[str], + suggestions: list[AutocompleteCompletionItem], kind: Kind = Kind.Variable, insert_text: Optional[Callable[[str], str]] = None, - details: Optional[List[str | None]] = None, + details: Optional[list[str | None]] = None, ) -> None: suggestions.extend( [ diff --git a/posthog/hogql/bytecode.py b/posthog/hogql/bytecode.py index 2be5c206cf3..f1abb9c4be0 100644 --- a/posthog/hogql/bytecode.py +++ b/posthog/hogql/bytecode.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import Any from posthog.hogql import ast from posthog.hogql.errors import NotImplementedError @@ -39,13 +39,13 @@ ARITHMETIC_OPERATIONS = { } -def to_bytecode(expr: str) -> List[Any]: +def to_bytecode(expr: str) -> list[Any]: from posthog.hogql.parser import parse_expr return create_bytecode(parse_expr(expr)) -def create_bytecode(expr: ast.Expr) -> List[Any]: +def create_bytecode(expr: ast.Expr) -> list[Any]: bytecode = [HOGQL_BYTECODE_IDENTIFIER] bytecode.extend(BytecodeBuilder().visit(expr)) return bytecode diff --git a/posthog/hogql/constants.py b/posthog/hogql/constants.py index 3d933bca47e..45c5b1e034c 100644 --- a/posthog/hogql/constants.py +++ b/posthog/hogql/constants.py @@ -1,6 +1,6 @@ from datetime import date, datetime from enum import Enum -from typing import Optional, Literal, TypeAlias, Tuple, List +from typing import Optional, Literal, TypeAlias from uuid import UUID from pydantic import ConfigDict, BaseModel @@ -18,7 +18,7 @@ ConstantDataType: TypeAlias = Literal[ ] ConstantSupportedPrimitive: TypeAlias = int | float | str | bool | date | datetime | UUID | None ConstantSupportedData: TypeAlias = ( - ConstantSupportedPrimitive | List[ConstantSupportedPrimitive] | Tuple[ConstantSupportedPrimitive, ...] + ConstantSupportedPrimitive | list[ConstantSupportedPrimitive] | tuple[ConstantSupportedPrimitive, ...] ) # Keywords passed to ClickHouse without transformation diff --git a/posthog/hogql/context.py b/posthog/hogql/context.py index 68692323e05..9b5b6092a69 100644 --- a/posthog/hogql/context.py +++ b/posthog/hogql/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Any +from typing import TYPE_CHECKING, Literal, Optional, Any from posthog.hogql.timings import HogQLTimings from posthog.schema import HogQLNotice, HogQLQueryModifiers @@ -11,7 +11,7 @@ if TYPE_CHECKING: @dataclass class HogQLFieldAccess: - input: List[str] + input: list[str] type: Optional[Literal["event", "event.properties", "person", "person.properties"]] field: Optional[str] sql: str @@ -28,7 +28,7 @@ class HogQLContext: # Virtual database we're querying, will be populated from team_id if not present database: Optional["Database"] = None # If set, will save string constants to this dict. Inlines strings into the query if None. - values: Dict = field(default_factory=dict) + values: dict = field(default_factory=dict) # Are we small part of a non-HogQL query? If so, use custom syntax for accessed person properties. within_non_hogql_query: bool = False # Enable full SELECT queries and subqueries in ClickHouse @@ -39,9 +39,9 @@ class HogQLContext: max_view_depth: int = 1 # Warnings returned with the metadata query - warnings: List["HogQLNotice"] = field(default_factory=list) + warnings: list["HogQLNotice"] = field(default_factory=list) # Notices returned with the metadata query - notices: List["HogQLNotice"] = field(default_factory=list) + notices: list["HogQLNotice"] = field(default_factory=list) # Timings in seconds for different parts of the HogQL query timings: HogQLTimings = field(default_factory=HogQLTimings) # Modifications requested by the HogQL client diff --git a/posthog/hogql/database/argmax.py b/posthog/hogql/database/argmax.py index 5872dc77d8b..b6c8e3d853b 100644 --- a/posthog/hogql/database/argmax.py +++ b/posthog/hogql/database/argmax.py @@ -1,10 +1,11 @@ -from typing import Callable, List, Optional, Dict +from typing import Optional +from collections.abc import Callable def argmax_select( table_name: str, - select_fields: Dict[str, List[str | int]], - group_fields: List[str], + select_fields: dict[str, list[str | int]], + group_fields: list[str], argmax_field: str, deleted_field: Optional[str] = None, ): @@ -14,8 +15,8 @@ def argmax_select( name="argMax", args=[field, ast.Field(chain=[table_name, argmax_field])] ) - fields_to_group: List[ast.Expr] = [] - fields_to_select: List[ast.Expr] = [] + fields_to_group: list[ast.Expr] = [] + fields_to_select: list[ast.Expr] = [] for name, chain in select_fields.items(): if name not in group_fields: fields_to_select.append( diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index 74837d21f49..fc1f665cf39 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypedDict from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from pydantic import ConfigDict, BaseModel from sentry_sdk import capture_exception @@ -96,7 +96,7 @@ class Database(BaseModel): numbers: NumbersTable = NumbersTable() # clunky: keep table names in sync with above - _table_names: ClassVar[List[str]] = [ + _table_names: ClassVar[list[str]] = [ "events", "groups", "persons", @@ -109,7 +109,7 @@ class Database(BaseModel): "sessions", ] - _warehouse_table_names: List[str] = [] + _warehouse_table_names: list[str] = [] _timezone: Optional[str] _week_start_day: Optional[WeekStartDay] @@ -136,7 +136,7 @@ class Database(BaseModel): return getattr(self, table_name) raise QueryError(f'Unknown table "{table_name}".') - def get_all_tables(self) -> List[str]: + def get_all_tables(self) -> list[str]: return self._table_names + self._warehouse_table_names def add_warehouse_tables(self, **field_definitions: Any): @@ -226,7 +226,7 @@ def create_hogql_database( if database.events.fields.get(mapping.group_type) is None: database.events.fields[mapping.group_type] = FieldTraverser(chain=[f"group_{mapping.group_type_index}"]) - tables: Dict[str, Table] = {} + tables: dict[str, Table] = {} for table in DataWarehouseTable.objects.filter(team_id=team.pk).exclude(deleted=True): tables[table.name] = table.hogql_definition() @@ -362,35 +362,35 @@ class _SerializedFieldBase(TypedDict): class SerializedField(_SerializedFieldBase, total=False): - fields: List[str] + fields: list[str] table: str - chain: List[str | int] + chain: list[str | int] -def serialize_database(context: HogQLContext) -> Dict[str, List[SerializedField]]: - tables: Dict[str, List[SerializedField]] = {} +def serialize_database(context: HogQLContext) -> dict[str, list[SerializedField]]: + tables: dict[str, list[SerializedField]] = {} if context.database is None: raise ResolutionError("Must provide database to serialize_database") for table_key in context.database.model_fields.keys(): - field_input: Dict[str, Any] = {} + field_input: dict[str, Any] = {} table = getattr(context.database, table_key, None) if isinstance(table, FunctionCallTable): field_input = table.get_asterisk() elif isinstance(table, Table): field_input = table.fields - field_output: List[SerializedField] = serialize_fields(field_input, context) + field_output: list[SerializedField] = serialize_fields(field_input, context) tables[table_key] = field_output return tables -def serialize_fields(field_input, context: HogQLContext) -> List[SerializedField]: +def serialize_fields(field_input, context: HogQLContext) -> list[SerializedField]: from posthog.hogql.database.models import SavedQuery - field_output: List[SerializedField] = [] + field_output: list[SerializedField] = [] for field_key, field in field_input.items(): if field_key == "team_id": pass diff --git a/posthog/hogql/database/models.py b/posthog/hogql/database/models.py index f6e985d92b4..34bec54eca3 100644 --- a/posthog/hogql/database/models.py +++ b/posthog/hogql/database/models.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING +from collections.abc import Callable from pydantic import ConfigDict, BaseModel from posthog.hogql.base import Expr @@ -65,11 +66,11 @@ class ExpressionField(DatabaseField): class FieldTraverser(FieldOrTable): model_config = ConfigDict(extra="forbid") - chain: List[str | int] + chain: list[str | int] class Table(FieldOrTable): - fields: Dict[str, FieldOrTable] + fields: dict[str, FieldOrTable] model_config = ConfigDict(extra="forbid") def has_field(self, name: str | int) -> bool: @@ -87,12 +88,12 @@ class Table(FieldOrTable): def to_printed_hogql(self) -> str: raise NotImplementedError("Table.to_printed_hogql not overridden") - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: return [] def get_asterisk(self): fields_to_avoid = [*self.avoid_asterisk_fields(), "team_id"] - asterisk: Dict[str, FieldOrTable] = {} + asterisk: dict[str, FieldOrTable] = {} for key, field in self.fields.items(): if key in fields_to_avoid: continue @@ -109,10 +110,10 @@ class Table(FieldOrTable): class LazyJoin(FieldOrTable): model_config = ConfigDict(extra="forbid") - join_function: Callable[[str, str, Dict[str, Any], "HogQLContext", "SelectQuery"], Any] + join_function: Callable[[str, str, dict[str, Any], "HogQLContext", "SelectQuery"], Any] join_table: Table | str - from_field: List[str | int] - to_field: Optional[List[str | int]] = None + from_field: list[str | int] + to_field: Optional[list[str | int]] = None def resolve_table(self, context: "HogQLContext") -> Table: if isinstance(self.join_table, Table): @@ -132,7 +133,7 @@ class LazyTable(Table): model_config = ConfigDict(extra="forbid") def lazy_select( - self, requested_fields: Dict[str, List[str | int]], context: "HogQLContext", node: "SelectQuery" + self, requested_fields: dict[str, list[str | int]], context: "HogQLContext", node: "SelectQuery" ) -> Any: raise NotImplementedError("LazyTable.lazy_select not overridden") diff --git a/posthog/hogql/database/schema/cohort_people.py b/posthog/hogql/database/schema/cohort_people.py index c556903d40c..255779aef59 100644 --- a/posthog/hogql/database/schema/cohort_people.py +++ b/posthog/hogql/database/schema/cohort_people.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql.database.models import ( StringDatabaseField, IntegerDatabaseField, @@ -22,7 +20,7 @@ COHORT_PEOPLE_FIELDS = { } -def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]], team_id: int): +def select_from_cohort_people_table(requested_fields: dict[str, list[str | int]], team_id: int): from posthog.hogql import ast from posthog.models import Cohort @@ -39,7 +37,7 @@ def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]] if "cohort_id" not in requested_fields: requested_fields = {**requested_fields, "cohort_id": ["cohort_id"]} - fields: List[ast.Expr] = [ + fields: list[ast.Expr] = [ ast.Alias(alias=name, expr=ast.Field(chain=[table_name, *chain])) for name, chain in requested_fields.items() ] @@ -60,7 +58,7 @@ def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]] class RawCohortPeople(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **COHORT_PEOPLE_FIELDS, "sign": IntegerDatabaseField(name="sign"), "version": IntegerDatabaseField(name="version"), @@ -74,9 +72,9 @@ class RawCohortPeople(Table): class CohortPeople(LazyTable): - fields: Dict[str, FieldOrTable] = COHORT_PEOPLE_FIELDS + fields: dict[str, FieldOrTable] = COHORT_PEOPLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_cohort_people_table(requested_fields, context.team_id) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/event_sessions.py b/posthog/hogql/database/schema/event_sessions.py index 31682981ea3..fc03357884a 100644 --- a/posthog/hogql/database/schema/event_sessions.py +++ b/posthog/hogql/database/schema/event_sessions.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import ( @@ -14,7 +14,7 @@ from posthog.hogql.visitor import CloningVisitor, TraversingVisitor class EventsSessionSubTable(VirtualTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="$session_id"), "duration": IntegerDatabaseField(name="session_duration"), } @@ -27,7 +27,7 @@ class EventsSessionSubTable(VirtualTable): class GetFieldsTraverser(TraversingVisitor): - fields: List[ast.Field] + fields: list[ast.Field] def __init__(self, expr: ast.Expr): super().__init__() @@ -71,7 +71,7 @@ class ContainsLazyJoinType(TraversingVisitor): class WhereClauseExtractor: - compare_operators: List[ast.Expr] + compare_operators: list[ast.Expr] def __init__( self, @@ -123,10 +123,10 @@ class WhereClauseExtractor: return True - def run(self, expr: ast.Expr) -> List[ast.Expr]: - exprs_to_apply: List[ast.Expr] = [] + def run(self, expr: ast.Expr) -> list[ast.Expr]: + exprs_to_apply: list[ast.Expr] = [] - def should_add(expression: ast.Expr, fields: List[ast.Field]) -> bool: + def should_add(expression: ast.Expr, fields: list[ast.Field]) -> bool: for field in fields: on_table = self._is_field_on_table(field) if not on_table: @@ -168,7 +168,7 @@ class WhereClauseExtractor: def join_with_events_table_session_duration( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: ast.SelectQuery, ): diff --git a/posthog/hogql/database/schema/events.py b/posthog/hogql/database/schema/events.py index 88f59a11fd7..34941de0ec9 100644 --- a/posthog/hogql/database/schema/events.py +++ b/posthog/hogql/database/schema/events.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( VirtualTable, StringDatabaseField, @@ -20,7 +18,7 @@ from posthog.hogql.database.schema.sessions import join_events_table_to_sessions class EventsPersonSubTable(VirtualTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="person_id"), "created_at": DateTimeDatabaseField(name="person_created_at"), "properties": StringJSONDatabaseField(name="person_properties"), @@ -54,7 +52,7 @@ class EventsGroupSubTable(VirtualTable): class EventsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "uuid": StringDatabaseField(name="uuid"), "event": StringDatabaseField(name="event"), "properties": StringJSONDatabaseField(name="properties"), diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py index ad97ff7eb08..06fc40560b7 100644 --- a/posthog/hogql/database/schema/groups.py +++ b/posthog/hogql/database/schema/groups.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -24,7 +24,7 @@ GROUPS_TABLE_FIELDS = { } -def select_from_groups_table(requested_fields: Dict[str, List[str | int]]): +def select_from_groups_table(requested_fields: dict[str, list[str | int]]): return argmax_select( table_name="raw_groups", select_fields=requested_fields, @@ -37,7 +37,7 @@ def join_with_group_n_table(group_index: int): def join_with_group_table( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): @@ -70,7 +70,7 @@ def join_with_group_n_table(group_index: int): class RawGroupsTable(Table): - fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS def to_printed_clickhouse(self, context): return "groups" @@ -80,9 +80,9 @@ class RawGroupsTable(Table): class GroupsTable(LazyTable): - fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_groups_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/heatmaps.py b/posthog/hogql/database/schema/heatmaps.py index 6041926f536..959117baef8 100644 --- a/posthog/hogql/database/schema/heatmaps.py +++ b/posthog/hogql/database/schema/heatmaps.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( StringDatabaseField, DateTimeDatabaseField, @@ -11,7 +9,7 @@ from posthog.hogql.database.models import ( class HeatmapsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "session_id": StringDatabaseField(name="session_id"), "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), diff --git a/posthog/hogql/database/schema/log_entries.py b/posthog/hogql/database/schema/log_entries.py index 14efaff09ce..edd2f761981 100644 --- a/posthog/hogql/database/schema/log_entries.py +++ b/posthog/hogql/database/schema/log_entries.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql import ast from posthog.hogql.database.models import ( Table, @@ -10,7 +8,7 @@ from posthog.hogql.database.models import ( FieldOrTable, ) -LOG_ENTRIES_FIELDS: Dict[str, FieldOrTable] = { +LOG_ENTRIES_FIELDS: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "log_source": StringDatabaseField(name="log_source"), "log_source_id": StringDatabaseField(name="log_source_id"), @@ -22,7 +20,7 @@ LOG_ENTRIES_FIELDS: Dict[str, FieldOrTable] = { class LogEntriesTable(Table): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS def to_printed_clickhouse(self, context): return "log_entries" @@ -32,10 +30,10 @@ class LogEntriesTable(Table): class ReplayConsoleLogsLogEntriesTable(LazyTable): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): - fields: List[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): + fields: list[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] return ast.SelectQuery( select=fields, @@ -55,10 +53,10 @@ class ReplayConsoleLogsLogEntriesTable(LazyTable): class BatchExportLogEntriesTable(LazyTable): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): - fields: List[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): + fields: list[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] return ast.SelectQuery( select=fields, diff --git a/posthog/hogql/database/schema/numbers.py b/posthog/hogql/database/schema/numbers.py index 01c09ac66d7..7590e4041c1 100644 --- a/posthog/hogql/database/schema/numbers.py +++ b/posthog/hogql/database/schema/numbers.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from posthog.hogql.database.models import ( IntegerDatabaseField, @@ -12,7 +12,7 @@ NUMBERS_TABLE_FIELDS = { class NumbersTable(FunctionCallTable): - fields: Dict[str, FieldOrTable] = NUMBERS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = NUMBERS_TABLE_FIELDS name: str = "numbers" min_args: Optional[int] = 1 diff --git a/posthog/hogql/database/schema/person_distinct_id_overrides.py b/posthog/hogql/database/schema/person_distinct_id_overrides.py index 6045e74ff76..209c73c346e 100644 --- a/posthog/hogql/database/schema/person_distinct_id_overrides.py +++ b/posthog/hogql/database/schema/person_distinct_id_overrides.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -27,7 +26,7 @@ PERSON_DISTINCT_ID_OVERRIDES_FIELDS = { } -def select_from_person_distinct_id_overrides_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_distinct_id_overrides_table(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -43,7 +42,7 @@ def select_from_person_distinct_id_overrides_table(requested_fields: Dict[str, L def join_with_person_distinct_id_overrides_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str]], + requested_fields: dict[str, list[str]], context: HogQLContext, node: SelectQuery, ): @@ -65,7 +64,7 @@ def join_with_person_distinct_id_overrides_table( class RawPersonDistinctIdOverridesTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_DISTINCT_ID_OVERRIDES_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -79,9 +78,9 @@ class RawPersonDistinctIdOverridesTable(Table): class PersonDistinctIdOverridesTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSON_DISTINCT_ID_OVERRIDES_FIELDS + fields: dict[str, FieldOrTable] = PERSON_DISTINCT_ID_OVERRIDES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context: HogQLContext, node: SelectQuery): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery): return select_from_person_distinct_id_overrides_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/person_distinct_ids.py b/posthog/hogql/database/schema/person_distinct_ids.py index dde1f97c279..9fa00c59c29 100644 --- a/posthog/hogql/database/schema/person_distinct_ids.py +++ b/posthog/hogql/database/schema/person_distinct_ids.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -27,7 +26,7 @@ PERSON_DISTINCT_IDS_FIELDS = { } -def select_from_person_distinct_ids_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_distinct_ids_table(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -43,7 +42,7 @@ def select_from_person_distinct_ids_table(requested_fields: Dict[str, List[str | def join_with_person_distinct_ids_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str]], + requested_fields: dict[str, list[str]], context: HogQLContext, node: SelectQuery, ): @@ -65,7 +64,7 @@ def join_with_person_distinct_ids_table( class RawPersonDistinctIdsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_DISTINCT_IDS_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -79,9 +78,9 @@ class RawPersonDistinctIdsTable(Table): class PersonDistinctIdsTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS + fields: dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_person_distinct_ids_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/person_overrides.py b/posthog/hogql/database/schema/person_overrides.py index 559ddd3a801..366321cf65e 100644 --- a/posthog/hogql/database/schema/person_overrides.py +++ b/posthog/hogql/database/schema/person_overrides.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -14,7 +14,7 @@ from posthog.hogql.database.models import ( from posthog.hogql.errors import ResolutionError from posthog.schema import HogQLQueryModifiers -PERSON_OVERRIDES_FIELDS: Dict[str, FieldOrTable] = { +PERSON_OVERRIDES_FIELDS: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "old_person_id": StringDatabaseField(name="old_person_id"), "override_person_id": StringDatabaseField(name="override_person_id"), @@ -24,7 +24,7 @@ PERSON_OVERRIDES_FIELDS: Dict[str, FieldOrTable] = { } -def select_from_person_overrides_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_overrides_table(requested_fields: dict[str, list[str | int]]): return argmax_select( table_name="raw_person_overrides", select_fields=requested_fields, @@ -36,7 +36,7 @@ def select_from_person_overrides_table(requested_fields: Dict[str, List[str | in def join_with_person_overrides_table( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): @@ -59,7 +59,7 @@ def join_with_person_overrides_table( class RawPersonOverridesTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_OVERRIDES_FIELDS, "version": IntegerDatabaseField(name="version"), } @@ -72,9 +72,9 @@ class RawPersonOverridesTable(Table): class PersonOverridesTable(Table): - fields: Dict[str, FieldOrTable] = PERSON_OVERRIDES_FIELDS + fields: dict[str, FieldOrTable] = PERSON_OVERRIDES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: dict[str, list[str | int]], modifiers: HogQLQueryModifiers): return select_from_person_overrides_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/persons.py b/posthog/hogql/database/schema/persons.py index 189da1faee0..14884a7008f 100644 --- a/posthog/hogql/database/schema/persons.py +++ b/posthog/hogql/database/schema/persons.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.constants import HogQLQuerySettings @@ -19,7 +18,7 @@ from posthog.hogql.errors import ResolutionError from posthog.hogql.database.schema.persons_pdi import PersonsPDITable, persons_pdi_join from posthog.schema import HogQLQueryModifiers, PersonsArgMaxVersion -PERSONS_FIELDS: Dict[str, FieldOrTable] = { +PERSONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="id"), "created_at": DateTimeDatabaseField(name="created_at"), "team_id": IntegerDatabaseField(name="team_id"), @@ -33,7 +32,7 @@ PERSONS_FIELDS: Dict[str, FieldOrTable] = { } -def select_from_persons_table(requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): +def select_from_persons_table(requested_fields: dict[str, list[str | int]], modifiers: HogQLQueryModifiers): version = modifiers.personsArgMaxVersion if version == PersonsArgMaxVersion.auto: version = PersonsArgMaxVersion.v1 @@ -85,7 +84,7 @@ def select_from_persons_table(requested_fields: Dict[str, List[str | int]], modi def join_with_persons_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str | int]], + requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery, ): @@ -107,7 +106,7 @@ def join_with_persons_table( class RawPersonsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSONS_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -121,9 +120,9 @@ class RawPersonsTable(Table): class PersonsTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSONS_FIELDS + fields: dict[str, FieldOrTable] = PERSONS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_persons_table(requested_fields, context.modifiers) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/persons_pdi.py b/posthog/hogql/database/schema/persons_pdi.py index 30fdadee677..0e30b4e62d2 100644 --- a/posthog/hogql/database/schema/persons_pdi.py +++ b/posthog/hogql/database/schema/persons_pdi.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -14,7 +13,7 @@ from posthog.hogql.errors import ResolutionError # :NOTE: We already have person_distinct_ids.py, which most tables link to. This persons_pdi.py is a hack to # make "select persons.pdi.distinct_id from persons" work while avoiding circular imports. Don't use directly. -def persons_pdi_select(requested_fields: Dict[str, List[str | int]]): +def persons_pdi_select(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -32,7 +31,7 @@ def persons_pdi_select(requested_fields: Dict[str, List[str | int]]): def persons_pdi_join( from_table: str, to_table: str, - requested_fields: Dict[str, List[str | int]], + requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery, ): @@ -56,13 +55,13 @@ def persons_pdi_join( # :NOTE: We already have person_distinct_ids.py, which most tables link to. This persons_pdi.py is a hack to # make "select persons.pdi.distinct_id from persons" work while avoiding circular imports. Don't use directly. class PersonsPDITable(LazyTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), "person_id": StringDatabaseField(name="person_id"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return persons_pdi_select(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/session_replay_events.py b/posthog/hogql/database/schema/session_replay_events.py index a6f0fbed3bc..81f705af378 100644 --- a/posthog/hogql/database/schema/session_replay_events.py +++ b/posthog/hogql/database/schema/session_replay_events.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql.database.models import ( Table, StringDatabaseField, @@ -18,7 +16,7 @@ from posthog.hogql.database.schema.person_distinct_ids import ( RAW_ONLY_FIELDS = ["min_first_timestamp", "max_last_timestamp"] -SESSION_REPLAY_EVENTS_COMMON_FIELDS: Dict[str, FieldOrTable] = { +SESSION_REPLAY_EVENTS_COMMON_FIELDS: dict[str, FieldOrTable] = { "session_id": StringDatabaseField(name="session_id"), "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), @@ -46,14 +44,14 @@ SESSION_REPLAY_EVENTS_COMMON_FIELDS: Dict[str, FieldOrTable] = { class RawSessionReplayEventsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **SESSION_REPLAY_EVENTS_COMMON_FIELDS, "min_first_timestamp": DateTimeDatabaseField(name="min_first_timestamp"), "max_last_timestamp": DateTimeDatabaseField(name="max_last_timestamp"), "first_url": DatabaseField(name="first_url"), } - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: return ["first_url"] def to_printed_clickhouse(self, context): @@ -63,7 +61,7 @@ class RawSessionReplayEventsTable(Table): return "raw_session_replay_events" -def select_from_session_replay_events_table(requested_fields: Dict[str, List[str | int]]): +def select_from_session_replay_events_table(requested_fields: dict[str, list[str | int]]): from posthog.hogql import ast table_name = "raw_session_replay_events" @@ -85,8 +83,8 @@ def select_from_session_replay_events_table(requested_fields: Dict[str, List[str "message_count": ast.Call(name="sum", args=[ast.Field(chain=[table_name, "message_count"])]), } - select_fields: List[ast.Expr] = [] - group_by_fields: List[ast.Expr] = [] + select_fields: list[ast.Expr] = [] + group_by_fields: list[ast.Expr] = [] for name, chain in requested_fields.items(): if name in RAW_ONLY_FIELDS: @@ -107,14 +105,14 @@ def select_from_session_replay_events_table(requested_fields: Dict[str, List[str class SessionReplayEventsTable(LazyTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **{k: v for k, v in SESSION_REPLAY_EVENTS_COMMON_FIELDS.items() if k not in RAW_ONLY_FIELDS}, "start_time": DateTimeDatabaseField(name="start_time"), "end_time": DateTimeDatabaseField(name="end_time"), "first_url": StringDatabaseField(name="first_url"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_session_replay_events_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/sessions.py b/posthog/hogql/database/schema/sessions.py index e1fcaf1a75f..0bd6bfef09c 100644 --- a/posthog/hogql/database/schema/sessions.py +++ b/posthog/hogql/database/schema/sessions.py @@ -1,4 +1,4 @@ -from typing import Dict, List, cast, Any, TYPE_CHECKING +from typing import cast, Any, TYPE_CHECKING from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -19,7 +19,7 @@ from posthog.hogql.errors import ResolutionError if TYPE_CHECKING: pass -RAW_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { +RAW_SESSIONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="session_id"), # TODO remove this, it's a duplicate of the correct session_id field below to get some trends working on a deadline "session_id": StringDatabaseField(name="session_id"), @@ -44,7 +44,7 @@ RAW_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { "autocapture_count": IntegerDatabaseField(name="autocapture_count"), } -LAZY_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { +LAZY_SESSIONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="session_id"), # TODO remove this, it's a duplicate of the correct session_id field below to get some trends working on a deadline "session_id": StringDatabaseField(name="session_id"), @@ -75,7 +75,7 @@ LAZY_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { class RawSessionsTable(Table): - fields: Dict[str, FieldOrTable] = RAW_SESSIONS_FIELDS + fields: dict[str, FieldOrTable] = RAW_SESSIONS_FIELDS def to_printed_clickhouse(self, context): return "sessions" @@ -83,7 +83,7 @@ class RawSessionsTable(Table): def to_printed_hogql(self): return "raw_sessions" - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: # our clickhouse driver can't return aggregate states return [ "entry_url", @@ -100,7 +100,7 @@ class RawSessionsTable(Table): def select_from_sessions_table( - requested_fields: Dict[str, List[str | int]], node: ast.SelectQuery, context: HogQLContext + requested_fields: dict[str, list[str | int]], node: ast.SelectQuery, context: HogQLContext ): from posthog.hogql import ast @@ -166,8 +166,8 @@ def select_from_sessions_table( } aggregate_fields["duration"] = aggregate_fields["$session_duration"] - select_fields: List[ast.Expr] = [] - group_by_fields: List[ast.Expr] = [ast.Field(chain=[table_name, "session_id"])] + select_fields: list[ast.Expr] = [] + group_by_fields: list[ast.Expr] = [ast.Field(chain=[table_name, "session_id"])] for name, chain in requested_fields.items(): if name in aggregate_fields: @@ -189,9 +189,9 @@ def select_from_sessions_table( class SessionsTable(LazyTable): - fields: Dict[str, FieldOrTable] = LAZY_SESSIONS_FIELDS + fields: dict[str, FieldOrTable] = LAZY_SESSIONS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node: ast.SelectQuery): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node: ast.SelectQuery): return select_from_sessions_table(requested_fields, node, context) def to_printed_clickhouse(self, context): @@ -202,7 +202,7 @@ class SessionsTable(LazyTable): def join_events_table_to_sessions_table( - from_table: str, to_table: str, requested_fields: Dict[str, Any], context: HogQLContext, node: ast.SelectQuery + from_table: str, to_table: str, requested_fields: dict[str, Any], context: HogQLContext, node: ast.SelectQuery ) -> ast.JoinExpr: from posthog.hogql import ast diff --git a/posthog/hogql/database/schema/static_cohort_people.py b/posthog/hogql/database/schema/static_cohort_people.py index 97d90cbd6dc..fafbe9459eb 100644 --- a/posthog/hogql/database/schema/static_cohort_people.py +++ b/posthog/hogql/database/schema/static_cohort_people.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( Table, StringDatabaseField, @@ -11,7 +9,7 @@ from posthog.hogql.database.schema.persons import join_with_persons_table class StaticCohortPeople(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "person_id": StringDatabaseField(name="person_id"), "cohort_id": IntegerDatabaseField(name="cohort_id"), "team_id": IntegerDatabaseField(name="team_id"), diff --git a/posthog/hogql/database/schema/test/test_event_sessions.py b/posthog/hogql/database/schema/test/test_event_sessions.py index 1a31bc3f472..914ac471236 100644 --- a/posthog/hogql/database/schema/test/test_event_sessions.py +++ b/posthog/hogql/database/schema/test/test_event_sessions.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import create_hogql_database @@ -21,7 +21,7 @@ class TestWhereClauseExtractor(BaseTest): select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True)) return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse")) - def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> List[ast.Expr]: + def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> list[ast.Expr]: assert query.where is not None and query.type is not None return WhereClauseExtractor(query.where, table_name, query.type, self.context).compare_operators diff --git a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py index 1e3464c1b9b..ea8c55d054c 100644 --- a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py +++ b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -22,7 +22,7 @@ def f(s: Union[str, ast.Expr, None], placeholders: Optional[dict[str, ast.Expr]] def parse( s: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> ast.SelectQuery | ast.SelectUnionQuery: parsed = parse_select(s, placeholders=placeholders) return parsed diff --git a/posthog/hogql/escape_sql.py b/posthog/hogql/escape_sql.py index 35a563061e1..10f4a413fa6 100644 --- a/posthog/hogql/escape_sql.py +++ b/posthog/hogql/escape_sql.py @@ -1,6 +1,6 @@ import re from datetime import datetime, date -from typing import Optional, Any, Literal, List, Tuple +from typing import Optional, Any, Literal from uuid import UUID from zoneinfo import ZoneInfo @@ -129,8 +129,8 @@ class SQLValueEscaper: def visit_date(self, value: date): return f"toDate({self.visit(value.strftime('%Y-%m-%d'))})" - def visit_list(self, value: List): + def visit_list(self, value: list): return f"[{', '.join(str(self.visit(x)) for x in value)}]" - def visit_tuple(self, value: Tuple): + def visit_tuple(self, value: tuple): return f"({', '.join(str(self.visit(x)) for x in value)})" diff --git a/posthog/hogql/filters.py b/posthog/hogql/filters.py index 496cadf8da4..06ea36c1cdd 100644 --- a/posthog/hogql/filters.py +++ b/posthog/hogql/filters.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypeVar +from typing import Optional, TypeVar from dateutil.parser import isoparse @@ -23,7 +23,7 @@ class ReplaceFilters(CloningVisitor): super().__init__() self.filters = filters self.team = team - self.selects: List[ast.SelectQuery] = [] + self.selects: list[ast.SelectQuery] = [] def visit_select_query(self, node): self.selects.append(node) @@ -51,7 +51,7 @@ class ReplaceFilters(CloningVisitor): "Cannot use 'filters' placeholder in a SELECT clause that does not select from the events table." ) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] if self.filters.properties is not None: exprs.append(property_to_expr(self.filters.properties, self.team)) diff --git a/posthog/hogql/functions/action.py b/posthog/hogql/functions/action.py index 02888081632..5ed8a156e39 100644 --- a/posthog/hogql/functions/action.py +++ b/posthog/hogql/functions/action.py @@ -1,12 +1,10 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.errors import QueryError from posthog.hogql.escape_sql import escape_clickhouse_string -def matches_action(node: ast.Expr, args: List[ast.Expr], context: HogQLContext) -> ast.Expr: +def matches_action(node: ast.Expr, args: list[ast.Expr], context: HogQLContext) -> ast.Expr: arg = args[0] if not isinstance(arg, ast.Constant): raise QueryError("action() takes only constant arguments", node=arg) diff --git a/posthog/hogql/functions/cohort.py b/posthog/hogql/functions/cohort.py index fc5077f610a..2b0992c6e7e 100644 --- a/posthog/hogql/functions/cohort.py +++ b/posthog/hogql/functions/cohort.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -23,7 +23,7 @@ def cohort_query_node(node: ast.Expr, context: HogQLContext) -> ast.Expr: return cohort(node, [node], context) -def cohort(node: ast.Expr, args: List[ast.Expr], context: HogQLContext) -> ast.Expr: +def cohort(node: ast.Expr, args: list[ast.Expr], context: HogQLContext) -> ast.Expr: arg = args[0] if not isinstance(arg, ast.Constant): raise QueryError("cohort() takes only constant arguments", node=arg) diff --git a/posthog/hogql/functions/mapping.py b/posthog/hogql/functions/mapping.py index 6080face6f6..c4087013c85 100644 --- a/posthog/hogql/functions/mapping.py +++ b/posthog/hogql/functions/mapping.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from itertools import chain -from typing import List, Optional, Dict, Tuple, Type +from typing import Optional from posthog.hogql import ast from posthog.hogql.base import ConstantType from posthog.hogql.errors import QueryError def validate_function_args( - args: List[ast.Expr], + args: list[ast.Expr], min_args: int, max_args: Optional[int], function_name: str, @@ -31,7 +31,7 @@ def validate_function_args( ) -Overload = Tuple[Tuple[Type[ConstantType], ...] | Type[ConstantType], str] +Overload = tuple[tuple[type[ConstantType], ...] | type[ConstantType], str] @dataclass() @@ -42,7 +42,7 @@ class HogQLFunctionMeta: min_params: int = 0 max_params: Optional[int] = 0 aggregate: bool = False - overloads: Optional[List[Overload]] = None + overloads: Optional[list[Overload]] = None """Overloads allow for using a different ClickHouse function depending on the type of the first arg.""" tz_aware: bool = False """Whether the function is timezone-aware. This means the project timezone will be appended as the last arg.""" @@ -50,7 +50,7 @@ class HogQLFunctionMeta: """Not all ClickHouse functions are case-insensitive. See https://clickhouse.com/docs/en/sql-reference/syntax#keywords.""" -HOGQL_COMPARISON_MAPPING: Dict[str, ast.CompareOperationOp] = { +HOGQL_COMPARISON_MAPPING: dict[str, ast.CompareOperationOp] = { "equals": ast.CompareOperationOp.Eq, "notEquals": ast.CompareOperationOp.NotEq, "less": ast.CompareOperationOp.Lt, @@ -65,7 +65,7 @@ HOGQL_COMPARISON_MAPPING: Dict[str, ast.CompareOperationOp] = { "notIn": ast.CompareOperationOp.NotIn, } -HOGQL_CLICKHOUSE_FUNCTIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_CLICKHOUSE_FUNCTIONS: dict[str, HogQLFunctionMeta] = { # arithmetic "plus": HogQLFunctionMeta("plus", 2, 2), "minus": HogQLFunctionMeta("minus", 2, 2), @@ -575,7 +575,7 @@ HOGQL_CLICKHOUSE_FUNCTIONS: Dict[str, HogQLFunctionMeta] = { "leadInFrame": HogQLFunctionMeta("leadInFrame", 1, 1), } # Permitted HogQL aggregations -HOGQL_AGGREGATIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_AGGREGATIONS: dict[str, HogQLFunctionMeta] = { # Standard aggregate functions "count": HogQLFunctionMeta("count", 0, 1, aggregate=True, case_sensitive=False), "countIf": HogQLFunctionMeta("countIf", 1, 2, aggregate=True), @@ -747,7 +747,7 @@ HOGQL_AGGREGATIONS: Dict[str, HogQLFunctionMeta] = { "maxIntersectionsPosition": HogQLFunctionMeta("maxIntersectionsPosition", 2, 2, aggregate=True), "maxIntersectionsPositionIf": HogQLFunctionMeta("maxIntersectionsPositionIf", 3, 3, aggregate=True), } -HOGQL_POSTHOG_FUNCTIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_POSTHOG_FUNCTIONS: dict[str, HogQLFunctionMeta] = { "matchesAction": HogQLFunctionMeta("matchesAction", 1, 1), "sparkline": HogQLFunctionMeta("sparkline", 1, 1), "hogql_lookupDomainType": HogQLFunctionMeta("hogql_lookupDomainType", 1, 1), @@ -781,7 +781,7 @@ FIRST_ARG_DATETIME_FUNCTIONS = ( ) -def _find_function(name: str, functions: Dict[str, HogQLFunctionMeta]) -> Optional[HogQLFunctionMeta]: +def _find_function(name: str, functions: dict[str, HogQLFunctionMeta]) -> Optional[HogQLFunctionMeta]: func = functions.get(name) if func is not None: return func diff --git a/posthog/hogql/functions/sparkline.py b/posthog/hogql/functions/sparkline.py index ddd6c02a7b2..5bbf9004f44 100644 --- a/posthog/hogql/functions/sparkline.py +++ b/posthog/hogql/functions/sparkline.py @@ -1,9 +1,7 @@ -from typing import List - from posthog.hogql import ast -def sparkline(node: ast.Expr, args: List[ast.Expr]) -> ast.Expr: +def sparkline(node: ast.Expr, args: list[ast.Expr]) -> ast.Expr: return ast.Tuple( exprs=[ ast.Constant(value="__hogql_chart_type"), diff --git a/posthog/hogql/hogql.py b/posthog/hogql/hogql.py index d3052f58b01..2a537bfd7a8 100644 --- a/posthog/hogql/hogql.py +++ b/posthog/hogql/hogql.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, cast, Optional +from typing import Literal, cast, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -18,7 +18,7 @@ def translate_hogql( metadata_source: Optional[ast.SelectQuery] = None, *, events_table_alias: Optional[str] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> str: """Translate a HogQL expression into a ClickHouse expression.""" if query == "": diff --git a/posthog/hogql/parser.py b/posthog/hogql/parser.py index 0ec619f3389..68637a30a20 100644 --- a/posthog/hogql/parser.py +++ b/posthog/hogql/parser.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Literal, Optional, cast, Callable +from typing import Literal, Optional, cast +from collections.abc import Callable from antlr4 import CommonTokenStream, InputStream, ParseTreeVisitor, ParserRuleContext from antlr4.error.ErrorListener import ErrorListener @@ -19,7 +20,7 @@ from hogql_parser import ( parse_select as _parse_select_cpp, ) -RULE_TO_PARSE_FUNCTION: Dict[Literal["python", "cpp"], Dict[Literal["expr", "order_expr", "select"], Callable]] = { +RULE_TO_PARSE_FUNCTION: dict[Literal["python", "cpp"], dict[Literal["expr", "order_expr", "select"], Callable]] = { "python": { "expr": lambda string, start: HogQLParseTreeConverter(start=start).visit(get_parser(string).expr()), "order_expr": lambda string: HogQLParseTreeConverter().visit(get_parser(string).orderExpr()), @@ -32,7 +33,7 @@ RULE_TO_PARSE_FUNCTION: Dict[Literal["python", "cpp"], Dict[Literal["expr", "ord }, } -RULE_TO_HISTOGRAM: Dict[Literal["expr", "order_expr", "select"], Histogram] = { +RULE_TO_HISTOGRAM: dict[Literal["expr", "order_expr", "select"], Histogram] = { rule: Histogram( f"parse_{rule}_seconds", f"Time to parse {rule} expression", @@ -44,7 +45,7 @@ RULE_TO_HISTOGRAM: Dict[Literal["expr", "order_expr", "select"], Histogram] = { def parse_expr( expr: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, start: Optional[int] = 0, timings: Optional[HogQLTimings] = None, *, @@ -65,7 +66,7 @@ def parse_expr( def parse_order_expr( order_expr: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, timings: Optional[HogQLTimings] = None, *, backend: Optional[Literal["python", "cpp"]] = None, @@ -85,7 +86,7 @@ def parse_order_expr( def parse_select( statement: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, timings: Optional[HogQLTimings] = None, *, backend: Optional[Literal["python", "cpp"]] = None, @@ -159,10 +160,10 @@ class HogQLParseTreeConverter(ParseTreeVisitor): return self.visit(ctx.selectUnionStmt() or ctx.selectStmt() or ctx.hogqlxTagElement()) def visitSelectUnionStmt(self, ctx: HogQLParser.SelectUnionStmtContext): - select_queries: List[ast.SelectQuery | ast.SelectUnionQuery] = [ + select_queries: list[ast.SelectQuery | ast.SelectUnionQuery] = [ self.visit(select) for select in ctx.selectStmtWithParens() ] - flattened_queries: List[ast.SelectQuery] = [] + flattened_queries: list[ast.SelectQuery] = [] for query in select_queries: if isinstance(query, ast.SelectQuery): flattened_queries.append(query) @@ -771,7 +772,7 @@ class HogQLParseTreeConverter(ParseTreeVisitor): ) def visitWithExprList(self, ctx: HogQLParser.WithExprListContext): - ctes: Dict[str, ast.CTE] = {} + ctes: dict[str, ast.CTE] = {} for expr in ctx.withExpr(): cte = self.visit(expr) ctes[cte.name] = cte diff --git a/posthog/hogql/placeholders.py b/posthog/hogql/placeholders.py index a09e39fd656..d0e835fb0d8 100644 --- a/posthog/hogql/placeholders.py +++ b/posthog/hogql/placeholders.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional, List +from typing import Optional from posthog.hogql import ast from posthog.hogql.errors import QueryError from posthog.hogql.visitor import CloningVisitor, TraversingVisitor -def replace_placeholders(node: ast.Expr, placeholders: Optional[Dict[str, ast.Expr]]) -> ast.Expr: +def replace_placeholders(node: ast.Expr, placeholders: Optional[dict[str, ast.Expr]]) -> ast.Expr: return ReplacePlaceholders(placeholders).visit(node) -def find_placeholders(node: ast.Expr) -> List[str]: +def find_placeholders(node: ast.Expr) -> list[str]: finder = FindPlaceholders() finder.visit(node) return list(finder.found) @@ -28,7 +28,7 @@ class FindPlaceholders(TraversingVisitor): class ReplacePlaceholders(CloningVisitor): - def __init__(self, placeholders: Optional[Dict[str, ast.Expr]]): + def __init__(self, placeholders: Optional[dict[str, ast.Expr]]): super().__init__() self.placeholders = placeholders @@ -42,5 +42,5 @@ class ReplacePlaceholders(CloningVisitor): return new_node raise QueryError( f"Placeholder {{{node.field}}} is not available in this context. You can use the following: " - + ", ".join((f"{placeholder}" for placeholder in self.placeholders)) + + ", ".join(f"{placeholder}" for placeholder in self.placeholders) ) diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index ff4766f8607..a829697e900 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -2,7 +2,7 @@ import re from dataclasses import dataclass from datetime import datetime, date from difflib import get_close_matches -from typing import List, Literal, Optional, Union, cast +from typing import Literal, Optional, Union, cast from uuid import UUID from posthog.hogql import ast @@ -73,7 +73,7 @@ def print_ast( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ) -> str: @@ -92,7 +92,7 @@ def prepare_ast_for_printing( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, ) -> ast.Expr: with context.timings.measure("create_hogql_database"): @@ -130,7 +130,7 @@ def print_prepared_ast( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ) -> str: @@ -158,13 +158,13 @@ class _Printer(Visitor): self, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[AST]] = None, + stack: Optional[list[AST]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ): self.context = context self.dialect = dialect - self.stack: List[AST] = stack or [] # Keep track of all traversed nodes. + self.stack: list[AST] = stack or [] # Keep track of all traversed nodes. self.settings = settings self.pretty = pretty self._indent = -1 @@ -773,7 +773,7 @@ class _Printer(Visitor): if self.dialect == "clickhouse": if node.name in FIRST_ARG_DATETIME_FUNCTIONS: - args: List[str] = [] + args: list[str] = [] for idx, arg in enumerate(node.args): if idx == 0: if isinstance(arg, ast.Call) and arg.name in ADD_OR_NULL_DATETIME_FUNCTIONS: @@ -783,7 +783,7 @@ class _Printer(Visitor): else: args.append(self.visit(arg)) elif node.name == "concat": - args: List[str] = [] + args: list[str] = [] for arg in node.args: if isinstance(arg, ast.Constant): if arg.value is None: @@ -1002,7 +1002,7 @@ class _Printer(Visitor): while isinstance(table, ast.TableAliasType): table = table.table_type - args: List[str] = [] + args: list[str] = [] if self.context.modifiers.materializationMode != "disabled": # find a materialized property for the first part of the chain @@ -1094,7 +1094,7 @@ class _Printer(Visitor): raise ImpossibleASTError(f"Unknown AST node {type(node).__name__}") def visit_window_expr(self, node: ast.WindowExpr): - strings: List[str] = [] + strings: list[str] = [] if node.partition_by is not None: if len(node.partition_by) == 0: raise ImpossibleASTError("PARTITION BY must have at least one argument") @@ -1168,7 +1168,7 @@ class _Printer(Visitor): return escape_clickhouse_string(name, timezone=self._get_timezone()) return escape_hogql_string(name, timezone=self._get_timezone()) - def _unsafe_json_extract_trim_quotes(self, unsafe_field: str, unsafe_args: List[str]) -> str: + def _unsafe_json_extract_trim_quotes(self, unsafe_field: str, unsafe_args: list[str]) -> str: return f"replaceRegexpAll(nullIf(nullIf(JSONExtractRaw({', '.join([unsafe_field, *unsafe_args])}), ''), 'null'), '^\"|\"$', '')" def _get_materialized_column( @@ -1209,7 +1209,7 @@ class _Printer(Visitor): for key, value in settings: if value is None: continue - if not isinstance(value, (int, float, str)): + if not isinstance(value, int | float | str): raise QueryError(f"Setting {key} must be a string, int, or float") if not re.match(r"^[a-zA-Z0-9_]+$", key): raise QueryError(f"Setting {key} is not supported") diff --git a/posthog/hogql/property.py b/posthog/hogql/property.py index fb5e6d90a45..824a11bdae9 100644 --- a/posthog/hogql/property.py +++ b/posthog/hogql/property.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Union, cast, Literal +from typing import Optional, Union, cast, Literal from pydantic import BaseModel @@ -382,7 +382,7 @@ def action_to_expr(action: Action) -> ast.Expr: or_queries = [] for step in steps: - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] if step.event: exprs.append(parse_expr("event = {event}", {"event": ast.Constant(value=step.event)})) diff --git a/posthog/hogql/query.py b/posthog/hogql/query.py index 65c0c9d7135..b42a61b7855 100644 --- a/posthog/hogql/query.py +++ b/posthog/hogql/query.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, Optional, Union, cast +from typing import Optional, Union, cast from posthog.clickhouse.client.connection import Workload from posthog.errors import ExposedCHQueryError @@ -32,7 +32,7 @@ def execute_hogql_query( *, query_type: str = "hogql_query", filters: Optional[HogQLFilters] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, workload: Workload = Workload.ONLINE, settings: Optional[HogQLGlobalSettings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -175,7 +175,7 @@ def execute_hogql_query( except Exception as e: if explain: results, types = None, None - if isinstance(e, (ExposedCHQueryError, ExposedHogQLError)): + if isinstance(e, ExposedCHQueryError | ExposedHogQLError): error = str(e) else: error = "Unknown error" diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index fce251dc8a0..5921e5a6f2d 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -1,5 +1,5 @@ from datetime import date, datetime -from typing import List, Optional, Any, cast, Literal +from typing import Optional, Any, cast, Literal from uuid import UUID from posthog.hogql import ast @@ -58,7 +58,7 @@ def resolve_types( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - scopes: Optional[List[ast.SelectQueryType]] = None, + scopes: Optional[list[ast.SelectQueryType]] = None, ) -> ast.Expr: return Resolver(scopes=scopes, context=context, dialect=dialect).visit(node) @@ -66,7 +66,7 @@ def resolve_types( class AliasCollector(TraversingVisitor): def __init__(self): super().__init__() - self.aliases: List[str] = [] + self.aliases: list[str] = [] def visit_alias(self, node: ast.Alias): self.aliases.append(node.alias) @@ -80,11 +80,11 @@ class Resolver(CloningVisitor): self, context: HogQLContext, dialect: Literal["hogql", "clickhouse"] = "clickhouse", - scopes: Optional[List[ast.SelectQueryType]] = None, + scopes: Optional[list[ast.SelectQueryType]] = None, ): super().__init__() # Each SELECT query creates a new scope (type). Store all of them in a list as we traverse the tree. - self.scopes: List[ast.SelectQueryType] = scopes or [] + self.scopes: list[ast.SelectQueryType] = scopes or [] self.current_view_depth: int = 0 self.context = context self.dialect = dialect @@ -214,7 +214,7 @@ class Resolver(CloningVisitor): return new_node - def _asterisk_columns(self, asterisk: ast.AsteriskType) -> List[ast.Expr]: + def _asterisk_columns(self, asterisk: ast.AsteriskType) -> list[ast.Expr]: """Expand an asterisk. Mutates `select_query.select` and `select_query.type.columns` with the new fields""" if isinstance(asterisk.table_type, ast.BaseTableType): table = asterisk.table_type.resolve_database_table(self.context) @@ -393,13 +393,13 @@ class Resolver(CloningVisitor): return self.visit(matches_action(node=node, args=node.args, context=self.context)) node = super().visit_call(node) - arg_types: List[ast.ConstantType] = [] + arg_types: list[ast.ConstantType] = [] for arg in node.args: if arg.type: arg_types.append(arg.type.resolve_constant_type(self.context) or ast.UnknownType()) else: arg_types.append(ast.UnknownType()) - param_types: Optional[List[ast.ConstantType]] = None + param_types: Optional[list[ast.ConstantType]] = None if node.params is not None: param_types = [] for param in node.params: diff --git a/posthog/hogql/resolver_utils.py b/posthog/hogql/resolver_utils.py index 7910a17fdb9..bfede9538ab 100644 --- a/posthog/hogql/resolver_utils.py +++ b/posthog/hogql/resolver_utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog import schema from posthog.hogql import ast @@ -27,7 +27,7 @@ def lookup_field_by_name(scope: ast.SelectQueryType, name: str, context: HogQLCo return None -def lookup_cte_by_name(scopes: List[ast.SelectQueryType], name: str) -> Optional[ast.CTE]: +def lookup_cte_by_name(scopes: list[ast.SelectQueryType], name: str) -> Optional[ast.CTE]: for scope in reversed(scopes): if scope and scope.ctes and name in scope.ctes: return scope.ctes[name] diff --git a/posthog/hogql/test/_test_parser.py b/posthog/hogql/test/_test_parser.py index 478958746d6..514914906d0 100644 --- a/posthog/hogql/test/_test_parser.py +++ b/posthog/hogql/test/_test_parser.py @@ -1,4 +1,4 @@ -from typing import Literal, cast, Optional, Dict +from typing import Literal, cast, Optional import math @@ -20,10 +20,10 @@ def parser_test_factory(backend: Literal["python", "cpp"]): maxDiff = None - def _expr(self, expr: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: + def _expr(self, expr: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.Expr: return clear_locations(parse_expr(expr, placeholders=placeholders, backend=backend)) - def _select(self, query: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: + def _select(self, query: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.Expr: return clear_locations(parse_select(query, placeholders=placeholders, backend=backend)) def test_numbers(self): diff --git a/posthog/hogql/test/test_filters.py b/posthog/hogql/test/test_filters.py index 5aba11a3b28..05ac11667ae 100644 --- a/posthog/hogql/test/test_filters.py +++ b/posthog/hogql/test/test_filters.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -18,10 +18,10 @@ from posthog.test.base import BaseTest class TestFilters(BaseTest): maxDiff = None - def _parse_expr(self, expr: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_expr(self, expr: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_expr(expr, placeholders=placeholders)) - def _parse_select(self, select: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_select(self, select: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_select(select, placeholders=placeholders)) def _print_ast(self, node: ast.Expr): diff --git a/posthog/hogql/test/test_mapping.py b/posthog/hogql/test/test_mapping.py index b13b2d1c744..9af0d9a60e4 100644 --- a/posthog/hogql/test/test_mapping.py +++ b/posthog/hogql/test/test_mapping.py @@ -23,22 +23,22 @@ class TestMappings(BaseTest): return self._return_present_function(find_hogql_posthog_function(name)) def test_find_case_sensitive_function(self): - self.assertEquals(self._get_hogql_function("toString").clickhouse_name, "toString") - self.assertEquals(find_hogql_function("TOString"), None) - self.assertEquals(find_hogql_function("PlUs"), None) + self.assertEqual(self._get_hogql_function("toString").clickhouse_name, "toString") + self.assertEqual(find_hogql_function("TOString"), None) + self.assertEqual(find_hogql_function("PlUs"), None) - self.assertEquals(self._get_hogql_aggregation("countIf").clickhouse_name, "countIf") - self.assertEquals(find_hogql_aggregation("COUNTIF"), None) + self.assertEqual(self._get_hogql_aggregation("countIf").clickhouse_name, "countIf") + self.assertEqual(find_hogql_aggregation("COUNTIF"), None) - self.assertEquals(self._get_hogql_posthog_function("sparkline").clickhouse_name, "sparkline") - self.assertEquals(find_hogql_posthog_function("SPARKLINE"), None) + self.assertEqual(self._get_hogql_posthog_function("sparkline").clickhouse_name, "sparkline") + self.assertEqual(find_hogql_posthog_function("SPARKLINE"), None) def test_find_case_insensitive_function(self): - self.assertEquals(self._get_hogql_function("CoAlesce").clickhouse_name, "coalesce") + self.assertEqual(self._get_hogql_function("CoAlesce").clickhouse_name, "coalesce") - self.assertEquals(self._get_hogql_aggregation("SuM").clickhouse_name, "sum") + self.assertEqual(self._get_hogql_aggregation("SuM").clickhouse_name, "sum") def test_find_non_existent_function(self): - self.assertEquals(find_hogql_function("functionThatDoesntExist"), None) - self.assertEquals(find_hogql_aggregation("functionThatDoesntExist"), None) - self.assertEquals(find_hogql_posthog_function("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_function("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_aggregation("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_posthog_function("functionThatDoesntExist"), None) diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index 1a8a2130c52..9c7a1fda936 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Dict +from typing import Literal, Optional import pytest from django.test import override_settings @@ -35,7 +35,7 @@ class TestPrinter(BaseTest): self, query: str, context: Optional[HogQLContext] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> str: return print_ast( parse_select(query, placeholders=placeholders), diff --git a/posthog/hogql/test/test_property.py b/posthog/hogql/test/test_property.py index 9b07a362bdd..4f6ed2e1150 100644 --- a/posthog/hogql/test/test_property.py +++ b/posthog/hogql/test/test_property.py @@ -1,4 +1,4 @@ -from typing import List, Union, cast, Optional, Dict, Any, Literal +from typing import Union, cast, Optional, Any, Literal from unittest.mock import MagicMock, patch from posthog.constants import PropertyOperatorType, TREND_FILTER_TYPE_ACTIONS, TREND_FILTER_TYPE_EVENTS @@ -46,7 +46,7 @@ class TestProperty(BaseTest): def _selector_to_expr(self, selector: str): return clear_locations(selector_to_expr(selector)) - def _parse_expr(self, expr: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_expr(self, expr: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_expr(expr, placeholders=placeholders)) def test_has_aggregation(self): @@ -416,7 +416,7 @@ class TestProperty(BaseTest): PropertyGroup( type=PropertyOperatorType.AND, values=cast( - Union[List[Property], List[PropertyGroup]], + Union[list[Property], list[PropertyGroup]], [ Property(type="person", key="a", value="b", operator="exact"), PropertyGroup( diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py index 7dc29543807..b7f13f9c070 100644 --- a/posthog/hogql/test/test_query.py +++ b/posthog/hogql/test/test_query.py @@ -1014,7 +1014,7 @@ class TestQuery(ClickhouseTestMixin, APIBaseTest): f"LIMIT 100 " f"SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1", ) - self.assertEqual(response.results[0], tuple((random_uuid for x in alternatives))) + self.assertEqual(response.results[0], tuple(random_uuid for x in alternatives)) def test_property_access_with_arrays_zero_index_error(self): query = f"SELECT properties.something[0] FROM events" diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index a5f3b838c39..7cbd5a60a32 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -1,5 +1,5 @@ from datetime import timezone, datetime, date -from typing import Optional, Dict, cast +from typing import Optional, cast import pytest from django.test import override_settings from uuid import UUID @@ -28,7 +28,7 @@ from posthog.test.base import BaseTest class TestResolver(BaseTest): maxDiff = None - def _select(self, query: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.SelectQuery: + def _select(self, query: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.SelectQuery: return cast( ast.SelectQuery, clone_expr(parse_select(query, placeholders=placeholders), clear_locations=True), diff --git a/posthog/hogql/test/test_timings.py b/posthog/hogql/test/test_timings.py index 02f8392da09..cfb2259157a 100644 --- a/posthog/hogql/test/test_timings.py +++ b/posthog/hogql/test/test_timings.py @@ -26,8 +26,8 @@ class TestHogQLTimings(BaseTest): pass results = timings.to_dict() - self.assertAlmostEquals(results["./test"], 0.05) - self.assertAlmostEquals(results["."], 0.15) + self.assertAlmostEqual(results["./test"], 0.05) + self.assertAlmostEqual(results["."], 0.15) def test_no_timing(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -45,9 +45,9 @@ class TestHogQLTimings(BaseTest): pass results = timings.to_dict() - self.assertAlmostEquals(results["./outer/inner"], 0.05) - self.assertAlmostEquals(results["./outer"], 0.15) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./outer/inner"], 0.05) + self.assertAlmostEqual(results["./outer"], 0.15) + self.assertAlmostEqual(results["."], 0.25) def test_multiple_top_level_timings(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -59,9 +59,9 @@ class TestHogQLTimings(BaseTest): pass results = timings.to_dict() - self.assertAlmostEquals(results["./first"], 0.05) - self.assertAlmostEquals(results["./second"], 0.05) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./first"], 0.05) + self.assertAlmostEqual(results["./second"], 0.05) + self.assertAlmostEqual(results["."], 0.25) def test_deeply_nested_timing(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -73,10 +73,10 @@ class TestHogQLTimings(BaseTest): pass results = timings.to_dict() - self.assertAlmostEquals(results["./a/b/c"], 0.05) - self.assertAlmostEquals(results["./a/b"], 0.15) - self.assertAlmostEquals(results["./a"], 0.25) - self.assertAlmostEquals(results["."], 0.35) + self.assertAlmostEqual(results["./a/b/c"], 0.05) + self.assertAlmostEqual(results["./a/b"], 0.15) + self.assertAlmostEqual(results["./a"], 0.25) + self.assertAlmostEqual(results["."], 0.35) def test_overlapping_keys(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -88,5 +88,5 @@ class TestHogQLTimings(BaseTest): pass results = timings.to_dict() - self.assertAlmostEquals(results["./a"], 0.1) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./a"], 0.1) + self.assertAlmostEqual(results["."], 0.25) diff --git a/posthog/hogql/timings.py b/posthog/hogql/timings.py index fca643d640b..950d0f5bf23 100644 --- a/posthog/hogql/timings.py +++ b/posthog/hogql/timings.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from time import perf_counter -from typing import Dict, List from contextlib import contextmanager from sentry_sdk import start_span @@ -11,10 +10,10 @@ from posthog.schema import QueryTiming @dataclass class HogQLTimings: # Completed time in seconds for different parts of the HogQL query - timings: Dict[str, float] = field(default_factory=dict) + timings: dict[str, float] = field(default_factory=dict) # Used for housekeeping - _timing_starts: Dict[str, float] = field(default_factory=dict) + _timing_starts: dict[str, float] = field(default_factory=dict) _timing_pointer: str = "." def __post_init__(self): @@ -37,11 +36,11 @@ class HogQLTimings: if span: span.set_tag("duration_seconds", duration) - def to_dict(self) -> Dict[str, float]: + def to_dict(self) -> dict[str, float]: timings = {**self.timings} for key, start in reversed(self._timing_starts.items()): timings[key] = timings.get(key, 0.0) + (perf_counter() - start) return timings - def to_list(self) -> List[QueryTiming]: + def to_list(self) -> list[QueryTiming]: return [QueryTiming(k=key, t=time) for key, time in self.to_dict().items()] diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py index d10e393f539..67fdd57a7df 100644 --- a/posthog/hogql/transforms/in_cohort.py +++ b/posthog/hogql/transforms/in_cohort.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, cast, Literal +from typing import Optional, cast, Literal from posthog.hogql import ast @@ -13,7 +13,7 @@ from posthog.hogql.visitor import TraversingVisitor, clone_expr def resolve_in_cohorts( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): InCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) @@ -23,13 +23,13 @@ def resolve_in_cohorts_conjoined( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], context: HogQLContext, - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, ): MultipleInCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) class CohortCompareOperationTraverser(TraversingVisitor): - ops: List[ast.CompareOperation] = [] + ops: list[ast.CompareOperation] = [] def __init__(self, expr: ast.Expr): self.ops = [] @@ -50,10 +50,10 @@ class MultipleInCohortResolver(TraversingVisitor): self, dialect: Literal["hogql", "clickhouse"], context: HogQLContext, - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, ): super().__init__() - self.stack: List[ast.SelectQuery] = stack or [] + self.stack: list[ast.SelectQuery] = stack or [] self.context = context self.dialect = dialect @@ -68,7 +68,7 @@ class MultipleInCohortResolver(TraversingVisitor): self.stack.pop() - def _execute(self, node: ast.SelectQuery, compare_operations: List[ast.CompareOperation]): + def _execute(self, node: ast.SelectQuery, compare_operations: list[ast.CompareOperation]): if len(compare_operations) == 0: return @@ -81,11 +81,11 @@ class MultipleInCohortResolver(TraversingVisitor): compare_node.right = ast.Constant(value=1) def _resolve_cohorts( - self, compare_operations: List[ast.CompareOperation] - ) -> List[Tuple[int, StaticOrDynamic, int]]: + self, compare_operations: list[ast.CompareOperation] + ) -> list[tuple[int, StaticOrDynamic, int]]: from posthog.models import Cohort - cohorts: List[Tuple[int, StaticOrDynamic, int]] = [] + cohorts: list[tuple[int, StaticOrDynamic, int]] = [] for node in compare_operations: arg = node.right @@ -132,9 +132,9 @@ class MultipleInCohortResolver(TraversingVisitor): def _add_join( self, - cohorts: List[Tuple[int, StaticOrDynamic, int]], + cohorts: list[tuple[int, StaticOrDynamic, int]], select: ast.SelectQuery, - compare_operations: List[ast.CompareOperation], + compare_operations: list[ast.CompareOperation], ): must_add_join = True last_join = select.select_from @@ -264,11 +264,11 @@ class InCohortResolver(TraversingVisitor): def __init__( self, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() - self.stack: List[ast.SelectQuery] = stack or [] + self.stack: list[ast.SelectQuery] = stack or [] self.context = context self.dialect = dialect diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index bd3a3550034..c010fc13ce4 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, List, Optional, cast, Literal +from typing import Optional, cast, Literal from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -13,7 +13,7 @@ from posthog.hogql.visitor import TraversingVisitor, clone_expr def resolve_lazy_tables( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): LazyTableResolver(stack=stack, context=context, dialect=dialect).visit(node) @@ -21,7 +21,7 @@ def resolve_lazy_tables( @dataclasses.dataclass class JoinToAdd: - fields_accessed: Dict[str, List[str | int]] + fields_accessed: dict[str, list[str | int]] lazy_join: LazyJoin from_table: str to_table: str @@ -29,7 +29,7 @@ class JoinToAdd: @dataclasses.dataclass class TableToAdd: - fields_accessed: Dict[str, List[str | int]] + fields_accessed: dict[str, list[str | int]] lazy_table: LazyTable @@ -37,13 +37,13 @@ class TableToAdd: class ConstraintOverride: alias: str table_name: str - chain_to_replace: List[str | int] + chain_to_replace: list[str | int] class FieldChainReplacer(TraversingVisitor): - overrides: List[ConstraintOverride] = {} + overrides: list[ConstraintOverride] = {} - def __init__(self, overrides: List[ConstraintOverride]) -> None: + def __init__(self, overrides: list[ConstraintOverride]) -> None: super().__init__() self.overrides = overrides @@ -58,7 +58,7 @@ class LazyFinder(TraversingVisitor): max_type_visits: int = 3 def __init__(self) -> None: - self.visited_field_type_counts: Dict[int, int] = {} + self.visited_field_type_counts: dict[int, int] = {} def visit_lazy_join_type(self, node: ast.LazyJoinType): self.found_lazy = True @@ -80,11 +80,11 @@ class LazyTableResolver(TraversingVisitor): def __init__( self, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() - self.stack_of_fields: List[List[ast.FieldType | ast.PropertyType]] = [[]] if stack else [] + self.stack_of_fields: list[list[ast.FieldType | ast.PropertyType]] = [[]] if stack else [] self.context = context self.dialect: Literal["hogql", "clickhouse"] = dialect @@ -129,30 +129,30 @@ class LazyTableResolver(TraversingVisitor): assert select_type is not None # Collect each `ast.Field` with `ast.LazyJoinType` - field_collector: List[ast.FieldType | ast.PropertyType] = [] + field_collector: list[ast.FieldType | ast.PropertyType] = [] self.stack_of_fields.append(field_collector) # Collect all visited fields on lazy tables into field_collector super().visit_select_query(node) # Collect all the joins we need to add to the select query - joins_to_add: Dict[str, JoinToAdd] = {} - tables_to_add: Dict[str, TableToAdd] = {} + joins_to_add: dict[str, JoinToAdd] = {} + tables_to_add: dict[str, TableToAdd] = {} # First properties, then fields. This way we always get the smallest units to query first. - matched_properties: List[ast.PropertyType | ast.FieldType] = [ + matched_properties: list[ast.PropertyType | ast.FieldType] = [ property for property in field_collector if isinstance(property, ast.PropertyType) ] - matched_fields: List[ast.PropertyType | ast.FieldType] = [ + matched_fields: list[ast.PropertyType | ast.FieldType] = [ field for field in field_collector if isinstance(field, ast.FieldType) ] - sorted_properties: List[ast.PropertyType | ast.FieldType] = matched_properties + matched_fields + sorted_properties: list[ast.PropertyType | ast.FieldType] = matched_properties + matched_fields # Look for tables without requested fields to support cases like `select count() from table` join = node.select_from while join: if join.table is not None and isinstance(join.table.type, ast.LazyTableType): - fields: List[ast.FieldType | ast.PropertyType] = [] + fields: list[ast.FieldType | ast.PropertyType] = [] for field_or_property in field_collector: if isinstance(field_or_property, ast.FieldType): if isinstance(field_or_property.table_type, ast.TableAliasType): @@ -186,7 +186,7 @@ class LazyTableResolver(TraversingVisitor): # Traverse the lazy tables until we reach a real table, collecting them in a list. # Usually there's just one or two. - table_types: List[ast.LazyJoinType | ast.LazyTableType | ast.TableAliasType] = [] + table_types: list[ast.LazyJoinType | ast.LazyTableType | ast.TableAliasType] = [] while ( isinstance(table_type, ast.TableAliasType) or isinstance(table_type, ast.LazyJoinType) @@ -217,12 +217,12 @@ class LazyTableResolver(TraversingVisitor): ) new_join = joins_to_add[to_table] if table_type == field.table_type: - chain: List[str | int] = [] + chain: list[str | int] = [] chain.append(field.name) if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_join.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -241,7 +241,7 @@ class LazyTableResolver(TraversingVisitor): if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_table.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -259,12 +259,12 @@ class LazyTableResolver(TraversingVisitor): ) new_join = joins_to_add[to_table] if table_type == field.table_type: - chain: List[str | int] = [] + chain: list[str | int] = [] chain.append(field.name) if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_join.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -283,7 +283,7 @@ class LazyTableResolver(TraversingVisitor): if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_table.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -291,10 +291,10 @@ class LazyTableResolver(TraversingVisitor): # Make sure we also add fields we will use for the join's "ON" condition into the list of fields accessed. # Without this "pdi.person.id" won't work if you did not ALSO select "pdi.person_id" explicitly for the join. - join_constraint_overrides: Dict[str, List[ConstraintOverride]] = {} + join_constraint_overrides: dict[str, list[ConstraintOverride]] = {} - def create_override(table_name: str, field_chain: List[str | int]) -> None: - alias = f"{table_name}___{'___'.join((str(x) for x in field_chain))}" + def create_override(table_name: str, field_chain: list[str | int]) -> None: + alias = f"{table_name}___{'___'.join(str(x) for x in field_chain)}" if table_name in tables_to_add: tables_to_add[table_name].fields_accessed[alias] = field_chain @@ -387,7 +387,7 @@ class LazyTableResolver(TraversingVisitor): node.select_from = join_to_add # Collect any fields or properties that may have been added from the join_function with the LazyJoinType - join_field_collector: List[ast.FieldType | ast.PropertyType] = [] + join_field_collector: list[ast.FieldType | ast.PropertyType] = [] self.stack_of_fields.append(join_field_collector) super().visit(join_to_add) self.stack_of_fields.pop() diff --git a/posthog/hogql/transforms/property_types.py b/posthog/hogql/transforms/property_types.py index cc5451bf6bc..5627980fa0d 100644 --- a/posthog/hogql/transforms/property_types.py +++ b/posthog/hogql/transforms/property_types.py @@ -1,4 +1,4 @@ -from typing import Dict, Set, Literal, Optional, cast +from typing import Literal, Optional, cast from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -81,9 +81,9 @@ class PropertyFinder(TraversingVisitor): def __init__(self, context: HogQLContext): super().__init__() - self.person_properties: Set[str] = set() - self.event_properties: Set[str] = set() - self.group_properties: Dict[int, Set[str]] = {} + self.person_properties: set[str] = set() + self.event_properties: set[str] = set() + self.group_properties: dict[int, set[str]] = {} self.found_timestamps = False self.context = context @@ -123,9 +123,9 @@ class PropertySwapper(CloningVisitor): def __init__( self, timezone: str, - event_properties: Dict[str, str], - person_properties: Dict[str, str], - group_properties: Dict[str, str], + event_properties: dict[str, str], + person_properties: dict[str, str], + group_properties: dict[str, str], context: HogQLContext, ): super().__init__(clear_types=False) diff --git a/posthog/hogql_queries/actor_strategies.py b/posthog/hogql_queries/actor_strategies.py index d05661d4edd..41cd8d5a1bf 100644 --- a/posthog/hogql_queries/actor_strategies.py +++ b/posthog/hogql_queries/actor_strategies.py @@ -1,4 +1,4 @@ -from typing import Dict, List, cast, Literal, Optional +from typing import cast, Literal, Optional from django.db.models import Prefetch @@ -21,19 +21,19 @@ class ActorStrategy: self.paginator = paginator self.query = query - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: raise NotImplementedError() def get_recordings(self, matching_events) -> dict[str, list[dict]]: return {} - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: raise NotImplementedError() - def filter_conditions(self) -> List[ast.Expr]: + def filter_conditions(self) -> list[ast.Expr]: return [] - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: return None @@ -42,7 +42,7 @@ class PersonStrategy(ActorStrategy): origin = "persons" origin_id = "id" - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: return { str(p.uuid): { "id": p.uuid, @@ -58,11 +58,11 @@ class PersonStrategy(ActorStrategy): def get_recordings(self, matching_events) -> dict[str, list[dict]]: return RecordingsHelper(self.team).get_recordings(matching_events) - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: return ["person", "id", "created_at", "person.$delete"] - def filter_conditions(self) -> List[ast.Expr]: - where_exprs: List[ast.Expr] = [] + def filter_conditions(self) -> list[ast.Expr]: + where_exprs: list[ast.Expr] = [] if self.query.properties: where_exprs.append(property_to_expr(self.query.properties, self.team, scope="person")) @@ -98,7 +98,7 @@ class PersonStrategy(ActorStrategy): ) return where_exprs - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: if self.query.orderBy not in [["person"], ["person DESC"], ["person ASC"]]: return None @@ -125,7 +125,7 @@ class GroupStrategy(ActorStrategy): self.group_type_index = group_type_index super().__init__(**kwargs) - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: return { str(p["group_key"]): { "id": p["group_key"], @@ -140,11 +140,11 @@ class GroupStrategy(ActorStrategy): .iterator(chunk_size=self.paginator.limit) } - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: return ["group"] - def filter_conditions(self) -> List[ast.Expr]: - where_exprs: List[ast.Expr] = [] + def filter_conditions(self) -> list[ast.Expr]: + where_exprs: list[ast.Expr] = [] if self.query.search is not None and self.query.search != "": where_exprs.append( @@ -166,7 +166,7 @@ class GroupStrategy(ActorStrategy): return where_exprs - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: if self.query.orderBy not in [["group"], ["group DESC"], ["group ASC"]]: return None diff --git a/posthog/hogql_queries/actors_query_runner.py b/posthog/hogql_queries/actors_query_runner.py index da2e142bf66..8224067c24d 100644 --- a/posthog/hogql_queries/actors_query_runner.py +++ b/posthog/hogql_queries/actors_query_runner.py @@ -1,6 +1,7 @@ import itertools from datetime import timedelta -from typing import List, Generator, Sequence, Iterator, Optional +from typing import Optional +from collections.abc import Generator, Sequence, Iterator from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_order_expr from posthog.hogql.property import has_aggregation @@ -53,7 +54,7 @@ class ActorsQueryRunner(QueryRunner): actors_lookup, recordings_column_index: Optional[int], recordings_lookup: Optional[dict[str, list[dict]]], - ) -> Generator[List, None, None]: + ) -> Generator[list, None, None]: for result in results: new_row = list(result) actor_id = str(result[actor_column_index]) @@ -70,9 +71,7 @@ class ActorsQueryRunner(QueryRunner): return None, None column_index_events = input_columns.index("matched_recordings") - matching_events_list = itertools.chain.from_iterable( - (row[column_index_events] for row in self.paginator.results) - ) + matching_events_list = itertools.chain.from_iterable(row[column_index_events] for row in self.paginator.results) return column_index_events, self.strategy.get_recordings(matching_events_list) def calculate(self) -> ActorsQueryResponse: @@ -85,7 +84,7 @@ class ActorsQueryRunner(QueryRunner): ) input_columns = self.input_columns() missing_actors_count = None - results: Sequence[List] | Iterator[List] = self.paginator.results + results: Sequence[list] | Iterator[list] = self.paginator.results enrich_columns = filter(lambda column: column in ("person", "group", "actor"), input_columns) for column_name in enrich_columns: @@ -110,14 +109,14 @@ class ActorsQueryRunner(QueryRunner): **self.paginator.response_params(), ) - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: if self.query.select: return self.query.select return self.strategy.input_columns() # TODO: Figure out a more sure way of getting the actor id than using the alias or chain name - def source_id_column(self, source_query: ast.SelectQuery | ast.SelectUnionQuery) -> List[str]: + def source_id_column(self, source_query: ast.SelectQuery | ast.SelectUnionQuery) -> list[str]: # Figure out the id column of the source query, first column that has id in the name if isinstance(source_query, ast.SelectQuery): select = source_query.select diff --git a/posthog/hogql_queries/events_query_runner.py b/posthog/hogql_queries/events_query_runner.py index fe04ed8aa85..9dc329e9e46 100644 --- a/posthog/hogql_queries/events_query_runner.py +++ b/posthog/hogql_queries/events_query_runner.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Dict, List, Optional +from typing import Optional from dateutil.parser import isoparse from django.db.models import Prefetch @@ -53,8 +53,8 @@ class EventsQueryRunner(QueryRunner): with self.timings.measure("build_ast"): # columns & group_by with self.timings.measure("columns"): - select_input: List[str] = [] - person_indices: List[int] = [] + select_input: list[str] = [] + person_indices: list[int] = [] for index, col in enumerate(self.select_input_raw()): # Selecting a "*" expands the list of columns, resulting in a table that's not what we asked for. # Instead, ask for a tuple with all the columns we want. Later transform this back into a dict. @@ -66,11 +66,11 @@ class EventsQueryRunner(QueryRunner): person_indices.append(index) else: select_input.append(col) - select: List[ast.Expr] = [parse_expr(column, timings=self.timings) for column in select_input] + select: list[ast.Expr] = [parse_expr(column, timings=self.timings) for column in select_input] with self.timings.measure("aggregations"): - group_by: List[ast.Expr] = [column for column in select if not has_aggregation(column)] - aggregations: List[ast.Expr] = [column for column in select if has_aggregation(column)] + group_by: list[ast.Expr] = [column for column in select if not has_aggregation(column)] + aggregations: list[ast.Expr] = [column for column in select if has_aggregation(column)] has_any_aggregation = len(aggregations) > 0 # filters @@ -210,7 +210,7 @@ class EventsQueryRunner(QueryRunner): ).data self.paginator.results[index][star_idx] = new_result - person_indices: List[int] = [] + person_indices: list[int] = [] for index, col in enumerate(self.select_input_raw()): if col.split("--")[0].strip() == "person": person_indices.append(index) @@ -222,7 +222,7 @@ class EventsQueryRunner(QueryRunner): distinct_ids = list({event[person_idx] for event in self.paginator.results}) persons = get_persons_by_distinct_ids(self.team.pk, distinct_ids) persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache")) - distinct_to_person: Dict[str, Person] = {} + distinct_to_person: dict[str, Person] = {} for person in persons: if person: for person_distinct_id in person.distinct_ids: @@ -268,7 +268,7 @@ class EventsQueryRunner(QueryRunner): return new_query - def select_input_raw(self) -> List[str]: + def select_input_raw(self) -> list[str]: return ["*"] if len(self.query.select) == 0 else self.query.select def _is_stale(self, cached_result_package): diff --git a/posthog/hogql_queries/hogql_query_runner.py b/posthog/hogql_queries/hogql_query_runner.py index 46b4c105a43..3a9a0b62efd 100644 --- a/posthog/hogql_queries/hogql_query_runner.py +++ b/posthog/hogql_queries/hogql_query_runner.py @@ -1,5 +1,6 @@ from datetime import timedelta -from typing import Callable, Dict, Optional, cast +from typing import Optional, cast +from collections.abc import Callable from posthog.clickhouse.client.connection import Workload from posthog.hogql import ast @@ -26,7 +27,7 @@ class HogQLQueryRunner(QueryRunner): def to_query(self) -> ast.SelectQuery: if self.timings is None: self.timings = HogQLTimings() - values: Optional[Dict[str, ast.Expr]] = ( + values: Optional[dict[str, ast.Expr]] = ( {key: ast.Constant(value=value) for key, value in self.query.values.items()} if self.query.values else None ) with self.timings.measure("parse_select"): diff --git a/posthog/hogql_queries/insights/funnels/base.py b/posthog/hogql_queries/insights/funnels/base.py index 1dade0de4b0..40614464f13 100644 --- a/posthog/hogql_queries/insights/funnels/base.py +++ b/posthog/hogql_queries/insights/funnels/base.py @@ -1,6 +1,6 @@ from abc import ABC from functools import cached_property -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import uuid from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.hogql import ast @@ -37,14 +37,14 @@ from rest_framework.exceptions import ValidationError class FunnelBase(ABC): context: FunnelQueryContext - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] def __init__(self, context: FunnelQueryContext): self.context = context - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if ( hasattr(self.context, "actorsQuery") @@ -86,7 +86,7 @@ class FunnelBase(ABC): raise NotImplementedError() @cached_property - def breakdown_cohorts(self) -> List[Cohort]: + def breakdown_cohorts(self) -> list[Cohort]: team, breakdown = self.context.team, self.context.breakdown if isinstance(breakdown, list): @@ -97,7 +97,7 @@ class FunnelBase(ABC): return list(cohorts) @cached_property - def breakdown_cohorts_ids(self) -> List[int]: + def breakdown_cohorts_ids(self) -> list[int]: breakdown = self.context.breakdown ids = [int(cohort.pk) for cohort in self.breakdown_cohorts] @@ -108,7 +108,7 @@ class FunnelBase(ABC): return ids @cached_property - def breakdown_values(self) -> List[int] | List[str] | List[List[str]]: + def breakdown_values(self) -> list[int] | list[str] | list[list[str]]: # """ # Returns the top N breakdown prop values for event/person breakdown @@ -169,7 +169,7 @@ class FunnelBase(ABC): else: prop_exprs = [] - where_exprs: List[ast.Expr | None] = [ + where_exprs: list[ast.Expr | None] = [ # entity filter entity_expr, # prop filter @@ -209,7 +209,7 @@ class FunnelBase(ABC): raise ValidationError("Apologies, there has been an error computing breakdown values.") return [row[0] for row in results[0:breakdown_limit_or_default]] - def _get_breakdown_select_prop(self) -> List[ast.Expr]: + def _get_breakdown_select_prop(self) -> list[ast.Expr]: breakdown, breakdownAttributionType, funnelsFilter = ( self.context.breakdown, self.context.breakdownAttributionType, @@ -296,7 +296,7 @@ class FunnelBase(ABC): def _format_results( self, results - ) -> Union[FunnelTimeToConvertResults, List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + ) -> Union[FunnelTimeToConvertResults, list[dict[str, Any]], list[list[dict[str, Any]]]]: breakdown = self.context.breakdown if not results or len(results) == 0: @@ -387,9 +387,9 @@ class FunnelBase(ABC): step: ActionsNode | EventsNode | DataWarehouseNode, count: int, index: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: action_id: Optional[str | int] if isinstance(step, EventsNode): name = step.event @@ -419,7 +419,7 @@ class FunnelBase(ABC): def _get_inner_event_query( self, - entities: List[EntityNode] | None = None, + entities: list[EntityNode] | None = None, entity_name="events", skip_entity_filter=False, skip_step_filter=False, @@ -433,7 +433,7 @@ class FunnelBase(ABC): ) entities_to_use = entities or query.series - extra_fields: List[str] = [] + extra_fields: list[str] = [] for prop in self.context.includeProperties: extra_fields.append(prop) @@ -450,7 +450,7 @@ class FunnelBase(ABC): # extra_event_properties=self._extra_event_properties, # ).get_query(entities_to_use, entity_name, skip_entity_filter=skip_entity_filter) - all_step_cols: List[ast.Expr] = [] + all_step_cols: list[ast.Expr] = [] for index, entity in enumerate(entities_to_use): step_cols = self._get_step_col(entity, index, entity_name) all_step_cols.extend(step_cols) @@ -489,7 +489,7 @@ class FunnelBase(ABC): def _get_cohort_breakdown_join(self) -> ast.JoinExpr: breakdown = self.context.breakdown - cohort_queries: List[ast.SelectQuery] = [] + cohort_queries: list[ast.SelectQuery] = [] for cohort in self.breakdown_cohorts: query = parse_select( @@ -564,7 +564,7 @@ class FunnelBase(ABC): return query def _get_steps_conditions(self, length: int) -> ast.Expr: - step_conditions: List[ast.Expr] = [] + step_conditions: list[ast.Expr] = [] for index in range(length): step_conditions.append(parse_expr(f"step_{index} = 1")) @@ -580,10 +580,10 @@ class FunnelBase(ABC): index: int, entity_name: str, step_prefix: str = "", - ) -> List[ast.Expr]: + ) -> list[ast.Expr]: # step prefix is used to distinguish actual steps, and exclusion steps # without the prefix, we get the same parameter binding for both, which borks things up - step_cols: List[ast.Expr] = [] + step_cols: list[ast.Expr] = [] condition = self._build_step_query(entity, index, entity_name, step_prefix) step_cols.append( parse_expr(f"if({{condition}}, 1, 0) as {step_prefix}step_{index}", placeholders={"condition": condition}) @@ -626,7 +626,7 @@ class FunnelBase(ABC): else: return event_expr - def _get_timestamp_outer_select(self) -> List[ast.Expr]: + def _get_timestamp_outer_select(self) -> list[ast.Expr]: if self.context.includePrecedingTimestamp: return [ast.Field(chain=["max_timestamp"]), ast.Field(chain=["min_timestamp"])] elif self.context.includeTimestamp: @@ -646,7 +646,7 @@ class FunnelBase(ABC): funnelCustomSteps = actorsQuery.funnelCustomSteps funnelStepBreakdown = actorsQuery.funnelStepBreakdown - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] if funnelCustomSteps: conditions.append(parse_expr(f"steps IN {funnelCustomSteps}")) @@ -673,7 +673,7 @@ class FunnelBase(ABC): return ast.And(exprs=conditions) - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -694,23 +694,23 @@ class FunnelBase(ABC): return [parse_expr(f"step_{matching_events_step_num}_matching_events as matching_events")] return [] - def _get_count_columns(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_count_columns(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(max_steps): exprs.append(parse_expr(f"countIf(steps = {i + 1}) step_{i + 1}")) return exprs - def _get_step_time_names(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_names(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append(parse_expr(f"step_{i}_conversion_time")) return exprs - def _get_final_matching_event(self, max_steps: int) -> List[ast.Expr]: + def _get_final_matching_event(self, max_steps: int) -> list[ast.Expr]: statement = None for i in range(max_steps - 1, -1, -1): if i == max_steps - 1: @@ -721,7 +721,7 @@ class FunnelBase(ABC): statement = f"if(isNull(latest_{i}),step_{i-1}_matching_event,{statement})" return [parse_expr(f"{statement} as final_matching_event")] if statement else [] - def _get_matching_events(self, max_steps: int) -> List[ast.Expr]: + def _get_matching_events(self, max_steps: int) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -737,8 +737,8 @@ class FunnelBase(ABC): return [*events, *self._get_final_matching_event(max_steps)] return [] - def _get_matching_event_arrays(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_matching_event_arrays(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -749,8 +749,8 @@ class FunnelBase(ABC): exprs.append(parse_expr(f"groupArray(10)(final_matching_event) as final_matching_events")) return exprs - def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -761,8 +761,8 @@ class FunnelBase(ABC): return exprs - def _get_step_time_median(self, max_steps: int, inner_query: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_median(self, max_steps: int, inner_query: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -773,7 +773,7 @@ class FunnelBase(ABC): return exprs - def _get_timestamp_selects(self) -> Tuple[List[ast.Expr], List[ast.Expr]]: + def _get_timestamp_selects(self) -> tuple[list[ast.Expr], list[ast.Expr]]: """ Returns timestamp selectors for the target step and optionally the preceding step. In the former case, always returns the timestamp for the first and last step as well. @@ -829,11 +829,11 @@ class FunnelBase(ABC): else: return [], [] - def _get_step_times(self, max_steps: int) -> List[ast.Expr]: + def _get_step_times(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -844,12 +844,12 @@ class FunnelBase(ABC): return exprs - def _get_partition_cols(self, level_index: int, max_steps: int) -> List[ast.Expr]: + def _get_partition_cols(self, level_index: int, max_steps: int) -> list[ast.Expr]: query, funnelsFilter = self.context.query, self.context.funnelsFilter exclusions = funnelsFilter.exclusions series = query.series - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(0, max_steps): exprs.append(ast.Field(chain=[f"step_{i}"])) @@ -894,7 +894,7 @@ class FunnelBase(ABC): return exprs - def _get_breakdown_prop_expr(self, group_remaining=False) -> List[ast.Expr]: + def _get_breakdown_prop_expr(self, group_remaining=False) -> list[ast.Expr]: # SEE BELOW for a string implementation of the following breakdown, breakdownType = self.context.breakdown, self.context.breakdownType @@ -938,7 +938,7 @@ class FunnelBase(ABC): else: return "" - def _get_breakdown_conditions(self) -> Optional[List[int] | List[str] | List[List[str]]]: + def _get_breakdown_conditions(self) -> Optional[list[int] | list[str] | list[list[str]]]: """ For people, pagination sets the offset param, which is common across filters and gives us the wrong breakdown values here, so we override it. @@ -957,7 +957,7 @@ class FunnelBase(ABC): breakdown, breakdownType = self.context.breakdown, self.context.breakdownType return not isinstance(breakdown, str) and breakdownType != "cohort" - def _get_exclusion_condition(self) -> List[ast.Expr]: + def _get_exclusion_condition(self) -> list[ast.Expr]: funnelsFilter = self.context.funnelsFilter windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -965,7 +965,7 @@ class FunnelBase(ABC): if not funnelsFilter.exclusions: return [] - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for exclusion_id, exclusion in enumerate(funnelsFilter.exclusions): from_time = f"latest_{exclusion.funnelFromStep}" @@ -995,7 +995,7 @@ class FunnelBase(ABC): if curr_index == 1: return ast.Constant(value=1) - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for i in range(1, curr_index): duplicate_event = is_equal(series[i], series[i - 1]) or is_superset(series[i], series[i - 1]) @@ -1016,8 +1016,8 @@ class FunnelBase(ABC): ], ) - def _get_person_and_group_properties(self, aggregate: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_person_and_group_properties(self, aggregate: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for prop in self.context.includeProperties: exprs.append(parse_expr(f"any({prop}) as {prop}") if aggregate else parse_expr(prop)) diff --git a/posthog/hogql_queries/insights/funnels/funnel.py b/posthog/hogql_queries/insights/funnels/funnel.py index b5ce2bb7faf..1975645d753 100644 --- a/posthog/hogql_queries/insights/funnels/funnel.py +++ b/posthog/hogql_queries/insights/funnels/funnel.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -35,7 +33,7 @@ class Funnel(FunnelBase): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -54,13 +52,13 @@ class Funnel(FunnelBase): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -74,7 +72,7 @@ class Funnel(FunnelBase): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -106,7 +104,7 @@ class Funnel(FunnelBase): formatted_query = self._build_step_subquery(2, max_steps) breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["*"]), ast.Alias(alias="steps", expr=self._get_sorting_condition(max_steps, max_steps)), *self._get_exclusion_condition(), @@ -135,7 +133,7 @@ class Funnel(FunnelBase): def _build_step_subquery( self, level_index: int, max_steps: int, event_names_alias: str = "events" ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["timestamp"]), ] @@ -175,12 +173,12 @@ class Funnel(FunnelBase): ), ) - def _get_comparison_cols(self, level_index: int, max_steps: int) -> List[ast.Expr]: + def _get_comparison_cols(self, level_index: int, max_steps: int) -> list[ast.Expr]: """ level_index: The current smallest comparison step. Everything before level index is already at the minimum ordered timestamps. """ - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] funnelsFilter = self.context.funnelsFilter exclusions = funnelsFilter.exclusions @@ -225,7 +223,7 @@ class Funnel(FunnelBase): return exprs def _get_comparison_at_step(self, index: int, level_index: int) -> ast.Or: - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(level_index, index + 1): exprs.append(parse_expr(f"latest_{i} < latest_{level_index - 1}")) diff --git a/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py b/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py index 04b1115fd38..035339c8e02 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py +++ b/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py @@ -1,6 +1,6 @@ import dataclasses from datetime import timedelta -from typing import List, Literal, Optional, Any, Dict, Set, TypedDict, cast +from typing import Literal, Optional, Any, TypedDict, cast from posthog.constants import AUTOCAPTURE_EVENT from posthog.hogql.parser import parse_select @@ -95,7 +95,7 @@ class FunnelCorrelationQueryRunner(QueryRunner): def __init__( self, - query: FunnelCorrelationQuery | Dict[str, Any], + query: FunnelCorrelationQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -132,7 +132,7 @@ class FunnelCorrelationQueryRunner(QueryRunner): # Used for generating the funnel persons cte funnel_order_actor_class = get_funnel_actor_class(self.context.funnelsFilter)(context=self.context) assert isinstance( - funnel_order_actor_class, (FunnelActors, FunnelStrictActors, FunnelUnorderedActors) + funnel_order_actor_class, FunnelActors | FunnelStrictActors | FunnelUnorderedActors ) # for typings self._funnel_actors_generator = funnel_order_actor_class @@ -228,7 +228,7 @@ class FunnelCorrelationQueryRunner(QueryRunner): modifiers=self.modifiers, ) - def _calculate(self) -> tuple[List[EventOddsRatio], bool, str, HogQLQueryResponse]: + def _calculate(self) -> tuple[list[EventOddsRatio], bool, str, HogQLQueryResponse]: query = self.to_query() hogql = to_printed_hogql(query, self.team) @@ -823,8 +823,8 @@ class FunnelCorrelationQueryRunner(QueryRunner): props_str = ", ".join(props) return f"arrayJoin(arrayZip({self.query.funnelCorrelationNames}, [{props_str}])) as prop" - def _get_funnel_step_names(self) -> List[str]: - events: Set[str] = set() + def _get_funnel_step_names(self) -> list[str]: + events: set[str] = set() for entity in self.funnels_query.series: if isinstance(entity, ActionsNode): action = Action.objects.get(pk=int(entity.id), team=self.context.team) @@ -838,8 +838,8 @@ class FunnelCorrelationQueryRunner(QueryRunner): return sorted(events) @property - def properties_to_include(self) -> List[str]: - props_to_include: List[str] = [] + def properties_to_include(self) -> list[str]: + props_to_include: list[str] = [] # TODO: implement or remove # if self.query.funnelCorrelationType == FunnelCorrelationResultsType.properties: # assert self.query.funnelCorrelationNames is not None diff --git a/posthog/hogql_queries/insights/funnels/funnel_event_query.py b/posthog/hogql_queries/insights/funnels/funnel_event_query.py index b2fd19083ed..8acb0f7dea8 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_event_query.py +++ b/posthog/hogql_queries/insights/funnels/funnel_event_query.py @@ -1,4 +1,4 @@ -from typing import List, Set, Union, Optional +from typing import Union, Optional from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -13,16 +13,16 @@ from rest_framework.exceptions import ValidationError class FunnelEventQuery: context: FunnelQueryContext - _extra_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] EVENT_TABLE_ALIAS = "e" def __init__( self, context: FunnelQueryContext, - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, ): if extra_event_properties is None: extra_event_properties = [] @@ -38,12 +38,12 @@ class FunnelEventQuery: # entities=None, # TODO: implement passed in entities when needed skip_entity_filter=False, ) -> ast.SelectQuery: - _extra_fields: List[ast.Expr] = [ + _extra_fields: list[ast.Expr] = [ ast.Alias(alias=field, expr=ast.Field(chain=[self.EVENT_TABLE_ALIAS, field])) for field in self._extra_fields ] - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="timestamp", expr=ast.Field(chain=[self.EVENT_TABLE_ALIAS, "timestamp"])), ast.Alias(alias="aggregation_target", expr=self._aggregation_target_expr()), *_extra_fields, @@ -132,7 +132,7 @@ class FunnelEventQuery: if skip_entity_filter is True: return None - events: Set[Union[int, str, None]] = set() + events: set[Union[int, str, None]] = set() for node in [*query.series, *exclusions]: if isinstance(node, EventsNode) or isinstance(node, FunnelExclusionEventsNode): @@ -157,5 +157,5 @@ class FunnelEventQuery: op=ast.CompareOperationOp.In, ) - def _properties_expr(self) -> List[ast.Expr]: + def _properties_expr(self) -> list[ast.Expr]: return Properties(context=self.context).to_exprs() diff --git a/posthog/hogql_queries/insights/funnels/funnel_persons.py b/posthog/hogql_queries/insights/funnels/funnel_persons.py index 68781c6bbd0..5fc06a07a7d 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql_queries.insights.funnels.funnel import Funnel @@ -7,9 +7,9 @@ from posthog.hogql_queries.insights.funnels.funnel import Funnel class FunnelActors(Funnel): def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnel_query_context.py b/posthog/hogql_queries/insights/funnels/funnel_query_context.py index 3b777e3ff80..499dc3eb9ed 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_query_context.py +++ b/posthog/hogql_queries/insights/funnels/funnel_query_context.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from posthog.hogql.constants import LimitContext from posthog.hogql.timings import HogQLTimings from posthog.hogql_queries.insights.query_context import QueryContext @@ -25,7 +25,7 @@ class FunnelQueryContext(QueryContext): interval: IntervalType - breakdown: List[Union[str, int]] | str | int | None + breakdown: list[Union[str, int]] | str | int | None breakdownType: BreakdownType breakdownAttributionType: BreakdownAttributionType @@ -36,7 +36,7 @@ class FunnelQueryContext(QueryContext): includeTimestamp: Optional[bool] includePrecedingTimestamp: Optional[bool] - includeProperties: List[str] + includeProperties: list[str] includeFinalMatchingEvents: Optional[bool] def __init__( @@ -48,7 +48,7 @@ class FunnelQueryContext(QueryContext): limit_context: Optional[LimitContext] = None, include_timestamp: Optional[bool] = None, include_preceding_timestamp: Optional[bool] = None, - include_properties: Optional[List[str]] = None, + include_properties: Optional[list[str]] = None, include_final_matching_events: Optional[bool] = None, ): super().__init__(query=query, team=team, timings=timings, modifiers=modifiers, limit_context=limit_context) @@ -98,7 +98,7 @@ class FunnelQueryContext(QueryContext): "hogql", None, ]: - boxed_breakdown: List[Union[str, int]] = box_value(self.breakdownFilter.breakdown) + boxed_breakdown: list[Union[str, int]] = box_value(self.breakdownFilter.breakdown) self.breakdown = boxed_breakdown else: self.breakdown = self.breakdownFilter.breakdown # type: ignore diff --git a/posthog/hogql_queries/insights/funnels/funnel_strict.py b/posthog/hogql_queries/insights/funnels/funnel_strict.py index 1bea66772a6..1b5bf73ad50 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_strict.py +++ b/posthog/hogql_queries/insights/funnels/funnel_strict.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -11,7 +9,7 @@ class FunnelStrict(FunnelBase): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -30,13 +28,13 @@ class FunnelStrict(FunnelBase): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -50,7 +48,7 @@ class FunnelStrict(FunnelBase): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -77,7 +75,7 @@ class FunnelStrict(FunnelBase): def get_step_counts_without_aggregation_query(self): max_steps = self.context.max_steps - select_inner: List[ast.Expr] = [ + select_inner: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["timestamp"]), *self._get_partition_cols(1, max_steps), @@ -87,7 +85,7 @@ class FunnelStrict(FunnelBase): select_from_inner = self._get_inner_event_query(skip_entity_filter=True, skip_step_filter=True) inner_query = ast.SelectQuery(select=select_inner, select_from=ast.JoinExpr(table=select_from_inner)) - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["*"]), ast.Alias(alias="steps", expr=self._get_sorting_condition(max_steps, max_steps)), *self._get_step_times(max_steps), @@ -101,7 +99,7 @@ class FunnelStrict(FunnelBase): return ast.SelectQuery(select=select, select_from=select_from, where=where) def _get_partition_cols(self, level_index: int, max_steps: int): - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(0, max_steps): exprs.append(ast.Field(chain=[f"step_{i}"])) diff --git a/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py b/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py index f55afbd2182..299bd982b97 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql_queries.insights.funnels.funnel_strict import FunnelStrict @@ -7,9 +7,9 @@ from posthog.hogql_queries.insights.funnels.funnel_strict import FunnelStrict class FunnelStrictActors(FunnelStrict): def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnel_trends.py b/posthog/hogql_queries/insights/funnels/funnel_trends.py index 9d486f1b061..964f5d05cc6 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_trends.py +++ b/posthog/hogql_queries/insights/funnels/funnel_trends.py @@ -1,6 +1,6 @@ from datetime import datetime from itertools import groupby -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -58,7 +58,7 @@ class FunnelTrends(FunnelBase): self.just_summarize = just_summarize self.funnel_order = get_funnel_order_class(self.context.funnelsFilter)(context=self.context) - def _format_results(self, results) -> List[Dict[str, Any]]: + def _format_results(self, results) -> list[dict[str, Any]]: query = self.context.query breakdown_clause = self._get_breakdown_prop() @@ -75,7 +75,7 @@ class FunnelTrends(FunnelBase): if breakdown_clause: if isinstance(period_row[-1], str) or ( - isinstance(period_row[-1], List) and all(isinstance(item, str) for item in period_row[-1]) + isinstance(period_row[-1], list) and all(isinstance(item, str) for item in period_row[-1]) ): serialized_result.update({"breakdown_value": (period_row[-1])}) else: @@ -145,7 +145,7 @@ class FunnelTrends(FunnelBase): breakdown_clause = self._get_breakdown_prop_expr() - data_select: List[ast.Expr] = [ + data_select: list[ast.Expr] = [ ast.Field(chain=["entrance_period_start"]), parse_expr(f"countIf({reached_from_step_count_condition}) AS reached_from_step_count"), parse_expr(f"countIf({reached_to_step_count_condition}) AS reached_to_step_count"), @@ -163,10 +163,10 @@ class FunnelTrends(FunnelBase): args=[ast.Call(name="toDateTime", args=[(ast.Constant(value=formatted_date_to))])], ) data_select_from = ast.JoinExpr(table=step_counts) - data_group_by: List[ast.Expr] = [ast.Field(chain=["entrance_period_start"]), *breakdown_clause] + data_group_by: list[ast.Expr] = [ast.Field(chain=["entrance_period_start"]), *breakdown_clause] data_query = ast.SelectQuery(select=data_select, select_from=data_select_from, group_by=data_group_by) - fill_select: List[ast.Expr] = [ + fill_select: list[ast.Expr] = [ ast.Alias( alias="entrance_period_start", expr=ast.ArithmeticOperation( @@ -249,7 +249,7 @@ class FunnelTrends(FunnelBase): ), ) - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["fill", "entrance_period_start"]), ast.Field(chain=["reached_from_step_count"]), ast.Field(chain=["reached_to_step_count"]), @@ -263,7 +263,7 @@ class FunnelTrends(FunnelBase): alias="data", next_join=fill_join, ) - order_by: List[ast.OrderExpr] = [ + order_by: list[ast.OrderExpr] = [ ast.OrderExpr(expr=ast.Field(chain=["fill", "entrance_period_start"]), order="ASC") ] @@ -281,7 +281,7 @@ class FunnelTrends(FunnelBase): steps_per_person_query = self.funnel_order.get_step_counts_without_aggregation_query() - event_select_clause: List[ast.Expr] = [] + event_select_clause: list[ast.Expr] = [] if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -291,7 +291,7 @@ class FunnelTrends(FunnelBase): breakdown_clause = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Alias(alias="entrance_period_start", expr=get_start_of_interval_hogql(interval.value, team=team)), parse_expr("max(steps) AS steps_completed"), @@ -309,7 +309,7 @@ class FunnelTrends(FunnelBase): if specific_entrance_period_start else None ) - group_by: List[ast.Expr] = [ + group_by: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["entrance_period_start"]), *breakdown_clause, @@ -317,7 +317,7 @@ class FunnelTrends(FunnelBase): return ast.SelectQuery(select=select, select_from=select_from, where=where, group_by=group_by) - def get_steps_reached_conditions(self) -> Tuple[str, str, str]: + def get_steps_reached_conditions(self) -> tuple[str, str, str]: funnelsFilter, max_steps = self.context.funnelsFilter, self.context.max_steps # How many steps must have been done to count for the denominator of a funnel trends data point diff --git a/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py b/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py index c90a9ed5762..c124265ba65 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List from rest_framework.exceptions import ValidationError @@ -39,7 +38,7 @@ class FunnelTrendsActors(FunnelTrends): self.dropOff = actorsQuery.funnelTrendsDropOff self.entrancePeriodStart = entrancePeriodStart - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -71,7 +70,7 @@ class FunnelTrendsActors(FunnelTrends): did_not_reach_to_step_count_condition, ) = self.get_steps_reached_conditions() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), ] diff --git a/posthog/hogql_queries/insights/funnels/funnel_unordered.py b/posthog/hogql_queries/insights/funnels/funnel_unordered.py index af3ed18d4f8..4ac87866d7f 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_unordered.py +++ b/posthog/hogql_queries/insights/funnels/funnel_unordered.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional import uuid from rest_framework.exceptions import ValidationError @@ -45,7 +45,7 @@ class FunnelUnordered(FunnelBase): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -64,13 +64,13 @@ class FunnelUnordered(FunnelBase): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -82,7 +82,7 @@ class FunnelUnordered(FunnelBase): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -106,7 +106,7 @@ class FunnelUnordered(FunnelBase): def get_step_counts_without_aggregation_query(self): max_steps = self.context.max_steps - union_queries: List[ast.SelectQuery] = [] + union_queries: list[ast.SelectQuery] = [] entities_to_use = list(self.context.query.series) for i in range(max_steps): @@ -153,11 +153,11 @@ class FunnelUnordered(FunnelBase): return ast.SelectUnionQuery(select_queries=union_queries) - def _get_step_times(self, max_steps: int) -> List[ast.Expr]: + def _get_step_times(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] conversion_times_elements = [] for i in range(max_steps): @@ -175,7 +175,7 @@ class FunnelUnordered(FunnelBase): return exprs - def get_sorting_condition(self, max_steps: int) -> List[ast.Expr]: + def get_sorting_condition(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -187,7 +187,7 @@ class FunnelUnordered(FunnelBase): conditions.append(parse_expr(f"arraySort([{','.join(event_times_elements)}]) as event_times")) # replacement of latest_i for whatever query part requires it, just like conversion_times - basic_conditions: List[str] = [] + basic_conditions: list[str] = [] for i in range(1, max_steps): basic_conditions.append( f"if(latest_0 < latest_{i} AND latest_{i} <= toTimeZone(latest_0, 'UTC') + INTERVAL {windowInterval} {windowIntervalUnit}, 1, 0)" @@ -199,7 +199,7 @@ class FunnelUnordered(FunnelBase): else: return [ast.Alias(alias="steps", expr=ast.Constant(value=1))] - def _get_exclusion_condition(self) -> List[ast.Expr]: + def _get_exclusion_condition(self) -> list[ast.Expr]: funnelsFilter = self.context.funnelsFilter windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -207,7 +207,7 @@ class FunnelUnordered(FunnelBase): if not funnelsFilter.exclusions: return [] - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for exclusion_id, exclusion in enumerate(funnelsFilter.exclusions): from_time = f"latest_{exclusion.funnelFromStep}" @@ -233,9 +233,9 @@ class FunnelUnordered(FunnelBase): step: ActionsNode | EventsNode | DataWarehouseNode, count: int, index: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if isinstance(step, DataWarehouseNode): raise NotImplementedError("Data Warehouse queries are not supported in funnels") diff --git a/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py b/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py index a378f044b5d..ad1086bdc33 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -6,7 +6,7 @@ from posthog.hogql_queries.insights.funnels.funnel_unordered import FunnelUnorde class FunnelUnorderedActors(FunnelUnordered): - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: # Unordered funnels does not support matching events (and thereby recordings), # but it simplifies the logic if we return an empty array for matching events if ( @@ -19,9 +19,9 @@ class FunnelUnorderedActors(FunnelUnordered): def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnels_query_runner.py b/posthog/hogql_queries/insights/funnels/funnels_query_runner.py index d2ec04e3e84..3e1173b2760 100644 --- a/posthog/hogql_queries/insights/funnels/funnels_query_runner.py +++ b/posthog/hogql_queries/insights/funnels/funnels_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import Optional, Any, Dict +from typing import Optional, Any from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -37,7 +37,7 @@ class FunnelsQueryRunner(QueryRunner): def __init__( self, - query: FunnelsQuery | Dict[str, Any], + query: FunnelsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py b/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py index db5e882963e..2b1b08f4445 100644 --- a/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py +++ b/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from datetime import datetime from string import ascii_lowercase -from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast +from collections.abc import Callable from posthog.constants import INSIGHT_FUNNELS, FunnelOrderType from posthog.hogql_queries.insights.funnels.funnels_query_runner import FunnelsQueryRunner @@ -30,7 +31,7 @@ from posthog.test.test_journeys import journeys_for class FunnelStepResult: name: str count: int - breakdown: Union[List[str], str] + breakdown: Union[list[str], str] average_conversion_time: Optional[float] = None median_conversion_time: Optional[float] = None type: Literal["events", "actions"] = "events" @@ -51,8 +52,8 @@ def funnel_breakdown_test_factory( return [val["id"] for val in serialized_result] - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -2695,8 +2696,8 @@ def funnel_breakdown_group_test_factory(FunnelPerson): properties={"industry": "random"}, ) - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -3067,11 +3068,11 @@ def funnel_breakdown_group_test_factory(FunnelPerson): return TestFunnelBreakdownGroup -def sort_breakdown_funnel_results(results: List[Dict[int, Any]]): +def sort_breakdown_funnel_results(results: list[dict[int, Any]]): return sorted(results, key=lambda r: r[0]["breakdown_value"]) -def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str, Any]]): +def assert_funnel_results_equal(left: list[dict[str, Any]], right: list[dict[str, Any]]): """ Helper to be able to compare two funnel results, but exclude people urls from the comparison, as these include: @@ -3081,7 +3082,7 @@ def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str 2. contain timestamps which are not stable across runs """ - def _filter(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _filter(steps: list[dict[str, Any]]) -> list[dict[str, Any]]: return [{**step, "converted_people_url": None, "dropped_people_url": None} for step in steps] assert len(left) == len(right) diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py index 859f3e627aa..aef262ba22e 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, cast, Optional +from typing import cast, Optional from posthog.hogql_queries.insights.funnels.funnels_query_runner import FunnelsQueryRunner from posthog.hogql_queries.legacy_compatibility.filter_to_query import filter_to_query @@ -116,7 +116,7 @@ class TestFunnelBreakdownsByCurrentURL(ClickhouseTestMixin, APIBaseTest): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py index f69eb3c6977..4db744a6d92 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast import unittest from rest_framework.exceptions import ValidationError @@ -77,7 +77,7 @@ class TestClickhouseFunnelCorrelation(ClickhouseTestMixin, APIBaseTest): result, skewed_totals, _, _ = FunnelCorrelationQueryRunner(query=correlation_query, team=self.team)._calculate() return result, skewed_totals - def _get_actors_for_event(self, filters: Dict[str, Any], event_name: str, properties=None, success=True): + def _get_actors_for_event(self, filters: dict[str, Any], event_name: str, properties=None, success=True): serialized_actors = get_actors( filters, self.team, @@ -87,7 +87,7 @@ class TestClickhouseFunnelCorrelation(ClickhouseTestMixin, APIBaseTest): return [str(row[0]) for row in serialized_actors] def _get_actors_for_property( - self, filters: Dict[str, Any], property_values: list, success=True, funnelCorrelationNames=None + self, filters: dict[str, Any], property_values: list, success=True, funnelCorrelationNames=None ): funnelCorrelationPropertyValues = [ ( diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py index f324dcfcf7c..223b24a949b 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from datetime import datetime, timedelta from uuid import UUID @@ -37,7 +37,7 @@ PERSON_ID_COLUMN = 2 def get_actors( - filters: Dict[str, Any], + filters: dict[str, Any], team: Team, funnelCorrelationType: Optional[FunnelCorrelationResultsType] = FunnelCorrelationResultsType.events, funnelCorrelationNames=None, diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py index dec7bdd933b..37d9b853404 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Dict, List, Optional, cast, Any +from typing import Optional, cast, Any from uuid import UUID from django.utils import timezone @@ -32,11 +32,11 @@ PERSON_ID_COLUMN = 2 def get_actors( - filters: Dict[str, Any], + filters: dict[str, Any], team: Team, funnelStep: Optional[int] = None, - funnelCustomSteps: Optional[List[int]] = None, - funnelStepBreakdown: Optional[str | float | List[str | float]] = None, + funnelCustomSteps: Optional[list[int]] = None, + funnelStepBreakdown: Optional[str | float | list[str | float]] = None, funnelTrendsDropOff: Optional[bool] = None, funnelTrendsEntrancePeriodStart: Optional[str] = None, offset: Optional[int] = None, diff --git a/posthog/hogql_queries/insights/funnels/utils.py b/posthog/hogql_queries/insights/funnels/utils.py index 95374f179e1..7aea066883e 100644 --- a/posthog/hogql_queries/insights/funnels/utils.py +++ b/posthog/hogql_queries/insights/funnels/utils.py @@ -1,4 +1,3 @@ -from typing import List from posthog.constants import FUNNEL_WINDOW_INTERVAL_TYPES from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -61,7 +60,7 @@ def funnel_window_interval_unit_to_sql( def get_breakdown_expr( - breakdowns: List[str | int] | str | int, properties_column: str, normalize_url: bool | None = False + breakdowns: list[str | int] | str | int, properties_column: str, normalize_url: bool | None = False ) -> ast.Expr: if isinstance(breakdowns, str) or isinstance(breakdowns, int) or breakdowns is None: return ast.Call( diff --git a/posthog/hogql_queries/insights/lifecycle_query_runner.py b/posthog/hogql_queries/insights/lifecycle_query_runner.py index 42b35d6b4df..5e11dcdcae0 100644 --- a/posthog/hogql_queries/insights/lifecycle_query_runner.py +++ b/posthog/hogql_queries/insights/lifecycle_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import Optional, List +from typing import Optional from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -225,7 +225,7 @@ class LifecycleQueryRunner(QueryRunner): @cached_property def event_filter(self) -> ast.Expr: - event_filters: List[ast.Expr] = [] + event_filters: list[ast.Expr] = [] with self.timings.measure("date_range"): event_filters.append( parse_expr( diff --git a/posthog/hogql_queries/insights/paths_query_runner.py b/posthog/hogql_queries/insights/paths_query_runner.py index ca7890735f8..8c2bc84d821 100644 --- a/posthog/hogql_queries/insights/paths_query_runner.py +++ b/posthog/hogql_queries/insights/paths_query_runner.py @@ -3,7 +3,7 @@ from collections import defaultdict from datetime import datetime, timedelta from math import ceil from re import escape -from typing import Any, Dict, Literal, cast +from typing import Any, Literal, cast from typing import Optional from posthog.caching.insights_api import BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL, REDUCED_MINIMUM_INSIGHT_REFRESH_INTERVAL @@ -47,7 +47,7 @@ class PathsQueryRunner(QueryRunner): def __init__( self, - query: PathsQuery | Dict[str, Any], + query: PathsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/retention_query_runner.py b/posthog/hogql_queries/insights/retention_query_runner.py index ac15ded6728..f79af288ca6 100644 --- a/posthog/hogql_queries/insights/retention_query_runner.py +++ b/posthog/hogql_queries/insights/retention_query_runner.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from math import ceil -from typing import Any, Dict +from typing import Any from typing import Optional from posthog.caching.insights_api import BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL, REDUCED_MINIMUM_INSIGHT_REFRESH_INTERVAL @@ -39,7 +39,7 @@ class RetentionQueryRunner(QueryRunner): def __init__( self, - query: RetentionQuery | Dict[str, Any], + query: RetentionQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/stickiness_query_runner.py b/posthog/hogql_queries/insights/stickiness_query_runner.py index d9096f05853..24bb2504de6 100644 --- a/posthog/hogql_queries/insights/stickiness_query_runner.py +++ b/posthog/hogql_queries/insights/stickiness_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import List, Optional, Any, Dict, cast +from typing import Optional, Any, cast from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -47,11 +47,11 @@ class SeriesWithExtras: class StickinessQueryRunner(QueryRunner): query: StickinessQuery query_type = StickinessQuery - series: List[SeriesWithExtras] + series: list[SeriesWithExtras] def __init__( self, - query: StickinessQuery | Dict[str, Any], + query: StickinessQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -134,7 +134,7 @@ class StickinessQueryRunner(QueryRunner): def to_query(self) -> ast.SelectUnionQuery: return ast.SelectUnionQuery(select_queries=self.to_queries()) - def to_queries(self) -> List[ast.SelectQuery]: + def to_queries(self) -> list[ast.SelectQuery]: queries = [] for series in self.series: @@ -174,7 +174,7 @@ class StickinessQueryRunner(QueryRunner): return queries def to_actors_query(self, interval_num: Optional[int] = None) -> ast.SelectQuery | ast.SelectUnionQuery: - queries: List[ast.SelectQuery] = [] + queries: list[ast.SelectQuery] = [] for series in self.series: events_query = self._events_query(series) @@ -253,7 +253,7 @@ class StickinessQueryRunner(QueryRunner): def where_clause(self, series_with_extra: SeriesWithExtras) -> ast.Expr: date_range = self.date_range(series_with_extra) series = series_with_extra.series - filters: List[ast.Expr] = [] + filters: list[ast.Expr] = [] # Dates filters.extend( @@ -344,7 +344,7 @@ class StickinessQueryRunner(QueryRunner): else: return delta.days - def setup_series(self) -> List[SeriesWithExtras]: + def setup_series(self) -> list[SeriesWithExtras]: series_with_extras = [ SeriesWithExtras( series, diff --git a/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py b/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py index bb963cf1f8b..830ecc3982b 100644 --- a/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional +from typing import Any, Optional from freezegun import freeze_time @@ -69,7 +69,7 @@ class TestInsightActorsQueryRunner(ClickhouseTestMixin, APIBaseTest): ] ) - def select(self, query: str, placeholders: Optional[Dict[str, Any]] = None): + def select(self, query: str, placeholders: Optional[dict[str, Any]] = None): if placeholders is None: placeholders = {} return execute_hogql_query( diff --git a/posthog/hogql_queries/insights/test/test_paths_query_runner.py b/posthog/hogql_queries/insights/test/test_paths_query_runner.py index b74102ba705..0b82f33ca7e 100644 --- a/posthog/hogql_queries/insights/test/test_paths_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_paths_query_runner.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -25,7 +24,7 @@ class MockEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict class TestPaths(ClickhouseTestMixin, APIBaseTest): diff --git a/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py b/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py index 6e25827e6ec..e61f4160276 100644 --- a/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Optional, Union from unittest.mock import MagicMock, patch from django.test import override_settings @@ -41,18 +41,18 @@ from posthog.test.base import APIBaseTest, _create_event, _create_person @dataclass class Series: event: str - timestamps: List[str] + timestamps: list[str] @dataclass class SeriesTestData: distinct_id: str - events: List[Series] - properties: Dict[str, str | int] + events: list[Series] + properties: dict[str, str | int] StickinessProperties = Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -74,9 +74,9 @@ class TestStickinessQueryRunner(APIBaseTest): default_date_from = "2020-01-11" default_date_to = "2020-01-20" - def _create_events(self, data: List[SeriesTestData]): + def _create_events(self, data: list[SeriesTestData]): person_result = [] - properties_to_create: Dict[str, str] = {} + properties_to_create: dict[str, str] = {} for person in data: first_timestamp = person.events[0].timestamps[0] @@ -194,7 +194,7 @@ class TestStickinessQueryRunner(APIBaseTest): def _run_query( self, - series: Optional[List[EventsNode | ActionsNode]] = None, + series: Optional[list[EventsNode | ActionsNode]] = None, date_from: Optional[str] = None, date_to: Optional[str] = None, interval: Optional[IntervalType] = None, @@ -203,7 +203,7 @@ class TestStickinessQueryRunner(APIBaseTest): filter_test_accounts: Optional[bool] = False, limit_context: Optional[LimitContext] = None, ): - query_series: List[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series + query_series: list[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series query_date_from = date_from or self.default_date_from query_date_to = None if date_to == "now" else date_to or self.default_date_to query_interval = interval or IntervalType.day @@ -223,8 +223,8 @@ class TestStickinessQueryRunner(APIBaseTest): response = self._run_query() assert isinstance(response, StickinessQueryResponse) - assert isinstance(response.results, List) - assert isinstance(response.results[0], Dict) + assert isinstance(response.results, list) + assert isinstance(response.results[0], dict) @override_settings(PERSON_ON_EVENTS_V2_OVERRIDE=True) def test_stickiness_runs_with_poe(self): @@ -232,8 +232,8 @@ class TestStickinessQueryRunner(APIBaseTest): response = self._run_query() assert isinstance(response, StickinessQueryResponse) - assert isinstance(response.results, List) - assert isinstance(response.results[0], Dict) + assert isinstance(response.results, list) + assert isinstance(response.results[0], dict) def test_days(self): self._create_test_events() @@ -423,7 +423,7 @@ class TestStickinessQueryRunner(APIBaseTest): def test_event_filtering(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode( event="$pageview", properties=[EventPropertyFilter(key="$browser", operator=PropertyOperator.exact, value="Chrome")], @@ -450,7 +450,7 @@ class TestStickinessQueryRunner(APIBaseTest): def test_any_event(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode( event=None, ) @@ -484,7 +484,7 @@ class TestStickinessQueryRunner(APIBaseTest): properties=[{"key": "$browser", "type": "event", "value": "Chrome", "operator": "exact"}], ) - series: List[EventsNode | ActionsNode] = [ActionsNode(id=action.pk)] + series: list[EventsNode | ActionsNode] = [ActionsNode(id=action.pk)] response = self._run_query(series=series) @@ -541,7 +541,7 @@ class TestStickinessQueryRunner(APIBaseTest): self._create_test_groups() self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode(event="$pageview", math="unique_group", math_group_type_index=MathGroupTypeIndex.number_0) ] @@ -565,7 +565,7 @@ class TestStickinessQueryRunner(APIBaseTest): def test_hogql_aggregations(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode(event="$pageview", math="hogql", math_hogql="e.properties.prop") ] diff --git a/posthog/hogql_queries/insights/trends/aggregation_operations.py b/posthog/hogql_queries/insights/trends/aggregation_operations.py index 1c356277548..2e716b2b1ca 100644 --- a/posthog/hogql_queries/insights/trends/aggregation_operations.py +++ b/posthog/hogql_queries/insights/trends/aggregation_operations.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast, Union +from typing import Optional, cast, Union from posthog.constants import NON_TIME_SERIES_DISPLAY_TYPES from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_select @@ -13,8 +13,8 @@ class QueryAlternator: """Allows query_builder to modify the query without having to expost the whole AST interface""" _query: ast.SelectQuery - _selects: List[ast.Expr] - _group_bys: List[ast.Expr] + _selects: list[ast.Expr] + _group_bys: list[ast.Expr] _select_from: ast.JoinExpr | None def __init__(self, query: ast.SelectQuery | ast.SelectUnionQuery): @@ -143,7 +143,7 @@ class AggregationOperations(DataWarehouseInsightQueryMixin): "p99_count_per_actor", ] - def _math_func(self, method: str, override_chain: Optional[List[str | int]]) -> ast.Call: + def _math_func(self, method: str, override_chain: Optional[list[str | int]]) -> ast.Call: if override_chain is not None: return ast.Call(name=method, args=[ast.Field(chain=override_chain)]) @@ -167,7 +167,7 @@ class AggregationOperations(DataWarehouseInsightQueryMixin): return ast.Call(name=method, args=[ast.Field(chain=chain)]) - def _math_quantile(self, percentile: float, override_chain: Optional[List[str | int]]) -> ast.Call: + def _math_quantile(self, percentile: float, override_chain: Optional[list[str | int]]) -> ast.Call: if self.series.math_property == "$session_duration": chain = ["session_duration"] else: diff --git a/posthog/hogql_queries/insights/trends/breakdown.py b/posthog/hogql_queries/insights/trends/breakdown.py index e588ca30353..025d181bf81 100644 --- a/posthog/hogql_queries/insights/trends/breakdown.py +++ b/posthog/hogql_queries/insights/trends/breakdown.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast from posthog.hogql import ast from posthog.hogql.constants import LimitContext from posthog.hogql.parser import parse_expr @@ -30,7 +30,7 @@ class Breakdown: timings: HogQLTimings modifiers: HogQLQueryModifiers events_filter: ast.Expr - breakdown_values_override: Optional[List[str]] + breakdown_values_override: Optional[list[str]] limit_context: LimitContext def __init__( @@ -42,7 +42,7 @@ class Breakdown: timings: HogQLTimings, modifiers: HogQLQueryModifiers, events_filter: ast.Expr, - breakdown_values_override: Optional[List[str]] = None, + breakdown_values_override: Optional[list[str]] = None, limit_context: LimitContext = LimitContext.QUERY, ): self.team = team @@ -71,7 +71,7 @@ class Breakdown: def is_histogram_breakdown(self) -> bool: return self.enabled and self.query.breakdownFilter.breakdown_histogram_bin_count is not None - def placeholders(self) -> Dict[str, ast.Expr]: + def placeholders(self) -> dict[str, ast.Expr]: values = self._breakdown_buckets_ast if self.is_histogram_breakdown else self._breakdown_values_ast return {"cross_join_breakdown_values": ast.Alias(alias="breakdown_value", expr=values)} @@ -106,7 +106,7 @@ class Breakdown: if self.query.breakdownFilter.breakdown == "all": return None - if isinstance(self.query.breakdownFilter.breakdown, List): + if isinstance(self.query.breakdownFilter.breakdown, list): or_clause = ast.Or( exprs=[ ast.CompareOperation( @@ -226,10 +226,10 @@ class Breakdown: return ast.Array(exprs=exprs) @cached_property - def _all_breakdown_values(self) -> List[str | int | None]: + def _all_breakdown_values(self) -> list[str | int | None]: # Used in the actors query if self.breakdown_values_override is not None: - return cast(List[str | int | None], self.breakdown_values_override) + return cast(list[str | int | None], self.breakdown_values_override) if self.query.breakdownFilter is None: return [] @@ -245,18 +245,18 @@ class Breakdown: modifiers=self.modifiers, limit_context=self.limit_context, ) - return cast(List[str | int | None], breakdown.get_breakdown_values()) + return cast(list[str | int | None], breakdown.get_breakdown_values()) @cached_property - def _breakdown_values(self) -> List[str | int]: + def _breakdown_values(self) -> list[str | int]: values = [BREAKDOWN_NULL_STRING_LABEL if v is None else v for v in self._all_breakdown_values] - return cast(List[str | int], values) + return cast(list[str | int], values) @cached_property def has_breakdown_values(self) -> bool: return len(self._breakdown_values) > 0 - def _get_breakdown_histogram_buckets(self) -> List[Tuple[float, float]]: + def _get_breakdown_histogram_buckets(self) -> list[tuple[float, float]]: buckets = [] values = self._breakdown_values @@ -275,7 +275,7 @@ class Breakdown: return buckets def _get_breakdown_histogram_multi_if(self) -> ast.Expr: - multi_if_exprs: List[ast.Expr] = [] + multi_if_exprs: list[ast.Expr] = [] buckets = self._get_breakdown_histogram_buckets() diff --git a/posthog/hogql_queries/insights/trends/breakdown_values.py b/posthog/hogql_queries/insights/trends/breakdown_values.py index 6a9b9a24a22..b15897b360f 100644 --- a/posthog/hogql_queries/insights/trends/breakdown_values.py +++ b/posthog/hogql_queries/insights/trends/breakdown_values.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Any +from typing import Optional, Union, Any from posthog.hogql import ast from posthog.hogql.constants import LimitContext, get_breakdown_limit_for_context, BREAKDOWN_VALUES_LIMIT_FOR_COUNTRIES from posthog.hogql.parser import parse_expr, parse_select @@ -30,7 +30,7 @@ BREAKDOWN_NULL_DISPLAY = "None (i.e. no value)" class BreakdownValues: team: Team series: Union[EventsNode, ActionsNode, DataWarehouseNode] - breakdown_field: Union[str, float, List[Union[str, float]]] + breakdown_field: Union[str, float, list[Union[str, float]]] breakdown_type: BreakdownType events_filter: ast.Expr chart_display_type: ChartDisplayType @@ -76,12 +76,12 @@ class BreakdownValues: self.query_date_range = query_date_range self.modifiers = modifiers - def get_breakdown_values(self) -> List[str | int]: + def get_breakdown_values(self) -> list[str | int]: if self.breakdown_type == "cohort": if self.breakdown_field == "all": return [0] - if isinstance(self.breakdown_field, List): + if isinstance(self.breakdown_field, list): return [value if isinstance(value, str) else int(value) for value in self.breakdown_field] return [self.breakdown_field if isinstance(self.breakdown_field, str) else int(self.breakdown_field)] @@ -186,7 +186,7 @@ class BreakdownValues: ): inner_events_query.order_by[0].order = "ASC" - values: List[Any] + values: list[Any] if self.histogram_bin_count is not None: query = parse_select( """ diff --git a/posthog/hogql_queries/insights/trends/test/test_trends.py b/posthog/hogql_queries/insights/trends/test/test_trends.py index 8ba4aea1b34..f34229e99de 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import patch from zoneinfo import ZoneInfo @@ -68,8 +68,8 @@ from posthog.test.base import ( from posthog.test.test_journeys import journeys_for -def breakdown_label(entity: Entity, value: Union[str, int]) -> Dict[str, Optional[Union[str, int]]]: - ret_dict: Dict[str, Optional[Union[str, int]]] = {} +def breakdown_label(entity: Entity, value: Union[str, int]) -> dict[str, Optional[Union[str, int]]]: + ret_dict: dict[str, Optional[Union[str, int]]] = {} if not value or not isinstance(value, str) or "cohort_" not in value: label = value if (value or isinstance(value, bool)) and value != "None" and value != "nan" else "Other" ret_dict["label"] = f"{entity.name} - {label}" @@ -103,7 +103,7 @@ def _create_cohort(**kwargs): return cohort -def _props(dict: Dict): +def _props(dict: dict): props = dict.get("properties", None) if not props: return None @@ -125,11 +125,11 @@ def _props(dict: Dict): def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: filter_as_dict = filter.to_dict() - events: List[EventsNode] = [] - actions: List[ActionsNode] = [] + events: list[EventsNode] = [] + actions: list[ActionsNode] = [] for event in filter.events: - if isinstance(event._data.get("properties", None), List): + if isinstance(event._data.get("properties", None), list): properties = clean_entity_properties(event._data.get("properties", None)) elif event._data.get("properties", None) is not None: values = event._data.get("properties", None).get("values", None) @@ -151,7 +151,7 @@ def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: ) for action in filter.actions: - if isinstance(action._data.get("properties", None), List): + if isinstance(action._data.get("properties", None), list): properties = clean_entity_properties(action._data.get("properties", None)) elif action._data.get("properties", None) is not None: values = action._data.get("properties", None).get("values", None) @@ -172,7 +172,7 @@ def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: ) ) - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = [*events, *actions] + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = [*events, *actions] tq = TrendsQuery( series=series, @@ -304,7 +304,7 @@ class TestTrends(ClickhouseTestMixin, APIBaseTest): type=PropertyDefinition.Type.GROUP, ) - def _create_events(self, use_time=False) -> Tuple[Action, Person]: + def _create_events(self, use_time=False) -> tuple[Action, Person]: person = self._create_person( team_id=self.team.pk, distinct_ids=["blabla", "anonymous_id"], @@ -2080,7 +2080,7 @@ class TestTrends(ClickhouseTestMixin, APIBaseTest): ], ) - def _test_events_with_dates(self, dates: List[str], result, query_time=None, **filter_params): + def _test_events_with_dates(self, dates: list[str], result, query_time=None, **filter_params): self._create_person(team_id=self.team.pk, distinct_ids=["person_1"], properties={"name": "John"}) for time in dates: with freeze_time(time): diff --git a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py index 573bbf2c12e..772d7192272 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py @@ -1,7 +1,7 @@ import zoneinfo from dataclasses import dataclass from datetime import datetime -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import MagicMock, patch from django.test import override_settings from freezegun import freeze_time @@ -49,14 +49,14 @@ from posthog.test.base import ( @dataclass class Series: event: str - timestamps: List[str] + timestamps: list[str] @dataclass class SeriesTestData: distinct_id: str - events: List[Series] - properties: Dict[str, str | int] + events: list[Series] + properties: dict[str, str | int] @override_settings(IN_UNIT_TESTING=True) @@ -64,9 +64,9 @@ class TestTrendsQueryRunner(ClickhouseTestMixin, APIBaseTest): default_date_from = "2020-01-09" default_date_to = "2020-01-19" - def _create_events(self, data: List[SeriesTestData]): + def _create_events(self, data: list[SeriesTestData]): person_result = [] - properties_to_create: Dict[str, str] = {} + properties_to_create: dict[str, str] = {} for person in data: first_timestamp = person.events[0].timestamps[0] @@ -174,7 +174,7 @@ class TestTrendsQueryRunner(ClickhouseTestMixin, APIBaseTest): date_from: str, date_to: Optional[str], interval: IntervalType, - series: Optional[List[EventsNode | ActionsNode]], + series: Optional[list[EventsNode | ActionsNode]], trends_filters: Optional[TrendsFilter] = None, breakdown: Optional[BreakdownFilter] = None, filter_test_accounts: Optional[bool] = None, @@ -182,7 +182,7 @@ class TestTrendsQueryRunner(ClickhouseTestMixin, APIBaseTest): limit_context: Optional[LimitContext] = None, explicit_date: Optional[bool] = None, ) -> TrendsQueryRunner: - query_series: List[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series + query_series: list[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series query = TrendsQuery( dateRange=DateRange(date_from=date_from, date_to=date_to, explicitDate=explicit_date), interval=interval, @@ -198,7 +198,7 @@ class TestTrendsQueryRunner(ClickhouseTestMixin, APIBaseTest): date_from: str, date_to: Optional[str], interval: IntervalType, - series: Optional[List[EventsNode | ActionsNode]], + series: Optional[list[EventsNode | ActionsNode]], trends_filters: Optional[TrendsFilter] = None, breakdown: Optional[BreakdownFilter] = None, *, diff --git a/posthog/hogql_queries/insights/trends/trends_query_builder.py b/posthog/hogql_queries/insights/trends/trends_query_builder.py index 82fbb849ef5..072f371e4c0 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_builder.py +++ b/posthog/hogql_queries/insights/trends/trends_query_builder.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Optional, cast from posthog.hogql import ast from posthog.hogql.constants import LimitContext from posthog.hogql.parser import parse_expr, parse_select @@ -98,7 +98,7 @@ class TrendsQueryBuilder(DataWarehouseInsightQueryMixin): }, ) - def _get_date_subqueries(self, breakdown: Breakdown, ignore_breakdowns: bool = False) -> List[ast.SelectQuery]: + def _get_date_subqueries(self, breakdown: Breakdown, ignore_breakdowns: bool = False) -> list[ast.SelectQuery]: if not breakdown.enabled or ignore_breakdowns: return [ cast( @@ -473,7 +473,7 @@ class TrendsQueryBuilder(DataWarehouseInsightQueryMixin): actors_query_time_frame: Optional[str] = None, ) -> ast.Expr: series = self.series - filters: List[ast.Expr] = [] + filters: list[ast.Expr] = [] # Dates if is_actors_query and actors_query_time_frame is not None: diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 8629d17ec92..6ceb2dd1857 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -7,7 +7,7 @@ from itertools import groupby from math import ceil from operator import itemgetter import threading -from typing import List, Optional, Any, Dict +from typing import Optional, Any from dateutil import parser from dateutil.relativedelta import relativedelta from django.conf import settings @@ -70,11 +70,11 @@ from posthog.utils import format_label_date, multisort class TrendsQueryRunner(QueryRunner): query: TrendsQuery query_type = TrendsQuery - series: List[SeriesWithExtras] + series: list[SeriesWithExtras] def __init__( self, - query: TrendsQuery | Dict[str, Any], + query: TrendsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -115,7 +115,7 @@ class TrendsQueryRunner(QueryRunner): queries.extend(query.select_queries) return ast.SelectUnionQuery(select_queries=queries) - def to_queries(self) -> List[ast.SelectQuery | ast.SelectUnionQuery]: + def to_queries(self) -> list[ast.SelectQuery | ast.SelectUnionQuery]: queries = [] with self.timings.measure("trends_to_query"): for series in self.series: @@ -184,9 +184,9 @@ class TrendsQueryRunner(QueryRunner): return query def to_actors_query_options(self) -> InsightActorsQueryOptionsResponse: - res_breakdown: List[BreakdownItem] | None = None - res_series: List[Series] = [] - res_compare: List[CompareItem] | None = None + res_breakdown: list[BreakdownItem] | None = None + res_series: list[Series] = [] + res_compare: list[CompareItem] | None = None # Days res_days: Optional[list[DayItem]] = ( @@ -239,7 +239,7 @@ class TrendsQueryRunner(QueryRunner): is_boolean_breakdown = self._is_breakdown_field_boolean() is_histogram_breakdown = breakdown.is_histogram_breakdown - breakdown_values: List[str | int] + breakdown_values: list[str | int] res_breakdown = [] if is_histogram_breakdown: @@ -289,9 +289,9 @@ class TrendsQueryRunner(QueryRunner): with self.timings.measure("printing_hogql_for_response"): response_hogql = to_printed_hogql(response_hogql_query, self.team, self.modifiers) - res_matrix: List[List[Any] | Any | None] = [None] * len(queries) - timings_matrix: List[List[QueryTiming] | None] = [None] * len(queries) - errors: List[Exception] = [] + res_matrix: list[list[Any] | Any | None] = [None] * len(queries) + timings_matrix: list[list[QueryTiming] | None] = [None] * len(queries) + errors: list[Exception] = [] def run(index: int, query: ast.SelectQuery | ast.SelectUnionQuery, is_parallel: bool): try: @@ -342,14 +342,14 @@ class TrendsQueryRunner(QueryRunner): # Flatten res and timings res = [] for result in res_matrix: - if isinstance(result, List): + if isinstance(result, list): res.extend(result) else: res.append(result) timings = [] for result in timings_matrix: - if isinstance(result, List): + if isinstance(result, list): timings.extend(result) else: timings.append(result) @@ -555,7 +555,7 @@ class TrendsQueryRunner(QueryRunner): self.modifiers.inCohortVia == InCohortVia.auto and self.query.breakdownFilter is not None and self.query.breakdownFilter.breakdown_type == "cohort" - and isinstance(self.query.breakdownFilter.breakdown, List) + and isinstance(self.query.breakdownFilter.breakdown, list) and len(self.query.breakdownFilter.breakdown) > 1 and not any(value == "all" for value in self.query.breakdownFilter.breakdown) ): @@ -575,7 +575,7 @@ class TrendsQueryRunner(QueryRunner): self.modifiers.dataWarehouseEventsModifiers = datawarehouse_modifiers - def setup_series(self) -> List[SeriesWithExtras]: + def setup_series(self) -> list[SeriesWithExtras]: series_with_extras = [ SeriesWithExtras( series=series, @@ -593,7 +593,7 @@ class TrendsQueryRunner(QueryRunner): and self.query.breakdownFilter.breakdown_type == "cohort" ): updated_series = [] - if isinstance(self.query.breakdownFilter.breakdown, List): + if isinstance(self.query.breakdownFilter.breakdown, list): cohort_ids = self.query.breakdownFilter.breakdown elif self.query.breakdownFilter.breakdown is not None: cohort_ids = [self.query.breakdownFilter.breakdown] @@ -642,7 +642,7 @@ class TrendsQueryRunner(QueryRunner): return series_with_extras - def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def apply_formula(self, formula: str, results: list[dict[str, Any]]) -> list[dict[str, Any]]: has_compare = bool(self.query.trendsFilter and self.query.trendsFilter.compare) has_breakdown = bool(self.query.breakdownFilter and self.query.breakdownFilter.breakdown) is_total_value = self._trends_display.should_aggregate_values() @@ -694,8 +694,8 @@ class TrendsQueryRunner(QueryRunner): @staticmethod def apply_formula_to_results_group( - results_group: List[Dict[str, Any]], formula: str, aggregate_values: Optional[bool] = False - ) -> Dict[str, Any]: + results_group: list[dict[str, Any]], formula: str, aggregate_values: Optional[bool] = False + ) -> dict[str, Any]: """ Applies the formula to a list of results, resulting in a single, computed result. """ @@ -787,7 +787,7 @@ class TrendsQueryRunner(QueryRunner): return "String" # TODO: Move this to posthog/hogql_queries/legacy_compatibility/query_to_filter.py - def _query_to_filter(self) -> Dict[str, Any]: + def _query_to_filter(self) -> dict[str, Any]: filter_dict = { "insight": "TRENDS", "properties": self.query.properties, diff --git a/posthog/hogql_queries/insights/trends/utils.py b/posthog/hogql_queries/insights/trends/utils.py index 61a4252d499..b8f6c3989f1 100644 --- a/posthog/hogql_queries/insights/trends/utils.py +++ b/posthog/hogql_queries/insights/trends/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from posthog.schema import ActionsNode, DataWarehouseNode, EventsNode, BreakdownType @@ -12,7 +12,7 @@ def get_properties_chain( breakdown_type: BreakdownType | None, breakdown_field: str, group_type_index: Optional[float | int], -) -> List[str | int]: +) -> list[str | int]: if breakdown_type == "person": return ["person", "properties", breakdown_field] diff --git a/posthog/hogql_queries/insights/utils/properties.py b/posthog/hogql_queries/insights/utils/properties.py index ea4770037b7..41826b28535 100644 --- a/posthog/hogql_queries/insights/utils/properties.py +++ b/posthog/hogql_queries/insights/utils/properties.py @@ -1,11 +1,11 @@ -from typing import List, TypeAlias +from typing import TypeAlias from posthog.hogql import ast from posthog.hogql.property import property_to_expr from posthog.hogql_queries.insights.query_context import QueryContext from posthog.schema import PropertyGroupFilter from posthog.types import AnyPropertyFilter -PropertiesType: TypeAlias = List[AnyPropertyFilter] | PropertyGroupFilter | None +PropertiesType: TypeAlias = list[AnyPropertyFilter] | PropertyGroupFilter | None class Properties: @@ -17,8 +17,8 @@ class Properties: ) -> None: self.context = context - def to_exprs(self) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def to_exprs(self) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] team, query = self.context.team, self.context.query diff --git a/posthog/hogql_queries/insights/utils/utils.py b/posthog/hogql_queries/insights/utils/utils.py index c3b99c6a3b6..747d7e2b6ca 100644 --- a/posthog/hogql_queries/insights/utils/utils.py +++ b/posthog/hogql_queries/insights/utils/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.models.team.team import Team, WeekStartDay from posthog.queries.util import get_trunc_func_ch @@ -6,7 +6,7 @@ from posthog.queries.util import get_trunc_func_ch def get_start_of_interval_hogql(interval: str, *, team: Team, source: Optional[ast.Expr] = None) -> ast.Expr: trunc_func = get_trunc_func_ch(interval) - trunc_func_args: List[ast.Expr] = [source] if source else [ast.Field(chain=["timestamp"])] + trunc_func_args: list[ast.Expr] = [source] if source else [ast.Field(chain=["timestamp"])] if trunc_func == "toStartOfWeek": trunc_func_args.append(ast.Constant(value=int((WeekStartDay(team.week_start_day or 0)).clickhouse_mode))) return ast.Call(name=trunc_func, args=trunc_func_args) diff --git a/posthog/hogql_queries/legacy_compatibility/filter_to_query.py b/posthog/hogql_queries/legacy_compatibility/filter_to_query.py index 382b37fa56d..fdeb74fc907 100644 --- a/posthog/hogql_queries/legacy_compatibility/filter_to_query.py +++ b/posthog/hogql_queries/legacy_compatibility/filter_to_query.py @@ -1,7 +1,7 @@ import copy from enum import Enum import json -from typing import Any, List, Dict, Literal +from typing import Any, Literal from posthog.hogql_queries.legacy_compatibility.clean_properties import clean_entity_properties, clean_global_properties from posthog.models.entity.entity import Entity as LegacyEntity from posthog.schema import ( @@ -118,7 +118,7 @@ def exlusion_entity_to_node(entity) -> FunnelExclusionEventsNode | FunnelExclusi # TODO: remove this method that returns legacy entities -def to_base_entity_dict(entity: Dict): +def to_base_entity_dict(entity: dict): return { "type": entity.get("type"), "id": entity.get("id"), @@ -140,7 +140,7 @@ insight_to_query_type = { INSIGHT_TYPE = Literal["TRENDS", "FUNNELS", "RETENTION", "PATHS", "LIFECYCLE", "STICKINESS"] -def _date_range(filter: Dict): +def _date_range(filter: dict): date_range = DateRange( date_from=filter.get("date_from"), date_to=filter.get("date_to"), @@ -153,7 +153,7 @@ def _date_range(filter: Dict): return {"dateRange": date_range} -def _interval(filter: Dict): +def _interval(filter: dict): if _insight_type(filter) == "RETENTION" or _insight_type(filter) == "PATHS": return {} @@ -163,7 +163,7 @@ def _interval(filter: Dict): return {"interval": filter.get("interval")} -def _series(filter: Dict): +def _series(filter: dict): if _insight_type(filter) == "RETENTION" or _insight_type(filter) == "PATHS": return {} @@ -188,8 +188,8 @@ def _series(filter: Dict): } -def _entities(filter: Dict): - processed_entities: List[LegacyEntity] = [] +def _entities(filter: dict): + processed_entities: list[LegacyEntity] = [] # add actions actions = filter.get("actions", []) @@ -213,7 +213,7 @@ def _entities(filter: Dict): return processed_entities -def _sampling_factor(filter: Dict): +def _sampling_factor(filter: dict): if isinstance(filter.get("sampling_factor"), str): try: return float(filter.get("sampling_factor")) @@ -223,16 +223,16 @@ def _sampling_factor(filter: Dict): return {"samplingFactor": filter.get("sampling_factor")} -def _properties(filter: Dict): +def _properties(filter: dict): raw_properties = filter.get("properties", None) return {"properties": clean_global_properties(raw_properties)} -def _filter_test_accounts(filter: Dict): +def _filter_test_accounts(filter: dict): return {"filterTestAccounts": filter.get("filter_test_accounts")} -def _breakdown_filter(_filter: Dict): +def _breakdown_filter(_filter: dict): if _insight_type(_filter) != "TRENDS" and _insight_type(_filter) != "FUNNELS": return {} @@ -275,13 +275,13 @@ def _breakdown_filter(_filter: Dict): return {"breakdownFilter": BreakdownFilter(**breakdownFilter)} -def _group_aggregation_filter(filter: Dict): +def _group_aggregation_filter(filter: dict): if _insight_type(filter) == "STICKINESS" or _insight_type(filter) == "LIFECYCLE": return {} return {"aggregation_group_type_index": filter.get("aggregation_group_type_index")} -def _insight_filter(filter: Dict): +def _insight_filter(filter: dict): if _insight_type(filter) == "TRENDS": insight_filter = { "trendsFilter": TrendsFilter( @@ -387,7 +387,7 @@ def _insight_filter(filter: Dict): return insight_filter -def filters_to_funnel_paths_query(filter: Dict[str, Any]) -> FunnelPathsFilter | None: +def filters_to_funnel_paths_query(filter: dict[str, Any]) -> FunnelPathsFilter | None: funnel_paths = filter.get("funnel_paths") funnel_filter = filter.get("funnel_filter") @@ -404,13 +404,13 @@ def filters_to_funnel_paths_query(filter: Dict[str, Any]) -> FunnelPathsFilter | ) -def _insight_type(filter: Dict) -> INSIGHT_TYPE: +def _insight_type(filter: dict) -> INSIGHT_TYPE: if filter.get("insight") == "SESSIONS": return "TRENDS" return filter.get("insight", "TRENDS") -def filter_to_query(filter: Dict) -> InsightQueryNode: +def filter_to_query(filter: dict) -> InsightQueryNode: filter = copy.deepcopy(filter) # duplicate to prevent accidental filter alterations Query = insight_to_query_type[_insight_type(filter)] diff --git a/posthog/hogql_queries/query_runner.py b/posthog/hogql_queries/query_runner.py index 1ddd336dee9..0e49ca84963 100644 --- a/posthog/hogql_queries/query_runner.py +++ b/posthog/hogql_queries/query_runner.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import IntEnum -from typing import Any, Generic, List, Optional, Type, Dict, TypeVar, Union, Tuple, cast, TypeGuard +from typing import Any, Generic, Optional, TypeVar, Union, cast, TypeGuard from django.conf import settings from django.core.cache import cache @@ -76,9 +76,9 @@ class QueryResponse(BaseModel, Generic[DataT]): extra="forbid", ) results: DataT - timings: Optional[List[QueryTiming]] = None - types: Optional[List[Union[Tuple[str, str], str]]] = None - columns: Optional[List[str]] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list[Union[tuple[str, str], str]]] = None + columns: Optional[list[str]] = None hogql: Optional[str] = None hasMore: Optional[bool] = None limit: Optional[int] = None @@ -128,7 +128,7 @@ RunnableQueryNode = Union[ def get_query_runner( - query: Dict[str, Any] | RunnableQueryNode | BaseModel, + query: dict[str, Any] | RunnableQueryNode | BaseModel, team: Team, timings: Optional[HogQLTimings] = None, limit_context: Optional[LimitContext] = None, @@ -146,7 +146,7 @@ def get_query_runner( from .insights.trends.trends_query_runner import TrendsQueryRunner return TrendsQueryRunner( - query=cast(TrendsQuery | Dict[str, Any], query), + query=cast(TrendsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -156,7 +156,7 @@ def get_query_runner( from .insights.funnels.funnels_query_runner import FunnelsQueryRunner return FunnelsQueryRunner( - query=cast(FunnelsQuery | Dict[str, Any], query), + query=cast(FunnelsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -166,7 +166,7 @@ def get_query_runner( from .insights.retention_query_runner import RetentionQueryRunner return RetentionQueryRunner( - query=cast(RetentionQuery | Dict[str, Any], query), + query=cast(RetentionQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -176,7 +176,7 @@ def get_query_runner( from .insights.paths_query_runner import PathsQueryRunner return PathsQueryRunner( - query=cast(PathsQuery | Dict[str, Any], query), + query=cast(PathsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -186,7 +186,7 @@ def get_query_runner( from .insights.stickiness_query_runner import StickinessQueryRunner return StickinessQueryRunner( - query=cast(StickinessQuery | Dict[str, Any], query), + query=cast(StickinessQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -196,7 +196,7 @@ def get_query_runner( from .insights.lifecycle_query_runner import LifecycleQueryRunner return LifecycleQueryRunner( - query=cast(LifecycleQuery | Dict[str, Any], query), + query=cast(LifecycleQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -206,7 +206,7 @@ def get_query_runner( from .events_query_runner import EventsQueryRunner return EventsQueryRunner( - query=cast(EventsQuery | Dict[str, Any], query), + query=cast(EventsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -216,7 +216,7 @@ def get_query_runner( from .actors_query_runner import ActorsQueryRunner return ActorsQueryRunner( - query=cast(ActorsQuery | Dict[str, Any], query), + query=cast(ActorsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -226,7 +226,7 @@ def get_query_runner( from .insights.insight_actors_query_runner import InsightActorsQueryRunner return InsightActorsQueryRunner( - query=cast(InsightActorsQuery | Dict[str, Any], query), + query=cast(InsightActorsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -236,7 +236,7 @@ def get_query_runner( from .insights.insight_actors_query_options_runner import InsightActorsQueryOptionsRunner return InsightActorsQueryOptionsRunner( - query=cast(InsightActorsQueryOptions | Dict[str, Any], query), + query=cast(InsightActorsQueryOptions | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -246,7 +246,7 @@ def get_query_runner( from .insights.funnels.funnel_correlation_query_runner import FunnelCorrelationQueryRunner return FunnelCorrelationQueryRunner( - query=cast(FunnelCorrelationQuery | Dict[str, Any], query), + query=cast(FunnelCorrelationQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -256,7 +256,7 @@ def get_query_runner( from .hogql_query_runner import HogQLQueryRunner return HogQLQueryRunner( - query=cast(HogQLQuery | Dict[str, Any], query), + query=cast(HogQLQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -266,7 +266,7 @@ def get_query_runner( from .sessions_timeline_query_runner import SessionsTimelineQueryRunner return SessionsTimelineQueryRunner( - query=cast(SessionsTimelineQuery | Dict[str, Any], query), + query=cast(SessionsTimelineQuery | dict[str, Any], query), team=team, timings=timings, modifiers=modifiers, @@ -292,7 +292,7 @@ Q = TypeVar("Q", bound=RunnableQueryNode) class QueryRunner(ABC, Generic[Q]): query: Q - query_type: Type[Q] + query_type: type[Q] team: Team timings: HogQLTimings modifiers: HogQLQueryModifiers @@ -300,7 +300,7 @@ class QueryRunner(ABC, Generic[Q]): def __init__( self, - query: Q | BaseModel | Dict[str, Any], + query: Q | BaseModel | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -425,7 +425,7 @@ class QueryRunner(ABC, Generic[Q]): # The default logic below applies to all insights and a lot of other queries # Notable exception: `HogQLQuery`, which has `properties` and `dateRange` within `HogQLFilters` if hasattr(self.query, "properties") and hasattr(self.query, "dateRange"): - query_update: Dict[str, Any] = {} + query_update: dict[str, Any] = {} if dashboard_filter.properties: if self.query.properties: query_update["properties"] = PropertyGroupFilter( diff --git a/posthog/hogql_queries/sessions_timeline_query_runner.py b/posthog/hogql_queries/sessions_timeline_query_runner.py index cda9433d63e..306ec02c934 100644 --- a/posthog/hogql_queries/sessions_timeline_query_runner.py +++ b/posthog/hogql_queries/sessions_timeline_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta import json -from typing import Dict, cast +from typing import cast from posthog.api.element import ElementSerializer @@ -138,7 +138,7 @@ class SessionsTimelineQueryRunner(QueryRunner): limit_context=self.limit_context, ) assert query_result.results is not None - timeline_entries_map: Dict[str, TimelineEntry] = {} + timeline_entries_map: dict[str, TimelineEntry] = {} for ( uuid, timestamp_parsed, diff --git a/posthog/hogql_queries/test/test_events_query_runner.py b/posthog/hogql_queries/test/test_events_query_runner.py index 7c8c62c5fb0..1617919f984 100644 --- a/posthog/hogql_queries/test/test_events_query_runner.py +++ b/posthog/hogql_queries/test/test_events_query_runner.py @@ -1,4 +1,4 @@ -from typing import Tuple, Any, cast +from typing import Any, cast from freezegun import freeze_time @@ -25,7 +25,7 @@ from posthog.test.base import ( class TestEventsQueryRunner(ClickhouseTestMixin, APIBaseTest): maxDiff = None - def _create_events(self, data: list[Tuple[str, str, Any]], event="$pageview"): + def _create_events(self, data: list[tuple[str, str, Any]], event="$pageview"): person_result = [] for distinct_id, timestamp, event_properties in data: with freeze_time(timestamp): diff --git a/posthog/hogql_queries/test/test_query_runner.py b/posthog/hogql_queries/test/test_query_runner.py index a02cf4fb46c..88d6128b005 100644 --- a/posthog/hogql_queries/test/test_query_runner.py +++ b/posthog/hogql_queries/test/test_query_runner.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, List, Literal, Optional +from typing import Any, Literal, Optional from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -21,7 +21,7 @@ from posthog.test.base import BaseTest class TestQuery(BaseModel): kind: Literal["TestQuery"] = "TestQuery" some_attr: str - other_attr: Optional[List[Any]] = [] + other_attr: Optional[list[Any]] = [] class TestQueryRunner(BaseTest): diff --git a/posthog/hogql_queries/utils/formula_ast.py b/posthog/hogql_queries/utils/formula_ast.py index 28e705827b7..922e283362a 100644 --- a/posthog/hogql_queries/utils/formula_ast.py +++ b/posthog/hogql_queries/utils/formula_ast.py @@ -1,6 +1,6 @@ import ast import operator -from typing import Any, Dict, List +from typing import Any class FormulaAST: @@ -12,9 +12,9 @@ class FormulaAST: ast.Mod: operator.mod, ast.Pow: operator.pow, } - zipped_data: List[tuple[float]] + zipped_data: list[tuple[float]] - def __init__(self, data: List[List[float]]): + def __init__(self, data: list[list[float]]): self.zipped_data = list(zip(*data)) def call(self, node: str): @@ -27,8 +27,8 @@ class FormulaAST: res.append(result) return res - def _evaluate(self, node, const_map: Dict[str, Any]): - if isinstance(node, (list, tuple)): + def _evaluate(self, node, const_map: dict[str, Any]): + if isinstance(node, list | tuple): return [self._evaluate(sub_node, const_map) for sub_node in node] elif isinstance(node, str): diff --git a/posthog/hogql_queries/utils/query_date_range.py b/posthog/hogql_queries/utils/query_date_range.py index ab1f25fbb37..ac9636c1e1c 100644 --- a/posthog/hogql_queries/utils/query_date_range.py +++ b/posthog/hogql_queries/utils/query_date_range.py @@ -1,7 +1,7 @@ import re from datetime import datetime, timedelta from functools import cached_property -from typing import Literal, Optional, Dict +from typing import Literal, Optional from zoneinfo import ZoneInfo from dateutil.parser import parse @@ -248,7 +248,7 @@ class QueryDateRange: args=[self.date_to_start_of_interval_hogql(self.date_to_as_hogql()), self.one_interval_period()], ) - def to_placeholders(self) -> Dict[str, ast.Expr]: + def to_placeholders(self) -> dict[str, ast.Expr]: return { "interval": self.interval_period_string_as_hogql_constant(), "one_interval_period": self.one_interval_period(), diff --git a/posthog/hogql_queries/utils/query_previous_period_date_range.py b/posthog/hogql_queries/utils/query_previous_period_date_range.py index 652a95c835e..c6dca63dc7d 100644 --- a/posthog/hogql_queries/utils/query_previous_period_date_range.py +++ b/posthog/hogql_queries/utils/query_previous_period_date_range.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, Dict, Tuple +from typing import Optional from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.models.team import Team @@ -28,7 +28,7 @@ class QueryPreviousPeriodDateRange(QueryDateRange): ) -> None: super().__init__(date_range, team, interval, now) - def date_from_delta_mappings(self) -> Dict[str, int] | None: + def date_from_delta_mappings(self) -> dict[str, int] | None: if self._date_range and isinstance(self._date_range.date_from, str) and self._date_range.date_from != "all": date_from = self._date_range.date_from else: @@ -41,7 +41,7 @@ class QueryPreviousPeriodDateRange(QueryDateRange): )[1] return delta_mapping - def date_to_delta_mappings(self) -> Dict[str, int] | None: + def date_to_delta_mappings(self) -> dict[str, int] | None: if self._date_range and self._date_range.date_to: delta_mapping = relative_date_parse_with_delta_mapping( self._date_range.date_to, @@ -52,7 +52,7 @@ class QueryPreviousPeriodDateRange(QueryDateRange): return delta_mapping return None - def dates(self) -> Tuple[datetime, datetime]: + def dates(self) -> tuple[datetime, datetime]: current_period_date_from = super().date_from() current_period_date_to = super().date_to() diff --git a/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py b/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py index 7ea8e864a3a..3ea21760652 100644 --- a/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py +++ b/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py @@ -1,4 +1,4 @@ -from typing import Union, List +from typing import Union from freezegun import freeze_time @@ -62,7 +62,7 @@ class TestWebStatsTableQueryRunner(ClickhouseTestMixin, APIBaseTest): return WebOverviewQueryRunner(team=self.team, query=query) def test_sample_rate_cache_key_is_same_across_subclasses(self): - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), PersonPropertyFilter(key="$initial_utm_source", value="google", operator=PropertyOperator.is_not), ] @@ -75,10 +75,10 @@ class TestWebStatsTableQueryRunner(ClickhouseTestMixin, APIBaseTest): self.assertEqual(stats_key, overview_key) def test_sample_rate_cache_key_is_same_with_different_properties(self): - properties_a: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties_a: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), ] - properties_b: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties_b: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/b", operator=PropertyOperator.is_not), ] date_from = "2023-12-08" @@ -90,7 +90,7 @@ class TestWebStatsTableQueryRunner(ClickhouseTestMixin, APIBaseTest): self.assertEqual(key_a, key_b) def test_sample_rate_cache_key_changes_with_date_range(self): - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), ] date_from_a = "2023-12-08" @@ -100,7 +100,7 @@ class TestWebStatsTableQueryRunner(ClickhouseTestMixin, APIBaseTest): key_a = self._create_web_stats_table_query(date_from_a, date_to, properties)._sample_rate_cache_key() key_b = self._create_web_stats_table_query(date_from_b, date_to, properties)._sample_rate_cache_key() - self.assertNotEquals(key_a, key_b) + self.assertNotEqual(key_a, key_b) def test_sample_rate_from_count(self): self.assertEqual(SamplingRate(numerator=1), _sample_rate_from_count(0)) diff --git a/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py index f91f3c1cff4..fb1288ac1bd 100644 --- a/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py +++ b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py @@ -2,7 +2,7 @@ import typing from abc import ABC from datetime import timedelta from math import ceil -from typing import Optional, List, Union, Type +from typing import Optional, Union from django.conf import settings from django.core.cache import cache @@ -32,7 +32,7 @@ WebQueryNode = Union[WebOverviewQuery, WebTopClicksQuery, WebStatsTableQuery] class WebAnalyticsQueryRunner(QueryRunner, ABC): query: WebQueryNode - query_type: Type[WebQueryNode] + query_type: type[WebQueryNode] @cached_property def query_date_range(self): @@ -51,7 +51,7 @@ class WebAnalyticsQueryRunner(QueryRunner, ABC): return None @cached_property - def property_filters_without_pathname(self) -> List[Union[EventPropertyFilter, PersonPropertyFilter]]: + def property_filters_without_pathname(self) -> list[Union[EventPropertyFilter, PersonPropertyFilter]]: return [p for p in self.query.properties if p.key != "$pathname"] def session_where(self, include_previous_period: Optional[bool] = None): diff --git a/posthog/jwt.py b/posthog/jwt.py index fa458ab2f5e..111a85d51df 100644 --- a/posthog/jwt.py +++ b/posthog/jwt.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Any, Dict +from typing import Any import jwt from django.conf import settings @@ -32,7 +32,7 @@ def encode_jwt(payload: dict, expiry_delta: timedelta, audience: PosthogJwtAudie return encoded_jwt -def decode_jwt(token: str, audience: PosthogJwtAudience) -> Dict[str, Any]: +def decode_jwt(token: str, audience: PosthogJwtAudience) -> dict[str, Any]: info = jwt.decode(token, settings.SECRET_KEY, audience=audience.value, algorithms=["HS256"]) return info diff --git a/posthog/kafka_client/client.py b/posthog/kafka_client/client.py index d29d9e9c0ae..3f58e572417 100644 --- a/posthog/kafka_client/client.py +++ b/posthog/kafka_client/client.py @@ -1,6 +1,7 @@ import json from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Optional +from collections.abc import Callable from django.conf import settings from kafka import KafkaConsumer as KC @@ -32,7 +33,7 @@ class KafkaProducerForTests: topic: str, value: Any, key: Any = None, - headers: Optional[List[Tuple[str, bytes]]] = None, + headers: Optional[list[tuple[str, bytes]]] = None, ): produce_future = FutureProduceResult(topic_partition=TopicPartition(topic, 1)) future = FutureRecordMetadata( @@ -158,7 +159,7 @@ class _KafkaProducer: data: Any, key: Any = None, value_serializer: Optional[Callable[[Any], Any]] = None, - headers: Optional[List[Tuple[str, str]]] = None, + headers: Optional[list[tuple[str, str]]] = None, ): if not value_serializer: value_serializer = self.json_serializer @@ -258,7 +259,7 @@ class ClickhouseProducer: def __init__(self): self.producer = KafkaProducer() if not settings.TEST else None - def produce(self, sql: str, topic: str, data: Dict[str, Any], sync: bool = True): + def produce(self, sql: str, topic: str, data: dict[str, Any], sync: bool = True): if self.producer is not None: # TODO: this should be not sync and self.producer.produce(topic=topic, data=data) else: diff --git a/posthog/kafka_client/helper.py b/posthog/kafka_client/helper.py index 6084e991a10..39cb9f03856 100644 --- a/posthog/kafka_client/helper.py +++ b/posthog/kafka_client/helper.py @@ -39,9 +39,11 @@ def get_kafka_ssl_context(): # SSLContext inside the with so when it goes out of scope the files are removed which has them # existing for the shortest amount of time. As extra caution password # protect/encrypt the client key - with NamedTemporaryFile(suffix=".crt") as cert_file, NamedTemporaryFile( - suffix=".key" - ) as key_file, NamedTemporaryFile(suffix=".crt") as trust_file: + with ( + NamedTemporaryFile(suffix=".crt") as cert_file, + NamedTemporaryFile(suffix=".key") as key_file, + NamedTemporaryFile(suffix=".crt") as trust_file, + ): cert_file.write(base64.b64decode(os.environ["KAFKA_CLIENT_CERT_B64"].encode("utf-8"))) cert_file.flush() diff --git a/posthog/management/commands/backfill_distinct_id_overrides.py b/posthog/management/commands/backfill_distinct_id_overrides.py index 507e744a93d..4472ec62916 100644 --- a/posthog/management/commands/backfill_distinct_id_overrides.py +++ b/posthog/management/commands/backfill_distinct_id_overrides.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging from dataclasses import dataclass -from typing import Sequence +from collections.abc import Sequence import structlog from django.core.management.base import BaseCommand, CommandError diff --git a/posthog/management/commands/create_channel_definitions_file.py b/posthog/management/commands/create_channel_definitions_file.py index 859bbe3c631..cab70bf31d3 100644 --- a/posthog/management/commands/create_channel_definitions_file.py +++ b/posthog/management/commands/create_channel_definitions_file.py @@ -4,7 +4,7 @@ import subprocess from collections import OrderedDict from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple +from typing import Optional from django.core.management.base import BaseCommand @@ -40,7 +40,7 @@ class Command(BaseCommand): input_arg = options.get("ga_sources") if not input_arg: raise ValueError("No input file specified") - with open(input_arg, "r", encoding="utf-8-sig") as input_file: + with open(input_arg, encoding="utf-8-sig") as input_file: input_str = input_file.read() split_items = re.findall(r"\S+\s+SOURCE_CATEGORY_\S+", input_str) @@ -59,7 +59,7 @@ class Command(BaseCommand): base_type, type_if_paid, type_if_organic = types[raw_type] return (domain, EntryKind.source), SourceEntry(base_type, type_if_paid, type_if_organic) - entries: OrderedDict[Tuple[str, str], SourceEntry] = OrderedDict(map(handle_entry, split_items)) + entries: OrderedDict[tuple[str, str], SourceEntry] = OrderedDict(map(handle_entry, split_items)) # add google domains to this, from https://www.google.com/supported_domains for google_domain in [ diff --git a/posthog/management/commands/fix_person_distinct_ids_after_delete.py b/posthog/management/commands/fix_person_distinct_ids_after_delete.py index 842a4e5353e..4f0853dd001 100644 --- a/posthog/management/commands/fix_person_distinct_ids_after_delete.py +++ b/posthog/management/commands/fix_person_distinct_ids_after_delete.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional import structlog from django.core.management.base import BaseCommand @@ -50,7 +50,7 @@ def run(options, sync: bool = False): logger.info("Kafka producer queue flushed.") -def get_distinct_ids_tied_to_deleted_persons(team_id: int) -> List[str]: +def get_distinct_ids_tied_to_deleted_persons(team_id: int) -> list[str]: # find distinct_ids where the person is set to be deleted rows = sync_execute( """ diff --git a/posthog/management/commands/makemigrations.py b/posthog/management/commands/makemigrations.py index 8ff0a37bfaa..a9e0ea4f98e 100644 --- a/posthog/management/commands/makemigrations.py +++ b/posthog/management/commands/makemigrations.py @@ -9,7 +9,7 @@ from django.db.migrations.loader import MigrationLoader class Command(MakeMigrationsCommand): def handle(self, *app_labels, **options): # Generate a migrations manifest with latest migration on each app - super(Command, self).handle(*app_labels, **options) + super().handle(*app_labels, **options) loader = MigrationLoader(None, ignore_no_migrations=True) apps = sorted(loader.migrated_apps) diff --git a/posthog/management/commands/partition.py b/posthog/management/commands/partition.py index b17e958b0c1..4bb17e68b78 100644 --- a/posthog/management/commands/partition.py +++ b/posthog/management/commands/partition.py @@ -6,7 +6,7 @@ from django.db import connection def load_sql(filename): path = os.path.join(os.path.dirname(__file__), "../sql/", filename) - with open(path, "r", encoding="utf_8") as f: + with open(path, encoding="utf_8") as f: return f.read() diff --git a/posthog/management/commands/run_async_migrations.py b/posthog/management/commands/run_async_migrations.py index 611c6038fd4..c8ee72ea352 100644 --- a/posthog/management/commands/run_async_migrations.py +++ b/posthog/management/commands/run_async_migrations.py @@ -1,5 +1,5 @@ import logging -from typing import List, Sequence +from collections.abc import Sequence import structlog from django.core.exceptions import ImproperlyConfigured @@ -31,7 +31,7 @@ logger.setLevel(logging.INFO) def get_necessary_migrations() -> Sequence[AsyncMigration]: - necessary_migrations: List[AsyncMigration] = [] + necessary_migrations: list[AsyncMigration] = [] for migration_name, definition in sorted(ALL_ASYNC_MIGRATIONS.items()): if is_async_migration_complete(migration_name): continue @@ -144,10 +144,8 @@ def handle_plan(necessary_migrations: Sequence[AsyncMigration]): logger.info("Async migrations up to date!") else: logger.warning( - ( - f"Required async migration{' is' if len(necessary_migrations) == 1 else 's are'} not completed:\n" - "\n".join((f"- {migration.get_name_with_requirements()}" for migration in necessary_migrations)) - ) + f"Required async migration{' is' if len(necessary_migrations) == 1 else 's are'} not completed:\n" + "\n".join(f"- {migration.get_name_with_requirements()}" for migration in necessary_migrations) ) diff --git a/posthog/management/commands/sync_feature_flags.py b/posthog/management/commands/sync_feature_flags.py index df2e8d32576..4e260616036 100644 --- a/posthog/management/commands/sync_feature_flags.py +++ b/posthog/management/commands/sync_feature_flags.py @@ -1,4 +1,4 @@ -from typing import Dict, cast +from typing import cast from django.core.management.base import BaseCommand @@ -15,8 +15,8 @@ class Command(BaseCommand): help = "Add and enable all feature flags in frontend/src/lib/constants.tsx for all teams" def handle(self, *args, **options): - flags: Dict[str, str] = {} - with open("frontend/src/lib/constants.tsx", "r", encoding="utf_8") as f: + flags: dict[str, str] = {} + with open("frontend/src/lib/constants.tsx", encoding="utf_8") as f: lines = f.readlines() parsing_flags = False for line in lines: diff --git a/posthog/management/commands/sync_replicated_schema.py b/posthog/management/commands/sync_replicated_schema.py index e2c280bd41b..642eae80d9b 100644 --- a/posthog/management/commands/sync_replicated_schema.py +++ b/posthog/management/commands/sync_replicated_schema.py @@ -1,7 +1,6 @@ import logging import re from collections import defaultdict -from typing import Dict, Set import structlog from django.conf import settings @@ -65,8 +64,8 @@ class Command(BaseCommand): }, ) - host_tables: Dict[HostName, Set[TableName]] = defaultdict(set) - create_table_queries: Dict[TableName, Query] = {} + host_tables: dict[HostName, set[TableName]] = defaultdict(set) + create_table_queries: dict[TableName, Query] = {} for host, table_name, create_table_query in rows: host_tables[host].add(table_name) @@ -74,7 +73,7 @@ class Command(BaseCommand): return host_tables, create_table_queries, self.get_out_of_sync_hosts(host_tables) - def get_out_of_sync_hosts(self, host_tables: Dict[HostName, Set[TableName]]) -> Dict[HostName, Set[TableName]]: + def get_out_of_sync_hosts(self, host_tables: dict[HostName, set[TableName]]) -> dict[HostName, set[TableName]]: table_names = list(map(get_table_name, CREATE_TABLE_QUERIES)) out_of_sync = {} @@ -87,8 +86,8 @@ class Command(BaseCommand): def create_missing_tables( self, - out_of_sync_hosts: Dict[HostName, Set[TableName]], - create_table_queries: Dict[TableName, Query], + out_of_sync_hosts: dict[HostName, set[TableName]], + create_table_queries: dict[TableName, Query], ): missing_tables = {table for tables in out_of_sync_hosts.values() for table in tables} diff --git a/posthog/management/commands/test_migrations_are_safe.py b/posthog/management/commands/test_migrations_are_safe.py index 566533fd9fe..a576b982b50 100644 --- a/posthog/management/commands/test_migrations_are_safe.py +++ b/posthog/management/commands/test_migrations_are_safe.py @@ -1,6 +1,6 @@ import re import sys -from typing import List, Optional +from typing import Optional from django.core.management import call_command from django.core.management.base import BaseCommand, CommandError @@ -20,7 +20,7 @@ def _get_table(search_string: str, operation_sql: str) -> Optional[str]: def validate_migration_sql(sql) -> bool: new_tables = _get_new_tables(sql) operations = sql.split("\n") - tables_created_so_far: List[str] = [] + tables_created_so_far: list[str] = [] for operation_sql in operations: # Extract table name from queries of this format: ALTER TABLE TABLE "posthog_feature" table_being_altered: Optional[str] = ( diff --git a/posthog/middleware.py b/posthog/middleware.py index e43ef3a620f..87ee128a726 100644 --- a/posthog/middleware.py +++ b/posthog/middleware.py @@ -1,6 +1,7 @@ import time from ipaddress import ip_address, ip_network -from typing import Any, Callable, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from django.shortcuts import redirect import structlog @@ -66,7 +67,7 @@ cookie_api_paths_to_ignore = {"e", "s", "capture", "batch", "decide", "api", "tr class AllowIPMiddleware: - trusted_proxies: List[str] = [] + trusted_proxies: list[str] = [] def __init__(self, get_response): if not settings.ALLOWED_IP_BLOCKS: @@ -411,7 +412,7 @@ class CaptureMiddleware: def __init__(self, get_response): self.get_response = get_response - middlewares: List[Any] = [] + middlewares: list[Any] = [] # based on how we're using these middlewares, only middlewares that # have a process_request and process_response attribute can be valid here. # Or, middlewares that inherit from `middleware.util.deprecation.MiddlewareMixin` which diff --git a/posthog/migrations/0027_move_elements_to_group.py b/posthog/migrations/0027_move_elements_to_group.py index 51a65b1f5da..1bc55cd9853 100644 --- a/posthog/migrations/0027_move_elements_to_group.py +++ b/posthog/migrations/0027_move_elements_to_group.py @@ -1,7 +1,6 @@ # Generated by Django 3.0.3 on 2020-02-27 18:13 import hashlib import json -from typing import List from django.db import migrations, models, transaction from django.forms.models import model_to_dict @@ -21,7 +20,7 @@ def forwards(apps, schema_editor): ElementGroup = apps.get_model("posthog", "ElementGroup") Element = apps.get_model("posthog", "Element") - hashes_seen: List[str] = [] + hashes_seen: list[str] = [] while Event.objects.filter(element__isnull=False, elements_hash__isnull=True, event="$autocapture").exists(): with transaction.atomic(): events = ( diff --git a/posthog/migrations/0132_team_test_account_filters.py b/posthog/migrations/0132_team_test_account_filters.py index 313de9f3355..a1aba896aa2 100644 --- a/posthog/migrations/0132_team_test_account_filters.py +++ b/posthog/migrations/0132_team_test_account_filters.py @@ -22,7 +22,7 @@ class GenericEmails: """ def __init__(self): - with open(get_absolute_path("../helpers/generic_emails.txt"), "r") as f: + with open(get_absolute_path("../helpers/generic_emails.txt")) as f: self.emails = {x.rstrip(): True for x in f} def is_generic(self, email: str) -> bool: diff --git a/posthog/migrations/0219_migrate_tags_v2.py b/posthog/migrations/0219_migrate_tags_v2.py index fef394a5cc0..dcd7375511e 100644 --- a/posthog/migrations/0219_migrate_tags_v2.py +++ b/posthog/migrations/0219_migrate_tags_v2.py @@ -1,5 +1,5 @@ # Generated by Django 3.2.5 on 2022-03-01 23:41 -from typing import Any, List, Tuple +from typing import Any from django.core.paginator import Paginator from django.db import migrations @@ -19,7 +19,7 @@ def forwards(apps, schema_editor): Insight = apps.get_model("posthog", "Insight") Dashboard = apps.get_model("posthog", "Dashboard") - createables: List[Tuple[Any, Any]] = [] + createables: list[tuple[Any, Any]] = [] batch_size = 1_000 # Collect insight tags and taggeditems diff --git a/posthog/migrations/0259_backfill_team_recording_domains.py b/posthog/migrations/0259_backfill_team_recording_domains.py index 1f0dcba4f08..12304cc70fd 100644 --- a/posthog/migrations/0259_backfill_team_recording_domains.py +++ b/posthog/migrations/0259_backfill_team_recording_domains.py @@ -1,4 +1,3 @@ -from typing import Set from urllib.parse import urlparse import structlog @@ -20,7 +19,7 @@ def backfill_recording_domains(apps, _): teams_in_batch = all_teams[i : i + batch_size] for team in teams_in_batch: - recording_domains: Set[str] = set() + recording_domains: set[str] = set() for app_url in team.app_urls: # Extract just the domain from the URL parsed_url = urlparse(app_url) diff --git a/posthog/models/action/action.py b/posthog/models/action/action.py index bd016535a88..49aefe15440 100644 --- a/posthog/models/action/action.py +++ b/posthog/models/action/action.py @@ -1,5 +1,5 @@ import json -from typing import List, Any +from typing import Any from django.db import models from django.db.models import Q @@ -51,10 +51,10 @@ class Action(models.Model): "deleted": self.deleted, } - def get_step_events(self) -> List[str]: + def get_step_events(self) -> list[str]: return [action_step.event for action_step in self.steps.all()] - def generate_bytecode(self) -> List[Any]: + def generate_bytecode(self) -> list[Any]: from posthog.hogql.property import action_to_expr from posthog.hogql.bytecode import create_bytecode diff --git a/posthog/models/action/util.py b/posthog/models/action/util.py index 54fda6ef5b9..95cdca9721c 100644 --- a/posthog/models/action/util.py +++ b/posthog/models/action/util.py @@ -1,6 +1,6 @@ from collections import Counter -from typing import Counter as TCounter, Literal, Optional -from typing import Dict, List, Tuple +from typing import Literal, Optional +from collections import Counter as TCounter from posthog.constants import AUTOCAPTURE_EVENT, TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import HogQLContext @@ -15,7 +15,7 @@ from posthog.queries.util import PersonPropertiesMode def format_action_filter_event_only( action: Action, prepend: str = "action", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Return SQL for prefiltering events by action, i.e. down to only the events and without any other filters.""" events = action.get_step_events() if not events: @@ -37,7 +37,7 @@ def format_action_filter( table_name: str = "", person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_SUBQUERY, person_id_joined_alias: str = "person_id", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Return SQL for filtering events by action.""" # get action steps params = {"team_id": action.team.pk} if filter_by_team else {} @@ -48,7 +48,7 @@ def format_action_filter( or_queries = [] for index, step in enumerate(steps): - conditions: List[str] = [] + conditions: list[str] = [] # filter element if step.event == AUTOCAPTURE_EVENT: from posthog.models.property.util import ( @@ -118,7 +118,7 @@ def format_action_filter( def filter_event( step: ActionStep, prepend: str = "event", index: int = 0, table_name: str = "" -) -> Tuple[List[str], Dict]: +) -> tuple[list[str], dict]: from posthog.models.property.util import get_property_string_expr params = {} @@ -156,7 +156,7 @@ def format_entity_filter( person_id_joined_alias: str, prepend: str = "action", filter_by_team=True, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: if entity.type == TREND_FILTER_TYPE_ACTIONS: action = entity.get_action() entity_filter, params = format_action_filter( diff --git a/posthog/models/activity_logging/activity_log.py b/posthog/models/activity_logging/activity_log.py index 074b53b2dd5..141130ea4f8 100644 --- a/posthog/models/activity_logging/activity_log.py +++ b/posthog/models/activity_logging/activity_log.py @@ -2,7 +2,7 @@ import dataclasses import json from datetime import datetime from decimal import Decimal -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import structlog from django.core.paginator import Paginator @@ -52,7 +52,7 @@ class Change: class Trigger: job_type: str job_id: str - payload: Dict + payload: dict @dataclasses.dataclass(frozen=True) @@ -62,13 +62,13 @@ class Detail: # The short_id if it has one short_id: Optional[str] = None type: Optional[str] = None - changes: Optional[List[Change]] = None + changes: Optional[list[Change]] = None trigger: Optional[Trigger] = None class ActivityDetailEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, (Detail, Change, Trigger)): + if isinstance(obj, Detail | Change | Trigger): return obj.__dict__ if isinstance(obj, datetime): return obj.isoformat() @@ -132,7 +132,7 @@ common_field_exclusions = [ ] -field_exclusions: Dict[ActivityScope, List[str]] = { +field_exclusions: dict[ActivityScope, list[str]] = { "Notebook": [ "text_content", ], @@ -199,7 +199,7 @@ field_exclusions: Dict[ActivityScope, List[str]] = { } -def describe_change(m: Any) -> Union[str, Dict]: +def describe_change(m: Any) -> Union[str, dict]: if isinstance(m, Dashboard): return {"id": m.id, "name": m.name} if isinstance(m, DashboardTile): @@ -213,7 +213,7 @@ def describe_change(m: Any) -> Union[str, Dict]: return str(m) -def _read_through_relation(relation: models.Manager) -> List[Union[Dict, str]]: +def _read_through_relation(relation: models.Manager) -> list[Union[dict, str]]: described_models = [describe_change(r) for r in relation.all()] if all(isinstance(elem, str) for elem in described_models): @@ -227,11 +227,11 @@ def changes_between( model_type: ActivityScope, previous: Optional[models.Model], current: Optional[models.Model], -) -> List[Change]: +) -> list[Change]: """ Identifies changes between two models by comparing fields """ - changes: List[Change] = [] + changes: list[Change] = [] if previous is None and current is None: # there are no changes between two things that don't exist @@ -282,14 +282,14 @@ def changes_between( def dict_changes_between( model_type: ActivityScope, - previous: Dict[Any, Any], - new: Dict[Any, Any], + previous: dict[Any, Any], + new: dict[Any, Any], use_field_exclusions: bool = False, -) -> List[Change]: +) -> list[Change]: """ Identifies changes between two dictionaries by comparing fields """ - changes: List[Change] = [] + changes: list[Change] = [] if previous == new: return changes @@ -395,7 +395,7 @@ class ActivityPage: limit: int has_next: bool has_previous: bool - results: List[ActivityLog] + results: list[ActivityLog] def get_activity_page(activity_query: models.QuerySet, limit: int = 10, page: int = 1) -> ActivityPage: @@ -430,7 +430,7 @@ def load_activity( return get_activity_page(activity_query, limit, page) -def load_all_activity(scope_list: List[ActivityScope], team_id: int, limit: int = 10, page: int = 1): +def load_all_activity(scope_list: list[ActivityScope], team_id: int, limit: int = 10, page: int = 1): activity_query = ( ActivityLog.objects.select_related("user").filter(team_id=team_id, scope__in=scope_list).order_by("-created_at") ) diff --git a/posthog/models/async_deletion/delete.py b/posthog/models/async_deletion/delete.py index 9846842b8e0..1ab75b353e8 100644 --- a/posthog/models/async_deletion/delete.py +++ b/posthog/models/async_deletion/delete.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List, Tuple import structlog from django.utils import timezone @@ -13,7 +12,7 @@ logger = structlog.get_logger(__name__) class AsyncDeletionProcess(ABC): CLICKHOUSE_MUTATION_CHUNK_SIZE = 1_000_000 CLICKHOUSE_VERIFY_CHUNK_SIZE = 1_000 - DELETION_TYPES: List[DeletionType] = [] + DELETION_TYPES: list[DeletionType] = [] def __init__(self) -> None: super().__init__() @@ -60,14 +59,14 @@ class AsyncDeletionProcess(ABC): return result @abstractmethod - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): raise NotImplementedError() @abstractmethod - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: raise NotImplementedError() - def _conditions(self, async_deletions: List[AsyncDeletion]) -> Tuple[List[str], Dict]: + def _conditions(self, async_deletions: list[AsyncDeletion]) -> tuple[list[str], dict]: conditions, args = [], {} for i, row in enumerate(async_deletions): condition, arg = self._condition(row, str(i)) @@ -76,5 +75,5 @@ class AsyncDeletionProcess(ABC): return conditions, args @abstractmethod - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: raise NotImplementedError() diff --git a/posthog/models/async_deletion/delete_cohorts.py b/posthog/models/async_deletion/delete_cohorts.py index c2d452628ce..00f10aac6b8 100644 --- a/posthog/models/async_deletion/delete_cohorts.py +++ b/posthog/models/async_deletion/delete_cohorts.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Set, Tuple +from typing import Any from posthog.client import sync_execute from posthog.models.async_deletion import AsyncDeletion, DeletionType @@ -9,7 +9,7 @@ from posthog.clickhouse.client.connection import Workload class AsyncCohortDeletion(AsyncDeletionProcess): DELETION_TYPES = [DeletionType.Cohort_full, DeletionType.Cohort_stale] - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): if len(deletions) == 0: logger.warn("No AsyncDeletion for cohorts to perform") return @@ -33,7 +33,7 @@ class AsyncCohortDeletion(AsyncDeletionProcess): workload=Workload.OFFLINE, ) - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: if deletion_type == DeletionType.Cohort_stale or deletion_type == DeletionType.Cohort_full: cohort_ids_with_data = self._verify_by_column("team_id, cohort_id", async_deletions) return [ @@ -42,7 +42,7 @@ class AsyncCohortDeletion(AsyncDeletionProcess): else: return [] - def _verify_by_column(self, distinct_columns: str, async_deletions: List[AsyncDeletion]) -> Set[Tuple[Any, ...]]: + def _verify_by_column(self, distinct_columns: str, async_deletions: list[AsyncDeletion]) -> set[tuple[Any, ...]]: conditions, args = self._conditions(async_deletions) clickhouse_result = sync_execute( f""" @@ -62,7 +62,7 @@ class AsyncCohortDeletion(AsyncDeletionProcess): ) return "cohort_id" - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: team_id_param = f"team_id{suffix}" key_param = f"key{suffix}" version_param = f"version{suffix}" diff --git a/posthog/models/async_deletion/delete_events.py b/posthog/models/async_deletion/delete_events.py index 2486043a5b8..988161336cc 100644 --- a/posthog/models/async_deletion/delete_events.py +++ b/posthog/models/async_deletion/delete_events.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Set, Tuple +from typing import Any from posthog.client import sync_execute from posthog.models.async_deletion import AsyncDeletion, DeletionType, CLICKHOUSE_ASYNC_DELETION_TABLE @@ -22,7 +22,7 @@ TABLES_TO_DELETE_TEAM_DATA_FROM = [ class AsyncEventDeletion(AsyncDeletionProcess): DELETION_TYPES = [DeletionType.Team, DeletionType.Person] - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): if len(deletions) == 0: logger.debug("No AsyncDeletion to perform") return @@ -87,7 +87,7 @@ class AsyncEventDeletion(AsyncDeletionProcess): workload=Workload.OFFLINE, ) - def _fill_table(self, deletions: List[AsyncDeletion], temp_table_name: str): + def _fill_table(self, deletions: list[AsyncDeletion], temp_table_name: str): sync_execute(f"DROP TABLE IF EXISTS {temp_table_name}", workload=Workload.OFFLINE) sync_execute( CLICKHOUSE_ASYNC_DELETION_TABLE.format(table_name=temp_table_name, cluster=CLICKHOUSE_CLUSTER), @@ -111,7 +111,7 @@ class AsyncEventDeletion(AsyncDeletionProcess): workload=Workload.OFFLINE, ) - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: if deletion_type == DeletionType.Team: team_ids_with_data = self._verify_by_column("team_id", async_deletions) return [row for row in async_deletions if (row.team_id,) not in team_ids_with_data] @@ -122,7 +122,7 @@ class AsyncEventDeletion(AsyncDeletionProcess): else: return [] - def _verify_by_column(self, distinct_columns: str, async_deletions: List[AsyncDeletion]) -> Set[Tuple[Any, ...]]: + def _verify_by_column(self, distinct_columns: str, async_deletions: list[AsyncDeletion]) -> set[tuple[Any, ...]]: conditions, args = self._conditions(async_deletions) clickhouse_result = sync_execute( f""" @@ -142,7 +142,7 @@ class AsyncEventDeletion(AsyncDeletionProcess): else: return f"$group_{async_deletion.group_type_index}" - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: if async_deletion.deletion_type == DeletionType.Team: return f"team_id = %(team_id{suffix})s", {f"team_id{suffix}": async_deletion.team_id} else: diff --git a/posthog/models/async_migration.py b/posthog/models/async_migration.py index 885f7ce3979..92d61fb5e3f 100644 --- a/posthog/models/async_migration.py +++ b/posthog/models/async_migration.py @@ -1,5 +1,3 @@ -from typing import List - from django.db import models @@ -63,7 +61,7 @@ def get_all_running_async_migrations(): return AsyncMigration.objects.filter(status=MigrationStatus.Running) -def get_async_migrations_by_status(target_statuses: List[int]): +def get_async_migrations_by_status(target_statuses: list[int]): return AsyncMigration.objects.filter(status__in=target_statuses) diff --git a/posthog/models/channel_type/sql.py b/posthog/models/channel_type/sql.py index 15470601c2d..d631c276e55 100644 --- a/posthog/models/channel_type/sql.py +++ b/posthog/models/channel_type/sql.py @@ -37,7 +37,7 @@ TRUNCATE_CHANNEL_DEFINITION_TABLE_SQL = ( f"TRUNCATE TABLE IF EXISTS {CHANNEL_DEFINITION_TABLE_NAME} ON CLUSTER '{CLICKHOUSE_CLUSTER}'" ) -with open(os.path.join(os.path.dirname(__file__), "channel_definitions.json"), "r") as f: +with open(os.path.join(os.path.dirname(__file__), "channel_definitions.json")) as f: CHANNEL_DEFINITIONS = json.loads(f.read()) @@ -54,7 +54,7 @@ CHANNEL_DEFINITION_DATA_SQL = f""" INSERT INTO channel_definition (domain, kind, domain_type, type_if_paid, type_if_organic) VALUES { ''', -'''.join((f'({" ,".join(map(format_value, x))})' for x in CHANNEL_DEFINITIONS))}, +'''.join(f'({" ,".join(map(format_value, x))})' for x in CHANNEL_DEFINITIONS)}, ; """ diff --git a/posthog/models/cohort/cohort.py b/posthog/models/cohort/cohort.py index a10be159d57..8f7867127a1 100644 --- a/posthog/models/cohort/cohort.py +++ b/posthog/models/cohort/cohort.py @@ -1,6 +1,6 @@ import time from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast import structlog from django.conf import settings @@ -37,7 +37,7 @@ ON CONFLICT DO NOTHING class Group: def __init__( self, - properties: Optional[Dict[str, Any]] = None, + properties: Optional[dict[str, Any]] = None, action_id: Optional[int] = None, event_id: Optional[str] = None, days: Optional[int] = None, @@ -59,7 +59,7 @@ class Group: self.start_date = start_date self.end_date = end_date - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: dup = self.__dict__.copy() dup["start_date"] = self.start_date.isoformat() if self.start_date else self.start_date dup["end_date"] = self.end_date.isoformat() if self.end_date else self.end_date @@ -159,11 +159,11 @@ class Cohort(models.Model): ) else: # invalid state - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) return PropertyGroup(PropertyOperatorType.OR, property_groups) - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) @property def has_complex_behavioral_filter(self) -> bool: @@ -241,7 +241,7 @@ class Cohort(models.Model): clear_stale_cohort.delay(self.pk, before_version=pending_version) - def insert_users_by_list(self, items: List[str]) -> None: + def insert_users_by_list(self, items: list[str]) -> None: """ Items is a list of distinct_ids """ @@ -303,7 +303,7 @@ class Cohort(models.Model): self.save() capture_exception(err) - def insert_users_list_by_uuid(self, items: List[str], insert_in_clickhouse: bool = False, batchsize=1000) -> None: + def insert_users_list_by_uuid(self, items: list[str], insert_in_clickhouse: bool = False, batchsize=1000) -> None: from posthog.models.cohort.util import get_static_cohort_size, insert_static_cohort try: diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index 059bfb3813b..2af2c0c66c1 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Optional, Union, cast import structlog from dateutil import parser @@ -44,7 +44,7 @@ TEMP_PRECALCULATED_MARKER = parser.parse("2021-06-07T15:00:00+00:00") logger = structlog.get_logger(__name__) -def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) -> Tuple[str, Dict[str, Any]]: +def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) -> tuple[str, dict[str, Any]]: if cohort.is_static: return format_static_cohort_query(cohort, index, prepend="") @@ -72,7 +72,7 @@ def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str: from posthog.hogql_queries.query_runner import get_query_runner - persons_query = cast(Dict, cohort.query) + persons_query = cast(dict, cohort.query) persons_query["select"] = ["id as actor_id"] query = get_query_runner( persons_query, team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION @@ -81,7 +81,7 @@ def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str return print_ast(query, context=hogql_context, dialect="clickhouse") -def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> Tuple[str, Dict[str, Any]]: +def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> tuple[str, dict[str, Any]]: cohort_id = cohort.pk return ( f"SELECT person_id as id FROM {PERSON_STATIC_COHORT_TABLE} WHERE cohort_id = %({prepend}_cohort_id_{index})s AND team_id = %(team_id)s", @@ -89,7 +89,7 @@ def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> Tupl ) -def format_precalculated_cohort_query(cohort: Cohort, index: int, prepend: str = "") -> Tuple[str, Dict[str, Any]]: +def format_precalculated_cohort_query(cohort: Cohort, index: int, prepend: str = "") -> tuple[str, dict[str, Any]]: filter_query = GET_PERSON_ID_BY_PRECALCULATED_COHORT_ID.format(index=index, prepend=prepend) return ( filter_query, @@ -121,7 +121,7 @@ def get_entity_query( team_id: int, group_idx: Union[int, str], hogql_context: HogQLContext, -) -> Tuple[str, Dict[str, str]]: +) -> tuple[str, dict[str, str]]: if event_id: return f"event = %({f'event_{group_idx}'})s", {f"event_{group_idx}": event_id} elif action_id: @@ -139,9 +139,9 @@ def get_entity_query( def get_date_query( days: Optional[str], start_time: Optional[str], end_time: Optional[str] -) -> Tuple[str, Dict[str, str]]: +) -> tuple[str, dict[str, str]]: date_query: str = "" - date_params: Dict[str, str] = {} + date_params: dict[str, str] = {} if days: date_query, date_params = parse_entity_timestamps_in_days(int(days)) elif start_time or end_time: @@ -150,7 +150,7 @@ def get_date_query( return date_query, date_params -def parse_entity_timestamps_in_days(days: int) -> Tuple[str, Dict[str, str]]: +def parse_entity_timestamps_in_days(days: int) -> tuple[str, dict[str, str]]: curr_time = timezone.now() start_time = curr_time - timedelta(days=days) @@ -163,9 +163,9 @@ def parse_entity_timestamps_in_days(days: int) -> Tuple[str, Dict[str, str]]: ) -def parse_cohort_timestamps(start_time: Optional[str], end_time: Optional[str]) -> Tuple[str, Dict[str, str]]: +def parse_cohort_timestamps(start_time: Optional[str], end_time: Optional[str]) -> tuple[str, dict[str, str]]: clause = "AND " - params: Dict[str, str] = {} + params: dict[str, str] = {} if start_time: clause += "timestamp >= %(date_from)s" @@ -199,7 +199,7 @@ def format_filter_query( hogql_context: HogQLContext, id_column: str = "distinct_id", custom_match_field="person_id", -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: person_query, params = format_cohort_subquery(cohort, index, hogql_context, custom_match_field=custom_match_field) person_id_query = CALCULATE_COHORT_PEOPLE_SQL.format( @@ -215,7 +215,7 @@ def format_cohort_subquery( index: int, hogql_context: HogQLContext, custom_match_field="person_id", -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: is_precalculated = is_precalculated_query(cohort) if is_precalculated: query, params = format_precalculated_cohort_query(cohort, index) @@ -259,7 +259,7 @@ def get_person_ids_by_cohort_id( return [str(row[0]) for row in results] -def insert_static_cohort(person_uuids: List[Optional[uuid.UUID]], cohort_id: int, team: Team): +def insert_static_cohort(person_uuids: list[Optional[uuid.UUID]], cohort_id: int, team: Team): persons = ( { "id": str(uuid.uuid4()), @@ -442,17 +442,17 @@ def simplified_cohort_filter_properties(cohort: Cohort, team: Team, is_negated=F return cohort.properties -def _get_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def _get_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: res = sync_execute(GET_COHORTS_BY_PERSON_UUID, {"person_id": uuid, "team_id": team_id}) return [row[0] for row in res] -def _get_static_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def _get_static_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: res = sync_execute(GET_STATIC_COHORTPEOPLE_BY_PERSON_UUID, {"person_id": uuid, "team_id": team_id}) return [row[0] for row in res] -def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: cohort_ids = _get_cohort_ids_by_person_uuid(uuid, team_id) static_cohort_ids = _get_static_cohort_ids_by_person_uuid(uuid, team_id) return [*cohort_ids, *static_cohort_ids] @@ -461,8 +461,8 @@ def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: def get_dependent_cohorts( cohort: Cohort, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, -) -> List[Cohort]: + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, +) -> list[Cohort]: if seen_cohorts_cache is None: seen_cohorts_cache = {} @@ -508,7 +508,7 @@ def get_dependent_cohorts( return cohorts -def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[int, CohortOrEmpty]) -> List[int]: +def sort_cohorts_topologically(cohort_ids: set[int], seen_cohorts_cache: dict[int, CohortOrEmpty]) -> list[int]: """ Sorts the given cohorts in an order where cohorts with no dependencies are placed first, followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list @@ -518,7 +518,7 @@ def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[in if not cohort_ids: return [] - dependency_graph: Dict[int, List[int]] = {} + dependency_graph: dict[int, list[int]] = {} seen = set() # build graph (adjacency list) @@ -553,7 +553,7 @@ def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[in sorted_arr.append(int(node)) seen.add(node) - sorted_cohort_ids: List[int] = [] + sorted_cohort_ids: list[int] = [] seen = set() for cohort_id in cohort_ids: if cohort_id not in seen: diff --git a/posthog/models/dashboard.py b/posthog/models/dashboard.py index 9be7e0de14e..003201722a5 100644 --- a/posthog/models/dashboard.py +++ b/posthog/models/dashboard.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from django.contrib.postgres.fields import ArrayField from django.db import models @@ -93,7 +93,7 @@ class Dashboard(models.Model): def url(self): return absolute_uri(f"/dashboard/{self.id}") - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """ Returns serialized information about the object for analytics reporting. """ diff --git a/posthog/models/dashboard_tile.py b/posthog/models/dashboard_tile.py index 50af2868abf..9d39028e49b 100644 --- a/posthog/models/dashboard_tile.py +++ b/posthog/models/dashboard_tile.py @@ -1,5 +1,3 @@ -from typing import List - from django.core.exceptions import ValidationError from django.db import models from django.db.models import Q, QuerySet, UniqueConstraint @@ -112,7 +110,7 @@ class DashboardTile(models.Model): if "update_fields" in kwargs: kwargs["update_fields"].append("filters_hash") - super(DashboardTile, self).save(*args, **kwargs) + super().save(*args, **kwargs) def copy_to_dashboard(self, dashboard: Dashboard) -> None: DashboardTile.objects.create( @@ -139,7 +137,7 @@ class DashboardTile(models.Model): ) -def get_tiles_ordered_by_position(dashboard: Dashboard, size: str = "xs") -> List[DashboardTile]: +def get_tiles_ordered_by_position(dashboard: Dashboard, size: str = "xs") -> list[DashboardTile]: tiles = list( dashboard.tiles.select_related("insight", "text") .exclude(insight__deleted=True) diff --git a/posthog/models/element/element.py b/posthog/models/element/element.py index c1091932cd4..4beeb540085 100644 --- a/posthog/models/element/element.py +++ b/posthog/models/element/element.py @@ -1,5 +1,4 @@ import re -from typing import List from django.contrib.postgres.fields import ArrayField from django.db import models @@ -34,7 +33,7 @@ def _escape(input: str) -> str: return input.replace('"', r"\"") -def elements_to_string(elements: List[Element]) -> str: +def elements_to_string(elements: list[Element]) -> str: ret = [] for element in elements: el_string = "" @@ -58,7 +57,7 @@ def elements_to_string(elements: List[Element]) -> str: return ";".join(ret) -def chain_to_elements(chain: str) -> List[Element]: +def chain_to_elements(chain: str) -> list[Element]: elements = [] for idx, el_string in enumerate(re.findall(split_chain_regex, chain)): el_string_split = re.findall(split_class_attributes, el_string)[0] diff --git a/posthog/models/element_group.py b/posthog/models/element_group.py index 3d399f25598..0a6a2545da0 100644 --- a/posthog/models/element_group.py +++ b/posthog/models/element_group.py @@ -1,6 +1,6 @@ import hashlib import json -from typing import Any, Dict, List +from typing import Any from django.db import models, transaction from django.forms.models import model_to_dict @@ -9,8 +9,8 @@ from posthog.models.element import Element from posthog.models.team import Team -def hash_elements(elements: List) -> str: - elements_list: List[Dict] = [] +def hash_elements(elements: list) -> str: + elements_list: list[dict] = [] for element in elements: el_dict = model_to_dict(element) [el_dict.pop(key) for key in ["event", "id", "group"]] diff --git a/posthog/models/entity/entity.py b/posthog/models/entity/entity.py index 91865f9fa50..255edb0db4f 100644 --- a/posthog/models/entity/entity.py +++ b/posthog/models/entity/entity.py @@ -1,6 +1,6 @@ import inspect from collections import Counter -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from django.conf import settings from rest_framework.exceptions import ValidationError @@ -67,7 +67,7 @@ class Entity(PropertyMixin): id_field: Optional[str] timestamp_field: Optional[str] - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: self.id = data.get("id") if data.get("type") not in [ TREND_FILTER_TYPE_ACTIONS, @@ -102,7 +102,7 @@ class Entity(PropertyMixin): if self.type == TREND_FILTER_TYPE_EVENTS and not self.name: self.name = "All events" if self.id is None else str(self.id) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "type": self.type, @@ -180,10 +180,10 @@ class ExclusionEntity(Entity, FunnelFromToStepsMixin): with extra parameters for exclusion semantics. """ - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: super().__init__(data) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: ret = super().to_dict() for _, func in inspect.getmembers(self, inspect.ismethod): diff --git a/posthog/models/entity/util.py b/posthog/models/entity/util.py index 06abcda5d01..ffcd8cda671 100644 --- a/posthog/models/entity/util.py +++ b/posthog/models/entity/util.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Sequence, Set, Tuple +from typing import Any +from collections.abc import Sequence from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import HogQLContext @@ -16,17 +17,17 @@ def get_entity_filtering_params( person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_PERSON_PROPERTIES_COLUMN, person_id_joined_alias: str = "person_id", deep_filtering: bool = False, -) -> Tuple[Dict, Dict]: +) -> tuple[dict, dict]: """Return SQL condition for filtering events by allowed entities (events/actions). Events matching _at least one_ entity are included. If no entities are provided, _all_ events are included.""" if not allowed_entities: return {}, {} - params: Dict[str, Any] = {} - entity_clauses: List[str] = [] - action_ids_already_included: Set[int] = set() # Avoid duplicating action conditions - events_already_included: Set[str] = set() # Avoid duplicating event conditions + params: dict[str, Any] = {} + entity_clauses: list[str] = [] + action_ids_already_included: set[int] = set() # Avoid duplicating action conditions + events_already_included: set[str] = set() # Avoid duplicating event conditions for entity in allowed_entities: if entity.type == TREND_FILTER_TYPE_ACTIONS: if entity.id in action_ids_already_included or entity.id is None: diff --git a/posthog/models/event/event.py b/posthog/models/event/event.py index 59b2f3c0a03..184fffb18af 100644 --- a/posthog/models/event/event.py +++ b/posthog/models/event/event.py @@ -2,7 +2,7 @@ import copy import datetime import re from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Optional, Union from dateutil.relativedelta import relativedelta from django.db import models @@ -13,10 +13,10 @@ from posthog.models.team import Team SELECTOR_ATTRIBUTE_REGEX = r"([a-zA-Z]*)\[(.*)=[\'|\"](.*)[\'|\"]\]" -LAST_UPDATED_TEAM_ACTION: Dict[int, datetime.datetime] = {} -TEAM_EVENT_ACTION_QUERY_CACHE: Dict[int, Dict[str, tuple]] = defaultdict(dict) +LAST_UPDATED_TEAM_ACTION: dict[int, datetime.datetime] = {} +TEAM_EVENT_ACTION_QUERY_CACHE: dict[int, dict[str, tuple]] = defaultdict(dict) # TEAM_EVENT_ACTION_QUERY_CACHE looks like team_id -> event ex('$pageview') -> query -TEAM_ACTION_QUERY_CACHE: Dict[int, str] = {} +TEAM_ACTION_QUERY_CACHE: dict[int, str] = {} DEFAULT_EARLIEST_TIME_DELTA = relativedelta(weeks=1) @@ -26,8 +26,8 @@ class SelectorPart: def __init__(self, tag: str, direct_descendant: bool, escape_slashes: bool): self.direct_descendant = direct_descendant - self.data: Dict[str, Union[str, List]] = {} - self.ch_attributes: Dict[str, Union[str, List]] = {} # attributes for CH + self.data: dict[str, Union[str, list]] = {} + self.ch_attributes: dict[str, Union[str, list]] = {} # attributes for CH result = re.search(SELECTOR_ATTRIBUTE_REGEX, tag) if result and "[id=" in tag: @@ -58,9 +58,9 @@ class SelectorPart: self.data["tag_name"] = tag @property - def extra_query(self) -> Dict[str, List[Union[str, List[str]]]]: - where: List[Union[str, List[str]]] = [] - params: List[Union[str, List[str]]] = [] + def extra_query(self) -> dict[str, list[Union[str, list[str]]]]: + where: list[Union[str, list[str]]] = [] + params: list[Union[str, list[str]]] = [] for key, value in self.data.items(): if "attr__" in key: where.append(f"(attributes ->> 'attr__{key.split('attr__')[1]}') = %s") @@ -78,7 +78,7 @@ class SelectorPart: class Selector: - parts: List[SelectorPart] = [] + parts: list[SelectorPart] = [] def __init__(self, selector: str, escape_slashes=True): self.parts = [] @@ -98,7 +98,7 @@ class Selector: def _split(self, selector): in_attribute_selector = False in_quotes: Optional[str] = None - part: List[str] = [] + part: list[str] = [] for char in selector: if char == "[" and in_quotes is None: in_attribute_selector = True diff --git a/posthog/models/event/query_event_list.py b/posthog/models/event/query_event_list.py index ded739c9a81..1ecdbee021a 100644 --- a/posthog/models/event/query_event_list.py +++ b/posthog/models/event/query_event_list.py @@ -1,5 +1,5 @@ from datetime import timedelta, datetime, time -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -29,10 +29,10 @@ def parse_timestamp(timestamp: str, tzinfo: ZoneInfo) -> datetime: def parse_request_params( - conditions: Dict[str, Union[None, str, List[str]]], team: Team, tzinfo: ZoneInfo -) -> Tuple[str, Dict]: + conditions: dict[str, Union[None, str, list[str]]], team: Team, tzinfo: ZoneInfo +) -> tuple[str, dict]: result = "" - params: Dict[str, Union[str, List[str]]] = {} + params: dict[str, Union[str, list[str]]] = {} for k, v in conditions.items(): if not isinstance(v, str): continue @@ -58,13 +58,13 @@ def parse_request_params( def query_events_list( filter: Filter, team: Team, - request_get_query_dict: Dict, - order_by: List[str], + request_get_query_dict: dict, + order_by: list[str], action_id: Optional[str], unbounded_date_from: bool = False, limit: int = DEFAULT_RETURNED_ROWS, offset: int = 0, -) -> List: +) -> list: # Note: This code is inefficient and problematic, see https://github.com/PostHog/posthog/issues/13485 for details. # To isolate its impact from rest of the queries its queries are run on different nodes as part of "offline" workloads. hogql_context = HogQLContext(within_non_hogql_query=True, team_id=team.pk, enable_select_queries=True) diff --git a/posthog/models/event/util.py b/posthog/models/event/util.py index c5509489801..065d47da331 100644 --- a/posthog/models/event/util.py +++ b/posthog/models/event/util.py @@ -1,7 +1,7 @@ import datetime as dt import json import uuid -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Literal, Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -31,16 +31,16 @@ def create_event( team: Team, distinct_id: str, timestamp: Optional[Union[timezone.datetime, str]] = None, - properties: Optional[Dict] = None, - elements: Optional[List[Element]] = None, + properties: Optional[dict] = None, + elements: Optional[list[Element]] = None, person_id: Optional[uuid.UUID] = None, - person_properties: Optional[Dict] = None, + person_properties: Optional[dict] = None, person_created_at: Optional[Union[timezone.datetime, str]] = None, - group0_properties: Optional[Dict] = None, - group1_properties: Optional[Dict] = None, - group2_properties: Optional[Dict] = None, - group3_properties: Optional[Dict] = None, - group4_properties: Optional[Dict] = None, + group0_properties: Optional[dict] = None, + group1_properties: Optional[dict] = None, + group2_properties: Optional[dict] = None, + group3_properties: Optional[dict] = None, + group4_properties: Optional[dict] = None, group0_created_at: Optional[Union[timezone.datetime, str]] = None, group1_created_at: Optional[Union[timezone.datetime, str]] = None, group2_created_at: Optional[Union[timezone.datetime, str]] = None, @@ -105,8 +105,8 @@ def format_clickhouse_timestamp( def bulk_create_events( - events: List[Dict[str, Any]], - person_mapping: Optional[Dict[str, Person]] = None, + events: list[dict[str, Any]], + person_mapping: Optional[dict[str, Person]] = None, ) -> None: """ TEST ONLY @@ -121,7 +121,7 @@ def bulk_create_events( if not TEST: raise Exception("This function is only meant for setting up tests") inserts = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for index, event in enumerate(events): datetime64_default_timestamp = timezone.now().astimezone(ZoneInfo("UTC")).strftime("%Y-%m-%d %H:%M:%S") timestamp = event.get("timestamp") or dt.datetime.now() @@ -287,7 +287,7 @@ class ElementSerializer(serializers.ModelSerializer): ] -def parse_properties(properties: str, allow_list: Optional[Set[str]] = None) -> Dict: +def parse_properties(properties: str, allow_list: Optional[set[str]] = None) -> dict: # parse_constants gets called for any NaN, Infinity etc values # we just want those to be returned as None if allow_list is None: @@ -349,7 +349,7 @@ class ClickhouseEventSerializer(serializers.Serializer): return event["elements_chain"] -def get_agg_event_count_for_teams(team_ids: List[Union[str, int]]) -> int: +def get_agg_event_count_for_teams(team_ids: list[Union[str, int]]) -> int: result = sync_execute( """ SELECT count(1) as count @@ -362,7 +362,7 @@ def get_agg_event_count_for_teams(team_ids: List[Union[str, int]]) -> int: def get_agg_events_with_groups_count_for_teams_and_period( - team_ids: List[Union[str, int]], begin: timezone.datetime, end: timezone.datetime + team_ids: list[Union[str, int]], begin: timezone.datetime, end: timezone.datetime ) -> int: result = sync_execute( """ diff --git a/posthog/models/exported_asset.py b/posthog/models/exported_asset.py index ceebb2bc3db..d07009be45b 100644 --- a/posthog/models/exported_asset.py +++ b/posthog/models/exported_asset.py @@ -1,6 +1,6 @@ import secrets from datetime import timedelta -from typing import List, Optional +from typing import Optional import structlog from django.conf import settings @@ -178,7 +178,7 @@ def save_content_to_exported_asset(exported_asset: ExportedAsset, content: bytes def save_content_to_object_storage(exported_asset: ExportedAsset, content: bytes) -> None: - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_EXPORTS_FOLDER, exported_asset.export_format.split("/")[1], f"team-{exported_asset.team.id}", diff --git a/posthog/models/feature_flag/feature_flag.py b/posthog/models/feature_flag/feature_flag.py index 67432e0b643..0a46a44d53b 100644 --- a/posthog/models/feature_flag/feature_flag.py +++ b/posthog/models/feature_flag/feature_flag.py @@ -1,7 +1,7 @@ import json from django.http import HttpRequest import structlog -from typing import Dict, List, Optional, cast +from typing import Optional, cast from django.core.cache import cache from django.db import models @@ -59,7 +59,7 @@ class FeatureFlag(models.Model): # whether a feature is sending us rich analytics, like views & interactions. has_enriched_analytics: models.BooleanField = models.BooleanField(default=False, null=True, blank=True) - def get_analytics_metadata(self) -> Dict: + def get_analytics_metadata(self) -> dict: filter_count = sum(len(condition.get("properties", [])) for condition in self.conditions) variants_count = len(self.variants) payload_count = len(self._payloads) @@ -135,7 +135,7 @@ class FeatureFlag(models.Model): def transform_cohort_filters_for_easy_evaluation( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, ): """ Expands cohort filters into person property filters when possible. @@ -243,7 +243,7 @@ class FeatureFlag(models.Model): if target_properties.type == PropertyOperatorType.AND: return self.conditions - for prop_group in cast(List[PropertyGroup], target_properties.values): + for prop_group in cast(list[PropertyGroup], target_properties.values): if ( len(prop_group.values) == 0 or not isinstance(prop_group.values[0], Property) @@ -264,9 +264,9 @@ class FeatureFlag(models.Model): def get_cohort_ids( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, sort_by_topological_order=False, - ) -> List[int]: + ) -> list[int]: from posthog.models.cohort.util import get_dependent_cohorts, sort_cohorts_topologically if seen_cohorts_cache is None: @@ -398,9 +398,9 @@ class FeatureFlagOverride(models.Model): def set_feature_flags_for_team_in_cache( team_id: int, - feature_flags: Optional[List[FeatureFlag]] = None, + feature_flags: Optional[list[FeatureFlag]] = None, using_database: str = "default", -) -> List[FeatureFlag]: +) -> list[FeatureFlag]: from posthog.api.feature_flag import MinimalFeatureFlagSerializer if feature_flags is not None: @@ -422,7 +422,7 @@ def set_feature_flags_for_team_in_cache( return all_feature_flags -def get_feature_flags_for_team_in_cache(team_id: int) -> Optional[List[FeatureFlag]]: +def get_feature_flags_for_team_in_cache(team_id: int) -> Optional[list[FeatureFlag]]: try: flag_data = cache.get(f"team_feature_flags_{team_id}") except Exception: diff --git a/posthog/models/feature_flag/flag_analytics.py b/posthog/models/feature_flag/flag_analytics.py index d5f27d804ac..f62ed1934ec 100644 --- a/posthog/models/feature_flag/flag_analytics.py +++ b/posthog/models/feature_flag/flag_analytics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING from posthog.constants import FlagRequestType from posthog.helpers.dashboard_templates import ( add_enriched_insights_to_feature_flag_dashboard, @@ -45,7 +45,7 @@ def increment_request_count( capture_exception(error) -def _extract_total_count_for_key_from_redis_hash(client: redis.Redis, key: str) -> Tuple[int, int, int]: +def _extract_total_count_for_key_from_redis_hash(client: redis.Redis, key: str) -> tuple[int, int, int]: total_count = 0 existing_values = client.hgetall(key) time_buckets = existing_values.keys() diff --git a/posthog/models/feature_flag/flag_matching.py b/posthog/models/feature_flag/flag_matching.py index 134af65dfda..0b4a6befebc 100644 --- a/posthog/models/feature_flag/flag_matching.py +++ b/posthog/models/feature_flag/flag_matching.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum import time import structlog -from typing import Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from prometheus_client import Counter from django.conf import settings @@ -110,7 +110,7 @@ class FlagsMatcherCache: self.failed_to_fetch_flags = False @cached_property - def group_types_to_indexes(self) -> Dict[GroupTypeName, GroupTypeIndex]: + def group_types_to_indexes(self) -> dict[GroupTypeName, GroupTypeIndex]: if self.failed_to_fetch_flags: raise DatabaseError("Failed to fetch group type mapping previously, not trying again.") try: @@ -124,7 +124,7 @@ class FlagsMatcherCache: raise err @cached_property - def group_type_index_to_name(self) -> Dict[GroupTypeIndex, GroupTypeName]: + def group_type_index_to_name(self) -> dict[GroupTypeIndex, GroupTypeName]: return {value: key for key, value in self.group_types_to_indexes.items()} @@ -133,15 +133,15 @@ class FeatureFlagMatcher: def __init__( self, - feature_flags: List[FeatureFlag], + feature_flags: list[FeatureFlag], distinct_id: str, - groups: Optional[Dict[GroupTypeName, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, cache: Optional[FlagsMatcherCache] = None, - hash_key_overrides: Optional[Dict[str, str]] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, + hash_key_overrides: Optional[dict[str, str]] = None, + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, skip_database_flags: bool = False, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, ): if group_property_value_overrides is None: group_property_value_overrides = {} @@ -231,7 +231,7 @@ class FeatureFlagMatcher: payload=None, ) - def get_matches(self) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: + def get_matches(self) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: flag_values = {} flag_evaluation_reasons = {} faced_error_computing_flags = False @@ -287,7 +287,7 @@ class FeatureFlagMatcher: else: return None - def is_super_condition_match(self, feature_flag: FeatureFlag) -> Tuple[bool, bool, FeatureFlagMatchReason]: + def is_super_condition_match(self, feature_flag: FeatureFlag) -> tuple[bool, bool, FeatureFlagMatchReason]: # TODO: Right now super conditions with property overrides bork when the database is down, # because we're still going to the database in the line below. Ideally, we should not go to the database. # Don't skip test: test_super_condition_with_override_properties_doesnt_make_database_requests when this is fixed. @@ -320,8 +320,8 @@ class FeatureFlagMatcher: return False, False, FeatureFlagMatchReason.NO_CONDITION_MATCH def is_condition_match( - self, feature_flag: FeatureFlag, condition: Dict, condition_index: int - ) -> Tuple[bool, FeatureFlagMatchReason]: + self, feature_flag: FeatureFlag, condition: dict, condition_index: int + ) -> tuple[bool, FeatureFlagMatchReason]: rollout_percentage = condition.get("rollout_percentage") if len(condition.get("properties", [])) > 0: properties = Filter(data=condition).property_groups.flat @@ -405,12 +405,12 @@ class FeatureFlagMatcher: return lookup_table @cached_property - def query_conditions(self) -> Dict[str, bool]: + def query_conditions(self) -> dict[str, bool]: try: # Some extra wiggle room here for timeouts because this depends on the number of flags as well, # and not just the database query. with execute_with_timeout(FLAG_MATCHING_QUERY_TIMEOUT_MS * 2, DATABASE_FOR_FLAG_MATCHING): - all_conditions: Dict = {} + all_conditions: dict = {} team_id = self.feature_flags[0].team_id person_query: QuerySet = Person.objects.using(DATABASE_FOR_FLAG_MATCHING).filter( team_id=team_id, @@ -418,7 +418,7 @@ class FeatureFlagMatcher: persondistinctid__team_id=team_id, ) basic_group_query: QuerySet = Group.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(team_id=team_id) - group_query_per_group_type_mapping: Dict[GroupTypeIndex, Tuple[QuerySet, List[str]]] = {} + group_query_per_group_type_mapping: dict[GroupTypeIndex, tuple[QuerySet, list[str]]] = {} # :TRICKY: Create a queryset for each group type that uniquely identifies a group, based on the groups passed in. # If no groups for a group type are passed in, we can skip querying for that group type, # since the result will always be `false`. @@ -431,7 +431,7 @@ class FeatureFlagMatcher: [], ) - person_fields: List[str] = [] + person_fields: list[str] = [] for existence_condition_key in self.has_pure_is_not_conditions: if existence_condition_key == PERSON_KEY: @@ -637,7 +637,7 @@ class FeatureFlagMatcher: def can_compute_locally( self, - properties: List[Property], + properties: list[Property], group_type_index: Optional[GroupTypeIndex] = None, ) -> bool: target_properties = self.property_value_overrides @@ -682,10 +682,10 @@ class FeatureFlagMatcher: def get_feature_flag_hash_key_overrides( team_id: int, - distinct_ids: List[str], + distinct_ids: list[str], using_database: str = "default", - person_id_to_distinct_id_mapping: Optional[Dict[int, str]] = None, -) -> Dict[str, str]: + person_id_to_distinct_id_mapping: Optional[dict[int, str]] = None, +) -> dict[str, str]: feature_flag_to_key_overrides = {} # Priority to the first distinctID's values, to keep this function deterministic @@ -716,15 +716,15 @@ def get_feature_flag_hash_key_overrides( # Return a Dict with all flags and their values def _get_all_feature_flags( - feature_flags: List[FeatureFlag], + feature_flags: list[FeatureFlag], team_id: int, distinct_id: str, - person_overrides: Optional[Dict[str, str]] = None, - groups: Optional[Dict[GroupTypeName, str]] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, + person_overrides: Optional[dict[str, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, skip_database_flags: bool = False, -) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: +) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: if group_property_value_overrides is None: group_property_value_overrides = {} if property_value_overrides is None: @@ -752,11 +752,11 @@ def _get_all_feature_flags( def get_all_feature_flags( team_id: int, distinct_id: str, - groups: Optional[Dict[GroupTypeName, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, hash_key_override: Optional[str] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, -) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, +) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: if group_property_value_overrides is None: group_property_value_overrides = {} if property_value_overrides is None: @@ -907,7 +907,7 @@ def get_all_feature_flags( ) -def set_feature_flag_hash_key_overrides(team_id: int, distinct_ids: List[str], hash_key_override: str) -> bool: +def set_feature_flag_hash_key_overrides(team_id: int, distinct_ids: list[str], hash_key_override: str) -> bool: # As a product decision, the first override wins, i.e consistency matters for the first walkthrough. # Thus, we don't need to do upserts here. @@ -1004,7 +1004,7 @@ def parse_exception_for_error_message(err: Exception): return reason -def key_and_field_for_property(property: Property) -> Tuple[str, str]: +def key_and_field_for_property(property: Property) -> tuple[str, str]: column = "group_properties" if property.type == "group" else "properties" key = property.key sanitized_key = sanitize_property_key(key) @@ -1016,8 +1016,8 @@ def key_and_field_for_property(property: Property) -> Tuple[str, str]: def get_all_properties_with_math_operators( - properties: List[Property], cohorts_cache: Dict[int, CohortOrEmpty], team_id: int -) -> List[Tuple[str, str]]: + properties: list[Property], cohorts_cache: dict[int, CohortOrEmpty], team_id: int +) -> list[tuple[str, str]]: all_keys_and_fields = [] for prop in properties: diff --git a/posthog/models/filters/base_filter.py b/posthog/models/filters/base_filter.py index ca2ef9e4c57..f4d46c9acaf 100644 --- a/posthog/models/filters/base_filter.py +++ b/posthog/models/filters/base_filter.py @@ -1,6 +1,6 @@ import inspect import json -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from rest_framework import request @@ -17,14 +17,14 @@ if TYPE_CHECKING: class BaseFilter(BaseParamMixin): - _data: Dict + _data: dict team: Optional["Team"] - kwargs: Dict + kwargs: dict hogql_context: HogQLContext def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[request.Request] = None, *, team: Optional["Team"] = None, @@ -69,7 +69,7 @@ class BaseFilter(BaseParamMixin): simplified_filter = self.simplify(self.team) self._data = simplified_filter._data - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: ret = {} for _, func in inspect.getmembers(self, inspect.ismethod): @@ -78,20 +78,20 @@ class BaseFilter(BaseParamMixin): return ret - def to_params(self) -> Dict[str, str]: + def to_params(self) -> dict[str, str]: return encode_get_request_params(data=self.to_dict()) def toJSON(self): return json.dumps(self.to_dict(), default=lambda o: o.__dict__, sort_keys=True, indent=4) - def shallow_clone(self, overrides: Dict[str, Any]): + def shallow_clone(self, overrides: dict[str, Any]): "Clone the filter's data while sharing the HogQL context" return type(self)( data={**self._data, **overrides}, **{**self.kwargs, "team": self.team, "hogql_context": self.hogql_context}, ) - def query_tags(self) -> Dict[str, Any]: + def query_tags(self) -> dict[str, Any]: ret = {} for _, func in inspect.getmembers(self, inspect.ismethod): diff --git a/posthog/models/filters/lifecycle_filter.py b/posthog/models/filters/lifecycle_filter.py index 576cf499f30..34775ac98f8 100644 --- a/posthog/models/filters/lifecycle_filter.py +++ b/posthog/models/filters/lifecycle_filter.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from posthog.models import Filter from posthog.utils import relative_date_parse from rest_framework.request import Request @@ -12,7 +12,7 @@ class LifecycleFilter(Filter): def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/mixins/base.py b/posthog/models/filters/mixins/base.py index b0c79566f72..a4640f0aae1 100644 --- a/posthog/models/filters/mixins/base.py +++ b/posthog/models/filters/mixins/base.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal +from typing import Literal BreakdownType = Literal["event", "person", "cohort", "group", "session", "hogql"] IntervalType = Literal["hour", "day", "week", "month"] @@ -6,4 +6,4 @@ FunnelWindowIntervalType = Literal["second", "minute", "hour", "day", "week", "m class BaseParamMixin: - _data: Dict + _data: dict diff --git a/posthog/models/filters/mixins/common.py b/posthog/models/filters/mixins/common.py index 8ab2c1ac7fc..65be0351403 100644 --- a/posthog/models/filters/mixins/common.py +++ b/posthog/models/filters/mixins/common.py @@ -2,7 +2,7 @@ import datetime import json import re from math import ceil -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -142,7 +142,7 @@ class FormulaMixin(BaseParamMixin): class BreakdownMixin(BaseParamMixin): @cached_property - def breakdown(self) -> Optional[Union[str, List[Union[str, int]]]]: + def breakdown(self) -> Optional[Union[str, list[Union[str, int]]]]: breakdown = self._data.get(BREAKDOWN) if not isinstance(breakdown, str): @@ -171,11 +171,11 @@ class BreakdownMixin(BaseParamMixin): return int(attribution_value) if attribution_value is not None else None @cached_property - def breakdowns(self) -> Optional[List[Dict[str, Any]]]: + def breakdowns(self) -> Optional[list[dict[str, Any]]]: breakdowns = self._data.get(BREAKDOWNS) try: - if isinstance(breakdowns, List): + if isinstance(breakdowns, list): return breakdowns elif isinstance(breakdowns, str): return json.loads(breakdowns) @@ -226,7 +226,7 @@ class BreakdownMixin(BaseParamMixin): @include_dict def breakdown_to_dict(self): - result: Dict = {} + result: dict = {} if self.breakdown: result[BREAKDOWN] = self.breakdown if self.breakdowns: @@ -346,8 +346,8 @@ class CompareMixin(BaseParamMixin): class DateMixin(BaseParamMixin): - date_from_delta_mapping: Optional[Dict[str, int]] - date_to_delta_mapping: Optional[Dict[str, int]] + date_from_delta_mapping: Optional[dict[str, int]] + date_to_delta_mapping: Optional[dict[str, int]] @cached_property def _date_from(self) -> Optional[Union[str, datetime.datetime]]: @@ -417,7 +417,7 @@ class DateMixin(BaseParamMixin): return process_bool(self._data.get(EXPLICIT_DATE)) @include_dict - def date_to_dict(self) -> Dict: + def date_to_dict(self) -> dict: result_dict = {} if self._date_from: result_dict.update( @@ -455,8 +455,8 @@ class DateMixin(BaseParamMixin): class EntitiesMixin(BaseParamMixin): @cached_property - def entities(self) -> List[Entity]: - processed_entities: List[Entity] = [] + def entities(self) -> list[Entity]: + processed_entities: list[Entity] = [] if self._data.get(ACTIONS): actions = self._data.get(ACTIONS, []) if isinstance(actions, str): @@ -487,20 +487,20 @@ class EntitiesMixin(BaseParamMixin): return {"number_of_entities": len(self.entities)} @cached_property - def actions(self) -> List[Entity]: + def actions(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_ACTIONS] @cached_property - def events(self) -> List[Entity]: + def events(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_EVENTS] @cached_property - def data_warehouse_entities(self) -> List[Entity]: + def data_warehouse_entities(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_DATA_WAREHOUSE] @cached_property - def exclusions(self) -> List[ExclusionEntity]: - _exclusions: List[ExclusionEntity] = [] + def exclusions(self) -> list[ExclusionEntity]: + _exclusions: list[ExclusionEntity] = [] if self._data.get(EXCLUSIONS): exclusion_list = self._data.get(EXCLUSIONS, []) if isinstance(exclusion_list, str): diff --git a/posthog/models/filters/mixins/funnel.py b/posthog/models/filters/mixins/funnel.py index 91312a50304..3baf5f15b50 100644 --- a/posthog/models/filters/mixins/funnel.py +++ b/posthog/models/filters/mixins/funnel.py @@ -1,6 +1,6 @@ import datetime import json -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union from posthog.models.property import Property @@ -111,7 +111,7 @@ class FunnelWindowMixin(BaseParamMixin): @include_dict def funnel_window_to_dict(self): - dict_part: Dict = {} + dict_part: dict = {} if self.funnel_window_interval is not None: dict_part[FUNNEL_WINDOW_INTERVAL] = self.funnel_window_interval if self.funnel_window_interval_unit is not None: @@ -154,7 +154,7 @@ class FunnelPersonsStepMixin(BaseParamMixin): return int(_step_as_string) @cached_property - def funnel_custom_steps(self) -> List[int]: + def funnel_custom_steps(self) -> list[int]: """ Custom step numbers to get persons for. This overrides FunnelPersonsStepMixin::funnel_step """ @@ -176,7 +176,7 @@ class FunnelPersonsStepMixin(BaseParamMixin): class FunnelPersonsStepBreakdownMixin(BaseParamMixin): @cached_property - def funnel_step_breakdown(self) -> Optional[Union[List[str], int, str]]: + def funnel_step_breakdown(self) -> Optional[Union[list[str], int, str]]: """ The breakdown value for which to get persons for. @@ -241,7 +241,7 @@ class FunnelTypeMixin(BaseParamMixin): @include_dict def funnel_type_to_dict(self): - result: Dict[str, str] = {} + result: dict[str, str] = {} if self.funnel_order_type: result[FUNNEL_ORDER_TYPE] = self.funnel_order_type if self.funnel_viz_type: @@ -277,7 +277,7 @@ class FunnelTrendsPersonsMixin(BaseParamMixin): @include_dict def funnel_trends_persons_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.entrance_period_start: result_dict[ENTRANCE_PERIOD_START] = self.entrance_period_start.isoformat() if self.drop_off is not None: @@ -298,7 +298,7 @@ class FunnelCorrelationMixin(BaseParamMixin): return None @cached_property - def correlation_property_names(self) -> List[str]: + def correlation_property_names(self) -> list[str]: # Person Property names for which to run Person Properties correlation property_names = self._data.get(FUNNEL_CORRELATION_NAMES, []) if isinstance(property_names, str): @@ -306,7 +306,7 @@ class FunnelCorrelationMixin(BaseParamMixin): return property_names @cached_property - def correlation_property_exclude_names(self) -> List[str]: + def correlation_property_exclude_names(self) -> list[str]: # Person Property names to exclude from Person Properties correlation property_names = self._data.get(FUNNEL_CORRELATION_EXCLUDE_NAMES, []) if isinstance(property_names, str): @@ -314,7 +314,7 @@ class FunnelCorrelationMixin(BaseParamMixin): return property_names @cached_property - def correlation_event_names(self) -> List[str]: + def correlation_event_names(self) -> list[str]: # Event names for which to run EventWithProperties correlation event_names = self._data.get(FUNNEL_CORRELATION_EVENT_NAMES, []) if isinstance(event_names, str): @@ -322,7 +322,7 @@ class FunnelCorrelationMixin(BaseParamMixin): return event_names @cached_property - def correlation_event_exclude_names(self) -> List[str]: + def correlation_event_exclude_names(self) -> list[str]: # Exclude event names from Event correlation property_names = self._data.get(FUNNEL_CORRELATION_EXCLUDE_EVENT_NAMES, []) if isinstance(property_names, str): @@ -330,7 +330,7 @@ class FunnelCorrelationMixin(BaseParamMixin): return property_names @cached_property - def correlation_event_exclude_property_names(self) -> List[str]: + def correlation_event_exclude_property_names(self) -> list[str]: # Event Property names to exclude from EventWithProperties correlation property_names = self._data.get(FUNNEL_CORRELATION_EVENT_EXCLUDE_PROPERTY_NAMES, []) if isinstance(property_names, str): @@ -339,7 +339,7 @@ class FunnelCorrelationMixin(BaseParamMixin): @include_dict def funnel_correlation_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.correlation_type: result_dict[FUNNEL_CORRELATION_TYPE] = self.correlation_type if self.correlation_property_names: @@ -370,7 +370,7 @@ class FunnelCorrelationActorsMixin(BaseParamMixin): return Entity(event) if event else None @cached_property - def correlation_property_values(self) -> Optional[List[Property]]: + def correlation_property_values(self) -> Optional[list[Property]]: # Used for property correlations persons _props = self._data.get(FUNNEL_CORRELATION_PROPERTY_VALUES) @@ -421,7 +421,7 @@ class FunnelCorrelationActorsMixin(BaseParamMixin): @include_dict def funnel_correlation_persons_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.correlation_person_entity: result_dict[FUNNEL_CORRELATION_PERSON_ENTITY] = self.correlation_person_entity.to_dict() if self.correlation_property_values: diff --git a/posthog/models/filters/mixins/paths.py b/posthog/models/filters/mixins/paths.py index 393b1f7140a..8249a7015d1 100644 --- a/posthog/models/filters/mixins/paths.py +++ b/posthog/models/filters/mixins/paths.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from posthog.constants import ( CUSTOM_EVENT, @@ -84,21 +84,21 @@ class PathsHogQLExpressionMixin(PathTypeMixin): class TargetEventsMixin(BaseParamMixin): @cached_property - def target_events(self) -> List[str]: + def target_events(self) -> list[str]: target_events = self._data.get(PATHS_INCLUDE_EVENT_TYPES, []) if isinstance(target_events, str): return json.loads(target_events) return target_events @cached_property - def custom_events(self) -> List[str]: + def custom_events(self) -> list[str]: custom_events = self._data.get(PATHS_INCLUDE_CUSTOM_EVENTS, []) if isinstance(custom_events, str): return json.loads(custom_events) return custom_events @cached_property - def exclude_events(self) -> List[str]: + def exclude_events(self) -> list[str]: _exclude_events = self._data.get(PATHS_EXCLUDE_EVENTS, []) if isinstance(_exclude_events, str): return json.loads(_exclude_events) @@ -160,7 +160,7 @@ class FunnelPathsMixin(BaseParamMixin): class PathGroupingMixin(BaseParamMixin): @cached_property - def path_groupings(self) -> Optional[List[str]]: + def path_groupings(self) -> Optional[list[str]]: path_groupings = self._data.get(PATH_GROUPINGS, None) if isinstance(path_groupings, str): return json.loads(path_groupings) @@ -193,7 +193,7 @@ class PathReplacementMixin(BaseParamMixin): class LocalPathCleaningFiltersMixin(BaseParamMixin): @cached_property - def local_path_cleaning_filters(self) -> Optional[List[Dict[str, str]]]: + def local_path_cleaning_filters(self) -> Optional[list[dict[str, str]]]: local_path_cleaning_filters = self._data.get(LOCAL_PATH_CLEANING_FILTERS, None) if isinstance(local_path_cleaning_filters, str): return json.loads(local_path_cleaning_filters) diff --git a/posthog/models/filters/mixins/property.py b/posthog/models/filters/mixins/property.py index ff4cb56fee9..2ffc984754b 100644 --- a/posthog/models/filters/mixins/property.py +++ b/posthog/models/filters/mixins/property.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -15,7 +15,7 @@ from posthog.models.property import Property, PropertyGroup class PropertyMixin(BaseParamMixin): @cached_property - def old_properties(self) -> List[Property]: + def old_properties(self) -> list[Property]: _props = self._data.get(PROPERTIES) if isinstance(_props, str): @@ -64,7 +64,7 @@ class PropertyMixin(BaseParamMixin): # old properties return PropertyGroup(type=PropertyOperatorType.AND, values=self.old_properties) - def _parse_properties(self, properties: Optional[Any]) -> List[Property]: + def _parse_properties(self, properties: Optional[Any]) -> list[Property]: if isinstance(properties, list): _properties = [] for prop_params in properties: @@ -94,19 +94,19 @@ class PropertyMixin(BaseParamMixin): ) return ret - def _parse_property_group(self, group: Optional[Dict]) -> PropertyGroup: + def _parse_property_group(self, group: Optional[dict]) -> PropertyGroup: if group and "type" in group and "values" in group: return PropertyGroup( PropertyOperatorType(group["type"].upper()), self._parse_property_group_list(group["values"]), ) - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) - def _parse_property_group_list(self, prop_list: Optional[List]) -> Union[List[Property], List[PropertyGroup]]: + def _parse_property_group_list(self, prop_list: Optional[list]) -> Union[list[Property], list[PropertyGroup]]: if not prop_list: # empty prop list - return cast(List[Property], []) + return cast(list[Property], []) has_property_groups = False has_simple_properties = False diff --git a/posthog/models/filters/mixins/retention.py b/posthog/models/filters/mixins/retention.py index eeec027c4f8..044278f0142 100644 --- a/posthog/models/filters/mixins/retention.py +++ b/posthog/models/filters/mixins/retention.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union from dateutil.relativedelta import relativedelta from django.utils import timezone @@ -112,7 +112,7 @@ class RetentionDateDerivedMixin(PeriodMixin, TotalIntervalsMixin, DateMixin, Sel @staticmethod def determine_time_delta( total_intervals: int, period: str - ) -> Tuple[Union[timedelta, relativedelta], Union[timedelta, relativedelta]]: + ) -> tuple[Union[timedelta, relativedelta], Union[timedelta, relativedelta]]: if period == "Hour": return timedelta(hours=total_intervals), timedelta(hours=1) elif period == "Week": diff --git a/posthog/models/filters/mixins/session_recordings.py b/posthog/models/filters/mixins/session_recordings.py index 8779ea92e6b..83d9bb40245 100644 --- a/posthog/models/filters/mixins/session_recordings.py +++ b/posthog/models/filters/mixins/session_recordings.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional, Literal +from typing import Optional, Literal from posthog.constants import PERSON_UUID_FILTER, SESSION_RECORDINGS_FILTER_IDS from posthog.models.filters.mixins.common import BaseParamMixin @@ -19,7 +19,7 @@ class SessionRecordingsMixin(BaseParamMixin): return self._data.get("console_search_query", None) @cached_property - def console_logs_filter(self) -> List[Literal["error", "warn", "info"]]: + def console_logs_filter(self) -> list[Literal["error", "warn", "info"]]: user_value = self._data.get("console_logs", None) or [] if isinstance(user_value, str): user_value = json.loads(user_value) @@ -43,7 +43,7 @@ class SessionRecordingsMixin(BaseParamMixin): return None @cached_property - def session_ids(self) -> Optional[List[str]]: + def session_ids(self) -> Optional[list[str]]: # Can be ['a', 'b'] or "['a', 'b']" or "a,b" session_ids_str = self._data.get(SESSION_RECORDINGS_FILTER_IDS, None) diff --git a/posthog/models/filters/mixins/simplify.py b/posthog/models/filters/mixins/simplify.py index 3b1e0eb426b..72d8d184539 100644 --- a/posthog/models/filters/mixins/simplify.py +++ b/posthog/models/filters/mixins/simplify.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from posthog.constants import PropertyOperatorType from posthog.models.property import GroupTypeIndex, PropertyGroup @@ -67,9 +67,9 @@ class SimplifyFilterMixin: self, team: "Team", entity_type: Literal["events", "actions", "exclusions"], - entity_params: Dict, + entity_params: dict, **kwargs, - ) -> Dict: + ) -> dict: from posthog.models.entity import Entity, ExclusionEntity EntityClass = ExclusionEntity if entity_type == "exclusions" else Entity @@ -82,7 +82,7 @@ class SimplifyFilterMixin: return EntityClass({**entity_params, "properties": properties}).to_dict() - def _simplify_properties(self, team: "Team", properties: List["Property"], **kwargs) -> "PropertyGroup": + def _simplify_properties(self, team: "Team", properties: list["Property"], **kwargs) -> "PropertyGroup": simplified_properties_values = [] for prop in properties: simplified_properties_values.append(self._simplify_property(team, prop, **kwargs)) diff --git a/posthog/models/filters/mixins/stickiness.py b/posthog/models/filters/mixins/stickiness.py index 0dfca1d834c..1b659481b98 100644 --- a/posthog/models/filters/mixins/stickiness.py +++ b/posthog/models/filters/mixins/stickiness.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Callable from rest_framework.exceptions import ValidationError diff --git a/posthog/models/filters/mixins/utils.py b/posthog/models/filters/mixins/utils.py index a297cdcfa63..5b5fe6d422d 100644 --- a/posthog/models/filters/mixins/utils.py +++ b/posthog/models/filters/mixins/utils.py @@ -1,5 +1,6 @@ from functools import lru_cache -from typing import Callable, Optional, TypeVar, Union +from typing import Optional, TypeVar, Union +from collections.abc import Callable from posthog.utils import str_to_bool diff --git a/posthog/models/filters/path_filter.py b/posthog/models/filters/path_filter.py index 5ef9395d82d..df7a0ca9285 100644 --- a/posthog/models/filters/path_filter.py +++ b/posthog/models/filters/path_filter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from rest_framework.request import Request @@ -76,7 +76,7 @@ class PathFilter( ): def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/retention_filter.py b/posthog/models/filters/retention_filter.py index 338d3d87e3e..6f73aeb69d3 100644 --- a/posthog/models/filters/retention_filter.py +++ b/posthog/models/filters/retention_filter.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from rest_framework.request import Request @@ -48,7 +48,7 @@ class RetentionFilter( SampleMixin, BaseFilter, ): - def __init__(self, data: Optional[Dict[str, Any]] = None, request: Optional[Request] = None, **kwargs) -> None: + def __init__(self, data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs) -> None: if data is None: data = {} if data: @@ -58,7 +58,7 @@ class RetentionFilter( super().__init__(data, request, **kwargs) @cached_property - def breakdown_values(self) -> Optional[Tuple[Union[str, int], ...]]: + def breakdown_values(self) -> Optional[tuple[Union[str, int], ...]]: raw_value = self._data.get("breakdown_values", None) if raw_value is None: return None diff --git a/posthog/models/filters/stickiness_filter.py b/posthog/models/filters/stickiness_filter.py index 4674c4ceeb3..cde6d802092 100644 --- a/posthog/models/filters/stickiness_filter.py +++ b/posthog/models/filters/stickiness_filter.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from collections.abc import Callable from django.db.models.functions.datetime import ( TruncDay, @@ -62,7 +63,7 @@ class StickinessFilter( def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/test/test_filter.py b/posthog/models/filters/test/test_filter.py index 63a947bca67..eb99a3ac429 100644 --- a/posthog/models/filters/test/test_filter.py +++ b/posthog/models/filters/test/test_filter.py @@ -1,6 +1,7 @@ import datetime import json -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from django.db.models import Q, Func, F, CharField from freezegun import freeze_time @@ -993,8 +994,8 @@ class TestDjangoPropertiesToQ(property_to_Q_test_factory(_filter_persons, _creat def filter_persons_with_property_group( - filter: Filter, team: Team, property_overrides: Optional[Dict[str, Any]] = None -) -> List[str]: + filter: Filter, team: Team, property_overrides: Optional[dict[str, Any]] = None +) -> list[str]: if property_overrides is None: property_overrides = {} flush_persons_and_events() diff --git a/posthog/models/filters/test/test_path_filter.py b/posthog/models/filters/test/test_path_filter.py index df8ffac45aa..3f66e0b9b73 100644 --- a/posthog/models/filters/test/test_path_filter.py +++ b/posthog/models/filters/test/test_path_filter.py @@ -18,7 +18,7 @@ class TestPathFilter(BaseTest): } ) - self.assertEquals( + self.assertEqual( filter.to_dict(), filter.to_dict() | { @@ -51,7 +51,7 @@ class TestPathFilter(BaseTest): } ) - self.assertEquals( + self.assertEqual( filter.to_dict(), filter.to_dict() | { diff --git a/posthog/models/group/util.py b/posthog/models/group/util.py index 427c883a2e9..0b9c0fb9724 100644 --- a/posthog/models/group/util.py +++ b/posthog/models/group/util.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Dict, Optional, Union +from typing import Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -17,7 +17,7 @@ def raw_create_group_ch( team_id: int, group_type_index: GroupTypeIndex, group_key: str, - properties: Dict, + properties: dict, created_at: datetime.datetime, timestamp: Optional[datetime.datetime] = None, sync: bool = False, @@ -44,7 +44,7 @@ def create_group( team_id: int, group_type_index: GroupTypeIndex, group_key: str, - properties: Optional[Dict] = None, + properties: Optional[dict] = None, timestamp: Optional[Union[datetime.datetime, str]] = None, sync: bool = False, ) -> Group: diff --git a/posthog/models/instance_setting.py b/posthog/models/instance_setting.py index 749975e5d5e..0ad0ca5bde0 100644 --- a/posthog/models/instance_setting.py +++ b/posthog/models/instance_setting.py @@ -1,6 +1,6 @@ import json from contextlib import contextmanager -from typing import Any, List +from typing import Any from django.db import models @@ -29,7 +29,7 @@ def get_instance_setting(key: str) -> Any: return CONSTANCE_CONFIG[key][0] # Get the default value -def get_instance_settings(keys: List[str]) -> Any: +def get_instance_settings(keys: list[str]) -> Any: for key in keys: assert key in CONSTANCE_CONFIG, f"Unknown dynamic setting: {repr(key)}" diff --git a/posthog/models/integration.py b/posthog/models/integration.py index 8ce1c9d6ef7..6e313ea179f 100644 --- a/posthog/models/integration.py +++ b/posthog/models/integration.py @@ -2,7 +2,7 @@ import hashlib import hmac import time from datetime import timedelta -from typing import Dict, List, Literal +from typing import Literal from django.db import models from rest_framework.request import Request @@ -50,7 +50,7 @@ class SlackIntegration: def client(self) -> WebClient: return WebClient(self.integration.sensitive_config["access_token"]) - def list_channels(self) -> List[Dict]: + def list_channels(self) -> list[dict]: # NOTE: Annoyingly the Slack API has no search so we have to load all channels... # We load public and private channels separately as when mixed, the Slack API pagination is buggy public_channels = self._list_channels_by_type("public_channel") @@ -59,7 +59,7 @@ class SlackIntegration: return sorted(channels, key=lambda x: x["name"]) - def _list_channels_by_type(self, type: Literal["public_channel", "private_channel"]) -> List[Dict]: + def _list_channels_by_type(self, type: Literal["public_channel", "private_channel"]) -> list[dict]: max_page = 10 channels = [] cursor = None @@ -76,7 +76,7 @@ class SlackIntegration: return channels @classmethod - def integration_from_slack_response(cls, team_id: str, created_by: User, params: Dict[str, str]) -> Integration: + def integration_from_slack_response(cls, team_id: str, created_by: User, params: dict[str, str]) -> Integration: client = WebClient() slack_config = cls.slack_config() diff --git a/posthog/models/organization.py b/posthog/models/organization.py index 8740a0f34c4..cdb4ee7ccd9 100644 --- a/posthog/models/organization.py +++ b/posthog/models/organization.py @@ -1,6 +1,6 @@ import json import sys -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union import structlog from django.conf import settings @@ -45,7 +45,7 @@ class OrganizationUsageInfo(TypedDict): events: Optional[OrganizationUsageResource] recordings: Optional[OrganizationUsageResource] rows_synced: Optional[OrganizationUsageResource] - period: Optional[List[str]] + period: Optional[list[str]] class OrganizationManager(models.Manager): @@ -56,9 +56,9 @@ class OrganizationManager(models.Manager): self, user: Optional["User"], *, - team_fields: Optional[Dict[str, Any]] = None, + team_fields: Optional[dict[str, Any]] = None, **kwargs, - ) -> Tuple["Organization", Optional["OrganizationMembership"], "Team"]: + ) -> tuple["Organization", Optional["OrganizationMembership"], "Team"]: """Instead of doing the legwork of creating an organization yourself, delegate the details with bootstrap.""" from .project import Project # Avoiding circular import @@ -157,7 +157,7 @@ class Organization(UUIDModel): __repr__ = sane_repr("name") @property - def _billing_plan_details(self) -> Tuple[Optional[str], Optional[str]]: + def _billing_plan_details(self) -> tuple[Optional[str], Optional[str]]: """ Obtains details on the billing plan for the organization. Returns a tuple with (billing_plan_key, billing_realm) @@ -176,7 +176,7 @@ class Organization(UUIDModel): return (license.plan, "ee") return (None, None) - def update_available_features(self) -> List[Union[AvailableFeature, str]]: + def update_available_features(self) -> list[Union[AvailableFeature, str]]: """Updates field `available_features`. Does not `save()`.""" if is_cloud() or self.usage: # Since billing V2 we just use the available features which are updated when the billing service is called diff --git a/posthog/models/organization_domain.py b/posthog/models/organization_domain.py index 416b2d560f3..5d49d8a64ac 100644 --- a/posthog/models/organization_domain.py +++ b/posthog/models/organization_domain.py @@ -1,5 +1,5 @@ import secrets -from typing import Optional, Tuple +from typing import Optional import dns.resolver import structlog @@ -151,13 +151,13 @@ class OrganizationDomain(UUIDModel): """ return bool(self.saml_entity_id) and bool(self.saml_acs_url) and bool(self.saml_x509_cert) - def _complete_verification(self) -> Tuple["OrganizationDomain", bool]: + def _complete_verification(self) -> tuple["OrganizationDomain", bool]: self.last_verification_retry = None self.verified_at = timezone.now() self.save() return (self, True) - def attempt_verification(self) -> Tuple["OrganizationDomain", bool]: + def attempt_verification(self) -> tuple["OrganizationDomain", bool]: """ Performs a DNS verification for a specific domain. """ diff --git a/posthog/models/person/person.py b/posthog/models/person/person.py index a0456542333..20f9dd76754 100644 --- a/posthog/models/person/person.py +++ b/posthog/models/person/person.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Optional from django.db import models, transaction from django.db.models import F, Q @@ -21,15 +21,15 @@ class PersonManager(models.Manager): return person @staticmethod - def distinct_ids_exist(team_id: int, distinct_ids: List[str]) -> bool: + def distinct_ids_exist(team_id: int, distinct_ids: list[str]) -> bool: return PersonDistinctId.objects.filter(team_id=team_id, distinct_id__in=distinct_ids).exists() class Person(models.Model): - _distinct_ids: Optional[List[str]] + _distinct_ids: Optional[list[str]] @property - def distinct_ids(self) -> List[str]: + def distinct_ids(self) -> list[str]: if hasattr(self, "distinct_ids_cache"): return [id.distinct_id for id in self.distinct_ids_cache] if hasattr(self, "_distinct_ids") and self._distinct_ids: @@ -46,7 +46,7 @@ class Person(models.Model): PersonDistinctId.objects.create(person=self, distinct_id=distinct_id, team_id=self.team_id) # :DEPRECATED: This should happen through the plugin server - def _add_distinct_ids(self, distinct_ids: List[str]) -> None: + def _add_distinct_ids(self, distinct_ids: list[str]) -> None: for distinct_id in distinct_ids: self.add_distinct_id(distinct_id) @@ -274,7 +274,7 @@ class FlatPersonOverride(models.Model): ] -def get_distinct_ids_for_subquery(person: Person | None, team: Team) -> List[str]: +def get_distinct_ids_for_subquery(person: Person | None, team: Team) -> list[str]: """_summary_ Fetching distinct_ids for a person from CH is slow, so we fetch them from PG for certain queries. Therfore we need diff --git a/posthog/models/person/util.py b/posthog/models/person/util.py index f6bcc60ebc3..0e1efa7bdb2 100644 --- a/posthog/models/person/util.py +++ b/posthog/models/person/util.py @@ -1,7 +1,7 @@ import datetime import json from contextlib import ExitStack -from typing import Dict, List, Optional, Union +from typing import Optional, Union from uuid import UUID from zoneinfo import ZoneInfo @@ -80,7 +80,7 @@ if TEST: except: pass - def bulk_create_persons(persons_list: List[Dict]): + def bulk_create_persons(persons_list: list[dict]): persons = [] person_mapping = {} for _person in persons_list: @@ -127,7 +127,7 @@ def create_person( team_id: int, version: int, uuid: Optional[str] = None, - properties: Optional[Dict] = None, + properties: Optional[dict] = None, sync: bool = False, is_identified: bool = False, is_deleted: bool = False, @@ -217,7 +217,7 @@ def create_person_override( ) -def get_persons_by_distinct_ids(team_id: int, distinct_ids: List[str]) -> QuerySet: +def get_persons_by_distinct_ids(team_id: int, distinct_ids: list[str]) -> QuerySet: return Person.objects.filter( team_id=team_id, persondistinctid__team_id=team_id, @@ -225,7 +225,7 @@ def get_persons_by_distinct_ids(team_id: int, distinct_ids: List[str]) -> QueryS ) -def get_persons_by_uuids(team: Team, uuids: List[str]) -> QuerySet: +def get_persons_by_uuids(team: Team, uuids: list[str]) -> QuerySet: return Person.objects.filter(team_id=team.pk, uuid__in=uuids) @@ -254,7 +254,7 @@ def _delete_person( ) -def _get_distinct_ids_with_version(person: Person) -> Dict[str, int]: +def _get_distinct_ids_with_version(person: Person) -> dict[str, int]: return { distinct_id: int(version or 0) for distinct_id, version in PersonDistinctId.objects.filter(person=person, team_id=person.team_id) diff --git a/posthog/models/personal_api_key.py b/posthog/models/personal_api_key.py index 047471f4fe8..23bb04e0b42 100644 --- a/posthog/models/personal_api_key.py +++ b/posthog/models/personal_api_key.py @@ -1,4 +1,4 @@ -from typing import Optional, Literal, Tuple, get_args +from typing import Optional, Literal, get_args import hashlib from django.contrib.auth.hashers import PBKDF2PasswordHasher @@ -111,5 +111,5 @@ APIScopeObjectOrNotSupported = Literal[ ] -API_SCOPE_OBJECTS: Tuple[APIScopeObject, ...] = get_args(APIScopeObject) -API_SCOPE_ACTIONS: Tuple[APIScopeActions, ...] = get_args(APIScopeActions) +API_SCOPE_OBJECTS: tuple[APIScopeObject, ...] = get_args(APIScopeObject) +API_SCOPE_ACTIONS: tuple[APIScopeActions, ...] = get_args(APIScopeActions) diff --git a/posthog/models/plugin.py b/posthog/models/plugin.py index 900b1abec77..06971c1ce7c 100644 --- a/posthog/models/plugin.py +++ b/posthog/models/plugin.py @@ -3,7 +3,7 @@ import json import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Optional, cast from uuid import UUID from django.conf import settings @@ -52,13 +52,13 @@ def raise_if_plugin_installed(url: str, organization_id: str): raise ValidationError(f'Plugin from URL "{url_without_private_key}" already installed!') -def update_validated_data_from_url(validated_data: Dict[str, Any], url: str) -> Dict[str, Any]: +def update_validated_data_from_url(validated_data: dict[str, Any], url: str) -> dict[str, Any]: """If remote plugin, download the archive and get up-to-date validated_data from there. Returns plugin.json.""" - plugin_json: Optional[Dict[str, Any]] + plugin_json: Optional[dict[str, Any]] if url.startswith("file:"): plugin_path = url[5:] plugin_json_path = os.path.join(plugin_path, "plugin.json") - plugin_json = cast(Optional[Dict[str, Any]], load_json_file(plugin_json_path)) + plugin_json = cast(Optional[dict[str, Any]], load_json_file(plugin_json_path)) if not plugin_json: raise ValidationError(f"Could not load plugin.json from: {plugin_json_path}") validated_data["plugin_type"] = "local" @@ -81,7 +81,7 @@ def update_validated_data_from_url(validated_data: Dict[str, Any], url: str) -> validated_data["latest_tag"] = parsed_url.get("tag", None) validated_data["archive"] = download_plugin_archive(validated_data["url"], validated_data["tag"]) plugin_json = cast( - Optional[Dict[str, Any]], + Optional[dict[str, Any]], get_file_from_archive(validated_data["archive"], "plugin.json"), ) if not plugin_json: @@ -124,7 +124,7 @@ class PluginManager(models.Manager): def install(self, **kwargs) -> "Plugin": if "organization_id" not in kwargs and "organization" in kwargs: kwargs["organization_id"] = kwargs["organization"].id - plugin_json: Optional[Dict[str, Any]] = None + plugin_json: Optional[dict[str, Any]] = None if kwargs.get("plugin_type", None) != Plugin.PluginType.SOURCE: plugin_json = update_validated_data_from_url(kwargs, kwargs["url"]) raise_if_plugin_installed(kwargs["url"], kwargs["organization_id"]) @@ -204,8 +204,8 @@ class Plugin(models.Model): objects: PluginManager = PluginManager() - def get_default_config(self) -> Dict[str, Any]: - config: Dict[str, Any] = {} + def get_default_config(self) -> dict[str, Any]: + config: dict[str, Any] = {} config_schema = self.config_schema if isinstance(config_schema, dict): for key, config_entry in config_schema.items(): @@ -296,8 +296,8 @@ class PluginLogEntryType(str, Enum): class PluginSourceFileManager(models.Manager): def sync_from_plugin_archive( - self, plugin: Plugin, plugin_json_parsed: Optional[Dict[str, Any]] = None - ) -> Tuple[ + self, plugin: Plugin, plugin_json_parsed: Optional[dict[str, Any]] = None + ) -> tuple[ "PluginSourceFile", Optional["PluginSourceFile"], Optional["PluginSourceFile"], @@ -426,12 +426,12 @@ def fetch_plugin_log_entries( before: Optional[timezone.datetime] = None, search: Optional[str] = None, limit: Optional[int] = None, - type_filter: Optional[List[PluginLogEntryType]] = None, -) -> List[PluginLogEntry]: + type_filter: Optional[list[PluginLogEntryType]] = None, +) -> list[PluginLogEntry]: if type_filter is None: type_filter = [] - clickhouse_where_parts: List[str] = [] - clickhouse_kwargs: Dict[str, Any] = {} + clickhouse_where_parts: list[str] = [] + clickhouse_kwargs: dict[str, Any] = {} if team_id is not None: clickhouse_where_parts.append("team_id = %(team_id)s") clickhouse_kwargs["team_id"] = team_id @@ -457,7 +457,7 @@ def fetch_plugin_log_entries( return [PluginLogEntry(*result) for result in cast(list, sync_execute(clickhouse_query, clickhouse_kwargs))] -def validate_plugin_job_payload(plugin: Plugin, job_type: str, payload: Dict[str, Any], *, is_staff: bool): +def validate_plugin_job_payload(plugin: Plugin, job_type: str, payload: dict[str, Any], *, is_staff: bool): if not plugin.public_jobs: raise ValidationError("Plugin has no public jobs") if job_type not in plugin.public_jobs: diff --git a/posthog/models/project.py b/posthog/models/project.py index c4ead260fb7..030bd4669a6 100644 --- a/posthog/models/project.py +++ b/posthog/models/project.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional from django.db import models from django.db import transaction from django.core.validators import MinLengthValidator @@ -8,7 +8,7 @@ if TYPE_CHECKING: class ProjectManager(models.Manager): - def create_with_team(self, team_fields: Optional[dict] = None, **kwargs) -> Tuple["Project", "Team"]: + def create_with_team(self, team_fields: Optional[dict] = None, **kwargs) -> tuple["Project", "Team"]: from .team import Team with transaction.atomic(): diff --git a/posthog/models/property/property.py b/posthog/models/property/property.py index defd098cd7e..74ef611e257 100644 --- a/posthog/models/property/property.py +++ b/posthog/models/property/property.py @@ -2,11 +2,8 @@ import json from enum import Enum from typing import ( Any, - Dict, - List, Literal, Optional, - Tuple, Union, cast, ) @@ -27,7 +24,7 @@ class BehavioralPropertyType(str, Enum): RESTARTED_PERFORMING_EVENT = "restarted_performing_event" -ValueT = Union[str, int, List[str]] +ValueT = Union[str, int, list[str]] PropertyType = Literal[ "event", "feature", @@ -78,7 +75,7 @@ OperatorType = Literal[ OperatorInterval = Literal["day", "week", "month", "year"] GroupTypeName = str -PropertyIdentifier = Tuple[PropertyName, PropertyType, Optional[GroupTypeIndex]] +PropertyIdentifier = tuple[PropertyName, PropertyType, Optional[GroupTypeIndex]] NEGATED_OPERATORS = ["is_not", "not_icontains", "not_regex", "is_not_set"] CLICKHOUSE_ONLY_PROPERTY_TYPES = [ @@ -187,7 +184,7 @@ class Property: # Type of `key` event_type: Optional[Literal["events", "actions"]] # Any extra filters on the event - event_filters: Optional[List["Property"]] + event_filters: Optional[list["Property"]] # Query people who did event '$pageview' 20 times in the last 30 days # translates into: # key = '$pageview', value = 'performed_event_multiple' @@ -216,7 +213,7 @@ class Property: total_periods: Optional[int] min_periods: Optional[int] negation: Optional[bool] = False - _data: Dict + _data: dict def __init__( self, @@ -239,7 +236,7 @@ class Property: seq_time_value: Optional[int] = None, seq_time_interval: Optional[OperatorInterval] = None, negation: Optional[bool] = None, - event_filters: Optional[List["Property"]] = None, + event_filters: Optional[list["Property"]] = None, **kwargs, ) -> None: self.key = key @@ -298,7 +295,7 @@ class Property: params_repr = ", ".join(f"{key}={repr(value)}" for key, value in self.to_dict().items()) return f"Property({params_repr})" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return {key: value for key, value in vars(self).items() if value is not None} @staticmethod @@ -331,17 +328,17 @@ class Property: class PropertyGroup: type: PropertyOperatorType - values: Union[List[Property], List["PropertyGroup"]] + values: Union[list[Property], list["PropertyGroup"]] def __init__( self, type: PropertyOperatorType, - values: Union[List[Property], List["PropertyGroup"]], + values: Union[list[Property], list["PropertyGroup"]], ) -> None: self.type = type self.values = values - def combine_properties(self, operator: PropertyOperatorType, properties: List[Property]) -> "PropertyGroup": + def combine_properties(self, operator: PropertyOperatorType, properties: list[Property]) -> "PropertyGroup": if not properties: return self @@ -375,7 +372,7 @@ class PropertyGroup: return f"PropertyGroup(type={self.type}-{params_repr})" @cached_property - def flat(self) -> List[Property]: + def flat(self) -> list[Property]: return list(self._property_groups_flat(self)) def _property_groups_flat(self, prop_group: "PropertyGroup"): diff --git a/posthog/models/property/util.py b/posthog/models/property/util.py index cae1be3340e..b1ce9b6087a 100644 --- a/posthog/models/property/util.py +++ b/posthog/models/property/util.py @@ -1,17 +1,15 @@ import re from collections import Counter -from typing import Any, Callable -from typing import Counter as TCounter +from typing import Any +from collections.abc import Callable +from collections import Counter as TCounter from typing import ( - Dict, - Iterable, - List, Literal, Optional, - Tuple, Union, cast, ) +from collections.abc import Iterable from rest_framework import exceptions @@ -88,7 +86,7 @@ def parse_prop_grouped_clauses( person_id_joined_alias: str = "person_id", group_properties_joined: bool = True, _top_level: bool = True, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Translate the given property filter group into an SQL condition clause (+ SQL params).""" if not property_group or len(property_group.values) == 0: return "", {} @@ -119,7 +117,7 @@ def parse_prop_grouped_clauses( _final = f"{property_group.type} ".join(group_clauses) else: _final, final_params = parse_prop_clauses( - filters=cast(List[Property], property_group.values), + filters=cast(list[Property], property_group.values), prepend=f"{prepend}", table_name=table_name, allow_denormalized_props=allow_denormalized_props, @@ -151,7 +149,7 @@ def is_property_group(group: Union[Property, "PropertyGroup"]): def parse_prop_clauses( team_id: int, - filters: List[Property], + filters: list[Property], *, hogql_context: Optional[HogQLContext], prepend: str = "global", @@ -162,10 +160,10 @@ def parse_prop_clauses( person_id_joined_alias: str = "person_id", group_properties_joined: bool = True, property_operator: PropertyOperatorType = PropertyOperatorType.AND, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Translate the given property filter into an SQL condition clause (+ SQL params).""" final = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} table_formatted = table_name if table_formatted != "": @@ -411,7 +409,7 @@ def prop_filter_json_extract( property_operator: str = PropertyOperatorType.AND, table_name: Optional[str] = None, use_event_column: Optional[str] = None, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: # TODO: Once all queries are migrated over we can get rid of allow_denormalized_props if transform_expression is not None: prop_var = transform_expression(prop_var) @@ -433,7 +431,7 @@ def prop_filter_json_extract( if prop.negation: operator = negate_operator(operator or "exact") - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if operator == "is_not": params = { @@ -649,7 +647,7 @@ def get_single_or_multi_property_string_expr( allow_denormalized_props=True, materialised_table_column: str = "properties", normalize_url: bool = False, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: """ When querying for breakdown properties: * If the breakdown provided is a string, we extract the JSON from the properties object stored in the DB @@ -663,7 +661,7 @@ def get_single_or_multi_property_string_expr( no alias will be appended. """ - breakdown_params: Dict[str, Any] = {} + breakdown_params: dict[str, Any] = {} if isinstance(breakdown, str) or isinstance(breakdown, int): breakdown_key = f"breakdown_param_{len(breakdown_params) + 1}" breakdown_key = f"breakdown_param_{len(breakdown_params) + 1}" @@ -719,7 +717,7 @@ def get_property_string_expr( allow_denormalized_props: bool = True, table_alias: Optional[str] = None, materialised_table_column: str = "properties", -) -> Tuple[str, bool]: +) -> tuple[str, bool]: """ :param table: @@ -752,8 +750,8 @@ def get_property_string_expr( return trim_quotes_expr(f"JSONExtractRaw({table_string}{column}, {var})"), False -def box_value(value: Any, remove_spaces=False) -> List[Any]: - if not isinstance(value, List): +def box_value(value: Any, remove_spaces=False) -> list[Any]: + if not isinstance(value, list): value = [value] return [str(value).replace(" ", "") if remove_spaces else str(value) for value in value] @@ -764,19 +762,19 @@ def filter_element( *, operator: Optional[OperatorType] = None, prepend: str = "", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: if operator is None: operator = "exact" params = {} - combination_conditions: List[str] = [] + combination_conditions: list[str] = [] if key == "selector": if operator not in ("exact", "is_not"): raise exceptions.ValidationError( 'Filtering by element selector only supports operators "equals" and "doesn\'t equal" currently.' ) - selectors = cast(List[str | int], value) if isinstance(value, list) else [value] + selectors = cast(list[str | int], value) if isinstance(value, list) else [value] for idx, query in enumerate(selectors): if not query: # Skip empty selectors continue @@ -792,7 +790,7 @@ def filter_element( raise exceptions.ValidationError( 'Filtering by element tag only supports operators "equals" and "doesn\'t equal" currently.' ) - tag_names = cast(List[str | int], value) if isinstance(value, list) else [value] + tag_names = cast(list[str | int], value) if isinstance(value, list) else [value] for idx, tag_name in enumerate(tag_names): if not tag_name: # Skip empty tags continue @@ -824,12 +822,12 @@ def filter_element( return "0 = 191" if operator not in NEGATED_OPERATORS else "", {} -def process_ok_values(ok_values: Any, operator: OperatorType) -> List[str]: +def process_ok_values(ok_values: Any, operator: OperatorType) -> list[str]: if operator.endswith("_set"): return [r'[^"]+'] else: # Make sure ok_values is a list - ok_values = cast(List[str], [str(val) for val in ok_values]) if isinstance(ok_values, list) else [ok_values] + ok_values = cast(list[str], [str(val) for val in ok_values]) if isinstance(ok_values, list) else [ok_values] # Escape double quote characters, since e.g. text 'foo="bar"' is represented as text="foo=\"bar\"" # in the elements chain ok_values = [text.replace('"', r"\"") for text in ok_values] @@ -869,8 +867,8 @@ def build_selector_regex(selector: Selector) -> str: class HogQLPropertyChecker(TraversingVisitor): def __init__(self): - self.event_properties: List[str] = [] - self.person_properties: List[str] = [] + self.event_properties: list[str] = [] + self.person_properties: list[str] = [] def visit_field(self, node: ast.Field): if len(node.chain) > 1 and node.chain[0] == "properties": @@ -888,8 +886,8 @@ class HogQLPropertyChecker(TraversingVisitor): self.person_properties.append(node.chain[3]) -def extract_tables_and_properties(props: List[Property]) -> TCounter[PropertyIdentifier]: - counters: List[tuple] = [] +def extract_tables_and_properties(props: list[Property]) -> TCounter[PropertyIdentifier]: + counters: list[tuple] = [] for prop in props: if prop.type == "hogql": counters.extend(count_hogql_properties(prop.key)) @@ -917,7 +915,7 @@ def count_hogql_properties( return counter -def get_session_property_filter_statement(prop: Property, idx: int, prepend: str = "") -> Tuple[str, Dict[str, Any]]: +def get_session_property_filter_statement(prop: Property, idx: int, prepend: str = "") -> tuple[str, dict[str, Any]]: if prop.key == "$session_duration": try: duration = float(prop.value) # type: ignore diff --git a/posthog/models/sharing_configuration.py b/posthog/models/sharing_configuration.py index 48ea711f02a..7bbacc45355 100644 --- a/posthog/models/sharing_configuration.py +++ b/posthog/models/sharing_configuration.py @@ -1,5 +1,5 @@ import secrets -from typing import List, cast +from typing import cast from django.db import models @@ -48,7 +48,7 @@ class SharingConfiguration(models.Model): return False - def get_connected_insight_ids(self) -> List[int]: + def get_connected_insight_ids(self) -> list[int]: if self.insight: if self.insight.deleted: return [] diff --git a/posthog/models/subscription.py b/posthog/models/subscription.py index f7b8a90a7e4..a0aa65ed9f6 100644 --- a/posthog/models/subscription.py +++ b/posthog/models/subscription.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Any, Optional from dateutil.rrule import ( FR, @@ -134,7 +134,7 @@ class Subscription(models.Model): self.set_next_delivery_date() if "update_fields" in kwargs: kwargs["update_fields"].append("next_delivery_date") - super(Subscription, self).save(*args, **kwargs) + super().save(*args, **kwargs) @property def url(self): @@ -187,7 +187,7 @@ class Subscription(models.Model): capture_exception(e) return "sent on a schedule" - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """ Returns serialized information about the object for analytics reporting. """ diff --git a/posthog/models/tagged_item.py b/posthog/models/tagged_item.py index 612f2f39399..302adcdb24f 100644 --- a/posthog/models/tagged_item.py +++ b/posthog/models/tagged_item.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Union +from typing import Union +from collections.abc import Iterable from django.core.exceptions import ValidationError from django.db import models @@ -18,7 +19,7 @@ RELATED_OBJECTS = ( # Checks that exactly one object field is populated def build_check(related_objects: Iterable[str]): - built_check_list: List[Union[Q, Q]] = [] + built_check_list: list[Union[Q, Q]] = [] for field in related_objects: built_check_list.append( Q( @@ -117,7 +118,7 @@ class TaggedItem(UUIDModel): def save(self, *args, **kwargs): self.full_clean() - return super(TaggedItem, self).save(*args, **kwargs) + return super().save(*args, **kwargs) def __str__(self) -> str: return str(self.tag) diff --git a/posthog/models/team/team.py b/posthog/models/team/team.py index 6f5f927fe00..868b8555964 100644 --- a/posthog/models/team/team.py +++ b/posthog/models/team/team.py @@ -1,7 +1,7 @@ import re from decimal import Decimal from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import posthoganalytics import pydantic @@ -64,7 +64,7 @@ class TeamManager(models.Manager): def get_queryset(self): return super().get_queryset().defer(*DEPRECATED_ATTRS) - def set_test_account_filters(self, organization: Optional[Any]) -> List: + def set_test_account_filters(self, organization: Optional[Any]) -> list: filters = [ { "key": "$host", @@ -150,7 +150,7 @@ class TeamManager(models.Manager): return result[0] -def get_default_data_attributes() -> List[str]: +def get_default_data_attributes() -> list[str]: return ["data-attr"] @@ -477,7 +477,7 @@ def groups_on_events_querying_enabled(): def check_is_feature_available_for_team(team_id: int, feature_key: str, current_usage: Optional[int] = None): - available_product_features: Optional[List[Dict[str, str]]] = ( + available_product_features: Optional[list[dict[str, str]]] = ( Team.objects.select_related("organization") .values_list("organization__available_product_features", flat=True) .get(id=team_id) diff --git a/posthog/models/team/util.py b/posthog/models/team/util.py index a21b75ab803..5756c8da211 100644 --- a/posthog/models/team/util.py +++ b/posthog/models/team/util.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, List +from typing import Any from posthog.temporal.common.client import sync_connect from posthog.batch_exports.service import batch_export_delete_schedule @@ -7,7 +7,7 @@ from posthog.cache_utils import cache_for from posthog.models.async_migration import is_async_migration_complete -def delete_bulky_postgres_data(team_ids: List[int]): +def delete_bulky_postgres_data(team_ids: list[int]): "Efficiently delete large tables for teams from postgres. Using normal CASCADE delete here can time out" from posthog.models.cohort import CohortPeople @@ -29,7 +29,7 @@ def _raw_delete(queryset: Any): queryset._raw_delete(queryset.db) -def delete_batch_exports(team_ids: List[int]): +def delete_batch_exports(team_ids: list[int]): """Delete BatchExports for deleted teams. Using normal CASCADE doesn't trigger a delete from Temporal. diff --git a/posthog/models/test/test_dashboard_tile_model.py b/posthog/models/test/test_dashboard_tile_model.py index be13ba06975..79f4a085a24 100644 --- a/posthog/models/test/test_dashboard_tile_model.py +++ b/posthog/models/test/test_dashboard_tile_model.py @@ -1,5 +1,4 @@ import datetime -from typing import Dict, List from django.core.exceptions import ValidationError from django.db.utils import IntegrityError @@ -19,7 +18,7 @@ from posthog.test.db_context_capturing import capture_db_queries class TestDashboardTileModel(APIBaseTest): dashboard: Dashboard asset: ExportedAsset - tiles: List[DashboardTile] + tiles: list[DashboardTile] def setUp(self) -> None: self.dashboard = Dashboard.objects.create(team=self.team, name="private dashboard", created_by=self.user) @@ -64,7 +63,7 @@ class TestDashboardTileModel(APIBaseTest): DashboardTile.objects.create(dashboard=self.dashboard, insight=insight, text=text) def test_cannot_set_caching_data_for_text_tiles(self) -> None: - tile_fields: List[Dict] = [ + tile_fields: list[dict] = [ {"filters_hash": "123"}, {"refreshing": True}, {"refresh_attempt": 2}, diff --git a/posthog/models/uploaded_media.py b/posthog/models/uploaded_media.py index 0161b71beb4..2b31f348263 100644 --- a/posthog/models/uploaded_media.py +++ b/posthog/models/uploaded_media.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import structlog from django.conf import settings @@ -72,7 +72,7 @@ class UploadedMedia(UUIDModel): def save_content_to_object_storage(uploaded_media: UploadedMedia, content: bytes) -> None: - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_MEDIA_UPLOADS_FOLDER, f"team-{uploaded_media.team.pk}", f"media-{uploaded_media.pk}", diff --git a/posthog/models/user.py b/posthog/models/user.py index cb4b1063cc9..c2d5b0f8d55 100644 --- a/posthog/models/user.py +++ b/posthog/models/user.py @@ -1,5 +1,6 @@ from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict +from typing import Any, Optional, TypedDict +from collections.abc import Callable from django.contrib.auth.models import AbstractUser, BaseUserManager from django.db import models, transaction @@ -36,7 +37,7 @@ class UserManager(BaseUserManager): def get_queryset(self): return super().get_queryset().defer(*DEFERED_ATTRS) - model: Type["User"] + model: type["User"] use_in_migrations = True @@ -58,12 +59,12 @@ class UserManager(BaseUserManager): email: str, password: Optional[str], first_name: str = "", - organization_fields: Optional[Dict[str, Any]] = None, - team_fields: Optional[Dict[str, Any]] = None, + organization_fields: Optional[dict[str, Any]] = None, + team_fields: Optional[dict[str, Any]] = None, create_team: Optional[Callable[["Organization", "User"], "Team"]] = None, is_staff: bool = False, **user_fields, - ) -> Tuple["Organization", "Team", "User"]: + ) -> tuple["Organization", "Team", "User"]: """Instead of doing the legwork of creating a user from scratch, delegate the details with bootstrap.""" with transaction.atomic(): organization_fields = organization_fields or {} @@ -112,7 +113,7 @@ class UserManager(BaseUserManager): return personal_api_key.user -def events_column_config_default() -> Dict[str, Any]: +def events_column_config_default() -> dict[str, Any]: return {"active": "DEFAULT"} @@ -124,7 +125,7 @@ class ThemeMode(models.TextChoices): class User(AbstractUser, UUIDClassicModel): USERNAME_FIELD = "email" - REQUIRED_FIELDS: List[str] = [] + REQUIRED_FIELDS: list[str] = [] DISABLED = "disabled" TOOLBAR = "toolbar" diff --git a/posthog/models/utils.py b/posthog/models/utils.py index a093cf1e4eb..c832cc8f044 100644 --- a/posthog/models/utils.py +++ b/posthog/models/utils.py @@ -5,7 +5,8 @@ from collections import defaultdict, namedtuple from contextlib import contextmanager from random import Random, choice from time import time -from typing import Any, Callable, Dict, Iterator, Optional, Set, Type, TypeVar +from typing import Any, Optional, TypeVar +from collections.abc import Callable, Iterator from django.db import IntegrityError, connections, models, transaction from django.db.backends.utils import CursorWrapper @@ -40,7 +41,7 @@ class UUIDT(uuid.UUID): (https://blog.twitter.com/engineering/en_us/a/2010/announcing-snowflake.html). """ - current_series_per_ms: Dict[int, int] = defaultdict(int) + current_series_per_ms: dict[int, int] = defaultdict(int) def __init__( self, @@ -205,10 +206,10 @@ def create_with_slug(create_func: Callable[..., T], default_slug: str = "", *arg def get_deferred_field_set_for_model( - model: Type[models.Model], - fields_not_deferred: Optional[Set[str]] = None, + model: type[models.Model], + fields_not_deferred: Optional[set[str]] = None, field_prefix: str = "", -) -> Set[str]: +) -> set[str]: """Return a set of field names to be deferred for a given model. Used with `.defer()` after `select_related` Why? `select_related` fetches the entire related objects - not allowing you to specify which fields diff --git a/posthog/plugins/site.py b/posthog/plugins/site.py index 9cb2b3023f8..0f5feda2df2 100644 --- a/posthog/plugins/site.py +++ b/posthog/plugins/site.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass from hashlib import md5 -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from posthog.models import Team @@ -11,7 +11,7 @@ class WebJsSource: id: int source: str token: str - config_schema: List[dict] + config_schema: list[dict] config: dict @@ -48,7 +48,7 @@ def get_transpiled_site_source(id: int, token: str) -> Optional[WebJsSource]: return WebJsSource(*(list(response))) # type: ignore -def get_decide_site_apps(team: "Team", using_database: str = "default") -> List[dict]: +def get_decide_site_apps(team: "Team", using_database: str = "default") -> list[dict]: from posthog.models import PluginConfig, PluginSourceFile sources = ( @@ -70,13 +70,13 @@ def get_decide_site_apps(team: "Team", using_database: str = "default") -> List[ ) def site_app_url(source: tuple) -> str: - hash = md5(f"{source[2]}-{source[3]}-{source[4]}".encode("utf-8")).hexdigest() + hash = md5(f"{source[2]}-{source[3]}-{source[4]}".encode()).hexdigest() return f"/site_app/{source[0]}/{source[1]}/{hash}/" return [asdict(WebJsUrl(source[0], site_app_url(source))) for source in sources] -def get_site_config_from_schema(config_schema: Optional[List[dict]], config: Optional[dict]): +def get_site_config_from_schema(config_schema: Optional[list[dict]], config: Optional[dict]): if not config or not config_schema: return {} return { diff --git a/posthog/plugins/utils.py b/posthog/plugins/utils.py index 2610d8b2eb1..602f775447b 100644 --- a/posthog/plugins/utils.py +++ b/posthog/plugins/utils.py @@ -4,7 +4,7 @@ import os import re import tarfile from tarfile import ReadError -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from urllib.parse import parse_qs, quote from zipfile import ZIP_DEFLATED, BadZipFile, Path, ZipFile @@ -12,7 +12,7 @@ import requests from django.conf import settings -def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_github_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search( r"^https?://(?:www\.)?github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)(/(commit|tree|releases/tag)/([A-Za-z0-9_.\-]+)/?([A-Za-z0-9_.\-/]+)?)?$", @@ -27,7 +27,7 @@ def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "github", "root_url": f"https://github.com/{match.group(1)}/{match.group(2)}", "user": match.group(1), @@ -76,13 +76,13 @@ def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O return parsed -def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search(r"^https?://(?:www\.)?gitlab\.com/([A-Za-z0-9_.\-/]+)$", url) if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "gitlab", "project": match.group(1), "tag": None, @@ -127,7 +127,7 @@ def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O return parsed -def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search( r"^https?://(?:www\.)?npmjs\.com/package/([@a-z0-9_-]+(/[a-z0-9_-]+)?)?/?(v/([A-Za-z0-9_.-]+)/?|)$", @@ -135,7 +135,7 @@ def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Opti ) if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "npm", "pkg": match.group(1), "tag": match.group(4), @@ -166,7 +166,7 @@ def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Opti return parsed -def parse_url(url: str, get_latest_if_none=False) -> Dict[str, Optional[str]]: +def parse_url(url: str, get_latest_if_none=False) -> dict[str, Optional[str]]: parsed_url = parse_github_url(url, get_latest_if_none) if parsed_url: return parsed_url @@ -179,7 +179,7 @@ def parse_url(url: str, get_latest_if_none=False) -> Dict[str, Optional[str]]: raise Exception("Must be a GitHub/GitLab repository or npm package URL!") -def split_url_and_private_token(url: str) -> Tuple[str, Optional[str]]: +def split_url_and_private_token(url: str) -> tuple[str, Optional[str]]: private_token = None if "?" in url: url, query = url.split("?") @@ -242,7 +242,7 @@ def download_plugin_archive(url: str, tag: Optional[str] = None) -> bytes: def load_json_file(filename: str): try: - with open(filename, "r", encoding="utf_8") as reader: + with open(filename, encoding="utf_8") as reader: return json.loads(reader.read()) except FileNotFoundError: return None @@ -313,8 +313,8 @@ def find_index_ts_in_archive(archive: bytes, main_filename: Optional[str] = None def extract_plugin_code( - archive: bytes, plugin_json_parsed: Optional[Dict[str, Any]] = None -) -> Tuple[str, Optional[str], Optional[str], Optional[str]]: + archive: bytes, plugin_json_parsed: Optional[dict[str, Any]] = None +) -> tuple[str, Optional[str], Optional[str], Optional[str]]: """Extract plugin.json, index.ts (which can be aliased) and frontend.tsx out of an archive. If plugin.json has already been parsed before this is called, its value can be passed in as an optimization.""" diff --git a/posthog/queries/actor_base_query.py b/posthog/queries/actor_base_query.py index 66c476cd814..f23b4c4ff05 100644 --- a/posthog/queries/actor_base_query.py +++ b/posthog/queries/actor_base_query.py @@ -2,12 +2,8 @@ import uuid from datetime import datetime, timedelta from typing import ( Any, - Dict, - List, Literal, Optional, - Set, - Tuple, TypedDict, Union, cast, @@ -34,14 +30,14 @@ class EventInfoForRecording(TypedDict): class MatchedRecording(TypedDict): session_id: str - events: List[EventInfoForRecording] + events: list[EventInfoForRecording] class CommonActor(TypedDict): id: Union[uuid.UUID, str] created_at: Optional[str] - properties: Dict[str, Any] - matched_recordings: List[MatchedRecording] + properties: dict[str, Any] + matched_recordings: list[MatchedRecording] value_at_data_point: Optional[float] @@ -50,7 +46,7 @@ class SerializedPerson(CommonActor): uuid: Union[uuid.UUID, str] is_identified: Optional[bool] name: str - distinct_ids: List[str] + distinct_ids: list[str] class SerializedGroup(CommonActor): @@ -81,7 +77,7 @@ class ActorBaseQuery: self.entity = entity self._filter = filter - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: """Implemented by subclasses. Must provide query and params. The query must return list of uuids. Can be group uuids (group_key) or person uuids""" raise NotImplementedError() @@ -96,9 +92,9 @@ class ActorBaseQuery: def get_actors( self, - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], int, ]: """Get actors in data model and dict formats. Builds query and executes""" @@ -124,10 +120,10 @@ class ActorBaseQuery: def query_for_session_ids_with_recordings( self, - session_ids: Set[str], + session_ids: set[str], date_from: datetime | None, date_to: datetime | None, - ) -> Set[str]: + ) -> set[str]: """Filters a list of session_ids to those that actually have recordings""" query = """ SELECT DISTINCT session_id @@ -166,9 +162,9 @@ class ActorBaseQuery: def add_matched_recordings_to_serialized_actors( self, - serialized_actors: Union[List[SerializedGroup], List[SerializedPerson]], + serialized_actors: Union[list[SerializedGroup], list[SerializedPerson]], raw_result, - ) -> Union[List[SerializedGroup], List[SerializedPerson]]: + ) -> Union[list[SerializedGroup], list[SerializedPerson]]: all_session_ids = set() session_events_column_index = 2 if self.ACTOR_VALUES_INCLUDED else 1 @@ -192,9 +188,9 @@ class ActorBaseQuery: ) session_ids_with_recordings = session_ids_with_all_recordings.difference(session_ids_with_deleted_recordings) - matched_recordings_by_actor_id: Dict[Union[uuid.UUID, str], List[MatchedRecording]] = {} + matched_recordings_by_actor_id: dict[Union[uuid.UUID, str], list[MatchedRecording]] = {} for row in raw_result: - recording_events_by_session_id: Dict[str, List[EventInfoForRecording]] = {} + recording_events_by_session_id: dict[str, list[EventInfoForRecording]] = {} if len(row) > session_events_column_index - 1: for event in row[session_events_column_index]: event_session_id = event[2] @@ -211,7 +207,7 @@ class ActorBaseQuery: # Casting Union[SerializedActor, SerializedGroup] as SerializedPerson because mypy yells # when you do an indexed assignment on a Union even if all items in the Union support it - serialized_actors = cast(List[SerializedPerson], serialized_actors) + serialized_actors = cast(list[SerializedPerson], serialized_actors) serialized_actors_with_recordings = [] for actor in serialized_actors: actor["matched_recordings"] = matched_recordings_by_actor_id[actor["id"]] @@ -221,12 +217,12 @@ class ActorBaseQuery: def get_actors_from_result( self, raw_result - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], ]: actors: Union[QuerySet[Person], QuerySet[Group]] - serialized_actors: Union[List[SerializedGroup], List[SerializedPerson]] + serialized_actors: Union[list[SerializedGroup], list[SerializedPerson]] actor_ids = [row[0] for row in raw_result] value_per_actor_id = {str(row[0]): row[1] for row in raw_result} if self.ACTOR_VALUES_INCLUDED else None @@ -255,9 +251,9 @@ class ActorBaseQuery: def get_groups( team_id: int, group_type_index: int, - group_ids: List[Any], - value_per_actor_id: Optional[Dict[str, float]] = None, -) -> Tuple[QuerySet[Group], List[SerializedGroup]]: + group_ids: list[Any], + value_per_actor_id: Optional[dict[str, float]] = None, +) -> tuple[QuerySet[Group], list[SerializedGroup]]: """Get groups from raw SQL results in data model and dict formats""" groups: QuerySet[Group] = Group.objects.filter( team_id=team_id, group_type_index=group_type_index, group_key__in=group_ids @@ -267,10 +263,10 @@ def get_groups( def get_people( team: Team, - people_ids: List[Any], - value_per_actor_id: Optional[Dict[str, float]] = None, + people_ids: list[Any], + value_per_actor_id: Optional[dict[str, float]] = None, distinct_id_limit=1000, -) -> Tuple[QuerySet[Person], List[SerializedPerson]]: +) -> tuple[QuerySet[Person], list[SerializedPerson]]: """Get people from raw SQL results in data model and dict formats""" distinct_id_subquery = Subquery( PersonDistinctId.objects.filter(person_id=OuterRef("person_id")).values_list("id", flat=True)[ @@ -294,9 +290,9 @@ def get_people( def serialize_people( team: Team, - data: Union[QuerySet[Person], List[Person]], - value_per_actor_id: Optional[Dict[str, float]] = None, -) -> List[SerializedPerson]: + data: Union[QuerySet[Person], list[Person]], + value_per_actor_id: Optional[dict[str, float]] = None, +) -> list[SerializedPerson]: from posthog.api.person import get_person_name return [ @@ -316,7 +312,7 @@ def serialize_people( ] -def serialize_groups(data: QuerySet[Group], value_per_actor_id: Optional[Dict[str, float]]) -> List[SerializedGroup]: +def serialize_groups(data: QuerySet[Group], value_per_actor_id: Optional[dict[str, float]]) -> list[SerializedGroup]: return [ SerializedGroup( id=group.group_key, diff --git a/posthog/queries/app_metrics/historical_exports.py b/posthog/queries/app_metrics/historical_exports.py index cbf22d48015..5fd32a06ec2 100644 --- a/posthog/queries/app_metrics/historical_exports.py +++ b/posthog/queries/app_metrics/historical_exports.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Dict, Optional +from typing import Optional from zoneinfo import ZoneInfo @@ -26,7 +26,7 @@ def historical_exports_activity(team_id: int, plugin_config_id: int, job_id: Opt **({"detail__trigger__job_id": job_id} if job_id is not None else {}), ) - by_category: Dict = {"job_triggered": {}, "export_success": {}, "export_fail": {}} + by_category: dict = {"job_triggered": {}, "export_success": {}, "export_fail": {}} for entry in entries: by_category[entry.activity][entry.detail["trigger"]["job_id"]] = entry diff --git a/posthog/queries/app_metrics/test/test_app_metrics.py b/posthog/queries/app_metrics/test/test_app_metrics.py index e6c50b08ae5..2368961b507 100644 --- a/posthog/queries/app_metrics/test/test_app_metrics.py +++ b/posthog/queries/app_metrics/test/test_app_metrics.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Dict, Optional +from typing import Optional from freezegun.api import freeze_time @@ -34,7 +34,7 @@ def create_app_metric( failures=0, error_uuid: Optional[str] = None, error_type: Optional[str] = None, - error_details: Optional[Dict] = None, + error_details: Optional[dict] = None, ): timestamp = cast_timestamp_or_now(timestamp) data = { diff --git a/posthog/queries/base.py b/posthog/queries/base.py index 7dff88f6020..e5cf6e71744 100644 --- a/posthog/queries/base.py +++ b/posthog/queries/base.py @@ -3,14 +3,12 @@ import hashlib import re from typing import ( Any, - Callable, - Dict, - List, Optional, TypeVar, Union, cast, ) +from collections.abc import Callable from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta from dateutil import parser @@ -47,7 +45,7 @@ def determine_compared_filter(filter: FilterType) -> FilterType: return filter.shallow_clone({"date_from": date_from.isoformat(), "date_to": date_to.isoformat()}) -def convert_to_comparison(trend_entities: List[Dict[str, Any]], filter, label: str) -> List[Dict[str, Any]]: +def convert_to_comparison(trend_entities: list[dict[str, Any]], filter, label: str) -> list[dict[str, Any]]: for entity in trend_entities: labels = [ "{} {}".format(filter.interval if filter.interval is not None else "day", i) @@ -72,7 +70,7 @@ def convert_to_comparison(trend_entities: List[Dict[str, Any]], filter, label: s """ -def handle_compare(filter, func: Callable, team: Team, **kwargs) -> List: +def handle_compare(filter, func: Callable, team: Team, **kwargs) -> list: all_entities = [] base_entitites = func(filter=filter, team=team, **kwargs) if filter.compare: @@ -88,7 +86,7 @@ def handle_compare(filter, func: Callable, team: Team, **kwargs) -> List: return all_entities -def match_property(property: Property, override_property_values: Dict[str, Any]) -> bool: +def match_property(property: Property, override_property_values: dict[str, Any]) -> bool: # only looks for matches where key exists in override_property_values # doesn't support operator is_not_set @@ -276,8 +274,8 @@ def lookup_q(key: str, value: Any) -> Q: def property_to_Q( team_id: int, property: Property, - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: if override_property_values is None: @@ -382,8 +380,8 @@ def property_to_Q( def property_group_to_Q( team_id: int, property_group: PropertyGroup, - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: if override_property_values is None: @@ -426,9 +424,9 @@ def property_group_to_Q( def properties_to_Q( team_id: int, - properties: List[Property], - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + properties: list[Property], + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: """ diff --git a/posthog/queries/breakdown_props.py b/posthog/queries/breakdown_props.py index fffb0aef0f2..96cf6afa959 100644 --- a/posthog/queries/breakdown_props.py +++ b/posthog/queries/breakdown_props.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from django.forms import ValidationError @@ -50,7 +50,7 @@ def get_breakdown_prop_values( column_optimizer: Optional[ColumnOptimizer] = None, person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_PERSON_PROPERTIES_COLUMN, use_all_funnel_entities: bool = False, -) -> Tuple[List[Any], bool]: +) -> tuple[list[Any], bool]: """ Returns the top N breakdown prop values for event/person breakdown @@ -77,13 +77,13 @@ def get_breakdown_prop_values( props_to_filter = filter.property_groups person_join_clauses = "" - person_join_params: Dict = {} + person_join_params: dict = {} groups_join_clause = "" - groups_join_params: Dict = {} + groups_join_params: dict = {} sessions_join_clause = "" - sessions_join_params: Dict = {} + sessions_join_params: dict = {} null_person_filter = ( f"AND notEmpty(e.person_id)" if team.person_on_events_mode != PersonsOnEventsMode.disabled else "" @@ -248,14 +248,14 @@ def get_breakdown_prop_values( def _to_value_expression( breakdown_type: Optional[BREAKDOWN_TYPES], - breakdown: Union[str, List[Union[str, int]], None], + breakdown: Union[str, list[Union[str, int]], None], breakdown_group_type_index: Optional[GroupTypeIndex], hogql_context: HogQLContext, breakdown_normalize_url: bool = False, direct_on_events: bool = False, cast_as_float: bool = False, -) -> Tuple[str, Dict]: - params: Dict[str, Any] = {} +) -> tuple[str, dict]: + params: dict[str, Any] = {} if breakdown_type == "session": if breakdown == "$session_duration": # Return the session duration expression right away because it's already an number, @@ -321,7 +321,7 @@ def _to_bucketing_expression(bin_count: int) -> str: return f"arrayCompact(arrayMap(x -> floor(x, 2), {qunatile_expression}))" -def _format_all_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, Dict]: +def _format_all_query(team: Team, filter: Filter, **kwargs) -> tuple[str, dict]: entity = kwargs.pop("entity", None) date_params = {} @@ -354,7 +354,7 @@ def _format_all_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, Dict]: return query, {**date_params, **prop_filter_params} -def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, List, Dict]: +def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> tuple[str, list, dict]: entity = kwargs.pop("entity", None) cohorts = ( Cohort.objects.filter(team_id=team.pk, pk__in=[b for b in filter.breakdown if b != "all"]) @@ -371,9 +371,9 @@ def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> return " UNION ALL ".join(cohort_queries), ids, params -def _parse_breakdown_cohorts(cohorts: List[Cohort], hogql_context: HogQLContext) -> Tuple[List[str], Dict]: +def _parse_breakdown_cohorts(cohorts: list[Cohort], hogql_context: HogQLContext) -> tuple[list[str], dict]: queries = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for idx, cohort in enumerate(cohorts): person_id_query, cohort_filter_params = format_filter_query(cohort, idx, hogql_context) diff --git a/posthog/queries/column_optimizer/foss_column_optimizer.py b/posthog/queries/column_optimizer/foss_column_optimizer.py index 98dfb1b54c4..b3e73d3178c 100644 --- a/posthog/queries/column_optimizer/foss_column_optimizer.py +++ b/posthog/queries/column_optimizer/foss_column_optimizer.py @@ -1,6 +1,7 @@ from collections import Counter -from typing import Counter as TCounter -from typing import Generator, List, Set, Union, cast +from collections import Counter as TCounter +from typing import Union, cast +from collections.abc import Generator from posthog.clickhouse.materialized_columns import ColumnName, get_materialized_columns from posthog.constants import TREND_FILTER_TYPE_ACTIONS, FunnelCorrelationType @@ -48,19 +49,19 @@ class FOSSColumnOptimizer: self.property_optimizer = PropertyOptimizer() @cached_property - def event_columns_to_query(self) -> Set[ColumnName]: + def event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table columns containing materialized properties that this query needs" return self.columns_to_query("events", set(self.used_properties_with_type("event"))) @cached_property - def person_on_event_columns_to_query(self) -> Set[ColumnName]: + def person_on_event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table person columns containing materialized properties that this query needs" return self.columns_to_query("events", set(self.used_properties_with_type("person")), "person_properties") @cached_property - def person_columns_to_query(self) -> Set[ColumnName]: + def person_columns_to_query(self) -> set[ColumnName]: "Returns a list of person table columns containing materialized properties that this query needs" return self.columns_to_query("person", set(self.used_properties_with_type("person"))) @@ -68,9 +69,9 @@ class FOSSColumnOptimizer: def columns_to_query( self, table: TableWithProperties, - used_properties: Set[PropertyIdentifier], + used_properties: set[PropertyIdentifier], table_column: str = "properties", - ) -> Set[ColumnName]: + ) -> set[ColumnName]: "Transforms a list of property names to what columns are needed for that query" materialized_columns = get_materialized_columns(table) @@ -92,11 +93,11 @@ class FOSSColumnOptimizer: ) @cached_property - def group_types_to_query(self) -> Set[GroupTypeIndex]: + def group_types_to_query(self) -> set[GroupTypeIndex]: return set() @cached_property - def group_on_event_columns_to_query(self) -> Set[ColumnName]: + def group_on_event_columns_to_query(self) -> set[ColumnName]: return set() @cached_property @@ -171,7 +172,7 @@ class FOSSColumnOptimizer: counter += get_action_tables_and_properties(entity.get_action()) if ( - not isinstance(self.filter, (StickinessFilter, PropertiesTimelineFilter)) + not isinstance(self.filter, StickinessFilter | PropertiesTimelineFilter) and self.filter.correlation_type == FunnelCorrelationType.PROPERTIES and self.filter.correlation_property_names ): @@ -195,7 +196,7 @@ class FOSSColumnOptimizer: def entities_used_in_filter(self) -> Generator[Entity, None, None]: yield from self.filter.entities - yield from cast(List[Entity], self.filter.exclusions) + yield from cast(list[Entity], self.filter.exclusions) if isinstance(self.filter, RetentionFilter): yield self.filter.target_entity diff --git a/posthog/queries/event_query/event_query.py b/posthog/queries/event_query/event_query.py index 8737876d001..49d4565a943 100644 --- a/posthog/queries/event_query/event_query.py +++ b/posthog/queries/event_query/event_query.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from posthog.clickhouse.materialized_columns import ColumnName from posthog.models import Cohort, Filter, Property @@ -38,9 +38,9 @@ class EventQuery(metaclass=ABCMeta): _should_join_persons = False _should_join_sessions = False _should_round_interval = False - _extra_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] - _extra_person_fields: List[ColumnName] + _extra_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] + _extra_person_fields: list[ColumnName] _person_id_alias: str _session_id_alias: Optional[str] @@ -60,9 +60,9 @@ class EventQuery(metaclass=ABCMeta): should_join_persons=False, should_join_sessions=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, **kwargs, @@ -79,7 +79,7 @@ class EventQuery(metaclass=ABCMeta): self._extra_event_properties = extra_event_properties self._column_optimizer = ColumnOptimizer(self._filter, self._team_id) self._extra_person_fields = extra_person_fields - self.params: Dict[str, Any] = { + self.params: dict[str, Any] = { "team_id": self._team_id, "timezone": team.timezone, } @@ -118,7 +118,7 @@ class EventQuery(metaclass=ABCMeta): self._person_id_alias = self._get_person_id_alias(person_on_events_mode) @abstractmethod - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: pass @abstractmethod @@ -206,7 +206,7 @@ class EventQuery(metaclass=ABCMeta): extra_fields=self._extra_person_fields, ) - def _get_person_query(self) -> Tuple[str, Dict]: + def _get_person_query(self) -> tuple[str, dict]: if self._should_join_persons: person_query, params = self._person_query.get_query() return ( @@ -219,7 +219,7 @@ class EventQuery(metaclass=ABCMeta): else: return "", {} - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: return "", {} @cached_property @@ -232,7 +232,7 @@ class EventQuery(metaclass=ABCMeta): session_id_alias=self._session_id_alias, ) - def _get_sessions_query(self) -> Tuple[str, Dict]: + def _get_sessions_query(self) -> tuple[str, dict]: if self._should_join_sessions: session_query, session_params = self._sessions_query.get_query() @@ -246,7 +246,7 @@ class EventQuery(metaclass=ABCMeta): ) return "", {} - def _get_date_filter(self) -> Tuple[str, Dict]: + def _get_date_filter(self) -> tuple[str, dict]: date_params = {} query_date_range = QueryDateRange( filter=self._filter, team=self._team, should_round=self._should_round_interval @@ -270,7 +270,7 @@ class EventQuery(metaclass=ABCMeta): person_id_joined_alias="person_id", prepend="global", allow_denormalized_props=True, - ) -> Tuple[str, Dict]: + ) -> tuple[str, dict]: if not prop_group: return "", {} diff --git a/posthog/queries/foss_cohort_query.py b/posthog/queries/foss_cohort_query.py index 352fc19ee13..847f6737c9f 100644 --- a/posthog/queries/foss_cohort_query.py +++ b/posthog/queries/foss_cohort_query.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from zoneinfo import ZoneInfo from posthog.clickhouse.materialized_columns import ColumnName @@ -26,8 +26,8 @@ from posthog.queries.util import PersonPropertiesMode from posthog.schema import PersonsOnEventsMode from posthog.utils import relative_date_parse -Relative_Date = Tuple[int, OperatorInterval] -Event = Tuple[str, Union[str, int]] +Relative_Date = tuple[int, OperatorInterval] +Event = tuple[str, Union[str, int]] INTERVAL_TO_SECONDS = { @@ -40,7 +40,7 @@ INTERVAL_TO_SECONDS = { } -def relative_date_to_seconds(date: Tuple[Optional[int], Union[OperatorInterval, None]]): +def relative_date_to_seconds(date: tuple[Optional[int], Union[OperatorInterval, None]]): if date[0] is None or date[1] is None: raise ValueError("Time value and time interval must be specified") @@ -66,7 +66,7 @@ def parse_and_validate_positive_integer(value: Optional[int], value_name: str) - return parsed_value -def validate_entity(possible_event: Tuple[Optional[str], Optional[Union[int, str]]]) -> Event: +def validate_entity(possible_event: tuple[Optional[str], Optional[Union[int, str]]]) -> Event: event_type = possible_event[0] event_val = possible_event[1] if event_type is None or event_val is None: @@ -83,7 +83,7 @@ def relative_date_is_greater(date_1: Relative_Date, date_2: Relative_Date) -> bo return relative_date_to_seconds(date_1) > relative_date_to_seconds(date_2) -def convert_to_entity_params(events: List[Event]) -> Tuple[List, List]: +def convert_to_entity_params(events: list[Event]) -> tuple[list, list]: res_events = [] res_actions = [] @@ -124,8 +124,8 @@ class FOSSCohortQuery(EventQuery): BEHAVIOR_QUERY_ALIAS = "behavior_query" FUNNEL_QUERY_ALIAS = "funnel_query" SEQUENCE_FIELD_ALIAS = "steps" - _fields: List[str] - _events: List[str] + _fields: list[str] + _events: list[str] _earliest_time_for_event_query: Optional[Relative_Date] _restrict_event_query_by_time: bool @@ -139,9 +139,9 @@ class FOSSCohortQuery(EventQuery): should_join_distinct_ids=False, should_join_persons=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, **kwargs, ) -> None: @@ -187,14 +187,14 @@ class FOSSCohortQuery(EventQuery): if not negate_group: return PropertyGroup( type=property_group.type, - values=[_unwrap(v) for v in cast(List[PropertyGroup], property_group.values)], + values=[_unwrap(v) for v in cast(list[PropertyGroup], property_group.values)], ) else: return PropertyGroup( type=PropertyOperatorType.AND if property_group.type == PropertyOperatorType.OR else PropertyOperatorType.OR, - values=[_unwrap(v, True) for v in cast(List[PropertyGroup], property_group.values)], + values=[_unwrap(v, True) for v in cast(list[PropertyGroup], property_group.values)], ) elif isinstance(property_group.values[0], Property): @@ -202,7 +202,7 @@ class FOSSCohortQuery(EventQuery): # if any single one is a cohort property, unwrap it into a property group # which implies converting everything else in the list into a property group too - new_property_group_list: List[PropertyGroup] = [] + new_property_group_list: list[PropertyGroup] = [] for prop in property_group.values: prop = cast(Property, prop) current_negation = prop.negation or False @@ -258,7 +258,7 @@ class FOSSCohortQuery(EventQuery): return filter.shallow_clone({"properties": new_props.to_dict()}) # Implemented in /ee - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: if not self._outer_property_groups: # everything is pushed down, no behavioral stuff to do # thus, use personQuery directly @@ -294,7 +294,7 @@ class FOSSCohortQuery(EventQuery): return final_query, self.params - def _build_sources(self, subq: List[Tuple[str, str]]) -> Tuple[str, str]: + def _build_sources(self, subq: list[tuple[str, str]]) -> tuple[str, str]: q = "" filtered_queries = [(q, alias) for (q, alias) in subq if q and len(q)] @@ -325,7 +325,7 @@ class FOSSCohortQuery(EventQuery): return q, fields - def _get_behavior_subquery(self) -> Tuple[str, Dict[str, Any], str]: + def _get_behavior_subquery(self) -> tuple[str, dict[str, Any], str]: # # Get the subquery for the cohort query. # @@ -371,7 +371,7 @@ class FOSSCohortQuery(EventQuery): return query, params, self.BEHAVIOR_QUERY_ALIAS - def _get_persons_query(self, prepend: str = "") -> Tuple[str, Dict[str, Any], str]: + def _get_persons_query(self, prepend: str = "") -> tuple[str, dict[str, Any], str]: query, params = "", {} if self._should_join_persons: person_query, person_params = self._person_query.get_query(prepend=prepend) @@ -387,9 +387,9 @@ class FOSSCohortQuery(EventQuery): prop.type for prop in getattr(self._outer_property_groups, "flat", []) ] and "static-cohort" not in [prop.type for prop in getattr(self._outer_property_groups, "flat", [])] - def _get_date_condition(self) -> Tuple[str, Dict[str, Any]]: + def _get_date_condition(self) -> tuple[str, dict[str, Any]]: date_query = "" - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} earliest_time_param = f"earliest_time_{self._cohort_pk}" if self._earliest_time_for_event_query and self._restrict_event_query_by_time: @@ -404,7 +404,7 @@ class FOSSCohortQuery(EventQuery): elif relative_date_is_greater(relative_date, self._earliest_time_for_event_query): self._earliest_time_for_event_query = relative_date - def _get_conditions(self) -> Tuple[str, Dict[str, Any]]: + def _get_conditions(self) -> tuple[str, dict[str, Any]]: def build_conditions(prop: Optional[Union[PropertyGroup, Property]], prepend="level", num=0): if not prop: return "", {} @@ -426,9 +426,9 @@ class FOSSCohortQuery(EventQuery): return f"AND ({conditions})" if conditions else "", params # Implemented in /ee - def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if prop.type == "behavioral": if prop.value == "performed_event": @@ -446,7 +446,7 @@ class FOSSCohortQuery(EventQuery): return res, params - def get_person_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_person_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: if self._outer_property_groups and len(self._outer_property_groups.flat): return prop_filter_json_extract( prop, @@ -459,7 +459,7 @@ class FOSSCohortQuery(EventQuery): else: return "", {} - def get_static_cohort_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_static_cohort_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: # If we reach this stage, it means there are no cyclic dependencies # They should've been caught by API update validation # and if not there, `simplifyFilter` would've failed @@ -467,8 +467,8 @@ class FOSSCohortQuery(EventQuery): query, params = format_static_cohort_query(cohort, idx, prepend) return f"id {'NOT' if prop.negation else ''} IN ({query})", params - def _get_entity_event_filters(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = {} + def _get_entity_event_filters(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: + params: dict[str, Any] = {} if prop.event_filters: prop_query, prop_params = parse_prop_grouped_clauses( @@ -491,7 +491,7 @@ class FOSSCohortQuery(EventQuery): # one extra day for any partial days return (delta.days + 1, "day") - def _get_entity_datetime_filters(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_entity_datetime_filters(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: if prop.explicit_datetime: # Explicit datetime filter, can be a relative or absolute date, follows same convention # as all analytics datetime filters @@ -512,7 +512,7 @@ class FOSSCohortQuery(EventQuery): return f"timestamp > now() - INTERVAL %({date_param})s {date_interval}", {f"{date_param}": date_value} - def get_performed_event_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"performed_event_condition_{prepend}_{idx}" @@ -530,7 +530,7 @@ class FOSSCohortQuery(EventQuery): **entity_filters_params, } - def get_performed_event_multiple(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_multiple(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"performed_event_multiple_condition_{prepend}_{idx}" @@ -591,12 +591,12 @@ class FOSSCohortQuery(EventQuery): def _get_entity( self, - event: Tuple[Optional[str], Optional[Union[int, str]]], + event: tuple[Optional[str], Optional[Union[int, str]]], prepend: str, idx: int, - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if event[0] is None or event[1] is None: raise ValueError("Event type and key must be specified") diff --git a/posthog/queries/funnels/base.py b/posthog/queries/funnels/base.py index c4258c6f6eb..a6de14b050c 100644 --- a/posthog/queries/funnels/base.py +++ b/posthog/queries/funnels/base.py @@ -1,7 +1,7 @@ import urllib.parse import uuid from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -44,9 +44,9 @@ class ClickhouseFunnelBase(ABC): _team: Team _include_timestamp: Optional[bool] _include_preceding_timestamp: Optional[bool] - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] - _include_properties: List[str] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] + _include_properties: list[str] def __init__( self, @@ -55,7 +55,7 @@ class ClickhouseFunnelBase(ABC): include_timestamp: Optional[bool] = None, include_preceding_timestamp: Optional[bool] = None, base_uri: str = "/", - include_properties: Optional[List[str]] = None, + include_properties: Optional[list[str]] = None, ) -> None: self._filter = filter self._team = team @@ -92,8 +92,8 @@ class ClickhouseFunnelBase(ABC): self.params.update({OFFSET: self._filter.offset}) - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if self._filter.include_recordings: self._extra_event_fields = ["uuid"] self._extra_event_properties = ["$session_id", "$window_id"] @@ -111,9 +111,9 @@ class ClickhouseFunnelBase(ABC): self, step: Entity, count: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if step.type == TREND_FILTER_TYPE_ACTIONS: name = step.get_action().name else: @@ -135,7 +135,7 @@ class ClickhouseFunnelBase(ABC): def _update_filters(self): # format default dates - data: Dict[str, Any] = {} + data: dict[str, Any] = {} if not self._filter._date_from: data.update({"date_from": relative_date_parse("-7d", self._team.timezone_info)}) @@ -153,7 +153,7 @@ class ClickhouseFunnelBase(ABC): # # Once multi property breakdown is implemented in Trends this becomes unnecessary - if isinstance(self._filter.breakdowns, List) and self._filter.breakdown_type in [ + if isinstance(self._filter.breakdowns, list) and self._filter.breakdown_type in [ "person", "event", "hogql", @@ -167,7 +167,7 @@ class ClickhouseFunnelBase(ABC): "hogql", None, ]: - boxed_breakdown: List[Union[str, int]] = box_value(self._filter.breakdown) + boxed_breakdown: list[Union[str, int]] = box_value(self._filter.breakdown) data.update({"breakdown": boxed_breakdown}) for exclusion in self._filter.exclusions: @@ -270,7 +270,7 @@ class ClickhouseFunnelBase(ABC): else: return self._format_single_funnel(results[0]) - def _exec_query(self) -> List[Tuple]: + def _exec_query(self) -> list[tuple]: self._filter.team = self._team query = self.get_query() return insight_sync_execute( @@ -289,7 +289,7 @@ class ClickhouseFunnelBase(ABC): else: return "" - def _get_timestamp_selects(self) -> Tuple[str, str]: + def _get_timestamp_selects(self) -> tuple[str, str]: """ Returns timestamp selectors for the target step and optionally the preceding step. In the former case, always returns the timestamp for the first and last step as well. @@ -328,7 +328,7 @@ class ClickhouseFunnelBase(ABC): return "", "" def _get_step_times(self, max_steps: int): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"if(isNotNull(latest_{i}) AND latest_{i} <= latest_{i-1} + INTERVAL {self._filter.funnel_window_interval} {self._filter.funnel_window_interval_unit_ch()}, " @@ -339,7 +339,7 @@ class ClickhouseFunnelBase(ABC): return f", {formatted}" if formatted else "" def _get_partition_cols(self, level_index: int, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: @@ -397,7 +397,7 @@ class ClickhouseFunnelBase(ABC): if curr_index == 1: return "1" - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, curr_index): duplicate_event = ( True @@ -444,7 +444,7 @@ class ClickhouseFunnelBase(ABC): else: steps_conditions = self._get_steps_conditions(length=len(entities_to_use)) - all_step_cols: List[str] = [] + all_step_cols: list[str] = [] for index, entity in enumerate(entities_to_use): step_cols = self._get_step_col(entity, index, entity_name) all_step_cols.extend(step_cols) @@ -521,7 +521,7 @@ class ClickhouseFunnelBase(ABC): """ def _get_steps_conditions(self, length: int) -> str: - step_conditions: List[str] = [] + step_conditions: list[str] = [] for index in range(length): step_conditions.append(f"step_{index} = 1") @@ -531,10 +531,10 @@ class ClickhouseFunnelBase(ABC): return " OR ".join(step_conditions) - def _get_step_col(self, entity: Entity, index: int, entity_name: str, step_prefix: str = "") -> List[str]: + def _get_step_col(self, entity: Entity, index: int, entity_name: str, step_prefix: str = "") -> list[str]: # step prefix is used to distinguish actual steps, and exclusion steps # without the prefix, we get the same parameter binding for both, which borks things up - step_cols: List[str] = [] + step_cols: list[str] = [] condition = self._build_step_query(entity, index, entity_name, step_prefix) step_cols.append(f"if({condition}, 1, 0) as {step_prefix}step_{index}") step_cols.append(f"if({step_prefix}step_{index} = 1, timestamp, null) as {step_prefix}latest_{index}") @@ -637,7 +637,7 @@ class ClickhouseFunnelBase(ABC): return "" def _get_count_columns(self, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(max_steps): cols.append(f"countIf(steps = {i + 1}) step_{i + 1}") @@ -680,7 +680,7 @@ class ClickhouseFunnelBase(ABC): return "" def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"avg(step_{i}_conversion_time) step_{i}_average_conversion_time_inner" @@ -692,7 +692,7 @@ class ClickhouseFunnelBase(ABC): return f", {formatted}" if formatted else "" def _get_step_time_median(self, max_steps: int, inner_query: bool = False): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"median(step_{i}_conversion_time) step_{i}_median_conversion_time_inner" @@ -720,9 +720,9 @@ class ClickhouseFunnelBase(ABC): def get_step_counts_without_aggregation_query(self) -> str: raise NotImplementedError() - def _get_breakdown_select_prop(self) -> Tuple[str, Dict[str, Any]]: + def _get_breakdown_select_prop(self) -> tuple[str, dict[str, Any]]: basic_prop_selector = "" - basic_prop_params: Dict[str, Any] = {} + basic_prop_params: dict[str, Any] = {} if not self._filter.breakdown: return basic_prop_selector, basic_prop_params @@ -837,7 +837,7 @@ class ClickhouseFunnelBase(ABC): ON events.distinct_id = cohort_join.distinct_id """ - def _get_breakdown_conditions(self) -> Optional[List[str]]: + def _get_breakdown_conditions(self) -> Optional[list[str]]: """ For people, pagination sets the offset param, which is common across filters and gives us the wrong breakdown values here, so we override it. diff --git a/posthog/queries/funnels/funnel.py b/posthog/queries/funnels/funnel.py index e1ac23f00d6..c72a7f1608e 100644 --- a/posthog/queries/funnels/funnel.py +++ b/posthog/queries/funnels/funnel.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from posthog.queries.funnels.base import ClickhouseFunnelBase @@ -74,7 +74,7 @@ class ClickhouseFunnel(ClickhouseFunnelBase): """ def _get_comparison_at_step(self, index: int, level_index: int): - or_statements: List[str] = [] + or_statements: list[str] = [] for i in range(level_index, index + 1): or_statements.append(f"latest_{i} < latest_{level_index - 1}") @@ -86,7 +86,7 @@ class ClickhouseFunnel(ClickhouseFunnelBase): level_index: The current smallest comparison step. Everything before level index is already at the minimum ordered timestamps. """ - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: diff --git a/posthog/queries/funnels/funnel_event_query.py b/posthog/queries/funnels/funnel_event_query.py index 2c8ad72524f..9f0ad134257 100644 --- a/posthog/queries/funnels/funnel_event_query.py +++ b/posthog/queries/funnels/funnel_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Set, Tuple, Union +from typing import Any, Union from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import translate_hogql @@ -17,7 +17,7 @@ class FunnelEventQuery(EventQuery): entities=None, entity_name="events", skip_entity_filter=False, - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: # Aggregating by group if self._filter.aggregation_group_type_index is not None: aggregation_target = get_aggregation_target_field( @@ -81,7 +81,7 @@ class FunnelEventQuery(EventQuery): if skip_entity_filter: entity_query = "" - entity_params: Dict[str, Any] = {} + entity_params: dict[str, Any] = {} else: entity_query, entity_params = self._get_entity_query(entities, entity_name) @@ -145,8 +145,8 @@ class FunnelEventQuery(EventQuery): if self._person_on_events_mode != PersonsOnEventsMode.disabled: self._should_join_persons = False - def _get_entity_query(self, entities=None, entity_name="events") -> Tuple[str, Dict[str, Any]]: - events: Set[Union[int, str, None]] = set() + def _get_entity_query(self, entities=None, entity_name="events") -> tuple[str, dict[str, Any]]: + events: set[Union[int, str, None]] = set() entities_to_use = entities or self._filter.entities for entity in entities_to_use: diff --git a/posthog/queries/funnels/funnel_persons.py b/posthog/queries/funnels/funnel_persons.py index 5cebef5fb7d..c221727866e 100644 --- a/posthog/queries/funnels/funnel_persons.py +++ b/posthog/queries/funnels/funnel_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -18,7 +18,7 @@ class ClickhouseFunnelActors(ClickhouseFunnel, ActorBaseQuery): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/funnel_strict.py b/posthog/queries/funnels/funnel_strict.py index 38b5d3a4c6a..cb9f97d1918 100644 --- a/posthog/queries/funnels/funnel_strict.py +++ b/posthog/queries/funnels/funnel_strict.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.queries.funnels.base import ClickhouseFunnelBase @@ -57,7 +55,7 @@ class ClickhouseFunnelStrict(ClickhouseFunnelBase): return formatted_query def _get_partition_cols(self, level_index: int, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: diff --git a/posthog/queries/funnels/funnel_strict_persons.py b/posthog/queries/funnels/funnel_strict_persons.py index cca6f8e598d..2ad13822f54 100644 --- a/posthog/queries/funnels/funnel_strict_persons.py +++ b/posthog/queries/funnels/funnel_strict_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -18,7 +18,7 @@ class ClickhouseFunnelStrictActors(ClickhouseFunnelStrict, ActorBaseQuery): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/funnel_trends.py b/posthog/queries/funnels/funnel_trends.py index d67b24ae78b..cb8ecbe7c82 100644 --- a/posthog/queries/funnels/funnel_trends.py +++ b/posthog/queries/funnels/funnel_trends.py @@ -1,6 +1,6 @@ from datetime import datetime from itertools import groupby -from typing import List, Optional, Tuple +from typing import Optional from posthog.models.cohort import Cohort from posthog.models.filters.filter import Filter @@ -147,7 +147,7 @@ class ClickhouseFunnelTrends(ClickhouseFunnelBase): return query - def get_steps_reached_conditions(self) -> Tuple[str, str, str]: + def get_steps_reached_conditions(self) -> tuple[str, str, str]: # How many steps must have been done to count for the denominator of a funnel trends data point from_step = self._filter.funnel_from_step or 0 # How many steps must have been done to count for the numerator of a funnel trends data point @@ -180,7 +180,7 @@ class ClickhouseFunnelTrends(ClickhouseFunnelBase): if breakdown_clause: if isinstance(period_row[-1], str) or ( - isinstance(period_row[-1], List) and all(isinstance(item, str) for item in period_row[-1]) + isinstance(period_row[-1], list) and all(isinstance(item, str) for item in period_row[-1]) ): serialized_result.update({"breakdown_value": (period_row[-1])}) else: diff --git a/posthog/queries/funnels/funnel_unordered.py b/posthog/queries/funnels/funnel_unordered.py index ac3a6d939b0..ee984b9462a 100644 --- a/posthog/queries/funnels/funnel_unordered.py +++ b/posthog/queries/funnels/funnel_unordered.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast from rest_framework.exceptions import ValidationError @@ -40,9 +40,9 @@ class ClickhouseFunnelUnordered(ClickhouseFunnelBase): self, step: Entity, count: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "action_id": None, "name": f"Completed {step.index+1} step{'s' if step.index != 0 else ''}", @@ -119,7 +119,7 @@ class ClickhouseFunnelUnordered(ClickhouseFunnelBase): return " UNION ALL ".join(union_queries) def _get_step_times(self, max_steps: int): - conditions: List[str] = [] + conditions: list[str] = [] conversion_times_elements = [] for i in range(max_steps): @@ -146,7 +146,7 @@ class ClickhouseFunnelUnordered(ClickhouseFunnelBase): conditions.append(f"arraySort([{','.join(event_times_elements)}]) as event_times") # replacement of latest_i for whatever query part requires it, just like conversion_times - basic_conditions: List[str] = [] + basic_conditions: list[str] = [] for i in range(1, max_steps): basic_conditions.append( f"if(latest_0 < latest_{i} AND latest_{i} <= latest_0 + INTERVAL {self._filter.funnel_window_interval} {self._filter.funnel_window_interval_unit_ch()}, 1, 0)" diff --git a/posthog/queries/funnels/funnel_unordered_persons.py b/posthog/queries/funnels/funnel_unordered_persons.py index 334798c9902..fc1e953bfb5 100644 --- a/posthog/queries/funnels/funnel_unordered_persons.py +++ b/posthog/queries/funnels/funnel_unordered_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -25,7 +25,7 @@ class ClickhouseFunnelUnorderedActors(ClickhouseFunnelUnordered, ActorBaseQuery) def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/test/breakdown_cases.py b/posthog/queries/funnels/test/breakdown_cases.py index b38384c745e..2bc977c974a 100644 --- a/posthog/queries/funnels/test/breakdown_cases.py +++ b/posthog/queries/funnels/test/breakdown_cases.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from string import ascii_lowercase -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from posthog.constants import INSIGHT_FUNNELS from posthog.models.cohort import Cohort @@ -20,7 +20,7 @@ from posthog.test.test_journeys import journeys_for class FunnelStepResult: name: str count: int - breakdown: Union[List[str], str] + breakdown: Union[list[str], str] average_conversion_time: Optional[float] = None median_conversion_time: Optional[float] = None type: Literal["events", "actions"] = "events" @@ -35,8 +35,8 @@ def funnel_breakdown_test_factory(Funnel, FunnelPerson, _create_event, _create_a return [val["id"] for val in serialized_result] - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -2646,11 +2646,11 @@ def funnel_breakdown_test_factory(Funnel, FunnelPerson, _create_event, _create_a return TestFunnelBreakdown -def sort_breakdown_funnel_results(results: List[Dict[int, Any]]): +def sort_breakdown_funnel_results(results: list[dict[int, Any]]): return sorted(results, key=lambda r: r[0]["breakdown_value"]) -def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str, Any]]): +def assert_funnel_results_equal(left: list[dict[str, Any]], right: list[dict[str, Any]]): """ Helper to be able to compare two funnel results, but exclude people urls from the comparison, as these include: @@ -2660,7 +2660,7 @@ def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str 2. contain timestamps which are not stable across runs """ - def _filter(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _filter(steps: list[dict[str, Any]]) -> list[dict[str, Any]]: return [{**step, "converted_people_url": None, "dropped_people_url": None} for step in steps] assert len(left) == len(right) diff --git a/posthog/queries/funnels/test/test_breakdowns_by_current_url.py b/posthog/queries/funnels/test/test_breakdowns_by_current_url.py index 7994b195fca..800cd9f46dc 100644 --- a/posthog/queries/funnels/test/test_breakdowns_by_current_url.py +++ b/posthog/queries/funnels/test/test_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.models import Filter from posthog.queries.funnels import ClickhouseFunnel @@ -115,7 +115,7 @@ class TestBreakdownsByCurrentURL(ClickhouseTestMixin, APIBaseTest): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/funnels/utils.py b/posthog/queries/funnels/utils.py index 68f93c2d454..b2c0df300ce 100644 --- a/posthog/queries/funnels/utils.py +++ b/posthog/queries/funnels/utils.py @@ -1,11 +1,9 @@ -from typing import Type - from posthog.constants import FunnelOrderType from posthog.models.filters import Filter from posthog.queries.funnels import ClickhouseFunnelBase -def get_funnel_order_class(filter: Filter) -> Type[ClickhouseFunnelBase]: +def get_funnel_order_class(filter: Filter) -> type[ClickhouseFunnelBase]: from posthog.queries.funnels import ( ClickhouseFunnel, ClickhouseFunnelStrict, diff --git a/posthog/queries/groups_join_query/groups_join_query.py b/posthog/queries/groups_join_query/groups_join_query.py index 2cc62849cac..6499d39ce1e 100644 --- a/posthog/queries/groups_join_query/groups_join_query.py +++ b/posthog/queries/groups_join_query/groups_join_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from posthog.models import Filter from posthog.models.filters.path_filter import PathFilter @@ -31,5 +31,5 @@ class GroupsJoinQuery: self._join_key = join_key self._person_on_events_mode = person_on_events_mode - def get_join_query(self) -> Tuple[str, Dict]: + def get_join_query(self) -> tuple[str, dict]: return "", {} diff --git a/posthog/queries/paths/paths.py b/posthog/queries/paths/paths.py index 6a98857e392..21438ee6ea7 100644 --- a/posthog/queries/paths/paths.py +++ b/posthog/queries/paths/paths.py @@ -1,6 +1,6 @@ import dataclasses from collections import defaultdict -from typing import Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -35,8 +35,8 @@ class Paths: _filter: PathFilter _funnel_filter: Optional[Filter] _team: Team - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] def __init__(self, filter: PathFilter, team: Team, funnel_filter: Optional[Filter] = None) -> None: self._filter = filter @@ -50,8 +50,8 @@ class Paths: } self._funnel_filter = funnel_filter - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if self._filter.include_recordings: self._extra_event_fields = ["uuid", "timestamp"] self._extra_event_properties = ["$session_id", "$window_id"] @@ -93,7 +93,7 @@ class Paths: ) return resp - def _exec_query(self) -> List[Tuple]: + def _exec_query(self) -> list[tuple]: query = self.get_query() return insight_sync_execute( query, @@ -225,7 +225,7 @@ class Paths: return "", {} # Implemented in /ee - def get_edge_weight_clause(self) -> Tuple[str, Dict]: + def get_edge_weight_clause(self) -> tuple[str, dict]: return "", {} # Implemented in /ee @@ -240,8 +240,8 @@ class Paths: return "arraySplit(x -> if(x.3 < %(session_time_threshold)s, 0, 1), paths_tuple)" # Implemented in /ee - def get_target_clause(self) -> Tuple[str, Dict]: - params: Dict[str, Union[str, None]] = { + def get_target_clause(self) -> tuple[str, dict]: + params: dict[str, Union[str, None]] = { "target_point": None, "secondary_target_point": None, } @@ -276,7 +276,7 @@ class Paths: return "arraySlice" # Implemented in /ee - def get_filtered_path_ordering(self) -> Tuple[str, ...]: + def get_filtered_path_ordering(self) -> tuple[str, ...]: fields_to_include = ["filtered_path", "filtered_timings"] + [ f"filtered_{field}s" for field in self.extra_event_fields_and_properties ] diff --git a/posthog/queries/paths/paths_actors.py b/posthog/queries/paths/paths_actors.py index e39a01dfee3..ec739271795 100644 --- a/posthog/queries/paths/paths_actors.py +++ b/posthog/queries/paths/paths_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, cast +from typing import Optional, cast from posthog.models.filters.filter import Filter from posthog.queries.actor_base_query import ActorBaseQuery @@ -27,7 +27,7 @@ class PathsActors(Paths, ActorBaseQuery): # type: ignore QUERY_TYPE = "paths" - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: paths_per_person_query = self.get_paths_per_person_query() person_path_filter = self.get_person_path_filter() paths_funnel_cte = "" diff --git a/posthog/queries/paths/paths_event_query.py b/posthog/queries/paths/paths_event_query.py index 61b032aa663..31241cea649 100644 --- a/posthog/queries/paths/paths_event_query.py +++ b/posthog/queries/paths/paths_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any from posthog.constants import ( FUNNEL_PATH_AFTER_STEP, @@ -21,7 +21,7 @@ class PathEventQuery(EventQuery): FUNNEL_PERSONS_ALIAS = "funnel_actors" _filter: PathFilter - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: funnel_paths_timestamp = "" funnel_paths_join = "" funnel_paths_filter = "" @@ -151,7 +151,7 @@ class PathEventQuery(EventQuery): if self._person_on_events_mode != PersonsOnEventsMode.disabled: self._should_join_persons = False - def _get_grouping_fields(self) -> Tuple[List[str], Dict[str, Any]]: + def _get_grouping_fields(self) -> tuple[list[str], dict[str, Any]]: _fields = [] params = {} @@ -188,8 +188,8 @@ class PathEventQuery(EventQuery): return _fields, params - def _get_event_query(self, deep_filtering: bool) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = {} + def _get_event_query(self, deep_filtering: bool) -> tuple[str, dict[str, Any]]: + params: dict[str, Any] = {} conditions = [] or_conditions = [] diff --git a/posthog/queries/person_query.py b/posthog/queries/person_query.py index cffcce890c8..b785fcb7442 100644 --- a/posthog/queries/person_query.py +++ b/posthog/queries/person_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union from uuid import UUID from posthog.clickhouse.materialized_columns import ColumnName @@ -45,7 +45,7 @@ class PersonQuery: _filter: Union[Filter, PathFilter, RetentionFilter, StickinessFilter] _team_id: int _column_optimizer: ColumnOptimizer - _extra_fields: Set[ColumnName] + _extra_fields: set[ColumnName] _inner_person_properties: Optional[PropertyGroup] _cohort: Optional[Cohort] _include_distinct_ids: Optional[bool] = False @@ -58,10 +58,10 @@ class PersonQuery: cohort: Optional[Cohort] = None, *, entity: Optional[Entity] = None, - extra_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, # A sub-optimal version of the `cohort` parameter above, the difference being that # this supports multiple cohort filters, but is not as performant as the above. - cohort_filters: Optional[List[Property]] = None, + cohort_filters: Optional[list[Property]] = None, include_distinct_ids: Optional[bool] = False, ) -> None: self._filter = filter @@ -90,7 +90,7 @@ class PersonQuery: prepend: Optional[Union[str, int]] = None, paginate: bool = False, filter_future_persons: bool = False, - ) -> Tuple[str, Dict]: + ) -> tuple[str, dict]: prepend = str(prepend) if prepend is not None else "" fields = "id" + " ".join( @@ -175,7 +175,7 @@ class PersonQuery: ) @property - def fields(self) -> List[ColumnName]: + def fields(self) -> list[ColumnName]: "Returns person table fields this query exposes" return [alias for column_name, alias in self._get_fields()] @@ -194,7 +194,7 @@ class PersonQuery: def _uses_person_id(self, prop: Property) -> bool: return prop.type in ("person", "static-cohort", "precalculated-cohort") - def _get_fields(self) -> List[Tuple[str, str]]: + def _get_fields(self) -> list[tuple[str, str]]: # :TRICKY: Figure out what fields we want to expose - minimizing this set is good for performance. # We use the result from column_optimizer to figure out counts of all properties to be filtered and queried. # Here, we remove the ones only to be used for filtering. @@ -207,7 +207,7 @@ class PersonQuery: return [(column_name, self.ALIASES.get(column_name, column_name)) for column_name in sorted(columns)] - def _get_person_filter_clauses(self, prepend: str = "") -> Tuple[str, str, Dict]: + def _get_person_filter_clauses(self, prepend: str = "") -> tuple[str, str, dict]: finalization_conditions, params = parse_prop_grouped_clauses( self._team_id, self._inner_person_properties, @@ -231,7 +231,7 @@ class PersonQuery: params.update(prefiltering_params) return prefiltering_conditions, finalization_conditions, params - def _get_fast_single_cohort_clause(self) -> Tuple[str, Dict]: + def _get_fast_single_cohort_clause(self) -> tuple[str, dict]: if self._cohort: cohort_table = ( GET_STATIC_COHORTPEOPLE_BY_COHORT_ID if self._cohort.is_static else GET_COHORTPEOPLE_BY_COHORT_ID @@ -252,10 +252,10 @@ class PersonQuery: else: return "", {} - def _get_multiple_cohorts_clause(self, prepend: str = "") -> Tuple[str, Dict]: + def _get_multiple_cohorts_clause(self, prepend: str = "") -> tuple[str, dict]: if self._cohort_filters: query = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} # TODO: doesn't support non-caclculated cohorts for index, property in enumerate(self._cohort_filters): @@ -274,7 +274,7 @@ class PersonQuery: else: return "", {} - def _get_limit_offset_clause(self) -> Tuple[str, Dict]: + def _get_limit_offset_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -295,7 +295,7 @@ class PersonQuery: return clause, params - def _get_search_clauses(self, prepend: str = "") -> Tuple[str, str, Dict]: + def _get_search_clauses(self, prepend: str = "") -> tuple[str, str, dict]: """ Return - respectively - the prefiltering search clause (not aggregated by is_deleted or version, which is great for memory usage), the final search clause (aggregated for true results, more expensive), and new params. @@ -365,7 +365,7 @@ class PersonQuery: return "", "", {} - def _get_distinct_id_clause(self) -> Tuple[str, Dict]: + def _get_distinct_id_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -378,7 +378,7 @@ class PersonQuery: return distinct_id_clause, {"distinct_id_filter": self._filter.distinct_id} return "", {} - def _add_distinct_id_join_if_needed(self, query: str, params: Dict[Any, Any]) -> Tuple[str, Dict[Any, Any]]: + def _add_distinct_id_join_if_needed(self, query: str, params: dict[Any, Any]) -> tuple[str, dict[Any, Any]]: if not self._include_distinct_ids: return query, params return ( @@ -395,7 +395,7 @@ class PersonQuery: params, ) - def _get_email_clause(self) -> Tuple[str, Dict]: + def _get_email_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -407,7 +407,7 @@ class PersonQuery: ) return "", {} - def _get_updated_after_clause(self) -> Tuple[str, Dict]: + def _get_updated_after_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} diff --git a/posthog/queries/properties_timeline/properties_timeline.py b/posthog/queries/properties_timeline/properties_timeline.py index 34c39235309..fe9191e5e1c 100644 --- a/posthog/queries/properties_timeline/properties_timeline.py +++ b/posthog/queries/properties_timeline/properties_timeline.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Any, Dict, List, Set, TypedDict, Union, cast +from typing import Any, TypedDict, Union, cast from posthog.models.filters.properties_timeline_filter import PropertiesTimelineFilter from posthog.models.group.group import Group @@ -18,13 +18,13 @@ from .properties_timeline_event_query import PropertiesTimelineEventQuery class PropertiesTimelinePoint(TypedDict): timestamp: str - properties: Dict[str, Any] + properties: dict[str, Any] relevant_event_count: int class PropertiesTimelineResult(TypedDict): - points: List[PropertiesTimelinePoint] - crucial_property_keys: List[str] + points: list[PropertiesTimelinePoint] + crucial_property_keys: list[str] effective_date_from: str effective_date_to: str @@ -56,7 +56,7 @@ WHERE timestamp IS NOT NULL /* Remove sentinel row */ class PropertiesTimeline: - def extract_crucial_property_keys(self, filter: PropertiesTimelineFilter) -> Set[str]: + def extract_crucial_property_keys(self, filter: PropertiesTimelineFilter) -> set[str]: is_filter_relevant = lambda property_type, property_group_type_index: ( (property_type == "person") if filter.aggregation_group_type_index is None @@ -76,7 +76,7 @@ class PropertiesTimeline: if filter.breakdown and filter.breakdown_type == "person": if isinstance(filter.breakdown, list): - crucial_property_keys.update(cast(List[str], filter.breakdown)) + crucial_property_keys.update(cast(list[str], filter.breakdown)) else: crucial_property_keys.add(filter.breakdown) diff --git a/posthog/queries/properties_timeline/properties_timeline_event_query.py b/posthog/queries/properties_timeline/properties_timeline_event_query.py index d3ca17eb700..b5e9a87d07c 100644 --- a/posthog/queries/properties_timeline/properties_timeline_event_query.py +++ b/posthog/queries/properties_timeline/properties_timeline_event_query.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from zoneinfo import ZoneInfo from posthog.models.entity.util import get_entity_filtering_params @@ -20,7 +20,7 @@ class PropertiesTimelineEventQuery(EventQuery): super().__init__(filter, *args, **kwargs) self._group_type_index = filter.aggregation_group_type_index - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: real_fields = [f"{self.EVENT_TABLE_ALIAS}.timestamp AS timestamp"] sentinel_fields = ["NULL AS timestamp"] @@ -72,8 +72,8 @@ class PropertiesTimelineEventQuery(EventQuery): def _determine_should_join_sessions(self) -> None: self._should_join_sessions = False - def _get_date_filter(self) -> Tuple[str, Dict]: - query_params: Dict[str, Any] = {} + def _get_date_filter(self) -> tuple[str, dict]: + query_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team) effective_timezone = ZoneInfo(self._team.timezone) # Get effective date range from QueryDateRange @@ -92,7 +92,7 @@ class PropertiesTimelineEventQuery(EventQuery): return date_filter, query_params - def _get_entity_query(self) -> Tuple[str, Dict]: + def _get_entity_query(self) -> tuple[str, dict]: entity_params, entity_format_params = get_entity_filtering_params( allowed_entities=self._filter.entities, team_id=self._team_id, diff --git a/posthog/queries/property_optimizer.py b/posthog/queries/property_optimizer.py index d69cadfe5e8..b11be666fd6 100644 --- a/posthog/queries/property_optimizer.py +++ b/posthog/queries/property_optimizer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, cast +from typing import Optional, cast from rest_framework.exceptions import ValidationError @@ -94,7 +94,7 @@ class PropertyOptimizer: elif isinstance(property_group.values[0], PropertyGroup): return all( PropertyOptimizer.using_only_person_properties(group) - for group in cast(List[PropertyGroup], property_group.values) + for group in cast(list[PropertyGroup], property_group.values) ) else: diff --git a/posthog/queries/property_values.py b/posthog/queries/property_values.py index a8b943f25d1..0e79d15146a 100644 --- a/posthog/queries/property_values.py +++ b/posthog/queries/property_values.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from django.utils import timezone @@ -16,7 +16,7 @@ from posthog.utils import relative_date_parse def get_property_values_for_key( key: str, team: Team, - event_names: Optional[List[str]] = None, + event_names: Optional[list[str]] = None, value: Optional[str] = None, ): property_field, mat_column_exists = get_property_string_expr("events", key, "%(key)s", "properties") diff --git a/posthog/queries/query_date_range.py b/posthog/queries/query_date_range.py index 2825e4e0360..578f2ccf041 100644 --- a/posthog/queries/query_date_range.py +++ b/posthog/queries/query_date_range.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from functools import cached_property -from typing import Dict, Literal, Optional, Tuple +from typing import Literal, Optional from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -117,7 +117,7 @@ class QueryDateRange: return self._get_timezone_aware_date_condition("date_from") @cached_property - def date_to(self) -> Tuple[str, Dict]: + def date_to(self) -> tuple[str, dict]: date_to_query = self.date_to_clause date_to = self.date_to_param @@ -129,7 +129,7 @@ class QueryDateRange: return date_to_query, date_to_param @cached_property - def date_from(self) -> Tuple[str, Dict]: + def date_from(self) -> tuple[str, dict]: date_from_query = self.date_from_clause date_from = self.date_from_param diff --git a/posthog/queries/retention/actors_query.py b/posthog/queries/retention/actors_query.py index 5a49c510a32..e087b88e44f 100644 --- a/posthog/queries/retention/actors_query.py +++ b/posthog/queries/retention/actors_query.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.models.filters.retention_filter import RetentionFilter from posthog.models.team import Team @@ -19,7 +19,7 @@ class AppearanceRow: actor_id: str appearance_count: int # This is actually the number of days from first event to the current event. - appearances: List[float] + appearances: list[float] # Note: This class does not respect the entire flor from ActorBaseQuery because the result shape differs from other actor queries @@ -98,7 +98,7 @@ def build_actor_activity_query( selected_interval: Optional[int] = None, aggregate_users_by_distinct_id: Optional[bool] = None, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: from posthog.queries.retention import ( build_returning_event_query, build_target_event_query, @@ -150,7 +150,7 @@ def _build_actor_query( filter_by_breakdown: Optional[BreakdownValues] = None, selected_interval: Optional[int] = None, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: actor_activity_query, actor_activity_query_params = build_actor_activity_query( filter=filter, team=team, diff --git a/posthog/queries/retention/retention.py b/posthog/queries/retention/retention.py index 8f8b0d89254..d3b9f43ca5c 100644 --- a/posthog/queries/retention/retention.py +++ b/posthog/queries/retention/retention.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from urllib.parse import urlencode from zoneinfo import ZoneInfo @@ -24,7 +24,7 @@ class Retention: def __init__(self, base_uri="/"): self._base_uri = base_uri - def run(self, filter: RetentionFilter, team: Team, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: RetentionFilter, team: Team, *args, **kwargs) -> list[dict[str, Any]]: filter.team = team retention_by_breakdown = self._get_retention_by_breakdown_values(filter, team) if filter.breakdowns: @@ -34,7 +34,7 @@ class Retention: def _get_retention_by_breakdown_values( self, filter: RetentionFilter, team: Team - ) -> Dict[CohortKey, Dict[str, Any]]: + ) -> dict[CohortKey, dict[str, Any]]: actor_query, actor_query_params = build_actor_activity_query( filter=filter, team=team, retention_events_query=self.event_query ) @@ -77,7 +77,7 @@ class Retention: ).to_params() return f"{self._base_uri}api/person/retention/?{urlencode(params)}" - def process_breakdown_table_result(self, resultset: Dict[CohortKey, Dict[str, Any]], filter: RetentionFilter): + def process_breakdown_table_result(self, resultset: dict[CohortKey, dict[str, Any]], filter: RetentionFilter): result = [ { "values": [ @@ -101,7 +101,7 @@ class Retention: def process_table_result( self, - resultset: Dict[CohortKey, Dict[str, Any]], + resultset: dict[CohortKey, dict[str, Any]], filter: RetentionFilter, team: Team, ): @@ -140,7 +140,7 @@ class Retention: return result - def actors_in_period(self, filter: RetentionFilter, team: Team) -> Tuple[list, int]: + def actors_in_period(self, filter: RetentionFilter, team: Team) -> tuple[list, int]: """ Creates a response of the form @@ -168,7 +168,7 @@ def build_returning_event_query( aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: returning_event_query_templated, returning_event_params = retention_events_query( filter=filter.shallow_clone({"breakdowns": []}), # Avoid pulling in breakdown values from returning event query team=team, @@ -186,7 +186,7 @@ def build_target_event_query( aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: target_event_query_templated, target_event_params = retention_events_query( filter=filter, team=team, diff --git a/posthog/queries/retention/retention_events_query.py b/posthog/queries/retention/retention_events_query.py index e84e4bc1e91..9e64b758be6 100644 --- a/posthog/queries/retention/retention_events_query.py +++ b/posthog/queries/retention/retention_events_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Literal, Optional, Tuple, Union, cast +from typing import Any, Literal, Optional, Union, cast from posthog.constants import ( PAGEVIEW_EVENT, @@ -37,7 +37,7 @@ class RetentionEventsQuery(EventQuery): person_on_events_mode=person_on_events_mode, ) - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: _fields = [ self.get_timestamp_field(), self.target_field(), diff --git a/posthog/queries/retention/types.py b/posthog/queries/retention/types.py index d3f77fab7f5..0a9e630da6a 100644 --- a/posthog/queries/retention/types.py +++ b/posthog/queries/retention/types.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Tuple, Union +from typing import NamedTuple, Union -BreakdownValues = Tuple[Union[str, int], ...] +BreakdownValues = tuple[Union[str, int], ...] CohortKey = NamedTuple("CohortKey", (("breakdown_values", BreakdownValues), ("period", int))) diff --git a/posthog/queries/stickiness/stickiness.py b/posthog/queries/stickiness/stickiness.py index 50c2ff81ad9..26204b8e996 100644 --- a/posthog/queries/stickiness/stickiness.py +++ b/posthog/queries/stickiness/stickiness.py @@ -1,6 +1,6 @@ import copy import urllib.parse -from typing import Any, Dict, List +from typing import Any from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.models.action import Action @@ -19,7 +19,7 @@ class Stickiness: event_query_class = StickinessEventsQuery actor_query_class = StickinessActors - def run(self, filter: StickinessFilter, team: Team, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: StickinessFilter, team: Team, *args, **kwargs) -> list[dict[str, Any]]: response = [] for entity in filter.entities: if entity.type == TREND_FILTER_TYPE_ACTIONS and entity.id is not None: @@ -29,7 +29,7 @@ class Stickiness: response.extend(entity_resp) return response - def stickiness(self, entity: Entity, filter: StickinessFilter, team: Team) -> Dict[str, Any]: + def stickiness(self, entity: Entity, filter: StickinessFilter, team: Team) -> dict[str, Any]: events_query, event_params = self.event_query_class( entity, filter, team, person_on_events_mode=team.person_on_events_mode ).get_query() @@ -66,8 +66,8 @@ class Stickiness: _, serialized_actors, _ = self.actor_query_class(entity=target_entity, filter=filter, team=team).get_actors() return serialized_actors - def process_result(self, counts: List, filter: StickinessFilter, entity: Entity) -> Dict[str, Any]: - response: Dict[int, int] = {} + def process_result(self, counts: list, filter: StickinessFilter, entity: Entity) -> dict[str, Any]: + response: dict[int, int] = {} for result in counts: response[result[1]] = result[0] @@ -92,8 +92,8 @@ class Stickiness: "persons_urls": self._get_persons_url(filter, entity), } - def _serialize_entity(self, entity: Entity, filter: StickinessFilter, team: Team) -> List[Dict[str, Any]]: - serialized: Dict[str, Any] = { + def _serialize_entity(self, entity: Entity, filter: StickinessFilter, team: Team) -> list[dict[str, Any]]: + serialized: dict[str, Any] = { "action": entity.to_dict(), "label": entity.name, "count": 0, @@ -107,7 +107,7 @@ class Stickiness: response.append(new_dict) return response - def _get_persons_url(self, filter: StickinessFilter, entity: Entity) -> List[Dict[str, Any]]: + def _get_persons_url(self, filter: StickinessFilter, entity: Entity) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for interval_idx in range(1, filter.total_intervals): @@ -119,7 +119,7 @@ class Stickiness: "entity_math": entity.math, "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, diff --git a/posthog/queries/stickiness/stickiness_actors.py b/posthog/queries/stickiness/stickiness_actors.py index 625d3852ce5..c6c20301f2b 100644 --- a/posthog/queries/stickiness/stickiness_actors.py +++ b/posthog/queries/stickiness/stickiness_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional from posthog.models.entity import Entity from posthog.models.filters.mixins.utils import cached_property @@ -22,7 +22,7 @@ class StickinessActors(ActorBaseQuery): def aggregation_group_type_index(self): return None - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: events_query, event_params = self.event_query_class( entity=self.entity, filter=self._filter, diff --git a/posthog/queries/stickiness/stickiness_event_query.py b/posthog/queries/stickiness/stickiness_event_query.py index 25d68b1d6bf..7c8c92222ef 100644 --- a/posthog/queries/stickiness/stickiness_event_query.py +++ b/posthog/queries/stickiness/stickiness_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.constants import TREND_FILTER_TYPE_ACTIONS, PropertyOperatorType from posthog.models import Entity @@ -20,7 +20,7 @@ class StickinessEventsQuery(EventQuery): super().__init__(*args, **kwargs) self._should_round_interval = True - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: prop_query, prop_params = self._get_prop_groups( self._filter.property_groups.combine_property_group(PropertyOperatorType.AND, self._entity.property_groups), person_properties_mode=get_person_properties_mode(self._team), @@ -95,7 +95,7 @@ class StickinessEventsQuery(EventQuery): def aggregation_target(self): return self._person_id_alias - def get_entity_query(self) -> Tuple[str, Dict[str, Any]]: + def get_entity_query(self) -> tuple[str, dict[str, Any]]: if self._entity.type == TREND_FILTER_TYPE_ACTIONS: condition, params = format_action_filter( team_id=self._team_id, diff --git a/posthog/queries/test/test_paths.py b/posthog/queries/test/test_paths.py index 45f09a9ca57..4be8e978981 100644 --- a/posthog/queries/test/test_paths.py +++ b/posthog/queries/test/test_paths.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -26,7 +25,7 @@ class MockEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict class TestPaths(ClickhouseTestMixin, APIBaseTest): diff --git a/posthog/queries/test/test_trends.py b/posthog/queries/test/test_trends.py index abb32426dd6..333babb6ccf 100644 --- a/posthog/queries/test/test_trends.py +++ b/posthog/queries/test/test_trends.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import patch, ANY from urllib.parse import parse_qsl, urlparse @@ -56,8 +56,8 @@ from posthog.test.test_journeys import journeys_for from posthog.utils import generate_cache_key -def breakdown_label(entity: Entity, value: Union[str, int]) -> Dict[str, Optional[Union[str, int]]]: - ret_dict: Dict[str, Optional[Union[str, int]]] = {} +def breakdown_label(entity: Entity, value: Union[str, int]) -> dict[str, Optional[Union[str, int]]]: + ret_dict: dict[str, Optional[Union[str, int]]] = {} if not value or not isinstance(value, str) or "cohort_" not in value: label = ( value @@ -112,7 +112,7 @@ class TestTrends(ClickhouseTestMixin, APIBaseTest): ).json() return response["results"][0]["people"] - def _create_events(self, use_time=False) -> Tuple[Action, Person]: + def _create_events(self, use_time=False) -> tuple[Action, Person]: person = _create_person( team_id=self.team.pk, distinct_ids=["blabla", "anonymous_id"], @@ -1788,7 +1788,7 @@ class TestTrends(ClickhouseTestMixin, APIBaseTest): ], ) - def _test_events_with_dates(self, dates: List[str], result, query_time=None, **filter_params): + def _test_events_with_dates(self, dates: list[str], result, query_time=None, **filter_params): _create_person(team_id=self.team.pk, distinct_ids=["person_1"], properties={"name": "John"}) for time in dates: with freeze_time(time): diff --git a/posthog/queries/time_to_see_data/hierarchy.py b/posthog/queries/time_to_see_data/hierarchy.py index b4b686b6124..260a1fad0ef 100644 --- a/posthog/queries/time_to_see_data/hierarchy.py +++ b/posthog/queries/time_to_see_data/hierarchy.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List class NodeType(Enum): @@ -24,7 +23,7 @@ NODE_TYPE_WEIGHTS = { class Node: type: NodeType data: dict - children: List["Node"] = field(default_factory=list) + children: list["Node"] = field(default_factory=list) def to_dict(self): return { @@ -39,7 +38,7 @@ def construct_hierarchy(session, interactions_and_events, queries) -> dict: Constructs a tree-like hierarchy for session based on interactions and queries, to expose triggered-by relationships. """ - nodes: List[Node] = [] + nodes: list[Node] = [] nodes.extend(make_empty_node(interaction_type, data) for data in interactions_and_events) nodes.extend(make_empty_node(query_type, data) for data in queries) diff --git a/posthog/queries/time_to_see_data/sessions.py b/posthog/queries/time_to_see_data/sessions.py index 8ebeeb8db36..709d253d5b7 100644 --- a/posthog/queries/time_to_see_data/sessions.py +++ b/posthog/queries/time_to_see_data/sessions.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Optional from posthog.client import query_with_columns from posthog.queries.time_to_see_data.hierarchy import construct_hierarchy @@ -58,7 +58,7 @@ def get_sessions(query: SessionsQuerySerializer) -> SessionResponseSerializer: return response_serializer -def get_session_events(query: SessionEventsQuerySerializer) -> Optional[Dict]: +def get_session_events(query: SessionEventsQuerySerializer) -> Optional[dict]: params = { "team_id": query.validated_data["team_id"], "session_id": query.validated_data["session_id"], @@ -82,12 +82,12 @@ def get_session_events(query: SessionEventsQuerySerializer) -> Optional[Dict]: return construct_hierarchy(sessions[0], events, queries) -def _fetch_sessions(query: SessionsQuerySerializer) -> List[Dict]: +def _fetch_sessions(query: SessionsQuerySerializer) -> list[dict]: condition, params = _sessions_condition(query) return query_with_columns(GET_SESSIONS.format(condition=condition), params) -def _sessions_condition(query: SessionsQuerySerializer) -> Tuple[str, Dict]: +def _sessions_condition(query: SessionsQuerySerializer) -> tuple[str, dict]: conditions = [] if "team_id" in query.validated_data: diff --git a/posthog/queries/trends/breakdown.py b/posthog/queries/trends/breakdown.py index 444f045384a..0f06984bac0 100644 --- a/posthog/queries/trends/breakdown.py +++ b/posthog/queries/trends/breakdown.py @@ -2,7 +2,8 @@ import json import re import urllib.parse from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union +from collections.abc import Callable from zoneinfo import ZoneInfo from django.forms import ValidationError @@ -104,7 +105,7 @@ class TrendsBreakdown: self.filter = filter self.team = team self.team_id = team.pk - self.params: Dict[str, Any] = {"team_id": team.pk} + self.params: dict[str, Any] = {"team_id": team.pk} self.column_optimizer = column_optimizer or ColumnOptimizer(self.filter, self.team_id) self.add_person_urls = add_person_urls self.person_on_events_mode = person_on_events_mode @@ -122,7 +123,7 @@ class TrendsBreakdown: return self._person_id_alias @cached_property - def _props_to_filter(self) -> Tuple[str, Dict]: + def _props_to_filter(self) -> tuple[str, dict]: props_to_filter = self.filter.property_groups.combine_property_group( PropertyOperatorType.AND, self.entity.property_groups ) @@ -140,7 +141,7 @@ class TrendsBreakdown: hogql_context=self.filter.hogql_context, ) - def get_query(self) -> Tuple[str, Dict, Callable]: + def get_query(self) -> tuple[str, dict, Callable]: date_params = {} query_date_range = QueryDateRange(filter=self.filter, team=self.team) @@ -165,7 +166,7 @@ class TrendsBreakdown: ) action_query = "" - action_params: Dict = {} + action_params: dict = {} if self.entity.type == TREND_FILTER_TYPE_ACTIONS: action = self.entity.get_action() action_query, action_params = format_action_filter( @@ -439,7 +440,7 @@ class TrendsBreakdown: return params, breakdown_filter, breakdown_filter_params, "value" - def _breakdown_prop_params(self, aggregate_operation: str, math_params: Dict): + def _breakdown_prop_params(self, aggregate_operation: str, math_params: dict): values_arr, has_more_values = get_breakdown_prop_values( self.filter, self.entity, @@ -564,7 +565,7 @@ class TrendsBreakdown: return breakdown_value - def _get_histogram_breakdown_values(self, raw_breakdown_value: str, buckets: List[int]): + def _get_histogram_breakdown_values(self, raw_breakdown_value: str, buckets: list[int]): multi_if_conditionals = [] values_arr = [] @@ -607,9 +608,9 @@ class TrendsBreakdown: return count_or_aggregated_value * -1, value.get("label") # reverse it def _parse_single_aggregate_result( - self, filter: Filter, entity: Entity, additional_values: Dict[str, Any] + self, filter: Filter, entity: Entity, additional_values: dict[str, Any] ) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] cache_invalidation_key = generate_short_id() for stats in result: @@ -623,7 +624,7 @@ class TrendsBreakdown: "breakdown_value": result_descriptors["breakdown_value"], "breakdown_type": filter.breakdown_type or "event", } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) parsed_result = { "aggregated_value": float( correct_result_for_sampling(aggregated_value, filter.sampling_factor, entity.math) @@ -647,7 +648,7 @@ class TrendsBreakdown: return _parse def _parse_trend_result(self, filter: Filter, entity: Entity) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] for stats in result: result_descriptors = self._breakdown_result_descriptors(stats[2], filter, entity) @@ -679,9 +680,9 @@ class TrendsBreakdown: filter: Filter, entity: Entity, team: Team, - point_dates: List[datetime], + point_dates: list[datetime], breakdown_value: Union[str, int], - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for point_date in point_dates: @@ -705,7 +706,7 @@ class TrendsBreakdown: "breakdown_value": breakdown_value, "breakdown_type": filter.breakdown_type or "event", } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, @@ -744,7 +745,7 @@ class TrendsBreakdown: else: return str(value) or BREAKDOWN_NULL_DISPLAY - def _person_join_condition(self) -> Tuple[str, Dict]: + def _person_join_condition(self) -> tuple[str, dict]: if self.person_on_events_mode == PersonsOnEventsMode.person_id_no_override_properties_on_events: return "", {} @@ -780,7 +781,7 @@ class TrendsBreakdown: else: return "", {} - def _groups_join_condition(self) -> Tuple[str, Dict]: + def _groups_join_condition(self) -> tuple[str, dict]: return GroupsJoinQuery( self.filter, self.team_id, @@ -788,7 +789,7 @@ class TrendsBreakdown: person_on_events_mode=self.person_on_events_mode, ).get_join_query() - def _sessions_join_condition(self) -> Tuple[str, Dict]: + def _sessions_join_condition(self) -> tuple[str, dict]: session_query = SessionQuery(filter=self.filter, team=self.team) if session_query.is_used: query, session_params = session_query.get_query() diff --git a/posthog/queries/trends/formula.py b/posthog/queries/trends/formula.py index 4f59e5b0cd7..b2fd1bcd806 100644 --- a/posthog/queries/trends/formula.py +++ b/posthog/queries/trends/formula.py @@ -2,7 +2,7 @@ import math from itertools import accumulate import re from string import ascii_uppercase -from typing import Any, Dict, List +from typing import Any from sentry_sdk import push_scope @@ -22,7 +22,7 @@ class TrendsFormula: def _run_formula_query(self, filter: Filter, team: Team): letters = [ascii_uppercase[i] for i in range(0, len(filter.entities))] queries = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for idx, entity in enumerate(filter.entities): _, sql, entity_params, _ = self._get_sql_for_entity(filter, team, entity) # type: ignore sql = PARAM_DISAMBIGUATION_REGEX.sub(f"%({idx}_", sql) @@ -96,7 +96,7 @@ class TrendsFormula: ) response = [] for item in result: - additional_values: Dict[str, Any] = {"label": self._label(filter, item)} + additional_values: dict[str, Any] = {"label": self._label(filter, item)} if filter.breakdown: additional_values["breakdown_value"] = additional_values["label"] @@ -113,7 +113,7 @@ class TrendsFormula: response.append(parse_response(item, filter, additional_values=additional_values)) return response - def _label(self, filter: Filter, item: List) -> str: + def _label(self, filter: Filter, item: list) -> str: if filter.breakdown: if filter.breakdown_type == "cohort": return get_breakdown_cohort_name(item[2]) diff --git a/posthog/queries/trends/lifecycle.py b/posthog/queries/trends/lifecycle.py index 2629672879e..199e3c57973 100644 --- a/posthog/queries/trends/lifecycle.py +++ b/posthog/queries/trends/lifecycle.py @@ -1,5 +1,6 @@ import urllib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any +from collections.abc import Callable from posthog.models.entity import Entity from posthog.models.entity.util import get_entity_filtering_params @@ -28,7 +29,7 @@ from posthog.utils import encode_get_request_params, generate_short_id class Lifecycle: - def _format_lifecycle_query(self, entity: Entity, filter: Filter, team: Team) -> Tuple[str, Dict, Callable]: + def _format_lifecycle_query(self, entity: Entity, filter: Filter, team: Team) -> tuple[str, dict, Callable]: event_query, event_params = LifecycleEventQuery( team=team, filter=filter, person_on_events_mode=team.person_on_events_mode ).get_query() @@ -40,7 +41,7 @@ class Lifecycle: ) def _parse_result(self, filter: Filter, entity: Entity, team: Team) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: res = [] for val in result: label = "{} - {}".format(entity.name, val[2]) @@ -61,7 +62,7 @@ class Lifecycle: _, serialized_actors, _ = LifecycleActors(filter=filter, team=team, limit_actors=True).get_actors() return serialized_actors - def _get_persons_urls(self, filter: Filter, entity: Entity, times: List[str], status) -> List[Dict[str, Any]]: + def _get_persons_urls(self, filter: Filter, entity: Entity, times: list[str], status) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for target_date in times: @@ -75,7 +76,7 @@ class Lifecycle: "lifecycle_type": status, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, @@ -167,7 +168,7 @@ class LifecycleEventQuery(EventQuery): ) def _get_date_filter(self): - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team, should_round=False) _, date_from_params = query_date_range.date_from _, date_to_params = query_date_range.date_to diff --git a/posthog/queries/trends/lifecycle_actors.py b/posthog/queries/trends/lifecycle_actors.py index 2b83dbb364d..0e4b7446cda 100644 --- a/posthog/queries/trends/lifecycle_actors.py +++ b/posthog/queries/trends/lifecycle_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional from posthog.queries.actor_base_query import ActorBaseQuery from posthog.queries.trends.lifecycle import LifecycleEventQuery @@ -13,7 +13,7 @@ class LifecycleActors(ActorBaseQuery): QUERY_TYPE = "lifecycle" - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: events_query, event_params = self.event_query_class( filter=self._filter, team=self._team, diff --git a/posthog/queries/trends/test/test_breakdowns.py b/posthog/queries/trends/test/test_breakdowns.py index 78b5a01e45a..3b8651d5415 100644 --- a/posthog/queries/trends/test/test_breakdowns.py +++ b/posthog/queries/trends/test/test_breakdowns.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.constants import TRENDS_TABLE from posthog.models import Filter @@ -104,7 +104,7 @@ class TestBreakdowns(ClickhouseTestMixin, APIBaseTest): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/trends/test/test_breakdowns_by_current_url.py b/posthog/queries/trends/test/test_breakdowns_by_current_url.py index 26e0c40ae64..8474d7a27bb 100644 --- a/posthog/queries/trends/test/test_breakdowns_by_current_url.py +++ b/posthog/queries/trends/test/test_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.models import Filter from posthog.queries.trends.trends import Trends @@ -99,7 +99,7 @@ class TestBreakdownsByCurrentURL(ClickhouseTestMixin, APIBaseTest): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/trends/test/test_formula.py b/posthog/queries/trends/test/test_formula.py index 01e838336e5..d711bbff6f8 100644 --- a/posthog/queries/trends/test/test_formula.py +++ b/posthog/queries/trends/test/test_formula.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from freezegun.api import freeze_time @@ -129,7 +129,7 @@ class TestFormula(ClickhouseTestMixin, APIBaseTest): }, ) - def _run(self, extra: Optional[Dict] = None, run_at: Optional[str] = None): + def _run(self, extra: Optional[dict] = None, run_at: Optional[str] = None): if extra is None: extra = {} with freeze_time(run_at or "2020-01-04T13:01:01Z"): diff --git a/posthog/queries/trends/test/test_paging_breakdowns.py b/posthog/queries/trends/test/test_paging_breakdowns.py index b4040fee618..47ea447005c 100644 --- a/posthog/queries/trends/test/test_paging_breakdowns.py +++ b/posthog/queries/trends/test/test_paging_breakdowns.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from freezegun import freeze_time @@ -38,7 +38,7 @@ class TestPagingBreakdowns(APIBaseTest): create_people=True, ) - def _run(self, extra: Optional[Dict] = None, run_at: Optional[str] = None): + def _run(self, extra: Optional[dict] = None, run_at: Optional[str] = None): if extra is None: extra = {} with freeze_time(run_at or "2020-01-04T13:01:01Z"): diff --git a/posthog/queries/trends/total_volume.py b/posthog/queries/trends/total_volume.py index e36f6d2de73..5e91d9272cf 100644 --- a/posthog/queries/trends/total_volume.py +++ b/posthog/queries/trends/total_volume.py @@ -1,6 +1,7 @@ import urllib.parse from datetime import date, datetime, timedelta -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Union +from collections.abc import Callable from posthog.clickhouse.query_tagging import tag_queries from posthog.constants import ( @@ -48,7 +49,7 @@ class TrendsTotalVolume: EVENT_TABLE_ALIAS = EventQuery.EVENT_TABLE_ALIAS PERSON_ID_OVERRIDES_TABLE_ALIAS = EventQuery.PERSON_ID_OVERRIDES_TABLE_ALIAS - def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> Tuple[str, Dict, Callable]: + def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> tuple[str, dict, Callable]: interval_func = get_interval_func_ch(filter.interval) person_id_alias = f"{self.DISTINCT_ID_TABLE_ALIAS}.person_id" @@ -82,7 +83,7 @@ class TrendsTotalVolume: "timestamp": "e.timestamp", "interval_func": interval_func, } - params: Dict = {"team_id": team.id, "timezone": team.timezone} + params: dict = {"team_id": team.id, "timezone": team.timezone} params = {**params, **math_params, **event_query_params} if filter.display in NON_TIME_SERIES_DISPLAY_TYPES: @@ -219,14 +220,14 @@ class TrendsTotalVolume: return final_query, params, self._parse_total_volume_result(filter, entity, team) def _parse_total_volume_result(self, filter: Filter, entity: Entity, team: Team) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] if result is not None: for stats in result: parsed_result = parse_response(stats, filter, entity=entity) - point_dates: List[Union[datetime, date]] = stats[0] + point_dates: list[Union[datetime, date]] = stats[0] # Ensure we have datetimes for all points - point_datetimes: List[datetime] = [ + point_datetimes: list[datetime] = [ datetime.combine(d, datetime.min.time()) if not isinstance(d, datetime) else d for d in point_dates ] @@ -238,7 +239,7 @@ class TrendsTotalVolume: return _parse def _parse_aggregate_volume_result(self, filter: Filter, entity: Entity, team_id: int) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: aggregated_value = result[0][0] if result else 0 seconds_in_interval = TIME_IN_SECONDS[filter.interval] time_range = enumerate_time_range(filter, seconds_in_interval) @@ -249,7 +250,7 @@ class TrendsTotalVolume: "entity_math": entity.math, "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) cache_invalidation_key = generate_short_id() return [ @@ -286,8 +287,8 @@ class TrendsTotalVolume: filter: Filter, entity: Entity, team: Team, - point_datetimes: List[datetime], - ) -> List[Dict[str, Any]]: + point_datetimes: list[datetime], + ) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for point_datetime in point_datetimes: @@ -301,7 +302,7 @@ class TrendsTotalVolume: "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, diff --git a/posthog/queries/trends/trends.py b/posthog/queries/trends/trends.py index 81e35336138..da8e0ff80e1 100644 --- a/posthog/queries/trends/trends.py +++ b/posthog/queries/trends/trends.py @@ -2,7 +2,8 @@ import copy import threading from datetime import datetime, timedelta from itertools import accumulate -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Optional, cast +from collections.abc import Callable from zoneinfo import ZoneInfo from dateutil import parser @@ -33,7 +34,7 @@ from posthog.utils import generate_cache_key, get_safe_cache class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): - def _get_sql_for_entity(self, filter: Filter, team: Team, entity: Entity) -> Tuple[str, str, Dict, Callable]: + def _get_sql_for_entity(self, filter: Filter, team: Team, entity: Entity) -> tuple[str, str, dict, Callable]: if filter.breakdown and filter.display not in NON_BREAKDOWN_DISPLAY_TYPES: query_type = "trends_breakdown" sql, params, parse_function = TrendsBreakdown( @@ -53,7 +54,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return query_type, sql, params, parse_function # Use cached result even on refresh if team has strict caching enabled - def get_cached_result(self, filter: Filter, team: Team) -> Optional[List[Dict[str, Any]]]: + def get_cached_result(self, filter: Filter, team: Team) -> Optional[list[dict[str, Any]]]: if not team.strict_caching_enabled or filter.breakdown or filter.display != TRENDS_LINEAR: return None @@ -73,7 +74,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return cached_result if _is_present else None # Determine if the current timerange is present in the cache - def is_present_timerange(self, cached_result: List[Dict[str, Any]], filter: Filter, team: Team) -> bool: + def is_present_timerange(self, cached_result: list[dict[str, Any]], filter: Filter, team: Team) -> bool: if ( len(cached_result) > 0 and cached_result[0].get("days") @@ -92,7 +93,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return _is_present # Use a condensed filter if a cached result exists in the current timerange - def adjusted_filter(self, filter: Filter, team: Team) -> Tuple[Filter, Optional[Dict[str, Any]]]: + def adjusted_filter(self, filter: Filter, team: Team) -> tuple[Filter, Optional[dict[str, Any]]]: cached_result = self.get_cached_result(filter, team) new_filter = filter.shallow_clone({"date_from": interval_unit(filter.interval)}) if cached_result else filter @@ -107,7 +108,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): def merge_results( self, result, - cached_result: Optional[Dict[str, Any]], + cached_result: Optional[dict[str, Any]], entity_order: int, filter: Filter, team: Team, @@ -129,7 +130,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): else: return result, {} - def _run_query(self, filter: Filter, team: Team, entity: Entity) -> List[Dict[str, Any]]: + def _run_query(self, filter: Filter, team: Team, entity: Entity) -> list[dict[str, Any]]: adjusted_filter, cached_result = self.adjusted_filter(filter, team) with push_scope() as scope: query_type, sql, params, parse_function = self._get_sql_for_entity(adjusted_filter, team, entity) @@ -163,12 +164,12 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): def _run_query_for_threading( self, - result: List, + result: list, index: int, query_type, sql, params, - query_tags: Dict, + query_tags: dict, filter: Filter, team_id: int, ): @@ -177,10 +178,10 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): scope.set_context("query", {"sql": sql, "params": params}) result[index] = insight_sync_execute(sql, params, query_type=query_type, filter=filter, team_id=team_id) - def _run_parallel(self, filter: Filter, team: Team) -> List[Dict[str, Any]]: - result: List[Optional[List[Dict[str, Any]]]] = [None] * len(filter.entities) - parse_functions: List[Optional[Callable]] = [None] * len(filter.entities) - sql_statements_with_params: List[Tuple[Optional[str], Dict]] = [(None, {})] * len(filter.entities) + def _run_parallel(self, filter: Filter, team: Team) -> list[dict[str, Any]]: + result: list[Optional[list[dict[str, Any]]]] = [None] * len(filter.entities) + parse_functions: list[Optional[Callable]] = [None] * len(filter.entities) + sql_statements_with_params: list[tuple[Optional[str], dict]] = [(None, {})] * len(filter.entities) cached_result = None jobs = [] @@ -225,7 +226,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): "params": sql_statements_with_params[i][1], }, ) - serialized_data = cast(List[Callable], parse_functions)[entity.index](result[entity.index]) + serialized_data = cast(list[Callable], parse_functions)[entity.index](result[entity.index]) serialized_data = self._format_serialized(entity, serialized_data) merged_results, cached_result = self.merge_results( serialized_data, @@ -237,9 +238,9 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): result[entity.index] = merged_results # flatten results - flat_results: List[Dict[str, Any]] = [] + flat_results: list[dict[str, Any]] = [] for item in result: - for flat in cast(List[Dict[str, Any]], item): + for flat in cast(list[dict[str, Any]], item): flat_results.append(flat) if cached_result: @@ -248,7 +249,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return flat_results - def run(self, filter: Filter, team: Team, is_csv_export: bool = False, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: Filter, team: Team, is_csv_export: bool = False, *args, **kwargs) -> list[dict[str, Any]]: self.is_csv_export = is_csv_export actions = Action.objects.filter(team_id=team.pk).order_by("-id") if len(filter.actions) > 0: @@ -274,10 +275,10 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return result - def _format_serialized(self, entity: Entity, result: List[Dict[str, Any]]): + def _format_serialized(self, entity: Entity, result: list[dict[str, Any]]): serialized_data = [] - serialized: Dict[str, Any] = { + serialized: dict[str, Any] = { "action": entity.to_dict(), "label": entity.name, "count": 0, @@ -293,7 +294,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): return serialized_data - def _handle_cumulative(self, entity_metrics: List) -> List[Dict[str, Any]]: + def _handle_cumulative(self, entity_metrics: list) -> list[dict[str, Any]]: for metrics in entity_metrics: metrics.update(data=list(accumulate(metrics["data"]))) return entity_metrics diff --git a/posthog/queries/trends/trends_actors.py b/posthog/queries/trends/trends_actors.py index 9c4afa89c41..f7db8b36d8a 100644 --- a/posthog/queries/trends/trends_actors.py +++ b/posthog/queries/trends/trends_actors.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.constants import PropertyOperatorType from posthog.models.cohort import Cohort @@ -37,7 +37,7 @@ class TrendsActors(ActorBaseQuery): return self.entity.math_group_type_index return None - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: if self._filter.breakdown_type == "cohort" and self._filter.breakdown_value != "all": cohort = Cohort.objects.get(pk=self._filter.breakdown_value, team_id=self._team.pk) self._filter = self._filter.shallow_clone( @@ -95,7 +95,7 @@ class TrendsActors(ActorBaseQuery): } ) - extra_fields: List[str] = ["distinct_id", "team_id"] if not self.is_aggregating_by_groups else [] + extra_fields: list[str] = ["distinct_id", "team_id"] if not self.is_aggregating_by_groups else [] if self._filter.include_recordings: extra_fields += ["uuid"] @@ -147,7 +147,7 @@ class TrendsActors(ActorBaseQuery): return "person_id" @cached_property - def _aggregation_actor_value_expression_with_params(self) -> Tuple[str, Dict[str, Any]]: + def _aggregation_actor_value_expression_with_params(self) -> tuple[str, dict[str, Any]]: if self.entity.math in PROPERTY_MATH_FUNCTIONS: math_aggregate_operation, _, math_params = process_math( self.entity, self._team, filter=self._filter, event_table_alias="e" diff --git a/posthog/queries/trends/trends_event_query.py b/posthog/queries/trends/trends_event_query.py index bc9e9b979bd..b856cb6a035 100644 --- a/posthog/queries/trends/trends_event_query.py +++ b/posthog/queries/trends/trends_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.models.property.util import get_property_string_expr from posthog.queries.trends.trends_event_query_base import TrendsEventQueryBase @@ -6,7 +6,7 @@ from posthog.schema import PersonsOnEventsMode class TrendsEventQuery(TrendsEventQueryBase): - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: person_id_field = "" if self._should_join_distinct_ids: person_id_field = f", {self._person_id_alias} as person_id" diff --git a/posthog/queries/trends/trends_event_query_base.py b/posthog/queries/trends/trends_event_query_base.py index dbeb9f17cdc..8fb17d3579e 100644 --- a/posthog/queries/trends/trends_event_query_base.py +++ b/posthog/queries/trends/trends_event_query_base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.constants import ( MONTHLY_ACTIVE, @@ -29,7 +29,7 @@ class TrendsEventQueryBase(EventQuery): self._entity = entity super().__init__(*args, **kwargs) - def get_query_base(self) -> Tuple[str, Dict[str, Any]]: + def get_query_base(self) -> tuple[str, dict[str, Any]]: """ Returns part of the event query with only FROM, JOINs and WHERE clauses. """ @@ -114,9 +114,9 @@ class TrendsEventQueryBase(EventQuery): # If aggregating by group, exclude events that aren't associated with a group return f"""AND "$group_{self._entity.math_group_type_index}" != ''""" - def _get_date_filter(self) -> Tuple[str, Dict]: + def _get_date_filter(self) -> tuple[str, dict]: date_query = "" - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team) parsed_date_from, date_from_params = query_date_range.date_from parsed_date_to, date_to_params = query_date_range.date_to @@ -145,7 +145,7 @@ class TrendsEventQueryBase(EventQuery): return date_query, date_params - def _get_entity_query(self, *, deep_filtering: bool) -> Tuple[str, Dict]: + def _get_entity_query(self, *, deep_filtering: bool) -> tuple[str, dict]: entity_params, entity_format_params = get_entity_filtering_params( allowed_entities=[self._entity], team_id=self._team_id, diff --git a/posthog/queries/trends/util.py b/posthog/queries/trends/util.py index e002145de99..3558640602e 100644 --- a/posthog/queries/trends/util.py +++ b/posthog/queries/trends/util.py @@ -1,6 +1,6 @@ import datetime from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Optional, TypeVar from zoneinfo import ZoneInfo import structlog @@ -60,10 +60,10 @@ def process_math( filter: Filter, event_table_alias: Optional[str] = None, person_id_alias: str = "person_id", -) -> Tuple[str, str, Dict[str, Any]]: +) -> tuple[str, str, dict[str, Any]]: aggregate_operation = "count(*)" join_condition = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if entity.math in (UNIQUE_USERS, WEEKLY_ACTIVE, MONTHLY_ACTIVE): if team.aggregate_users_by_distinct_id: @@ -100,11 +100,11 @@ def process_math( def parse_response( - stats: Dict, + stats: dict, filter: Filter, - additional_values: Optional[Dict] = None, + additional_values: Optional[dict] = None, entity: Optional[Entity] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: if additional_values is None: additional_values = {} counts = stats[1] @@ -122,7 +122,7 @@ def parse_response( } -def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> tuple[dict[str, Any], dict[str, Any]]: diff = timedelta(days=7 if entity.math == WEEKLY_ACTIVE else 30) date_from: datetime.datetime @@ -155,11 +155,11 @@ def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> Tupl return format_params, query_params -def enumerate_time_range(filter: Filter, seconds_in_interval: int) -> List[str]: +def enumerate_time_range(filter: Filter, seconds_in_interval: int) -> list[str]: date_from = filter.date_from date_to = filter.date_to delta = timedelta(seconds=seconds_in_interval) - time_range: List[str] = [] + time_range: list[str] = [] if not date_from or not date_to: return time_range diff --git a/posthog/queries/util.py b/posthog/queries/util.py index e366fb1cc78..e0d2cb9896f 100644 --- a/posthog/queries/util.py +++ b/posthog/queries/util.py @@ -1,7 +1,7 @@ import json from datetime import datetime, timedelta from enum import Enum, auto -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from zoneinfo import ZoneInfo from django.utils import timezone @@ -46,21 +46,21 @@ GET_EARLIEST_TIMESTAMP_SQL = """ SELECT timestamp from events WHERE team_id = %(team_id)s AND timestamp > %(earliest_timestamp)s order by timestamp limit 1 """ -TIME_IN_SECONDS: Dict[str, Any] = { +TIME_IN_SECONDS: dict[str, Any] = { "hour": 3600, "day": 3600 * 24, "week": 3600 * 24 * 7, "month": 3600 * 24 * 30, # TODO: Let's get rid of this lie! Months are not all 30 days long } -PERIOD_TO_TRUNC_FUNC: Dict[str, str] = { +PERIOD_TO_TRUNC_FUNC: dict[str, str] = { "hour": "toStartOfHour", "week": "toStartOfWeek", "day": "toStartOfDay", "month": "toStartOfMonth", } -PERIOD_TO_INTERVAL_FUNC: Dict[str, str] = { +PERIOD_TO_INTERVAL_FUNC: dict[str, str] = { "hour": "toIntervalHour", "week": "toIntervalWeek", "day": "toIntervalDay", @@ -141,7 +141,7 @@ def get_time_in_seconds_for_period(period: Optional[str]) -> str: return seconds_in_period -def deep_dump_object(params: Dict[str, Any]) -> Dict[str, Any]: +def deep_dump_object(params: dict[str, Any]) -> dict[str, Any]: for key in params: if isinstance(params[key], dict) or isinstance(params[key], list): params[key] = json.dumps(params[key]) diff --git a/posthog/rate_limit.py b/posthog/rate_limit.py index 856d1b6cceb..d85238c3d49 100644 --- a/posthog/rate_limit.py +++ b/posthog/rate_limit.py @@ -2,7 +2,7 @@ import hashlib import re import time from functools import lru_cache -from typing import List, Optional +from typing import Optional from prometheus_client import Counter from rest_framework.throttling import SimpleRateThrottle, BaseThrottle, UserRateThrottle @@ -36,7 +36,7 @@ DECIDE_RATE_LIMIT_EXCEEDED_COUNTER = Counter( @lru_cache(maxsize=1) -def get_team_allow_list(_ttl: int) -> List[str]: +def get_team_allow_list(_ttl: int) -> list[str]: """ The "allow list" will change way less frequently than it will be called _ttl is passed an infrequently changing value to ensure the cache is invalidated after some delay diff --git a/posthog/renderers.py b/posthog/renderers.py index fa2d532fdce..2c7853497ea 100644 --- a/posthog/renderers.py +++ b/posthog/renderers.py @@ -1,10 +1,8 @@ -from typing import Dict - import orjson from rest_framework.renderers import JSONRenderer from rest_framework.utils.encoders import JSONEncoder -CleaningMarker = bool | Dict[int, "CleaningMarker"] +CleaningMarker = bool | dict[int, "CleaningMarker"] class SafeJSONRenderer(JSONRenderer): diff --git a/posthog/schema.py b/posthog/schema.py index 5673db2a3bf..46ad0beb9a1 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel @@ -165,8 +165,8 @@ class DatabaseSchemaQueryResponseField(BaseModel): model_config = ConfigDict( extra="forbid", ) - chain: Optional[List[str]] = None - fields: Optional[List[str]] = None + chain: Optional[list[str]] = None + fields: Optional[list[str]] = None key: str table: Optional[str] = None type: str @@ -203,9 +203,9 @@ class ElementType(BaseModel): model_config = ConfigDict( extra="forbid", ) - attr_class: Optional[List[str]] = None + attr_class: Optional[list[str]] = None attr_id: Optional[str] = None - attributes: Dict[str, str] + attributes: dict[str, str] href: Optional[str] = None nth_child: Optional[float] = None nth_of_type: Optional[float] = None @@ -232,9 +232,9 @@ class EventDefinition(BaseModel): model_config = ConfigDict( extra="forbid", ) - elements: List + elements: list event: str - properties: Dict[str, Any] + properties: dict[str, Any] class CorrelationType(str, Enum): @@ -257,9 +257,9 @@ class Person(BaseModel): model_config = ConfigDict( extra="forbid", ) - distinct_ids: List[str] + distinct_ids: list[str] is_identified: Optional[bool] = None - properties: Dict[str, Any] + properties: dict[str, Any] class EventType(BaseModel): @@ -267,12 +267,12 @@ class EventType(BaseModel): extra="forbid", ) distinct_id: str - elements: List[ElementType] + elements: list[ElementType] elements_chain: Optional[str] = None event: str id: str person: Optional[Person] = None - properties: Dict[str, Any] + properties: dict[str, Any] timestamp: str uuid: Optional[str] = None @@ -282,7 +282,7 @@ class Response(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class Properties(BaseModel): @@ -321,7 +321,7 @@ class FunnelCorrelationResult(BaseModel): model_config = ConfigDict( extra="forbid", ) - events: List[EventOddsRatioSerialized] + events: list[EventOddsRatioSerialized] skewed: bool @@ -374,7 +374,7 @@ class FunnelTimeToConvertResults(BaseModel): extra="forbid", ) average_conversion_time: Optional[float] = None - bins: List[List[int]] + bins: list[list[int]] class FunnelVizType(str, Enum): @@ -432,7 +432,7 @@ class HogQLQueryModifiers(BaseModel): model_config = ConfigDict( extra="forbid", ) - dataWarehouseEventsModifiers: Optional[List[DataWarehouseEventsModifier]] = None + dataWarehouseEventsModifiers: Optional[list[DataWarehouseEventsModifier]] = None inCohortVia: Optional[InCohortVia] = None materializationMode: Optional[MaterializationMode] = None personsArgMaxVersion: Optional[PersonsArgMaxVersion] = None @@ -496,12 +496,12 @@ class InsightActorsQueryOptionsResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[List[BreakdownItem]] = None - compare: Optional[List[CompareItem]] = None - day: Optional[List[DayItem]] = None - interval: Optional[List[IntervalItem]] = None - series: Optional[List[Series]] = None - status: Optional[List[StatusItem]] = None + breakdown: Optional[list[BreakdownItem]] = None + compare: Optional[list[CompareItem]] = None + day: Optional[list[DayItem]] = None + interval: Optional[list[IntervalItem]] = None + series: Optional[list[Series]] = None + status: Optional[list[StatusItem]] = None class InsightFilterProperty(str, Enum): @@ -604,14 +604,14 @@ class PathsFilter(BaseModel): ) edgeLimit: Optional[int] = None endPoint: Optional[str] = None - excludeEvents: Optional[List[str]] = None - includeEventTypes: Optional[List[PathType]] = None - localPathCleaningFilters: Optional[List[PathCleaningFilter]] = None + excludeEvents: Optional[list[str]] = None + includeEventTypes: Optional[list[PathType]] = None + localPathCleaningFilters: Optional[list[PathCleaningFilter]] = None maxEdgeWeight: Optional[int] = None minEdgeWeight: Optional[int] = None pathDropoffKey: Optional[str] = Field(default=None, description="Relevant only within actors query") pathEndKey: Optional[str] = Field(default=None, description="Relevant only within actors query") - pathGroupings: Optional[List[str]] = None + pathGroupings: Optional[list[str]] = None pathReplacements: Optional[bool] = None pathStartKey: Optional[str] = Field(default=None, description="Relevant only within actors query") pathsHogQLExpression: Optional[str] = None @@ -625,14 +625,14 @@ class PathsFilterLegacy(BaseModel): ) edge_limit: Optional[int] = None end_point: Optional[str] = None - exclude_events: Optional[List[str]] = None - funnel_filter: Optional[Dict[str, Any]] = None + exclude_events: Optional[list[str]] = None + funnel_filter: Optional[dict[str, Any]] = None funnel_paths: Optional[FunnelPathType] = None - include_event_types: Optional[List[PathType]] = None - local_path_cleaning_filters: Optional[List[PathCleaningFilter]] = None + include_event_types: Optional[list[PathType]] = None + local_path_cleaning_filters: Optional[list[PathCleaningFilter]] = None max_edge_weight: Optional[int] = None min_edge_weight: Optional[int] = None - path_groupings: Optional[List[str]] = None + path_groupings: Optional[list[str]] = None path_replacements: Optional[bool] = None path_type: Optional[PathType] = None paths_hogql_expression: Optional[str] = None @@ -693,39 +693,39 @@ class QueryResponseAlternative1(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class QueryResponseAlternative2(BaseModel): model_config = ConfigDict( extra="forbid", ) - results: List[Dict[str, Any]] + results: list[dict[str, Any]] class QueryResponseAlternative5(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[List[BreakdownItem]] = None - compare: Optional[List[CompareItem]] = None - day: Optional[List[DayItem]] = None - interval: Optional[List[IntervalItem]] = None - series: Optional[List[Series]] = None - status: Optional[List[StatusItem]] = None + breakdown: Optional[list[BreakdownItem]] = None + compare: Optional[list[CompareItem]] = None + day: Optional[list[DayItem]] = None + interval: Optional[list[IntervalItem]] = None + series: Optional[list[Series]] = None + status: Optional[list[StatusItem]] = None class QueryResponseAlternative8(BaseModel): model_config = ConfigDict( extra="forbid", ) - errors: List[HogQLNotice] + errors: list[HogQLNotice] inputExpr: Optional[str] = None inputSelect: Optional[str] = None isValid: Optional[bool] = None isValidView: Optional[bool] = None - notices: List[HogQLNotice] - warnings: List[HogQLNotice] + notices: list[HogQLNotice] + warnings: list[HogQLNotice] class QueryStatus(BaseModel): @@ -822,7 +822,7 @@ class SessionPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["session"] = "session" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class StepOrderValue(str, Enum): @@ -837,7 +837,7 @@ class StickinessFilter(BaseModel): ) compare: Optional[bool] = None display: Optional[ChartDisplayType] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None showLegend: Optional[bool] = None showValuesOnSeries: Optional[bool] = None @@ -848,7 +848,7 @@ class StickinessFilterLegacy(BaseModel): ) compare: Optional[bool] = None display: Optional[ChartDisplayType] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None show_legend: Optional[bool] = None show_values_on_series: Optional[bool] = None @@ -862,8 +862,8 @@ class StickinessQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class TimeToSeeDataQuery(BaseModel): @@ -874,7 +874,7 @@ class TimeToSeeDataQuery(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") sessionEnd: Optional[str] = None sessionId: Optional[str] = Field(default=None, description="Project to filter on. Defaults to current session") sessionStart: Optional[str] = Field( @@ -887,7 +887,7 @@ class TimeToSeeDataSessionsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - results: List[Dict[str, Any]] + results: list[dict[str, Any]] class TimeToSeeDataWaterfallNode(BaseModel): @@ -902,7 +902,7 @@ class TimelineEntry(BaseModel): model_config = ConfigDict( extra="forbid", ) - events: List[EventType] + events: list[EventType] recording_duration_s: Optional[float] = Field(default=None, description="Duration of the recording in seconds.") sessionId: Optional[str] = Field(default=None, description="Session ID. None means out-of-session events") @@ -919,7 +919,7 @@ class TrendsFilter(BaseModel): decimalPlaces: Optional[float] = None display: Optional[ChartDisplayType] = None formula: Optional[str] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None showLabelsOnSeries: Optional[bool] = None showLegend: Optional[bool] = None showPercentStackView: Optional[bool] = None @@ -939,7 +939,7 @@ class TrendsFilterLegacy(BaseModel): decimal_places: Optional[float] = None display: Optional[ChartDisplayType] = None formula: Optional[str] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None show_labels_on_series: Optional[bool] = None show_legend: Optional[bool] = None show_percent_stack_view: Optional[bool] = None @@ -956,8 +956,8 @@ class TrendsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class ActionsPie(BaseModel): @@ -1020,9 +1020,9 @@ class WebOverviewQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[WebOverviewItem] + results: list[WebOverviewItem] samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class WebStatsBreakdown(str, Enum): @@ -1047,7 +1047,7 @@ class WebStatsTableQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None is_cached: Optional[bool] = None @@ -1056,42 +1056,42 @@ class WebStatsTableQueryResponse(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None offset: Optional[int] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class WebTopClicksQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hogql: Optional[str] = None is_cached: Optional[bool] = None last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class ActorsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: int missing_actors_count: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: int - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class AnyResponseType1(BaseModel): @@ -1099,7 +1099,7 @@ class AnyResponseType1(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class Breakdown(BaseModel): @@ -1115,14 +1115,14 @@ class BreakdownFilter(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[Union[str, float, List[Union[str, float]]]] = None + breakdown: Optional[Union[str, float, list[Union[str, float]]]] = None breakdown_group_type_index: Optional[int] = None breakdown_hide_other_aggregation: Optional[bool] = None breakdown_histogram_bin_count: Optional[int] = None breakdown_limit: Optional[int] = None breakdown_normalize_url: Optional[bool] = None breakdown_type: Optional[BreakdownType] = None - breakdowns: Optional[List[Breakdown]] = None + breakdowns: Optional[list[Breakdown]] = None class DataNode(BaseModel): @@ -1133,16 +1133,16 @@ class DataNode(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class ChartSettings(BaseModel): model_config = ConfigDict( extra="forbid", ) - goalLines: Optional[List[GoalLine]] = None + goalLines: Optional[list[GoalLine]] = None xAxis: Optional[ChartAxis] = None - yAxis: Optional[List[ChartAxis]] = None + yAxis: Optional[list[ChartAxis]] = None class DataWarehousePersonPropertyFilter(BaseModel): @@ -1153,7 +1153,7 @@ class DataWarehousePersonPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["data_warehouse_person_property"] = "data_warehouse_person_property" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class DataWarehousePropertyFilter(BaseModel): @@ -1164,7 +1164,7 @@ class DataWarehousePropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["data_warehouse"] = "data_warehouse" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class ElementPropertyFilter(BaseModel): @@ -1175,7 +1175,7 @@ class ElementPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["element"] = "element" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class EventPropertyFilter(BaseModel): @@ -1186,22 +1186,22 @@ class EventPropertyFilter(BaseModel): label: Optional[str] = None operator: Optional[PropertyOperator] = PropertyOperator("exact") type: Literal["event"] = Field(default="event", description="Event properties") - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class EventsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class FeaturePropertyFilter(BaseModel): @@ -1212,22 +1212,22 @@ class FeaturePropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["feature"] = Field(default="feature", description='Event property with "$feature/" prepended') - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class FunnelCorrelationResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None results: FunnelCorrelationResult - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class FunnelsFilterLegacy(BaseModel): @@ -1237,7 +1237,7 @@ class FunnelsFilterLegacy(BaseModel): bin_count: Optional[Union[float, str]] = None breakdown_attribution_type: Optional[BreakdownAttributionType] = None breakdown_attribution_value: Optional[float] = None - exclusions: Optional[List[FunnelExclusionLegacy]] = None + exclusions: Optional[list[FunnelExclusionLegacy]] = None funnel_aggregate_by_hogql: Optional[str] = None funnel_from_step: Optional[float] = None funnel_order_type: Optional[StepOrderValue] = None @@ -1246,7 +1246,7 @@ class FunnelsFilterLegacy(BaseModel): funnel_viz_type: Optional[FunnelVizType] = None funnel_window_interval: Optional[float] = None funnel_window_interval_unit: Optional[FunnelConversionWindowTimeUnit] = None - hidden_legend_breakdowns: Optional[List[str]] = None + hidden_legend_breakdowns: Optional[list[str]] = None layout: Optional[FunnelLayout] = None @@ -1259,8 +1259,8 @@ class FunnelsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: Union[FunnelTimeToConvertResults, List[Dict[str, Any]], List[List[Dict[str, Any]]]] - timings: Optional[List[QueryTiming]] = None + results: Union[FunnelTimeToConvertResults, list[dict[str, Any]], list[list[dict[str, Any]]]] + timings: Optional[list[QueryTiming]] = None class GroupPropertyFilter(BaseModel): @@ -1272,7 +1272,7 @@ class GroupPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["group"] = "group" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class HogQLAutocompleteResponse(BaseModel): @@ -1280,8 +1280,8 @@ class HogQLAutocompleteResponse(BaseModel): extra="forbid", ) incomplete_list: bool = Field(..., description="Whether or not the suggestions returned are complete") - suggestions: List[AutocompleteCompletionItem] - timings: Optional[List[QueryTiming]] = Field( + suggestions: list[AutocompleteCompletionItem] + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -1290,13 +1290,13 @@ class HogQLMetadataResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - errors: List[HogQLNotice] + errors: list[HogQLNotice] inputExpr: Optional[str] = None inputSelect: Optional[str] = None isValid: Optional[bool] = None isValidView: Optional[bool] = None - notices: List[HogQLNotice] - warnings: List[HogQLNotice] + notices: list[HogQLNotice] + warnings: list[HogQLNotice] class HogQLPropertyFilter(BaseModel): @@ -1306,7 +1306,7 @@ class HogQLPropertyFilter(BaseModel): key: str label: Optional[str] = None type: Literal["hogql"] = "hogql" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class HogQLQueryResponse(BaseModel): @@ -1314,11 +1314,11 @@ class HogQLQueryResponse(BaseModel): extra="forbid", ) clickhouse: Optional[str] = Field(default=None, description="Executed ClickHouse query") - columns: Optional[List] = Field(default=None, description="Returned columns") + columns: Optional[list] = Field(default=None, description="Returned columns") error: Optional[str] = Field( default=None, description="Query error. Returned only if 'explain' is true. Throws an error otherwise." ) - explain: Optional[List[str]] = Field(default=None, description="Query explanation output") + explain: Optional[list[str]] = Field(default=None, description="Query explanation output") hasMore: Optional[bool] = None hogql: Optional[str] = Field(default=None, description="Generated HogQL query") limit: Optional[int] = None @@ -1328,11 +1328,11 @@ class HogQLQueryResponse(BaseModel): ) offset: Optional[int] = None query: Optional[str] = Field(default=None, description="Input query string") - results: Optional[List] = Field(default=None, description="Query results") - timings: Optional[List[QueryTiming]] = Field( + results: Optional[list] = Field(default=None, description="Query results") + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) - types: Optional[List] = Field(default=None, description="Types of returned columns") + types: Optional[list] = Field(default=None, description="Types of returned columns") class InsightActorsQueryBase(BaseModel): @@ -1349,7 +1349,7 @@ class LifecycleFilter(BaseModel): extra="forbid", ) showValuesOnSeries: Optional[bool] = None - toggledLifecycles: Optional[List[LifecycleToggle]] = None + toggledLifecycles: Optional[list[LifecycleToggle]] = None class LifecycleFilterLegacy(BaseModel): @@ -1357,7 +1357,7 @@ class LifecycleFilterLegacy(BaseModel): extra="forbid", ) show_values_on_series: Optional[bool] = None - toggledLifecycles: Optional[List[LifecycleToggle]] = None + toggledLifecycles: Optional[list[LifecycleToggle]] = None class LifecycleQueryResponse(BaseModel): @@ -1369,8 +1369,8 @@ class LifecycleQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class Node(BaseModel): @@ -1389,8 +1389,8 @@ class PathsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class PersonPropertyFilter(BaseModel): @@ -1401,7 +1401,7 @@ class PersonPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["person"] = Field(default="person", description="Person properties") - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class QueryResponse(BaseModel): @@ -1414,38 +1414,38 @@ class QueryResponse(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None results: Any - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative3(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class QueryResponseAlternative4(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: int missing_actors_count: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: int - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class QueryResponseAlternative6(BaseModel): @@ -1454,8 +1454,8 @@ class QueryResponseAlternative6(BaseModel): ) hasMore: Optional[bool] = None hogql: Optional[str] = None - results: List[TimelineEntry] - timings: Optional[List[QueryTiming]] = None + results: list[TimelineEntry] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative7(BaseModel): @@ -1463,11 +1463,11 @@ class QueryResponseAlternative7(BaseModel): extra="forbid", ) clickhouse: Optional[str] = Field(default=None, description="Executed ClickHouse query") - columns: Optional[List] = Field(default=None, description="Returned columns") + columns: Optional[list] = Field(default=None, description="Returned columns") error: Optional[str] = Field( default=None, description="Query error. Returned only if 'explain' is true. Throws an error otherwise." ) - explain: Optional[List[str]] = Field(default=None, description="Query explanation output") + explain: Optional[list[str]] = Field(default=None, description="Query explanation output") hasMore: Optional[bool] = None hogql: Optional[str] = Field(default=None, description="Generated HogQL query") limit: Optional[int] = None @@ -1477,11 +1477,11 @@ class QueryResponseAlternative7(BaseModel): ) offset: Optional[int] = None query: Optional[str] = Field(default=None, description="Input query string") - results: Optional[List] = Field(default=None, description="Query results") - timings: Optional[List[QueryTiming]] = Field( + results: Optional[list] = Field(default=None, description="Query results") + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) - types: Optional[List] = Field(default=None, description="Types of returned columns") + types: Optional[list] = Field(default=None, description="Types of returned columns") class QueryResponseAlternative9(BaseModel): @@ -1489,8 +1489,8 @@ class QueryResponseAlternative9(BaseModel): extra="forbid", ) incomplete_list: bool = Field(..., description="Whether or not the suggestions returned are complete") - suggestions: List[AutocompleteCompletionItem] - timings: Optional[List[QueryTiming]] = Field( + suggestions: list[AutocompleteCompletionItem] + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -1504,16 +1504,16 @@ class QueryResponseAlternative10(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[WebOverviewItem] + results: list[WebOverviewItem] samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative11(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None is_cached: Optional[bool] = None @@ -1522,26 +1522,26 @@ class QueryResponseAlternative11(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None offset: Optional[int] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class QueryResponseAlternative12(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hogql: Optional[str] = None is_cached: Optional[bool] = None last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class QueryResponseAlternative13(BaseModel): @@ -1553,23 +1553,23 @@ class QueryResponseAlternative13(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative17(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None results: FunnelCorrelationResult - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class RetentionFilter(BaseModel): @@ -1602,7 +1602,7 @@ class RetentionResult(BaseModel): ) date: AwareDatetime label: str - values: List[RetentionValue] + values: list[RetentionValue] class SavedInsightNode(BaseModel): @@ -1664,8 +1664,8 @@ class SessionsTimelineQueryResponse(BaseModel): ) hasMore: Optional[bool] = None hogql: Optional[str] = None - results: List[TimelineEntry] - timings: Optional[List[QueryTiming]] = None + results: list[TimelineEntry] + timings: Optional[list[QueryTiming]] = None class TimeToSeeDataJSONNode(BaseModel): @@ -1699,7 +1699,7 @@ class WebAnalyticsQueryBase(BaseModel): ) dateRange: Optional[DateRange] = None modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1712,7 +1712,7 @@ class WebOverviewQuery(BaseModel): dateRange: Optional[DateRange] = None kind: Literal["WebOverviewQuery"] = "WebOverviewQuery" modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebOverviewQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1730,7 +1730,7 @@ class WebStatsTableQuery(BaseModel): kind: Literal["WebStatsTableQuery"] = "WebStatsTableQuery" limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebStatsTableQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1743,7 +1743,7 @@ class WebTopClicksQuery(BaseModel): dateRange: Optional[DateRange] = None kind: Literal["WebTopClicksQuery"] = "WebTopClicksQuery" modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebTopClicksQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1752,7 +1752,7 @@ class WebTopClicksQuery(BaseModel): class AnyResponseType( RootModel[ Union[ - Dict[str, Any], + dict[str, Any], HogQLQueryResponse, HogQLMetadataResponse, HogQLAutocompleteResponse, @@ -1762,7 +1762,7 @@ class AnyResponseType( ] ): root: Union[ - Dict[str, Any], + dict[str, Any], HogQLQueryResponse, HogQLMetadataResponse, HogQLAutocompleteResponse, @@ -1778,7 +1778,7 @@ class DashboardFilter(BaseModel): date_from: Optional[str] = None date_to: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1804,7 +1804,7 @@ class DataWarehouseNode(BaseModel): custom_name: Optional[str] = None distinct_id_field: str fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1838,7 +1838,7 @@ class DataWarehouseNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1855,7 +1855,7 @@ class DataWarehouseNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") table_name: str timestamp_field: str @@ -1868,7 +1868,7 @@ class DatabaseSchemaQuery(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, List[DatabaseSchemaQueryResponseField]]] = Field( + response: Optional[dict[str, list[DatabaseSchemaQueryResponseField]]] = Field( default=None, description="Cached query response" ) @@ -1879,7 +1879,7 @@ class EntityNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1911,7 +1911,7 @@ class EntityNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1928,7 +1928,7 @@ class EntityNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class EventsNode(BaseModel): @@ -1938,7 +1938,7 @@ class EventsNode(BaseModel): custom_name: Optional[str] = None event: Optional[str] = Field(default=None, description="The event or `null` for all events.") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1970,9 +1970,9 @@ class EventsNode(BaseModel): default=None, description="Modifiers used when performing the query" ) name: Optional[str] = None - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2002,7 +2002,7 @@ class EventsQuery(BaseModel): event: Optional[str] = Field(default=None, description="Limit to events matching this string") filterTestAccounts: Optional[bool] = Field(default=None, description="Filter test accounts") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2028,10 +2028,10 @@ class EventsQuery(BaseModel): default=None, description="Modifiers used when performing the query" ) offset: Optional[int] = Field(default=None, description="Number of rows to skip before returning rows") - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") personId: Optional[str] = Field(default=None, description="Show events for a given person") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2049,8 +2049,8 @@ class EventsQuery(BaseModel): ] ] = Field(default=None, description="Properties configurable in the interface") response: Optional[EventsQueryResponse] = Field(default=None, description="Cached query response") - select: List[str] = Field(..., description="Return a limited set of data. Required.") - where: Optional[List[str]] = Field(default=None, description="HogQL filters to apply on returned data") + select: list[str] = Field(..., description="Return a limited set of data. Required.") + where: Optional[list[str]] = Field(default=None, description="HogQL filters to apply on returned data") class FunnelExclusionActionsNode(BaseModel): @@ -2059,7 +2059,7 @@ class FunnelExclusionActionsNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2094,7 +2094,7 @@ class FunnelExclusionActionsNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2111,7 +2111,7 @@ class FunnelExclusionActionsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class FunnelExclusionEventsNode(BaseModel): @@ -2121,7 +2121,7 @@ class FunnelExclusionEventsNode(BaseModel): custom_name: Optional[str] = None event: Optional[str] = Field(default=None, description="The event or `null` for all events.") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2155,9 +2155,9 @@ class FunnelExclusionEventsNode(BaseModel): default=None, description="Modifiers used when performing the query" ) name: Optional[str] = None - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2184,7 +2184,7 @@ class HogQLFilters(BaseModel): dateRange: Optional[DateRange] = None filterTestAccounts: Optional[bool] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2215,7 +2215,7 @@ class HogQLQuery(BaseModel): ) query: str response: Optional[HogQLQueryResponse] = Field(default=None, description="Cached query response") - values: Optional[Dict[str, Any]] = Field( + values: Optional[dict[str, Any]] = Field( default=None, description="Constant values that can be referenced with the {placeholder} syntax in the query" ) @@ -2227,7 +2227,7 @@ class PersonsNode(BaseModel): cohort: Optional[int] = None distinctId: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2254,7 +2254,7 @@ class PersonsNode(BaseModel): ) offset: Optional[int] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2271,7 +2271,7 @@ class PersonsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") search: Optional[str] = None @@ -2280,7 +2280,7 @@ class PropertyGroupFilterValue(BaseModel): extra="forbid", ) type: FilterLogicalOperator - values: List[ + values: list[ Union[ PropertyGroupFilterValue, Union[ @@ -2310,15 +2310,15 @@ class QueryResponseAlternative14(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[RetentionResult] - timings: Optional[List[QueryTiming]] = None + results: list[RetentionResult] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative( RootModel[ Union[ QueryResponseAlternative1, - Dict[str, Any], + dict[str, Any], QueryResponseAlternative2, QueryResponseAlternative3, QueryResponseAlternative4, @@ -2333,13 +2333,13 @@ class QueryResponseAlternative( QueryResponseAlternative13, QueryResponseAlternative14, QueryResponseAlternative17, - Dict[str, List[DatabaseSchemaQueryResponseField]], + dict[str, list[DatabaseSchemaQueryResponseField]], ] ] ): root: Union[ QueryResponseAlternative1, - Dict[str, Any], + dict[str, Any], QueryResponseAlternative2, QueryResponseAlternative3, QueryResponseAlternative4, @@ -2354,7 +2354,7 @@ class QueryResponseAlternative( QueryResponseAlternative13, QueryResponseAlternative14, QueryResponseAlternative17, - Dict[str, List[DatabaseSchemaQueryResponseField]], + dict[str, list[DatabaseSchemaQueryResponseField]], ] @@ -2367,8 +2367,8 @@ class RetentionQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[RetentionResult] - timings: Optional[List[QueryTiming]] = None + results: list[RetentionResult] + timings: Optional[list[QueryTiming]] = None class SessionsTimelineQuery(BaseModel): @@ -2395,7 +2395,7 @@ class ActionsNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2428,7 +2428,7 @@ class ActionsNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2445,7 +2445,7 @@ class ActionsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class DataVisualizationNode(BaseModel): @@ -2465,7 +2465,7 @@ class FunnelsFilter(BaseModel): binCount: Optional[int] = None breakdownAttributionType: Optional[BreakdownAttributionType] = None breakdownAttributionValue: Optional[int] = None - exclusions: Optional[List[Union[FunnelExclusionEventsNode, FunnelExclusionActionsNode]]] = None + exclusions: Optional[list[Union[FunnelExclusionEventsNode, FunnelExclusionActionsNode]]] = None funnelAggregateByHogQL: Optional[str] = None funnelFromStep: Optional[int] = None funnelOrderType: Optional[StepOrderValue] = None @@ -2474,7 +2474,7 @@ class FunnelsFilter(BaseModel): funnelVizType: Optional[FunnelVizType] = None funnelWindowInterval: Optional[int] = None funnelWindowIntervalUnit: Optional[FunnelConversionWindowTimeUnit] = None - hidden_legend_breakdowns: Optional[List[str]] = None + hidden_legend_breakdowns: Optional[list[str]] = None layout: Optional[FunnelLayout] = None @@ -2508,7 +2508,7 @@ class PropertyGroupFilter(BaseModel): extra="forbid", ) type: FilterLogicalOperator - values: List[PropertyGroupFilterValue] + values: list[PropertyGroupFilterValue] class RetentionQuery(BaseModel): @@ -2526,7 +2526,7 @@ class RetentionQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2567,7 +2567,7 @@ class StickinessQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2587,7 +2587,7 @@ class StickinessQuery(BaseModel): ] ] = Field(default=None, description="Property filters for all series") samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) stickinessFilter: Optional[StickinessFilter] = Field( @@ -2614,7 +2614,7 @@ class TrendsQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2635,7 +2635,7 @@ class TrendsQuery(BaseModel): ] = Field(default=None, description="Property filters for all series") response: Optional[TrendsQueryResponse] = None samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) trendsFilter: Optional[TrendsFilter] = Field(default=None, description="Properties specific to the trends insight") @@ -2645,22 +2645,22 @@ class FilterType(BaseModel): model_config = ConfigDict( extra="forbid", ) - actions: Optional[List[Dict[str, Any]]] = None + actions: Optional[list[dict[str, Any]]] = None aggregation_group_type_index: Optional[float] = None - breakdown: Optional[Union[str, float, List[Union[str, float]]]] = None + breakdown: Optional[Union[str, float, list[Union[str, float]]]] = None breakdown_group_type_index: Optional[float] = None breakdown_hide_other_aggregation: Optional[bool] = None breakdown_limit: Optional[int] = None breakdown_normalize_url: Optional[bool] = None breakdown_type: Optional[BreakdownType] = None - breakdowns: Optional[List[Breakdown]] = None - data_warehouse: Optional[List[Dict[str, Any]]] = None + breakdowns: Optional[list[Breakdown]] = None + data_warehouse: Optional[list[dict[str, Any]]] = None date_from: Optional[str] = None date_to: Optional[str] = None entity_id: Optional[Union[str, float]] = None entity_math: Optional[str] = None entity_type: Optional[EntityType] = None - events: Optional[List[Dict[str, Any]]] = None + events: Optional[list[dict[str, Any]]] = None explicit_date: Optional[Union[bool, str]] = Field( default=None, description='Whether the `date_from` and `date_to` should be used verbatim. Disables rounding to the start and end of period. Strings are cast to bools, e.g. "true" -> true.', @@ -2669,10 +2669,10 @@ class FilterType(BaseModel): from_dashboard: Optional[Union[bool, float]] = None insight: Optional[InsightType] = None interval: Optional[IntervalType] = None - new_entity: Optional[List[Dict[str, Any]]] = None + new_entity: Optional[list[dict[str, Any]]] = None properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2716,7 +2716,7 @@ class FunnelsQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2736,7 +2736,7 @@ class FunnelsQuery(BaseModel): ] ] = Field(default=None, description="Property filters for all series") samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) @@ -2756,7 +2756,7 @@ class InsightsQueryBase(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2798,7 +2798,7 @@ class LifecycleQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2819,7 +2819,7 @@ class LifecycleQuery(BaseModel): ] = Field(default=None, description="Property filters for all series") response: Optional[LifecycleQueryResponse] = None samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) @@ -2844,7 +2844,7 @@ class FunnelsActorsQuery(BaseModel): model_config = ConfigDict( extra="forbid", ) - funnelCustomSteps: Optional[List[int]] = Field( + funnelCustomSteps: Optional[list[int]] = Field( default=None, description="Custom step numbers to get persons for. This overrides `funnelStep`. Primarily for correlation use.", ) @@ -2852,7 +2852,7 @@ class FunnelsActorsQuery(BaseModel): default=None, description="Index of the step for which we want to get the timestamp for, per person. Positive for converted persons, negative for dropped of persons.", ) - funnelStepBreakdown: Optional[Union[str, float, List[Union[str, float]]]] = Field( + funnelStepBreakdown: Optional[Union[str, float, list[Union[str, float]]]] = Field( default=None, description="The breakdown value for which to get persons for. This is an array for person and event properties, a string for groups and an integer for cohorts.", ) @@ -2887,7 +2887,7 @@ class PathsQuery(BaseModel): pathsFilter: PathsFilter = Field(..., description="Properties specific to the paths insight") properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2914,11 +2914,11 @@ class FunnelCorrelationQuery(BaseModel): model_config = ConfigDict( extra="forbid", ) - funnelCorrelationEventExcludePropertyNames: Optional[List[str]] = None - funnelCorrelationEventNames: Optional[List[str]] = None - funnelCorrelationExcludeEventNames: Optional[List[str]] = None - funnelCorrelationExcludeNames: Optional[List[str]] = None - funnelCorrelationNames: Optional[List[str]] = None + funnelCorrelationEventExcludePropertyNames: Optional[list[str]] = None + funnelCorrelationEventNames: Optional[list[str]] = None + funnelCorrelationExcludeEventNames: Optional[list[str]] = None + funnelCorrelationExcludeNames: Optional[list[str]] = None + funnelCorrelationNames: Optional[list[str]] = None funnelCorrelationType: FunnelCorrelationResultsType kind: Literal["FunnelCorrelationQuery"] = "FunnelCorrelationQuery" response: Optional[FunnelCorrelationResponse] = None @@ -2956,7 +2956,7 @@ class FunnelCorrelationActorsQuery(BaseModel): funnelCorrelationPersonConverted: Optional[bool] = None funnelCorrelationPersonEntity: Optional[Union[EventsNode, ActionsNode, DataWarehouseNode]] = None funnelCorrelationPropertyValues: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3015,7 +3015,7 @@ class ActorsQuery(BaseModel): extra="forbid", ) fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3038,9 +3038,9 @@ class ActorsQuery(BaseModel): default=None, description="Modifiers used when performing the query" ) offset: Optional[int] = None - orderBy: Optional[List[str]] = None + orderBy: Optional[list[str]] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3059,7 +3059,7 @@ class ActorsQuery(BaseModel): ] = None response: Optional[ActorsQueryResponse] = Field(default=None, description="Cached query response") search: Optional[str] = None - select: Optional[List[str]] = None + select: Optional[list[str]] = None source: Optional[Union[InsightActorsQuery, FunnelsActorsQuery, FunnelCorrelationActorsQuery, HogQLQuery]] = None @@ -3070,7 +3070,7 @@ class DataTableNode(BaseModel): allowSorting: Optional[bool] = Field( default=None, description="Can the user click on column headers to sort the table? (default: true)" ) - columns: Optional[List[str]] = Field( + columns: Optional[list[str]] = Field( default=None, description="Columns shown in the table, unless the `source` provides them." ) embedded: Optional[bool] = Field(default=None, description="Uses the embedded version of LemonTable") @@ -3078,7 +3078,7 @@ class DataTableNode(BaseModel): default=None, description="Can expand row to show raw event data (default: true)" ) full: Optional[bool] = Field(default=None, description="Show with most visual options enabled. Used in scenes.") - hiddenColumns: Optional[List[str]] = Field( + hiddenColumns: Optional[list[str]] = Field( default=None, description="Columns that aren't shown in the table, even if in columns or returned data" ) kind: Literal["DataTableNode"] = "DataTableNode" diff --git a/posthog/session_recordings/models/metadata.py b/posthog/session_recordings/models/metadata.py index 4d75c70dae4..dd26fde6a3b 100644 --- a/posthog/session_recordings/models/metadata.py +++ b/posthog/session_recordings/models/metadata.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Dict, List, Optional, TypedDict, Union, Literal +from typing import Optional, TypedDict, Union, Literal -SnapshotData = Dict +SnapshotData = dict WindowId = Optional[str] @@ -22,7 +22,7 @@ class SessionRecordingEventSummary(TypedDict): timestamp: int type: int # keys of this object should be any of EVENT_SUMMARY_DATA_INCLUSIONS - data: Dict[str, Union[int, str]] + data: dict[str, Union[int, str]] # NOTE: MatchingSessionRecordingEvent is a minimal version of full events that is used to display events matching a filter on the frontend @@ -35,7 +35,7 @@ class MatchingSessionRecordingEvent(TypedDict): class DecompressedRecordingData(TypedDict): has_next: bool - snapshot_data_by_window_id: Dict[WindowId, List[Union[SnapshotData, SessionRecordingEventSummary]]] + snapshot_data_by_window_id: dict[WindowId, list[Union[SnapshotData, SessionRecordingEventSummary]]] class RecordingMetadata(TypedDict): @@ -55,10 +55,10 @@ class RecordingMetadata(TypedDict): class RecordingMatchingEvents(TypedDict): - events: List[MatchingSessionRecordingEvent] + events: list[MatchingSessionRecordingEvent] class PersistedRecordingV1(TypedDict): version: str # "2022-12-22" - snapshot_data_by_window_id: Dict[WindowId, List[Union[SnapshotData, SessionRecordingEventSummary]]] + snapshot_data_by_window_id: dict[WindowId, list[Union[SnapshotData, SessionRecordingEventSummary]]] distinct_id: str diff --git a/posthog/session_recordings/models/session_recording.py b/posthog/session_recordings/models/session_recording.py index c217d41cef8..d5ac114f8a2 100644 --- a/posthog/session_recordings/models/session_recording.py +++ b/posthog/session_recordings/models/session_recording.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, Optional +from typing import Any, Literal, Optional from django.conf import settings from django.db import models @@ -136,7 +136,7 @@ class SessionRecording(UUIDModel): def build_object_storage_path(self, version: Literal["2023-08-01", "2022-12-22"]) -> str: if version == "2022-12-22": - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_SESSION_RECORDING_LTS_FOLDER, f"team-{self.team_id}", f"session-{self.session_id}", @@ -161,7 +161,7 @@ class SessionRecording(UUIDModel): return SessionRecording(session_id=session_id, team=team) @staticmethod - def get_or_build_from_clickhouse(team: Team, ch_recordings: List[dict]) -> "List[SessionRecording]": + def get_or_build_from_clickhouse(team: Team, ch_recordings: list[dict]) -> "list[SessionRecording]": session_ids = sorted([recording["session_id"] for recording in ch_recordings]) recordings_by_id = { @@ -193,7 +193,7 @@ class SessionRecording(UUIDModel): return recordings - def set_start_url_from_urls(self, urls: Optional[List[str]] = None, first_url: Optional[str] = None): + def set_start_url_from_urls(self, urls: Optional[list[str]] = None, first_url: Optional[str] = None): if first_url: self.start_url = first_url[:512] return diff --git a/posthog/session_recordings/queries/session_query.py b/posthog/session_recordings/queries/session_query.py index d0ff7b32afb..eb856194806 100644 --- a/posthog/session_recordings/queries/session_query.py +++ b/posthog/session_recordings/queries/session_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from posthog.models import Filter from posthog.models.filters.path_filter import PathFilter @@ -29,7 +29,7 @@ class SessionQuery: self._team = team self._session_id_alias = session_id_alias - def get_query(self) -> Tuple[str, Dict]: + def get_query(self) -> tuple[str, dict]: params = {"team_id": self._team.pk} query_date_range = QueryDateRange(filter=self._filter, team=self._team, should_round=False) diff --git a/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py b/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py index 4f64fff7f8a..b9458c597c9 100644 --- a/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py +++ b/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py @@ -1,7 +1,7 @@ import dataclasses import re from datetime import datetime, timedelta -from typing import Any, Dict, List, Literal, NamedTuple, Tuple, Union +from typing import Any, Literal, NamedTuple, Union from django.conf import settings from sentry_sdk import capture_exception @@ -25,15 +25,15 @@ class SummaryEventFiltersSQL: having_conditions: str having_select: str where_conditions: str - params: Dict[str, Any] + params: dict[str, Any] class SessionRecordingQueryResult(NamedTuple): - results: List + results: list has_more_recording: bool -def _get_recording_start_time_clause(recording_filters: SessionRecordingsFilter) -> Tuple[str, Dict[str, Any]]: +def _get_recording_start_time_clause(recording_filters: SessionRecordingsFilter) -> tuple[str, dict[str, Any]]: start_time_clause = "" start_time_params = {} if recording_filters.date_from: @@ -52,7 +52,7 @@ def _get_order_by_clause(filter_order: str | None) -> str: def _get_filter_by_log_text_session_ids_clause( team: Team, recording_filters: SessionRecordingsFilter, column_name="session_id" -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: if not recording_filters.console_search_query: return "", {} @@ -66,7 +66,7 @@ def _get_filter_by_log_text_session_ids_clause( def _get_filter_by_provided_session_ids_clause( recording_filters: SessionRecordingsFilter, column_name="session_id" -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: if recording_filters.session_ids is None: return "", {} @@ -111,7 +111,7 @@ class LogQuery: # a recording spans the time boundaries # TODO This is just copied from below @cached_property - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -124,8 +124,8 @@ class LogQuery: @staticmethod def _get_console_log_clause( - console_logs_filter: List[Literal["error", "warn", "info"]], - ) -> Tuple[str, Dict[str, Any]]: + console_logs_filter: list[Literal["error", "warn", "info"]], + ) -> tuple[str, dict[str, Any]]: return ( ( f"AND level in %(console_logs_levels)s", @@ -135,7 +135,7 @@ class LogQuery: else ("", {}) ) - def get_query(self) -> Tuple[str, Dict]: + def get_query(self) -> tuple[str, dict]: if not self._filter.console_search_query: return "", {} @@ -177,7 +177,7 @@ class ActorsQuery(EventQuery): pass # we have to implement this from EventQuery but don't need it - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: pass _raw_persons_query = """ @@ -195,7 +195,7 @@ class ActorsQuery(EventQuery): {filter_by_person_uuid_condition} """ - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: # we don't support PoE V1 - hopefully that's ok if self._person_on_events_mode == PersonsOnEventsMode.person_id_override_properties_on_events: return "", {} @@ -280,7 +280,7 @@ class SessionIdEventsQuery(EventQuery): pass # we have to implement this from EventQuery but don't need it - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: pass def _determine_should_join_events(self): @@ -354,7 +354,7 @@ class SessionIdEventsQuery(EventQuery): HAVING 1=1 {event_filter_having_events_condition} """ - def format_event_filter(self, entity: Entity, prepend: str, team_id: int) -> Tuple[str, Dict[str, Any]]: + def format_event_filter(self, entity: Entity, prepend: str, team_id: int) -> tuple[str, dict[str, Any]]: filter_sql, params = format_entity_filter( team_id=team_id, entity=entity, @@ -382,8 +382,8 @@ class SessionIdEventsQuery(EventQuery): @cached_property def build_event_filters(self) -> SummaryEventFiltersSQL: - event_names_to_filter: List[Union[int, str]] = [] - params: Dict = {} + event_names_to_filter: list[Union[int, str]] = [] + params: dict = {} condition_sql = "" for index, entity in enumerate(self._filter.entities): @@ -432,7 +432,7 @@ class SessionIdEventsQuery(EventQuery): params=params, ) - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: try: from ee.clickhouse.queries.groups_join_query import GroupsJoinQuery except ImportError: @@ -449,7 +449,7 @@ class SessionIdEventsQuery(EventQuery): # We want to select events beyond the range of the recording to handle the case where # a recording spans the time boundaries @cached_property - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -460,7 +460,7 @@ class SessionIdEventsQuery(EventQuery): timestamp_params["event_end_time"] = self._filter.date_to + timedelta(hours=12) return timestamp_clause, timestamp_params - def get_query(self, select_event_ids: bool = False) -> Tuple[str, Dict[str, Any]]: + def get_query(self, select_event_ids: bool = False) -> tuple[str, dict[str, Any]]: if not self._determine_should_join_events(): return "", {} @@ -564,7 +564,7 @@ class SessionIdEventsQuery(EventQuery): return persons_join, persons_select_params, persons_sub_query @cached_property - def _get_person_id_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_person_id_clause(self) -> tuple[str, dict[str, Any]]: person_id_clause = "" person_id_params = {} if self._filter.person_uuid: @@ -572,7 +572,7 @@ class SessionIdEventsQuery(EventQuery): person_id_params = {"person_uuid": self._filter.person_uuid} return person_id_clause, person_id_params - def matching_events(self) -> List[str]: + def matching_events(self) -> list[str]: self._filter.hogql_context.modifiers.personsOnEventsMode = self._person_on_events_mode query, query_params = self.get_query(select_event_ids=True) query_results = sync_execute(query, {**query_params, **self._filter.hogql_context.values}) @@ -644,7 +644,7 @@ class SessionRecordingListFromReplaySummary(EventQuery): """ @staticmethod - def _data_to_return(results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(results: list[Any]) -> list[dict[str, Any]]: default_columns = [ "session_id", "team_id", @@ -694,7 +694,7 @@ class SessionRecordingListFromReplaySummary(EventQuery): def limit(self): return self._filter.limit or self.SESSION_RECORDINGS_DEFAULT_LIMIT - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: offset = self._filter.offset or 0 base_params = { @@ -758,7 +758,7 @@ class SessionRecordingListFromReplaySummary(EventQuery): def duration_clause( self, duration_filter_type: Literal["duration", "active_seconds", "inactive_seconds"], - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: duration_clause = "" duration_params = {} if self._filter.recording_duration_filter: @@ -775,7 +775,7 @@ class SessionRecordingListFromReplaySummary(EventQuery): return duration_clause, duration_params @staticmethod - def _get_console_log_clause(console_logs_filter: List[Literal["error", "warn", "info"]]) -> str: + def _get_console_log_clause(console_logs_filter: list[Literal["error", "warn", "info"]]) -> str: # to avoid a CH migration we map from info to log when constructing the query here filters = [f"console_{'log' if log == 'info' else log}_count > 0" for log in console_logs_filter] return f"AND ({' OR '.join(filters)})" if filters else "" diff --git a/posthog/session_recordings/queries/session_recording_properties.py b/posthog/session_recordings/queries/session_recording_properties.py index e7c5544f14f..2d2ef187c04 100644 --- a/posthog/session_recordings/queries/session_recording_properties.py +++ b/posthog/session_recordings/queries/session_recording_properties.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple +from typing import TYPE_CHECKING, Any, NamedTuple from posthog.client import sync_execute from posthog.models.event.util import parse_properties @@ -14,12 +14,12 @@ class EventFiltersSQL(NamedTuple): aggregate_select_clause: str aggregate_having_clause: str where_conditions: str - params: Dict[str, Any] + params: dict[str, Any] class SessionRecordingProperties(EventQuery): _filter: SessionRecordingsFilter - _session_ids: List[str] + _session_ids: list[str] SESSION_RECORDING_PROPERTIES_ALLOWLIST = { "$os", @@ -47,7 +47,7 @@ class SessionRecordingProperties(EventQuery): GROUP BY session_id """ - def __init__(self, team: "Team", session_ids: List[str], filter: SessionRecordingsFilter): + def __init__(self, team: "Team", session_ids: list[str], filter: SessionRecordingsFilter): super().__init__(team=team, filter=filter) self._session_ids = sorted(session_ids) # Sort for stable queries @@ -56,7 +56,7 @@ class SessionRecordingProperties(EventQuery): # We want to select events beyond the range of the recording to handle the case where # a recording spans the time boundaries - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -67,11 +67,11 @@ class SessionRecordingProperties(EventQuery): timestamp_params["event_end_time"] = self._filter.date_to + timedelta(hours=12) return timestamp_clause, timestamp_params - def format_session_recording_id_filters(self) -> Tuple[str, Dict]: + def format_session_recording_id_filters(self) -> tuple[str, dict]: where_conditions = "AND session_id IN %(session_ids)s" return where_conditions, {"session_ids": self._session_ids} - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: base_params = {"team_id": self._team_id} ( events_timestamp_clause, @@ -90,7 +90,7 @@ class SessionRecordingProperties(EventQuery): {**base_params, **events_timestamp_params, **session_ids_params}, ) - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: return [ { "session_id": row[0], @@ -99,7 +99,7 @@ class SessionRecordingProperties(EventQuery): for row in results ] - def run(self) -> List: + def run(self) -> list: query, query_params = self.get_query() query_results = sync_execute(query, query_params) session_recording_properties = self._data_to_return(query_results) diff --git a/posthog/session_recordings/queries/session_replay_events.py b/posthog/session_recordings/queries/session_replay_events.py index fbb3577bf03..226d27154fd 100644 --- a/posthog/session_recordings/queries/session_replay_events.py +++ b/posthog/session_recordings/queries/session_replay_events.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Optional, Tuple, List +from typing import Optional from django.conf import settings @@ -75,7 +75,7 @@ class SessionReplayEvents: ) ) - replay_response: List[Tuple] = sync_execute( + replay_response: list[tuple] = sync_execute( query, { "team_id": team.pk, @@ -107,8 +107,8 @@ class SessionReplayEvents: ) def get_events( - self, session_id: str, team: Team, metadata: RecordingMetadata, events_to_ignore: List[str] | None - ) -> Tuple[List | None, List | None]: + self, session_id: str, team: Team, metadata: RecordingMetadata, events_to_ignore: list[str] | None + ) -> tuple[list | None, list | None]: from posthog.schema import HogQLQuery, HogQLQueryResponse from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner diff --git a/posthog/session_recordings/queries/test/session_replay_sql.py b/posthog/session_recordings/queries/test/session_replay_sql.py index b72c64dbc0f..fbec2ea0650 100644 --- a/posthog/session_recordings/queries/test/session_replay_sql.py +++ b/posthog/session_recordings/queries/test/session_replay_sql.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, Dict +from typing import Optional from uuid import uuid4 from dateutil.parser import parse @@ -113,7 +113,7 @@ def produce_replay_summary( console_log_count: Optional[int] = None, console_warn_count: Optional[int] = None, console_error_count: Optional[int] = None, - log_messages: Dict[str, List[str]] | None = None, + log_messages: dict[str, list[str]] | None = None, snapshot_source: str | None = None, ): if log_messages is None: diff --git a/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py b/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py index 1af1554415d..5abfc3727fe 100644 --- a/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py +++ b/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Dict from uuid import uuid4 from dateutil.relativedelta import relativedelta @@ -76,7 +75,7 @@ class TestClickhouseSessionRecordingsListFromSessionReplay(ClickhouseTestMixin, properties=properties, ) - def _filter_recordings_by(self, recordings_filter: Dict) -> SessionRecordingQueryResult: + def _filter_recordings_by(self, recordings_filter: dict) -> SessionRecordingQueryResult: the_filter = SessionRecordingsFilter(team=self.team, data=recordings_filter) session_recording_list_instance = SessionRecordingListFromReplaySummary(filter=the_filter, team=self.team) return session_recording_list_instance.run() diff --git a/posthog/session_recordings/realtime_snapshots.py b/posthog/session_recordings/realtime_snapshots.py index d6890c63517..8e943db34a6 100644 --- a/posthog/session_recordings/realtime_snapshots.py +++ b/posthog/session_recordings/realtime_snapshots.py @@ -1,6 +1,6 @@ import json from time import sleep -from typing import Dict, List, Optional +from typing import Optional import structlog from prometheus_client import Counter @@ -54,7 +54,7 @@ def publish_subscription(team_id: str, session_id: str) -> None: raise e -def get_realtime_snapshots(team_id: str, session_id: str, attempt_count=0) -> Optional[List[Dict]]: +def get_realtime_snapshots(team_id: str, session_id: str, attempt_count=0) -> Optional[list[dict]]: try: redis = get_client(settings.SESSION_RECORDING_REDIS_URL) key = get_key(team_id, session_id) diff --git a/posthog/session_recordings/session_recording_api.py b/posthog/session_recordings/session_recording_api.py index e7f4ac77696..d9a9fc303d1 100644 --- a/posthog/session_recordings/session_recording_api.py +++ b/posthog/session_recordings/session_recording_api.py @@ -3,7 +3,7 @@ import time from datetime import datetime, timedelta, timezone import json -from typing import Any, List, Type, cast, Dict, Tuple +from typing import Any, cast from django.conf import settings @@ -191,7 +191,7 @@ class SessionRecordingSnapshotsSerializer(serializers.Serializer): def list_recordings_response( - filter: SessionRecordingsFilter, request: request.Request, serializer_context: Dict[str, Any] + filter: SessionRecordingsFilter, request: request.Request, serializer_context: dict[str, Any] ) -> Response: (recordings, timings) = list_recordings(filter, request, context=serializer_context) response = Response(recordings) @@ -211,7 +211,7 @@ class SessionRecordingViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): sharing_enabled_actions = ["retrieve", "snapshots", "snapshot_file"] - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if isinstance(self.request.successful_authenticator, SharingAccessTokenAuthentication): return SessionRecordingSharedSerializer else: @@ -252,7 +252,7 @@ class SessionRecordingViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): "Must specify at least one event or action filter", ) - matching_events: List[str] = SessionIdEventsQuery(filter=filter, team=self.team).matching_events() + matching_events: list[str] = SessionIdEventsQuery(filter=filter, team=self.team).matching_events() return JsonResponse(data={"results": matching_events}) # Returns metadata about the recording @@ -342,9 +342,9 @@ class SessionRecordingViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): SNAPSHOT_SOURCE_REQUESTED.labels(source=source).inc() if not source: - sources: List[dict] = [] + sources: list[dict] = [] - blob_keys: List[str] | None = None + blob_keys: list[str] | None = None if recording.object_storage_path: if recording.storage_version == "2023-08-01": blob_prefix = recording.object_storage_path @@ -603,8 +603,8 @@ class SessionRecordingViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): def list_recordings( - filter: SessionRecordingsFilter, request: request.Request, context: Dict[str, Any] -) -> Tuple[Dict, Dict]: + filter: SessionRecordingsFilter, request: request.Request, context: dict[str, Any] +) -> tuple[dict, dict]: """ As we can store recordings in S3 or in Clickhouse we need to do a few things here @@ -617,7 +617,7 @@ def list_recordings( all_session_ids = filter.session_ids - recordings: List[SessionRecording] = [] + recordings: list[SessionRecording] = [] more_recordings_available = False team = context["get_team"]() @@ -655,7 +655,7 @@ def list_recordings( if all_session_ids: recordings = sorted( recordings, - key=lambda x: cast(List[str], all_session_ids).index(x.session_id), + key=lambda x: cast(list[str], all_session_ids).index(x.session_id), ) if not request.user.is_authenticated: # for mypy diff --git a/posthog/session_recordings/session_recording_helpers.py b/posthog/session_recordings/session_recording_helpers.py index 1eccc2be26e..8dfb1c0ad23 100644 --- a/posthog/session_recordings/session_recording_helpers.py +++ b/posthog/session_recordings/session_recording_helpers.py @@ -3,7 +3,8 @@ import gzip import json from collections import defaultdict from datetime import datetime, timezone -from typing import Any, Callable, Dict, Generator, List, Tuple +from typing import Any +from collections.abc import Callable, Generator from dateutil.parser import parse from prometheus_client import Counter @@ -89,10 +90,10 @@ EVENT_SUMMARY_DATA_INCLUSIONS = [ ] -Event = Dict[str, Any] +Event = dict[str, Any] -def split_replay_events(events: List[Event]) -> Tuple[List[Event], List[Event]]: +def split_replay_events(events: list[Event]) -> tuple[list[Event], list[Event]]: replay, other = [], [] for event in events: @@ -102,12 +103,12 @@ def split_replay_events(events: List[Event]) -> Tuple[List[Event], List[Event]]: # TODO is this covered by enough tests post-blob ingester rollout -def preprocess_replay_events_for_blob_ingestion(events: List[Event], max_size_bytes=1024 * 1024) -> List[Event]: +def preprocess_replay_events_for_blob_ingestion(events: list[Event], max_size_bytes=1024 * 1024) -> list[Event]: return _process_windowed_events(events, lambda x: preprocess_replay_events(x, max_size_bytes=max_size_bytes)) def preprocess_replay_events( - _events: List[Event] | Generator[Event, None, None], max_size_bytes=1024 * 1024 + _events: list[Event] | Generator[Event, None, None], max_size_bytes=1024 * 1024 ) -> Generator[Event, None, None]: """ The events going to blob ingestion are uncompressed (the compression happens in the Kafka producer) @@ -135,7 +136,7 @@ def preprocess_replay_events( window_id = events[0]["properties"].get("$window_id") snapshot_source = events[0]["properties"].get("$snapshot_source", "web") - def new_event(items: List[dict] | None = None) -> Event: + def new_event(items: list[dict] | None = None) -> Event: return { **events[0], "event": "$snapshot_items", # New event name to avoid confusion with the old $snapshot event @@ -151,7 +152,7 @@ def preprocess_replay_events( # 1. Group by $snapshot_bytes if any of the events have it if events[0]["properties"].get("$snapshot_bytes"): - current_event: Dict | None = None + current_event: dict | None = None current_event_size = 0 for event in events: @@ -208,13 +209,13 @@ def preprocess_replay_events( def _process_windowed_events( - events: List[Event], fn: Callable[[List[Any]], Generator[Event, None, None]] -) -> List[Event]: + events: list[Event], fn: Callable[[list[Any]], Generator[Event, None, None]] +) -> list[Event]: """ Helper method to simplify grouping events by window_id and session_id, processing them with the given function, and then returning the flattened list """ - result: List[Event] = [] + result: list[Event] = [] snapshots_by_session_and_window_id = defaultdict(list) for event in events: @@ -228,7 +229,7 @@ def _process_windowed_events( return result -def is_unprocessed_snapshot_event(event: Dict) -> bool: +def is_unprocessed_snapshot_event(event: dict) -> bool: try: is_snapshot = event["event"] == "$snapshot" except KeyError: @@ -274,5 +275,5 @@ def convert_to_timestamp(source: str) -> int: return int(parse(source).timestamp() * 1000) -def byte_size_dict(x: Dict | List) -> int: +def byte_size_dict(x: dict | list) -> int: return len(json.dumps(x)) diff --git a/posthog/session_recordings/snapshots/convert_legacy_snapshots.py b/posthog/session_recordings/snapshots/convert_legacy_snapshots.py index 963016d0e86..d2d4ba2c4b4 100644 --- a/posthog/session_recordings/snapshots/convert_legacy_snapshots.py +++ b/posthog/session_recordings/snapshots/convert_legacy_snapshots.py @@ -1,5 +1,4 @@ import json -from typing import Dict import structlog from prometheus_client import Histogram @@ -67,7 +66,7 @@ def _prepare_legacy_content(content: str) -> str: return _convert_legacy_format_from_lts_storage(json_content) -def _convert_legacy_format_from_lts_storage(lts_formatted_data: Dict) -> str: +def _convert_legacy_format_from_lts_storage(lts_formatted_data: dict) -> str: """ The latest version is JSONL formatted data. Each line is json containing a window_id and a data array. diff --git a/posthog/session_recordings/test/test_lts_session_recordings.py b/posthog/session_recordings/test/test_lts_session_recordings.py index 7d60d07defb..bd6dfc39d24 100644 --- a/posthog/session_recordings/test/test_lts_session_recordings.py +++ b/posthog/session_recordings/test/test_lts_session_recordings.py @@ -1,5 +1,4 @@ import uuid -from typing import List from unittest.mock import patch, MagicMock, call, Mock from rest_framework import status @@ -32,7 +31,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) session_id = str(uuid.uuid4()) lts_storage_path = "purposefully/not/what/we/would/calculate/to/prove/this/is/used" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == lts_storage_path: @@ -88,7 +87,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) session_id = str(uuid.uuid4()) lts_storage_path = "1234-5678" - def list_objects_func(_path: str) -> List[str]: + def list_objects_func(_path: str) -> list[str]: return [] mock_list_objects.side_effect = list_objects_func @@ -138,7 +137,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) session_id = str(uuid.uuid4()) lts_storage_path = "purposefully/not/what/we/would/calculate/to/prove/this/is/used" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == lts_storage_path: @@ -208,7 +207,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) session_id = str(uuid.uuid4()) lts_storage_path = "1234-5678" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: return [] mock_list_objects.side_effect = list_objects_func diff --git a/posthog/session_recordings/test/test_session_recording_helpers.py b/posthog/session_recordings/test/test_session_recording_helpers.py index b6b83e02c28..a13b131fb31 100644 --- a/posthog/session_recordings/test/test_session_recording_helpers.py +++ b/posthog/session_recordings/test/test_session_recording_helpers.py @@ -3,7 +3,7 @@ import math import random import string from datetime import datetime -from typing import Any, List, Tuple +from typing import Any import pytest from pytest_mock import MockerFixture @@ -27,7 +27,7 @@ def create_activity_data(timestamp: datetime, is_active: bool): ) -def mock_capture_flow(events: List[dict], max_size_bytes=512 * 1024) -> Tuple[List[dict], List[dict]]: +def mock_capture_flow(events: list[dict], max_size_bytes=512 * 1024) -> tuple[list[dict], list[dict]]: """ Returns the legacy events and the new flow ones """ @@ -422,7 +422,7 @@ def test_new_ingestion_groups_using_snapshot_bytes_if_possible(raw_snapshot_even "something": "small", } - events: List[Any] = [ + events: list[Any] = [ { "event": "$snapshot", "properties": { diff --git a/posthog/session_recordings/test/test_session_recordings.py b/posthog/session_recordings/test/test_session_recordings.py index 78f92c24a0a..6b0721d6730 100644 --- a/posthog/session_recordings/test/test_session_recordings.py +++ b/posthog/session_recordings/test/test_session_recordings.py @@ -1,7 +1,6 @@ import time import uuid from datetime import datetime, timedelta, timezone -from typing import List from unittest.mock import ANY, patch, MagicMock, call from urllib.parse import urlencode @@ -65,7 +64,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) # because we use `now()` in the CH queries which don't know about any frozen time # @snapshot_clickhouse_queries def test_get_session_recordings(self): - twelve_distinct_ids: List[str] = [f"user_one_{i}" for i in range(12)] + twelve_distinct_ids: list[str] = [f"user_one_{i}" for i in range(12)] user = Person.objects.create( team=self.team, @@ -132,7 +131,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) # almost duplicate of test_get_session_recordings above # but if we have multiple distinct ids on a recording the snapshot # varies which makes the snapshot useless - twelve_distinct_ids: List[str] = [f"user_one_{i}" for i in range(12)] + twelve_distinct_ids: list[str] = [f"user_one_{i}" for i in range(12)] Person.objects.create( team=self.team, @@ -577,7 +576,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) object_storage_path="an lts stored object path", ) - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == "an lts stored object path": diff --git a/posthog/settings/__init__.py b/posthog/settings/__init__.py index 455b7e8dc34..faf2e466764 100644 --- a/posthog/settings/__init__.py +++ b/posthog/settings/__init__.py @@ -13,7 +13,6 @@ https://docs.djangoproject.com/en/2.2/ref/settings/ # isort: skip_file import os -from typing import Dict, List # :TRICKY: Imported before anything else to support overloads from posthog.settings.overrides import * @@ -68,7 +67,7 @@ else: DISABLE_MMDB = get_from_env( "DISABLE_MMDB", TEST, type_cast=str_to_bool ) # plugin server setting disabling GeoIP feature -PLUGINS_PREINSTALLED_URLS: List[str] = ( +PLUGINS_PREINSTALLED_URLS: list[str] = ( os.getenv( "PLUGINS_PREINSTALLED_URLS", "https://www.npmjs.com/package/@posthog/geoip-plugin", @@ -100,7 +99,7 @@ PERSON_ON_EVENTS_V2_OVERRIDE = get_from_env("PERSON_ON_EVENTS_V2_OVERRIDE", opti # Wether to use insight queries converted to HogQL. HOGQL_INSIGHTS_OVERRIDE = get_from_env("HOGQL_INSIGHTS_OVERRIDE", optional=True, type_cast=str_to_bool) -HOOK_EVENTS: Dict[str, str] = {} +HOOK_EVENTS: dict[str, str] = {} # Support creating multiple organizations in a single instance. Requires a premium license. MULTI_ORG_ENABLED = get_from_env("MULTI_ORG_ENABLED", False, type_cast=str_to_bool) diff --git a/posthog/settings/data_stores.py b/posthog/settings/data_stores.py index f3402a74811..d175f04f07c 100644 --- a/posthog/settings/data_stores.py +++ b/posthog/settings/data_stores.py @@ -1,6 +1,5 @@ import json import os -from typing import List from urllib.parse import urlparse import dj_database_url @@ -173,7 +172,7 @@ READONLY_CLICKHOUSE_USER = os.getenv("READONLY_CLICKHOUSE_USER", None) READONLY_CLICKHOUSE_PASSWORD = os.getenv("READONLY_CLICKHOUSE_PASSWORD", None) -def _parse_kafka_hosts(hosts_string: str) -> List[str]: +def _parse_kafka_hosts(hosts_string: str) -> list[str]: hosts = [] for host in hosts_string.split(","): if "://" in host: diff --git a/posthog/settings/logs.py b/posthog/settings/logs.py index 8f41f3e6c21..f8f21294e37 100644 --- a/posthog/settings/logs.py +++ b/posthog/settings/logs.py @@ -1,7 +1,6 @@ import logging import os import threading -from typing import List import structlog @@ -27,7 +26,7 @@ def add_pid_and_tid( # To enable standard library logs to be formatted via structlog, we add this # `foreign_pre_chain` to both formatters. -foreign_pre_chain: List[structlog.types.Processor] = [ +foreign_pre_chain: list[structlog.types.Processor] = [ structlog.contextvars.merge_contextvars, structlog.processors.TimeStamper(fmt="iso"), structlog.stdlib.add_logger_name, diff --git a/posthog/settings/session_replay.py b/posthog/settings/session_replay.py index 4cd8a429aa0..429f3207dcc 100644 --- a/posthog/settings/session_replay.py +++ b/posthog/settings/session_replay.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.settings import get_from_env, get_list from posthog.utils import str_to_bool @@ -18,7 +16,7 @@ REALTIME_SNAPSHOTS_FROM_REDIS_ATTEMPT_TIMEOUT_SECONDS = get_from_env( "REALTIME_SNAPSHOTS_FROM_REDIS_ATTEMPT_TIMEOUT_SECONDS", 0.2, type_cast=float ) -REPLAY_EMBEDDINGS_ALLOWED_TEAMS: List[str] = get_list(get_from_env("REPLAY_EMBEDDINGS_ALLOWED_TEAM", "", type_cast=str)) +REPLAY_EMBEDDINGS_ALLOWED_TEAMS: list[str] = get_list(get_from_env("REPLAY_EMBEDDINGS_ALLOWED_TEAM", "", type_cast=str)) REPLAY_EMBEDDINGS_BATCH_SIZE = get_from_env("REPLAY_EMBEDDINGS_BATCH_SIZE", 10, type_cast=int) REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS = get_from_env("REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS", 30, type_cast=int) REPLAY_EMBEDDINGS_CALCULATION_CELERY_INTERVAL_SECONDS = get_from_env( diff --git a/posthog/settings/temporal.py b/posthog/settings/temporal.py index ccb5fbfb0db..ce0e72172ea 100644 --- a/posthog/settings/temporal.py +++ b/posthog/settings/temporal.py @@ -1,5 +1,4 @@ import os -from typing import Dict from posthog.settings.utils import get_list, get_from_env @@ -24,6 +23,6 @@ UNCONSTRAINED_TIMESTAMP_TEAM_IDS = get_list(os.getenv("UNCONSTRAINED_TIMESTAMP_T CLICKHOUSE_MAX_EXECUTION_TIME = get_from_env("CLICKHOUSE_MAX_EXECUTION_TIME", 0, type_cast=int) CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT = get_from_env("CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT", 10000, type_cast=int) # Comma separated list of overrides in the format "team_id:block_size" -CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES: Dict[int, int] = dict( +CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES: dict[int, int] = dict( [map(int, o.split(":")) for o in os.getenv("CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES", "").split(",") if o] # type: ignore ) diff --git a/posthog/settings/utils.py b/posthog/settings/utils.py index 6dd22dbf97c..eead270c7bd 100644 --- a/posthog/settings/utils.py +++ b/posthog/settings/utils.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, List, Optional, Set +from typing import Any, Optional +from collections.abc import Callable from django.core.exceptions import ImproperlyConfigured @@ -28,13 +29,13 @@ def get_from_env( return value -def get_list(text: str) -> List[str]: +def get_list(text: str) -> list[str]: if not text: return [] return [item.strip() for item in text.split(",")] -def get_set(text: str) -> Set[str]: +def get_set(text: str) -> set[str]: if not text: return set() return {item.strip() for item in text.split(",")} diff --git a/posthog/settings/web.py b/posthog/settings/web.py index ee6961de70e..b80c1baab02 100644 --- a/posthog/settings/web.py +++ b/posthog/settings/web.py @@ -1,7 +1,6 @@ # Web app specific settings/middleware/apps setup import os from datetime import timedelta -from typing import List from corsheaders.defaults import default_headers @@ -160,7 +159,7 @@ SOCIAL_AUTH_JSONFIELD_ENABLED = True SOCIAL_AUTH_USER_MODEL = "posthog.User" SOCIAL_AUTH_REDIRECT_IS_HTTPS = get_from_env("SOCIAL_AUTH_REDIRECT_IS_HTTPS", not DEBUG, type_cast=str_to_bool) -AUTHENTICATION_BACKENDS: List[str] = [ +AUTHENTICATION_BACKENDS: list[str] = [ "axes.backends.AxesBackend", "social_core.backends.github.GithubOAuth2", "social_core.backends.gitlab.GitLabOAuth2", diff --git a/posthog/storage/object_storage.py b/posthog/storage/object_storage.py index a1ff639b1c2..147b02436fa 100644 --- a/posthog/storage/object_storage.py +++ b/posthog/storage/object_storage.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Union, List, Dict +from typing import Optional, Union import structlog from boto3 import client @@ -26,7 +26,7 @@ class ObjectStorageClient(metaclass=abc.ABCMeta): pass @abc.abstractmethod - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: pass @abc.abstractmethod @@ -38,11 +38,11 @@ class ObjectStorageClient(metaclass=abc.ABCMeta): pass @abc.abstractmethod - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: pass @abc.abstractmethod - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: pass @abc.abstractmethod @@ -60,7 +60,7 @@ class UnavailableStorage(ObjectStorageClient): def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]: pass - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: pass def read(self, bucket: str, key: str) -> Optional[str]: @@ -69,10 +69,10 @@ class UnavailableStorage(ObjectStorageClient): def read_bytes(self, bucket: str, key: str) -> Optional[bytes]: pass - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: pass - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: pass def copy_objects(self, bucket: str, source_prefix: str, target_prefix: str) -> int | None: @@ -103,7 +103,7 @@ class ObjectStorage(ObjectStorageClient): capture_exception(e) return None - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: try: s3_response = self.aws_client.list_objects_v2(Bucket=bucket, Prefix=prefix) if s3_response.get("Contents"): @@ -143,7 +143,7 @@ class ObjectStorage(ObjectStorageClient): capture_exception(e) raise ObjectStorageError("read failed") from e - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: try: self.aws_client.put_object_tagging( Bucket=bucket, @@ -155,7 +155,7 @@ class ObjectStorage(ObjectStorageClient): capture_exception(e) raise ObjectStorageError("tag failed") from e - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: s3_response = {} try: s3_response = self.aws_client.put_object(Bucket=bucket, Body=content, Key=key, **(extras or {})) @@ -218,7 +218,7 @@ def object_storage_client() -> ObjectStorageClient: return _client -def write(file_name: str, content: Union[str, bytes], extras: Dict | None = None) -> None: +def write(file_name: str, content: Union[str, bytes], extras: dict | None = None) -> None: return object_storage_client().write( bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name, @@ -227,7 +227,7 @@ def write(file_name: str, content: Union[str, bytes], extras: Dict | None = None ) -def tag(file_name: str, tags: Dict[str, str]) -> None: +def tag(file_name: str, tags: dict[str, str]) -> None: return object_storage_client().tag(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name, tags=tags) @@ -239,7 +239,7 @@ def read_bytes(file_name: str) -> Optional[bytes]: return object_storage_client().read_bytes(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name) -def list_objects(prefix: str) -> Optional[List[str]]: +def list_objects(prefix: str) -> Optional[list[str]]: return object_storage_client().list_objects(bucket=settings.OBJECT_STORAGE_BUCKET, prefix=prefix) diff --git a/posthog/storage/test/test_object_storage.py b/posthog/storage/test/test_object_storage.py index f24114911ba..3737ca155ee 100644 --- a/posthog/storage/test/test_object_storage.py +++ b/posthog/storage/test/test_object_storage.py @@ -57,7 +57,7 @@ class TestStorage(APIBaseTest): chunk_id = uuid.uuid4() name = f"{session_id}/{0}-{chunk_id}" file_name = f"{TEST_BUCKET}/test_write_and_read_works_with_known_content/{name}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") self.assertEqual(read(file_name), "my content") def test_can_generate_presigned_url_for_existing_file(self) -> None: @@ -66,7 +66,7 @@ class TestStorage(APIBaseTest): chunk_id = uuid.uuid4() name = f"{session_id}/{0}-{chunk_id}" file_name = f"{TEST_BUCKET}/test_can_generate_presigned_url_for_existing_file/{name}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") presigned_url = get_presigned_url(file_name) assert presigned_url is not None @@ -93,7 +93,7 @@ class TestStorage(APIBaseTest): for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") listing = list_objects(prefix=f"{TEST_BUCKET}/{shared_prefix}") @@ -117,7 +117,7 @@ class TestStorage(APIBaseTest): for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") copied_count = copy_objects( source_prefix=f"{TEST_BUCKET}/{shared_prefix}", @@ -142,7 +142,7 @@ class TestStorage(APIBaseTest): for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") copied_count = copy_objects( source_prefix=f"nothing_here", diff --git a/posthog/tasks/calculate_cohort.py b/posthog/tasks/calculate_cohort.py index 7dba512d6c8..35ccc8fe9ec 100644 --- a/posthog/tasks/calculate_cohort.py +++ b/posthog/tasks/calculate_cohort.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List, Optional +from typing import Any, Optional import structlog from celery import shared_task @@ -53,7 +53,7 @@ def calculate_cohort_ch(cohort_id: int, pending_version: int, initiating_user_id @shared_task(ignore_result=True, max_retries=1) -def calculate_cohort_from_list(cohort_id: int, items: List[str]) -> None: +def calculate_cohort_from_list(cohort_id: int, items: list[str]) -> None: start_time = time.time() cohort = Cohort.objects.get(pk=cohort_id) @@ -62,7 +62,7 @@ def calculate_cohort_from_list(cohort_id: int, items: List[str]) -> None: @shared_task(ignore_result=True, max_retries=1) -def insert_cohort_from_insight_filter(cohort_id: int, filter_data: Dict[str, Any]) -> None: +def insert_cohort_from_insight_filter(cohort_id: int, filter_data: dict[str, Any]) -> None: from posthog.api.cohort import ( insert_cohort_actors_into_ch, insert_cohort_people_into_pg, diff --git a/posthog/tasks/email.py b/posthog/tasks/email.py index d06d15ee12a..2d7198dc2d8 100644 --- a/posthog/tasks/email.py +++ b/posthog/tasks/email.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import List, Optional +from typing import Optional import posthoganalytics import structlog @@ -281,7 +281,7 @@ def send_async_migration_errored_email(migration_key: str, time: str, error: str send_message_to_all_staff_users(message) -def get_users_for_orgs_with_no_ingested_events(org_created_from: datetime, org_created_to: datetime) -> List[User]: +def get_users_for_orgs_with_no_ingested_events(org_created_from: datetime, org_created_to: datetime) -> list[User]: # Get all users for organization that haven't ingested any events users = [] recently_created_organizations = Organization.objects.filter( diff --git a/posthog/tasks/exports/csv_exporter.py b/posthog/tasks/exports/csv_exporter.py index 22cf1004ac0..489bf64e740 100644 --- a/posthog/tasks/exports/csv_exporter.py +++ b/posthog/tasks/exports/csv_exporter.py @@ -1,6 +1,7 @@ import datetime import io -from typing import Any, Dict, List, Optional, Tuple, Generator +from typing import Any, Optional +from collections.abc import Generator from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse import requests @@ -53,14 +54,14 @@ logger = structlog.get_logger(__name__) # 5. We save the final blob output and update the ExportedAsset -def add_query_params(url: str, params: Dict[str, str]) -> str: +def add_query_params(url: str, params: dict[str, str]) -> str: """ Uses parse_qsl because parse_qs turns all values into lists but doesn't unbox them when re-encoded """ parsed = urlparse(url) query_params = parse_qsl(parsed.query, keep_blank_values=True) - update_params: List[Tuple[str, Any]] = [] + update_params: list[tuple[str, Any]] = [] for param, value in query_params: if param in params: update_params.append((param, params.pop(param))) @@ -265,7 +266,7 @@ def get_from_hogql_query(exported_asset: ExportedAsset, limit: int, resource: di def _export_to_dict(exported_asset: ExportedAsset, limit: int) -> Any: resource = exported_asset.export_context - columns: List[str] = resource.get("columns", []) + columns: list[str] = resource.get("columns", []) returned_rows: Generator[Any, None, None] if resource.get("source"): @@ -310,7 +311,7 @@ def _export_to_excel(exported_asset: ExportedAsset, limit: int) -> None: for row_num, row_data in enumerate(renderer.tablize(all_csv_rows, header=render_context.get("header"))): for col_num, value in enumerate(row_data): - if value is not None and not isinstance(value, (str, int, float, bool)): + if value is not None and not isinstance(value, str | int | float | bool): value = str(value) worksheet.cell(row=row_num + 1, column=col_num + 1, value=value) diff --git a/posthog/tasks/exports/ordered_csv_renderer.py b/posthog/tasks/exports/ordered_csv_renderer.py index d183ee874b2..5b70e9bed91 100644 --- a/posthog/tasks/exports/ordered_csv_renderer.py +++ b/posthog/tasks/exports/ordered_csv_renderer.py @@ -1,6 +1,7 @@ import itertools from collections import OrderedDict -from typing import Any, Dict, Generator +from typing import Any +from collections.abc import Generator from more_itertools import unique_everseen from rest_framework_csv.renderers import CSVRenderer @@ -28,7 +29,7 @@ class OrderedCsvRenderer( # Get the set of all unique headers, and sort them. unique_fields = list(unique_everseen(itertools.chain(*(item.keys() for item in data)))) - ordered_fields: Dict[str, Any] = OrderedDict() + ordered_fields: dict[str, Any] = OrderedDict() for item in unique_fields: field = item.split(".") field = field[0] diff --git a/posthog/tasks/exports/test/test_csv_exporter.py b/posthog/tasks/exports/test/test_csv_exporter.py index 87d731dd6a1..d1c03ea5a3e 100644 --- a/posthog/tasks/exports/test/test_csv_exporter.py +++ b/posthog/tasks/exports/test/test_csv_exporter.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest import mock from unittest.mock import MagicMock, Mock, patch, ANY @@ -97,7 +97,7 @@ class TestCSVExporter(APIBaseTest): patched_request.return_value = mock_response yield patched_request - def _create_asset(self, extra_context: Optional[Dict] = None) -> ExportedAsset: + def _create_asset(self, extra_context: Optional[dict] = None) -> ExportedAsset: if extra_context is None: extra_context = {} @@ -588,7 +588,7 @@ class TestCSVExporter(APIBaseTest): self.assertEqual(lines[0], "error") self.assertEqual(lines[1], "No data available or unable to format for export.") - def _split_to_dict(self, url: str) -> Dict[str, Any]: + def _split_to_dict(self, url: str) -> dict[str, Any]: first_split_parts = url.split("?") assert len(first_split_parts) == 2 return {bits[0]: bits[1] for bits in [param.split("=") for param in first_split_parts[1].split("&")]} diff --git a/posthog/tasks/sync_all_organization_available_features.py b/posthog/tasks/sync_all_organization_available_features.py index 87e425fa5ca..ec16a0e0a5a 100644 --- a/posthog/tasks/sync_all_organization_available_features.py +++ b/posthog/tasks/sync_all_organization_available_features.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast +from typing import cast +from collections.abc import Sequence from posthog.models.organization import Organization diff --git a/posthog/tasks/test/test_calculate_cohort.py b/posthog/tasks/test/test_calculate_cohort.py index 0c81076c8fa..ff2c534a910 100644 --- a/posthog/tasks/test/test_calculate_cohort.py +++ b/posthog/tasks/test/test_calculate_cohort.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from unittest.mock import MagicMock, patch from freezegun import freeze_time diff --git a/posthog/tasks/test/test_email.py b/posthog/tasks/test/test_email.py index 447d0d442bf..b89127b48b7 100644 --- a/posthog/tasks/test/test_email.py +++ b/posthog/tasks/test/test_email.py @@ -1,5 +1,4 @@ import datetime as dt -from typing import Tuple from unittest.mock import MagicMock, patch import pytest @@ -28,7 +27,7 @@ from posthog.tasks.test.utils_email_tests import mock_email_messages from posthog.test.base import APIBaseTest, ClickhouseTestMixin -def create_org_team_and_user(creation_date: str, email: str, ingested_event: bool = False) -> Tuple[Organization, User]: +def create_org_team_and_user(creation_date: str, email: str, ingested_event: bool = False) -> tuple[Organization, User]: with freeze_time(creation_date): org = Organization.objects.create(name="too_late_org") Team.objects.create(organization=org, name="Default Project", ingested_event=ingested_event) diff --git a/posthog/tasks/test/test_usage_report.py b/posthog/tasks/test/test_usage_report.py index d977f27560b..286e1a623f8 100644 --- a/posthog/tasks/test/test_usage_report.py +++ b/posthog/tasks/test/test_usage_report.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from unittest.mock import ANY, MagicMock, Mock, call, patch from uuid import uuid4 @@ -324,14 +324,14 @@ class UsageReport(APIBaseTest, ClickhouseTestMixin, ClickhouseDestroyTablesMixin flush_persons_and_events() - def _select_report_by_org_id(self, org_id: str, reports: List[Dict]) -> Dict: + def _select_report_by_org_id(self, org_id: str, reports: list[dict]) -> dict: return next(report for report in reports if report["organization_id"] == org_id) def _create_plugin(self, name: str, enabled: bool) -> None: plugin = Plugin.objects.create(organization_id=self.team.organization.pk, name=name) PluginConfig.objects.create(plugin=plugin, enabled=enabled, order=1) - def _test_usage_report(self) -> List[dict]: + def _test_usage_report(self) -> list[dict]: with self.settings(SITE_URL="http://test.posthog.com"): self._create_sample_usage_data() self._create_plugin("Installed but not enabled", False) diff --git a/posthog/tasks/test/utils_email_tests.py b/posthog/tasks/test/utils_email_tests.py index d9be8cdd3bc..ccb998b3dc3 100644 --- a/posthog/tasks/test/utils_email_tests.py +++ b/posthog/tasks/test/utils_email_tests.py @@ -1,12 +1,12 @@ import os -from typing import Any, List +from typing import Any from unittest.mock import MagicMock from posthog.email import EmailMessage from posthog.utils import get_absolute_path -def mock_email_messages(MockEmailMessage: MagicMock, path: str = "tasks/test/__emails__/") -> List[Any]: +def mock_email_messages(MockEmailMessage: MagicMock, path: str = "tasks/test/__emails__/") -> list[Any]: """ Takes a mocked EmailMessage class and returns a list of all subsequently created EmailMessage instances The "send" method is spyed on to write the generated email to a file diff --git a/posthog/tasks/usage_report.py b/posthog/tasks/usage_report.py index 958601d1ec3..727a93f1e0c 100644 --- a/posthog/tasks/usage_report.py +++ b/posthog/tasks/usage_report.py @@ -4,16 +4,13 @@ from collections import Counter from datetime import datetime from typing import ( Any, - Dict, - List, Literal, Optional, - Sequence, - Tuple, TypedDict, Union, cast, ) +from collections.abc import Sequence import requests import structlog @@ -52,8 +49,15 @@ from posthog.utils import ( logger = structlog.get_logger(__name__) -Period = TypedDict("Period", {"start_inclusive": str, "end_inclusive": str}) -TableSizes = TypedDict("TableSizes", {"posthog_event": int, "posthog_sessionrecordingevent": int}) + +class Period(TypedDict): + start_inclusive: str + end_inclusive: str + + +class TableSizes(TypedDict): + posthog_event: int + posthog_sessionrecordingevent: int CH_BILLING_SETTINGS = { @@ -133,13 +137,13 @@ class InstanceMetadata: product: str helm: Optional[dict] clickhouse_version: Optional[str] - users_who_logged_in: Optional[List[Dict[str, Union[str, int]]]] + users_who_logged_in: Optional[list[dict[str, Union[str, int]]]] users_who_logged_in_count: Optional[int] - users_who_signed_up: Optional[List[Dict[str, Union[str, int]]]] + users_who_signed_up: Optional[list[dict[str, Union[str, int]]]] users_who_signed_up_count: Optional[int] table_sizes: Optional[TableSizes] - plugins_installed: Optional[Dict] - plugins_enabled: Optional[Dict] + plugins_installed: Optional[dict] + plugins_enabled: Optional[dict] instance_tag: str @@ -151,7 +155,7 @@ class OrgReport(UsageReportCounters): organization_created_at: str organization_user_count: int team_count: int - teams: Dict[str, UsageReportCounters] + teams: dict[str, UsageReportCounters] @dataclasses.dataclass @@ -163,7 +167,7 @@ def fetch_table_size(table_name: str) -> int: return fetch_sql("SELECT pg_total_relation_size(%s) as size", (table_name,))[0].size -def fetch_sql(sql_: str, params: Tuple[Any, ...]) -> List[Any]: +def fetch_sql(sql_: str, params: tuple[Any, ...]) -> list[Any]: with connection.cursor() as cursor: cursor.execute(sql.SQL(sql_), params) return namedtuplefetchall(cursor) @@ -178,7 +182,7 @@ def get_product_name(realm: str, has_license: bool) -> str: return "unknown" -def get_instance_metadata(period: Tuple[datetime, datetime]) -> InstanceMetadata: +def get_instance_metadata(period: tuple[datetime, datetime]) -> InstanceMetadata: has_license = False if settings.EE_AVAILABLE: @@ -288,7 +292,7 @@ def get_org_owner_or_first_user(organization_id: str) -> Optional[User]: @shared_task(**USAGE_REPORT_TASK_KWARGS, max_retries=3) -def send_report_to_billing_service(org_id: str, report: Dict[str, Any]) -> None: +def send_report_to_billing_service(org_id: str, report: dict[str, Any]) -> None: if not settings.EE_AVAILABLE: return @@ -340,7 +344,7 @@ def capture_event( pha_client: Client, name: str, organization_id: str, - properties: Dict[str, Any], + properties: dict[str, Any], timestamp: Optional[Union[datetime, str]] = None, ) -> None: if timestamp and isinstance(timestamp, str): @@ -373,7 +377,7 @@ def capture_event( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_lifetime() -> List[Tuple[int, int]]: +def get_teams_with_event_count_lifetime() -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(1) as count @@ -390,7 +394,7 @@ def get_teams_with_event_count_lifetime() -> List[Tuple[int, int]]: @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_billable_event_count_in_period( begin: datetime, end: datetime, count_distinct: bool = False -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # count only unique events # Duplicate events will be eventually removed by ClickHouse and likely came from our library or pipeline. # We shouldn't bill for these. However counting unique events is more expensive, and likely to fail on longer time ranges. @@ -420,7 +424,7 @@ def get_teams_with_billable_event_count_in_period( @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_billable_enhanced_persons_event_count_in_period( begin: datetime, end: datetime, count_distinct: bool = False -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # count only unique events # Duplicate events will be eventually removed by ClickHouse and likely came from our library or pipeline. # We shouldn't bill for these. However counting unique events is more expensive, and likely to fail on longer time ranges. @@ -448,7 +452,7 @@ def get_teams_with_billable_enhanced_persons_event_count_in_period( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(1) as count @@ -466,7 +470,7 @@ def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datet @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> List[Tuple[int, str, int]]: +def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> list[tuple[int, str, int]]: results = sync_execute( """ SELECT team_id, JSONExtractString(properties, '$lib') as lib, COUNT(1) as count @@ -483,7 +487,7 @@ def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> List[Tu @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> List[Tuple[int, str, int]]: +def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> list[tuple[int, str, int]]: results = sync_execute( """ SELECT team_id, event, COUNT(1) as count @@ -500,7 +504,7 @@ def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> List[T @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: previous_begin = begin - (end - begin) result = sync_execute( @@ -531,7 +535,7 @@ def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_recording_count_total() -> List[Tuple[int, int]]: +def get_teams_with_recording_count_total() -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(distinct session_id) as count @@ -549,10 +553,10 @@ def get_teams_with_recording_count_total() -> List[Tuple[int, int]]: def get_teams_with_hogql_metric( begin: datetime, end: datetime, - query_types: List[str], + query_types: list[str], access_method: str = "", metric: Literal["read_bytes", "read_rows", "query_duration_ms"] = "read_bytes", -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: if metric not in ["read_bytes", "read_rows", "query_duration_ms"]: # :TRICKY: Inlined into the query below. raise ValueError(f"Invalid metric {metric}") @@ -586,7 +590,7 @@ def get_teams_with_hogql_metric( @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_feature_flag_requests_count_in_period( begin: datetime, end: datetime, request_type: FlagRequestType -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # depending on the region, events are stored in different teams team_to_query = 1 if get_instance_region() == "EU" else 2 validity_token = settings.DECIDE_BILLING_ANALYTICS_TOKEN @@ -620,7 +624,7 @@ def get_teams_with_feature_flag_requests_count_in_period( def get_teams_with_survey_responses_count_in_period( begin: datetime, end: datetime, -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: results = sync_execute( """ SELECT team_id, COUNT() as count @@ -638,7 +642,7 @@ def get_teams_with_survey_responses_count_in_period( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: team_to_query = 1 if get_instance_region() == "EU" else 2 # dedup by job id incase there were duplicates sent @@ -668,7 +672,7 @@ def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> List def capture_report( capture_event_name: str, org_id: str, - full_report_dict: Dict[str, Any], + full_report_dict: dict[str, Any], at_date: Optional[datetime] = None, ) -> None: pha_client = Client("sTMFPsFhdP1Ssg") @@ -695,7 +699,7 @@ def has_non_zero_usage(report: FullUsageReport) -> bool: ) -def convert_team_usage_rows_to_dict(rows: List[Union[dict, Tuple[int, int]]]) -> Dict[int, int]: +def convert_team_usage_rows_to_dict(rows: list[Union[dict, tuple[int, int]]]) -> dict[int, int]: team_id_map = {} for row in rows: if isinstance(row, dict) and "team_id" in row: @@ -708,7 +712,7 @@ def convert_team_usage_rows_to_dict(rows: List[Union[dict, Tuple[int, int]]]) -> return team_id_map -def _get_all_usage_data(period_start: datetime, period_end: datetime) -> Dict[str, Any]: +def _get_all_usage_data(period_start: datetime, period_end: datetime) -> dict[str, Any]: """ Gets all usage data for the specified period. Clickhouse is good at counting things so we count across all teams rather than doing it one by one @@ -867,7 +871,7 @@ def _get_all_usage_data(period_start: datetime, period_end: datetime) -> Dict[st } -def _get_all_usage_data_as_team_rows(period_start: datetime, period_end: datetime) -> Dict[str, Any]: +def _get_all_usage_data_as_team_rows(period_start: datetime, period_end: datetime) -> dict[str, Any]: """ Gets all usage data for the specified period as a map of team_id -> value. This makes it faster to access the data than looping over all_data to find what we want. @@ -887,7 +891,7 @@ def _get_teams_for_usage_reports() -> Sequence[Team]: ) -def _get_team_report(all_data: Dict[str, Any], team: Team) -> UsageReportCounters: +def _get_team_report(all_data: dict[str, Any], team: Team) -> UsageReportCounters: decide_requests_count_in_month = all_data["teams_with_decide_requests_count_in_month"].get(team.id, 0) decide_requests_count_in_period = all_data["teams_with_decide_requests_count_in_period"].get(team.id, 0) local_evaluation_requests_count_in_period = all_data["teams_with_local_evaluation_requests_count_in_period"].get( @@ -942,7 +946,7 @@ def _get_team_report(all_data: Dict[str, Any], team: Team) -> UsageReportCounter def _add_team_report_to_org_reports( - org_reports: Dict[str, OrgReport], + org_reports: dict[str, OrgReport], team: Team, team_report: UsageReportCounters, period_start: datetime, @@ -975,12 +979,12 @@ def _add_team_report_to_org_reports( ) -def _get_all_org_reports(period_start: datetime, period_end: datetime) -> Dict[str, OrgReport]: +def _get_all_org_reports(period_start: datetime, period_end: datetime) -> dict[str, OrgReport]: all_data = _get_all_usage_data_as_team_rows(period_start, period_end) teams = _get_teams_for_usage_reports() - org_reports: Dict[str, OrgReport] = {} + org_reports: dict[str, OrgReport] = {} print("Generating reports for teams...") # noqa T201 time_now = datetime.now() @@ -1000,7 +1004,7 @@ def _get_full_org_usage_report(org_report: OrgReport, instance_metadata: Instanc ) -def _get_full_org_usage_report_as_dict(full_report: FullUsageReport) -> Dict[str, Any]: +def _get_full_org_usage_report_as_dict(full_report: FullUsageReport) -> dict[str, Any]: return dataclasses.asdict(full_report) diff --git a/posthog/tasks/verify_persons_data_in_sync.py b/posthog/tasks/verify_persons_data_in_sync.py index 02a53b0176c..5ed2a3ec074 100644 --- a/posthog/tasks/verify_persons_data_in_sync.py +++ b/posthog/tasks/verify_persons_data_in_sync.py @@ -1,7 +1,7 @@ import json from collections import Counter, defaultdict from datetime import timedelta -from typing import Any, Dict, List +from typing import Any import structlog from celery import shared_task @@ -80,7 +80,7 @@ def verify_persons_data_in_sync( return results -def _team_integrity_statistics(person_data: List[Any]) -> Counter: +def _team_integrity_statistics(person_data: list[Any]) -> Counter: person_ids = [id for id, _, _ in person_data] person_uuids = [uuid for _, uuid, _ in person_data] team_ids = list({team_id for _, _, team_id in person_data}) @@ -159,8 +159,8 @@ def _emit_metrics(integrity_results: Counter) -> None: statsd.gauge(f"posthog_person_integrity_{key}", value) -def _index_by(collection: List[Any], key_fn: Any, flat: bool = True) -> Dict: - result: Dict = {} if flat else defaultdict(list) +def _index_by(collection: list[Any], key_fn: Any, flat: bool = True) -> dict: + result: dict = {} if flat else defaultdict(list) for item in collection: if flat: result[key_fn(item)] = item diff --git a/posthog/templatetags/posthog_assets.py b/posthog/templatetags/posthog_assets.py index 422bd687d9a..dd8a1c1bb19 100644 --- a/posthog/templatetags/posthog_assets.py +++ b/posthog/templatetags/posthog_assets.py @@ -1,5 +1,4 @@ import re -from typing import List from django.conf import settings from django.template import Library @@ -26,7 +25,7 @@ def absolute_asset_url(path: str) -> str: @register.simple_tag -def human_social_providers(providers: List[str]) -> str: +def human_social_providers(providers: list[str]) -> str: """ Returns a human-friendly name for a social login provider. Example: diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index c522a75bce2..68ea47c19c6 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -188,14 +188,14 @@ def iter_records( timestamp_predicates = "" if fields is None: - query_fields = ",".join((f"{field['expression']} AS {field['alias']}" for field in default_fields())) + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in default_fields()) else: if "_inserted_at" not in [field["alias"] for field in fields]: control_fields = [BatchExportField(expression="COALESCE(inserted_at, _timestamp)", alias="_inserted_at")] else: control_fields = [] - query_fields = ",".join((f"{field['expression']} AS {field['alias']}" for field in fields + control_fields)) + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields) query = SELECT_QUERY_TEMPLATE.substitute( fields=query_fields, @@ -219,8 +219,7 @@ def iter_records( else: query_parameters = base_query_parameters - for record_batch in client.stream_query_as_arrow(query, query_parameters=query_parameters): - yield record_batch + yield from client.stream_query_as_arrow(query, query_parameters=query_parameters) def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt.datetime, dt.datetime]: diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 6ebede565bc..a4c1712a12e 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -98,7 +98,7 @@ async def copy_tsv_to_postgres( # TODO: Switch to binary encoding as CSV has a million edge cases. sql.SQL("COPY {table_name} ({fields}) FROM STDIN WITH (FORMAT CSV, DELIMITER '\t')").format( table_name=sql.Identifier(table_name), - fields=sql.SQL(",").join((sql.Identifier(column) for column in schema_columns)), + fields=sql.SQL(",").join(sql.Identifier(column) for column in schema_columns), ) ) as copy: while data := tsv_file.read(): diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index c769862af96..373312303be 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -283,7 +283,7 @@ async def create_table_in_snowflake( table_name: fields: An iterable of (name, type) tuples representing the fields of the table. """ - field_ddl = ", ".join((f'"{field[0]}" {field[1]}' for field in fields)) + field_ddl = ", ".join(f'"{field[0]}" {field[1]}' for field in fields) await execute_async_query( connection, diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index f165ae070a8..c10ede32d77 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -24,8 +24,7 @@ def peek_first_and_rewind( def rewind_gen() -> collections.abc.Generator[T, None, None]: """Yield the item we popped to rewind the generator.""" yield first - for i in gen: - yield i + yield from gen return (first, rewind_gen()) diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index d548d3871d8..2640bf95c1f 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -24,7 +24,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: return b"NULL" case uuid.UUID(): - return f"{quote_char}{data}{quote_char}".encode("utf-8") + return f"{quote_char}{data}{quote_char}".encode() case int() | float(): return b"%d" % data @@ -35,8 +35,8 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: timezone_arg = f", '{data:%Z}'" if data.microsecond == 0: - return f"toDateTime('{data:%Y-%m-%d %H:%M:%S}'{timezone_arg})".encode("utf-8") - return f"toDateTime64('{data:%Y-%m-%d %H:%M:%S.%f}', 6{timezone_arg})".encode("utf-8") + return f"toDateTime('{data:%Y-%m-%d %H:%M:%S}'{timezone_arg})".encode() + return f"toDateTime64('{data:%Y-%m-%d %H:%M:%S.%f}', 6{timezone_arg})".encode() case list(): encoded_data = [encode_clickhouse_data(value) for value in data] @@ -62,7 +62,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: value = str(value) encoded_data.append( - f'"{str(key)}"'.encode("utf-8") + b":" + encode_clickhouse_data(value, quote_char=quote_char) + f'"{str(key)}"'.encode() + b":" + encode_clickhouse_data(value, quote_char=quote_char) ) result = b"{" + b",".join(encoded_data) + b"}" @@ -71,7 +71,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: case _: str_data = str(data) str_data = str_data.replace("\\", "\\\\").replace("'", "\\'") - return f"{quote_char}{str_data}{quote_char}".encode("utf-8") + return f"{quote_char}{str_data}{quote_char}".encode() class ClickHouseError(Exception): @@ -355,8 +355,7 @@ class ClickHouseClient: """ with self.post_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: with pa.ipc.open_stream(pa.PythonFile(response.raw)) as reader: - for batch in reader: - yield batch + yield from reader async def __aenter__(self): """Enter method part of the AsyncContextManager protocol.""" diff --git a/posthog/temporal/common/codec.py b/posthog/temporal/common/codec.py index faf91c31173..42e775a24ba 100644 --- a/posthog/temporal/common/codec.py +++ b/posthog/temporal/common/codec.py @@ -1,5 +1,5 @@ import base64 -from typing import Iterable +from collections.abc import Iterable from cryptography.fernet import Fernet from temporalio.api.common.v1 import Payload diff --git a/posthog/temporal/common/sentry.py b/posthog/temporal/common/sentry.py index 290cc0182d2..81af9367914 100644 --- a/posthog/temporal/common/sentry.py +++ b/posthog/temporal/common/sentry.py @@ -1,5 +1,5 @@ from dataclasses import is_dataclass -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from temporalio import activity, workflow from temporalio.worker import ( @@ -83,5 +83,5 @@ class SentryInterceptor(Interceptor): def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput - ) -> Optional[Type[WorkflowInboundInterceptor]]: + ) -> Optional[type[WorkflowInboundInterceptor]]: return _SentryWorkflowInterceptor diff --git a/posthog/temporal/common/utils.py b/posthog/temporal/common/utils.py index 022c8270d77..e8e03332c1a 100644 --- a/posthog/temporal/common/utils.py +++ b/posthog/temporal/common/utils.py @@ -103,7 +103,7 @@ HeartbeatType = typing.TypeVar("HeartbeatType", bound=HeartbeatDetails) async def should_resume_from_activity_heartbeat( - activity, heartbeat_type: typing.Type[HeartbeatType], logger + activity, heartbeat_type: type[HeartbeatType], logger ) -> tuple[bool, HeartbeatType | None]: """Check if a batch export should resume from an activity's heartbeat details. diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 9c9245e003d..dc111fb4b83 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -34,7 +34,6 @@ from posthog.warehouse.models import ( ExternalDataSource, ) from posthog.temporal.common.logger import bind_temporal_worker_logger -from typing import Dict @dataclasses.dataclass @@ -67,7 +66,7 @@ class ValidateSchemaInputs: team_id: int schema_id: uuid.UUID table_schema: TSchemaTables - table_row_counts: Dict[str, int] + table_row_counts: dict[str, int] @activity.defn diff --git a/posthog/temporal/data_imports/pipelines/hubspot/__init__.py b/posthog/temporal/data_imports/pipelines/hubspot/__init__.py index 3ffa3c8ffa1..49d84aa41f2 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/__init__.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/__init__.py @@ -23,7 +23,8 @@ python >>> resources = hubspot(api_key="hubspot_access_code") """ -from typing import Literal, Sequence, Iterator, Iterable +from typing import Literal +from collections.abc import Sequence, Iterator, Iterable import dlt from dlt.common.typing import TDataItems @@ -114,13 +115,11 @@ def crm_objects( if len(props) > 10000: raise ValueError( - ( - "Your request to Hubspot is too long to process. " - "Maximum allowed query length is 10000 symbols, while " - f"your list of properties `{props[:200]}`... is {len(props)} " - "symbols long. Use the `props` argument of the resource to " - "set the list of properties to extract from the endpoint." - ) + "Your request to Hubspot is too long to process. " + "Maximum allowed query length is 10000 symbols, while " + f"your list of properties `{props[:200]}`... is {len(props)} " + "symbols long. Use the `props` argument of the resource to " + "set the list of properties to extract from the endpoint." ) params = {"properties": props, "limit": 100} diff --git a/posthog/temporal/data_imports/pipelines/hubspot/auth.py b/posthog/temporal/data_imports/pipelines/hubspot/auth.py index 490552cfe23..b88aa731499 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/auth.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/auth.py @@ -1,6 +1,5 @@ import requests from django.conf import settings -from typing import Tuple def refresh_access_token(refresh_token: str) -> str: @@ -21,7 +20,7 @@ def refresh_access_token(refresh_token: str) -> str: return res.json()["access_token"] -def get_access_token_from_code(code: str, redirect_uri: str) -> Tuple[str, str]: +def get_access_token_from_code(code: str, redirect_uri: str) -> tuple[str, str]: res = requests.post( "https://api.hubapi.com/oauth/v1/token", data={ diff --git a/posthog/temporal/data_imports/pipelines/hubspot/helpers.py b/posthog/temporal/data_imports/pipelines/hubspot/helpers.py index 0ef03b6db23..d47616f251a 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/helpers.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/helpers.py @@ -1,7 +1,8 @@ """Hubspot source helpers""" import urllib.parse -from typing import Iterator, Dict, Any, List, Optional +from typing import Any, Optional +from collections.abc import Iterator from dlt.sources.helpers import requests import requests as http_requests @@ -16,7 +17,7 @@ def get_url(endpoint: str) -> str: return urllib.parse.urljoin(BASE_URL, endpoint) -def _get_headers(api_key: str) -> Dict[str, str]: +def _get_headers(api_key: str) -> dict[str, str]: """ Return a dictionary of HTTP headers to use for API requests, including the specified API key. @@ -32,7 +33,7 @@ def _get_headers(api_key: str) -> Dict[str, str]: return {"authorization": f"Bearer {api_key}"} -def extract_property_history(objects: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]: +def extract_property_history(objects: list[dict[str, Any]]) -> Iterator[dict[str, Any]]: for item in objects: history = item.get("propertiesWithHistory") if not history: @@ -49,8 +50,8 @@ def fetch_property_history( endpoint: str, api_key: str, props: str, - params: Optional[Dict[str, Any]] = None, -) -> Iterator[List[Dict[str, Any]]]: + params: Optional[dict[str, Any]] = None, +) -> Iterator[list[dict[str, Any]]]: """Fetch property history from the given CRM endpoint. Args: @@ -91,8 +92,8 @@ def fetch_property_history( def fetch_data( - endpoint: str, api_key: str, refresh_token: str, params: Optional[Dict[str, Any]] = None -) -> Iterator[List[Dict[str, Any]]]: + endpoint: str, api_key: str, refresh_token: str, params: Optional[dict[str, Any]] = None +) -> Iterator[list[dict[str, Any]]]: """ Fetch data from HUBSPOT endpoint using a specified API key and yield the properties of each result. For paginated endpoint this function yields item from all pages. @@ -141,7 +142,7 @@ def fetch_data( # Yield the properties of each result in the API response while _data is not None: if "results" in _data: - _objects: List[Dict[str, Any]] = [] + _objects: list[dict[str, Any]] = [] for _result in _data["results"]: _obj = _result.get("properties", _result) if "id" not in _obj and "id" in _result: @@ -176,7 +177,7 @@ def fetch_data( _data = None -def _get_property_names(api_key: str, refresh_token: str, object_type: str) -> List[str]: +def _get_property_names(api_key: str, refresh_token: str, object_type: str) -> list[str]: """ Retrieve property names for a given entity from the HubSpot API. diff --git a/posthog/temporal/data_imports/pipelines/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline.py index 0b3f7c448f1..0ac469c214d 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline.py +++ b/posthog/temporal/data_imports/pipelines/pipeline.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Literal +from typing import Literal from uuid import UUID import dlt @@ -95,7 +95,7 @@ class DataImportPipeline: dataset_name=self.inputs.dataset_name, ) - def _run(self) -> Dict[str, int]: + def _run(self) -> dict[str, int]: pipeline = self._create_pipeline() total_counts: Counter = Counter({}) @@ -121,7 +121,7 @@ class DataImportPipeline: return dict(total_counts) - async def run(self) -> Dict[str, int]: + async def run(self) -> dict[str, int]: try: return await asyncio.to_thread(self._run) except PipelineStepFailed: diff --git a/posthog/temporal/data_imports/pipelines/postgres/__init__.py b/posthog/temporal/data_imports/pipelines/postgres/__init__.py index 438b25fbe9d..07a368ed572 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/__init__.py +++ b/posthog/temporal/data_imports/pipelines/postgres/__init__.py @@ -1,7 +1,8 @@ """Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" -from typing import List, Optional, Union, Iterable, Any -from sqlalchemy import MetaData, Table, text +from typing import Optional, Union, List # noqa: UP035 +from collections.abc import Iterable +from sqlalchemy import MetaData, Table from sqlalchemy.engine import Engine import dlt @@ -35,7 +36,7 @@ def sql_database( credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, - table_names: Optional[List[str]] = dlt.config.value, + table_names: Optional[List[str]] = dlt.config.value, # noqa: UP006 ) -> Iterable[DltResource]: """ A DLT source which loads data from an SQL database using SQLAlchemy. diff --git a/posthog/temporal/data_imports/pipelines/postgres/helpers.py b/posthog/temporal/data_imports/pipelines/postgres/helpers.py index a288205063f..5805e789918 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/helpers.py +++ b/posthog/temporal/data_imports/pipelines/postgres/helpers.py @@ -2,11 +2,10 @@ from typing import ( Any, - List, Optional, - Iterator, Union, ) +from collections.abc import Iterator import operator import dlt @@ -63,7 +62,7 @@ class TableLoader: return query return query.where(filter_op(self.cursor_column, self.last_value)) # type: ignore - def load_rows(self) -> Iterator[List[TDataItem]]: + def load_rows(self) -> Iterator[list[TDataItem]]: query = self.make_query() with self.engine.connect() as conn: result = conn.execution_options(yield_per=self.chunk_size).execute(query) @@ -104,7 +103,7 @@ def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engi return create_engine(credentials) -def get_primary_key(table: Table) -> List[str]: +def get_primary_key(table: Table) -> list[str]: return [c.name for c in table.primary_key] diff --git a/posthog/temporal/data_imports/pipelines/stripe/helpers.py b/posthog/temporal/data_imports/pipelines/stripe/helpers.py index 7e2e02017b2..56494d3d47c 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/helpers.py +++ b/posthog/temporal/data_imports/pipelines/stripe/helpers.py @@ -1,6 +1,7 @@ """Stripe analytics source helpers""" -from typing import Any, Dict, Optional, Union, Iterable, Tuple +from typing import Any, Optional, Union +from collections.abc import Iterable import stripe import dlt @@ -32,7 +33,7 @@ async def stripe_get_data( start_date: Optional[Any] = None, end_date: Optional[Any] = None, **kwargs: Any, -) -> Dict[Any, Any]: +) -> dict[Any, Any]: if start_date: start_date = transform_date(start_date) if end_date: @@ -148,7 +149,7 @@ async def stripe_pagination( def stripe_source( api_key: str, account_id: str, - endpoints: Tuple[str, ...], + endpoints: tuple[str, ...], team_id, job_id, schema_id, diff --git a/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py b/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py index a1747d96c78..c478060940d 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py @@ -1,4 +1,4 @@ -from typing import Optional, TypedDict, Dict +from typing import Optional, TypedDict from dlt.common import pendulum from dlt.common.time import ensure_pendulum_datetime @@ -18,7 +18,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: def process_ticket( ticket: DictStrAny, - custom_fields: Dict[str, TCustomFieldInfo], + custom_fields: dict[str, TCustomFieldInfo], pivot_custom_fields: bool = True, ) -> DictStrAny: """ @@ -78,7 +78,7 @@ def process_ticket( return ticket -def process_ticket_field(field: DictStrAny, custom_fields_state: Dict[str, TCustomFieldInfo]) -> TDataItem: +def process_ticket_field(field: DictStrAny, custom_fields_state: dict[str, TCustomFieldInfo]) -> TDataItem: """Update custom field mapping in dlt state for the given field.""" # grab id and update state dict # if the id is new, add a new key to indicate that this is the initial value for title diff --git a/posthog/temporal/data_imports/pipelines/zendesk/credentials.py b/posthog/temporal/data_imports/pipelines/zendesk/credentials.py index aa0463bb441..d0565280595 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/credentials.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/credentials.py @@ -2,7 +2,7 @@ This module handles how credentials are read in dlt sources """ -from typing import ClassVar, List, Union +from typing import ClassVar, Union import dlt from dlt.common.configuration import configspec from dlt.common.configuration.specs import CredentialsConfiguration @@ -16,7 +16,7 @@ class ZendeskCredentialsBase(CredentialsConfiguration): """ subdomain: str - __config_gen_annotations__: ClassVar[List[str]] = [] + __config_gen_annotations__: ClassVar[list[str]] = [] @configspec diff --git a/posthog/temporal/data_imports/pipelines/zendesk/helpers.py b/posthog/temporal/data_imports/pipelines/zendesk/helpers.py index 8c0e0427c3f..c29f41279a0 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/helpers.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/helpers.py @@ -1,4 +1,5 @@ -from typing import Iterator, Optional, Iterable, Tuple +from typing import Optional +from collections.abc import Iterator, Iterable from itertools import chain import dlt @@ -211,7 +212,7 @@ def chats_table_resource( def zendesk_support( team_id: int, credentials: TZendeskCredentials = dlt.secrets.value, - endpoints: Tuple[str, ...] = (), + endpoints: tuple[str, ...] = (), pivot_ticket_fields: bool = True, start_date: Optional[TAnyDateTime] = DEFAULT_START_DATE, end_date: Optional[TAnyDateTime] = None, diff --git a/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py b/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py index 5db9a28eafc..4ebf375bf70 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Dict, Iterator, Optional, Tuple, Any +from typing import Optional, Any +from collections.abc import Iterator from dlt.common.typing import DictStrStr, TDataItems, TSecretValue from dlt.sources.helpers.requests import client @@ -27,7 +28,7 @@ class ZendeskAPIClient: subdomain: str = "" url: str = "" headers: Optional[DictStrStr] - auth: Optional[Tuple[str, TSecretValue]] + auth: Optional[tuple[str, TSecretValue]] def __init__(self, credentials: TZendeskCredentials, url_prefix: Optional[str] = None) -> None: """ @@ -64,7 +65,7 @@ class ZendeskAPIClient: endpoint: str, data_point_name: str, pagination: PaginationType, - params: Optional[Dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, ) -> Iterator[TDataItems]: """ Makes a request to a paginated endpoint and returns a generator of data items per page. diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index a8389611374..e6407e9f785 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -3,7 +3,6 @@ import uuid from asgiref.sync import sync_to_async from temporalio import activity -from typing import Tuple # TODO: remove dependency from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING @@ -24,7 +23,7 @@ class CreateExternalDataJobModelActivityInputs: @activity.defn -async def create_external_data_job_model_activity(inputs: CreateExternalDataJobModelActivityInputs) -> Tuple[str, bool]: +async def create_external_data_job_model_activity(inputs: CreateExternalDataJobModelActivityInputs) -> tuple[str, bool]: run = await sync_to_async(create_external_data_job)( team_id=inputs.team_id, external_data_source_id=inputs.source_id, diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py index b1730221f8b..b6806071e72 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data.py @@ -17,7 +17,6 @@ from posthog.warehouse.models import ( get_external_data_job, ) from posthog.temporal.common.logger import bind_temporal_worker_logger -from typing import Dict, Tuple import asyncio from django.conf import settings from django.utils import timezone @@ -34,7 +33,7 @@ class ImportDataActivityInputs: @activity.defn -async def import_data_activity(inputs: ImportDataActivityInputs) -> Tuple[TSchemaTables, Dict[str, int]]: # noqa: F821 +async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchemaTables, dict[str, int]]: # noqa: F821 model: ExternalDataJob = await get_external_data_job( job_id=inputs.run_id, ) diff --git a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py index b0163b8fee7..7b7e2b56674 100644 --- a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py @@ -190,8 +190,9 @@ async def test_insert_into_http_activity_inserts_data_into_http_endpoint( ) mock_server = MockServer() - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) await activity_environment.run(insert_into_http_activity, insert_inputs) @@ -239,22 +240,25 @@ async def test_insert_into_http_activity_throws_on_bad_http_status( **http_config, ) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=400, repeat=True) with pytest.raises(NonRetryableResponseError): await activity_environment.run(insert_into_http_activity, insert_inputs) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=429, repeat=True) with pytest.raises(RetryableResponseError): await activity_environment.run(insert_into_http_activity, insert_inputs) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=500, repeat=True) with pytest.raises(RetryableResponseError): @@ -352,8 +356,9 @@ async def test_http_export_workflow( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) @@ -589,8 +594,9 @@ async def test_insert_into_http_activity_heartbeats( ) mock_server = MockServer() - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) await activity_environment.run(insert_into_http_activity, insert_inputs) diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index 459dff8dc3c..6652ac224b2 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -175,7 +175,7 @@ def add_mock_snowflake_api(rsps: responses.RequestsMock, fail: bool | str = Fals # contents as a string in `staged_files`. if match := re.match(r"^PUT file://(?P.*) @%(?P.*)$", sql_text): file_path = match.group("file_path") - with open(file_path, "r") as f: + with open(file_path) as f: staged_files.append(f.read()) if fail == "put": @@ -414,9 +414,12 @@ async def test_snowflake_export_workflow_exports_events( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with unittest.mock.patch( - "posthog.temporal.batch_exports.snowflake_batch_export.snowflake.connector.connect", - ) as mock, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1): + with ( + unittest.mock.patch( + "posthog.temporal.batch_exports.snowflake_batch_export.snowflake.connector.connect", + ) as mock, + override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1), + ): fake_conn = FakeSnowflakeConnection() mock.return_value = fake_conn @@ -482,10 +485,13 @@ async def test_snowflake_export_workflow_without_events(ateam, snowflake_batch_e ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with responses.RequestsMock( - target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send", - assert_all_requests_are_fired=False, - ) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2): + with ( + responses.RequestsMock( + target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send", + assert_all_requests_are_fired=False, + ) as rsps, + override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2), + ): queries, staged_files = add_mock_snowflake_api(rsps) await activity_environment.client.execute_workflow( SnowflakeBatchExportWorkflow.run, diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index 80d5aa7b2cb..37431206fe3 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -330,12 +330,14 @@ async def test_run_stripe_job(activity_environment, team, minio_client, **kwargs job_1, job_1_inputs = await setup_job_1() job_2, job_2_inputs = await setup_job_2() - with mock.patch("stripe.Customer.list") as mock_customer_list, mock.patch( - "stripe.Charge.list" - ) as mock_charge_list, override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + mock.patch("stripe.Charge.list") as mock_charge_list, + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -410,10 +412,13 @@ async def test_run_stripe_job_cancelled(activity_environment, team, minio_client job_1, job_1_inputs = await setup_job_1() - with mock.patch("stripe.Customer.list") as mock_customer_list, override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -475,12 +480,14 @@ async def test_run_stripe_job_row_count_update(activity_environment, team, minio job_1, job_1_inputs = await setup_job_1() - with mock.patch("stripe.Customer.list") as mock_customer_list, mock.patch( - "posthog.temporal.data_imports.pipelines.helpers.CHUNK_SIZE", 0 - ), override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + mock.patch("posthog.temporal.data_imports.pipelines.helpers.CHUNK_SIZE", 0), + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -527,9 +534,10 @@ async def test_validate_schema_and_update_table_activity(activity_environment, t test_1_schema = await _create_schema("test-1", new_source, team) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -597,9 +605,10 @@ async def test_validate_schema_and_update_table_activity_with_existing(activity_ test_1_schema = await _create_schema("test-1", new_source, team, table_id=existing_table.id) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -640,9 +649,13 @@ async def test_validate_schema_and_update_table_activity_half_run(activity_envir rows_synced=0, ) - with mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, mock.patch( - "posthog.warehouse.data_load.validate_schema.validate_schema", - ) as mock_validate, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + mock.patch( + "posthog.warehouse.data_load.validate_schema.validate_schema", + ) as mock_validate, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} credential = await sync_to_async(DataWarehouseCredential.objects.create)( team=team, @@ -708,9 +721,10 @@ async def test_create_schema_activity(activity_environment, team, **kwargs): test_1_schema = await _create_schema("test-1", new_source, team) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -763,9 +777,10 @@ async def test_external_data_job_workflow_with_schema(team, **kwargs): async def mock_async_func(inputs): return {} - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"} - ), mock.patch.object(DataImportPipeline, "run", mock_async_func): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"}), + mock.patch.object(DataImportPipeline, "run", mock_async_func), + ): with override_settings(AIRBYTE_BUCKET_KEY="test-key", AIRBYTE_BUCKET_SECRET="test-secret"): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( @@ -910,13 +925,17 @@ async def test_check_schedule_activity_with_missing_schema_id_but_with_schedule( should_sync=True, ) - with mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_trigger_external_data_workflow" - ) as mock_a_trigger_external_data_workflow: + with ( + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_trigger_external_data_workflow" + ) as mock_a_trigger_external_data_workflow, + ): should_exit = await activity_environment.run( check_schedule_activity, ExternalDataWorkflowInputs( @@ -950,13 +969,17 @@ async def test_check_schedule_activity_with_missing_schema_id_and_no_schedule(ac should_sync=True, ) - with mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=False - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_sync_external_data_job_workflow" - ) as mock_a_sync_external_data_job_workflow: + with ( + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=False + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_sync_external_data_job_workflow" + ) as mock_a_sync_external_data_job_workflow, + ): should_exit = await activity_environment.run( check_schedule_activity, ExternalDataWorkflowInputs( diff --git a/posthog/test/base.py b/posthog/test/base.py index c96738aafa1..2ebfa6178e2 100644 --- a/posthog/test/base.py +++ b/posthog/test/base.py @@ -7,7 +7,8 @@ import time import uuid from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union, Generator +from typing import Any, Optional, Union +from collections.abc import Generator from unittest.mock import patch import freezegun @@ -86,8 +87,8 @@ from posthog.test.assert_faster_than import assert_faster_than freezegun.configure(extend_ignore_list=["posthog.test.assert_faster_than"]) # type: ignore -persons_cache_tests: List[Dict[str, Any]] = [] -events_cache_tests: List[Dict[str, Any]] = [] +persons_cache_tests: list[dict[str, Any]] = [] +events_cache_tests: list[dict[str, Any]] = [] persons_ordering_int: int = 1 @@ -124,7 +125,7 @@ class FuzzyInt(int): highest: int def __new__(cls, lowest, highest): - obj = super(FuzzyInt, cls).__new__(cls, highest) + obj = super().__new__(cls, highest) obj.lowest = lowest obj.highest = highest return obj @@ -144,7 +145,7 @@ class ErrorResponsesMixin: "attr": None, } - def not_found_response(self, message: str = "Not found.") -> Dict[str, Optional[str]]: + def not_found_response(self, message: str = "Not found.") -> dict[str, Optional[str]]: return { "type": "invalid_request", "code": "not_found", @@ -154,7 +155,7 @@ class ErrorResponsesMixin: def permission_denied_response( self, message: str = "You do not have permission to perform this action." - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "authentication_error", "code": "permission_denied", @@ -162,7 +163,7 @@ class ErrorResponsesMixin: "attr": None, } - def method_not_allowed_response(self, method: str) -> Dict[str, Optional[str]]: + def method_not_allowed_response(self, method: str) -> dict[str, Optional[str]]: return { "type": "invalid_request", "code": "method_not_allowed", @@ -174,7 +175,7 @@ class ErrorResponsesMixin: self, message: str = "Authentication credentials were not provided.", code: str = "not_authenticated", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "authentication_error", "code": code, @@ -187,7 +188,7 @@ class ErrorResponsesMixin: message: str = "Malformed request", code: str = "invalid_input", attr: Optional[str] = None, - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "validation_error", "code": code, @@ -820,7 +821,7 @@ class ClickhouseTestMixin(QueryMatchingTest): return self.capture_queries(("SELECT", "WITH", "select", "with")) @contextmanager - def capture_queries(self, query_prefixes: Union[str, Tuple[str, ...]]): + def capture_queries(self, query_prefixes: Union[str, tuple[str, ...]]): queries = [] original_get_client = ch_pool.get_client @@ -863,7 +864,7 @@ def failhard_threadhook_context(): threading.excepthook = old_hook -def run_clickhouse_statement_in_parallel(statements: List[str]): +def run_clickhouse_statement_in_parallel(statements: list[str]): jobs = [] with failhard_threadhook_context(): for item in statements: @@ -1063,8 +1064,8 @@ def also_test_with_person_on_events_v2(fn): def _create_insight( - team: Team, insight_filters: Dict[str, Any], dashboard_filters: Dict[str, Any] -) -> Tuple[Insight, Dashboard, DashboardTile]: + team: Team, insight_filters: dict[str, Any], dashboard_filters: dict[str, Any] +) -> tuple[Insight, Dashboard, DashboardTile]: dashboard = Dashboard.objects.create(team=team, filters=dashboard_filters) insight = Insight.objects.create(team=team, filters=insight_filters) dashboard_tile = DashboardTile.objects.create(dashboard=dashboard, insight=insight) @@ -1088,7 +1089,7 @@ def create_person_id_override_by_distinct_id( """ ) - person_id_from, person_id_to = [row[1] for row in person_ids_result] + person_id_from, person_id_to = (row[1] for row in person_ids_result) sync_execute( f""" diff --git a/posthog/test/db_context_capturing.py b/posthog/test/db_context_capturing.py index 60600235456..44c1b05d23c 100644 --- a/posthog/test/db_context_capturing.py +++ b/posthog/test/db_context_capturing.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Generator +from collections.abc import Generator from django.db import DEFAULT_DB_ALIAS, connections from django.test.utils import CaptureQueriesContext diff --git a/posthog/test/test_feature_flag.py b/posthog/test/test_feature_flag.py index 38afbe7dbbc..91db555b31b 100644 --- a/posthog/test/test_feature_flag.py +++ b/posthog/test/test_feature_flag.py @@ -2784,8 +2784,9 @@ class TestFeatureFlagMatcher(BaseTest, QueryMatchingTest): key="variant", ) - with self.assertNumQueries(10), snapshot_postgres_queries_context( - self + with ( + self.assertNumQueries(10), + snapshot_postgres_queries_context(self), ): # 1 to fill group cache, 2 to match feature flags with group properties (of each type), 1 to match feature flags with person properties matches, reasons, payloads, _ = FeatureFlagMatcher( [ @@ -2859,8 +2860,9 @@ class TestFeatureFlagMatcher(BaseTest, QueryMatchingTest): self.assertEqual(payloads, {"variant": {"color": "blue"}}) - with self.assertNumQueries(9), snapshot_postgres_queries_context( - self + with ( + self.assertNumQueries(9), + snapshot_postgres_queries_context(self), ): # 1 to fill group cache, 1 to match feature flags with group properties (only 1 group provided), 1 to match feature flags with person properties matches, reasons, payloads, _ = FeatureFlagMatcher( [ @@ -6016,8 +6018,9 @@ class TestHashKeyOverridesRaceConditions(TransactionTestCase, QueryMatchingTest) properties={"email": "tim@posthog.com", "team": "posthog"}, ) - with snapshot_postgres_queries_context(self, capture_all_queries=True), connection.execute_wrapper( - InsertFailOnce() + with ( + snapshot_postgres_queries_context(self, capture_all_queries=True), + connection.execute_wrapper(InsertFailOnce()), ): flags, reasons, payloads, errors = get_all_feature_flags( team.pk, "other_id", {}, hash_key_override="example_id" diff --git a/posthog/test/test_feature_flag_analytics.py b/posthog/test/test_feature_flag_analytics.py index f5a5f37e0ac..ed8228ff211 100644 --- a/posthog/test/test_feature_flag_analytics.py +++ b/posthog/test/test_feature_flag_analytics.py @@ -77,8 +77,9 @@ class TestFeatureFlagAnalytics(BaseTest, QueryMatchingTest): team_uuid = "team-uuid" other_team_uuid = "other-team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -299,8 +300,9 @@ class TestFeatureFlagAnalytics(BaseTest, QueryMatchingTest): other_team_id = 1243 team_uuid = "team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -400,8 +402,9 @@ class TestFeatureFlagAnalytics(BaseTest, QueryMatchingTest): team_uuid = "team-uuid" other_team_uuid = "other-team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -489,8 +492,9 @@ class TestFeatureFlagAnalytics(BaseTest, QueryMatchingTest): other_team_id = 1243 team_uuid = "team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket diff --git a/posthog/test/test_health.py b/posthog/test/test_health.py index 89611fb11ee..2ce4e464e8c 100644 --- a/posthog/test/test_health.py +++ b/posthog/test/test_health.py @@ -1,7 +1,7 @@ import logging from contextlib import contextmanager import random -from typing import List, Optional +from typing import Optional from unittest import mock from unittest.mock import patch @@ -70,7 +70,13 @@ def test_livez_returns_200_and_doesnt_require_any_dependencies(client: Client): just be an indicator that the python process hasn't hung. """ - with simulate_postgres_error(), simulate_kafka_cannot_connect(), simulate_clickhouse_cannot_connect(), simulate_celery_cannot_connect(), simulate_cache_cannot_connect(): + with ( + simulate_postgres_error(), + simulate_kafka_cannot_connect(), + simulate_clickhouse_cannot_connect(), + simulate_celery_cannot_connect(), + simulate_cache_cannot_connect(), + ): resp = get_livez(client) assert resp.status_code == 200, resp.content @@ -263,7 +269,7 @@ def test_readyz_complains_if_role_does_not_exist(client: Client): assert data["error"] == "InvalidRole" -def get_readyz(client: Client, exclude: Optional[List[str]] = None, role: Optional[str] = None) -> HttpResponse: +def get_readyz(client: Client, exclude: Optional[list[str]] = None, role: Optional[str] = None) -> HttpResponse: return client.get("/_readyz", data={"exclude": exclude or [], "role": role or ""}) diff --git a/posthog/test/test_journeys.py b/posthog/test/test_journeys.py index 0e535437076..69bb2050d8f 100644 --- a/posthog/test/test_journeys.py +++ b/posthog/test/test_journeys.py @@ -3,7 +3,7 @@ from hashlib import md5 import json from datetime import datetime import os -from typing import Any, Dict, List +from typing import Any from uuid import UUID, uuid4 from django.utils import timezone @@ -15,10 +15,10 @@ from posthog.test.base import _create_event, flush_persons_and_events def journeys_for( - events_by_person: Dict[str, List[Dict[str, Any]]], + events_by_person: dict[str, list[dict[str, Any]]], team: Team, create_people: bool = True, -) -> Dict[str, Person]: +) -> dict[str, Person]: """ Helper for creating specific events for a team. @@ -115,11 +115,11 @@ def journeys_for( return people -def _create_all_events_raw(all_events: List[Dict]): +def _create_all_events_raw(all_events: list[dict]): parsed = "" for event in all_events: timestamp = timezone.now() - data: Dict[str, Any] = { + data: dict[str, Any] = { "properties": {}, "timestamp": timestamp.strftime("%Y-%m-%d %H:%M:%S.%f"), "person_id": str(uuid4()), @@ -162,7 +162,7 @@ def _create_all_events_raw(all_events: List[Dict]): ) -def create_all_events(all_events: List[dict]): +def create_all_events(all_events: list[dict]): for event in all_events: _create_event(**event) @@ -175,15 +175,15 @@ class InMemoryEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict person_id: str person_created_at: datetime - person_properties: Dict - group0_properties: Dict - group1_properties: Dict - group2_properties: Dict - group3_properties: Dict - group4_properties: Dict + person_properties: dict + group0_properties: dict + group1_properties: dict + group2_properties: dict + group3_properties: dict + group4_properties: dict group0_created_at: datetime group1_created_at: datetime group2_created_at: datetime @@ -191,7 +191,7 @@ class InMemoryEvent: group4_created_at: datetime -def update_or_create_person(distinct_ids: List[str], team_id: int, **kwargs): +def update_or_create_person(distinct_ids: list[str], team_id: int, **kwargs): (person, _) = Person.objects.update_or_create( persondistinctid__distinct_id__in=distinct_ids, persondistinctid__team_id=team_id, diff --git a/posthog/test/test_utils.py b/posthog/test/test_utils.py index 827c5dd1de8..dab6a4d1e0e 100644 --- a/posthog/test/test_utils.py +++ b/posthog/test/test_utils.py @@ -434,7 +434,7 @@ class TestShouldRefresh(TestCase): def test_refresh_requested_by_client_with_data_true(self): drf_request = Request(HttpRequest()) drf_request._full_data = {"refresh": True} # type: ignore - self.assertTrue(refresh_requested_by_client((drf_request))) + self.assertTrue(refresh_requested_by_client(drf_request)) def test_should_not_refresh_with_data_false(self): drf_request = Request(HttpRequest()) diff --git a/posthog/urls.py b/posthog/urls.py index b047f897307..3681f4a1ca4 100644 --- a/posthog/urls.py +++ b/posthog/urls.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from posthog.models.instance_setting import get_instance_setting from urllib.parse import urlparse @@ -60,7 +61,7 @@ import structlog logger = structlog.get_logger(__name__) -ee_urlpatterns: List[Any] = [] +ee_urlpatterns: list[Any] = [] try: from ee.urls import extend_api_router from ee.urls import urlpatterns as ee_urlpatterns diff --git a/posthog/user_permissions.py b/posthog/user_permissions.py index 30a6bfca298..7b4d9b07728 100644 --- a/posthog/user_permissions.py +++ b/posthog/user_permissions.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast from uuid import UUID from posthog.constants import AvailableFeature @@ -32,10 +32,10 @@ class UserPermissions: self.user = user self._current_team = team - self._tiles: Optional[List[DashboardTile]] = None - self._team_permissions: Dict[int, UserTeamPermissions] = {} - self._dashboard_permissions: Dict[int, UserDashboardPermissions] = {} - self._insight_permissions: Dict[int, UserInsightPermissions] = {} + self._tiles: Optional[list[DashboardTile]] = None + self._team_permissions: dict[int, UserTeamPermissions] = {} + self._dashboard_permissions: dict[int, UserDashboardPermissions] = {} + self._insight_permissions: dict[int, UserInsightPermissions] = {} @cached_property def current_team(self) -> "UserTeamPermissions": @@ -68,7 +68,7 @@ class UserPermissions: return self._insight_permissions[insight.pk] @cached_property - def team_ids_visible_for_user(self) -> List[int]: + def team_ids_visible_for_user(self) -> list[int]: candidate_teams = Team.objects.filter(organization_id__in=self.organizations.keys()).only( "pk", "organization_id", "access_control" ) @@ -86,16 +86,16 @@ class UserPermissions: return self.organizations.get(organization_id) @cached_property - def organizations(self) -> Dict[UUID, Organization]: + def organizations(self) -> dict[UUID, Organization]: return {member.organization_id: member.organization for member in self.organization_memberships.values()} @cached_property - def organization_memberships(self) -> Dict[UUID, OrganizationMembership]: + def organization_memberships(self) -> dict[UUID, OrganizationMembership]: memberships = OrganizationMembership.objects.filter(user=self.user).select_related("organization") return {membership.organization_id: membership for membership in memberships} @cached_property - def explicit_team_memberships(self) -> Dict[int, Any]: + def explicit_team_memberships(self) -> dict[int, Any]: try: from ee.models import ExplicitTeamMembership except ImportError: @@ -107,7 +107,7 @@ class UserPermissions: return {membership.team_id: membership.level for membership in memberships} @cached_property - def dashboard_privileges(self) -> Dict[int, Dashboard.PrivilegeLevel]: + def dashboard_privileges(self) -> dict[int, Dashboard.PrivilegeLevel]: try: from ee.models import DashboardPrivilege @@ -116,14 +116,14 @@ class UserPermissions: except ImportError: return {} - def set_preloaded_dashboard_tiles(self, tiles: List[DashboardTile]): + def set_preloaded_dashboard_tiles(self, tiles: list[DashboardTile]): """ Allows for speeding up insight-related permissions code """ self._tiles = tiles @cached_property - def preloaded_insight_dashboards(self) -> Optional[List[Dashboard]]: + def preloaded_insight_dashboards(self) -> Optional[list[Dashboard]]: if self._tiles is None: return None diff --git a/posthog/utils.py b/posthog/utils.py index f7c32736b25..cdc0a4ed48f 100644 --- a/posthog/utils.py +++ b/posthog/utils.py @@ -19,15 +19,11 @@ from functools import lru_cache, wraps from typing import ( TYPE_CHECKING, Any, - Dict, - Generator, - List, - Mapping, Optional, - Tuple, Union, cast, ) +from collections.abc import Generator, Mapping from urllib.parse import urljoin, urlparse from zoneinfo import ZoneInfo @@ -125,7 +121,7 @@ def absolute_uri(url: Optional[str] = None) -> str: return urljoin(settings.SITE_URL.rstrip("/") + "/", url.lstrip("/")) -def get_previous_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.datetime, datetime.datetime]: +def get_previous_day(at: Optional[datetime.datetime] = None) -> tuple[datetime.datetime, datetime.datetime]: """ Returns a pair of datetimes, representing the start and end of the preceding day. `at` is the datetime to use as a reference point. @@ -149,7 +145,7 @@ def get_previous_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.d return (period_start, period_end) -def get_current_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.datetime, datetime.datetime]: +def get_current_day(at: Optional[datetime.datetime] = None) -> tuple[datetime.datetime, datetime.datetime]: """ Returns a pair of datetimes, representing the start and end of the current day. `at` is the datetime to use as a reference point. @@ -179,7 +175,7 @@ def relative_date_parse_with_delta_mapping( *, always_truncate: bool = False, now: Optional[datetime.datetime] = None, -) -> Tuple[datetime.datetime, Optional[Dict[str, int]], str | None]: +) -> tuple[datetime.datetime, Optional[dict[str, int]], str | None]: """Returns the parsed datetime, along with the period mapping - if the input was a relative datetime string.""" try: try: @@ -202,7 +198,7 @@ def relative_date_parse_with_delta_mapping( regex = r"\-?(?P[0-9]+)?(?P[a-z])(?PStart|End)?" match = re.search(regex, input) parsed_dt = (now or dt.datetime.now()).astimezone(timezone_info) - delta_mapping: Dict[str, int] = {} + delta_mapping: dict[str, int] = {} if not match: return parsed_dt, delta_mapping, None if match.group("type") == "h": @@ -276,7 +272,7 @@ def get_js_url(request: HttpRequest) -> str: def render_template( template_name: str, request: HttpRequest, - context: Optional[Dict] = None, + context: Optional[dict] = None, *, team_for_public_context: Optional["Team"] = None, ) -> HttpResponse: @@ -331,13 +327,13 @@ def render_template( except: year_in_hog_url = None - posthog_app_context: Dict[str, Any] = { + posthog_app_context: dict[str, Any] = { "persisted_feature_flags": settings.PERSISTED_FEATURE_FLAGS, "anonymous": not request.user or not request.user.is_authenticated, "year_in_hog_url": year_in_hog_url, } - posthog_bootstrap: Dict[str, Any] = {} + posthog_bootstrap: dict[str, Any] = {} posthog_distinct_id: Optional[str] = None # Set the frontend app context @@ -453,7 +449,7 @@ def get_default_event_name(team: "Team"): return "$pageview" -def get_frontend_apps(team_id: int) -> Dict[int, Dict[str, Any]]: +def get_frontend_apps(team_id: int) -> dict[int, dict[str, Any]]: from posthog.models import Plugin, PluginSourceFile plugin_configs = ( @@ -541,10 +537,10 @@ def convert_property_value(input: Union[str, bool, dict, list, int, Optional[str def get_compare_period_dates( date_from: datetime.datetime, date_to: datetime.datetime, - date_from_delta_mapping: Optional[Dict[str, int]], - date_to_delta_mapping: Optional[Dict[str, int]], + date_from_delta_mapping: Optional[dict[str, int]], + date_to_delta_mapping: Optional[dict[str, int]], interval: str, -) -> Tuple[datetime.datetime, datetime.datetime]: +) -> tuple[datetime.datetime, datetime.datetime]: diff = date_to - date_from new_date_from = date_from - diff if interval == "hour": @@ -783,7 +779,7 @@ def get_plugin_server_version() -> Optional[str]: return None -def get_plugin_server_job_queues() -> Optional[List[str]]: +def get_plugin_server_job_queues() -> Optional[list[str]]: cache_key_value = get_client().get("@posthog-plugin-server/enabled-job-queues") if cache_key_value: qs = cache_key_value.decode("utf-8").replace('"', "") @@ -861,13 +857,13 @@ def get_can_create_org(user: Union["AbstractBaseUser", "AnonymousUser"]) -> bool return False -def get_instance_available_sso_providers() -> Dict[str, bool]: +def get_instance_available_sso_providers() -> dict[str, bool]: """ Returns a dictionary containing final determination to which SSO providers are available. SAML is not included in this method as it can only be configured domain-based and not instance-based (see `OrganizationDomain` for details) Validates configuration settings and license validity (if applicable). """ - output: Dict[str, bool] = { + output: dict[str, bool] = { "github": bool(settings.SOCIAL_AUTH_GITHUB_KEY and settings.SOCIAL_AUTH_GITHUB_SECRET), "gitlab": bool(settings.SOCIAL_AUTH_GITLAB_KEY and settings.SOCIAL_AUTH_GITLAB_SECRET), "google-oauth2": False, @@ -897,7 +893,7 @@ def get_instance_available_sso_providers() -> Dict[str, bool]: return output -def flatten(i: Union[List, Tuple], max_depth=10) -> Generator: +def flatten(i: Union[list, tuple], max_depth=10) -> Generator: for el in i: if isinstance(el, list) and max_depth > 0: yield from flatten(el, max_depth=max_depth - 1) @@ -909,7 +905,7 @@ def get_daterange( start_date: Optional[datetime.datetime], end_date: Optional[datetime.datetime], frequency: str, -) -> List[Any]: +) -> list[Any]: """ Returns list of a fixed frequency Datetime objects between given bounds. @@ -981,7 +977,7 @@ class GenericEmails: """ def __init__(self): - with open(get_absolute_path("helpers/generic_emails.txt"), "r") as f: + with open(get_absolute_path("helpers/generic_emails.txt")) as f: self.emails = {x.rstrip(): True for x in f} def is_generic(self, email: str) -> bool: @@ -992,7 +988,7 @@ class GenericEmails: @lru_cache(maxsize=1) -def get_available_timezones_with_offsets() -> Dict[str, float]: +def get_available_timezones_with_offsets() -> dict[str, float]: now = dt.datetime.now() result = {} for tz in pytz.common_timezones: @@ -1066,7 +1062,7 @@ def get_milliseconds_between_dates(d1: dt.datetime, d2: dt.datetime) -> int: return abs(int((d1 - d2).total_seconds() * 1000)) -def encode_get_request_params(data: Dict[str, Any]) -> Dict[str, str]: +def encode_get_request_params(data: dict[str, Any]) -> dict[str, str]: return { key: encode_value_as_param(value=value) for key, value in data.items() @@ -1083,7 +1079,7 @@ class DataclassJSONEncoder(json.JSONEncoder): def encode_value_as_param(value: Union[str, list, dict, datetime.datetime]) -> str: - if isinstance(value, (list, dict, tuple)): + if isinstance(value, list | dict | tuple): return json.dumps(value, cls=DataclassJSONEncoder) elif isinstance(value, Enum): return value.value @@ -1311,7 +1307,7 @@ def patchable(fn): def label_for_team_id_to_track(team_id: int) -> str: - team_id_filter: List[str] = settings.DECIDE_TRACK_TEAM_IDS + team_id_filter: list[str] = settings.DECIDE_TRACK_TEAM_IDS team_id_as_string = str(team_id) diff --git a/posthog/version_requirement.py b/posthog/version_requirement.py index 0f60d553e76..ad0979abc3b 100644 --- a/posthog/version_requirement.py +++ b/posthog/version_requirement.py @@ -1,5 +1,3 @@ -from typing import Tuple - from semantic_version.base import SimpleSpec, Version from posthog import redis @@ -24,7 +22,7 @@ class ServiceVersionRequirement: f"The provided supported_version for service {service} is invalid. See the Docs for SimpleSpec: https://pypi.org/project/semantic-version/" ) - def is_service_in_accepted_version(self) -> Tuple[bool, Version]: + def is_service_in_accepted_version(self) -> tuple[bool, Version]: service_version = self.get_service_version() return service_version in self.supported_version, service_version diff --git a/posthog/views.py b/posthog/views.py index b9cae80fde3..6797b3ab7f8 100644 --- a/posthog/views.py +++ b/posthog/views.py @@ -1,6 +1,6 @@ import os from functools import wraps -from typing import Dict, Union +from typing import Union import sentry_sdk from django.conf import settings @@ -70,7 +70,7 @@ def health(request): def stats(request): - stats_response: Dict[str, Union[int, str]] = {} + stats_response: dict[str, Union[int, str]] = {} stats_response["worker_heartbeat"] = get_celery_heartbeat() return JsonResponse(stats_response) diff --git a/posthog/warehouse/api/external_data_schema.py b/posthog/warehouse/api/external_data_schema.py index e7abb808ce5..c02f6c146f7 100644 --- a/posthog/warehouse/api/external_data_schema.py +++ b/posthog/warehouse/api/external_data_schema.py @@ -2,7 +2,7 @@ from rest_framework import serializers import structlog import temporalio from posthog.warehouse.models import ExternalDataSchema, ExternalDataJob -from typing import Optional, Dict, Any +from typing import Optional, Any from posthog.api.routing import TeamAndOrgViewSetMixin from rest_framework import viewsets, filters, status from rest_framework.decorators import action @@ -47,7 +47,7 @@ class ExternalDataSchemaSerializer(serializers.ModelSerializer): return SimpleTableSerializer(schema.table, context={"database": hogql_context}).data or None - def update(self, instance: ExternalDataSchema, validated_data: Dict[str, Any]) -> ExternalDataSchema: + def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) -> ExternalDataSchema: should_sync = validated_data.get("should_sync", None) schedule_exists = external_data_workflow_exists(str(instance.id)) @@ -77,7 +77,7 @@ class ExternalDataSchemaViewset(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["name"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index 36142a80593..c8e70315409 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, List, Tuple, Dict +from typing import Any import structlog from rest_framework import filters, serializers, status, viewsets @@ -71,7 +71,7 @@ class ExternalDataSourceSerializers(serializers.ModelSerializer): return latest_completed_run.created_at if latest_completed_run else None def get_status(self, instance: ExternalDataSource) -> str: - active_schemas: List[ExternalDataSchema] = list(instance.schemas.filter(should_sync=True).all()) + active_schemas: list[ExternalDataSchema] = list(instance.schemas.filter(should_sync=True).all()) any_failures = any(schema.status == ExternalDataSchema.Status.ERROR for schema in active_schemas) any_cancelled = any(schema.status == ExternalDataSchema.Status.CANCELLED for schema in active_schemas) any_paused = any(schema.status == ExternalDataSchema.Status.PAUSED for schema in active_schemas) @@ -122,7 +122,7 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["source_id"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context @@ -193,7 +193,7 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): disabled_schemas = [schema for schema in default_schemas if schema not in enabled_schemas] - active_schemas: List[ExternalDataSchema] = [] + active_schemas: list[ExternalDataSchema] = [] for schema in enabled_schemas: active_schemas.append( @@ -289,7 +289,7 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): def _handle_postgres_source( self, request: Request, *args: Any, **kwargs: Any - ) -> Tuple[ExternalDataSource, List[Any]]: + ) -> tuple[ExternalDataSource, list[Any]]: payload = request.data["payload"] prefix = request.data.get("prefix", None) source_type = request.data["source_type"] diff --git a/posthog/warehouse/api/saved_query.py b/posthog/warehouse/api/saved_query.py index f341b5779d0..581593377f2 100644 --- a/posthog/warehouse/api/saved_query.py +++ b/posthog/warehouse/api/saved_query.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.conf import settings from rest_framework import exceptions, filters, serializers, viewsets @@ -33,7 +33,7 @@ class DataWarehouseSavedQuerySerializer(serializers.ModelSerializer): ] read_only_fields = ["id", "created_by", "created_at", "columns"] - def get_columns(self, view: DataWarehouseSavedQuery) -> List[SerializedField]: + def get_columns(self, view: DataWarehouseSavedQuery) -> list[SerializedField]: team_id = self.context["team_id"] context = HogQLContext(team_id=team_id, database=create_hogql_database(team_id=team_id)) diff --git a/posthog/warehouse/api/table.py b/posthog/warehouse/api/table.py index fcfdd7eee88..7e149b0faba 100644 --- a/posthog/warehouse/api/table.py +++ b/posthog/warehouse/api/table.py @@ -1,4 +1,4 @@ -from typing import Any, List, Dict +from typing import Any from rest_framework import filters, request, response, serializers, status, viewsets from rest_framework.exceptions import NotAuthenticated @@ -53,7 +53,7 @@ class TableSerializer(serializers.ModelSerializer): ] read_only_fields = ["id", "created_by", "created_at", "columns", "external_data_source", "external_schema"] - def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]: + def get_columns(self, table: DataWarehouseTable) -> list[SerializedField]: hogql_context = self.context.get("database", None) if not hogql_context: hogql_context = create_hogql_database(team_id=self.context["team_id"]) @@ -91,7 +91,7 @@ class SimpleTableSerializer(serializers.ModelSerializer): fields = ["id", "name", "columns", "row_count"] read_only_fields = ["id", "name", "columns", "row_count"] - def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]: + def get_columns(self, table: DataWarehouseTable) -> list[SerializedField]: hogql_context = self.context.get("database", None) if not hogql_context: hogql_context = create_hogql_database(team_id=self.context["team_id"]) @@ -111,7 +111,7 @@ class TableViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["name"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context diff --git a/posthog/warehouse/data_load/validate_schema.py b/posthog/warehouse/data_load/validate_schema.py index 6a7e2512583..f3755442d3c 100644 --- a/posthog/warehouse/data_load/validate_schema.py +++ b/posthog/warehouse/data_load/validate_schema.py @@ -29,13 +29,12 @@ from posthog.warehouse.models.external_data_job import ExternalDataJob from posthog.temporal.common.logger import bind_temporal_worker_logger from clickhouse_driver.errors import ServerException from asgiref.sync import sync_to_async -from typing import Dict, Type from posthog.utils import camel_to_snake_case from posthog.warehouse.models.external_data_schema import ExternalDataSchema def dlt_to_hogql_type(dlt_type: TDataType | None) -> str: - hogql_type: Type[DatabaseField] = DatabaseField + hogql_type: type[DatabaseField] = DatabaseField if dlt_type is None: hogql_type = StringDatabaseField @@ -69,7 +68,7 @@ def dlt_to_hogql_type(dlt_type: TDataType | None) -> str: async def validate_schema( credential: DataWarehouseCredential, table_name: str, new_url_pattern: str, team_id: int, row_count: int -) -> Dict: +) -> dict: params = { "credential": credential, "name": table_name, @@ -97,7 +96,7 @@ async def validate_schema_and_update_table( team_id: int, schema_id: uuid.UUID, table_schema: TSchemaTables, - table_row_counts: Dict[str, int], + table_row_counts: dict[str, int], ) -> None: """ @@ -167,7 +166,7 @@ async def validate_schema_and_update_table( for schema in table_schema.values(): if schema.get("resource") == _schema_name: schema_columns = schema.get("columns") or {} - db_columns: Dict[str, str] = await sync_to_async(table_created.get_columns)() + db_columns: dict[str, str] = await sync_to_async(table_created.get_columns)() columns = {} for column_name, db_column_type in db_columns.items(): diff --git a/posthog/warehouse/external_data_source/source.py b/posthog/warehouse/external_data_source/source.py index f722bae1f33..99e49a39a1d 100644 --- a/posthog/warehouse/external_data_source/source.py +++ b/posthog/warehouse/external_data_source/source.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Dict, Optional +from typing import Optional from pydantic import BaseModel, field_validator @@ -71,7 +71,7 @@ def create_stripe_source(payload: StripeSourcePayload, workspace_id: str) -> Ext return _create_source(payload) -def _create_source(payload: Dict) -> ExternalDataSource: +def _create_source(payload: dict) -> ExternalDataSource: response = send_request(AIRBYTE_SOURCE_URL, method="POST", payload=payload) return ExternalDataSource( source_id=response["sourceId"], diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py index ffa890ba45b..0513cc3b7d1 100644 --- a/posthog/warehouse/models/datawarehouse_saved_query.py +++ b/posthog/warehouse/models/datawarehouse_saved_query.py @@ -1,5 +1,4 @@ import re -from typing import Dict from sentry_sdk import capture_exception from django.core.exceptions import ValidationError from django.db import models @@ -47,7 +46,7 @@ class DataWarehouseSavedQuery(CreatedMetaFields, UUIDModel, DeletedMetaFields): ) ] - def get_columns(self) -> Dict[str, str]: + def get_columns(self) -> dict[str, str]: from posthog.api.services.query import process_query # TODO: catch and raise error diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index ed883f6d623..045a4e10d8a 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.db import models from posthog.models.team import Team @@ -80,7 +80,7 @@ def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, t ExternalDataSchema.objects.create(name=schema, team_id=team_id, source_id=source_id, should_sync=False) -def get_postgres_schemas(host: str, port: str, database: str, user: str, password: str, schema: str) -> List[Any]: +def get_postgres_schemas(host: str, port: str, database: str, user: str, password: str, schema: str) -> list[Any]: connection = psycopg.Connection.connect( host=host, port=int(port), diff --git a/posthog/warehouse/models/external_table_definitions.py b/posthog/warehouse/models/external_table_definitions.py index 405ffa150e6..6a684d96eca 100644 --- a/posthog/warehouse/models/external_table_definitions.py +++ b/posthog/warehouse/models/external_table_definitions.py @@ -1,4 +1,3 @@ -from typing import Dict from posthog.hogql import ast from posthog.hogql.database.models import ( BooleanDatabaseField, @@ -10,7 +9,7 @@ from posthog.hogql.database.models import ( ) -external_tables: Dict[str, Dict[str, FieldOrTable]] = { +external_tables: dict[str, dict[str, FieldOrTable]] = { "*": { "__dlt_id": StringDatabaseField(name="_dlt_id", hidden=True), "__dlt_load_id": StringDatabaseField(name="_dlt_load_id", hidden=True), diff --git a/posthog/warehouse/models/join.py b/posthog/warehouse/models/join.py index 5a3e46658fd..d3edfb864c4 100644 --- a/posthog/warehouse/models/join.py +++ b/posthog/warehouse/models/join.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from warnings import warn from django.db import models @@ -45,7 +45,7 @@ class DataWarehouseJoin(CreatedMetaFields, UUIDModel, DeletedMetaFields): def _join_function( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index 2b4609e79a6..229c81168a8 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from django.db import models from posthog.client import sync_execute @@ -111,7 +111,7 @@ class DataWarehouseTable(CreatedMetaFields, UUIDModel, DeletedMetaFields): prefix = "" return self.name[len(prefix) :] - def get_columns(self, safe_expose_ch_error=True) -> Dict[str, str]: + def get_columns(self, safe_expose_ch_error=True) -> dict[str, str]: try: result = sync_execute( """DESCRIBE TABLE ( @@ -160,7 +160,7 @@ class DataWarehouseTable(CreatedMetaFields, UUIDModel, DeletedMetaFields): if not self.columns: raise Exception("Columns must be fetched and saved to use in HogQL.") - fields: Dict[str, FieldOrTable] = {} + fields: dict[str, FieldOrTable] = {} structure = [] for column, type in self.columns.items(): # Support for 'old' style columns diff --git a/posthog/year_in_posthog/calculate_2023.py b/posthog/year_in_posthog/calculate_2023.py index 29477cfd150..03428d2711d 100644 --- a/posthog/year_in_posthog/calculate_2023.py +++ b/posthog/year_in_posthog/calculate_2023.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Dict, Optional +from typing import Optional from django.conf import settings from django.db import connection @@ -147,7 +147,7 @@ def dictfetchall(cursor): @cache_for(timedelta(seconds=0 if settings.DEBUG else 30)) -def calculate_year_in_posthog_2023(user_uuid: str) -> Optional[Dict]: +def calculate_year_in_posthog_2023(user_uuid: str) -> Optional[dict]: with connection.cursor() as cursor: cursor.execute(query, {"user_uuid": user_uuid}) rows = dictfetchall(cursor) diff --git a/posthog/year_in_posthog/year_in_posthog.py b/posthog/year_in_posthog/year_in_posthog.py index 3bf05d821c2..a6ac65fa2fd 100644 --- a/posthog/year_in_posthog/year_in_posthog.py +++ b/posthog/year_in_posthog/year_in_posthog.py @@ -2,7 +2,7 @@ from django.http import HttpResponse from django.template.loader import get_template from django.views.decorators.cache import cache_control import os -from typing import Dict, List, Union +from typing import Union import structlog @@ -58,7 +58,7 @@ explanation = { } -def stats_for_user(data: Dict) -> List[Dict[str, Union[int, str]]]: +def stats_for_user(data: dict) -> list[dict[str, Union[int, str]]]: stats = data["stats"] return [ @@ -75,7 +75,7 @@ def stats_for_user(data: Dict) -> List[Dict[str, Union[int, str]]]: ] -def sort_list_based_on_preference(badges: List[str]) -> str: +def sort_list_based_on_preference(badges: list[str]) -> str: """sort a list based on its order in badge_preferences and then choose the last one""" if len(badges) >= 3: return "champion" diff --git a/pyproject.toml b/pyproject.toml index 2701b5a74d6..cb19ccadb81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[project] +requires-python = ">=3.10" + [tool.black] line-length = 120 target-version = ['py310'] @@ -28,6 +31,8 @@ ignore = [ "F403", "F541", "F601", + "UP007", + "UP032", ] select = [ "B", @@ -40,6 +45,7 @@ select = [ "RUF015", "RUF019", "T2", + "UP", "W", ]