0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-21 13:39:22 +01:00

feat(product-assistant): better failover for the ReAct agent (#25903)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Georgiy Tarasov 2024-10-30 16:01:13 +01:00 committed by GitHub
parent 124d166a5b
commit 1fbc799021
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 281 additions and 28 deletions

View File

@ -1,12 +1,10 @@
import itertools
import xml.etree.ElementTree as ET
from functools import cached_property
from typing import Optional, Union, cast
from typing import Optional, cast
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAssistantMessage
from langchain_core.messages import BaseMessage, merge_message_runs
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
@ -15,10 +13,20 @@ from langchain_openai import ChatOpenAI
from pydantic import ValidationError
from ee.hogai.hardcoded_definitions import hardcoded_prop_defs
from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output
from ee.hogai.trends.parsers import (
PydanticOutputParserException,
ReActParserException,
ReActParserMissingActionException,
parse_generated_trends_output,
parse_react_agent_output,
)
from ee.hogai.trends.prompts import (
react_definitions_prompt,
react_follow_up_prompt,
react_malformed_json_prompt,
react_missing_action_correction_prompt,
react_missing_action_prompt,
react_pydantic_validation_exception_prompt,
react_scratchpad_prompt,
react_system_prompt,
react_user_prompt,
@ -80,40 +88,42 @@ class CreateTrendsPlanNode(AssistantNode):
)
toolkit = TrendsAgentToolkit(self._team)
output_parser = ReActJsonSingleInputOutputParser()
merger = merge_message_runs()
agent = prompt | merger | self._model | output_parser
agent = prompt | merger | self._model | parse_react_agent_output
try:
result = cast(
Union[AgentAction, AgentFinish],
AgentAction,
agent.invoke(
{
"tools": toolkit.render_text_description(),
"tool_names": ", ".join([t["name"] for t in toolkit.tools]),
"agent_scratchpad": format_log_to_str(
[(action, output) for action, output in intermediate_steps if output is not None]
),
"agent_scratchpad": self._get_agent_scratchpad(intermediate_steps),
},
config,
),
)
except OutputParserException as e:
text = str(e)
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
except ReActParserException as e:
if isinstance(e, ReActParserMissingActionException):
# When the agent doesn't output the "Action:" block, we need to correct the log and append the action block,
# so that it has a higher chance to recover.
corrected_log = str(
ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache")
.format_messages(output=e.llm_output)[0]
.content
)
result = AgentAction(
"handle_incorrect_response",
react_missing_action_prompt,
corrected_log,
)
else:
observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question."
result = AgentAction("handle_incorrect_response", observation, text)
if isinstance(result, AgentFinish):
# Exceptional case
return {
"plan": result.log,
"intermediate_steps": None,
}
result = AgentAction(
"handle_incorrect_response",
react_malformed_json_prompt,
e.llm_output,
)
return {
"intermediate_steps": [*intermediate_steps, (result, None)],
@ -205,6 +215,14 @@ class CreateTrendsPlanNode(AssistantNode):
return conversation
def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str:
actions = []
for action, observation in scratchpad:
if observation is None:
continue
actions.append((action, observation))
return format_log_to_str(actions)
class CreateTrendsPlanToolsNode(AssistantNode):
name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS
@ -217,8 +235,12 @@ class CreateTrendsPlanToolsNode(AssistantNode):
try:
input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root
except ValidationError as e:
feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}"
return {"intermediate_steps": [*intermediate_steps, (action, feedback)]}
observation = (
ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache")
.format_messages(exception=e.errors(include_url=False))[0]
.content
)
return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]}
# The plan has been found. Move to the generation.
if input.name == "final_answer":

View File

