0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-29 11:12:33 +01:00

fix(product-assistant): trim redundant queries in the schema generator (#26425)

This commit is contained in:
Georgiy Tarasov 2024-11-26 17:29:56 +01:00 committed by GitHub
parent 9b819fc8fc
commit e3ba870704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 139 additions and 3 deletions

View File

@ -144,7 +144,10 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
human_messages, visualization_messages = filter_visualization_conversation(messages)
first_ai_message = True
for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages):
for idx, (human_message, ai_message) in enumerate(
itertools.zip_longest(human_messages, visualization_messages)
):
# Plans go first
if ai_message:
conversation.append(
HumanMessagePromptTemplate.from_template(
@ -161,6 +164,7 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
).format(plan=generated_plan)
)
# Then questions
if human_message:
conversation.append(
HumanMessagePromptTemplate.from_template(QUESTION_PROMPT, template_format="mustache").format(
@ -168,7 +172,8 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
)
)
if ai_message:
# Then schemas, but include only last generated schema because it doesn't need more context.
if ai_message and idx + 1 == len(visualization_messages):
conversation.append(
LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "")
)

View File

@ -9,9 +9,11 @@ from langchain_core.runnables import RunnableLambda
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
from posthog.schema import (
AssistantMessage,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import APIBaseTest, ClickhouseTestMixin
@ -169,6 +171,71 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertNotIn("{{question}}", history[5].content)
self.assertIn("Follow\nUp", history[5].content)
def test_agent_reconstructs_typical_conversation(self):
node = DummyGeneratorNode(self.team)
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
],
"plan": "Plan 3",
}
)
self.assertEqual(len(history), 8)
self.assertEqual(history[0].type, "human")
self.assertIn("mapping", history[0].content)
self.assertEqual(history[1].type, "human")
self.assertIn("Plan 1", history[1].content)
self.assertEqual(history[2].type, "human")
self.assertIn("Question 1", history[2].content)
self.assertEqual(history[3].type, "human")
self.assertIn("Plan 2", history[3].content)
self.assertEqual(history[4].type, "human")
self.assertIn("Question 2", history[4].content)
self.assertEqual(history[5].type, "ai")
self.assertEqual(history[6].type, "human")
self.assertIn("Plan 3", history[6].content)
self.assertEqual(history[7].type, "human")
self.assertIn("Question 3", history[7].content)
def test_prompt(self):
node = DummyGeneratorNode(self.team)
state = {
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
],
"plan": "Plan 3",
}
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
def assert_prompt(prompt):
self.assertEqual(len(prompt), 4)
self.assertEqual(prompt[0].type, "system")
self.assertEqual(prompt[1].type, "human")
self.assertEqual(prompt[2].type, "ai")
self.assertEqual(prompt[3].type, "human")
generator_model_mock.return_value = RunnableLambda(assert_prompt)
node.run(state, {})
def test_failover_with_incorrect_schema(self):
node = DummyGeneratorNode(self.team)
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:

View File

@ -18,6 +18,7 @@ from posthog.schema import (
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person
@ -116,6 +117,36 @@ class TestTaxonomyAgentPlannerNode(ClickhouseTestMixin, APIBaseTest):
self.assertIn("Text", history[0].content)
self.assertNotIn("{{question}}", history[0].content)
def test_agent_reconstructs_typical_conversation(self):
node = self._get_node()
history = node._construct_messages(
{
"messages": [
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
HumanMessage(content="Question 3"),
RouterMessage(content="funnel"),
]
}
)
self.assertEqual(len(history), 5)
self.assertEqual(history[0].type, "human")
self.assertIn("Question 1", history[0].content)
self.assertEqual(history[1].type, "ai")
self.assertEqual(history[1].content, "Plan 1")
self.assertEqual(history[2].type, "human")
self.assertIn("Question 2", history[2].content)
self.assertEqual(history[3].type, "ai")
self.assertEqual(history[3].content, "Plan 2")
self.assertEqual(history[4].type, "human")
self.assertIn("Question 3", history[4].content)
def test_agent_filters_out_low_count_events(self):
_create_person(distinct_ids=["test"], team=self.team)
for i in range(26):

View File

@ -1,7 +1,14 @@
from langchain_core.messages import HumanMessage as LangchainHumanMessage
from ee.hogai.utils import filter_visualization_conversation, merge_human_messages
from posthog.schema import AssistantTrendsQuery, FailureMessage, HumanMessage, VisualizationMessage
from posthog.schema import (
AssistantMessage,
AssistantTrendsQuery,
FailureMessage,
HumanMessage,
RouterMessage,
VisualizationMessage,
)
from posthog.test.base import BaseTest
@ -37,3 +44,29 @@ class TestTrendsUtils(BaseTest):
self.assertEqual(
visualization_messages, [VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="plan")]
)
def test_filters_typical_conversation(self):
human_messages, visualization_messages = filter_visualization_conversation(
[
HumanMessage(content="Question 1"),
RouterMessage(content="trends"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
AssistantMessage(content="Summary 1"),
HumanMessage(content="Question 2"),
RouterMessage(content="funnel"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
AssistantMessage(content="Summary 2"),
]
)
self.assertEqual(len(human_messages), 2)
self.assertEqual(len(visualization_messages), 2)
self.assertEqual(
human_messages, [LangchainHumanMessage(content="Question 1"), LangchainHumanMessage(content="Question 2")]
)
self.assertEqual(
visualization_messages,
[
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 1"),
VisualizationMessage(answer=AssistantTrendsQuery(series=[]), plan="Plan 2"),
],
)