diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index df200ae869c..3e34a5e0acc 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -5,7 +5,6 @@ from langchain_core.messages import AIMessageChunk from langfuse.callback import CallbackHandler from langgraph.graph.state import StateGraph from pydantic import BaseModel -from sentry_sdk import capture_exception from ee import settings from ee.hogai.funnels.nodes import ( @@ -30,6 +29,7 @@ from posthog.schema import ( AssistantGenerationStatusType, AssistantMessage, FailureMessage, + ReasoningMessage, VisualizationMessage, ) @@ -68,12 +68,30 @@ def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], As return len(update) == 2 and update[0] == "values" +def is_task_started_update( + update: list[Any], +) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]: + """ + Streaming of messages. Returns a partial state. + """ + return len(update) == 2 and update[0] == "debug" and update[1]["type"] == "task" + + VISUALIZATION_NODES: dict[AssistantNodeName, type[SchemaGeneratorNode]] = { AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode, AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode, } +NODE_TO_REASONING_MESSAGE: dict[AssistantNodeName, str] = { + AssistantNodeName.ROUTER: "Identifying type of analysis", + AssistantNodeName.TRENDS_PLANNER: "Picking relevant events and properties", + AssistantNodeName.FUNNEL_PLANNER: "Picking relevant events and properties", + AssistantNodeName.TRENDS_GENERATOR: "Creating trends query", + AssistantNodeName.FUNNEL_GENERATOR: "Creating funnel query", +} + + class Assistant: _team: Team _graph: StateGraph @@ -185,7 +203,7 @@ class Assistant: generator = assistant_graph.stream( state, config={"recursion_limit": 24, "callbacks": callbacks}, - stream_mode=["messages", "values", "updates"], + stream_mode=["messages", "values", "updates", "debug"], ) chunks = AIMessageChunk(content="") @@ -228,12 +246,15 @@ class Assistant: chunks.tool_calls[0]["args"] ) if parsed_message: - yield VisualizationMessage( - reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer - ) + yield VisualizationMessage(answer=parsed_message.query) elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER: chunks += langchain_message # type: ignore yield AssistantMessage(content=chunks.content) - except Exception as e: - capture_exception(e) + elif is_task_started_update(update): + _, task_update = update + node_name = task_update["payload"]["name"] + if reasoning_message := NODE_TO_REASONING_MESSAGE.get(node_name): + yield ReasoningMessage(content=reasoning_message) + except: yield FailureMessage() # This is an unhandled error, so we just stop further generation at this point + raise # Re-raise, so that this is printed or goes into Sentry diff --git a/ee/hogai/funnels/test/test_nodes.py b/ee/hogai/funnels/test/test_nodes.py index 59ba48ff6fa..5c65b141105 100644 --- a/ee/hogai/funnels/test/test_nodes.py +++ b/ee/hogai/funnels/test/test_nodes.py @@ -21,7 +21,7 @@ class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest): node = FunnelGeneratorNode(self.team) with patch.object(FunnelGeneratorNode, "_model") as generator_model_mock: generator_model_mock.return_value = RunnableLambda( - lambda _: FunnelsSchemaGeneratorOutput(reasoning_steps=["step"], answer=self.schema).model_dump() + lambda _: FunnelsSchemaGeneratorOutput(query=self.schema).model_dump() ) new_state = node.run( { @@ -33,9 +33,7 @@ class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest): self.assertEqual( new_state, { - "messages": [ - VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True) - ], + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], "intermediate_steps": None, }, ) diff --git a/ee/hogai/funnels/toolkit.py b/ee/hogai/funnels/toolkit.py index 75550864cd8..8d6407027aa 100644 --- a/ee/hogai/funnels/toolkit.py +++ b/ee/hogai/funnels/toolkit.py @@ -62,15 +62,10 @@ def generate_funnel_schema() -> dict: "parameters": { "type": "object", "properties": { - "reasoning_steps": { - "type": "array", - "items": {"type": "string"}, - "description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.", - }, - "answer": dereference_schema(schema), + "query": dereference_schema(schema), }, "additionalProperties": False, - "required": ["reasoning_steps", "answer"], + "required": ["query"], }, } diff --git a/ee/hogai/schema_generator/nodes.py b/ee/hogai/schema_generator/nodes.py index 6470c52c4fe..9ad475563c9 100644 --- a/ee/hogai/schema_generator/nodes.py +++ b/ee/hogai/schema_generator/nodes.py @@ -99,8 +99,7 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]): "messages": [ VisualizationMessage( plan=generated_plan, - reasoning_steps=message.reasoning_steps, - answer=message.answer, + answer=message.query, done=True, ) ], diff --git a/ee/hogai/schema_generator/test/test_nodes.py b/ee/hogai/schema_generator/test/test_nodes.py index 25f82e43d44..af662349787 100644 --- a/ee/hogai/schema_generator/test/test_nodes.py +++ b/ee/hogai/schema_generator/test/test_nodes.py @@ -41,9 +41,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): def test_node_runs(self): node = DummyGeneratorNode(self.team) with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: - generator_model_mock.return_value = RunnableLambda( - lambda _: TestSchema(reasoning_steps=["step"], answer=self.schema).model_dump() - ) + generator_model_mock.return_value = RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump()) new_state = node.run( { "messages": [HumanMessage(content="Text")], @@ -54,9 +52,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): self.assertEqual( new_state, { - "messages": [ - VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True) - ], + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], "intermediate_steps": None, }, ) @@ -176,9 +172,9 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): def test_failover_with_incorrect_schema(self): node = DummyGeneratorNode(self.team) with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: - schema = TestSchema(reasoning_steps=[], answer=None).model_dump() + schema = TestSchema(query=None).model_dump() # Emulate an incorrect JSON. It should be an object. - schema["answer"] = [] + schema["query"] = [] generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) @@ -200,7 +196,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): with patch.object( DummyGeneratorNode, "_model", - return_value=RunnableLambda(lambda _: TestSchema(reasoning_steps=[], answer=self.schema).model_dump()), + return_value=RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump()), ): new_state = node.run( { @@ -226,9 +222,9 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest): def test_node_leaves_failover_after_second_unsuccessful_attempt(self): node = DummyGeneratorNode(self.team) with patch.object(DummyGeneratorNode, "_model") as generator_model_mock: - schema = TestSchema(reasoning_steps=[], answer=None).model_dump() + schema = TestSchema(query=None).model_dump() # Emulate an incorrect JSON. It should be an object. - schema["answer"] = [] + schema["query"] = [] generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) new_state = node.run( diff --git a/ee/hogai/schema_generator/utils.py b/ee/hogai/schema_generator/utils.py index 54b610f05f3..8d0f8db4de9 100644 --- a/ee/hogai/schema_generator/utils.py +++ b/ee/hogai/schema_generator/utils.py @@ -6,5 +6,4 @@ T = TypeVar("T", bound=BaseModel) class SchemaGeneratorOutput(BaseModel, Generic[T]): - reasoning_steps: Optional[list[str]] = None - answer: Optional[T] = None + query: Optional[T] = None diff --git a/ee/hogai/summarizer/test/test_nodes.py b/ee/hogai/summarizer/test/test_nodes.py index b0e8cdcd37f..b38d88275aa 100644 --- a/ee/hogai/summarizer/test/test_nodes.py +++ b/ee/hogai/summarizer/test/test_nodes.py @@ -38,7 +38,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest): VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - reasoning_steps=["step"], done=True, ), ], @@ -73,7 +72,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest): VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - reasoning_steps=["step"], done=True, ), ], @@ -110,7 +108,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest): VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - reasoning_steps=["step"], done=True, ), ], @@ -159,7 +156,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest): VisualizationMessage( answer=None, plan="Plan", - reasoning_steps=["step"], done=True, ), ], @@ -180,7 +176,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest): VisualizationMessage( answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]), plan="Plan", - reasoning_steps=["step"], done=True, ), ] diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 03c2ac85ea7..44973b31953 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -23,7 +23,7 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest): node = TrendsGeneratorNode(self.team) with patch.object(TrendsGeneratorNode, "_model") as generator_model_mock: generator_model_mock.return_value = RunnableLambda( - lambda _: TrendsSchemaGeneratorOutput(reasoning_steps=["step"], answer=self.schema).model_dump() + lambda _: TrendsSchemaGeneratorOutput(query=self.schema).model_dump() ) new_state = node.run( { @@ -35,9 +35,7 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest): self.assertEqual( new_state, { - "messages": [ - VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True) - ], + "messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)], "intermediate_steps": None, }, ) diff --git a/ee/hogai/trends/toolkit.py b/ee/hogai/trends/toolkit.py index f4cc96e706f..672f065a1f2 100644 --- a/ee/hogai/trends/toolkit.py +++ b/ee/hogai/trends/toolkit.py @@ -65,15 +65,10 @@ def generate_trends_schema() -> dict: "parameters": { "type": "object", "properties": { - "reasoning_steps": { - "type": "array", - "items": {"type": "string"}, - "description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.", - }, - "answer": dereference_schema(schema), + "query": dereference_schema(schema), }, "additionalProperties": False, - "required": ["reasoning_steps", "answer"], + "required": ["query"], }, } diff --git a/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx b/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx index 347fd3227ce..a9b08bf114a 100644 --- a/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx +++ b/frontend/src/lib/components/Cards/InsightCard/InsightDetails.tsx @@ -27,12 +27,13 @@ import { FunnelsQuery, InsightQueryNode, LifecycleQuery, - NodeKind, PathsQuery, StickinessQuery, TrendsQuery, } from '~/queries/schema' import { + isActionsNode, + isEventsNode, isFunnelsQuery, isInsightQueryWithBreakdown, isInsightQueryWithSeries, @@ -149,15 +150,15 @@ function SeriesDisplay({ seriesIndex: number }): JSX.Element { const { mathDefinitions } = useValues(mathsLogic) - const filter = query.series[seriesIndex] + const series = query.series[seriesIndex] const hasBreakdown = isInsightQueryWithBreakdown(query) && isValidBreakdown(query.breakdownFilter) const mathDefinition = mathDefinitions[ isLifecycleQuery(query) ? 'dau' - : filter.math - ? apiValueToMathType(filter.math, filter.math_group_type_index) + : series.math + ? apiValueToMathType(series.math, series.math_group_type_index) : 'total' ] as MathDefinition | undefined @@ -167,12 +168,12 @@ function SeriesDisplay({ className="SeriesDisplay" icon={} extendedContent={ - filter.properties && - filter.properties.length > 0 && ( + series.properties && + series.properties.length > 0 && ( @@ -181,34 +182,36 @@ function SeriesDisplay({ > {isFunnelsQuery(query) ? 'Performed' : 'Showing'} - {filter.custom_name && "{filter.custom_name}"} - {filter.kind === NodeKind.ActionsNode && filter.id ? ( + {series.custom_name && "{series.custom_name}"} + {isActionsNode(series) ? ( - {filter.name} + {series.name} - ) : ( + ) : isEventsNode(series) ? ( - + + ) : ( + {series.kind /* TODO: Support DataWarehouseNode */} )} {!isFunnelsQuery(query) && ( counted by{' '} {mathDefinition?.category === MathCategory.HogQLExpression ? ( - {filter.math_hogql} + {series.math_hogql} ) : ( <> - {mathDefinition?.category === MathCategory.PropertyValue && filter.math_property && ( + {mathDefinition?.category === MathCategory.PropertyValue && series.math_property && ( <> {' '} event's diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 1da587c9fcd..7908409a2d2 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -1102,7 +1102,7 @@ "type": "object" }, "AssistantMessageType": { - "enum": ["human", "ai", "ai/viz", "ai/failure", "ai/router"], + "enum": ["human", "ai", "ai/reasoning", "ai/viz", "ai/failure", "ai/router"], "type": "string" }, "AssistantMultipleBreakdownFilter": { @@ -11275,6 +11275,24 @@ "required": ["k", "t"], "type": "object" }, + "ReasoningMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "done": { + "const": true, + "type": "boolean" + }, + "type": { + "const": "ai/reasoning", + "type": "string" + } + }, + "required": ["type", "content", "done"], + "type": "object" + }, "RecordingOrder": { "enum": [ "duration", @@ -11686,6 +11704,9 @@ { "$ref": "#/definitions/VisualizationMessage" }, + { + "$ref": "#/definitions/ReasoningMessage" + }, { "$ref": "#/definitions/AssistantMessage" }, @@ -12867,19 +12888,6 @@ "plan": { "type": "string" }, - "reasoning_steps": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ] - }, "type": { "const": "ai/viz", "type": "string" diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index a6a78a93cce..d582b706e9c 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -2472,6 +2472,7 @@ export type CachedActorsPropertyTaxonomyQueryResponse = CachedQueryResponse } else if (isVisualizationMessage(message)) { return + } else if (isReasoningMessage(message)) { + return
{message.content}
} return null // We currently skip other types of messages })} @@ -130,7 +132,7 @@ function VisualizationAnswer({ }: { message: VisualizationMessage status?: MessageStatus -}): JSX.Element { +}): JSX.Element | null { const query = useMemo(() => { if (message.answer) { return { @@ -143,54 +145,33 @@ function VisualizationAnswer({ return null }, [message]) - return ( - <> - {message.reasoning_steps && ( - } status="warning" size="small"> - Max is generating this answer one more time because the previous attempt has failed. - - ) - } - className={status === 'error' ? 'border-warning' : undefined} - > -
    - {message.reasoning_steps.map((step, index) => ( -
  • {step}
  • - ))} -