@ -1,10 +1,66 @@
import json
import re
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from pydantic import ValidationError
from ee.hogai.trends.utils import GenerateTrendOutputModel
class ReActParserException(ValueError):
llm_output: str
def __init__(self, llm_output: str):
super().__init__(llm_output)
self.llm_output = llm_output
class ReActParserMalformedJsonException(ReActParserException):
pass
class ReActParserMissingActionException(ReActParserException):
"""
The ReAct agent didn't output the "Action:" block.
"""
pass
ACTION_LOG_PREFIX = "Action:"
def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction:
"""
A ReAct agent must output in this format:
Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
text = str(message.content)
if ACTION_LOG_PREFIX not in text:
raise ReActParserMissingActionException(text)
found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text)
if not found:
# JSON not found.
raise ReActParserMalformedJsonException(text)
try:
action = found.group(1).strip()
response = json.loads(action)
is_complete = "action" in response and "action_input" in response
except Exception:
# JSON is malformed or has a wrong type.
raise ReActParserMalformedJsonException(text)
if not is_complete:
# JSON does not contain an action.
raise ReActParserMalformedJsonException(text)
return AgentAction(response["action"], response.get("action_input", {}), text)
class PydanticOutputParserException(ValueError):
llm_output: str
"""Serialized LLM output."""

View File

@ -178,6 +178,29 @@ react_follow_up_prompt = """
Improve the previously generated plan based on the feedback: {{feedback}}
"""
react_missing_action_prompt = """
Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt.
"""
react_missing_action_correction_prompt = """
{{output}}
Action: I didn't output the `Action:` block.
"""
react_malformed_json_prompt = """
Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields.
"""
react_pydantic_validation_exception_prompt = """
The action input you previously provided didn't pass the validation and raised a Pydantic validation exception.
<pydantic_exception>
{{exception}}
</pydantic_exception>
You must fix the exception and try again.
"""
trends_system_prompt = """
You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can.

View File

@ -3,9 +3,15 @@ from unittest.mock import patch
from django.test import override_settings
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableLambda
from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode
from ee.hogai.trends.nodes import (
CreateTrendsPlanNode,
CreateTrendsPlanToolsNode,
GenerateTrendsNode,
GenerateTrendsToolsNode,
)
from ee.hogai.trends.utils import GenerateTrendOutputModel
from ee.hogai.utils import AssistantNodeName
from posthog.schema import (
@ -115,6 +121,74 @@ class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest):
self.assertIn("distinctevent", node._events_prompt)
self.assertIn("all events", node._events_prompt)
def test_agent_scratchpad(self):
node = CreateTrendsPlanNode(self.team)
scratchpad = [
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
]
prompt = node._get_agent_scratchpad(scratchpad)
self.assertIn("log1", prompt)
self.assertIn("log3", prompt)
def test_agent_handles_output_without_action_block(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("I don't want to output an action.", action.log)
self.assertIn("Action:", action.log)
self.assertIn("Action:", action.tool_input)
def test_agent_handles_output_with_malformed_json(self):
with patch(
"ee.hogai.trends.nodes.CreateTrendsPlanNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
):
node = CreateTrendsPlanNode(self.team)
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("Thought.\nAction: abc", action.log)
self.assertIn("action", action.tool_input)
self.assertIn("action_input", action.tool_input)
@override_settings(IN_UNIT_TESTING=True)
class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest):
def test_node_handles_action_name_validation_error(self):
state = {
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)
def test_node_handles_action_input_validation_error(self):
state = {
"intermediate_steps": [
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
],
"messages": [],
}
node = CreateTrendsPlanToolsNode(self.team)
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)
@override_settings(IN_UNIT_TESTING=True)
class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest):

View File

@ -0,0 +1,78 @@
from langchain_core.messages import AIMessage as LangchainAIMessage
from ee.hogai.trends.parsers import (
ReActParserMalformedJsonException,
ReActParserMissingActionException,
parse_react_agent_output,
)
from posthog.test.base import BaseTest
class TestParsers(BaseTest):
def test_parse_react_agent_output(self):
res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```json
{"action": "action_name", "action_input": "action_input"}
```
"""
)
)
self.assertEqual(res.tool, "action_name")
self.assertEqual(res.tool_input, "action_input")
res = parse_react_agent_output(
LangchainAIMessage(
content="""
Some thoughts...
Action:
```
{"action": "tool", "action_input": {"key": "value"}}
```
"""
)
)
self.assertEqual(res.tool, "tool")
self.assertEqual(res.tool_input, {"key": "value"})
self.assertRaises(
ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...")
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: abc"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction: {}"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'),
)
self.assertRaises(
ReActParserMalformedJsonException,
parse_react_agent_output,
LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'),
)