mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-29 11:12:33 +01:00
fix(product-assistant): trim redundant queries in the schema generator (#26425)
This commit is contained in:
parent
9b819fc8fc
commit
e3ba870704
@ -144,7 +144,10 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
human_messages, visualization_messages = filter_visualization_conversation(messages)
|
||||
first_ai_message = True
|
||||
|
||||
for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages):
|
||||
for idx, (human_message, ai_message) in enumerate(
|
||||
itertools.zip_longest(human_messages, visualization_messages)
|
||||
):
|
||||
# Plans go first
|
||||
if ai_message:
|
||||
conversation.append(
|
||||
HumanMessagePromptTemplate.from_template(
|
||||
@ -161,6 +164,7 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
).format(plan=generated_plan)
|
||||
)
|
||||
|
||||
# Then questions
|
||||
if human_message:
|
||||
conversation.append(
|
||||
HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
|
||||
@ -168,7 +172,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
|
||||
)
|
||||
)
|
||||
|
||||
if ai_message:
|
||||
# Then schemas, but include only last generated schema because it doesn't need more context.
|
||||
if ai_message and idx + 1 == len(visualization_messages):
|
||||
conversation.append(
|
||||
LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "")
|
||||
)
|
||||
|
@ -9,9 +9,11 @@ from langchain_core.runnables import RunnableLambda
|
||||
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
|
||||
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
|
||||
from posthog.schema import (
|
||||
AssistantMessage,
|
||||
AssistantTrendsQuery,
|
||||
FailureMessage,
|
||||
HumanMessage,
|
||||
RouterMessage,
|
||||
VisualizationMessage,
|
||||
)
|
||||
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
|
||||
@ -169,6 +171,71 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
|
||||
self.assertNotIn("{{question}}", history[5].content)
|
||||
self.assertIn("Follow\nUp", history[5].content)
|
||||
|
||||
def test_agent_reconstructs_typical_conversation(self):
|
||||
node = DummyGeneratorNode(self.team)
|
||||
history = node._construct_messages(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="Question 1"),
|
||||
RouterMessage(content="trends"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
|
||||
AssistantMessage(content="Summary 1"),
|
||||
HumanMessage(content="Question 2"),
|
||||
RouterMessage(content="funnel"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
|
||||
AssistantMessage(content="Summary 2"),
|
||||
HumanMessage(content="Question 3"),
|
||||
RouterMessage(content="funnel"),
|
||||
],
|
||||
"plan": "Plan 3",
|
||||
}
|
||||
)
|
||||
self.assertEqual(len(history), 8)
|
||||
self.assertEqual(history[0].type, "human")
|
||||
self.assertIn("mapping", history[0].content)
|
||||
self.assertEqual(history[1].type, "human")
|
||||
self.assertIn("Plan 1", history[1].content)
|
||||
self.assertEqual(history[2].type, "human")
|
||||
self.assertIn("Question 1", history[2].content)
|
||||
self.assertEqual(history[3].type, "human")
|
||||
self.assertIn("Plan 2", history[3].content)
|
||||
self.assertEqual(history[4].type, "human")
|
||||
self.assertIn("Question 2", history[4].content)
|
||||
self.assertEqual(history[5].type, "ai")
|
||||
self.assertEqual(history[6].type, "human")
|
||||
self.assertIn("Plan 3", history[6].content)
|
||||
self.assertEqual(history[7].type, "human")
|
||||
self.assertIn("Question 3", history[7].content)
|
||||
|
||||
def test_prompt(self):
|
||||
node = DummyGeneratorNode(self.team)
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="Question 1"),
|
||||
RouterMessage(content="trends"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
|
||||
AssistantMessage(content="Summary 1"),
|
||||
HumanMessage(content="Question 2"),
|
||||
RouterMessage(content="funnel"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
|
||||
AssistantMessage(content="Summary 2"),
|
||||
HumanMessage(content="Question 3"),
|
||||
RouterMessage(content="funnel"),
|
||||
],
|
||||
"plan": "Plan 3",
|
||||
}
|
||||
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
|
||||
|
||||
def assert_prompt(prompt):
|
||||
self.assertEqual(len(prompt), 4)
|
||||
self.assertEqual(prompt[0].type, "system")
|
||||
self.assertEqual(prompt[1].type, "human")
|
||||
self.assertEqual(prompt[2].type, "ai")
|
||||
self.assertEqual(prompt[3].type, "human")
|
||||
|
||||
generator_model_mock.return_value = RunnableLambda(assert_prompt)
|
||||
node.run(state, {})
|
||||
|
||||
def test_failover_with_incorrect_schema(self):
|
||||
node = DummyGeneratorNode(self.team)
|
||||
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
|
||||
|
@ -18,6 +18,7 @@ from posthog.schema import (
|
||||
AssistantTrendsQuery,
|
||||
FailureMessage,
|
||||
HumanMessage,
|
||||
RouterMessage,
|
||||
VisualizationMessage,
|
||||
)
|
||||
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
|
||||
@ -116,6 +117,36 @@ class TestTaxonomyAgentPlannerNode(ClickhouseTestMixin, APIBaseTest):
|
||||
self.assertIn("Text", history[0].content)
|
||||
self.assertNotIn("{{question}}", history[0].content)
|
||||
|
||||
def test_agent_reconstructs_typical_conversation(self):
|
||||
node = self._get_node()
|
||||
history = node._construct_messages(
|
||||
{
|
||||
"messages": [
|
||||
HumanMessage(content="Question 1"),
|
||||
RouterMessage(content="trends"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
|
||||
AssistantMessage(content="Summary 1"),
|
||||
HumanMessage(content="Question 2"),
|
||||
RouterMessage(content="funnel"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
|
||||
AssistantMessage(content="Summary 2"),
|
||||
HumanMessage(content="Question 3"),
|
||||
RouterMessage(content="funnel"),
|
||||
]
|
||||
}
|
||||
)
|
||||
self.assertEqual(len(history), 5)
|
||||
self.assertEqual(history[0].type, "human")
|
||||
self.assertIn("Question 1", history[0].content)
|
||||
self.assertEqual(history[1].type, "ai")
|
||||
self.assertEqual(history[1].content, "Plan 1")
|
||||
self.assertEqual(history[2].type, "human")
|
||||
self.assertIn("Question 2", history[2].content)
|
||||
self.assertEqual(history[3].type, "ai")
|
||||
self.assertEqual(history[3].content, "Plan 2")
|
||||
self.assertEqual(history[4].type, "human")
|
||||
self.assertIn("Question 3", history[4].content)
|
||||
|
||||
def test_agent_filters_out_low_count_events(self):
|
||||
_create_person(distinct_ids=["test"], team=self.team)
|
||||
for i in range(26):
|
||||
|
@ -1,7 +1,14 @@
|
||||
from langchain_core.messages import HumanMessage as LangchainHumanMessage
|
||||
|
||||
from ee.hogai.utils import filter_visualization_conversation, merge_human_messages
|
||||
from posthog.schema import AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
|
||||
from posthog.schema import (
|
||||
AssistantMessage,
|
||||
AssistantTrendsQuery,
|
||||
FailureMessage,
|
||||
HumanMessage,
|
||||
RouterMessage,
|
||||
VisualizationMessage,
|
||||
)
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
|
||||
@ -37,3 +44,29 @@ class TestTrendsUtils(BaseTest):
|
||||
self.assertEqual(
|
||||
visualization_messages, [VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan")]
|
||||
)
|
||||
|
||||
def test_filters_typical_conversation(self):
|
||||
human_messages, visualization_messages = filter_visualization_conversation(
|
||||
[
|
||||
HumanMessage(content="Question 1"),
|
||||
RouterMessage(content="trends"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
|
||||
AssistantMessage(content="Summary 1"),
|
||||
HumanMessage(content="Question 2"),
|
||||
RouterMessage(content="funnel"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
|
||||
AssistantMessage(content="Summary 2"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(len(human_messages), 2)
|
||||
self.assertEqual(len(visualization_messages), 2)
|
||||
self.assertEqual(
|
||||
human_messages, [LangchainHumanMessage(content="Question 1"), LangchainHumanMessage(content="Question 2")]
|
||||
)
|
||||
self.assertEqual(
|
||||
visualization_messages,
|
||||
[
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
|
||||
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user