From 8ca155ec9b7780d9bb685b2d90881706d315fd61 Mon Sep 17 00:00:00 2001 From: Taku Amano Date: Sun, 8 Sep 2024 15:57:45 +0900 Subject: [PATCH] feat(helper/streaming): Support Promise or (async) JSX.Element in streamSSE (#3344) * feat(helper/streaming): Support Promise or (async) JSX.Element in streamSSE * refactor(context): enable to pass Promise (includes async JSX.Element) to resolveCallback --- src/context.ts | 17 ++-- .../streaming/{sse.test.ts => sse.test.tsx} | 88 +++++++++++++++++++ src/helper/streaming/sse.ts | 8 +- src/utils/html.ts | 13 ++- 4 files changed, 109 insertions(+), 17 deletions(-) rename src/helper/streaming/{sse.test.ts => sse.test.tsx} (62%) diff --git a/src/context.ts b/src/context.ts index e3dbc3c5..7e5ce67e 100644 --- a/src/context.ts +++ b/src/context.ts @@ -844,18 +844,11 @@ export class Context< this.#preparedHeaders['content-type'] = 'text/html; charset=UTF-8' if (typeof html === 'object') { - if (!(html instanceof Promise)) { - html = (html as string).toString() // HtmlEscapedString object to string - } - if ((html as string | Promise) instanceof Promise) { - return (html as unknown as Promise) - .then((html) => resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {})) - .then((html) => { - return typeof arg === 'number' - ? this.newResponse(html, arg, headers) - : this.newResponse(html, arg) - }) - } + return resolveCallback(html, HtmlEscapedCallbackPhase.Stringify, false, {}).then((html) => { + return typeof arg === 'number' + ? this.newResponse(html, arg, headers) + : this.newResponse(html, arg) + }) } return typeof arg === 'number' diff --git a/src/helper/streaming/sse.test.ts b/src/helper/streaming/sse.test.tsx similarity index 62% rename from src/helper/streaming/sse.test.ts rename to src/helper/streaming/sse.test.tsx index eb7bbb89..df77bb3c 100644 --- a/src/helper/streaming/sse.test.ts +++ b/src/helper/streaming/sse.test.tsx @@ -1,3 +1,5 @@ +/** @jsxImportSource ../../jsx */ +import { ErrorBoundary } from '../../jsx' import { Context } from '../../context' import { streamSSE } from '.' @@ -145,4 +147,90 @@ describe('SSE Streaming helper', () => { expect(onError).toBeCalledTimes(1) expect(onError).toBeCalledWith(new Error('Test error'), expect.anything()) // 2nd argument is StreamingApi instance }) + + it('Check streamSSE Response via Promise', async () => { + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ data: Promise.resolve('Async Message') }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data: Async Message\n\n') + }) + + it('Check streamSSE Response via JSX.Element', async () => { + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ data:
Hello
}) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Hello
\n\n') + }) + + it('Check streamSSE Response via ErrorBoundary in success case', async () => { + const AsyncComponent = async () => Promise.resolve(
Async Hello
) + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ + data: ( + Error}> + + + ), + }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Async Hello
\n\n') + }) + + it('Check streamSSE Response via ErrorBoundary in error case', async () => { + const AsyncComponent = async () => Promise.reject() + const res = streamSSE(c, async (stream) => { + await stream.writeSSE({ + data: ( + Error}> + + + ), + }) + }) + + expect(res).not.toBeNull() + expect(res.status).toBe(200) + + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const decoder = new TextDecoder() + const { value } = await reader.read() + const decodedValue = decoder.decode(value) + expect(decodedValue).toBe('data:
Error
\n\n') + }) }) diff --git a/src/helper/streaming/sse.ts b/src/helper/streaming/sse.ts index 1ed96e13..fb38f3d4 100644 --- a/src/helper/streaming/sse.ts +++ b/src/helper/streaming/sse.ts @@ -1,8 +1,9 @@ import type { Context } from '../../context' import { StreamingApi } from '../../utils/stream' +import { HtmlEscapedCallbackPhase, resolveCallback } from '../../utils/html' export interface SSEMessage { - data: string + data: string | Promise event?: string id?: string retry?: number @@ -14,7 +15,8 @@ export class SSEStreamingApi extends StreamingApi { } async writeSSE(message: SSEMessage) { - const data = message.data + const data = await resolveCallback(message.data, HtmlEscapedCallbackPhase.Stringify, false, {}) + const dataLines = (data as string) .split('\n') .map((line) => { return `data: ${line}` @@ -24,7 +26,7 @@ export class SSEStreamingApi extends StreamingApi { const sseData = [ message.event && `event: ${message.event}`, - data, + dataLines, message.id && `id: ${message.id}`, message.retry && `retry: ${message.retry}`, ] diff --git a/src/utils/html.ts b/src/utils/html.ts index d3557263..7731e565 100644 --- a/src/utils/html.ts +++ b/src/utils/html.ts @@ -140,12 +140,21 @@ export const resolveCallbackSync = (str: string | HtmlEscapedString): string => } export const resolveCallback = async ( - str: string | HtmlEscapedString, + str: string | HtmlEscapedString | Promise, phase: (typeof HtmlEscapedCallbackPhase)[keyof typeof HtmlEscapedCallbackPhase], preserveCallbacks: boolean, context: object, buffer?: [string] ): Promise => { + if (typeof str === 'object' && !(str instanceof String)) { + if (!((str as unknown) instanceof Promise)) { + str = (str as unknown as string).toString() // HtmlEscapedString object to string + } + if ((str as string | Promise) instanceof Promise) { + str = await (str as unknown as Promise) + } + } + const callbacks = (str as HtmlEscapedString).callbacks as HtmlEscapedCallback[] if (!callbacks?.length) { return Promise.resolve(str) @@ -153,7 +162,7 @@ export const resolveCallback = async ( if (buffer) { buffer[0] += str } else { - buffer = [str] + buffer = [str as string] } const resStr = Promise.all(callbacks.map((c) => c({ phase, buffer, context }))).then((res) =>