0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-21 13:39:22 +01:00

feat(max): Replace reasoning points with generation progress

This commit is contained in:
Michael Matloka 2024-11-19 13:35:33 +01:00
parent a58ee30f95
commit ce91decac0
No known key found for this signature in database
18 changed files with 156 additions and 144 deletions

View File

@ -5,7 +5,6 @@ from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import StateGraph
from pydantic import BaseModel
from sentry_sdk import capture_exception
from ee import settings
from ee.hogai.funnels.nodes import (
@ -30,6 +29,7 @@ from posthog.schema import (
AssistantGenerationStatusType,
AssistantMessage,
FailureMessage,
ReasoningMessage,
VisualizationMessage,
)
@ -68,12 +68,30 @@ def is_state_update(update: list[Any]) -> TypeGuard[tuple[Literal["updates"], As
return len(update) == 2 and update[0] == "values"
def is_task_started_update(
update: list[Any],
) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]:
"""
Streaming of messages. Returns a partial state.
"""
return len(update) == 2 and update[0] == "debug" and update[1]["type"] == "task"
VISUALIZATION_NODES: dict[AssistantNodeName, type[SchemaGeneratorNode]] = {
AssistantNodeName.TRENDS_GENERATOR: TrendsGeneratorNode,
AssistantNodeName.FUNNEL_GENERATOR: FunnelGeneratorNode,
}
NODE_TO_REASONING_MESSAGE: dict[AssistantNodeName, str] = {
AssistantNodeName.ROUTER: "Identifying type of analysis",
AssistantNodeName.TRENDS_PLANNER: "Picking relevant events and properties",
AssistantNodeName.FUNNEL_PLANNER: "Picking relevant events and properties",
AssistantNodeName.TRENDS_GENERATOR: "Creating trends query",
AssistantNodeName.FUNNEL_GENERATOR: "Creating funnel query",
}
class Assistant:
_team: Team
_graph: StateGraph
@ -185,7 +203,7 @@ class Assistant:
generator = assistant_graph.stream(
state,
config={"recursion_limit": 24, "callbacks": callbacks},
stream_mode=["messages", "values", "updates"],
stream_mode=["messages", "values", "updates", "debug"],
)
chunks = AIMessageChunk(content="")
@ -228,12 +246,15 @@ class Assistant:
chunks.tool_calls[0]["args"]
)
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
)
yield VisualizationMessage(answer=parsed_message.query)
elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER:
chunks += langchain_message # type: ignore
yield AssistantMessage(content=chunks.content)
except Exception as e:
capture_exception(e)
elif is_task_started_update(update):
_, task_update = update
node_name = task_update["payload"]["name"]
if reasoning_message := NODE_TO_REASONING_MESSAGE.get(node_name):
yield ReasoningMessage(content=reasoning_message)
except:
yield FailureMessage() # This is an unhandled error, so we just stop further generation at this point
raise # Re-raise, so that this is printed or goes into Sentry

View File

@ -21,7 +21,7 @@ class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
node = FunnelGeneratorNode(self.team)
with patch.object(FunnelGeneratorNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: FunnelsSchemaGeneratorOutput(reasoning_steps=["step"], answer=self.schema).model_dump()
lambda _: FunnelsSchemaGeneratorOutput(query=self.schema).model_dump()
)
new_state = node.run(
{
@ -33,9 +33,7 @@ class TestFunnelsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)],
"intermediate_steps": None,
},
)

View File

@ -62,15 +62,10 @@ def generate_funnel_schema() -> dict:
"parameters": {
"type": "object",
"properties": {
"reasoning_steps": {
"type": "array",
"items": {"type": "string"},
"description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.",
},
"answer": dereference_schema(schema),
"query": dereference_schema(schema),
},
"additionalProperties": False,
"required": ["reasoning_steps", "answer"],
"required": ["query"],
},
}

View File

@ -99,8 +99,7 @@ class SchemaGeneratorNode(AssistantNode, Generic[Q]):
"messages": [
VisualizationMessage(
plan=generated_plan,
reasoning_steps=message.reasoning_steps,
answer=message.answer,
answer=message.query,
done=True,
)
],

View File

