mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-24 00:47:50 +01:00
feat(queries): Add timings to async queries and simplify code (#18963)
This commit is contained in:
parent
1bf890b080
commit
ef1a10b41e
@ -39,7 +39,6 @@ from posthog.rate_limit import (
|
||||
AISustainedRateThrottle,
|
||||
TeamRateThrottle,
|
||||
)
|
||||
from posthog.schema import QueryStatus
|
||||
from posthog.utils import refresh_requested_by_client
|
||||
|
||||
|
||||
@ -127,13 +126,13 @@ class QueryViewSet(StructuredViewSetMixin, viewsets.ViewSet):
|
||||
self._tag_client_query_id(client_query_id)
|
||||
|
||||
if query_async:
|
||||
query_id = enqueue_process_query_task(
|
||||
query_status = enqueue_process_query_task(
|
||||
team_id=self.team.pk,
|
||||
query_json=query_json,
|
||||
query_id=client_query_id,
|
||||
refresh_requested=refresh_requested,
|
||||
)
|
||||
return JsonResponse(QueryStatus(id=query_id, team_id=self.team.pk).model_dump(), safe=False)
|
||||
return JsonResponse(query_status.model_dump(), safe=False)
|
||||
|
||||
try:
|
||||
result = process_query(self.team, query_json, refresh_requested=refresh_requested)
|
||||
|
@ -415,7 +415,6 @@ def process_query_task(self, team_id, query_id, query_json, in_export_context=Fa
|
||||
query_json=query_json,
|
||||
in_export_context=in_export_context,
|
||||
refresh_requested=refresh_requested,
|
||||
task_id=self.request.id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
from prometheus_client import Histogram
|
||||
from rest_framework.exceptions import NotFound
|
||||
|
||||
from posthog import celery, redis
|
||||
@ -12,8 +13,8 @@ from posthog.schema import QueryStatus
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
REDIS_STATUS_TTL_SECONDS = 600 # 10 minutes
|
||||
REDIS_KEY_PREFIX_ASYNC_RESULTS = "query_async"
|
||||
QUERY_WAIT_TIME = Histogram("query_wait_time_seconds", "Time from query creation to pick-up")
|
||||
QUERY_PROCESS_TIME = Histogram("query_process_time_seconds", "Time from query pick-up to result")
|
||||
|
||||
|
||||
class QueryNotFoundError(NotFound):
|
||||
@ -24,8 +25,44 @@ class QueryRetrievalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def generate_redis_results_key(query_id: str, team_id: int) -> str:
|
||||
return f"{REDIS_KEY_PREFIX_ASYNC_RESULTS}:{team_id}:{query_id}"
|
||||
class QueryStatusManager:
|
||||
STATUS_TTL_SECONDS = 600 # 10 minutes
|
||||
KEY_PREFIX_ASYNC_RESULTS = "query_async"
|
||||
|
||||
def __init__(self, query_id: str, team_id: int):
|
||||
self.redis_client = redis.get_client()
|
||||
self.query_id = query_id
|
||||
self.team_id = team_id
|
||||
|
||||
@property
|
||||
def results_key(self) -> str:
|
||||
return f"{self.KEY_PREFIX_ASYNC_RESULTS}:{self.team_id}:{self.query_id}"
|
||||
|
||||
def store_query_status(self, query_status: QueryStatus):
|
||||
self.redis_client.set(self.results_key, query_status.model_dump_json(), ex=self.STATUS_TTL_SECONDS)
|
||||
|
||||
def _get_results(self):
|
||||
try:
|
||||
byte_results = self.redis_client.get(self.results_key)
|
||||
except Exception as e:
|
||||
raise QueryRetrievalError(f"Error retrieving query {self.query_id} for team {self.team_id}") from e
|
||||
|
||||
return byte_results
|
||||
|
||||
def has_results(self):
|
||||
return self._get_results() is not None
|
||||
|
||||
def get_query_status(self) -> QueryStatus:
|
||||
byte_results = self._get_results()
|
||||
|
||||
if not byte_results:
|
||||
raise QueryNotFoundError(f"Query {self.query_id} not found for team {self.team_id}")
|
||||
|
||||
return QueryStatus(**json.loads(byte_results))
|
||||
|
||||
def delete_query_status(self):
|
||||
logger.info("Deleting redis query key %s", self.results_key)
|
||||
self.redis_client.delete(self.results_key)
|
||||
|
||||
|
||||
def execute_process_query(
|
||||
@ -34,25 +71,21 @@ def execute_process_query(
|
||||
query_json,
|
||||
in_export_context,
|
||||
refresh_requested,
|
||||
task_id=None,
|
||||
):
|
||||
key = generate_redis_results_key(query_id, team_id)
|
||||
redis_client = redis.get_client()
|
||||
manager = QueryStatusManager(query_id, team_id)
|
||||
|
||||
from posthog.models import Team
|
||||
from posthog.api.services.query import process_query
|
||||
|
||||
team = Team.objects.get(pk=team_id)
|
||||
|
||||
query_status = QueryStatus(
|
||||
id=query_id,
|
||||
team_id=team_id,
|
||||
task_id=task_id,
|
||||
complete=False,
|
||||
error=True, # Assume error in case nothing below ends up working
|
||||
start_time=datetime.datetime.utcnow(),
|
||||
)
|
||||
value = query_status.model_dump_json()
|
||||
query_status = manager.get_query_status()
|
||||
query_status.error = True # Assume error in case nothing below ends up working
|
||||
|
||||
pickup_time = datetime.datetime.utcnow()
|
||||
if query_status.start_time:
|
||||
pickup_duration = (pickup_time - query_status.start_time).total_seconds()
|
||||
QUERY_WAIT_TIME.observe(pickup_duration)
|
||||
|
||||
try:
|
||||
tag_queries(client_query_id=query_id, team_id=team_id)
|
||||
@ -63,17 +96,17 @@ def execute_process_query(
|
||||
query_status.complete = True
|
||||
query_status.error = False
|
||||
query_status.results = results
|
||||
query_status.expiration_time = datetime.datetime.utcnow() + datetime.timedelta(seconds=REDIS_STATUS_TTL_SECONDS)
|
||||
query_status.end_time = datetime.datetime.utcnow()
|
||||
value = query_status.model_dump_json()
|
||||
query_status.expiration_time = query_status.end_time + datetime.timedelta(seconds=manager.STATUS_TTL_SECONDS)
|
||||
process_duration = (query_status.end_time - pickup_time).total_seconds()
|
||||
QUERY_PROCESS_TIME.observe(process_duration)
|
||||
except Exception as err:
|
||||
query_status.results = None # Clear results in case they are faulty
|
||||
query_status.error_message = str(err)
|
||||
logger.error("Error processing query for team %s query %s: %s", team_id, query_id, err)
|
||||
value = query_status.model_dump_json()
|
||||
raise err
|
||||
finally:
|
||||
redis_client.set(key, value, ex=REDIS_STATUS_TTL_SECONDS)
|
||||
manager.store_query_status(query_status)
|
||||
|
||||
|
||||
def enqueue_process_query_task(
|
||||
@ -83,34 +116,22 @@ def enqueue_process_query_task(
|
||||
refresh_requested=False,
|
||||
bypass_celery=False,
|
||||
force=False,
|
||||
):
|
||||
) -> QueryStatus:
|
||||
if not query_id:
|
||||
query_id = uuid.uuid4().hex
|
||||
|
||||
key = generate_redis_results_key(query_id, team_id)
|
||||
redis_client = redis.get_client()
|
||||
manager = QueryStatusManager(query_id, team_id)
|
||||
|
||||
if force:
|
||||
# If we want to force rerun of this query we need to
|
||||
# 1) Get the current status from redis
|
||||
task_str = redis_client.get(key)
|
||||
if task_str:
|
||||
# if the status exists in redis we need to tell celery to kill the job
|
||||
task_str = task_str.decode("utf-8")
|
||||
query_task = QueryStatus(**json.loads(task_str))
|
||||
# Instruct celery to revoke task and terminate if running
|
||||
celery.app.control.revoke(query_task.task_id, terminate=True)
|
||||
# Then we need to make redis forget about this job entirely
|
||||
# and continue as normal. As if we never saw this query before
|
||||
redis_client.delete(key)
|
||||
cancel_query(team_id, query_id)
|
||||
|
||||
if redis_client.get(key):
|
||||
# If we've seen this query before return the query_id and don't resubmit it.
|
||||
return query_id
|
||||
if manager.has_results() and not refresh_requested:
|
||||
# If we've seen this query before return and don't resubmit it.
|
||||
return manager.get_query_status()
|
||||
|
||||
# Immediately set status, so we don't have race with celery
|
||||
query_status = QueryStatus(id=query_id, team_id=team_id)
|
||||
redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS)
|
||||
query_status = QueryStatus(id=query_id, team_id=team_id, start_time=datetime.datetime.utcnow())
|
||||
manager.store_query_status(query_status)
|
||||
|
||||
if bypass_celery:
|
||||
# Call directly ( for testing )
|
||||
@ -120,29 +141,24 @@ def enqueue_process_query_task(
|
||||
team_id, query_id, query_json, in_export_context=True, refresh_requested=refresh_requested
|
||||
)
|
||||
query_status.task_id = task.id
|
||||
redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS)
|
||||
manager.store_query_status(query_status)
|
||||
|
||||
return query_id
|
||||
return query_status
|
||||
|
||||
|
||||
def get_query_status(team_id, query_id):
|
||||
redis_client = redis.get_client()
|
||||
key = generate_redis_results_key(query_id, team_id)
|
||||
|
||||
try:
|
||||
byte_results = redis_client.get(key)
|
||||
except Exception as e:
|
||||
raise QueryRetrievalError(f"Error retrieving query {query_id} for team {team_id}") from e
|
||||
|
||||
if not byte_results:
|
||||
raise QueryNotFoundError(f"Query {query_id} not found for team {team_id}")
|
||||
|
||||
return QueryStatus(**json.loads(byte_results))
|
||||
"""
|
||||
Abstracts away the manager for any caller and returns a QueryStatus object
|
||||
"""
|
||||
manager = QueryStatusManager(query_id, team_id)
|
||||
return manager.get_query_status()
|
||||
|
||||
|
||||
def cancel_query(team_id, query_id):
|
||||
manager = QueryStatusManager(query_id, team_id)
|
||||
|
||||
try:
|
||||
query_status = get_query_status(team_id, query_id)
|
||||
query_status = manager.get_query_status()
|
||||
|
||||
if query_status.task_id:
|
||||
logger.info("Got task id %s, attempting to revoke", query_status.task_id)
|
||||
@ -157,9 +173,6 @@ def cancel_query(team_id, query_id):
|
||||
|
||||
cancel_query_on_cluster(team_id, query_id)
|
||||
|
||||
redis_client = redis.get_client()
|
||||
key = generate_redis_results_key(query_id, team_id)
|
||||
logger.info("Deleting redis query key %s", key)
|
||||
redis_client.delete(key)
|
||||
manager.delete_query_status()
|
||||
|
||||
return True
|
||||
|
@ -27,7 +27,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin):
|
||||
def test_async_query_client(self):
|
||||
query = build_query("SELECT 1+1")
|
||||
team_id = self.team_id
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True)
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
|
||||
result = client.get_query_status(team_id, query_id)
|
||||
self.assertFalse(result.error, result.error_message)
|
||||
self.assertTrue(result.complete)
|
||||
@ -53,7 +53,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin):
|
||||
def test_async_query_client_uuid(self):
|
||||
query = build_query("SELECT toUUID('00000000-0000-0000-0000-000000000000')")
|
||||
team_id = self.team_id
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True)
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
|
||||
result = client.get_query_status(team_id, query_id)
|
||||
self.assertFalse(result.error, result.error_message)
|
||||
self.assertTrue(result.complete)
|
||||
@ -63,7 +63,7 @@ class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin):
|
||||
query = build_query("SELECT 1+1")
|
||||
team_id = self.team_id
|
||||
wrong_team = 5
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True)
|
||||
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
|
||||
|
||||
try:
|
||||
client.get_query_status(wrong_team, query_id)
|
||||
|
Loading…
Reference in New Issue
Block a user