mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-28 09:16:49 +01:00
23bd1a010f
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
240 lines
9.5 KiB
Python
240 lines
9.5 KiB
Python
from collections.abc import Generator
|
|
from typing import Any, Literal, TypedDict, TypeGuard, Union
|
|
|
|
from langchain_core.messages import AIMessageChunk
|
|
from langfuse.callback import CallbackHandler
|
|
from langgraph.graph.state import StateGraph
|
|
from pydantic import BaseModel
|
|
from sentry_sdk import capture_exception
|
|
|
|
from ee import settings
|
|
from ee.hogai.funnels.nodes import (
|
|
FunnelGeneratorNode,
|
|
FunnelGeneratorToolsNode,
|
|
FunnelPlannerNode,
|
|
FunnelPlannerToolsNode,
|
|
)
|
|
from ee.hogai.router.nodes import RouterNode
|
|
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode
|
|
from ee.hogai.summarizer.nodes import SummarizerNode
|
|
from ee.hogai.trends.nodes import (
|
|
TrendsGeneratorNode,
|
|
TrendsGeneratorToolsNode,
|
|
TrendsPlannerNode,
|
|
TrendsPlannerToolsNode,
|
|
)
|
|
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
|
|
from posthog.models.team.team import Team
|
|
from posthog.schema import (
|
|
AssistantGenerationStatusEvent,
|
|
AssistantGenerationStatusType,
|
|
AssistantMessage,
|
|
FailureMessage,
|
|
VisualizationMessage,
|
|
)
|
|
|
|
if settings.LANGFUSE_PUBLIC_KEY:
|
|
langfuse_handler = CallbackHandler(
|
|
public_key=settings.LANGFUSE_PUBLIC_KEY, secret_key=settings.LANGFUSE_SECRET_KEY, host=settings.LANGFUSE_HOST
|
|
)
|
|
else:
|
|
langfuse_handler = None
|
|
|
|
|
|
def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, AssistantState]]]:
|
|
"""
|
|
Transition between nodes.
|
|
"""
|
|
return len(update) == 2 and update[0] == "updates"
|
|
|
|
|
|
class LangGraphState(TypedDict):
|
|
langgraph_node: AssistantNodeName
|
|
|
|
|
|
def is_message_update(
|
|
update: list[Any],
|
|
) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]:
|
|
"""
|
|
Streaming of messages. Returns a partial state.
|
|
"""
|
|
return len(update) == 2 and update[0] == "messages"
|
|
|
|
|
|
def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], AssistantState]]:
|
|
"""
|
|
Update of the state.
|
|
"""
|
|
return len(update) == 2 and update[0] == "values"
|
|
|
|
|
|
VISUALIZATION_NODES: dict[AssistantNodeName, type[SchemaGeneratorNode]] = {
|
|
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
|
|
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
|
|
}
|
|
|
|
|
|
class Assistant:
|
|
_team: Team
|
|
_graph: StateGraph
|
|
|
|
def __init__(self, team: Team):
|
|
self._team = team
|
|
self._graph = StateGraph(AssistantState)
|
|
|
|
def _compile_graph(self):
|
|
builder = self._graph
|
|
|
|
router_node = RouterNode(self._team)
|
|
builder.add_node(AssistantNodeName.ROUTER, router_node.run)
|
|
builder.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.ROUTER,
|
|
router_node.router,
|
|
path_map={"trends": AssistantNodeName.TRENDS_PLANNER, "funnel": AssistantNodeName.FUNNEL_PLANNER},
|
|
)
|
|
|
|
create_trends_plan_node = TrendsPlannerNode(self._team)
|
|
builder.add_node(AssistantNodeName.TRENDS_PLANNER, create_trends_plan_node.run)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.TRENDS_PLANNER,
|
|
create_trends_plan_node.router,
|
|
path_map={
|
|
"tools": AssistantNodeName.TRENDS_PLANNER_TOOLS,
|
|
},
|
|
)
|
|
|
|
create_trends_plan_tools_node = TrendsPlannerToolsNode(self._team)
|
|
builder.add_node(AssistantNodeName.TRENDS_PLANNER_TOOLS, create_trends_plan_tools_node.run)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.TRENDS_PLANNER_TOOLS,
|
|
create_trends_plan_tools_node.router,
|
|
path_map={
|
|
"continue": AssistantNodeName.TRENDS_PLANNER,
|
|
"plan_found": AssistantNodeName.TRENDS_GENERATOR,
|
|
},
|
|
)
|
|
|
|
generate_trends_node = TrendsGeneratorNode(self._team)
|
|
builder.add_node(AssistantNodeName.TRENDS_GENERATOR, generate_trends_node.run)
|
|
|
|
generate_trends_tools_node = TrendsGeneratorToolsNode(self._team)
|
|
builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, generate_trends_tools_node.run)
|
|
|
|
builder.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.TRENDS_GENERATOR,
|
|
generate_trends_node.router,
|
|
path_map={
|
|
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
|
|
"next": AssistantNodeName.SUMMARIZER,
|
|
},
|
|
)
|
|
|
|
funnel_planner = FunnelPlannerNode(self._team)
|
|
builder.add_node(AssistantNodeName.FUNNEL_PLANNER, funnel_planner.run)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.FUNNEL_PLANNER,
|
|
funnel_planner.router,
|
|
path_map={
|
|
"tools": AssistantNodeName.FUNNEL_PLANNER_TOOLS,
|
|
},
|
|
)
|
|
|
|
funnel_planner_tools = FunnelPlannerToolsNode(self._team)
|
|
builder.add_node(AssistantNodeName.FUNNEL_PLANNER_TOOLS, funnel_planner_tools.run)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.FUNNEL_PLANNER_TOOLS,
|
|
funnel_planner_tools.router,
|
|
path_map={
|
|
"continue": AssistantNodeName.FUNNEL_PLANNER,
|
|
"plan_found": AssistantNodeName.FUNNEL_GENERATOR,
|
|
},
|
|
)
|
|
|
|
funnel_generator = FunnelGeneratorNode(self._team)
|
|
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator.run)
|
|
|
|
funnel_generator_tools_node = FunnelGeneratorToolsNode(self._team)
|
|
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools_node.run)
|
|
|
|
builder.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
|
|
builder.add_conditional_edges(
|
|
AssistantNodeName.FUNNEL_GENERATOR,
|
|
generate_trends_node.router,
|
|
path_map={
|
|
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
|
|
"next": AssistantNodeName.SUMMARIZER,
|
|
},
|
|
)
|
|
|
|
summarizer_node = SummarizerNode(self._team)
|
|
builder.add_node(AssistantNodeName.SUMMARIZER, summarizer_node.run)
|
|
builder.add_edge(AssistantNodeName.SUMMARIZER, AssistantNodeName.END)
|
|
|
|
return builder.compile()
|
|
|
|
def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
|
|
assistant_graph = self._compile_graph()
|
|
callbacks = [langfuse_handler] if langfuse_handler else []
|
|
messages = [message.root for message in conversation.messages]
|
|
|
|
chunks = AIMessageChunk(content="")
|
|
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}
|
|
|
|
generator = assistant_graph.stream(
|
|
state,
|
|
config={"recursion_limit": 24, "callbacks": callbacks},
|
|
stream_mode=["messages", "values", "updates"],
|
|
)
|
|
|
|
chunks = AIMessageChunk(content="")
|
|
|
|
# Send a chunk to establish the connection avoiding the worker's timeout.
|
|
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
|
|
|
|
try:
|
|
for update in generator:
|
|
if is_state_update(update):
|
|
_, new_state = update
|
|
state = new_state
|
|
|
|
elif is_value_update(update):
|
|
_, state_update = update
|
|
|
|
if (
|
|
AssistantNodeName.ROUTER in state_update
|
|
and "messages" in state_update[AssistantNodeName.ROUTER]
|
|
):
|
|
yield state_update[AssistantNodeName.ROUTER]["messages"][0]
|
|
elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys():
|
|
# Reset chunks when schema validation fails.
|
|
chunks = AIMessageChunk(content="")
|
|
|
|
node_name = intersected_nodes.pop()
|
|
if "messages" in state_update[node_name]:
|
|
yield state_update[node_name]["messages"][0]
|
|
elif state_update[node_name].get("intermediate_steps", []):
|
|
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
|
|
elif AssistantNodeName.SUMMARIZER in state_update:
|
|
chunks = AIMessageChunk(content="")
|
|
yield state_update[AssistantNodeName.SUMMARIZER]["messages"][0]
|
|
elif is_message_update(update):
|
|
langchain_message, langgraph_state = update[1]
|
|
if isinstance(langchain_message, AIMessageChunk):
|
|
if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys():
|
|
chunks += langchain_message # type: ignore
|
|
parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output(
|
|
chunks.tool_calls[0]["args"]
|
|
)
|
|
if parsed_message:
|
|
yield VisualizationMessage(
|
|
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
|
|
)
|
|
elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER:
|
|
chunks += langchain_message # type: ignore
|
|
yield AssistantMessage(content=chunks.content)
|
|
except Exception as e:
|
|
capture_exception(e)
|
|
yield FailureMessage() # This is an unhandled error, so we just stop further generation at this point
|