0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-25 02:31:30 +01:00
posthog/ee/hogai/utils.py

52 lines
1.5 KiB
Python
Raw Normal View History

import operator
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import StrEnum
from typing import Annotated, Optional, TypedDict, Union
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START
from pydantic import BaseModel, Field
from posthog.models.team.team import Team
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage
AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage]
class Conversation(BaseModel):
messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=20)
session_id: str
class AssistantState(TypedDict):
messages: Annotated[Sequence[AssistantMessageUnion], operator.add]
intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]]
plan: Optional[str]
class AssistantNodeName(StrEnum):
START = START
END = END
CREATE_TRENDS_PLAN = "create_trends_plan"
CREATE_TRENDS_PLAN_TOOLS = "create_trends_plan_tools"
GENERATE_TRENDS = "generate_trends_schema"
GENERATE_TRENDS_TOOLS = "generate_trends_tools"
class AssistantNode(ABC):
name: AssistantNodeName
_team: Team
def __init__(self, team: Team):
self._team = team
@abstractmethod
def run(cls, state: AssistantState, config: RunnableConfig):
raise NotImplementedError
def remove_line_breaks(line: str) -> str:
return line.replace("\n", " ")