0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-24 09:14:46 +01:00
posthog/ee/hogai/taxonomy_agent/test/test_nodes.py
Georgiy Tarasov 04f656bc67
fix(product-assistant): planner nodes didn't output tools (#26225)
Co-authored-by: Michael Matloka <michael@posthog.com>
2024-11-18 09:46:17 +00:00

227 lines
9.5 KiB
Python

from unittest.mock import patch
from django.test import override_settings
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage as LangchainAIMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda
from ee.hogai.taxonomy_agent.nodes import (
ChatPromptTemplate,
TaxonomyAgentPlannerNode,
TaxonomyAgentPlannerToolsNode,
)
from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentToolkit, ToolkitTool
from ee.hogai.utils import AssistantState
from posthog.schema import (
AssistantMessage,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
VisualizationMessage,
)
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
class DummyToolkit(TaxonomyAgentToolkit):
def _get_tools(self) -> list[ToolkitTool]:
return self._default_tools
@override_settings(IN_UNIT_TESTING=True)
class TestTaxonomyAgentPlannerNode(ClickhouseTestMixin, APIBaseTest):
def setUp(self):
super().setUp()
self.schema = AssistantTrendsQuery(series=[])
def _get_node(self):
class Node(TaxonomyAgentPlannerNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages([("user", "test")])
toolkit = DummyToolkit(self._team)
return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config)
return Node(self.team)
def test_agent_reconstructs_conversation(self):
node = self._get_node()
history = node._construct_messages({"messages": [HumanMessage(content="Text")]})
self.assertEqual(len(history), 1)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn(f"{{question}}", history[0].content)
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Text"),
VisualizationMessage(answer=self.schema, plan="randomplan"),
]
}
)
self.assertEqual(len(history), 2)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)
self.assertEqual(history[1].type, "ai")
self.assertEqual(history[1].content, "randomplan")
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Text"),
VisualizationMessage(answer=self.schema, plan="randomplan"),
HumanMessage(content="Text"),
]
}
)
self.assertEqual(len(history), 3)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)
self.assertEqual(history[1].type, "ai")
self.assertEqual(history[1].content, "randomplan")
self.assertEqual(history[2].type, "human")
self.assertIn("Text", history[2].content)
self.assertNotIn("{{question}}", history[2].content)
def test_agent_reconstructs_conversation_and_omits_unknown_messages(self):
node = self._get_node()
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Text"),
AssistantMessage(content="test"),
]
}
)
self.assertEqual(len(history), 1)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)
def test_agent_reconstructs_conversation_with_failures(self):
node = self._get_node()
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Text"),
FailureMessage(content="Error"),
HumanMessage(content="Text"),
]
}
)
self.assertEqual(len(history), 1)
self.assertEqual(history[0].type, "human")
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)
def test_agent_filters_out_low_count_events(self):
_create_person(distinct_ids=["test"], team=self.team)
for i in range(26):
_create_event(event=f"event{i}", distinct_id="test", team=self.team)
_create_event(event="distinctevent", distinct_id="test", team=self.team)
node = self._get_node()
self.assertEqual(
node._events_prompt,
"<defined_events><event><name>All Events</name><description>All events. This is a wildcard that matches all events.</description></event><event><name>distinctevent</name></event></defined_events>",
)
def test_agent_preserves_low_count_events_for_smaller_teams(self):
_create_person(distinct_ids=["test"], team=self.team)
_create_event(event="distinctevent", distinct_id="test", team=self.team)
node = self._get_node()
self.assertIn("distinctevent", node._events_prompt)
self.assertIn("all events", node._events_prompt)
def test_agent_scratchpad(self):
node = self._get_node()
scratchpad = [
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
]
prompt = node._get_agent_scratchpad(scratchpad)
self.assertIn("log1", prompt)
self.assertIn("log3", prompt)
def test_agent_handles_output_without_action_block(self):
with patch(
"ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
):
node = self._get_node()
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("I don't want to output an action.", action.log)
self.assertIn("Action:", action.log)
self.assertIn("Action:", action.tool_input)
def test_agent_handles_output_with_malformed_json(self):
with patch(
"ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model",
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
):
node = self._get_node()
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, obs = state_update["intermediate_steps"][0]
self.assertIsNone(obs)
self.assertIn("Thought.\nAction: abc", action.log)
self.assertIn("action", action.tool_input)
self.assertIn("action_input", action.tool_input)
def test_node_outputs_all_events_prompt(self):
node = self._get_node()
self.assertIn("All Events", node._events_prompt)
self.assertIn(
"<event><name>All Events</name><description>All events. This is a wildcard that matches all events.</description></event>",
node._events_prompt,
)
def test_format_prompt(self):
node = self._get_node()
self.assertNotIn("Human:", node._get_react_format_prompt(DummyToolkit(self.team)))
self.assertIn("retrieve_event_properties,", node._get_react_format_prompt(DummyToolkit(self.team)))
self.assertIn(
"retrieve_event_properties(event_name: str)", node._get_react_format_prompt(DummyToolkit(self.team))
)
@override_settings(IN_UNIT_TESTING=True)
class TestTaxonomyAgentPlannerToolsNode(ClickhouseTestMixin, APIBaseTest):
def _get_node(self):
class Node(TaxonomyAgentPlannerToolsNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
toolkit = DummyToolkit(self._team)
return super()._run_with_toolkit(state, toolkit, config=config)
return Node(self.team)
def test_node_handles_action_name_validation_error(self):
state = {
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
"messages": [],
}
node = self._get_node()
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)
def test_node_handles_action_input_validation_error(self):
state = {
"intermediate_steps": [
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
],
"messages": [],
}
node = self._get_node()
state_update = node.run(state, {})
self.assertEqual(len(state_update["intermediate_steps"]), 1)
action, observation = state_update["intermediate_steps"][0]
self.assertIsNotNone(observation)
self.assertIn("<pydantic_exception>", observation)