2024-10-30 14:40:56 +01:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
2024-10-25 14:19:45 +02:00
|
|
|
from django.test import override_settings
|
2024-10-30 14:40:56 +01:00
|
|
|
from langchain_core.runnables import RunnableLambda
|
2024-10-25 14:19:45 +02:00
|
|
|
|
2024-11-12 14:39:27 +01:00
|
|
|
from ee.hogai.trends.nodes import TrendsGeneratorNode, TrendsSchemaGeneratorOutput
|
2024-10-30 14:40:56 +01:00
|
|
|
from posthog.schema import (
|
2024-11-12 14:39:27 +01:00
|
|
|
AssistantTrendsQuery,
|
2024-10-30 14:40:56 +01:00
|
|
|
HumanMessage,
|
|
|
|
VisualizationMessage,
|
|
|
|
)
|
2024-11-12 14:39:27 +01:00
|
|
|
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
|
2024-10-25 14:19:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
@override_settings(IN_UNIT_TESTING=True)
|
2024-11-12 14:39:27 +01:00
|
|
|
class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
|
2024-10-25 14:19:45 +02:00
|
|
|
def setUp(self):
|
2024-11-12 14:39:27 +01:00
|
|
|
self.schema = AssistantTrendsQuery(series=[])
|
2024-10-25 14:19:45 +02:00
|
|
|
|
2024-10-30 14:40:56 +01:00
|
|
|
def test_node_runs(self):
|
2024-11-12 14:39:27 +01:00
|
|
|
node = TrendsGeneratorNode(self.team)
|
|
|
|
with patch.object(TrendsGeneratorNode, "_model") as generator_model_mock:
|
2024-10-30 14:40:56 +01:00
|
|
|
generator_model_mock.return_value = RunnableLambda(
|
2024-11-12 14:39:27 +01:00
|
|
|
lambda _: TrendsSchemaGeneratorOutput(reasoning_steps=["step"], answer=self.schema).model_dump()
|
2024-10-30 14:40:56 +01:00
|
|
|
)
|
|
|
|
new_state = node.run(
|
|
|
|
{
|
|
|
|
"messages": [HumanMessage(content="Text")],
|
|
|
|
"plan": "Plan",
|
|
|
|
},
|
|
|
|
{},
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
|
|
|
new_state,
|
|
|
|
{
|
|
|
|
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"])],
|
|
|
|
"intermediate_steps": None,
|
|
|
|
},
|
|
|
|
)
|