From ef1a10b41e28493afb57d137bd65a680234f3405 Mon Sep 17 00:00:00 2001 From: Julian Bez Date: Thu, 30 Nov 2023 09:26:26 +0000 Subject: [PATCH] feat(queries): Add timings to async queries and simplify code (#18963) --- posthog/api/query.py | 5 +- posthog/celery.py | 1 - posthog/clickhouse/client/execute_async.py | 131 ++++++++++-------- .../client/test/test_execute_async.py | 6 +- 4 files changed, 77 insertions(+), 66 deletions(-) diff --git a/posthog/api/query.py b/posthog/api/query.py index 021139911cb..a41bc8a5c8a 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -39,7 +39,6 @@ from posthog.rate_limit import ( AISustainedRateThrottle, TeamRateThrottle, ) -from posthog.schema import QueryStatus from posthog.utils import refresh_requested_by_client @@ -127,13 +126,13 @@ class QueryViewSet(StructuredViewSetMixin, viewsets.ViewSet): self._tag_client_query_id(client_query_id) if query_async: - query_id = enqueue_process_query_task( + query_status = enqueue_process_query_task( team_id=self.team.pk, query_json=query_json, query_id=client_query_id, refresh_requested=refresh_requested, ) - return JsonResponse(QueryStatus(id=query_id, team_id=self.team.pk).model_dump(), safe=False) + return JsonResponse(query_status.model_dump(), safe=False) try: result = process_query(self.team, query_json, refresh_requested=refresh_requested) diff --git a/posthog/celery.py b/posthog/celery.py index de7c7a2c34b..356a77c385f 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -415,7 +415,6 @@ def process_query_task(self, team_id, query_id, query_json, in_export_context=Fa query_json=query_json, in_export_context=in_export_context, refresh_requested=refresh_requested, - task_id=self.request.id, ) diff --git a/posthog/clickhouse/client/execute_async.py b/posthog/clickhouse/client/execute_async.py index d44c20b7072..34a2a8f25b4 100644 --- a/posthog/clickhouse/client/execute_async.py +++ b/posthog/clickhouse/client/execute_async.py @@ -3,6 +3,7 @@ import json import uuid import structlog +from prometheus_client import Histogram from rest_framework.exceptions import NotFound from posthog import celery, redis @@ -12,8 +13,8 @@ from posthog.schema import QueryStatus logger = structlog.get_logger(__name__) -REDIS_STATUS_TTL_SECONDS = 600 # 10 minutes -REDIS_KEY_PREFIX_ASYNC_RESULTS = "query_async" +QUERY_WAIT_TIME = Histogram("query_wait_time_seconds", "Time from query creation to pick-up") +QUERY_PROCESS_TIME = Histogram("query_process_time_seconds", "Time from query pick-up to result") class QueryNotFoundError(NotFound): @@ -24,8 +25,44 @@ class QueryRetrievalError(Exception): pass -def generate_redis_results_key(query_id: str, team_id: int) -> str: - return f"{REDIS_KEY_PREFIX_ASYNC_RESULTS}:{team_id}:{query_id}" +class QueryStatusManager: + STATUS_TTL_SECONDS = 600 # 10 minutes + KEY_PREFIX_ASYNC_RESULTS = "query_async" + + def __init__(self, query_id: str, team_id: int): + self.redis_client = redis.get_client() + self.query_id = query_id + self.team_id = team_id + + @property + def results_key(self) -> str: + return f"{self.KEY_PREFIX_ASYNC_RESULTS}:{self.team_id}:{self.query_id}" + + def store_query_status(self, query_status: QueryStatus): + self.redis_client.set(self.results_key, query_status.model_dump_json(), ex=self.STATUS_TTL_SECONDS) + + def _get_results(self): + try: + byte_results = self.redis_client.get(self.results_key) + except Exception as e: + raise QueryRetrievalError(f"Error retrieving query {self.query_id} for team {self.team_id}") from e + + return byte_results + + def has_results(self): + return self._get_results() is not None + + def get_query_status(self) -> QueryStatus: + byte_results = self._get_results() + + if not byte_results: + raise QueryNotFoundError(f"Query {self.query_id} not found for team {self.team_id}") + + return QueryStatus(**json.loads(byte_results)) + + def delete_query_status(self): + logger.info("Deleting redis query key %s", self.results_key) + self.redis_client.delete(self.results_key) def execute_process_query( @@ -34,25 +71,21 @@ def execute_process_query( query_json, in_export_context, refresh_requested, - task_id=None, ): - key = generate_redis_results_key(query_id, team_id) - redis_client = redis.get_client() + manager = QueryStatusManager(query_id, team_id) from posthog.models import Team from posthog.api.services.query import process_query team = Team.objects.get(pk=team_id) - query_status = QueryStatus( - id=query_id, - team_id=team_id, - task_id=task_id, - complete=False, - error=True, # Assume error in case nothing below ends up working - start_time=datetime.datetime.utcnow(), - ) - value = query_status.model_dump_json() + query_status = manager.get_query_status() + query_status.error = True # Assume error in case nothing below ends up working + + pickup_time = datetime.datetime.utcnow() + if query_status.start_time: + pickup_duration = (pickup_time - query_status.start_time).total_seconds() + QUERY_WAIT_TIME.observe(pickup_duration) try: tag_queries(client_query_id=query_id, team_id=team_id) @@ -63,17 +96,17 @@ def execute_process_query( query_status.complete = True query_status.error = False query_status.results = results - query_status.expiration_time = datetime.datetime.utcnow() + datetime.timedelta(seconds=REDIS_STATUS_TTL_SECONDS) query_status.end_time = datetime.datetime.utcnow() - value = query_status.model_dump_json() + query_status.expiration_time = query_status.end_time + datetime.timedelta(seconds=manager.STATUS_TTL_SECONDS) + process_duration = (query_status.end_time - pickup_time).total_seconds() + QUERY_PROCESS_TIME.observe(process_duration) except Exception as err: query_status.results = None # Clear results in case they are faulty query_status.error_message = str(err) logger.error("Error processing query for team %s query %s: %s", team_id, query_id, err) - value = query_status.model_dump_json() raise err finally: - redis_client.set(key, value, ex=REDIS_STATUS_TTL_SECONDS) + manager.store_query_status(query_status) def enqueue_process_query_task( @@ -83,34 +116,22 @@ def enqueue_process_query_task( refresh_requested=False, bypass_celery=False, force=False, -): +) -> QueryStatus: if not query_id: query_id = uuid.uuid4().hex - key = generate_redis_results_key(query_id, team_id) - redis_client = redis.get_client() + manager = QueryStatusManager(query_id, team_id) if force: - # If we want to force rerun of this query we need to - # 1) Get the current status from redis - task_str = redis_client.get(key) - if task_str: - # if the status exists in redis we need to tell celery to kill the job - task_str = task_str.decode("utf-8") - query_task = QueryStatus(**json.loads(task_str)) - # Instruct celery to revoke task and terminate if running - celery.app.control.revoke(query_task.task_id, terminate=True) - # Then we need to make redis forget about this job entirely - # and continue as normal. As if we never saw this query before - redis_client.delete(key) + cancel_query(team_id, query_id) - if redis_client.get(key): - # If we've seen this query before return the query_id and don't resubmit it. - return query_id + if manager.has_results() and not refresh_requested: + # If we've seen this query before return and don't resubmit it. + return manager.get_query_status() # Immediately set status, so we don't have race with celery - query_status = QueryStatus(id=query_id, team_id=team_id) - redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS) + query_status = QueryStatus(id=query_id, team_id=team_id, start_time=datetime.datetime.utcnow()) + manager.store_query_status(query_status) if bypass_celery: # Call directly ( for testing ) @@ -120,29 +141,24 @@ def enqueue_process_query_task( team_id, query_id, query_json, in_export_context=True, refresh_requested=refresh_requested ) query_status.task_id = task.id - redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS) + manager.store_query_status(query_status) - return query_id + return query_status def get_query_status(team_id, query_id): - redis_client = redis.get_client() - key = generate_redis_results_key(query_id, team_id) - - try: - byte_results = redis_client.get(key) - except Exception as e: - raise QueryRetrievalError(f"Error retrieving query {query_id} for team {team_id}") from e - - if not byte_results: - raise QueryNotFoundError(f"Query {query_id} not found for team {team_id}") - - return QueryStatus(**json.loads(byte_results)) + """ + Abstracts away the manager for any caller and returns a QueryStatus object + """ + manager = QueryStatusManager(query_id, team_id) + return manager.get_query_status() def cancel_query(team_id, query_id): + manager = QueryStatusManager(query_id, team_id) + try: - query_status = get_query_status(team_id, query_id) + query_status = manager.get_query_status() if query_status.task_id: logger.info("Got task id %s, attempting to revoke", query_status.task_id) @@ -157,9 +173,6 @@ def cancel_query(team_id, query_id): cancel_query_on_cluster(team_id, query_id) - redis_client = redis.get_client() - key = generate_redis_results_key(query_id, team_id) - logger.info("Deleting redis query key %s", key) - redis_client.delete(key) + manager.delete_query_status() return True diff --git a/posthog/clickhouse/client/test/test_execute_async.py b/posthog/clickhouse/client/test/test_execute_async.py index 4958c23b3f0..0bfc926cfb9 100644 --- a/posthog/clickhouse/client/test/test_execute_async.py +++ b/posthog/clickhouse/client/test/test_execute_async.py @@ -27,7 +27,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin): def test_async_query_client(self): query = build_query("SELECT 1+1") team_id = self.team_id - query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id result = client.get_query_status(team_id, query_id) self.assertFalse(result.error, result.error_message) self.assertTrue(result.complete) @@ -53,7 +53,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin): def test_async_query_client_uuid(self): query = build_query("SELECT toUUID('00000000-0000-0000-0000-000000000000')") team_id = self.team_id - query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id result = client.get_query_status(team_id, query_id) self.assertFalse(result.error, result.error_message) self.assertTrue(result.complete) @@ -63,7 +63,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin): query = build_query("SELECT 1+1") team_id = self.team_id wrong_team = 5 - query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id try: client.get_query_status(wrong_team, query_id)