mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-21 13:39:22 +01:00
feat(product-assistant): better failover for the ReAct agent (#25903)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
124d166a5b
commit
1fbc799021
@ -1,12 +1,10 @@
|
||||
import itertools
|
||||
import xml.etree.ElementTree as ET
|
||||
from functools import cached_property
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.agents.format_scratchpad import format_log_to_str
|
||||
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import AIMessage as LangchainAssistantMessage
|
||||
from langchain_core.messages import BaseMessage, merge_message_runs
|
||||
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
@ -15,10 +13,20 @@ from langchain_openai import ChatOpenAI
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
|
||||
from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
|
||||
from ee.hogai.trends.parsers import (
|
||||
PydanticOutputParserException,
|
||||
ReActParserException,
|
||||
ReActParserMissingActionException,
|
||||
parse_generated_trends_output,
|
||||
parse_react_agent_output,
|
||||
)
|
||||
from ee.hogai.trends.prompts import (
|
||||
react_definitions_prompt,
|
||||
react_follow_up_prompt,
|
||||
react_malformed_json_prompt,
|
||||
react_missing_action_correction_prompt,
|
||||
react_missing_action_prompt,
|
||||
react_pydantic_validation_exception_prompt,
|
||||
react_scratchpad_prompt,
|
||||
react_system_prompt,
|
||||
react_user_prompt,
|
||||
@ -80,40 +88,42 @@ class CreateTrendsPlanNode(AssistantNode):
|
||||
)
|
||||
|
||||
toolkit = TrendsAgentToolkit(self._team)
|
||||
output_parser = ReActJsonSingleInputOutputParser()
|
||||
merger = merge_message_runs()
|
||||
|
||||
agent = prompt | merger | self._model | output_parser
|
||||
agent = prompt | merger | self._model | parse_react_agent_output
|
||||
|
||||
try:
|
||||
result = cast(
|
||||
Union[AgentAction, AgentFinish],
|
||||
AgentAction,
|
||||
agent.invoke(
|
||||
{
|
||||
"tools": toolkit.render_text_description(),
|
||||
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
|
||||
"agent_scratchpad": format_log_to_str(
|
||||
[(action, output) for action, output in intermediate_steps if output is not None]
|
||||
),
|
||||
"agent_scratchpad": self._get_agent_scratchpad(intermediate_steps),
|
||||
},
|
||||
config,
|
||||
),
|
||||
)
|
||||
except OutputParserException as e:
|
||||
text = str(e)
|
||||
if e.send_to_llm:
|
||||
observation = str(e.observation)
|
||||
text = str(e.llm_output)
|
||||
except ReActParserException as e:
|
||||
if isinstance(e, ReActParserMissingActionException):
|
||||
# When the agent doesn't output the "Action:" block, we need to correct the log and append the action block,
|
||||
# so that it has a higher chance to recover.
|
||||
corrected_log = str(
|
||||
ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache")
|
||||
.format_messages(output=e.llm_output)[0]
|
||||
.content
|
||||
)
|
||||
result = AgentAction(
|
||||
"handle_incorrect_response",
|
||||
react_missing_action_prompt,
|
||||
corrected_log,
|
||||
)
|
||||
else:
|
||||
observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
|
||||
result = AgentAction("handle_incorrect_response", observation, text)
|
||||
|
||||
if isinstance(result, AgentFinish):
|
||||
# Exceptional case
|
||||
return {
|
||||
"plan": result.log,
|
||||
"intermediate_steps": None,
|
||||
}
|
||||
result = AgentAction(
|
||||
"handle_incorrect_response",
|
||||
react_malformed_json_prompt,
|
||||
e.llm_output,
|
||||
)
|
||||
|
||||
return {
|
||||
"intermediate_steps": [*intermediate_steps, (result, None)],
|
||||
@ -205,6 +215,14 @@ class CreateTrendsPlanNode(AssistantNode):
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str:
|
||||
actions = []
|
||||
for action, observation in scratchpad:
|
||||
if observation is None:
|
||||
continue
|
||||
actions.append((action, observation))
|
||||
return format_log_to_str(actions)
|
||||
|
||||
|
||||
class CreateTrendsPlanToolsNode(AssistantNode):
|
||||
name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS
|
||||
@ -217,8 +235,12 @@ class CreateTrendsPlanToolsNode(AssistantNode):
|
||||
try:
|
||||
input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root
|
||||
except ValidationError as e:
|
||||
feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}"
|
||||
return {"intermediate_steps": [*intermediate_steps, (action, feedback)]}
|
||||
observation = (
|
||||
ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache")
|
||||
.format_messages(exception=e.errors(include_url=False))[0]
|
||||
.content
|
||||
)
|
||||
return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}
|
||||
|
||||
# The plan has been found. Move to the generation.
|
||||
if input.name == "final_answer":
|
||||
|
@ -1,10 +1,66 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.hogai.trends.utils import GenerateTrendOutputModel
|
||||
|
||||
|
||||
class ReActParserException(ValueError):
|
||||
llm_output: str
|
||||
|
||||
def __init__(self, llm_output: str):
|
||||
super().__init__(llm_output)
|
||||
self.llm_output = llm_output
|
||||
|
||||
|
||||
class ReActParserMalformedJsonException(ReActParserException):
|
||||
pass
|
||||
|
||||
|
||||
class ReActParserMissingActionException(ReActParserException):
|
||||
"""
|
||||
The ReAct agent didn't output the "Action:" block.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
ACTION_LOG_PREFIX = "Action:"
|
||||
|
||||
|
||||
def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction:
|
||||
"""
|
||||
A ReAct agent must output in this format:
|
||||
|
||||
Some thoughts...
|
||||
Action:
|
||||
```json
|
||||
{"action": "action_name", "action_input": "action_input"}
|
||||
```
|
||||
"""
|
||||
text = str(message.content)
|
||||
if ACTION_LOG_PREFIX not in text:
|
||||
raise ReActParserMissingActionException(text)
|
||||
found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text)
|
||||
if not found:
|
||||
# JSON not found.
|
||||
raise ReActParserMalformedJsonException(text)
|
||||
try:
|
||||
action = found.group(1).strip()
|
||||
response = json.loads(action)
|
||||
is_complete = "action" in response and "action_input" in response
|
||||
except Exception:
|
||||
# JSON is malformed or has a wrong type.
|
||||
raise ReActParserMalformedJsonException(text)
|
||||
if not is_complete:
|
||||
# JSON does not contain an action.
|
||||
raise ReActParserMalformedJsonException(text)
|
||||
return AgentAction(response["action"], response.get("action_input", {}), text)
|
||||
|
||||
|
||||
class PydanticOutputParserException(ValueError):
|
||||
llm_output: str
|
||||
"""Serialized LLM output."""
|
||||
|
@ -178,6 +178,29 @@ react_follow_up_prompt = """
|
||||
Improve the previously generated plan based on the feedback: {{feedback}}
|
||||
"""
|
||||
|
||||
react_missing_action_prompt = """
|
||||
Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt.
|
||||
"""
|
||||
|
||||
react_missing_action_correction_prompt = """
|
||||
{{output}}
|
||||
Action: I didn't output the `Action:` block.
|
||||
"""
|
||||
|
||||
react_malformed_json_prompt = """
|
||||
Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields.
|
||||
"""
|
||||
|
||||
react_pydantic_validation_exception_prompt = """
|
||||
The action input you previously provided didn't pass the validation and raised a Pydantic validation exception.
|
||||
|
||||
<pydantic_exception>
|
||||
{{exception}}
|
||||
</pydantic_exception>
|
||||
|
||||
You must fix the exception and try again.
|
||||
"""
|
||||
|
||||
trends_system_prompt = """
|
||||
You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can.
|
||||
|
||||
|
@ -3,9 +3,15 @@ from unittest.mock import patch
|
||||
|
||||
from django.test import override_settings
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
|
||||
from ee.hogai.trends.nodes import (
|
||||
CreateTrendsPlanNode,
|
||||
CreateTrendsPlanToolsNode,
|
||||
GenerateTrendsNode,
|
||||
GenerateTrendsToolsNode,
|
||||
)
|
||||
from ee.hogai.trends.utils import GenerateTrendOutputModel
|
||||
from ee.hogai.utils import AssistantNodeName
|
||||
from posthog.schema import (
|
||||
@ -115,6 +121,74 @@ class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest):
|
||||
self.assertIn("distinctevent", node._events_prompt)
|
||||
self.assertIn("all events", node._events_prompt)
|
||||
|
||||
def test_agent_scratchpad(self):
|
||||
node = CreateTrendsPlanNode(self.team)
|
||||
scratchpad = [
|
||||
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
|
||||
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
|
||||
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
|
||||
]
|
||||
prompt = node._get_agent_scratchpad(scratchpad)
|
||||
self.assertIn("log1", prompt)
|
||||
self.assertIn("log3", prompt)
|
||||
|
||||
def test_agent_handles_output_without_action_block(self):
|
||||
with patch(
|
||||
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
|
||||
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
|
||||
):
|
||||
node = CreateTrendsPlanNode(self.team)
|
||||
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
|
||||
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
||||
action, obs = state_update["intermediate_steps"][0]
|
||||
self.assertIsNone(obs)
|
||||
self.assertIn("I don't want to output an action.", action.log)
|
||||
self.assertIn("Action:", action.log)
|
||||
self.assertIn("Action:", action.tool_input)
|
||||
|
||||
def test_agent_handles_output_with_malformed_json(self):
|
||||
with patch(
|
||||
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
|
||||
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
|
||||
):
|
||||
node = CreateTrendsPlanNode(self.team)
|
||||
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
|
||||
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
||||
action, obs = state_update["intermediate_steps"][0]
|
||||
self.assertIsNone(obs)
|
||||
self.assertIn("Thought.\nAction: abc", action.log)
|
||||
self.assertIn("action", action.tool_input)
|
||||
self.assertIn("action_input", action.tool_input)
|
||||
|
||||
|
||||
@override_settings(IN_UNIT_TESTING=True)
|
||||
class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
|
||||
def test_node_handles_action_name_validation_error(self):
|
||||
state = {
|
||||
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
|
||||
"messages": [],
|
||||
}
|
||||
node = CreateTrendsPlanToolsNode(self.team)
|
||||
state_update = node.run(state, {})
|
||||
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
||||
action, observation = state_update["intermediate_steps"][0]
|
||||
self.assertIsNotNone(observation)
|
||||
self.assertIn("<pydantic_exception>", observation)
|
||||
|
||||
def test_node_handles_action_input_validation_error(self):
|
||||
state = {
|
||||
"intermediate_steps": [
|
||||
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
|
||||
],
|
||||
"messages": [],
|
||||
}
|
||||
node = CreateTrendsPlanToolsNode(self.team)
|
||||
state_update = node.run(state, {})
|
||||
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
||||
action, observation = state_update["intermediate_steps"][0]
|
||||
self.assertIsNotNone(observation)
|
||||
self.assertIn("<pydantic_exception>", observation)
|
||||
|
||||
|
||||
@override_settings(IN_UNIT_TESTING=True)
|
||||
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):
|
||||
|
78
ee/hogai/trends/test/test_parsers.py
Normal file
78
ee/hogai/trends/test/test_parsers.py
Normal file
@ -0,0 +1,78 @@
|
||||
from langchain_core.messages import AIMessage as LangchainAIMessage
|
||||
|
||||
from ee.hogai.trends.parsers import (
|
||||
ReActParserMalformedJsonException,
|
||||
ReActParserMissingActionException,
|
||||
parse_react_agent_output,
|
||||
)
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
|
||||
class TestParsers(BaseTest):
|
||||
def test_parse_react_agent_output(self):
|
||||
res = parse_react_agent_output(
|
||||
LangchainAIMessage(
|
||||
content="""
|
||||
Some thoughts...
|
||||
Action:
|
||||
```json
|
||||
{"action": "action_name", "action_input": "action_input"}
|
||||
```
|
||||
"""
|
||||
)
|
||||
)
|
||||
self.assertEqual(res.tool, "action_name")
|
||||
self.assertEqual(res.tool_input, "action_input")
|
||||
|
||||
res = parse_react_agent_output(
|
||||
LangchainAIMessage(
|
||||
content="""
|
||||
Some thoughts...
|
||||
Action:
|
||||
```
|
||||
{"action": "tool", "action_input": {"key": "value"}}
|
||||
```
|
||||
"""
|
||||
)
|
||||
)
|
||||
self.assertEqual(res.tool, "tool")
|
||||
self.assertEqual(res.tool_input, {"key": "value"})
|
||||
|
||||
self.assertRaises(
|
||||
ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...")
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content="Some thoughts...\nAction: abc"),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content="Some thoughts...\nAction:"),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content="Some thoughts...\nAction: {}"),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'),
|
||||
)
|
||||
self.assertRaises(
|
||||
ReActParserMalformedJsonException,
|
||||
parse_react_agent_output,
|
||||
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'),
|
||||
)
|
Loading…
Reference in New Issue
Block a user