From 544e239f40559ff1a86d87c3b44ed1746555259b Mon Sep 17 00:00:00 2001 From: Georgiy Tarasov Date: Fri, 25 Oct 2024 21:18:54 +0200 Subject: [PATCH] fix(product-assistant): correctly establish a connection for streaming (#25826) --- ee/hogai/assistant.py | 3 + frontend/src/scenes/max/Max.stories.tsx | 4 +- .../max/__mocks__/chatResponse.mocks.ts | 3 + frontend/src/scenes/max/maxLogic.ts | 77 +++++++------------ package.json | 1 + pnpm-lock.yaml | 11 ++- posthog/api/query.py | 2 +- 7 files changed, 46 insertions(+), 55 deletions(-) create mode 100644 frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py index d1aa9656257..e47020fdcdf 100644 --- a/ee/hogai/assistant.py +++ b/ee/hogai/assistant.py @@ -79,6 +79,9 @@ class Assistant: chunks = AIMessageChunk(content="") + # Send a chunk to establish the connection avoiding the worker's timeout. + yield "" + for update in generator: if is_value_update(update): _, state_update = update diff --git a/frontend/src/scenes/max/Max.stories.tsx b/frontend/src/scenes/max/Max.stories.tsx index 65106d4ae44..27045963c6e 100644 --- a/frontend/src/scenes/max/Max.stories.tsx +++ b/frontend/src/scenes/max/Max.stories.tsx @@ -4,7 +4,7 @@ import { useEffect } from 'react' import { mswDecorator, useStorybookMocks } from '~/mocks/browser' -import chatResponse from './__mocks__/chatResponse.json' +import { chatResponseChunk } from './__mocks__/chatResponse.mocks' import { MaxInstance } from './Max' import { maxLogic } from './maxLogic' @@ -13,7 +13,7 @@ const meta: Meta = { decorators: [ mswDecorator({ post: { - '/api/environments/:team_id/query/chat/': chatResponse, + '/api/environments/:team_id/query/chat/': (_, res, ctx) => res(ctx.text(chatResponseChunk)), }, }), ], diff --git a/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts b/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts new file mode 100644 index 00000000000..18b82c1947d --- /dev/null +++ b/frontend/src/scenes/max/__mocks__/chatResponse.mocks.ts @@ -0,0 +1,3 @@ +import chatResponse from './chatResponse.json' + +export const chatResponseChunk = `data: ${JSON.stringify(chatResponse)}\n\n` diff --git a/frontend/src/scenes/max/maxLogic.ts b/frontend/src/scenes/max/maxLogic.ts index 69d53bf956b..16bc6fa87b2 100644 --- a/frontend/src/scenes/max/maxLogic.ts +++ b/frontend/src/scenes/max/maxLogic.ts @@ -1,4 +1,5 @@ import { shuffle } from 'd3' +import { createParser } from 'eventsource-parser' import { actions, kea, key, listeners, path, props, reducers, selectors } from 'kea' import { loaders } from 'kea-loaders' import api from 'lib/api' @@ -118,21 +119,23 @@ export const maxLogic = kea([ messages: values.thread.map(({ status, ...message }) => message), }) const reader = response.body?.getReader() + + if (!reader) { + return + } + const decoder = new TextDecoder() - if (reader) { - let firstChunk = true + let firstChunk = true - while (true) { - const { done, value } = await reader.read() - if (done) { - actions.setMessageStatus(newIndex, 'completed') - break + const parser = createParser({ + onEvent: (event) => { + const parsedResponse = parseResponse(event.data) + + if (!parsedResponse) { + return } - const text = decoder.decode(value) - const parsedResponse = parseResponse(text) - if (firstChunk) { firstChunk = false @@ -145,6 +148,17 @@ export const maxLogic = kea([ status: 'loading', }) } + }, + }) + + while (true) { + const { done, value } = await reader.read() + + parser.feed(decoder.decode(value)) + + if (done) { + actions.setMessageStatus(newIndex, 'completed') + break } } } catch { @@ -163,50 +177,11 @@ export const maxLogic = kea([ * Parses the generation result from the API. Some generation chunks might be sent in batches. * @param response */ -function parseResponse(response: string, recursive = true): RootAssistantMessage | null { +function parseResponse(response: string): RootAssistantMessage | null | undefined { try { const parsed = JSON.parse(response) - return parsed as RootAssistantMessage + return parsed as RootAssistantMessage | null | undefined } catch { - if (!recursive) { - return null - } - - const results: [number, number][] = [] - let pair: [number, number] = [0, 0] - let seq = 0 - - for (let i = 0; i < response.length; i++) { - const char = response[i] - - if (char === '{') { - if (seq === 0) { - pair[0] = i - } - - seq += 1 - } - - if (char === '}') { - seq -= 1 - if (seq === 0) { - pair[1] = i - } - } - - if (seq === 0) { - results.push(pair) - pair = [0, 0] - } - } - - const lastPair = results.pop() - - if (lastPair) { - const [left, right] = lastPair - return parseResponse(response.slice(left, right + 1), false) - } - return null } } diff --git a/package.json b/package.json index 65dd94c0645..a95fbb96416 100644 --- a/package.json +++ b/package.json @@ -124,6 +124,7 @@ "esbuild-plugin-less": "^1.3.1", "esbuild-plugin-polyfill-node": "^0.3.0", "esbuild-sass-plugin": "^3.0.0", + "eventsource-parser": "^3.0.0", "expr-eval": "^2.0.2", "express": "^4.17.1", "fast-deep-equal": "^3.1.3", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7ccdb6be2be..213db7def97 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -193,6 +193,9 @@ dependencies: esbuild-sass-plugin: specifier: ^3.0.0 version: 3.0.0(esbuild@0.19.8) + eventsource-parser: + specifier: ^3.0.0 + version: 3.0.0 expr-eval: specifier: ^2.0.2 version: 2.0.2 @@ -386,7 +389,7 @@ dependencies: optionalDependencies: fsevents: specifier: ^2.3.2 - version: 2.3.2 + version: 2.3.3 devDependencies: '@babel/core': @@ -12521,6 +12524,11 @@ packages: engines: {node: '>=0.8.x'} dev: true + /eventsource-parser@3.0.0: + resolution: {integrity: sha512-T1C0XCUimhxVQzW4zFipdx0SficT651NnkR0ZSH3yQwh+mFMdLfgjABVi4YtMTtaL4s168593DaoaRLMqryavA==} + engines: {node: '>=18.0.0'} + dev: false + /execa@4.1.0: resolution: {integrity: sha512-j5W0//W7f8UxAn8hXVnwG8tLwdiUy4FJLcSupCg6maBYZDpyBvTApK7KyuI4bKj8KOh1r2YH+6ucuYtJv1bTZA==} engines: {node: '>=10'} @@ -13126,6 +13134,7 @@ packages: engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} os: [darwin] requiresBuild: true + dev: true optional: true /fsevents@2.3.3: diff --git a/posthog/api/query.py b/posthog/api/query.py index d4d45ce66a2..f2eaccea53a 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -185,7 +185,7 @@ class QueryViewSet(TeamAndOrgViewSetMixin, PydanticModelMixin, viewsets.ViewSet) last_message = None for message in assistant.stream(validated_body): last_message = message - yield last_message + yield f"data: {message}\n\n" human_message = validated_body.messages[-1].root if isinstance(human_message, HumanMessage):