From 1fbc7990218a4814048e2e9e01e5b3df6ce6b8f7 Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Wed, 30 Oct 2024 16:01:13 +0100 Subject: [PATCH] feat(product-assistant): better failover for the ReAct agent (#25903) Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- ee/hogai/trends/nodes.py | 76 +++++++++++++++++---------- ee/hogai/trends/parsers.py | 56 ++++++++++++++++++++ ee/hogai/trends/prompts.py | 23 ++++++++ ee/hogai/trends/test/test_nodes.py | 76 ++++++++++++++++++++++++++- ee/hogai/trends/test/test_parsers.py | 78 ++++++++++++++++++++++++++++ 5 files changed, 281 insertions(+), 28 deletions(-) create mode 100644 ee/hogai/trends/test/test_parsers.py diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py index 206b74173a1..d1819b49b70 100644 --- a/ee/hogai/trends/nodes.py +++ b/ee/hogai/trends/nodes.py @@ -1,12 +1,10 @@ import itertools import xml.etree.ElementTree as ET from functools import cached_property -from typing import Optional, Union, cast +from typing import Optional, cast from langchain.agents.format_scratchpad import format_log_to_str -from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser -from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.exceptions import OutputParserException +from langchain_core.agents import AgentAction from langchain_core.messages import AIMessage as LangchainAssistantMessage from langchain_core.messages import BaseMessage, merge_message_runs from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate @@ -15,10 +13,20 @@ from langchain_openai import ChatOpenAI from pydantic import ValidationError from ee.hogai.hardcoded_definitions import hardcoded_prop_defs -from ee.hogai.trends.parsers import PydanticOutputParserException, parse_generated_trends_output +from ee.hogai.trends.parsers import ( + PydanticOutputParserException, + ReActParserException, + ReActParserMissingActionException, + parse_generated_trends_output, + parse_react_agent_output, +) from ee.hogai.trends.prompts import ( react_definitions_prompt, react_follow_up_prompt, + react_malformed_json_prompt, + react_missing_action_correction_prompt, + react_missing_action_prompt, + react_pydantic_validation_exception_prompt, react_scratchpad_prompt, react_system_prompt, react_user_prompt, @@ -80,40 +88,42 @@ class CreateTrendsPlanNode(AssistantNode): ) toolkit = TrendsAgentToolkit(self._team) - output_parser = ReActJsonSingleInputOutputParser() merger = merge_message_runs() - agent = prompt | merger | self._model | output_parser + agent = prompt | merger | self._model | parse_react_agent_output try: result = cast( - Union[AgentAction, AgentFinish], + AgentAction, agent.invoke( { "tools": toolkit.render_text_description(), "tool_names": ", ".join([t["name"] for t in toolkit.tools]), - "agent_scratchpad": format_log_to_str( - [(action, output) for action, output in intermediate_steps if output is not None] - ), + "agent_scratchpad": self._get_agent_scratchpad(intermediate_steps), }, config, ), ) - except OutputParserException as e: - text = str(e) - if e.send_to_llm: - observation = str(e.observation) - text = str(e.llm_output) + except ReActParserException as e: + if isinstance(e, ReActParserMissingActionException): + # When the agent doesn't output the "Action:" block, we need to correct the log and append the action block, + # so that it has a higher chance to recover. + corrected_log = str( + ChatPromptTemplate.from_template(react_missing_action_correction_prompt, template_format="mustache") + .format_messages(output=e.llm_output)[0] + .content + ) + result = AgentAction( + "handle_incorrect_response", + react_missing_action_prompt, + corrected_log, + ) else: - observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." - result = AgentAction("handle_incorrect_response", observation, text) - - if isinstance(result, AgentFinish): - # Exceptional case - return { - "plan": result.log, - "intermediate_steps": None, - } + result = AgentAction( + "handle_incorrect_response", + react_malformed_json_prompt, + e.llm_output, + ) return { "intermediate_steps": [*intermediate_steps, (result, None)], @@ -205,6 +215,14 @@ class CreateTrendsPlanNode(AssistantNode): return conversation + def _get_agent_scratchpad(self, scratchpad: list[tuple[AgentAction, str | None]]) -> str: + actions = [] + for action, observation in scratchpad: + if observation is None: + continue + actions.append((action, observation)) + return format_log_to_str(actions) + class CreateTrendsPlanToolsNode(AssistantNode): name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS @@ -217,8 +235,12 @@ class CreateTrendsPlanToolsNode(AssistantNode): try: input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root except ValidationError as e: - feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}" - return {"intermediate_steps": [*intermediate_steps, (action, feedback)]} + observation = ( + ChatPromptTemplate.from_template(react_pydantic_validation_exception_prompt, template_format="mustache") + .format_messages(exception=e.errors(include_url=False))[0] + .content + ) + return {"intermediate_steps": [*intermediate_steps[:-1], (action, observation)]} # The plan has been found. Move to the generation. if input.name == "final_answer": diff --git a/ee/hogai/trends/parsers.py b/ee/hogai/trends/parsers.py index a461c692d88..e66f9745769 100644 --- a/ee/hogai/trends/parsers.py +++ b/ee/hogai/trends/parsers.py @@ -1,10 +1,66 @@ import json +import re +from langchain_core.agents import AgentAction +from langchain_core.messages import AIMessage as LangchainAIMessage from pydantic import ValidationError from ee.hogai.trends.utils import GenerateTrendOutputModel +class ReActParserException(ValueError): + llm_output: str + + def __init__(self, llm_output: str): + super().__init__(llm_output) + self.llm_output = llm_output + + +class ReActParserMalformedJsonException(ReActParserException): + pass + + +class ReActParserMissingActionException(ReActParserException): + """ + The ReAct agent didn't output the "Action:" block. + """ + + pass + + +ACTION_LOG_PREFIX = "Action:" + + +def parse_react_agent_output(message: LangchainAIMessage) -> AgentAction: + """ + A ReAct agent must output in this format: + + Some thoughts... + Action: + ```json + {"action": "action_name", "action_input": "action_input"} + ``` + """ + text = str(message.content) + if ACTION_LOG_PREFIX not in text: + raise ReActParserMissingActionException(text) + found = re.compile(r"^.*?`{3}(?:json)?\n?(.*?)`{3}.*?$", re.DOTALL).search(text) + if not found: + # JSON not found. + raise ReActParserMalformedJsonException(text) + try: + action = found.group(1).strip() + response = json.loads(action) + is_complete = "action" in response and "action_input" in response + except Exception: + # JSON is malformed or has a wrong type. + raise ReActParserMalformedJsonException(text) + if not is_complete: + # JSON does not contain an action. + raise ReActParserMalformedJsonException(text) + return AgentAction(response["action"], response.get("action_input", {}), text) + + class PydanticOutputParserException(ValueError): llm_output: str """Serialized LLM output.""" diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py index 84c2bcedb54..2543b1efc26 100644 --- a/ee/hogai/trends/prompts.py +++ b/ee/hogai/trends/prompts.py @@ -178,6 +178,29 @@ react_follow_up_prompt = """ Improve the previously generated plan based on the feedback: {{feedback}} """ +react_missing_action_prompt = """ +Your previous answer didn't output the `Action:` block. You must always follow the format described in the system prompt. +""" + +react_missing_action_correction_prompt = """ +{{output}} +Action: I didn't output the `Action:` block. +""" + +react_malformed_json_prompt = """ +Your previous answer had a malformed JSON. You must return a correct JSON response containing the `action` and `action_input` fields. +""" + +react_pydantic_validation_exception_prompt = """ +The action input you previously provided didn't pass the validation and raised a Pydantic validation exception. + + +{{exception}} + + +You must fix the exception and try again. +""" + trends_system_prompt = """ You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can. diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py index 990f83ddca7..1e89c45458a 100644 --- a/ee/hogai/trends/test/test_nodes.py +++ b/ee/hogai/trends/test/test_nodes.py @@ -3,9 +3,15 @@ from unittest.mock import patch from django.test import override_settings from langchain_core.agents import AgentAction +from langchain_core.messages import AIMessage as LangchainAIMessage from langchain_core.runnables import RunnableLambda -from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode, GenerateTrendsToolsNode +from ee.hogai.trends.nodes import ( + CreateTrendsPlanNode, + CreateTrendsPlanToolsNode, + GenerateTrendsNode, + GenerateTrendsToolsNode, +) from ee.hogai.trends.utils import GenerateTrendOutputModel from ee.hogai.utils import AssistantNodeName from posthog.schema import ( @@ -115,6 +121,74 @@ class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest): self.assertIn("distinctevent", node._events_prompt) self.assertIn("all events", node._events_prompt) + def test_agent_scratchpad(self): + node = CreateTrendsPlanNode(self.team) + scratchpad = [ + (AgentAction(tool="test1", tool_input="input1", log="log1"), "test"), + (AgentAction(tool="test2", tool_input="input2", log="log2"), None), + (AgentAction(tool="test3", tool_input="input3", log="log3"), ""), + ] + prompt = node._get_agent_scratchpad(scratchpad) + self.assertIn("log1", prompt) + self.assertIn("log3", prompt) + + def test_agent_handles_output_without_action_block(self): + with patch( + "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", + return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")), + ): + node = CreateTrendsPlanNode(self.team) + state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, obs = state_update["intermediate_steps"][0] + self.assertIsNone(obs) + self.assertIn("I don't want to output an action.", action.log) + self.assertIn("Action:", action.log) + self.assertIn("Action:", action.tool_input) + + def test_agent_handles_output_with_malformed_json(self): + with patch( + "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", + return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")), + ): + node = CreateTrendsPlanNode(self.team) + state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, obs = state_update["intermediate_steps"][0] + self.assertIsNone(obs) + self.assertIn("Thought.\nAction: abc", action.log) + self.assertIn("action", action.tool_input) + self.assertIn("action_input", action.tool_input) + + +@override_settings(IN_UNIT_TESTING=True) +class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest): + def test_node_handles_action_name_validation_error(self): + state = { + "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + + def test_node_handles_action_input_validation_error(self): + state = { + "intermediate_steps": [ + (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test") + ], + "messages": [], + } + node = CreateTrendsPlanToolsNode(self.team) + state_update = node.run(state, {}) + self.assertEqual(len(state_update["intermediate_steps"]), 1) + action, observation = state_update["intermediate_steps"][0] + self.assertIsNotNone(observation) + self.assertIn("", observation) + @override_settings(IN_UNIT_TESTING=True) class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): diff --git a/ee/hogai/trends/test/test_parsers.py b/ee/hogai/trends/test/test_parsers.py new file mode 100644 index 00000000000..c32ff7f146b --- /dev/null +++ b/ee/hogai/trends/test/test_parsers.py @@ -0,0 +1,78 @@ +from langchain_core.messages import AIMessage as LangchainAIMessage + +from ee.hogai.trends.parsers import ( + ReActParserMalformedJsonException, + ReActParserMissingActionException, + parse_react_agent_output, +) +from posthog.test.base import BaseTest + + +class TestParsers(BaseTest): + def test_parse_react_agent_output(self): + res = parse_react_agent_output( + LangchainAIMessage( + content=""" + Some thoughts... + Action: + ```json + {"action": "action_name", "action_input": "action_input"} + ``` + """ + ) + ) + self.assertEqual(res.tool, "action_name") + self.assertEqual(res.tool_input, "action_input") + + res = parse_react_agent_output( + LangchainAIMessage( + content=""" + Some thoughts... + Action: + ``` + {"action": "tool", "action_input": {"key": "value"}} + ``` + """ + ) + ) + self.assertEqual(res.tool, "tool") + self.assertEqual(res.tool_input, {"key": "value"}) + + self.assertRaises( + ReActParserMissingActionException, parse_react_agent_output, LangchainAIMessage(content="Some thoughts...") + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction: abc"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction: {}"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{}\n```"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content="Some thoughts...\nAction:\n```\n{not a json}\n```"), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action":"tool"}\n```'), + ) + self.assertRaises( + ReActParserMalformedJsonException, + parse_react_agent_output, + LangchainAIMessage(content='Some thoughts...\nAction:\n```\n{"action_input":"input"}\n```'), + )