diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 21f05169..e1523a10 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: - uses: denoland/setup-deno@v1 with: deno-version: v1.x - - run: env NAME=Deno deno test --coverage=coverage/raw/deno-runtime --allow-read --allow-env --allow-write -c runtime_tests/deno/deno.json runtime_tests/deno + - run: env NAME=Deno deno test --coverage=coverage/raw/deno-runtime --allow-read --allow-env --allow-write --allow-net -c runtime_tests/deno/deno.json runtime_tests/deno - run: deno test -c runtime_tests/deno-jsx/deno.precompile.json --coverage=coverage/raw/deno-precompile-jsx runtime_tests/deno-jsx - run: deno test -c runtime_tests/deno-jsx/deno.react-jsx.json --coverage=coverage/raw/deno-react-jsx runtime_tests/deno-jsx - uses: actions/upload-artifact@v4 diff --git a/package.json b/package.json index c62e89ae..4971def7 100644 --- a/package.json +++ b/package.json @@ -12,7 +12,7 @@ "scripts": { "test": "tsc --noEmit && vitest --run && vitest -c .vitest.config/jsx-runtime-default.ts --run && vitest -c .vitest.config/jsx-runtime-dom.ts --run", "test:watch": "vitest --watch", - "test:deno": "deno test --allow-read --allow-env --allow-write -c runtime_tests/deno/deno.json runtime_tests/deno && deno test --no-lock -c runtime_tests/deno-jsx/deno.precompile.json runtime_tests/deno-jsx && deno test --no-lock -c runtime_tests/deno-jsx/deno.react-jsx.json runtime_tests/deno-jsx", + "test:deno": "deno test --allow-read --allow-env --allow-write --allow-net -c runtime_tests/deno/deno.json runtime_tests/deno && deno test --no-lock -c runtime_tests/deno-jsx/deno.precompile.json runtime_tests/deno-jsx && deno test --no-lock -c runtime_tests/deno-jsx/deno.react-jsx.json runtime_tests/deno-jsx", "test:bun": "bun test --jsx-import-source ../../src/jsx runtime_tests/bun/index.test.tsx", "test:fastly": "vitest --run --config ./runtime_tests/fastly/vitest.config.ts", "test:node": "vitest --run --config ./runtime_tests/node/vitest.config.ts", diff --git a/runtime_tests/bun/index.test.tsx b/runtime_tests/bun/index.test.tsx index 3904c9ed..8caaf186 100644 --- a/runtime_tests/bun/index.test.tsx +++ b/runtime_tests/bun/index.test.tsx @@ -1,4 +1,4 @@ -import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { afterAll, afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { serveStatic, toSSG } from '../../src/adapter/bun' import { createBunWebSocket } from '../../src/adapter/bun/websocket' import type { BunWebSocketData } from '../../src/adapter/bun/websocket' @@ -11,6 +11,7 @@ import { jsx } from '../../src/jsx' import { basicAuth } from '../../src/middleware/basic-auth' import { jwt } from '../../src/middleware/jwt' import { HonoRequest } from '../../src/request' +import { stream, streamSSE } from '../..//src/helper/streaming' // Test just only minimal patterns. // Because others are tested well in Cloudflare Workers environment already. @@ -316,3 +317,74 @@ async function deleteDirectory(dirPath) { await fs.unlink(dirPath) } } + +describe('streaming', () => { + const app = new Hono() + let server: ReturnType + let aborted = false + + app.get('/stream', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + app.get('/streamSSE', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + + beforeEach(() => { + aborted = false + server = Bun.serve({ port: 0, fetch: app.fetch }) + }) + + afterEach(() => { + server.stop() + }) + + describe('stream', () => { + it('Should call onAbort', async () => { + const ac = new AbortController() + const req = new Request(`http://localhost:${server.port}/stream`, { + signal: ac.signal, + }) + expect(aborted).toBe(false) + const res = fetch(req).catch(() => {}) + await new Promise((resolve) => setTimeout(resolve, 10)) + ac.abort() + await res + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + expect(aborted).toBe(true) + }) + }) + + describe('streamSSE', () => { + it('Should call onAbort', async () => { + const ac = new AbortController() + const req = new Request(`http://localhost:${server.port}/streamSSE`, { + signal: ac.signal, + }) + const res = fetch(req).catch(() => {}) + await new Promise((resolve) => setTimeout(resolve, 10)) + ac.abort() + await res + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + expect(aborted).toBe(true) + }) + }) +}) diff --git a/runtime_tests/deno/stream.test.ts b/runtime_tests/deno/stream.test.ts new file mode 100644 index 00000000..8e48f51b --- /dev/null +++ b/runtime_tests/deno/stream.test.ts @@ -0,0 +1,69 @@ +import { Hono } from '../../src/hono.ts' +import { assertEquals } from './deps.ts' +import { stream, streamSSE } from '../../src/helper/streaming/index.ts' + +Deno.test('Shuld call onAbort via stream', async () => { + const app = new Hono() + let aborted = false + app.get('/stream', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + + const server = Deno.serve({ port: 0 }, app.fetch) + const ac = new AbortController() + const req = new Request(`http://localhost:${server.addr.port}/stream`, { + signal: ac.signal, + }) + assertEquals + const res = fetch(req).catch(() => {}) + assertEquals(aborted, false) + await new Promise((resolve) => setTimeout(resolve, 10)) + ac.abort() + await res + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + assertEquals(aborted, true) + + await server.shutdown() +}) + +Deno.test('Shuld call onAbort via streamSSE', async () => { + const app = new Hono() + let aborted = false + app.get('/stream', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + + const server = Deno.serve({ port: 0 }, app.fetch) + const ac = new AbortController() + const req = new Request(`http://localhost:${server.addr.port}/stream`, { + signal: ac.signal, + }) + assertEquals + const res = fetch(req).catch(() => {}) + assertEquals(aborted, false) + await new Promise((resolve) => setTimeout(resolve, 10)) + ac.abort() + await res + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + assertEquals(aborted, true) + + await server.shutdown() +}) diff --git a/runtime_tests/node/index.test.ts b/runtime_tests/node/index.test.ts index 436a4dcd..3b891da1 100644 --- a/runtime_tests/node/index.test.ts +++ b/runtime_tests/node/index.test.ts @@ -6,6 +6,7 @@ import { env, getRuntimeKey } from '../../src/helper/adapter' import { basicAuth } from '../../src/middleware/basic-auth' import { jwt } from '../../src/middleware/jwt' import { HonoRequest } from '../../src/request' +import { stream, streamSSE } from '../../src/helper/streaming' // Test only minimal patterns. // See for more tests and information. @@ -96,3 +97,69 @@ describe('JWT Auth Middleware', () => { expect(res.text).toBe('auth') }) }) + +describe('stream', () => { + const app = new Hono() + + let aborted = false + + app.get('/stream', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + + const server = createAdaptorServer(app) + + it('Should call onAbort', async () => { + const req = request(server) + .get('/stream') + .end(() => {}) + + expect(aborted).toBe(false) + await new Promise((resolve) => setTimeout(resolve, 10)) + req.abort() + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + expect(aborted).toBe(true) + }) +}) + +describe('streamSSE', () => { + const app = new Hono() + + let aborted = false + + app.get('/stream', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + return new Promise((resolve) => { + stream.onAbort(resolve) + }) + }) + }) + + const server = createAdaptorServer(app) + + it('Should call onAbort', async () => { + const req = request(server) + .get('/stream') + .end(() => {}) + + expect(aborted).toBe(false) + await new Promise((resolve) => setTimeout(resolve, 10)) + req.abort() + while (!aborted) { + await new Promise((resolve) => setTimeout(resolve)) + } + expect(aborted).toBe(true) + }) +}) diff --git a/src/helper/streaming/sse.test.ts b/src/helper/streaming/sse.test.ts index 48e9ac74..eb7bbb89 100644 --- a/src/helper/streaming/sse.test.ts +++ b/src/helper/streaming/sse.test.ts @@ -73,6 +73,33 @@ describe('SSE Streaming helper', () => { expect(aborted).toBeTruthy() }) + it('Check streamSSE Response if aborted by abort signal', async () => { + const ac = new AbortController() + const req = new Request('http://localhost/', { signal: ac.signal }) + const c = new Context(req) + + let aborted = false + const res = streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + for (let i = 0; i < 3; i++) { + await stream.writeSSE({ + data: `Message ${i}`, + }) + await stream.sleep(1) + } + }) + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const { value } = await reader.read() + expect(value).toEqual(new TextEncoder().encode('data: Message 0\n\n')) + ac.abort() + expect(aborted).toBeTruthy() + }) + it('Should include retry in the SSE message', async () => { const retryTime = 3000 // 3 seconds const res = streamSSE(c, async (stream) => { diff --git a/src/helper/streaming/sse.ts b/src/helper/streaming/sse.ts index 6498648d..9bc18a61 100644 --- a/src/helper/streaming/sse.ts +++ b/src/helper/streaming/sse.ts @@ -58,6 +58,7 @@ const run = async ( } } +const contextStash = new WeakMap() export const streamSSE = ( c: Context, cb: (stream: SSEStreamingApi) => Promise, @@ -66,6 +67,13 @@ export const streamSSE = ( const { readable, writable } = new TransformStream() const stream = new SSEStreamingApi(writable, readable) + // bun does not cancel response stream when request is canceled, so detect abort by signal + c.req.raw.signal.addEventListener('abort', () => { + stream.abort() + }) + // in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming + contextStash.set(stream.responseReadable, c) + c.header('Transfer-Encoding', 'chunked') c.header('Content-Type', 'text/event-stream') c.header('Cache-Control', 'no-cache') diff --git a/src/helper/streaming/stream.test.ts b/src/helper/streaming/stream.test.ts index 34de07a3..820579de 100644 --- a/src/helper/streaming/stream.test.ts +++ b/src/helper/streaming/stream.test.ts @@ -46,6 +46,31 @@ describe('Basic Streaming Helper', () => { expect(aborted).toBeTruthy() }) + it('Check stream Response if aborted by abort signal', async () => { + const ac = new AbortController() + const req = new Request('http://localhost/', { signal: ac.signal }) + const c = new Context(req) + + let aborted = false + const res = stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + for (let i = 0; i < 3; i++) { + await stream.write(new Uint8Array([i])) + await stream.sleep(1) + } + }) + if (!res.body) { + throw new Error('Body is null') + } + const reader = res.body.getReader() + const { value } = await reader.read() + expect(value).toEqual(new Uint8Array([0])) + ac.abort() + expect(aborted).toBeTruthy() + }) + it('Check stream Response if error occurred', async () => { const onError = vi.fn() const res = stream( diff --git a/src/helper/streaming/stream.ts b/src/helper/streaming/stream.ts index ed739336..f1264eff 100644 --- a/src/helper/streaming/stream.ts +++ b/src/helper/streaming/stream.ts @@ -1,6 +1,7 @@ import type { Context } from '../../context' import { StreamingApi } from '../../utils/stream' +const contextStash = new WeakMap() export const stream = ( c: Context, cb: (stream: StreamingApi) => Promise, @@ -8,6 +9,13 @@ export const stream = ( ): Response => { const { readable, writable } = new TransformStream() const stream = new StreamingApi(writable, readable) + + // bun does not cancel response stream when request is canceled, so detect abort by signal + c.req.raw.signal.addEventListener('abort', () => { + stream.abort() + }) + // in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming + contextStash.set(stream.responseReadable, c) ;(async () => { try { await cb(stream) @@ -21,5 +29,6 @@ export const stream = ( stream.close() } })() + return c.newResponse(stream.responseReadable) } diff --git a/src/utils/stream.test.ts b/src/utils/stream.test.ts index 5ce92ef1..f2b9b6d3 100644 --- a/src/utils/stream.test.ts +++ b/src/utils/stream.test.ts @@ -96,4 +96,26 @@ describe('StreamingApi', () => { expect(handleAbort1).toBeCalled() expect(handleAbort2).toBeCalled() }) + + it('abort()', async () => { + const { readable, writable } = new TransformStream() + const handleAbort1 = vi.fn() + const handleAbort2 = vi.fn() + const api = new StreamingApi(writable, readable) + api.onAbort(handleAbort1) + api.onAbort(handleAbort2) + expect(handleAbort1).not.toBeCalled() + expect(handleAbort2).not.toBeCalled() + expect(api.aborted).toBe(false) + + api.abort() + expect(handleAbort1).toHaveBeenCalledOnce() + expect(handleAbort2).toHaveBeenCalledOnce() + expect(api.aborted).toBe(true) + + api.abort() + expect(handleAbort1).toHaveBeenCalledOnce() + expect(handleAbort2).toHaveBeenCalledOnce() + expect(api.aborted).toBe(true) + }) }) diff --git a/src/utils/stream.ts b/src/utils/stream.ts index f3d434e5..2bed9bad 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -9,6 +9,10 @@ export class StreamingApi { private writable: WritableStream private abortSubscribers: (() => void | Promise)[] = [] responseReadable: ReadableStream + /** + * Whether the stream has been aborted. + */ + aborted: boolean = false constructor(writable: WritableStream, _readable: ReadableStream) { this.writable = writable @@ -30,7 +34,7 @@ export class StreamingApi { done ? controller.close() : controller.enqueue(value) }, cancel: () => { - this.abortSubscribers.forEach((subscriber) => subscriber()) + this.abort() }, }) } @@ -73,4 +77,15 @@ export class StreamingApi { onAbort(listener: () => void | Promise) { this.abortSubscribers.push(listener) } + + /** + * Abort the stream. + * You can call this method when stream is aborted by external event. + */ + abort() { + if (!this.aborted) { + this.aborted = true + this.abortSubscribers.forEach((subscriber) => subscriber()) + } + } }