0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-24 09:14:46 +01:00

feat(max): Summarize insight results (#26172)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Michael Matloka 2024-11-18 11:05:54 +01:00 committed by GitHub
parent 04f656bc67
commit 23bd1a010f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 596 additions and 171 deletions

View File

@ -5,6 +5,7 @@ from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel
from sentry_sdk import capture_exception
from ee import settings
from ee.hogai.funnels.nodes import (
@ -15,6 +16,7 @@ from ee.hogai.funnels.nodes import (
)
from ee.hogai.router.nodes import RouterNode
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode
from ee.hogai.summarizer.nodes import SummarizerNode
from ee.hogai.trends.nodes import (
TrendsGeneratorNode,
TrendsGeneratorToolsNode,
@ -26,6 +28,8 @@ from posthog.models.team.team import Team
from posthog.schema import (
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
FailureMessage,
VisualizationMessage,
)
@ -123,7 +127,7 @@ class Assistant:
generate_trends_node.router,
path_map={
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
"next": AssistantNodeName.END,
"next": AssistantNodeName.SUMMARIZER,
},
)
@ -160,10 +164,14 @@ class Assistant:
generate_trends_node.router,
path_map={
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
"next": AssistantNodeName.END,
"next": AssistantNodeName.SUMMARIZER,
},
)
summarizer_node = SummarizerNode(self._team)
builder.add_node(AssistantNodeName.SUMMARIZER, summarizer_node.run)
builder.add_edge(AssistantNodeName.SUMMARIZER, AssistantNodeName.END)
return builder.compile()
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
@ -185,33 +193,47 @@ class Assistant:
# Send a chunk to establish the connection avoiding the worker's timeout.
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
for update in generator:
if is_state_update(update):
_, new_state = update
state = new_state
try:
for update in generator:
if is_state_update(update):
_, new_state = update
state = new_state
elif is_value_update(update):
_, state_update = update
elif is_value_update(update):
_, state_update = update
if AssistantNodeName.ROUTER in state_update and "messages" in state_update[AssistantNodeName.ROUTER]:
yield state_update[AssistantNodeName.ROUTER]["messages"][0]
elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys():
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")
if (
AssistantNodeName.ROUTER in state_update
and "messages" in state_update[AssistantNodeName.ROUTER]
):
yield state_update[AssistantNodeName.ROUTER]["messages"][0]
elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys():
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")
node_name = intersected_nodes.pop()
if "messages" in state_update[node_name]:
yield state_update[node_name]["messages"][0]
elif state_update[node_name].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
elif is_message_update(update):
langchain_message, langgraph_state = update[1]
for node_name, viz_node in VISUALIZATION_NODES.items():
if langgraph_state["langgraph_node"] == node_name and isinstance(langchain_message, AIMessageChunk):
chunks += langchain_message # type: ignore
parsed_message = viz_node.parse_output(chunks.tool_calls[0]["args"])
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
node_name = intersected_nodes.pop()
if "messages" in state_update[node_name]:
yield state_update[node_name]["messages"][0]
elif state_update[node_name].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
elif AssistantNodeName.SUMMARIZER in state_update:
chunks = AIMessageChunk(content="")
yield state_update[AssistantNodeName.SUMMARIZER]["messages"][0]
elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if isinstance(langchain_message, AIMessageChunk):
if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys():
chunks += langchain_message # type: ignore
parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output(
chunks.tool_calls[0]["args"]
)
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
)
elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER:
chunks += langchain_message # type: ignore
yield AssistantMessage(content=chunks.content)
except Exception as e:
capture_exception(e)
yield FailureMessage() # This is an unhandled error, so we just stop further generation at this point

View File

@ -33,7 +33,9 @@ class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])],
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"intermediate_steps": None,
},
)

View File

@ -101,6 +101,7 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
plan=generated_plan,
reasoning_steps=message.reasoning_steps,
answer=message.answer,
done=True,
)
],
"intermediate_steps": None,

View File

@ -54,7 +54,9 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])],
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"intermediate_steps": None,
},
)

View File

View File

