0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-24 18:07:17 +01:00
posthog/ee/hogai/utils.py
Georgiy Tarasov 8de5762cd1
feat(product-assistant): trends generation failover (#25769)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
2024-10-30 13:40:56 +00:00

52 lines
1.5 KiB
Python

import operator
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import StrEnum
from typing import Annotated, Optional, TypedDict, Union
from langchain_core.agents import AgentAction
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START
from pydantic import BaseModel, Field
from posthog.models.team.team import Team
from posthog.schema import AssistantMessage, FailureMessage, HumanMessage, RootAssistantMessage, VisualizationMessage
AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage]
class Conversation(BaseModel):
messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=20)
session_id: str
class AssistantState(TypedDict):
messages: Annotated[Sequence[AssistantMessageUnion], operator.add]
intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]]
plan: Optional[str]
class AssistantNodeName(StrEnum):
START = START
END = END
CREATE_TRENDS_PLAN = "create_trends_plan"
CREATE_TRENDS_PLAN_TOOLS = "create_trends_plan_tools"
GENERATE_TRENDS = "generate_trends_schema"
GENERATE_TRENDS_TOOLS = "generate_trends_tools"
class AssistantNode(ABC):
name: AssistantNodeName
_team: Team
def __init__(self, team: Team):
self._team = team
@abstractmethod
def run(cls, state: AssistantState, config: RunnableConfig):
raise NotImplementedError
def remove_line_breaks(line: str) -> str:
return line.replace("\n", " ")