2024-10-25 14:19:45 +02:00
|
|
|
import operator
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Sequence
|
|
|
|
from enum import StrEnum
|
|
|
|
from typing import Annotated, Optional, TypedDict, Union
|
|
|
|
|
|
|
|
from langchain_core.agents import AgentAction
|
|
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
from langgraph.graph import END, START
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
from posthog.models.team.team import Team
|
2024-10-30 14:40:56 +01:00
|
|
|
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage
|
2024-10-25 14:19:45 +02:00
|
|
|
|
2024-10-30 14:40:56 +01:00
|
|
|
AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage]
|
2024-10-25 14:19:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Conversation(BaseModel):
|
|
|
|
messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=20)
|
|
|
|
session_id: str
|
|
|
|
|
|
|
|
|
|
|
|
class AssistantState(TypedDict):
|
|
|
|
messages: Annotated[Sequence[AssistantMessageUnion], operator.add]
|
|
|
|
intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]]
|
|
|
|
plan: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
class AssistantNodeName(StrEnum):
|
|
|
|
START = START
|
|
|
|
END = END
|
|
|
|
CREATE_TRENDS_PLAN = "create_trends_plan"
|
|
|
|
CREATE_TRENDS_PLAN_TOOLS = "create_trends_plan_tools"
|
|
|
|
GENERATE_TRENDS = "generate_trends_schema"
|
|
|
|
GENERATE_TRENDS_TOOLS = "generate_trends_tools"
|
|
|
|
|
|
|
|
|
|
|
|
class AssistantNode(ABC):
|
|
|
|
name: AssistantNodeName
|
|
|
|
_team: Team
|
|
|
|
|
|
|
|
def __init__(self, team: Team):
|
|
|
|
self._team = team
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def run(cls, state: AssistantState, config: RunnableConfig):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def remove_line_breaks(line: str) -> str:
|
|
|
|
return line.replace("\n", " ")
|