@ -0,0 +1,95 @@
import json
from time import sleep
from django.conf import settings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from django.core.serializers.json import DjangoJSONEncoder
from rest_framework.exceptions import APIException
from sentry_sdk import capture_exception
from ee.hogai.summarizer.prompts import SUMMARIZER_SYSTEM_PROMPT, SUMMARIZER_INSTRUCTION_PROMPT
from ee.hogai.utils import AssistantNode, AssistantNodeName, AssistantState
from posthog.api.services.query import process_query_dict
from posthog.clickhouse.client.execute_async import get_query_status
from posthog.errors import ExposedCHQueryError
from posthog.hogql.errors import ExposedHogQLError
from posthog.hogql_queries.query_runner import ExecutionMode
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, VisualizationMessage
class SummarizerNode(AssistantNode):
name = AssistantNodeName.SUMMARIZER
def run(self, state: AssistantState, config: RunnableConfig):
viz_message = state["messages"][-1]
if not isinstance(viz_message, VisualizationMessage):
raise ValueError("Can only run summarization with a visualization message as the last one in the state")
if viz_message.answer is None:
raise ValueError("Did not found query in the visualization message")
try:
results_response = process_query_dict( # type: ignore
self._team, # TODO: Add user
viz_message.answer.model_dump(mode="json"), # We need mode="json" so that
# Celery doesn't run in tests, so there we use force_blocking instead
# This does mean that the waiting logic is not tested
execution_mode=ExecutionMode.RECENT_CACHE_CALCULATE_ASYNC_IF_STALE
if not settings.TEST
else ExecutionMode.CALCULATE_BLOCKING_ALWAYS,
).model_dump(mode="json")
if results_response.get("query_status") and not results_response["query_status"]["complete"]:
query_id = results_response["query_status"]["id"]
for i in range(0, 999):
sleep(i / 2) # We start at 0.5s and every iteration we wait 0.5s more
query_status = get_query_status(team_id=self._team.pk, query_id=query_id)
if query_status.error:
if query_status.error_message:
raise APIException(query_status.error_message)
else:
raise ValueError("Query failed")
if query_status.complete:
results_response = query_status.results
break
except (APIException, ExposedHogQLError, ExposedCHQueryError) as err:
err_message = str(err)
if isinstance(err, APIException):
if isinstance(err.detail, dict):
err_message = ", ".join(f"{key}: {value}" for key, value in err.detail.items())
elif isinstance(err.detail, list):
err_message = ", ".join(map(str, err.detail))
return {"messages": [FailureMessage(content=f"There was an error running this query: {err_message}")]}
except Exception as err:
capture_exception(err)
return {"messages": [FailureMessage(content="There was an unknown error running this query.")]}
summarization_prompt = ChatPromptTemplate(self._construct_messages(state), template_format="mustache")
chain = summarization_prompt | self._model
message = chain.invoke(
{
"query_kind": viz_message.answer.kind,
"product_description": self._team.project.product_description,
"results": json.dumps(results_response["results"], cls=DjangoJSONEncoder),
},
config,
)
return {"messages": [AssistantMessage(content=str(message.content), done=True)]}
@property
def _model(self):
return ChatOpenAI(model="gpt-4o", temperature=0.5, streaming=True) # Slightly higher temp than earlier steps
def _construct_messages(self, state: AssistantState) -> list[tuple[str, str]]:
conversation: list[tuple[str, str]] = [("system", SUMMARIZER_SYSTEM_PROMPT)]
for message in state.get("messages", []):
if isinstance(message, HumanMessage):
conversation.append(("human", message.content))
elif isinstance(message, AssistantMessage):
conversation.append(("assistant", message.content))
conversation.append(("human", SUMMARIZER_INSTRUCTION_PROMPT))
return conversation

View File

@ -0,0 +1,17 @@
SUMMARIZER_SYSTEM_PROMPT = """
Act as an expert product manager. Your task is to summarize query results in a a concise way.
Offer actionable feedback if possible. Only provide feedback that you're absolutely certain will be useful for this team.
The product being analyzed is described as follows:
{{product_description}}"""
SUMMARIZER_INSTRUCTION_PROMPT = """
Here are the {{query_kind}} results for this question:
```json
{{results}}
```
Answer my earlier question using the results above. Point out interesting trends or anomalies.
Take into account what you know about my product. If possible, offer actionable feedback, but avoid generic advice.
Limit yourself to a few sentences. The answer needs to be high-impact and relevant for me as a Silicon Valley engineer.
"""

View File

View File

