0
0
mirror of https://github.com/honojs/hono.git synced 2024-11-21 18:18:57 +01:00

feat(ws): Make WebSocket adapter more changeable (#3531)

* feat(ws): Make WebSocket adapter more changeable

* lint

* format code

* test: add tests

* format code

* lint:fix

* format code

* test: add more tests
This commit is contained in:
Shotaro Nakamura 2024-10-25 16:09:36 +09:00 committed by GitHub
parent 0a99bd3e74
commit 234b083777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 403 additions and 97 deletions

View File

@ -0,0 +1,123 @@
import { Context } from '../../context'
import type { BunWebSocketData, BunServerWebSocket } from './websocket'
import { createWSContext, createBunWebSocket } from './websocket'
describe('createWSContext()', () => {
it('Should send() and close() works', () => {
const send = vi.fn()
const close = vi.fn()
const ws = createWSContext({
send(data) {
send(data)
},
close(code, reason) {
close(code, reason)
},
data: {},
} as BunServerWebSocket<BunWebSocketData>)
ws.send('message')
expect(send).toBeCalled()
ws.close()
expect(close).toBeCalled()
})
})
describe('upgradeWebSocket()', () => {
it('Should throw error when server is null', async () => {
const { upgradeWebSocket } = createBunWebSocket()
const run = async () =>
await upgradeWebSocket(() => ({}))(
new Context(new Request('http://localhost'), {
env: {
server: null,
},
}),
() => Promise.resolve()
)
expect(run).rejects.toThrowError(/env has/)
})
it('Should response null when upgraded', async () => {
const { upgradeWebSocket } = createBunWebSocket()
const upgraded = await upgradeWebSocket(() => ({}))(
new Context(new Request('http://localhost'), {
env: {
upgrade: () => true,
},
}),
() => Promise.resolve()
)
expect(upgraded).toBeTruthy()
})
it('Should response undefined when upgrade failed', async () => {
const { upgradeWebSocket } = createBunWebSocket()
const upgraded = await upgradeWebSocket(() => ({}))(
new Context(new Request('http://localhost'), {
env: {
upgrade: () => undefined,
},
}),
() => Promise.resolve()
)
expect(upgraded).toBeFalsy()
})
})
describe('createBunWebSocket()', () => {
beforeAll(() => {
// @ts-expect-error patch global
globalThis.CloseEvent = Event
})
afterAll(() => {
// @ts-expect-error patch global
delete globalThis.CloseEvent
})
it('Should events are called', async () => {
const { websocket, upgradeWebSocket } = createBunWebSocket()
const ws = {
data: {
connId: 0,
},
} as BunServerWebSocket<BunWebSocketData>
const open = vi.fn()
const message = vi.fn()
const close = vi.fn()
let receivedArrayBuffer: ArrayBuffer | undefined = undefined
await upgradeWebSocket(() => ({
onOpen() {
open()
},
onMessage(evt) {
message()
if (evt.data instanceof ArrayBuffer) {
receivedArrayBuffer = evt.data
}
},
onClose() {
close()
},
}))(
new Context(new Request('http://localhost'), {
env: {
upgrade() {
return true
},
},
}),
() => Promise.resolve()
)
websocket.open(ws)
expect(open).toBeCalled()
websocket.message(ws, 'message')
expect(message).toBeCalled()
websocket.message(ws, new Uint8Array(16))
expect(receivedArrayBuffer).toBeInstanceOf(ArrayBuffer)
expect(receivedArrayBuffer!.byteLength).toBe(16)
websocket.close(ws)
expect(close).toBeCalled()
})
})

View File

@ -1,13 +1,11 @@
import { createWSMessageEvent } from '../../helper/websocket'
import type {
UpgradeWebSocket,
WSContext,
WSEvents,
WSMessageReceive,
} from '../../helper/websocket'
import type { UpgradeWebSocket, WSEvents, WSMessageReceive } from '../../helper/websocket'
import { createWSMessageEvent, defineWebSocketHelper, WSContext } from '../../helper/websocket'
import { getBunServer } from './server'
interface BunServerWebSocket<T> {
/**
* @internal
*/
export interface BunServerWebSocket<T> {
send(data: string | ArrayBufferLike, compress?: boolean): void
close(code?: number, reason?: string): void
data: T
@ -29,48 +27,46 @@ export interface BunWebSocketData {
protocol: string
}
const createWSContext = (ws: BunServerWebSocket<BunWebSocketData>): WSContext => {
return {
/**
* @internal
*/
export const createWSContext = (ws: BunServerWebSocket<BunWebSocketData>): WSContext => {
return new WSContext({
send: (source, options) => {
const sendingData =
typeof source === 'string' ? source : source instanceof Uint8Array ? source.buffer : source
ws.send(sendingData, options?.compress)
ws.send(source, options?.compress)
},
raw: ws,
binaryType: 'arraybuffer',
readyState: ws.readyState,
url: ws.data.url,
protocol: ws.data.protocol,
close(code, reason) {
ws.close(code, reason)
},
}
})
}
export const createBunWebSocket = <T>(): CreateWebSocket<T> => {
const websocketConns: WSEvents[] = []
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const upgradeWebSocket: UpgradeWebSocket<any> = (createEvents) => {
return async (c, next) => {
const server = getBunServer(c)
if (!server) {
throw new TypeError('env has to include the 2nd argument of fetch.')
}
const connId = websocketConns.push(await createEvents(c)) - 1
const upgradeResult = server.upgrade<BunWebSocketData>(c.req.raw, {
data: {
connId,
url: new URL(c.req.url),
protocol: c.req.url,
},
})
if (upgradeResult) {
return new Response(null)
}
await next() // Failed
const upgradeWebSocket: UpgradeWebSocket<any> = defineWebSocketHelper((c, events) => {
const server = getBunServer(c)
if (!server) {
throw new TypeError('env has to include the 2nd argument of fetch.')
}
}
const connId = websocketConns.push(events) - 1
const upgradeResult = server.upgrade<BunWebSocketData>(c.req.raw, {
data: {
connId,
url: new URL(c.req.url),
protocol: c.req.url,
},
})
if (upgradeResult) {
return new Response(null)
}
return // failed
})
const websocket: BunWebSocketHandler<BunWebSocketData> = {
open(ws) {
const websocketListeners = websocketConns[ws.data.connId]

View File

@ -1,4 +1,5 @@
import { Hono } from '../..'
import { Context } from '../../context'
import { upgradeWebSocket } from '.'
describe('upgradeWebSocket middleware', () => {
@ -42,4 +43,18 @@ describe('upgradeWebSocket middleware', () => {
expect(sendingData).toBe(await wsPromise)
})
it('Should call next() when header does not have upgrade', async () => {
const next = vi.fn()
await upgradeWebSocket(() => ({}))(
new Context(
new Request('http://localhost', {
headers: {
Upgrade: 'example',
},
})
),
next
)
expect(next).toBeCalled()
})
})

View File

@ -1,50 +1,51 @@
import type { UpgradeWebSocket, WSContext, WSReadyState } from '../../helper/websocket'
import { WSContext, defineWebSocketHelper } from '../../helper/websocket'
import type { UpgradeWebSocket, WSReadyState } from '../../helper/websocket'
// Based on https://github.com/honojs/hono/issues/1153#issuecomment-1767321332
export const upgradeWebSocket: UpgradeWebSocket<WebSocket> = (createEvents) => async (c, next) => {
const events = await createEvents(c)
export const upgradeWebSocket: UpgradeWebSocket<WebSocket> = defineWebSocketHelper(
async (c, events) => {
const upgradeHeader = c.req.header('Upgrade')
if (upgradeHeader !== 'websocket') {
return
}
const upgradeHeader = c.req.header('Upgrade')
if (upgradeHeader !== 'websocket') {
return await next()
}
// @ts-expect-error WebSocketPair is not typed
const webSocketPair = new WebSocketPair()
const client: WebSocket = webSocketPair[0]
const server: WebSocket = webSocketPair[1]
// @ts-expect-error WebSocketPair is not typed
const webSocketPair = new WebSocketPair()
const client: WebSocket = webSocketPair[0]
const server: WebSocket = webSocketPair[1]
const wsContext = new WSContext<WebSocket>({
close: (code, reason) => server.close(code, reason),
get protocol() {
return server.protocol
},
raw: server,
get readyState() {
return server.readyState as WSReadyState
},
url: server.url ? new URL(server.url) : null,
send: (source) => server.send(source),
})
const wsContext: WSContext<WebSocket> = {
binaryType: 'arraybuffer',
close: (code, reason) => server.close(code, reason),
get protocol() {
return server.protocol
},
raw: server,
get readyState() {
return server.readyState as WSReadyState
},
url: server.url ? new URL(server.url) : null,
send: (source) => server.send(source),
}
if (events.onOpen) {
server.addEventListener('open', (evt: Event) => events.onOpen?.(evt, wsContext))
}
if (events.onClose) {
server.addEventListener('close', (evt: CloseEvent) => events.onClose?.(evt, wsContext))
}
if (events.onMessage) {
server.addEventListener('message', (evt: MessageEvent) => events.onMessage?.(evt, wsContext))
}
if (events.onError) {
server.addEventListener('error', (evt: Event) => events.onError?.(evt, wsContext))
}
if (events.onOpen) {
server.addEventListener('open', (evt: Event) => events.onOpen?.(evt, wsContext))
}
if (events.onClose) {
server.addEventListener('close', (evt: CloseEvent) => events.onClose?.(evt, wsContext))
}
if (events.onMessage) {
server.addEventListener('message', (evt: MessageEvent) => events.onMessage?.(evt, wsContext))
}
if (events.onError) {
server.addEventListener('error', (evt: Event) => events.onError?.(evt, wsContext))
}
// @ts-expect-error - server.accept is not typed
server.accept?.()
return new Response(null, {
status: 101,
// @ts-expect-error - webSocket is not typed
webSocket: client,
})
}
// @ts-expect-error - server.accept is not typed
server.accept?.()
return new Response(null, {
status: 101,
// @ts-expect-error - webSocket is not typed
webSocket: client,
})
}
)

View File

@ -1,4 +1,5 @@
import { Hono } from '../..'
import { Context } from '../../context'
import { upgradeWebSocket } from './websocket'
globalThis.Deno = {} as typeof Deno
@ -75,4 +76,18 @@ describe('WebSockets', () => {
)
expect(await messagePromise).toBe(data)
})
it('Should call next() when header does not have upgrade', async () => {
const next = vi.fn()
await upgradeWebSocket(() => ({}))(
new Context(
new Request('http://localhost', {
headers: {
Upgrade: 'example',
},
})
),
next
)
expect(next).toBeCalled()
})
})

View File

@ -1,4 +1,5 @@
import type { UpgradeWebSocket, WSContext, WSReadyState } from '../../helper/websocket'
import type { UpgradeWebSocket, WSReadyState } from '../../helper/websocket'
import { WSContext, defineWebSocketHelper } from '../../helper/websocket'
export interface UpgradeWebSocketOptions {
/**
@ -21,16 +22,14 @@ export interface UpgradeWebSocketOptions {
}
export const upgradeWebSocket: UpgradeWebSocket<WebSocket, UpgradeWebSocketOptions> =
(createEvents, options) => async (c, next) => {
defineWebSocketHelper(async (c, events, options) => {
if (c.req.header('upgrade') !== 'websocket') {
return await next()
return
}
const events = await createEvents(c)
const { response, socket } = Deno.upgradeWebSocket(c.req.raw, options || {})
const { response, socket } = Deno.upgradeWebSocket(c.req.raw, options ?? {})
const wsContext: WSContext<WebSocket> = {
binaryType: 'arraybuffer',
const wsContext: WSContext<WebSocket> = new WSContext({
close: (code, reason) => socket.close(code, reason),
get protocol() {
return socket.protocol
@ -41,11 +40,11 @@ export const upgradeWebSocket: UpgradeWebSocket<WebSocket, UpgradeWebSocketOptio
},
url: socket.url ? new URL(socket.url) : null,
send: (source) => socket.send(source),
}
})
socket.onopen = (evt) => events.onOpen?.(evt, wsContext)
socket.onmessage = (evt) => events.onMessage?.(evt, wsContext)
socket.onclose = (evt) => events.onClose?.(evt, wsContext)
socket.onerror = (evt) => events.onError?.(evt, wsContext)
return response
}
})

View File

@ -1,4 +1,6 @@
import { createWSMessageEvent } from '.'
import { Context } from '../../context'
import type { WSContestInit } from '.'
import { WSContext, createWSMessageEvent, defineWebSocketHelper } from '.'
describe('`createWSMessageEvent`', () => {
it('Should `createWSMessageEvent` is working for string', () => {
@ -12,3 +14,96 @@ describe('`createWSMessageEvent`', () => {
expect(event.type).toBe('message')
})
})
describe('defineWebSocketHelper', () => {
it('defineWebSocketHelper should work', async () => {
const upgradeWebSocket = defineWebSocketHelper(() => {
return new Response('Hello World', {
status: 200,
})
})
const response = await upgradeWebSocket(() => ({}))(
new Context(new Request('http://localhost')),
() => Promise.resolve()
)
expect(response).toBeTruthy()
expect((response as Response).status).toBe(200)
})
it('When response is undefined, should call next()', async () => {
const upgradeWebSocket = defineWebSocketHelper(() => {
return
})
const next = vi.fn()
await upgradeWebSocket(() => ({}))(new Context(new Request('http://localhost')), next)
expect(next).toBeCalled()
})
})
describe('WSContext', () => {
it('Should close() works', async () => {
type Result = [number | undefined, string | undefined]
let ws!: WSContext
const promise = new Promise<Result>((resolve) => {
ws = new WSContext({
close(code, reason) {
resolve([code, reason])
},
} as WSContestInit)
})
ws.close(0, 'reason')
const [code, reason] = await promise
expect(code).toBe(0)
expect(reason).toBe('reason')
})
it('Should send() works', async () => {
let ws!: WSContext
const promise = new Promise<string | ArrayBuffer>((resolve) => {
ws = new WSContext({
// eslint-disable-next-line @typescript-eslint/no-unused-vars
send(data, _options) {
resolve(data)
},
} as WSContestInit)
})
ws.send('Hello')
expect(await promise).toBe('Hello')
})
it('Should readyState works', () => {
const ws = new WSContext({
readyState: 0,
} as WSContestInit)
expect(ws.readyState).toBe(0)
})
it('Should normalize URL', () => {
const stringURLWS = new WSContext({
url: 'http://localhost',
} as WSContestInit)
expect(stringURLWS.url).toBeInstanceOf(URL)
const urlURLWS = new WSContext({
url: new URL('http://localhost'),
} as WSContestInit)
expect(urlURLWS.url).toBeInstanceOf(URL)
const nullURLWS = new WSContext({
url: undefined,
} as WSContestInit)
expect(nullURLWS.url).toBeNull()
})
it('Should normalize message in send()', () => {
let data: string | ArrayBuffer | null = null
const wsContext = new WSContext({
// eslint-disable-next-line @typescript-eslint/no-unused-vars
send(received, _options) {
data = received
},
} as WSContestInit)
wsContext.send('string')
expect(data).toBe('string')
wsContext.send(new ArrayBuffer(16))
expect(data).toBeInstanceOf(ArrayBuffer)
wsContext.send(new Uint8Array(16))
expect(data).toBeInstanceOf(ArrayBuffer)
})
})

View File

@ -31,21 +31,58 @@ export type UpgradeWebSocket<T = unknown, U = any> = (
}
>
/**
* ReadyState for WebSocket
*/
export type WSReadyState = 0 | 1 | 2 | 3
export type WSContext<T = unknown> = {
send(
source: string | ArrayBuffer | Uint8Array,
options?: {
compress: boolean
}
): void
/**
* An argument for WSContext class
*/
export interface WSContestInit<T = unknown> {
send(data: string | ArrayBuffer, options: SendOptions): void
close(code?: number, reason?: string): void
raw?: T
binaryType: BinaryType
readyState: WSReadyState
url?: string | URL | null
protocol?: string | null
}
/**
* Options for sending message
*/
export interface SendOptions {
compress?: boolean
}
/**
* A context for controlling WebSockets
*/
export class WSContext<T = unknown> {
#init: WSContestInit<T>
constructor(init: WSContestInit<T>) {
this.#init = init
this.raw = init.raw
this.url = init.url ? new URL(init.url) : null
this.protocol = init.protocol ?? null
}
send(source: string | ArrayBuffer | Uint8Array, options?: SendOptions): void {
this.#init.send(
typeof source === 'string' ? source : source instanceof Uint8Array ? source.buffer : source,
options ?? {}
)
}
raw?: T
binaryType: BinaryType = 'arraybuffer'
get readyState(): WSReadyState {
return this.#init.readyState
}
url: URL | null
protocol: string | null
close(code?: number, reason?: string): void
close(code?: number, reason?: string) {
this.#init.close(code, reason)
}
}
export type WSMessageReceive = string | Blob | ArrayBufferLike
@ -55,3 +92,28 @@ export const createWSMessageEvent = (source: WSMessageReceive): MessageEvent<WSM
data: source,
})
}
export interface WebSocketHelperDefineContext {}
export type WebSocketHelperDefineHandler<T, U> = (
c: Context,
events: WSEvents<T>,
options?: U
) => Promise<Response | void> | Response | void
/**
* Create a WebSocket adapter/helper
*/
export const defineWebSocketHelper = <T = unknown, U = any>(
handler: WebSocketHelperDefineHandler<T, U>
): UpgradeWebSocket<T, U> => {
return (createEvents, options) => {
return async function UpgradeWebSocket(c, next) {
const events = await createEvents(c)
const result = await handler(c, events, options)
if (result) {
return result
}
await next()
}
}
}