0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-25 11:17:50 +01:00
posthog/ee/hogai/utils.py
Michael Matloka 23bd1a010f
feat(max): Summarize insight results (#26172)
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
2024-11-18 11:05:54 +01:00

115 lines
3.5 KiB
Python

import operator
from abc import ABC, abstractmethod
from collections.abc import Sequence
from enum import StrEnum
from typing import Annotated, Optional, TypedDict, Union
from jsonref import replace_refs
from langchain_core.agents import AgentAction
from langchain_core.messages import (
HumanMessage as LangchainHumanMessage,
merge_message_runs,
)
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START
from pydantic import BaseModel, Field
from posthog.models.team.team import Team
from posthog.schema import (
AssistantMessage,
FailureMessage,
HumanMessage,
RootAssistantMessage,
RouterMessage,
VisualizationMessage,
)
AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage, FailureMessage, RouterMessage]
class Conversation(BaseModel):
messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=20)
session_id: str
class AssistantState(TypedDict, total=False):
messages: Annotated[Sequence[AssistantMessageUnion], operator.add]
intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]]
plan: Optional[str]
class AssistantNodeName(StrEnum):
START = START
END = END
ROUTER = "router"
TRENDS_PLANNER = "trends_planner"
TRENDS_PLANNER_TOOLS = "trends_planner_tools"
TRENDS_GENERATOR = "trends_generator"
TRENDS_GENERATOR_TOOLS = "trends_generator_tools"
FUNNEL_PLANNER = "funnel_planner"
FUNNEL_PLANNER_TOOLS = "funnel_planner_tools"
FUNNEL_GENERATOR = "funnel_generator"
FUNNEL_GENERATOR_TOOLS = "funnel_generator_tools"
SUMMARIZER = "summarizer"
class AssistantNode(ABC):
_team: Team
def __init__(self, team: Team):
self._team = team
@abstractmethod
def run(cls, state: AssistantState, config: RunnableConfig) -> AssistantState:
raise NotImplementedError
def remove_line_breaks(line: str) -> str:
return line.replace("\n", " ")
def merge_human_messages(messages: list[LangchainHumanMessage]) -> list[LangchainHumanMessage]:
"""
Filters out duplicated human messages and merges them into one message.
"""
contents = set()
filtered_messages = []
for message in messages:
if message.content in contents:
continue
contents.add(message.content)
filtered_messages.append(message)
return merge_message_runs(filtered_messages)
def filter_visualization_conversation(
messages: Sequence[AssistantMessageUnion],
) -> tuple[list[LangchainHumanMessage], list[VisualizationMessage]]:
"""
Splits, filters and merges the message history to be consumable by agents. Returns human and visualization messages.
"""
stack: list[LangchainHumanMessage] = []
human_messages: list[LangchainHumanMessage] = []
visualization_messages: list[VisualizationMessage] = []
for message in messages:
if isinstance(message, HumanMessage):
stack.append(LangchainHumanMessage(content=message.content))
elif isinstance(message, VisualizationMessage) and message.answer:
if stack:
human_messages += merge_human_messages(stack)
stack = []
visualization_messages.append(message)
if stack:
human_messages += merge_human_messages(stack)
return human_messages, visualization_messages
def dereference_schema(schema: dict) -> dict:
new_schema: dict = replace_refs(schema, proxies=False, lazy_load=False)
if "$defs" in new_schema:
new_schema.pop("$defs")
return new_schema