0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-28 18:26:15 +01:00
posthog/ee/hogai/summarizer/nodes.py
Michael Matloka 23bd1a010f
feat(max): Summarize insight results (#26172)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
2024-11-18 11:05:54 +01:00

96 lines
4.7 KiB
Python

import json
from time import sleep
from django.conf import settings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from django.core.serializers.json import DjangoJSONEncoder
from rest_framework.exceptions import APIException
from sentry_sdk import capture_exception
from ee.hogai.summarizer.prompts import SUMMARIZER_SYSTEM_PROMPT, SUMMARIZER_INSTRUCTION_PROMPT
from ee.hogai.utils import AssistantNode, AssistantNodeName, AssistantState
from posthog.api.services.query import process_query_dict
from posthog.clickhouse.client.execute_async import get_query_status
from posthog.errors import ExposedCHQueryError
from posthog.hogql.errors import ExposedHogQLError
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, VisualizationMessage
class SummarizerNode(AssistantNode):
name = AssistantNodeName.SUMMARIZER
def run(self, state: AssistantState, config: RunnableConfig):
viz_message = state["messages"][-1]
if not isinstance(viz_message, VisualizationMessage):
raise ValueError("Can only run summarization with a visualization message as the last one in the state")
if viz_message.answer is None:
raise ValueError("Did not found query in the visualization message")
try:
results_response = process_query_dict( # type: ignore
self._team, # TODO: Add user
viz_message.answer.model_dump(mode="json"), # We need mode="json" so that
# Celery doesn't run in tests, so there we use force_blocking instead
# This does mean that the waiting logic is not tested
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE
if not settings.TEST
else ExecutionMode.CALCULATE_BLOCKING_ALWAYS,
).model_dump(mode="json")
if results_response.get("query_status") and not results_response["query_status"]["complete"]:
query_id = results_response["query_status"]["id"]
for i in range(0, 999):
sleep(i / 2) # We start at 0.5s and every iteration we wait 0.5s more
query_status = get_query_status(team_id=self._team.pk, query_id=query_id)
if query_status.error:
if query_status.error_message:
raise APIException(query_status.error_message)
else:
raise ValueError("Query failed")
if query_status.complete:
results_response = query_status.results
break
except (APIException, ExposedHogQLError, ExposedCHQueryError) as err:
err_message = str(err)
if isinstance(err, APIException):
if isinstance(err.detail, dict):
err_message = ", ".join(f"{key}: {value}" for key, value in err.detail.items())
elif isinstance(err.detail, list):
err_message = ", ".join(map(str, err.detail))
return {"messages": [FailureMessage(content=f"There was an error running this query: {err_message}")]}
except Exception as err:
capture_exception(err)
return {"messages": [FailureMessage(content="There was an unknown error running this query.")]}
summarization_prompt = ChatPromptTemplate(self._construct_messages(state), template_format="mustache")
chain = summarization_prompt | self._model
message = chain.invoke(
{
"query_kind": viz_message.answer.kind,
"product_description": self._team.project.product_description,
"results": json.dumps(results_response["results"], cls=DjangoJSONEncoder),
},
config,
)
return {"messages": [AssistantMessage(content=str(message.content), done=True)]}
@property
def _model(self):
return ChatOpenAI(model="gpt-4o", temperature=0.5, streaming=True) # Slightly higher temp than earlier steps
def _construct_messages(self, state: AssistantState) -> list[tuple[str, str]]:
conversation: list[tuple[str, str]] = [("system", SUMMARIZER_SYSTEM_PROMPT)]
for message in state.get("messages", []):
if isinstance(message, HumanMessage):
conversation.append(("human", message.content))
elif isinstance(message, AssistantMessage):
conversation.append(("assistant", message.content))
conversation.append(("human", SUMMARIZER_INSTRUCTION_PROMPT))
return conversation