@ -0,0 +1,196 @@
from unittest.mock import patch
from django.test import override_settings
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import (
HumanMessage as LangchainHumanMessage,
)
from ee.hogai.summarizer.nodes import SummarizerNode
from ee.hogai.summarizer.prompts import SUMMARIZER_INSTRUCTION_PROMPT, SUMMARIZER_SYSTEM_PROMPT
from posthog.schema import (
AssistantMessage,
AssistantTrendsEventsNode,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
VisualizationMessage,
)
from rest_framework.exceptions import ValidationError
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
from posthog.api.services.query import process_query_dict
@override_settings(IN_UNIT_TESTING=True)
class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
maxDiff = None
@patch("ee.hogai.summarizer.nodes.process_query_dict", side_effect=process_query_dict)
def test_node_runs(self, mock_process_query_dict):
node = SummarizerNode(self.team)
with patch.object(SummarizerNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: LangchainHumanMessage(content="The results indicate foobar.")
)
new_state = node.run(
{
"messages": [
HumanMessage(content="Text"),
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
"plan": "Plan",
},
{},
)
mock_process_query_dict.assert_called_once() # Query processing started
self.assertEqual(
new_state,
{
"messages": [
AssistantMessage(content="The results indicate foobar.", done=True),
],
},
)
@patch(
"ee.hogai.summarizer.nodes.process_query_dict",
side_effect=ValueError("You have not glibbled the glorp before running this."),
)
def test_node_handles_internal_error(self, mock_process_query_dict):
node = SummarizerNode(self.team)
with patch.object(SummarizerNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: LangchainHumanMessage(content="The results indicate foobar.")
)
new_state = node.run(
{
"messages": [
HumanMessage(content="Text"),
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
"plan": "Plan",
},
{},
)
mock_process_query_dict.assert_called_once() # Query processing started
self.assertEqual(
new_state,
{
"messages": [
FailureMessage(content="There was an unknown error running this query."),
],
},
)
@patch(
"ee.hogai.summarizer.nodes.process_query_dict",
side_effect=ValidationError(
"This query exceeds the capabilities of our picolator. Try de-brolling its flim-flam."
),
)
def test_node_handles_exposed_error(self, mock_process_query_dict):
node = SummarizerNode(self.team)
with patch.object(SummarizerNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: LangchainHumanMessage(content="The results indicate foobar.")
)
new_state = node.run(
{
"messages": [
HumanMessage(content="Text"),
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
"plan": "Plan",
},
{},
)
mock_process_query_dict.assert_called_once() # Query processing started
self.assertEqual(
new_state,
{
"messages": [
FailureMessage(
content=(
"There was an error running this query: This query exceeds the capabilities of our picolator. "
"Try de-brolling its flim-flam."
)
),
],
},
)
def test_node_requires_a_viz_message_in_state(self):
node = SummarizerNode(self.team)
with self.assertRaisesMessage(
ValueError, "Can only run summarization with a visualization message as the last one in the state"
):
node.run(
{
"messages": [
HumanMessage(content="Text"),
],
"plan": "Plan",
},
{},
)
def test_node_requires_viz_message_in_state_to_have_query(self):
node = SummarizerNode(self.team)
with self.assertRaisesMessage(ValueError, "Did not found query in the visualization message"):
node.run(
{
"messages": [
VisualizationMessage(
answer=None,
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
"plan": "Plan",
},
{},
)
def test_agent_reconstructs_conversation(self):
self.project.product_description = "Dating app for lonely hedgehogs."
self.project.save()
node = SummarizerNode(self.team)
history = node._construct_messages(
{
"messages": [
HumanMessage(content="What's the trends in signups?"),
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
]
}
)
self.assertEqual(
history,
[
("system", SUMMARIZER_SYSTEM_PROMPT),
("human", "What's the trends in signups?"),
("human", SUMMARIZER_INSTRUCTION_PROMPT),
],
)

View File

@ -14,6 +14,8 @@ from posthog.test.base import APIBaseTest, ClickhouseTestMixin
@override_settings(IN_UNIT_TESTING=True)
class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
maxDiff = None
def setUp(self):
self.schema = AssistantTrendsQuery(series=[])
@ -33,7 +35,9 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])],
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"intermediate_steps": None,
},
)

View File

@ -50,6 +50,7 @@ class AssistantNodeName(StrEnum):
FUNNEL_PLANNER_TOOLS = "funnel_planner_tools"
FUNNEL_GENERATOR = "funnel_generator"
FUNNEL_GENERATOR_TOOLS = "funnel_generator_tools"
SUMMARIZER = "summarizer"
class AssistantNode(ABC):

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 42 KiB

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 67 KiB

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 64 KiB

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -52,9 +52,18 @@ export const OBJECTS = {
'IconGearFilled',
'IconStack',
'IconSparkles',
'IconPlug',
'IconPuzzle',
],
People: ['IconPeople', 'IconPeopleFilled', 'IconPerson', 'IconProfile', 'IconUser', 'IconGroups'],
People: [
'IconPeople',
'IconPeopleFilled',
'IconPerson',
'IconProfile',
'IconUser',
'IconGroups',
'IconShieldPeople',
],
'Business & Finance': ['IconStore', 'IconCart', 'IconReceipt', 'IconPiggyBank', 'IconHandMoney'],
Time: ['IconHourglass', 'IconCalendar', 'IconClock'],
Nature: ['IconDay', 'IconNight', 'IconGlobe', 'IconCloud', 'IconBug'],
@ -183,6 +192,7 @@ export const TEAMS_AND_COMPANIES = {
'IconPageChart',
'IconSampling',
'IconLive',
'IconRefresh',
'IconBadge',
],
Replay: [

View File

@ -1089,6 +1089,10 @@
"content": {
"type": "string"
},
"done": {
"description": "We only need this \"done\" value to tell when the particular message is finished during its streaming. It won't be necessary when we optimize streaming to NOT send the entire message every time a character is added.",
"type": "boolean"
},
"type": {
"const": "ai",
"type": "string"
@ -6296,12 +6300,16 @@
"content": {
"type": "string"
},
"done": {
"const": true,
"type": "boolean"
},
"type": {
"const": "ai/failure",
"type": "string"
}
},
"required": ["type"],
"required": ["type", "done"],
"type": "object"
},
"FeaturePropertyFilter": {
@ -7618,12 +7626,17 @@
"content": {
"type": "string"
},
"done": {
"const": true,
"description": "Human messages are only appended when done.",
"type": "boolean"
},
"type": {
"const": "human",
"type": "string"
}
},
"required": ["type", "content"],
"required": ["type", "content", "done"],
"type": "object"
},
"InsightActorsQuery": {
@ -11687,12 +11700,17 @@
"content": {
"type": "string"
},
"done": {
"const": true,
"description": "Router messages are not streamed, so they can only be done.",
"type": "boolean"
},
"type": {
"const": "ai/router",
"type": "string"
}
},
"required": ["type", "content"],
"required": ["type", "content", "done"],
"type": "object"
},
"SamplingRate": {
@ -12837,6 +12855,9 @@
}
]
},
"done": {
"type": "boolean"
},
"plan": {
"type": "string"
},

