import json from unittest.mock import patch from django.test import override_settings from langchain_core.agents import AgentAction from langchain_core.messages import AIMessage as LangchainAIMessage from langchain_core.runnables import RunnableLambda from ee.hogai.trends.nodes import ( CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode, GenerateTrendsToolsNode, ) from ee.hogai.trends.utils import GenerateTrendOutputModel from ee.hogai.utils import AssistantNodeName from posthog.schema import ( AssistantMessage, ExperimentalAITrendsQuery, FailureMessage, HumanMessage, VisualizationMessage, ) from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person @override_settings(IN_UNIT_TESTING=True) class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): super().setUp() self.schema = ExperimentalAITrendsQuery(series=[]) def test_agent_reconstructs_conversation(self): node = CreateTrendsPlanNode(self.team) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn(f"{{question}}", history[0].content) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), VisualizationMessage(answer=self.schema, plan="randomplan"), ] } ) self.assertEqual(len(history), 2) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) self.assertEqual(history[1].type, "ai") self.assertEqual(history[1].content, "randomplan") history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), VisualizationMessage(answer=self.schema, plan="randomplan"), HumanMessage(content="Text"), ] } ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) self.assertEqual(history[1].type, "ai") self.assertEqual(history[1].content, "randomplan") self.assertEqual(history[2].type, "human") self.assertIn("Text", history[2].content) self.assertNotIn("{{question}}", history[2].content) def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): node = CreateTrendsPlanNode(self.team) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), AssistantMessage(content="test"), ] } ) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) def test_agent_reconstructs_conversation_with_failures(self): node = CreateTrendsPlanNode(self.team) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), FailureMessage(content="Error"), HumanMessage(content="Text"), ] } ) self.assertEqual(len(history), 1) self.assertEqual(history[0].type, "human") self.assertIn("Text", history[0].content) self.assertNotIn("{{question}}", history[0].content) def test_agent_filters_out_low_count_events(self): _create_person(distinct_ids=["test"], team=self.team) for i in range(26): _create_event(event=f"event{i}", distinct_id="test", team=self.team) _create_event(event="distinctevent", distinct_id="test", team=self.team) node = CreateTrendsPlanNode(self.team) self.assertEqual( node._events_prompt, "\nall events\ndistinctevent\n", ) def test_agent_preserves_low_count_events_for_smaller_teams(self): _create_person(distinct_ids=["test"], team=self.team) _create_event(event="distinctevent", distinct_id="test", team=self.team) node = CreateTrendsPlanNode(self.team) self.assertIn("distinctevent", node._events_prompt) self.assertIn("all events", node._events_prompt) def test_agent_scratchpad(self): node = CreateTrendsPlanNode(self.team) scratchpad = [ (AgentAction(tool="test1", tool_input="input1", log="log1"), "test"), (AgentAction(tool="test2", tool_input="input2", log="log2"), None), (AgentAction(tool="test3", tool_input="input3", log="log3"), ""), ] prompt = node._get_agent_scratchpad(scratchpad) self.assertIn("log1", prompt) self.assertIn("log3", prompt) def test_agent_handles_output_without_action_block(self): with patch( "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", return_value=RunnableLambda(lambda _: LangchainAIMessage(content="I don't want to output an action.")), ): node = CreateTrendsPlanNode(self.team) state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) self.assertEqual(len(state_update["intermediate_steps"]), 1) action, obs = state_update["intermediate_steps"][0] self.assertIsNone(obs) self.assertIn("I don't want to output an action.", action.log) self.assertIn("Action:", action.log) self.assertIn("Action:", action.tool_input) def test_agent_handles_output_with_malformed_json(self): with patch( "ee.hogai.trends.nodes.CreateTrendsPlanNode._model", return_value=RunnableLambda(lambda _: LangchainAIMessage(content="Thought.\nAction: abc")), ): node = CreateTrendsPlanNode(self.team) state_update = node.run({"messages": [HumanMessage(content="Question")]}, {}) self.assertEqual(len(state_update["intermediate_steps"]), 1) action, obs = state_update["intermediate_steps"][0] self.assertIsNone(obs) self.assertIn("Thought.\nAction: abc", action.log) self.assertIn("action", action.tool_input) self.assertIn("action_input", action.tool_input) @override_settings(IN_UNIT_TESTING=True) class TestCreateTrendsPlanToolsNode(ClickhouseTestMixin, APIBaseTest): def test_node_handles_action_name_validation_error(self): state = { "intermediate_steps": [(AgentAction(tool="does not exist", tool_input="input", log="log"), "test")], "messages": [], } node = CreateTrendsPlanToolsNode(self.team) state_update = node.run(state, {}) self.assertEqual(len(state_update["intermediate_steps"]), 1) action, observation = state_update["intermediate_steps"][0] self.assertIsNotNone(observation) self.assertIn("", observation) def test_node_handles_action_input_validation_error(self): state = { "intermediate_steps": [ (AgentAction(tool="retrieve_entity_property_values", tool_input="input", log="log"), "test") ], "messages": [], } node = CreateTrendsPlanToolsNode(self.team) state_update = node.run(state, {}) self.assertEqual(len(state_update["intermediate_steps"]), 1) action, observation = state_update["intermediate_steps"][0] self.assertIsNotNone(observation) self.assertIn("", observation) @override_settings(IN_UNIT_TESTING=True) class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): def setUp(self): self.schema = ExperimentalAITrendsQuery(series=[]) def test_node_runs(self): node = GenerateTrendsNode(self.team) with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: generator_model_mock.return_value = RunnableLambda( lambda _: GenerateTrendOutputModel(reasoning_steps=["step"], answer=self.schema).model_dump() ) new_state = node.run( { "messages": [HumanMessage(content="Text")], "plan": "Plan", }, {}, ) self.assertEqual( new_state, { "messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])], "intermediate_steps": None, }, ) def test_agent_reconstructs_conversation(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) self.assertEqual(len(history), 2) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("Answer to this question:", history[1].content) self.assertNotIn("{{question}}", history[1].content) history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")], "plan": "randomplan"}) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), VisualizationMessage(answer=self.schema, plan="randomplan"), HumanMessage(content="Follow Up"), ], "plan": "newrandomplan", } ) self.assertEqual(len(history), 6) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) self.assertEqual(history[3].type, "ai") self.assertEqual(history[3].content, self.schema.model_dump_json()) self.assertEqual(history[4].type, "human") self.assertIn("the new plan", history[4].content) self.assertNotIn("{{plan}}", history[4].content) self.assertIn("newrandomplan", history[4].content) self.assertEqual(history[5].type, "human") self.assertIn("Answer to this question:", history[5].content) self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow Up", history[5].content) def test_agent_reconstructs_conversation_and_merges_messages(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation( { "messages": [HumanMessage(content="Te"), HumanMessage(content="xt")], "plan": "randomplan", } ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Te\nxt", history[2].content) node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), VisualizationMessage(answer=self.schema, plan="randomplan"), HumanMessage(content="Follow"), HumanMessage(content="Up"), ], "plan": "newrandomplan", } ) self.assertEqual(len(history), 6) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) self.assertEqual(history[3].type, "ai") self.assertEqual(history[3].content, self.schema.model_dump_json()) self.assertEqual(history[4].type, "human") self.assertIn("the new plan", history[4].content) self.assertNotIn("{{plan}}", history[4].content) self.assertIn("newrandomplan", history[4].content) self.assertEqual(history[5].type, "human") self.assertIn("Answer to this question:", history[5].content) self.assertNotIn("{{question}}", history[5].content) self.assertIn("Follow\nUp", history[5].content) def test_failover_with_incorrect_schema(self): node = GenerateTrendsNode(self.team) with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() # Emulate an incorrect JSON. It should be an object. schema["answer"] = [] generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) new_state = node.run({"messages": [HumanMessage(content="Text")]}, {}) self.assertIn("intermediate_steps", new_state) self.assertEqual(len(new_state["intermediate_steps"]), 1) new_state = node.run( { "messages": [HumanMessage(content="Text")], "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], }, {}, ) self.assertIn("intermediate_steps", new_state) self.assertEqual(len(new_state["intermediate_steps"]), 2) def test_node_leaves_failover(self): node = GenerateTrendsNode(self.team) with patch( "ee.hogai.trends.nodes.GenerateTrendsNode._model", return_value=RunnableLambda( lambda _: GenerateTrendOutputModel(reasoning_steps=[], answer=self.schema).model_dump() ), ): new_state = node.run( { "messages": [HumanMessage(content="Text")], "intermediate_steps": [(AgentAction(tool="", tool_input="", log="exception"), "exception")], }, {}, ) self.assertIsNone(new_state["intermediate_steps"]) new_state = node.run( { "messages": [HumanMessage(content="Text")], "intermediate_steps": [ (AgentAction(tool="", tool_input="", log="exception"), "exception"), (AgentAction(tool="", tool_input="", log="exception"), "exception"), ], }, {}, ) self.assertIsNone(new_state["intermediate_steps"]) def test_node_leaves_failover_after_second_unsuccessful_attempt(self): node = GenerateTrendsNode(self.team) with patch("ee.hogai.trends.nodes.GenerateTrendsNode._model") as generator_model_mock: schema = GenerateTrendOutputModel(reasoning_steps=[], answer=None).model_dump() # Emulate an incorrect JSON. It should be an object. schema["answer"] = [] generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema)) new_state = node.run( { "messages": [HumanMessage(content="Text")], "intermediate_steps": [ (AgentAction(tool="", tool_input="", log="exception"), "exception"), (AgentAction(tool="", tool_input="", log="exception"), "exception"), ], }, {}, ) self.assertIsNone(new_state["intermediate_steps"]) self.assertEqual(len(new_state["messages"]), 1) self.assertIsInstance(new_state["messages"][0], FailureMessage) def test_agent_reconstructs_conversation_with_failover(self): action = AgentAction(tool="fix", tool_input="validation error", log="exception") node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation( { "messages": [HumanMessage(content="Text")], "plan": "randomplan", "intermediate_steps": [(action, "uniqexception")], }, "uniqexception", ) self.assertEqual(len(history), 4) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) self.assertEqual(history[3].type, "human") self.assertIn("Pydantic", history[3].content) self.assertIn("uniqexception", history[3].content) def test_agent_reconstructs_conversation_with_failed_messages(self): node = GenerateTrendsNode(self.team) history = node._reconstruct_conversation( { "messages": [ HumanMessage(content="Text"), FailureMessage(content="Error"), HumanMessage(content="Text"), ], "plan": "randomplan", }, ) self.assertEqual(len(history), 3) self.assertEqual(history[0].type, "human") self.assertIn("mapping", history[0].content) self.assertEqual(history[1].type, "human") self.assertIn("the plan", history[1].content) self.assertNotIn("{{plan}}", history[1].content) self.assertIn("randomplan", history[1].content) self.assertEqual(history[2].type, "human") self.assertIn("Answer to this question:", history[2].content) self.assertNotIn("{{question}}", history[2].content) self.assertIn("Text", history[2].content) def test_router(self): node = GenerateTrendsNode(self.team) state = node.router({"messages": [], "intermediate_steps": None}) self.assertEqual(state, AssistantNodeName.END) state = node.router( {"messages": [], "intermediate_steps": [(AgentAction(tool="", tool_input="", log=""), None)]} ) self.assertEqual(state, AssistantNodeName.GENERATE_TRENDS_TOOLS) class TestGenerateTrendsToolsNode(ClickhouseTestMixin, APIBaseTest): def test_tools_node(self): node = GenerateTrendsToolsNode(self.team) action = AgentAction(tool="fix", tool_input="validationerror", log="pydanticexception") state = node.run({"messages": [], "intermediate_steps": [(action, None)]}, {}) self.assertIsNotNone("validationerror", state["intermediate_steps"][0][1]) self.assertIn("validationerror", state["intermediate_steps"][0][1]) self.assertIn("pydanticexception", state["intermediate_steps"][0][1])