From 234b083777d34fb1efadbe8f504148e2e533428f Mon Sep 17 00:00:00 2001 From: Shotaro Nakamura <79000684+nakasyou@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:09:36 +0900 Subject: [PATCH] feat(ws): Make WebSocket adapter more changeable (#3531) * feat(ws): Make WebSocket adapter more changeable * lint * format code * test: add tests * format code * lint:fix * format code * test: add more tests --- src/adapter/bun/websocket.test.ts | 123 ++++++++++++++++++ src/adapter/bun/websocket.ts | 64 +++++---- .../cloudflare-workers/websocket.test.ts | 15 +++ src/adapter/cloudflare-workers/websocket.ts | 89 ++++++------- src/adapter/deno/websocket.test.ts | 15 +++ src/adapter/deno/websocket.ts | 17 ++- src/helper/websocket/index.test.ts | 97 +++++++++++++- src/helper/websocket/index.ts | 80 ++++++++++-- 8 files changed, 403 insertions(+), 97 deletions(-) create mode 100644 src/adapter/bun/websocket.test.ts diff --git a/src/adapter/bun/websocket.test.ts b/src/adapter/bun/websocket.test.ts new file mode 100644 index 00000000..e7606342 --- /dev/null +++ b/src/adapter/bun/websocket.test.ts @@ -0,0 +1,123 @@ +import { Context } from '../../context' +import type { BunWebSocketData, BunServerWebSocket } from './websocket' +import { createWSContext, createBunWebSocket } from './websocket' + +describe('createWSContext()', () => { + it('Should send() and close() works', () => { + const send = vi.fn() + const close = vi.fn() + const ws = createWSContext({ + send(data) { + send(data) + }, + close(code, reason) { + close(code, reason) + }, + data: {}, + } as BunServerWebSocket) + ws.send('message') + expect(send).toBeCalled() + ws.close() + expect(close).toBeCalled() + }) +}) +describe('upgradeWebSocket()', () => { + it('Should throw error when server is null', async () => { + const { upgradeWebSocket } = createBunWebSocket() + const run = async () => + await upgradeWebSocket(() => ({}))( + new Context(new Request('http://localhost'), { + env: { + server: null, + }, + }), + () => Promise.resolve() + ) + + expect(run).rejects.toThrowError(/env has/) + }) + it('Should response null when upgraded', async () => { + const { upgradeWebSocket } = createBunWebSocket() + const upgraded = await upgradeWebSocket(() => ({}))( + new Context(new Request('http://localhost'), { + env: { + upgrade: () => true, + }, + }), + () => Promise.resolve() + ) + expect(upgraded).toBeTruthy() + }) + it('Should response undefined when upgrade failed', async () => { + const { upgradeWebSocket } = createBunWebSocket() + const upgraded = await upgradeWebSocket(() => ({}))( + new Context(new Request('http://localhost'), { + env: { + upgrade: () => undefined, + }, + }), + () => Promise.resolve() + ) + expect(upgraded).toBeFalsy() + }) +}) +describe('createBunWebSocket()', () => { + beforeAll(() => { + // @ts-expect-error patch global + globalThis.CloseEvent = Event + }) + afterAll(() => { + // @ts-expect-error patch global + delete globalThis.CloseEvent + }) + it('Should events are called', async () => { + const { websocket, upgradeWebSocket } = createBunWebSocket() + const ws = { + data: { + connId: 0, + }, + } as BunServerWebSocket + + const open = vi.fn() + const message = vi.fn() + const close = vi.fn() + + let receivedArrayBuffer: ArrayBuffer | undefined = undefined + await upgradeWebSocket(() => ({ + onOpen() { + open() + }, + onMessage(evt) { + message() + if (evt.data instanceof ArrayBuffer) { + receivedArrayBuffer = evt.data + } + }, + onClose() { + close() + }, + }))( + new Context(new Request('http://localhost'), { + env: { + upgrade() { + return true + }, + }, + }), + () => Promise.resolve() + ) + + websocket.open(ws) + expect(open).toBeCalled() + + websocket.message(ws, 'message') + expect(message).toBeCalled() + + websocket.message(ws, new Uint8Array(16)) + expect(receivedArrayBuffer).toBeInstanceOf(ArrayBuffer) + expect(receivedArrayBuffer!.byteLength).toBe(16) + + websocket.close(ws) + expect(close).toBeCalled() + }) +}) diff --git a/src/adapter/bun/websocket.ts b/src/adapter/bun/websocket.ts index daae7647..34b5b1ef 100644 --- a/src/adapter/bun/websocket.ts +++ b/src/adapter/bun/websocket.ts @@ -1,13 +1,11 @@ -import { createWSMessageEvent } from '../../helper/websocket' -import type { - UpgradeWebSocket, - WSContext, - WSEvents, - WSMessageReceive, -} from '../../helper/websocket' +import type { UpgradeWebSocket, WSEvents, WSMessageReceive } from '../../helper/websocket' +import { createWSMessageEvent, defineWebSocketHelper, WSContext } from '../../helper/websocket' import { getBunServer } from './server' -interface BunServerWebSocket { +/** + * @internal + */ +export interface BunServerWebSocket { send(data: string | ArrayBufferLike, compress?: boolean): void close(code?: number, reason?: string): void data: T @@ -29,48 +27,46 @@ export interface BunWebSocketData { protocol: string } -const createWSContext = (ws: BunServerWebSocket): WSContext => { - return { +/** + * @internal + */ +export const createWSContext = (ws: BunServerWebSocket): WSContext => { + return new WSContext({ send: (source, options) => { - const sendingData = - typeof source === 'string' ? source : source instanceof Uint8Array ? source.buffer : source - ws.send(sendingData, options?.compress) + ws.send(source, options?.compress) }, raw: ws, - binaryType: 'arraybuffer', readyState: ws.readyState, url: ws.data.url, protocol: ws.data.protocol, close(code, reason) { ws.close(code, reason) }, - } + }) } export const createBunWebSocket = (): CreateWebSocket => { const websocketConns: WSEvents[] = [] // eslint-disable-next-line @typescript-eslint/no-explicit-any - const upgradeWebSocket: UpgradeWebSocket = (createEvents) => { - return async (c, next) => { - const server = getBunServer(c) - if (!server) { - throw new TypeError('env has to include the 2nd argument of fetch.') - } - const connId = websocketConns.push(await createEvents(c)) - 1 - const upgradeResult = server.upgrade(c.req.raw, { - data: { - connId, - url: new URL(c.req.url), - protocol: c.req.url, - }, - }) - if (upgradeResult) { - return new Response(null) - } - await next() // Failed + const upgradeWebSocket: UpgradeWebSocket = defineWebSocketHelper((c, events) => { + const server = getBunServer(c) + if (!server) { + throw new TypeError('env has to include the 2nd argument of fetch.') } - } + const connId = websocketConns.push(events) - 1 + const upgradeResult = server.upgrade(c.req.raw, { + data: { + connId, + url: new URL(c.req.url), + protocol: c.req.url, + }, + }) + if (upgradeResult) { + return new Response(null) + } + return // failed + }) const websocket: BunWebSocketHandler = { open(ws) { const websocketListeners = websocketConns[ws.data.connId] diff --git a/src/adapter/cloudflare-workers/websocket.test.ts b/src/adapter/cloudflare-workers/websocket.test.ts index 47aa0cc9..a065d4ab 100644 --- a/src/adapter/cloudflare-workers/websocket.test.ts +++ b/src/adapter/cloudflare-workers/websocket.test.ts @@ -1,4 +1,5 @@ import { Hono } from '../..' +import { Context } from '../../context' import { upgradeWebSocket } from '.' describe('upgradeWebSocket middleware', () => { @@ -42,4 +43,18 @@ describe('upgradeWebSocket middleware', () => { expect(sendingData).toBe(await wsPromise) }) + it('Should call next() when header does not have upgrade', async () => { + const next = vi.fn() + await upgradeWebSocket(() => ({}))( + new Context( + new Request('http://localhost', { + headers: { + Upgrade: 'example', + }, + }) + ), + next + ) + expect(next).toBeCalled() + }) }) diff --git a/src/adapter/cloudflare-workers/websocket.ts b/src/adapter/cloudflare-workers/websocket.ts index 312e3478..cbf468c7 100644 --- a/src/adapter/cloudflare-workers/websocket.ts +++ b/src/adapter/cloudflare-workers/websocket.ts @@ -1,50 +1,51 @@ -import type { UpgradeWebSocket, WSContext, WSReadyState } from '../../helper/websocket' +import { WSContext, defineWebSocketHelper } from '../../helper/websocket' +import type { UpgradeWebSocket, WSReadyState } from '../../helper/websocket' // Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332 -export const upgradeWebSocket: UpgradeWebSocket = (createEvents) => async (c, next) => { - const events = await createEvents(c) +export const upgradeWebSocket: UpgradeWebSocket = defineWebSocketHelper( + async (c, events) => { + const upgradeHeader = c.req.header('Upgrade') + if (upgradeHeader !== 'websocket') { + return + } - const upgradeHeader = c.req.header('Upgrade') - if (upgradeHeader !== 'websocket') { - return await next() - } + // @ts-expect-error WebSocketPair is not typed + const webSocketPair = new WebSocketPair() + const client: WebSocket = webSocketPair[0] + const server: WebSocket = webSocketPair[1] - // @ts-expect-error WebSocketPair is not typed - const webSocketPair = new WebSocketPair() - const client: WebSocket = webSocketPair[0] - const server: WebSocket = webSocketPair[1] + const wsContext = new WSContext({ + close: (code, reason) => server.close(code, reason), + get protocol() { + return server.protocol + }, + raw: server, + get readyState() { + return server.readyState as WSReadyState + }, + url: server.url ? new URL(server.url) : null, + send: (source) => server.send(source), + }) - const wsContext: WSContext = { - binaryType: 'arraybuffer', - close: (code, reason) => server.close(code, reason), - get protocol() { - return server.protocol - }, - raw: server, - get readyState() { - return server.readyState as WSReadyState - }, - url: server.url ? new URL(server.url) : null, - send: (source) => server.send(source), - } - if (events.onOpen) { - server.addEventListener('open', (evt: Event) => events.onOpen?.(evt, wsContext)) - } - if (events.onClose) { - server.addEventListener('close', (evt: CloseEvent) => events.onClose?.(evt, wsContext)) - } - if (events.onMessage) { - server.addEventListener('message', (evt: MessageEvent) => events.onMessage?.(evt, wsContext)) - } - if (events.onError) { - server.addEventListener('error', (evt: Event) => events.onError?.(evt, wsContext)) - } + if (events.onOpen) { + server.addEventListener('open', (evt: Event) => events.onOpen?.(evt, wsContext)) + } + if (events.onClose) { + server.addEventListener('close', (evt: CloseEvent) => events.onClose?.(evt, wsContext)) + } + if (events.onMessage) { + server.addEventListener('message', (evt: MessageEvent) => events.onMessage?.(evt, wsContext)) + } + if (events.onError) { + server.addEventListener('error', (evt: Event) => events.onError?.(evt, wsContext)) + } - // @ts-expect-error - server.accept is not typed - server.accept?.() - return new Response(null, { - status: 101, - // @ts-expect-error - webSocket is not typed - webSocket: client, - }) -} + // @ts-expect-error - server.accept is not typed + server.accept?.() + return new Response(null, { + status: 101, + // @ts-expect-error - webSocket is not typed + webSocket: client, + }) + } +) diff --git a/src/adapter/deno/websocket.test.ts b/src/adapter/deno/websocket.test.ts index b4d0d71d..57d8e59d 100644 --- a/src/adapter/deno/websocket.test.ts +++ b/src/adapter/deno/websocket.test.ts @@ -1,4 +1,5 @@ import { Hono } from '../..' +import { Context } from '../../context' import { upgradeWebSocket } from './websocket' globalThis.Deno = {} as typeof Deno @@ -75,4 +76,18 @@ describe('WebSockets', () => { ) expect(await messagePromise).toBe(data) }) + it('Should call next() when header does not have upgrade', async () => { + const next = vi.fn() + await upgradeWebSocket(() => ({}))( + new Context( + new Request('http://localhost', { + headers: { + Upgrade: 'example', + }, + }) + ), + next + ) + expect(next).toBeCalled() + }) }) diff --git a/src/adapter/deno/websocket.ts b/src/adapter/deno/websocket.ts index 1eda3e15..a99bf8dd 100644 --- a/src/adapter/deno/websocket.ts +++ b/src/adapter/deno/websocket.ts @@ -1,4 +1,5 @@ -import type { UpgradeWebSocket, WSContext, WSReadyState } from '../../helper/websocket' +import type { UpgradeWebSocket, WSReadyState } from '../../helper/websocket' +import { WSContext, defineWebSocketHelper } from '../../helper/websocket' export interface UpgradeWebSocketOptions { /** @@ -21,16 +22,14 @@ export interface UpgradeWebSocketOptions { } export const upgradeWebSocket: UpgradeWebSocket = - (createEvents, options) => async (c, next) => { + defineWebSocketHelper(async (c, events, options) => { if (c.req.header('upgrade') !== 'websocket') { - return await next() + return } - const events = await createEvents(c) - const { response, socket } = Deno.upgradeWebSocket(c.req.raw, options || {}) + const { response, socket } = Deno.upgradeWebSocket(c.req.raw, options ?? {}) - const wsContext: WSContext = { - binaryType: 'arraybuffer', + const wsContext: WSContext = new WSContext({ close: (code, reason) => socket.close(code, reason), get protocol() { return socket.protocol @@ -41,11 +40,11 @@ export const upgradeWebSocket: UpgradeWebSocket socket.send(source), - } + }) socket.onopen = (evt) => events.onOpen?.(evt, wsContext) socket.onmessage = (evt) => events.onMessage?.(evt, wsContext) socket.onclose = (evt) => events.onClose?.(evt, wsContext) socket.onerror = (evt) => events.onError?.(evt, wsContext) return response - } + }) diff --git a/src/helper/websocket/index.test.ts b/src/helper/websocket/index.test.ts index 42c68cf8..e1efe4fb 100644 --- a/src/helper/websocket/index.test.ts +++ b/src/helper/websocket/index.test.ts @@ -1,4 +1,6 @@ -import { createWSMessageEvent } from '.' +import { Context } from '../../context' +import type { WSContestInit } from '.' +import { WSContext, createWSMessageEvent, defineWebSocketHelper } from '.' describe('`createWSMessageEvent`', () => { it('Should `createWSMessageEvent` is working for string', () => { @@ -12,3 +14,96 @@ describe('`createWSMessageEvent`', () => { expect(event.type).toBe('message') }) }) +describe('defineWebSocketHelper', () => { + it('defineWebSocketHelper should work', async () => { + const upgradeWebSocket = defineWebSocketHelper(() => { + return new Response('Hello World', { + status: 200, + }) + }) + const response = await upgradeWebSocket(() => ({}))( + new Context(new Request('http://localhost')), + () => Promise.resolve() + ) + expect(response).toBeTruthy() + expect((response as Response).status).toBe(200) + }) + it('When response is undefined, should call next()', async () => { + const upgradeWebSocket = defineWebSocketHelper(() => { + return + }) + const next = vi.fn() + await upgradeWebSocket(() => ({}))(new Context(new Request('http://localhost')), next) + expect(next).toBeCalled() + }) +}) +describe('WSContext', () => { + it('Should close() works', async () => { + type Result = [number | undefined, string | undefined] + let ws!: WSContext + const promise = new Promise((resolve) => { + ws = new WSContext({ + close(code, reason) { + resolve([code, reason]) + }, + } as WSContestInit) + }) + ws.close(0, 'reason') + const [code, reason] = await promise + expect(code).toBe(0) + expect(reason).toBe('reason') + }) + it('Should send() works', async () => { + let ws!: WSContext + const promise = new Promise((resolve) => { + ws = new WSContext({ + // eslint-disable-next-line @typescript-eslint/no-unused-vars + send(data, _options) { + resolve(data) + }, + } as WSContestInit) + }) + ws.send('Hello') + expect(await promise).toBe('Hello') + }) + it('Should readyState works', () => { + const ws = new WSContext({ + readyState: 0, + } as WSContestInit) + expect(ws.readyState).toBe(0) + }) + it('Should normalize URL', () => { + const stringURLWS = new WSContext({ + url: 'http://localhost', + } as WSContestInit) + expect(stringURLWS.url).toBeInstanceOf(URL) + + const urlURLWS = new WSContext({ + url: new URL('http://localhost'), + } as WSContestInit) + expect(urlURLWS.url).toBeInstanceOf(URL) + + const nullURLWS = new WSContext({ + url: undefined, + } as WSContestInit) + expect(nullURLWS.url).toBeNull() + }) + it('Should normalize message in send()', () => { + let data: string | ArrayBuffer | null = null + const wsContext = new WSContext({ + // eslint-disable-next-line @typescript-eslint/no-unused-vars + send(received, _options) { + data = received + }, + } as WSContestInit) + + wsContext.send('string') + expect(data).toBe('string') + + wsContext.send(new ArrayBuffer(16)) + expect(data).toBeInstanceOf(ArrayBuffer) + + wsContext.send(new Uint8Array(16)) + expect(data).toBeInstanceOf(ArrayBuffer) + }) +}) diff --git a/src/helper/websocket/index.ts b/src/helper/websocket/index.ts index 2c27b63d..c6136dce 100644 --- a/src/helper/websocket/index.ts +++ b/src/helper/websocket/index.ts @@ -31,21 +31,58 @@ export type UpgradeWebSocket = ( } > +/** + * ReadyState for WebSocket + */ export type WSReadyState = 0 | 1 | 2 | 3 -export type WSContext = { - send( - source: string | ArrayBuffer | Uint8Array, - options?: { - compress: boolean - } - ): void +/** + * An argument for WSContext class + */ +export interface WSContestInit { + send(data: string | ArrayBuffer, options: SendOptions): void + close(code?: number, reason?: string): void + raw?: T - binaryType: BinaryType readyState: WSReadyState + url?: string | URL | null + protocol?: string | null +} + +/** + * Options for sending message + */ +export interface SendOptions { + compress?: boolean +} + +/** + * A context for controlling WebSockets + */ +export class WSContext { + #init: WSContestInit + constructor(init: WSContestInit) { + this.#init = init + this.raw = init.raw + this.url = init.url ? new URL(init.url) : null + this.protocol = init.protocol ?? null + } + send(source: string | ArrayBuffer | Uint8Array, options?: SendOptions): void { + this.#init.send( + typeof source === 'string' ? source : source instanceof Uint8Array ? source.buffer : source, + options ?? {} + ) + } + raw?: T + binaryType: BinaryType = 'arraybuffer' + get readyState(): WSReadyState { + return this.#init.readyState + } url: URL | null protocol: string | null - close(code?: number, reason?: string): void + close(code?: number, reason?: string) { + this.#init.close(code, reason) + } } export type WSMessageReceive = string | Blob | ArrayBufferLike @@ -55,3 +92,28 @@ export const createWSMessageEvent = (source: WSMessageReceive): MessageEvent = ( + c: Context, + events: WSEvents, + options?: U +) => Promise | Response | void + +/** + * Create a WebSocket adapter/helper + */ +export const defineWebSocketHelper = ( + handler: WebSocketHelperDefineHandler +): UpgradeWebSocket => { + return (createEvents, options) => { + return async function UpgradeWebSocket(c, next) { + const events = await createEvents(c) + const result = await handler(c, events, options) + if (result) { + return result + } + await next() + } + } +}