diff --git a/runtime_tests/bun/index.test.tsx b/runtime_tests/bun/index.test.tsx index 8caaf186..9aaaf26d 100644 --- a/runtime_tests/bun/index.test.tsx +++ b/runtime_tests/bun/index.test.tsx @@ -333,6 +333,14 @@ describe('streaming', () => { }) }) }) + app.get('/streamHello', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) app.get('/streamSSE', (c) => { return streamSSE(c, async (stream) => { stream.onAbort(() => { @@ -343,6 +351,14 @@ describe('streaming', () => { }) }) }) + app.get('/streamSSEHello', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) beforeEach(() => { aborted = false @@ -369,6 +385,13 @@ describe('streaming', () => { } expect(aborted).toBe(true) }) + + it('Should not be called onAbort if already closed', async () => { + expect(aborted).toBe(false) + const res = await fetch(`http://localhost:${server.port}/streamHello`) + expect(await res.text()).toBe('Hello') + expect(aborted).toBe(false) + }) }) describe('streamSSE', () => { @@ -386,5 +409,12 @@ describe('streaming', () => { } expect(aborted).toBe(true) }) + + it('Should not be called onAbort if already closed', async () => { + expect(aborted).toBe(false) + const res = await fetch(`http://localhost:${server.port}/streamSSEHello`) + expect(await res.text()).toBe('Hello') + expect(aborted).toBe(false) + }) }) }) diff --git a/runtime_tests/deno/stream.test.ts b/runtime_tests/deno/stream.test.ts index 8e48f51b..67eb1a4c 100644 --- a/runtime_tests/deno/stream.test.ts +++ b/runtime_tests/deno/stream.test.ts @@ -35,6 +35,26 @@ Deno.test('Shuld call onAbort via stream', async () => { await server.shutdown() }) +Deno.test('Shuld not call onAbort via stream if already closed', async () => { + const app = new Hono() + let aborted = false + app.get('/stream', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) + + const server = Deno.serve({ port: 0 }, app.fetch) + assertEquals(aborted, false) + const res = await fetch(`http://localhost:${server.addr.port}/stream`) + assertEquals(await res.text(), 'Hello') + assertEquals(aborted, false) + await server.shutdown() +}) + Deno.test('Shuld call onAbort via streamSSE', async () => { const app = new Hono() let aborted = false @@ -67,3 +87,23 @@ Deno.test('Shuld call onAbort via streamSSE', async () => { await server.shutdown() }) + +Deno.test('Shuld not call onAbort via streamSSE if already closed', async () => { + const app = new Hono() + let aborted = false + app.get('/stream', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) + + const server = Deno.serve({ port: 0 }, app.fetch) + assertEquals(aborted, false) + const res = await fetch(`http://localhost:${server.addr.port}/stream`) + assertEquals(await res.text(), 'Hello') + assertEquals(aborted, false) + await server.shutdown() +}) diff --git a/runtime_tests/node/index.test.ts b/runtime_tests/node/index.test.ts index 3b891da1..92892ef8 100644 --- a/runtime_tests/node/index.test.ts +++ b/runtime_tests/node/index.test.ts @@ -113,9 +113,21 @@ describe('stream', () => { }) }) }) + app.get('/streamHello', (c) => { + return stream(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) const server = createAdaptorServer(app) + beforeEach(() => { + aborted = false + }) + it('Should call onAbort', async () => { const req = request(server) .get('/stream') @@ -129,6 +141,14 @@ describe('stream', () => { } expect(aborted).toBe(true) }) + + it('Should not be called onAbort if already closed', async () => { + expect(aborted).toBe(false) + const res = await request(server).get('/streamHello') + expect(res.status).toBe(200) + expect(res.text).toBe('Hello') + expect(aborted).toBe(false) + }) }) describe('streamSSE', () => { @@ -146,9 +166,21 @@ describe('streamSSE', () => { }) }) }) + app.get('/streamHello', (c) => { + return streamSSE(c, async (stream) => { + stream.onAbort(() => { + aborted = true + }) + await stream.write('Hello') + }) + }) const server = createAdaptorServer(app) + beforeEach(() => { + aborted = false + }) + it('Should call onAbort', async () => { const req = request(server) .get('/stream') @@ -162,4 +194,12 @@ describe('streamSSE', () => { } expect(aborted).toBe(true) }) + + it('Should not be called onAbort if already closed', async () => { + expect(aborted).toBe(false) + const res = await request(server).get('/streamHello') + expect(res.status).toBe(200) + expect(res.text).toBe('Hello') + expect(aborted).toBe(false) + }) }) diff --git a/src/helper/streaming/sse.ts b/src/helper/streaming/sse.ts index 9bc18a61..5571f2da 100644 --- a/src/helper/streaming/sse.ts +++ b/src/helper/streaming/sse.ts @@ -69,7 +69,9 @@ export const streamSSE = ( // bun does not cancel response stream when request is canceled, so detect abort by signal c.req.raw.signal.addEventListener('abort', () => { - stream.abort() + if (!stream.closed) { + stream.abort() + } }) // in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming contextStash.set(stream.responseReadable, c) diff --git a/src/helper/streaming/stream.ts b/src/helper/streaming/stream.ts index f1264eff..98b5a48a 100644 --- a/src/helper/streaming/stream.ts +++ b/src/helper/streaming/stream.ts @@ -12,7 +12,9 @@ export const stream = ( // bun does not cancel response stream when request is canceled, so detect abort by signal c.req.raw.signal.addEventListener('abort', () => { - stream.abort() + if (!stream.closed) { + stream.abort() + } }) // in bun, `c` is destroyed when the request is returned, so hold it until the end of streaming contextStash.set(stream.responseReadable, c) diff --git a/src/utils/stream.ts b/src/utils/stream.ts index 2bed9bad..90538ec1 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -13,6 +13,10 @@ export class StreamingApi { * Whether the stream has been aborted. */ aborted: boolean = false + /** + * Whether the stream has been closed normally. + */ + closed: boolean = false constructor(writable: WritableStream, _readable: ReadableStream) { this.writable = writable @@ -66,6 +70,7 @@ export class StreamingApi { } catch (e) { // Do nothing. If you want to handle errors, create a stream by yourself. } + this.closed = true } async pipe(body: ReadableStream) {