diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py
index 206b74173a1..d1819b49b70 100644
--- a/ee/hogai/trends/nodes.py
+++ b/ee/hogai/trends/nodes.py
@@ -1,12 +1,10 @@
import itertools
import xml.etree.ElementTree as ET
from functools import cached_property
-from typing import Optional, Union, cast
+from typing import Optional, cast
from langchain.agents.format_scratchpad import format_log_to_str
-from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
-from langchain_core.agents import AgentAction, AgentFinish
-from langchain_core.exceptions import OutputParserException
+from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
@@ -15,10 +13,20 @@ from langchain_openai import ChatOpenAI
from pydantic import ValidationError
from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
-from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
+from ee.hogai.trends.parsers import (
+ PydanticOutputParserException,
+ ReActParserException,
+ ReActParserMissingActionException,
+ parse_generated_trends_output,
+ parse_react_agent_output,
+)
from ee.hogai.trends.prompts import (
react_definitions_prompt,
react_follow_up_prompt,
+ react_malformed_json_prompt,
+ react_missing_action_correction_prompt,
+ react_missing_action_prompt,
+ react_pydantic_validation_exception_prompt,
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
@@ -80,40 +88,42 @@ class CreateTrendsPlanNode(AssistantNode):
)
toolkit = TrendsAgentToolkit(self._team)
- output_parser = ReActJsonSingleInputOutputParser()
merger = merge_message_runs()
- agent = prompt | merger | self._model | output_parser
+ agent = prompt | merger | self._model | parse_react_agent_output
try:
result = cast(
- Union[AgentAction, AgentFinish],
+ AgentAction,
agent.invoke(
{
"tools": toolkit.render_text_description(),
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
- "agent_scratchpad": format_log_to_str(
- [(action, output) for action, output in intermediate_steps if output is not None]
- ),
+ "agent_scratchpad": self._get_agent_scratchpad(intermediate_steps),
},
config,
),
)
- except OutputParserException as e:
- text = str(e)
- if e.send_to_llm:
- observation = str(e.observation)
- text = str(e.llm_output)
+ except ReActParserException as e:
+ if isinstance(e, ReActParserMissingActionException):
+ # When the agent doesn't output the "Action:" block, we need to correct the log and append the action block,
+ # so that it has a higher chance to recover.
+ corrected_log = str(
+ ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache")
+ .format_messages(output=e.llm_output)[0]
+ .content
+ )
+ result = AgentAction(
+ "handle_incorrect_response",
+ react_missing_action_prompt,
+ corrected_log,
+ )
else:
- observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
- result = AgentAction("handle_incorrect_response", observation, text)
-
- if isinstance(result, AgentFinish):
- # Exceptional case
- return {
- "plan": result.log,
- "intermediate_steps": None,
- }
+ result = AgentAction(
+ "handle_incorrect_response",
+ react_malformed_json_prompt,
+ e.llm_output,
+ )
return {
"intermediate_steps": [*intermediate_steps, (result, None)],
@@ -205,6 +215,14 @@ class CreateTrendsPlanNode(AssistantNode):
return conversation
+ def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str:
+ actions = []
+ for action, observation in scratchpad:
+ if observation is None:
+ continue
+ actions.append((action, observation))
+ return format_log_to_str(actions)
+
class CreateTrendsPlanToolsNode(AssistantNode):
name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS
@@ -217,8 +235,12 @@ class CreateTrendsPlanToolsNode(AssistantNode):
try:
input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root
except ValidationError as e:
- feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}"
- return {"intermediate_steps": [*intermediate_steps, (action, feedback)]}
+ observation = (
+ ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache")
+ .format_messages(exception=e.errors(include_url=False))[0]
+ .content
+ )
+ return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}
# The plan has been found. Move to the generation.
if input.name == "final_answer":
diff --git a/ee/hogai/trends/parsers.py b/ee/hogai/trends/parsers.py
index a461c692d88..e66f9745769 100644
--- a/ee/hogai/trends/parsers.py
+++ b/ee/hogai/trends/parsers.py
@@ -1,10 +1,66 @@
import json
+import re
+from langchain_core.agents import AgentAction
+from langchain_core.messages import AIMessage as LangchainAIMessage
from pydantic import ValidationError
from ee.hogai.trends.utils import GenerateTrendOutputModel
+class ReActParserException(ValueError):
+ llm_output: str
+
+ def __init__(self, llm_output: str):
+ super().__init__(llm_output)
+ self.llm_output = llm_output
+
+
+class ReActParserMalformedJsonException(ReActParserException):
+ pass
+
+
+class ReActParserMissingActionException(ReActParserException):
+ """
+ The ReAct agent didn't output the "Action:" block.
+ """
+
+ pass
+
+
+ACTION_LOG_PREFIX = "Action:"
+
+
+def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction:
+ """
+ A ReAct agent must output in this format:
+
+ Some thoughts...
+ Action:
+ ```json
+ {"action": "action_name", "action_input": "action_input"}
+ ```
+ """
+ text = str(message.content)
+ if ACTION_LOG_PREFIX not in text:
+ raise ReActParserMissingActionException(text)
+ found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text)
+ if not found:
+ # JSON not found.
+ raise ReActParserMalformedJsonException(text)
+ try:
+ action = found.group(1).strip()
+ response = json.loads(action)
+ is_complete = "action" in response and "action_input" in response
+ except Exception:
+ # JSON is malformed or has a wrong type.
+ raise ReActParserMalformedJsonException(text)
+ if not is_complete:
+ # JSON does not contain an action.
+ raise ReActParserMalformedJsonException(text)
+ return AgentAction(response["action"], response.get("action_input", {}), text)
+
+
class PydanticOutputParserException(ValueError):
llm_output: str
"""Serialized LLM output."""
diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py
index 84c2bcedb54..2543b1efc26 100644
--- a/ee/hogai/trends/prompts.py
+++ b/ee/hogai/trends/prompts.py
@@ -178,6 +178,29 @@ react_follow_up_prompt = """
Improve the previously generated plan based on the feedback: {{feedback}}
"""
+react_missing_action_prompt = """
+Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt.
+"""
+
+react_missing_action_correction_prompt = """
+{{output}}
+Action: I didn't output the `Action:` block.
+"""
+
+react_malformed_json_prompt = """
+Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields.
+"""
+
+react_pydantic_validation_exception_prompt = """
+The action input you previously provided didn't pass the validation and raised a Pydantic validation exception.
+
+
+{{exception}}
+
+
+You must fix the exception and try again.
+"""
+
trends_system_prompt = """
You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can.
diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py
index 990f83ddca7..1e89c45458a 100644
--- a/ee/hogai/trends/test/test_nodes.py
+++ b/ee/hogai/trends/test/test_nodes.py
@@ -3,9 +3,15 @@ from unittest.mock import patch
from django.test import override_settings
from langchain_core.agents import AgentAction
+from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableLambda
-from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
+from ee.hogai.trends.nodes import (
+ CreateTrendsPlanNode,
+ CreateTrendsPlanToolsNode,
+ GenerateTrendsNode,
+ GenerateTrendsToolsNode,
+)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.utils import AssistantNodeName
from posthog.schema import (
@@ -115,6 +121,74 @@ class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest):
self.assertIn("distinctevent", node._events_prompt)
self.assertIn("all events", node._events_prompt)
+ def test_agent_scratchpad(self):
+ node = CreateTrendsPlanNode(self.team)
+ scratchpad = [
+ (AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
+ (AgentAction(tool="test2", tool_input="input2", log="log2"), None),
+ (AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
+ ]
+ prompt = node._get_agent_scratchpad(scratchpad)
+ self.assertIn("log1", prompt)
+ self.assertIn("log3", prompt)
+
+ def test_agent_handles_output_without_action_block(self):
+ with patch(
+ "ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
+ return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
+ ):
+ node = CreateTrendsPlanNode(self.team)
+ state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
+ self.assertEqual(len(state_update["intermediate_steps"]), 1)
+ action, obs = state_update["intermediate_steps"][0]
+ self.assertIsNone(obs)
+ self.assertIn("I don't want to output an action.", action.log)
+ self.assertIn("Action:", action.log)
+ self.assertIn("Action:", action.tool_input)
+
+ def test_agent_handles_output_with_malformed_json(self):
+ with patch(
+ "ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
+ return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
+ ):
+ node = CreateTrendsPlanNode(self.team)
+ state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
+ self.assertEqual(len(state_update["intermediate_steps"]), 1)
+ action, obs = state_update["intermediate_steps"][0]
+ self.assertIsNone(obs)
+ self.assertIn("Thought.\nAction: abc", action.log)
+ self.assertIn("action", action.tool_input)
+ self.assertIn("action_input", action.tool_input)
+
+
+@override_settings(IN_UNIT_TESTING=True)
+class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
+ def test_node_handles_action_name_validation_error(self):
+ state = {
+ "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
+ "messages": [],
+ }
+ node = CreateTrendsPlanToolsNode(self.team)
+ state_update = node.run(state, {})
+ self.assertEqual(len(state_update["intermediate_steps"]), 1)
+ action, observation = state_update["intermediate_steps"][0]
+ self.assertIsNotNone(observation)
+ self.assertIn("", observation)
+
+ def test_node_handles_action_input_validation_error(self):
+ state = {
+ "intermediate_steps": [
+ (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
+ ],
+ "messages": [],
+ }
+ node = CreateTrendsPlanToolsNode(self.team)
+ state_update = node.run(state, {})
+ self.assertEqual(len(state_update["intermediate_steps"]), 1)
+ action, observation = state_update["intermediate_steps"][0]
+ self.assertIsNotNone(observation)
+ self.assertIn("", observation)
+
@override_settings(IN_UNIT_TESTING=True)
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
diff --git a/ee/hogai/trends/test/test_parsers.py b/ee/hogai/trends/test/test_parsers.py
new file mode 100644
index 00000000000..c32ff7f146b
--- /dev/null
+++ b/ee/hogai/trends/test/test_parsers.py
@@ -0,0 +1,78 @@
+from langchain_core.messages import AIMessage as LangchainAIMessage
+
+from ee.hogai.trends.parsers import (
+ ReActParserMalformedJsonException,
+ ReActParserMissingActionException,
+ parse_react_agent_output,
+)
+from posthog.test.base import BaseTest
+
+
+class TestParsers(BaseTest):
+ def test_parse_react_agent_output(self):
+ res = parse_react_agent_output(
+ LangchainAIMessage(
+ content="""
+ Some thoughts...
+ Action:
+ ```json
+ {"action": "action_name", "action_input": "action_input"}
+ ```
+ """
+ )
+ )
+ self.assertEqual(res.tool, "action_name")
+ self.assertEqual(res.tool_input, "action_input")
+
+ res = parse_react_agent_output(
+ LangchainAIMessage(
+ content="""
+ Some thoughts...
+ Action:
+ ```
+ {"action": "tool", "action_input": {"key": "value"}}
+ ```
+ """
+ )
+ )
+ self.assertEqual(res.tool, "tool")
+ self.assertEqual(res.tool_input, {"key": "value"})
+
+ self.assertRaises(
+ ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...")
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content="Some thoughts...\nAction: abc"),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content="Some thoughts...\nAction:"),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content="Some thoughts...\nAction: {}"),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'),
+ )
+ self.assertRaises(
+ ReActParserMalformedJsonException,
+ parse_react_agent_output,
+ LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'),
+ )