-
- )} - {status === 'completed' && query && ( - <> - -
- -
-
- } - size="xsmall" - targetBlank - className="absolute right-0 -top-px" - > - Open as new insight - - } /> -
- - -
-
-
- - )} - - ) + return status !== 'completed' + ? null + : query && ( + <> + +
+ +
+
+ } + size="xsmall" + targetBlank + className="absolute right-0 -top-px" + > + Open as new insight + + } /> +
+ + +
+
+
+ + ) } function RetriableAnswerActions(): JSX.Element { diff --git a/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts b/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts index dffb3cfa056..3b1fa6b196a 100644 --- a/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts +++ b/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts @@ -1,6 +1,7 @@ import { AssistantGenerationStatusEvent, AssistantGenerationStatusType } from '~/queries/schema' import failureMessage from './failureMessage.json' +import reasoningMessage from './reasoningMessage.json' import summaryMessage from './summaryMessage.json' import visualizationMessage from './visualizationMessage.json' @@ -9,12 +10,16 @@ function generateChunk(events: string[]): string { } export const chatResponseChunk = generateChunk([ + 'event: message', + `data: ${JSON.stringify(reasoningMessage)}`, 'event: message', `data: ${JSON.stringify(visualizationMessage)}`, 'event: message', `data: ${JSON.stringify(summaryMessage)}`, ]) +export const chatMidwayResponseChunk = generateChunk(['event: message', `data: ${JSON.stringify(reasoningMessage)}`]) + const generationFailure: AssistantGenerationStatusEvent = { type: AssistantGenerationStatusType.GenerationError } const responseWithReasoningStepsOnly = { ...visualizationMessage, diff --git a/frontend/src/scenes/max/__mocks__/reasoningMessage.json b/frontend/src/scenes/max/__mocks__/reasoningMessage.json new file mode 100644 index 00000000000..d6146628b1b --- /dev/null +++ b/frontend/src/scenes/max/__mocks__/reasoningMessage.json @@ -0,0 +1,5 @@ +{ + "type": "ai/reasoning", + "content": "Considering available events and properties", + "done": true +} diff --git a/frontend/src/scenes/max/__mocks__/visualizationMessage.json b/frontend/src/scenes/max/__mocks__/visualizationMessage.json index cabfe93ca1c..63f21b458ef 100644 --- a/frontend/src/scenes/max/__mocks__/visualizationMessage.json +++ b/frontend/src/scenes/max/__mocks__/visualizationMessage.json @@ -1,16 +1,7 @@ { "type": "ai/viz", "plan": "Test plan", - "reasoning_steps": [ - "The user's query is to identify the most popular pages.", - "To determine the most popular pages, we should analyze the '$pageview' event as it tracks when a user loads or reloads a page.", - "We need to use the '$current_url' property from the event properties to identify different pages.", - "Since the user didn't specify a date range, a reasonable default would be the last 30 days to get recent insights.", - "We should use a breakdown on the '$current_url' to see the popularity of each page URL.", - "A bar chart would be suitable for visualizing the most popular pages as it's categorical data.", - "Filter out internal and test users by default unless specified otherwise." - ], - "answer": { + "query": { "aggregation_group_type_index": null, "breakdownFilter": { "breakdown_hide_other_aggregation": null, diff --git a/frontend/src/scenes/max/utils.ts b/frontend/src/scenes/max/utils.ts index 0bfa5757863..22e03040886 100644 --- a/frontend/src/scenes/max/utils.ts +++ b/frontend/src/scenes/max/utils.ts @@ -6,6 +6,7 @@ import { FailureMessage, FunnelsQuery, HumanMessage, + ReasoningMessage, RootAssistantMessage, RouterMessage, TrendsQuery, @@ -13,6 +14,10 @@ import { } from '~/queries/schema' import { isTrendsQuery } from '~/queries/utils' +export function isReasoningMessage(message: RootAssistantMessage | undefined | null): message is ReasoningMessage { + return message?.type === AssistantMessageType.Reasoning +} + export function isVisualizationMessage( message: RootAssistantMessage | undefined | null ): message is VisualizationMessage { diff --git a/posthog/schema.py b/posthog/schema.py index fee008aadf0..eda32e51a4a 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -261,6 +261,7 @@ class AssistantMessage(BaseModel): class AssistantMessageType(StrEnum): HUMAN = "human" AI = "ai" + AI_REASONING = "ai/reasoning" AI_VIZ = "ai/viz" AI_FAILURE = "ai/failure" AI_ROUTER = "ai/router" @@ -1395,6 +1396,15 @@ class QueryTiming(BaseModel): t: float = Field(..., description="Time in seconds. Shortened to 't' to save on data.") +class ReasoningMessage(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + content: str + done: Literal[True] = True + type: Literal["ai/reasoning"] = "ai/reasoning" + + class RecordingOrder(StrEnum): DURATION = "duration" RECORDING_DURATION = "recording_duration" @@ -5549,7 +5559,6 @@ class VisualizationMessage(BaseModel): answer: Optional[Union[AssistantTrendsQuery, AssistantFunnelsQuery]] = None done: Optional[bool] = None plan: Optional[str] = None - reasoning_steps: Optional[list[str]] = None type: Literal["ai/viz"] = "ai/viz" @@ -5936,9 +5945,11 @@ class RetentionQuery(BaseModel): class RootAssistantMessage( - RootModel[Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage]] + RootModel[ + Union[VisualizationMessage, ReasoningMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage] + ] ): - root: Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage] + root: Union[VisualizationMessage, ReasoningMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage] class StickinessQuery(BaseModel):