View File

@ -2478,11 +2478,18 @@ export enum AssistantMessageType {
export interface HumanMessage {
type: AssistantMessageType.Human
content: string
/** Human messages are only appended when done. */
done: true
}
export interface AssistantMessage {
type: AssistantMessageType.Assistant
content: string
/**
* We only need this "done" value to tell when the particular message is finished during its streaming.
* It won't be necessary when we optimize streaming to NOT send the entire message every time a character is added.
*/
done?: boolean
}
export interface VisualizationMessage {
@ -2490,16 +2497,20 @@ export interface VisualizationMessage {
plan?: string
reasoning_steps?: string[] | null
answer?: AssistantTrendsQuery | AssistantFunnelsQuery
done?: boolean
}
export interface FailureMessage {
type: AssistantMessageType.Failure
content?: string
done: true
}
export interface RouterMessage {
type: AssistantMessageType.Router
content: string
/** Router messages are not streamed, so they can only be done. */
done: true
}
export type RootAssistantMessage =

View File

@ -25,7 +25,7 @@ export function QuestionInput(): JSX.Element {
className={clsx(
!isFloating
? 'w-[min(44rem,100%)] relative'
: 'w-full max-w-200 sticky z-10 self-center p-1 mx-3 mb-3 bottom-3 border border-[var(--glass-border-3000)] rounded-[0.625rem] backdrop-blur bg-[var(--glass-bg-3000)]'
: 'w-full max-w-192 sticky z-10 self-center p-1 mx-4 mb-3 bottom-3 border border-[var(--glass-border-3000)] rounded-[0.625rem] backdrop-blur bg-[var(--glass-bg-3000)]'
)}
>
<LemonTextArea

View File

@ -54,6 +54,8 @@ export function QuestionSuggestions(): JSX.Element {
size="xsmall"
type="secondary"
sideIcon={<IconArrowUpRight />}
center
className="shrink"
>
{suggestion}
</LemonButton>

View File

@ -1,4 +1,5 @@
import {
IconRefresh,
IconThumbsDown,
IconThumbsDownFilled,
IconThumbsUp,
@ -11,91 +12,71 @@ import clsx from 'clsx'
import { useActions, useValues } from 'kea'
import { BreakdownSummary, PropertiesSummary, SeriesSummary } from 'lib/components/Cards/InsightCard/InsightDetails'
import { TopHeading } from 'lib/components/Cards/InsightCard/TopHeading'
import { IconRefresh } from 'lib/lemon-ui/icons'
import { IconOpenInNew } from 'lib/lemon-ui/icons'
import { LemonMarkdown } from 'lib/lemon-ui/LemonMarkdown'
import posthog from 'posthog-js'
import React, { useMemo, useRef, useState } from 'react'
import { urls } from 'scenes/urls'
import { Query } from '~/queries/Query/Query'
import { AssistantMessageType, HumanMessage, InsightVizNode, NodeKind, VisualizationMessage } from '~/queries/schema'
import {
AssistantMessage,
AssistantMessageType,
FailureMessage,
HumanMessage,
InsightVizNode,
NodeKind,
VisualizationMessage,
} from '~/queries/schema'
import { maxLogic, MessageStatus, ThreadMessage } from './maxLogic'
import { castAssistantQuery, isFailureMessage, isHumanMessage, isVisualizationMessage } from './utils'
import {
castAssistantQuery,
isAssistantMessage,
isFailureMessage,
isHumanMessage,
isVisualizationMessage,
} from './utils'
export function Thread(): JSX.Element | null {
const { thread, threadLoading } = useValues(maxLogic)
const { retryLastMessage } = useActions(maxLogic)
return (
<div className="flex flex-col items-stretch w-full max-w-200 self-center gap-2 grow p-4">
{thread.map((message, index) => {
if (isHumanMessage(message)) {
return (
<Message
<MessageTemplate
key={index}
type="human"
className={message.status === 'error' ? 'border-danger' : undefined}
>
{message.content || <i>No text</i>}
</Message>
<LemonMarkdown>{message.content || '*No text.*'}</LemonMarkdown>
</MessageTemplate>
)
} else if (isAssistantMessage(message) || isFailureMessage(message)) {
return <TextAnswer key={index} message={message} index={index} />
} else if (isVisualizationMessage(message)) {
return <VisualizationAnswer key={index} message={message} status={message.status} />
}
if (isVisualizationMessage(message)) {
return (
<Answer
key={index}
message={message}
status={message.status}
previousMessage={thread[index - 1]}
/>
)
}
if (isFailureMessage(message)) {
return (
<Message
key={index}
type="ai"
className="border-danger"
action={
index === thread.length - 1 && (
<LemonButton
icon={<IconRefresh />}
size="small"
className="mt-2"
type="secondary"
onClick={() => retryLastMessage()}
>
Try again
</LemonButton>
)
}
>
{message.content || <i>Max has failed to generate an answer. Please try again.</i>}
</Message>
)
}
return null
return null // We currently skip other types of messages
})}
{threadLoading && (
<Message type="ai" className="w-fit select-none">
<MessageTemplate type="ai" className="w-fit select-none">
<div className="flex items-center gap-2">
Let me think
<Spinner className="text-xl" />
</div>
</Message>
</MessageTemplate>
)}
</div>
)
}
const Message = React.forwardRef<
const MessageTemplate = React.forwardRef<
HTMLDivElement,
React.PropsWithChildren<{ type: 'human' | 'ai'; className?: string; action?: React.ReactNode }>
>(function Message({ type, children, className, action }, ref): JSX.Element {
{ type: 'human' | 'ai'; className?: string; action?: React.ReactNode; children: React.ReactNode }
>(function MessageTemplate({ type, children, className, action }, ref) {
if (type === AssistantMessageType.Human) {
return (
<div className={clsx('mt-1 mb-3 text-2xl font-medium', className)} ref={ref}>
@ -105,7 +86,7 @@ const Message = React.forwardRef<
}
return (
<div>
<div className="space-y-2">
<div className={clsx('border p-2 rounded bg-bg-light', className)} ref={ref}>
{children}
</div>
@ -114,14 +95,41 @@ const Message = React.forwardRef<
)
})
function Answer({
const TextAnswer = React.forwardRef<
HTMLDivElement,
{ message: (AssistantMessage | FailureMessage) & ThreadMessage; index: number }
>(function TextAnswer({ message, index }, ref) {
const { thread } = useValues(maxLogic)
return (
<MessageTemplate
type="ai"
className={message.status === 'error' || message.type === 'ai/failure' ? 'border-danger' : undefined}
ref={ref}
action={
message.type === 'ai/failure' && index === thread.length - 1 ? (
<RetriableAnswerActions />
) : message.type === 'ai' &&
message.status === 'completed' &&
(thread[index + 1] === undefined || thread[index + 1].type === 'human') ? (
// Show answer actions if the assistant's response is complete at this point
<SuccessfulAnswerActions messageIndex={index} />
) : null
}
>
<LemonMarkdown>
{message.content || '*Max has failed to generate an answer. Please try again.*'}
</LemonMarkdown>
</MessageTemplate>
)
})
function VisualizationAnswer({
message,
status,
previousMessage,
}: {
message: VisualizationMessage
status?: MessageStatus
previousMessage: ThreadMessage
}): JSX.Element {
const query = useMemo<InsightVizNode | null>(() => {
if (message.answer) {
@ -138,7 +146,7 @@ function Answer({
return (
<>
{message.reasoning_steps && (
<Message
<MessageTemplate
type="ai"
action={
status === 'error' && (
@ -154,11 +162,11 @@ function Answer({
<li key={index}>{step}</li>
))}
</ul>
</Message>
</MessageTemplate>
)}
{status === 'completed' && query && (
<>
<Message type="ai">
<MessageTemplate type="ai">
<div className="h-96 flex">
<Query query={query} readOnly embedded />
</div>
@ -178,36 +186,55 @@ function Answer({
<BreakdownSummary query={query.source} />
</div>
</div>
</Message>
{isHumanMessage(previousMessage) && (
<AnswerActions message={message} previousMessage={previousMessage} />
)}
</MessageTemplate>
</>
)}
</>
)
}
function AnswerActions({
message,
previousMessage,
}: {
message: VisualizationMessage
previousMessage: HumanMessage
}): JSX.Element {
function RetriableAnswerActions(): JSX.Element {
const { retryLastMessage } = useActions(maxLogic)
return (
<LemonButton
icon={<IconRefresh />}
type="secondary"
size="small"
tooltip="Try again"
onClick={() => retryLastMessage()}
>
Try again
</LemonButton>
)
}
function SuccessfulAnswerActions({ messageIndex }: { messageIndex: number }): JSX.Element {
const { thread } = useValues(maxLogic)
const { retryLastMessage } = useActions(maxLogic)
const [rating, setRating] = useState<'good' | 'bad' | null>(null)
const [feedback, setFeedback] = useState<string>('')
const [feedbackInputStatus, setFeedbackInputStatus] = useState<'hidden' | 'pending' | 'submitted'>('hidden')
const hasScrolledFeedbackInputIntoView = useRef<boolean>(false)
const [relevantHumanMessage, relevantVisualizationMessage] = useMemo(() => {
// We need to find the relevant visualization message (which might be a message earlier if the most recent one
// is a results summary message), and the human message that triggered it.
const relevantMessages = thread.slice(0, messageIndex + 1).reverse()
const visualizationMessage = relevantMessages.find(isVisualizationMessage) as VisualizationMessage
const humanMessage = relevantMessages.find(isHumanMessage) as HumanMessage
return [humanMessage, visualizationMessage]
}, [thread, messageIndex])
function submitRating(newRating: 'good' | 'bad'): void {
if (rating) {
return // Already rated
}
setRating(newRating)
posthog.capture('chat rating', {
question: previousMessage.content,
answer: JSON.stringify(message.answer),
question: relevantHumanMessage.content,
answer: JSON.stringify(relevantVisualizationMessage.answer),
answer_rating: rating,
})
if (newRating === 'bad') {
@ -220,8 +247,8 @@ function AnswerActions({
return // Input is empty
}
posthog.capture('chat feedback', {
question: previousMessage.content,
answer: JSON.stringify(message.answer),
question: relevantHumanMessage.content,
answer: JSON.stringify(relevantVisualizationMessage.answer),
feedback,
})
setFeedbackInputStatus('submitted')
@ -248,9 +275,18 @@ function AnswerActions({
onClick={() => submitRating('bad')}
/>
)}
{messageIndex === thread.length - 1 && (
<LemonButton
icon={<IconRefresh />}
type="tertiary"
size="small"
tooltip="Try again"
onClick={() => retryLastMessage()}
/>
)}
</div>
{feedbackInputStatus !== 'hidden' && (
<Message
<MessageTemplate
type="ai"
ref={(el) => {
if (el && !hasScrolledFeedbackInputIntoView.current) {
@ -292,7 +328,7 @@ function AnswerActions({
</LemonButton>
</div>
)}
</Message>
</MessageTemplate>
)}
</>
)

View File

@ -1,17 +1,23 @@
import { AssistantGenerationStatusEvent, AssistantGenerationStatusType } from '~/queries/schema'
import chatResponse from './chatResponse.json'
import failureResponse from './failureResponse.json'
import failureMessage from './failureMessage.json'
import summaryMessage from './summaryMessage.json'
import visualizationMessage from './visualizationMessage.json'
function generateChunk(events: string[]): string {
return events.map((event) => (event.startsWith('event:') ? `${event}\n` : `${event}\n\n`)).join('')
}
export const chatResponseChunk = generateChunk(['event: message', `data: ${JSON.stringify(chatResponse)}`])
export const chatResponseChunk = generateChunk([
'event: message',
`data: ${JSON.stringify(visualizationMessage)}`,
'event: message',
`data: ${JSON.stringify(summaryMessage)}`,
])
const generationFailure: AssistantGenerationStatusEvent = { type: AssistantGenerationStatusType.GenerationError }
const responseWithReasoningStepsOnly = {
...chatResponse,
...visualizationMessage,
answer: null,
}
@ -22,4 +28,4 @@ export const generationFailureChunk = generateChunk([
`data: ${JSON.stringify(generationFailure)}`,
])
export const failureChunk = generateChunk(['event: message', `data: ${JSON.stringify(failureResponse)}`])
export const failureChunk = generateChunk(['event: message', `data: ${JSON.stringify(failureMessage)}`])

View File

@ -0,0 +1,5 @@
{
"type": "ai",
"content": "Looks like no pageviews have occured. Get some damn users.",
"done": true
}

View File

@ -64,5 +64,6 @@
"smoothingIntervals": 1,
"yAxisScaleType": null
}
}
},
"done": true
}

View File

@ -4,7 +4,7 @@ import { createParser } from 'eventsource-parser'
import { actions, afterMount, connect, kea, key, listeners, path, props, reducers, selectors } from 'kea'
import { loaders } from 'kea-loaders'
import api from 'lib/api'
import { isHumanMessage, isRouterMessage, isVisualizationMessage } from 'scenes/max/utils'
import { isHumanMessage } from 'scenes/max/utils'
import { projectLogic } from 'scenes/projectLogic'
import {
@ -13,6 +13,7 @@ import {
AssistantGenerationStatusType,
AssistantMessageType,
FailureMessage,
HumanMessage,
NodeKind,
RefreshType,
RootAssistantMessage,
@ -28,12 +29,14 @@ export interface MaxLogicProps {
export type MessageStatus = 'loading' | 'completed' | 'error'
export type ThreadMessage = RootAssistantMessage & {
status?: MessageStatus
status: MessageStatus
}
const FAILURE_MESSAGE: FailureMessage = {
const FAILURE_MESSAGE: FailureMessage & ThreadMessage = {
type: AssistantMessageType.Failure,
content: 'Oops! It looks like Im having trouble generating this trends insight. Could you please try again?',
status: 'error',
done: true,
}
export const maxLogic = kea<maxLogicType>([
@ -48,7 +51,7 @@ export const maxLogic = kea<maxLogicType>([
setThreadLoaded: (testOnlyOverride = false) => ({ testOnlyOverride }),
addMessage: (message: ThreadMessage) => ({ message }),
replaceMessage: (index: number, message: ThreadMessage) => ({ index, message }),
setMessageStatus: (index: number, status: ThreadMessage['status']) => ({ index, status }),
setMessageStatus: (index: number, status: MessageStatus) => ({ index, status }),
setQuestion: (question: string) => ({ question }),
setVisibleSuggestions: (suggestions: string[]) => ({ suggestions }),
shuffleVisibleSuggestions: true,
@ -149,9 +152,7 @@ export const maxLogic = kea<maxLogicType>([
)
},
askMax: async ({ prompt }) => {
actions.addMessage({ type: AssistantMessageType.Human, content: prompt })
let generatingMessageIndex: number = -1
actions.addMessage({ type: AssistantMessageType.Human, content: prompt, done: true, status: 'completed' })
try {
const response = await api.chat({
session_id: props.sessionId,
@ -173,21 +174,15 @@ export const maxLogic = kea<maxLogicType>([
return
}
if (isRouterMessage(parsedResponse)) {
if (values.thread[values.thread.length - 1].status === 'completed') {
actions.addMessage({
...parsedResponse,
status: 'completed',
status: !parsedResponse.done ? 'loading' : 'completed',
})
} else if (generatingMessageIndex === -1) {
generatingMessageIndex = values.thread.length
if (parsedResponse) {
actions.addMessage({ ...parsedResponse, status: 'loading' })
}
} else if (parsedResponse) {
actions.replaceMessage(generatingMessageIndex, {
actions.replaceMessage(values.thread.length - 1, {
...parsedResponse,
status: values.thread[generatingMessageIndex].status,
status: !parsedResponse.done ? 'loading' : 'completed',
})
}
} else if (event === AssistantEventType.Status) {
@ -197,7 +192,7 @@ export const maxLogic = kea<maxLogicType>([
}
if (parsedResponse.type === AssistantGenerationStatusType.GenerationError) {
actions.setMessageStatus(generatingMessageIndex, 'error')
actions.setMessageStatus(values.thread.length - 1, 'error')
}
}
},
@ -205,47 +200,28 @@ export const maxLogic = kea<maxLogicType>([
while (true) {
const { done, value } = await reader.read()
parser.feed(decoder.decode(value))
if (done) {
if (generatingMessageIndex === -1) {
break
}
const generatedMessage = values.thread[generatingMessageIndex]
if (generatedMessage && isVisualizationMessage(generatedMessage) && generatedMessage.plan) {
actions.setMessageStatus(generatingMessageIndex, 'completed')
} else if (generatedMessage) {
actions.replaceMessage(generatingMessageIndex, FAILURE_MESSAGE)
} else {
actions.addMessage({
...FAILURE_MESSAGE,
status: 'completed',
})
}
break
}
}
} catch (e) {
captureException(e)
if (generatingMessageIndex !== -1) {
if (values.thread[generatingMessageIndex]) {
actions.replaceMessage(generatingMessageIndex, FAILURE_MESSAGE)
} else {
actions.addMessage({
...FAILURE_MESSAGE,
status: 'completed',
})
}
if (values.thread[values.thread.length - 1]?.status === 'loading') {
actions.replaceMessage(values.thread.length - 1, FAILURE_MESSAGE)
} else if (values.thread[values.thread.length - 1]?.status !== 'error') {
actions.addMessage({
...FAILURE_MESSAGE,
status: 'completed',
})
}
}
actions.setThreadLoaded()
},
retryLastMessage: () => {
const lastMessage = values.thread.filter(isHumanMessage).pop()
const lastMessage = values.thread.filter(isHumanMessage).pop() as HumanMessage | undefined
if (lastMessage) {
actions.askMax(lastMessage.content)
}

View File

@ -1,5 +1,6 @@
import {
AssistantFunnelsQuery,
AssistantMessage,
AssistantMessageType,
AssistantTrendsQuery,
FailureMessage,
@ -22,6 +23,10 @@ export function isHumanMessage(message: RootAssistantMessage | undefined | null)
return message?.type === AssistantMessageType.Human
}
export function isAssistantMessage(message: RootAssistantMessage | undefined | null): message is AssistantMessage {
return message?.type === AssistantMessageType.Assistant
}
export function isFailureMessage(message: RootAssistantMessage | undefined | null): message is FailureMessage {
return message?.type === AssistantMessageType.Failure
}

View File

@ -54,8 +54,8 @@
"typegen:check": "kea-typegen check",
"typegen:watch": "kea-typegen watch --delete --show-ts-errors",
"typegen:clean": "find frontend/src -type f -name '*Type.ts' -delete",
"storybook": "storybook dev -p 6006",
"build-storybook": "storybook build",
"storybook": "DEBUG=0 storybook dev -p 6006",
"build-storybook": "DEBUG=0 storybook build",
"dev:migrate:postgres": "export DEBUG=1 && source env/bin/activate && python manage.py migrate",
"dev:migrate:clickhouse": "export DEBUG=1 && source env/bin/activate && python manage.py migrate_clickhouse",
"prepare": "husky install",
@ -77,7 +77,7 @@
"@microlink/react-json-view": "^1.21.3",
"@monaco-editor/react": "4.6.0",
"@posthog/hogvm": "^1.0.58",
"@posthog/icons": "0.8.5",
"@posthog/icons": "0.9.1",
"@posthog/plugin-scaffold": "^1.4.4",
"@react-hook/size": "^2.1.2",
"@rrweb/types": "2.0.0-alpha.13",

View File

@ -53,8 +53,8 @@ dependencies:
specifier: ^1.0.58
version: 1.0.58(luxon@3.5.0)
'@posthog/icons':
specifier: 0.8.5
version: 0.8.5(react-dom@18.2.0)(react@18.2.0)
specifier: 0.9.1
version: 0.9.1(react-dom@18.2.0)(react@18.2.0)
'@posthog/plugin-scaffold':
specifier: ^1.4.4
version: 1.4.4
@ -5426,8 +5426,8 @@ packages:
luxon: 3.5.0
dev: false
/@posthog/icons@0.8.5(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-bFPMgnR3ZaNnMQ81OznYFQRd7KaCqXcI8xS3qS49UBkSZpKeJgH86JbWXBXI2q2GZWX00gc+gZxEo5EBkY7KcQ==}
/@posthog/icons@0.9.1(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-9zlU1H7MZm2gSh1JsDzM25km6VDc/Y7HdNf6RyP5sUiHCHVMKhQQ8TA2IMq55v/uTFRc5Yen6BagOUvunD2kqQ==}
peerDependencies:
react: '>=16.14.0'
react-dom: '>=16.14.0'

View File

@ -137,6 +137,7 @@ class MatrixManager:
)
for cohort in Cohort.objects.filter(team=team):
cohort.calculate_people_ch(pending_version=0)
team.project.save()
team.save()
def _save_analytics_data(self, data_team: Team):

View File

@ -98,6 +98,8 @@ class HedgeboxMatrix(Matrix):
def set_project_up(self, team, user):
super().set_project_up(team, user)
team.project.product_description = "Dropbox for hedgehogs. We're a file sharing and collaboration platform. Free for limited personal use, with paid plans available."
team.autocapture_web_vitals_opt_in = True
# Actions
interacted_with_file_action = Action.objects.create(
@ -882,6 +884,3 @@ class HedgeboxMatrix(Matrix):
)
except IntegrityError:
pass # This can happen if demo data generation is re-run for the same project
# autocapture
team.autocapture_web_vitals_opt_in = True

View File

@ -247,6 +247,14 @@ class AssistantMessage(BaseModel):
extra="forbid",
)
content: str
done: Optional[bool] = Field(
default=None,
description=(
'We only need this "done" value to tell when the particular message is finished during its streaming. It'
" won't be necessary when we optimize streaming to NOT send the entire message every time a character is"
" added."
),
)
type: Literal["ai"] = "ai"
@ -850,6 +858,7 @@ class FailureMessage(BaseModel):
extra="forbid",
)
content: Optional[str] = None
done: Literal[True] = True
type: Literal["ai/failure"] = "ai/failure"
@ -1047,6 +1056,7 @@ class HumanMessage(BaseModel):
extra="forbid",
)
content: str
done: Literal[True] = Field(default=True, description="Human messages are only appended when done.")
type: Literal["human"] = "human"
@ -1443,6 +1453,7 @@ class RouterMessage(BaseModel):
extra="forbid",
)
content: str
done: Literal[True] = Field(default=True, description="Router messages are not streamed, so they can only be done.")
type: Literal["ai/router"] = "ai/router"
@ -5536,6 +5547,7 @@ class VisualizationMessage(BaseModel):
extra="forbid",
)
answer: Optional[Union[AssistantTrendsQuery, AssistantFunnelsQuery]] = None
done: Optional[bool] = None
plan: Optional[str] = None
reasoning_steps: Optional[list[str]] = None
type: Literal["ai/viz"] = "ai/viz"