mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-24 09:14:46 +01:00
04f656bc67
Co-authored-by: Michael Matloka <michael@posthog.com>
227 lines
9.5 KiB
Python
227 lines
9.5 KiB
Python
from unittest.mock import patch
|
|
|
|
from django.test import override_settings
|
|
from langchain_core.agents import AgentAction
|
|
from langchain_core.messages import AIMessage as LangchainAIMessage
|
|
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
|
|
|
from ee.hogai.taxonomy_agent.nodes import (
|
|
ChatPromptTemplate,
|
|
TaxonomyAgentPlannerNode,
|
|
TaxonomyAgentPlannerToolsNode,
|
|
)
|
|
from ee.hogai.taxonomy_agent.toolkit import TaxonomyAgentToolkit, ToolkitTool
|
|
from ee.hogai.utils import AssistantState
|
|
from posthog.schema import (
|
|
AssistantMessage,
|
|
AssistantTrendsQuery,
|
|
FailureMessage,
|
|
HumanMessage,
|
|
VisualizationMessage,
|
|
)
|
|
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
|
|
|
|
|
|
class DummyToolkit(TaxonomyAgentToolkit):
|
|
def _get_tools(self) -> list[ToolkitTool]:
|
|
return self._default_tools
|
|
|
|
|
|
@override_settings(IN_UNIT_TESTING=True)
|
|
class TestTaxonomyAgentPlannerNode(ClickhouseTestMixin, APIBaseTest):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.schema = AssistantTrendsQuery(series=[])
|
|
|
|
def _get_node(self):
|
|
class Node(TaxonomyAgentPlannerNode):
|
|
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
|
|
prompt: ChatPromptTemplate = ChatPromptTemplate.from_messages([("user", "test")])
|
|
toolkit = DummyToolkit(self._team)
|
|
return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config)
|
|
|
|
return Node(self.team)
|
|
|
|
def test_agent_reconstructs_conversation(self):
|
|
node = self._get_node()
|
|
history = node._construct_messages({"messages": [HumanMessage(content="Text")]})
|
|
self.assertEqual(len(history), 1)
|
|
self.assertEqual(history[0].type, "human")
|
|
self.assertIn("Text", history[0].content)
|
|
self.assertNotIn(f"{{question}}", history[0].content)
|
|
|
|
history = node._construct_messages(
|
|
{
|
|
"messages": [
|
|
HumanMessage(content="Text"),
|
|
VisualizationMessage(answer=self.schema, plan="randomplan"),
|
|
]
|
|
}
|
|
)
|
|
self.assertEqual(len(history), 2)
|
|
self.assertEqual(history[0].type, "human")
|
|
self.assertIn("Text", history[0].content)
|
|
self.assertNotIn("{{question}}", history[0].content)
|
|
self.assertEqual(history[1].type, "ai")
|
|
self.assertEqual(history[1].content, "randomplan")
|
|
|
|
history = node._construct_messages(
|
|
{
|
|
"messages": [
|
|
HumanMessage(content="Text"),
|
|
VisualizationMessage(answer=self.schema, plan="randomplan"),
|
|
HumanMessage(content="Text"),
|
|
]
|
|
}
|
|
)
|
|
self.assertEqual(len(history), 3)
|
|
self.assertEqual(history[0].type, "human")
|
|
self.assertIn("Text", history[0].content)
|
|
self.assertNotIn("{{question}}", history[0].content)
|
|
self.assertEqual(history[1].type, "ai")
|
|
self.assertEqual(history[1].content, "randomplan")
|
|
self.assertEqual(history[2].type, "human")
|
|
self.assertIn("Text", history[2].content)
|
|
self.assertNotIn("{{question}}", history[2].content)
|
|
|
|
def test_agent_reconstructs_conversation_and_omits_unknown_messages(self):
|
|
node = self._get_node()
|
|
history = node._construct_messages(
|
|
{
|
|
"messages": [
|
|
HumanMessage(content="Text"),
|
|
AssistantMessage(content="test"),
|
|
]
|
|
}
|
|
)
|
|
self.assertEqual(len(history), 1)
|
|
self.assertEqual(history[0].type, "human")
|
|
self.assertIn("Text", history[0].content)
|
|
self.assertNotIn("{{question}}", history[0].content)
|
|
|
|
def test_agent_reconstructs_conversation_with_failures(self):
|
|
node = self._get_node()
|
|
history = node._construct_messages(
|
|
{
|
|
"messages": [
|
|
HumanMessage(content="Text"),
|
|
FailureMessage(content="Error"),
|
|
HumanMessage(content="Text"),
|
|
]
|
|
}
|
|
)
|
|
self.assertEqual(len(history), 1)
|
|
self.assertEqual(history[0].type, "human")
|
|
self.assertIn("Text", history[0].content)
|
|
self.assertNotIn("{{question}}", history[0].content)
|
|
|
|
def test_agent_filters_out_low_count_events(self):
|
|
_create_person(distinct_ids=["test"], team=self.team)
|
|
for i in range(26):
|
|
_create_event(event=f"event{i}", distinct_id="test", team=self.team)
|
|
_create_event(event="distinctevent", distinct_id="test", team=self.team)
|
|
node = self._get_node()
|
|
self.assertEqual(
|
|
node._events_prompt,
|
|
"<defined_events><event><name>All Events</name><description>All events. This is a wildcard that matches all events.</description></event><event><name>distinctevent</name></event></defined_events>",
|
|
)
|
|
|
|
def test_agent_preserves_low_count_events_for_smaller_teams(self):
|
|
_create_person(distinct_ids=["test"], team=self.team)
|
|
_create_event(event="distinctevent", distinct_id="test", team=self.team)
|
|
node = self._get_node()
|
|
self.assertIn("distinctevent", node._events_prompt)
|
|
self.assertIn("all events", node._events_prompt)
|
|
|
|
def test_agent_scratchpad(self):
|
|
node = self._get_node()
|
|
scratchpad = [
|
|
(AgentAction(tool="test1", tool_input="input1", log="log1"), "test"),
|
|
(AgentAction(tool="test2", tool_input="input2", log="log2"), None),
|
|
(AgentAction(tool="test3", tool_input="input3", log="log3"), ""),
|
|
]
|
|
prompt = node._get_agent_scratchpad(scratchpad)
|
|
self.assertIn("log1", prompt)
|
|
self.assertIn("log3", prompt)
|
|
|
|
def test_agent_handles_output_without_action_block(self):
|
|
with patch(
|
|
"ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model",
|
|
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")),
|
|
):
|
|
node = self._get_node()
|
|
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
|
|
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
|
action, obs = state_update["intermediate_steps"][0]
|
|
self.assertIsNone(obs)
|
|
self.assertIn("I don't want to output an action.", action.log)
|
|
self.assertIn("Action:", action.log)
|
|
self.assertIn("Action:", action.tool_input)
|
|
|
|
def test_agent_handles_output_with_malformed_json(self):
|
|
with patch(
|
|
"ee.hogai.taxonomy_agent.nodes.TaxonomyAgentPlannerNode._model",
|
|
return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")),
|
|
):
|
|
node = self._get_node()
|
|
state_update = node.run({"messages": [HumanMessage(content="Question")]}, {})
|
|
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
|
action, obs = state_update["intermediate_steps"][0]
|
|
self.assertIsNone(obs)
|
|
self.assertIn("Thought.\nAction: abc", action.log)
|
|
self.assertIn("action", action.tool_input)
|
|
self.assertIn("action_input", action.tool_input)
|
|
|
|
def test_node_outputs_all_events_prompt(self):
|
|
node = self._get_node()
|
|
self.assertIn("All Events", node._events_prompt)
|
|
self.assertIn(
|
|
"<event><name>All Events</name><description>All events. This is a wildcard that matches all events.</description></event>",
|
|
node._events_prompt,
|
|
)
|
|
|
|
def test_format_prompt(self):
|
|
node = self._get_node()
|
|
self.assertNotIn("Human:", node._get_react_format_prompt(DummyToolkit(self.team)))
|
|
self.assertIn("retrieve_event_properties,", node._get_react_format_prompt(DummyToolkit(self.team)))
|
|
self.assertIn(
|
|
"retrieve_event_properties(event_name: str)", node._get_react_format_prompt(DummyToolkit(self.team))
|
|
)
|
|
|
|
|
|
@override_settings(IN_UNIT_TESTING=True)
|
|
class TestTaxonomyAgentPlannerToolsNode(ClickhouseTestMixin, APIBaseTest):
|
|
def _get_node(self):
|
|
class Node(TaxonomyAgentPlannerToolsNode):
|
|
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
|
|
toolkit = DummyToolkit(self._team)
|
|
return super()._run_with_toolkit(state, toolkit, config=config)
|
|
|
|
return Node(self.team)
|
|
|
|
def test_node_handles_action_name_validation_error(self):
|
|
state = {
|
|
"intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")],
|
|
"messages": [],
|
|
}
|
|
node = self._get_node()
|
|
state_update = node.run(state, {})
|
|
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
|
action, observation = state_update["intermediate_steps"][0]
|
|
self.assertIsNotNone(observation)
|
|
self.assertIn("<pydantic_exception>", observation)
|
|
|
|
def test_node_handles_action_input_validation_error(self):
|
|
state = {
|
|
"intermediate_steps": [
|
|
(AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test")
|
|
],
|
|
"messages": [],
|
|
}
|
|
node = self._get_node()
|
|
state_update = node.run(state, {})
|
|
self.assertEqual(len(state_update["intermediate_steps"]), 1)
|
|
action, observation = state_update["intermediate_steps"][0]
|
|
self.assertIsNotNone(observation)
|
|
self.assertIn("<pydantic_exception>", observation)
|