mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-22 08:40:03 +01:00
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode, SchemaGeneratorToolsNode
|
|
from ee.hogai.schema_generator.utils import SchemaGeneratorOutput
|
|
from ee.hogai.taxonomy_agent.nodes import TaxonomyAgentPlannerNode, TaxonomyAgentPlannerToolsNode
|
|
from ee.hogai.trends.prompts import REACT_SYSTEM_PROMPT, TRENDS_SYSTEM_PROMPT
|
|
from ee.hogai.trends.toolkit import TRENDS_SCHEMA, TrendsTaxonomyAgentToolkit
|
|
from ee.hogai.utils import AssistantState
|
|
from posthog.schema import AssistantTrendsQuery
|
|
|
|
|
|
class TrendsPlannerNode(TaxonomyAgentPlannerNode):
|
|
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
|
|
toolkit = TrendsTaxonomyAgentToolkit(self._team)
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", REACT_SYSTEM_PROMPT),
|
|
],
|
|
template_format="mustache",
|
|
)
|
|
return super()._run_with_prompt_and_toolkit(state, prompt, toolkit, config=config)
|
|
|
|
|
|
class TrendsPlannerToolsNode(TaxonomyAgentPlannerToolsNode):
|
|
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
|
|
toolkit = TrendsTaxonomyAgentToolkit(self._team)
|
|
return super()._run_with_toolkit(state, toolkit, config=config)
|
|
|
|
|
|
TrendsSchemaGeneratorOutput = SchemaGeneratorOutput[AssistantTrendsQuery]
|
|
|
|
|
|
class TrendsGeneratorNode(SchemaGeneratorNode[AssistantTrendsQuery]):
|
|
INSIGHT_NAME = "Trends"
|
|
OUTPUT_MODEL = TrendsSchemaGeneratorOutput
|
|
OUTPUT_SCHEMA = TRENDS_SCHEMA
|
|
|
|
def run(self, state: AssistantState, config: RunnableConfig) -> AssistantState:
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", TRENDS_SYSTEM_PROMPT),
|
|
],
|
|
template_format="mustache",
|
|
)
|
|
return super()._run_with_prompt(state, prompt, config=config)
|
|
|
|
|
|
class TrendsGeneratorToolsNode(SchemaGeneratorToolsNode):
|
|
pass
|