0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-22 08:40:03 +01:00
posthog/ee/hogai/trends/nodes.py

51 lines
2.0 KiB
Python

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
from ee.hogai.taxonomy_agent.nodes import TaxonomyAgentPlannerNode, TaxonomyAgentPlannerToolsNode
from ee.hogai.trends.prompts import REACT_SYSTEM_PROMPT, TRENDS_SYSTEM_PROMPT
from ee.hogai.trends.toolkit import TRENDS_SCHEMA, TrendsTaxonomyAgentToolkit
from ee.hogai.utils import AssistantState
from posthog.schema import AssistantTrendsQuery
class TrendsPlannerNode(TaxonomyAgentPlannerNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
toolkit = TrendsTaxonomyAgentToolkit(self._team)
prompt = ChatPromptTemplate.from_messages(
[
("system", REACT_SYSTEM_PROMPT),
],
template_format="mustache",
)
return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config)
class TrendsPlannerToolsNode(TaxonomyAgentPlannerToolsNode):
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
toolkit = TrendsTaxonomyAgentToolkit(self._team)
return super()._run_with_toolkit(state, toolkit, config=config)
TrendsSchemaGeneratorOutput = SchemaGeneratorOutput[AssistantTrendsQuery]
class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]):
INSIGHT_NAME = "Trends"
OUTPUT_MODEL = TrendsSchemaGeneratorOutput
OUTPUT_SCHEMA = TRENDS_SCHEMA
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
prompt = ChatPromptTemplate.from_messages(
[
("system", TRENDS_SYSTEM_PROMPT),
],
template_format="mustache",
)
return super()._run_with_prompt(state, prompt, config=config)
class TrendsGeneratorToolsNode(SchemaGeneratorToolsNode):
pass