@ -41,9 +41,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
def test_node_runs(self):
node = DummyGeneratorNode(self.team)
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: TestSchema(reasoning_steps=["step"], answer=self.schema).model_dump()
)
generator_model_mock.return_value = RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump())
new_state = node.run(
{
"messages": [HumanMessage(content="Text")],
@ -54,9 +52,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)],
"intermediate_steps": None,
},
)
@ -176,9 +172,9 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
def test_failover_with_incorrect_schema(self):
node = DummyGeneratorNode(self.team)
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
schema = TestSchema(reasoning_steps=[], answer=None).model_dump()
schema = TestSchema(query=None).model_dump()
# Emulate an incorrect JSON. It should be an object.
schema["answer"] = []
schema["query"] = []
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))
new_state = node.run({"messages": [HumanMessage(content="Text")]}, {})
@ -200,7 +196,7 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
with patch.object(
DummyGeneratorNode,
"_model",
return_value=RunnableLambda(lambda _: TestSchema(reasoning_steps=[], answer=self.schema).model_dump()),
return_value=RunnableLambda(lambda _: TestSchema(query=self.schema).model_dump()),
):
new_state = node.run(
{
@ -226,9 +222,9 @@ class TestSchemaGeneratorNode(ClickhouseTestMixin, APIBaseTest):
def test_node_leaves_failover_after_second_unsuccessful_attempt(self):
node = DummyGeneratorNode(self.team)
with patch.object(DummyGeneratorNode, "_model") as generator_model_mock:
schema = TestSchema(reasoning_steps=[], answer=None).model_dump()
schema = TestSchema(query=None).model_dump()
# Emulate an incorrect JSON. It should be an object.
schema["answer"] = []
schema["query"] = []
generator_model_mock.return_value = RunnableLambda(lambda _: json.dumps(schema))
new_state = node.run(

View File

@ -6,5 +6,4 @@ T = TypeVar("T", bound=BaseModel)
class SchemaGeneratorOutput(BaseModel, Generic[T]):
reasoning_steps: Optional[list[str]] = None
answer: Optional[T] = None
query: Optional[T] = None

View File

@ -38,7 +38,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
@ -73,7 +72,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
@ -110,7 +108,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
@ -159,7 +156,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
VisualizationMessage(
answer=None,
plan="Plan",
reasoning_steps=["step"],
done=True,
),
],
@ -180,7 +176,6 @@ class TestSummarizerNode(ClickhouseTestMixin, APIBaseTest):
VisualizationMessage(
answer=AssistantTrendsQuery(series=[AssistantTrendsEventsNode()]),
plan="Plan",
reasoning_steps=["step"],
done=True,
),
]

View File

@ -23,7 +23,7 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
node = TrendsGeneratorNode(self.team)
with patch.object(TrendsGeneratorNode, "_model") as generator_model_mock:
generator_model_mock.return_value = RunnableLambda(
lambda _: TrendsSchemaGeneratorOutput(reasoning_steps=["step"], answer=self.schema).model_dump()
lambda _: TrendsSchemaGeneratorOutput(query=self.schema).model_dump()
)
new_state = node.run(
{
@ -35,9 +35,7 @@ class TestTrendsGeneratorNode(ClickhouseTestMixin, APIBaseTest):
self.assertEqual(
new_state,
{
"messages": [
VisualizationMessage(answer=self.schema, plan="Plan", reasoning_steps=["step"], done=True)
],
"messages": [VisualizationMessage(answer=self.schema, plan="Plan", done=True)],
"intermediate_steps": None,
},
)

View File

@ -65,15 +65,10 @@ def generate_trends_schema() -> dict:
"parameters": {
"type": "object",
"properties": {
"reasoning_steps": {
"type": "array",
"items": {"type": "string"},
"description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.",
},
"answer": dereference_schema(schema),
"query": dereference_schema(schema),
},
"additionalProperties": False,
"required": ["reasoning_steps", "answer"],
"required": ["query"],
},
}

View File

@ -27,12 +27,13 @@ import {
FunnelsQuery,
InsightQueryNode,
LifecycleQuery,
NodeKind,
PathsQuery,
StickinessQuery,
TrendsQuery,
} from '~/queries/schema'
import {
isActionsNode,
isEventsNode,
isFunnelsQuery,
isInsightQueryWithBreakdown,
isInsightQueryWithSeries,
@ -149,15 +150,15 @@ function SeriesDisplay({
seriesIndex: number
}): JSX.Element {
const { mathDefinitions } = useValues(mathsLogic)
const filter = query.series[seriesIndex]
const series = query.series[seriesIndex]
const hasBreakdown = isInsightQueryWithBreakdown(query) && isValidBreakdown(query.breakdownFilter)
const mathDefinition = mathDefinitions[
isLifecycleQuery(query)
? 'dau'
: filter.math
? apiValueToMathType(filter.math, filter.math_group_type_index)
: series.math
? apiValueToMathType(series.math, series.math_group_type_index)
: 'total'
] as MathDefinition | undefined
@ -167,12 +168,12 @@ function SeriesDisplay({
className="SeriesDisplay"
icon={<SeriesLetter seriesIndex={seriesIndex} hasBreakdown={hasBreakdown} />}
extendedContent={
filter.properties &&
filter.properties.length > 0 && (
series.properties &&
series.properties.length > 0 && (
<CompactPropertyFiltersDisplay
groupFilter={{
type: FilterLogicalOperator.And,
values: [{ type: FilterLogicalOperator.And, values: filter.properties }],
values: [{ type: FilterLogicalOperator.And, values: series.properties }],
}}
embedded
/>
@ -181,34 +182,36 @@ function SeriesDisplay({
>
<span>
{isFunnelsQuery(query) ? 'Performed' : 'Showing'}
{filter.custom_name && <b> "{filter.custom_name}"</b>}
{filter.kind === NodeKind.ActionsNode && filter.id ? (
{series.custom_name && <b> "{series.custom_name}"</b>}
{isActionsNode(series) ? (
<Link
to={urls.action(filter.id)}
to={urls.action(series.id)}
className="SeriesDisplay__raw-name SeriesDisplay__raw-name--action"
title="Action series"
>
{filter.name}
{series.name}
</Link>
) : (
) : isEventsNode(series) ? (
<span className="SeriesDisplay__raw-name SeriesDisplay__raw-name--event" title="Event series">
<PropertyKeyInfo value={filter.name || '$pageview'} type={TaxonomicFilterGroupType.Events} />
<PropertyKeyInfo value={series.event || '$pageview'} type={TaxonomicFilterGroupType.Events} />
</span>
) : (
<i>{series.kind /* TODO: Support DataWarehouseNode */}</i>
)}
{!isFunnelsQuery(query) && (
<span className="leading-none">
counted by{' '}
{mathDefinition?.category === MathCategory.HogQLExpression ? (
<code>{filter.math_hogql}</code>
<code>{series.math_hogql}</code>
) : (
<>
{mathDefinition?.category === MathCategory.PropertyValue && filter.math_property && (
{mathDefinition?.category === MathCategory.PropertyValue && series.math_property && (
<>
{' '}
event's
<span className="SeriesDisplay__raw-name">
<PropertyKeyInfo
value={filter.math_property}
value={series.math_property}
type={TaxonomicFilterGroupType.EventProperties}
/>
</span>

View File

@ -1102,7 +1102,7 @@
"type": "object"
},
"AssistantMessageType": {
"enum": ["human", "ai", "ai/viz", "ai/failure", "ai/router"],
"enum": ["human", "ai", "ai/reasoning", "ai/viz", "ai/failure", "ai/router"],
"type": "string"
},
"AssistantMultipleBreakdownFilter": {
@ -11275,6 +11275,24 @@
"required": ["k", "t"],
"type": "object"
},
"ReasoningMessage": {
"additionalProperties": false,
"properties": {
"content": {
"type": "string"
},
"done": {
"const": true,
"type": "boolean"
},
"type": {
"const": "ai/reasoning",
"type": "string"
}
},
"required": ["type", "content", "done"],
"type": "object"
},
"RecordingOrder": {
"enum": [
"duration",
@ -11686,6 +11704,9 @@
{
"$ref": "#/definitions/VisualizationMessage"
},
{
"$ref": "#/definitions/ReasoningMessage"
},
{
"$ref": "#/definitions/AssistantMessage"
},
@ -12867,19 +12888,6 @@
"plan": {
"type": "string"
},
"reasoning_steps": {
"anyOf": [
{
"items": {
"type": "string"
},
"type": "array"
},
{
"type": "null"
}
]
},
"type": {
"const": "ai/viz",
"type": "string"

View File

@ -2472,6 +2472,7 @@ export type CachedActorsPropertyTaxonomyQueryResponse = CachedQueryResponse<Acto
export enum AssistantMessageType {
Human = 'human',
Assistant = 'ai',
Reasoning = 'ai/reasoning',
Visualization = 'ai/viz',
Failure = 'ai/failure',
Router = 'ai/router',
@ -2494,10 +2495,15 @@ export interface AssistantMessage {
done?: boolean
}
export interface ReasoningMessage {
type: AssistantMessageType.Reasoning
content: string
done: true
}
export interface VisualizationMessage {
type: AssistantMessageType.Visualization
plan?: string
reasoning_steps?: string[] | null
answer?: AssistantTrendsQuery | AssistantFunnelsQuery
done?: boolean
}
@ -2517,6 +2523,7 @@ export interface RouterMessage {
export type RootAssistantMessage =
| VisualizationMessage
| ReasoningMessage
| AssistantMessage
| HumanMessage
| FailureMessage

View File

@ -4,10 +4,9 @@ import {
IconThumbsDownFilled,
IconThumbsUp,
IconThumbsUpFilled,
IconWarning,
IconX,
} from '@posthog/icons'
import { LemonButton, LemonInput, LemonRow, Spinner } from '@posthog/lemon-ui'
import { LemonButton, LemonInput, Spinner } from '@posthog/lemon-ui'
import clsx from 'clsx'
import { useActions, useValues } from 'kea'
import { BreakdownSummary, PropertiesSummary, SeriesSummary } from 'lib/components/Cards/InsightCard/InsightDetails'
@ -35,6 +34,7 @@ import {
isAssistantMessage,
isFailureMessage,
isHumanMessage,
isReasoningMessage,
isVisualizationMessage,
} from './utils'
@ -58,6 +58,8 @@ export function Thread(): JSX.Element | null {
return <TextAnswer key={index} message={message} index={index} />
} else if (isVisualizationMessage(message)) {
return <VisualizationAnswer key={index} message={message} status={message.status} />
} else if (isReasoningMessage(message)) {
return <div key={index}>{message.content}</div>
}
return null // We currently skip other types of messages
})}
@ -130,7 +132,7 @@ function VisualizationAnswer({
}: {
message: VisualizationMessage
status?: MessageStatus
}): JSX.Element {
}): JSX.Element | null {
const query = useMemo<InsightVizNode | null>(() => {
if (message.answer) {
return {
@ -143,54 +145,33 @@ function VisualizationAnswer({
return null
}, [message])
return (
<>
{message.reasoning_steps && (
<MessageTemplate
type="ai"
action={
status === 'error' && (
<LemonRow icon={<IconWarning />} status="warning" size="small">
Max is generating this answer one more time because the previous attempt has failed.
</LemonRow>
)
}
className={status === 'error' ? 'border-warning' : undefined}
>
<ul className="list-disc ml-4">
{message.reasoning_steps.map((step, index) => (
<li key={index}>{step}</li>
))}
</ul>
</MessageTemplate>
)}
{status === 'completed' && query && (
<>
<MessageTemplate type="ai">
<div className="h-96 flex">
<Query query={query} readOnly embedded />
</div>
<div className="relative mb-1">
<LemonButton
to={urls.insightNew(undefined, undefined, query)}
sideIcon={<IconOpenInNew />}
size="xsmall"
targetBlank
className="absolute right-0 -top-px"
>
Open as new insight
</LemonButton>
<SeriesSummary query={query.source} heading={<TopHeading query={query} />} />
<div className="flex flex-wrap gap-4 mt-1 *:grow">
<PropertiesSummary properties={query.source.properties} />
<BreakdownSummary query={query.source} />
</div>
</div>
</MessageTemplate>
</>
)}
</>
)
return status !== 'completed'
? null
: query && (
<>
<MessageTemplate type="ai">
<div className="h-96 flex">
<Query query={query} readOnly embedded />
</div>
<div className="relative mb-1">
<LemonButton
to={urls.insightNew(undefined, undefined, query)}
sideIcon={<IconOpenInNew />}
size="xsmall"
targetBlank
className="absolute right-0 -top-px"
>
Open as new insight
</LemonButton>
<SeriesSummary query={query.source} heading={<TopHeading query={query} />} />
<div className="flex flex-wrap gap-4 mt-1 *:grow">
<PropertiesSummary properties={query.source.properties} />
<BreakdownSummary query={query.source} />
</div>
</div>
</MessageTemplate>
</>
)
}
function RetriableAnswerActions(): JSX.Element {

View File

@ -1,6 +1,7 @@
import { AssistantGenerationStatusEvent, AssistantGenerationStatusType } from '~/queries/schema'
import failureMessage from './failureMessage.json'
import reasoningMessage from './reasoningMessage.json'
import summaryMessage from './summaryMessage.json'
import visualizationMessage from './visualizationMessage.json'
@ -9,12 +10,16 @@ function generateChunk(events: string[]): string {
}
export const chatResponseChunk = generateChunk([
'event: message',
`data: ${JSON.stringify(reasoningMessage)}`,
'event: message',
`data: ${JSON.stringify(visualizationMessage)}`,
'event: message',
`data: ${JSON.stringify(summaryMessage)}`,
])
export const chatMidwayResponseChunk = generateChunk(['event: message', `data: ${JSON.stringify(reasoningMessage)}`])
const generationFailure: AssistantGenerationStatusEvent = { type: AssistantGenerationStatusType.GenerationError }
const responseWithReasoningStepsOnly = {
...visualizationMessage,

View File

@ -0,0 +1,5 @@
{
"type": "ai/reasoning",
"content": "Considering available events and properties",
"done": true
}

View File

@ -1,16 +1,7 @@
{
"type": "ai/viz",
"plan": "Test plan",
"reasoning_steps": [
"The user's query is to identify the most popular pages.",
"To determine the most popular pages, we should analyze the '$pageview' event as it tracks when a user loads or reloads a page.",
"We need to use the '$current_url' property from the event properties to identify different pages.",
"Since the user didn't specify a date range, a reasonable default would be the last 30 days to get recent insights.",
"We should use a breakdown on the '$current_url' to see the popularity of each page URL.",
"A bar chart would be suitable for visualizing the most popular pages as it's categorical data.",
"Filter out internal and test users by default unless specified otherwise."
],
"answer": {
"query": {
"aggregation_group_type_index": null,
"breakdownFilter": {
"breakdown_hide_other_aggregation": null,

View File

@ -6,6 +6,7 @@ import {
FailureMessage,
FunnelsQuery,
HumanMessage,
ReasoningMessage,
RootAssistantMessage,
RouterMessage,
TrendsQuery,
@ -13,6 +14,10 @@ import {
} from '~/queries/schema'
import { isTrendsQuery } from '~/queries/utils'
export function isReasoningMessage(message: RootAssistantMessage | undefined | null): message is ReasoningMessage {
return message?.type === AssistantMessageType.Reasoning
}
export function isVisualizationMessage(
message: RootAssistantMessage | undefined | null
): message is VisualizationMessage {

View File

@ -261,6 +261,7 @@ class AssistantMessage(BaseModel):
class AssistantMessageType(StrEnum):
HUMAN = "human"
AI = "ai"
AI_REASONING = "ai/reasoning"
AI_VIZ = "ai/viz"
AI_FAILURE = "ai/failure"
AI_ROUTER = "ai/router"
@ -1395,6 +1396,15 @@ class QueryTiming(BaseModel):
t: float = Field(..., description="Time in seconds. Shortened to 't' to save on data.")
class ReasoningMessage(BaseModel):
model_config = ConfigDict(
extra="forbid",
)
content: str
done: Literal[True] = True
type: Literal["ai/reasoning"] = "ai/reasoning"
class RecordingOrder(StrEnum):
DURATION = "duration"
RECORDING_DURATION = "recording_duration"
@ -5549,7 +5559,6 @@ class VisualizationMessage(BaseModel):
answer: Optional[Union[AssistantTrendsQuery, AssistantFunnelsQuery]] = None
done: Optional[bool] = None
plan: Optional[str] = None
reasoning_steps: Optional[list[str]] = None
type: Literal["ai/viz"] = "ai/viz"
@ -5936,9 +5945,11 @@ class RetentionQuery(BaseModel):
class RootAssistantMessage(
RootModel[Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage]]
RootModel[
Union[VisualizationMessage, ReasoningMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage]
]
):
root: Union[VisualizationMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage]
root: Union[VisualizationMessage, ReasoningMessage, AssistantMessage, HumanMessage, FailureMessage, RouterMessage]
class StickinessQuery(BaseModel):