diff --git a/jsr.json b/jsr.json index 46e57a7b..ac3177da 100644 --- a/jsr.json +++ b/jsr.json @@ -21,6 +21,7 @@ "./basic-auth": "./src/middleware/basic-auth/index.ts", "./bearer-auth": "./src/middleware/bearer-auth/index.ts", "./body-limit": "./src/middleware/body-limit/index.ts", + "./ip-restriction": "./src/middleware/ip-restriction/index.ts", "./cache": "./src/middleware/cache/index.ts", "./cookie": "./src/helper/cookie/index.ts", "./accepts": "./src/helper/accepts/index.ts", @@ -93,7 +94,8 @@ "./utils/mime": "./src/utils/mime.ts", "./utils/stream": "./src/utils/stream.ts", "./utils/types": "./src/utils/types.ts", - "./utils/url": "./src/utils/url.ts" + "./utils/url": "./src/utils/url.ts", + "./utils/ipaddr": "./src/utils/ipaddr.ts" }, "publish": { "include": [ diff --git a/package.json b/package.json index 598e8837..6f657e54 100644 --- a/package.json +++ b/package.json @@ -78,6 +78,11 @@ "import": "./dist/middleware/body-limit/index.js", "require": "./dist/cjs/middleware/body-limit/index.js" }, + "./ip-restriction": { + "types": "./dist/types/middleware/ip-restriction/index.d.ts", + "import": "./dist/middleware/ip-restriction/index.js", + "require": "./dist/cjs/middleware/ip-restriction/index.js" + }, "./cache": { "types": "./dist/types/middleware/cache/index.d.ts", "import": "./dist/middleware/cache/index.js", @@ -385,6 +390,9 @@ "body-limit": [ "./dist/types/middleware/body-limit" ], + "ip-restriction": [ + "./dist/types/middleware/ip-restriction" + ], "cache": [ "./dist/types/middleware/cache" ], diff --git a/src/middleware/ip-restriction/index.test.ts b/src/middleware/ip-restriction/index.test.ts new file mode 100644 index 00000000..86904f03 --- /dev/null +++ b/src/middleware/ip-restriction/index.test.ts @@ -0,0 +1,114 @@ +import { Hono } from '../../hono' +import { Context } from '../../context' +import type { AddressType, GetConnInfo } from '../../helper/conninfo' +import { ipRestriction } from '.' +import type { IPRestrictionRule } from '.' + +describe('ipRestriction middleware', () => { + it('Should restrict', async () => { + const getConnInfo: GetConnInfo = (c) => { + return { + remote: { + address: c.env.ip, + }, + } + } + const app = new Hono<{ + Bindings: { + ip: string + } + }>() + app.use( + '/basic', + ipRestriction(getConnInfo, { + allowList: ['192.168.1.0', '192.168.2.0/24'], + denyList: ['192.168.2.10'], + }) + ) + app.get('/basic', (c) => c.text('Hello World!')) + + app.use( + '/allow-empty', + ipRestriction(getConnInfo, { + denyList: ['192.168.1.0'], + }) + ) + app.get('/allow-empty', (c) => c.text('Hello World!')) + + expect((await app.request('/basic', {}, { ip: '0.0.0.0' })).status).toBe(403) + + expect((await app.request('/basic', {}, { ip: '192.168.1.0' })).status).toBe(200) + + expect((await app.request('/basic', {}, { ip: '192.168.2.5' })).status).toBe(200) + expect((await app.request('/basic', {}, { ip: '192.168.2.10' })).status).toBe(403) + + expect((await app.request('/allow-empty', {}, { ip: '0.0.0.0' })).status).toBe(200) + + expect((await app.request('/allow-empty', {}, { ip: '192.168.1.0' })).status).toBe(403) + + expect((await app.request('/allow-empty', {}, { ip: '192.168.2.5' })).status).toBe(200) + expect((await app.request('/allow-empty', {}, { ip: '192.168.2.10' })).status).toBe(200) + }) + it('Custom onerror', async () => { + const res = await ipRestriction( + () => '0.0.0.0', + { denyList: ['0.0.0.0'] }, + () => new Response('error') + )(new Context(new Request('http://localhost/')), async () => void 0) + expect(res).toBeTruthy() + if (res) { + expect(await res.text()).toBe('error') + } + }) +}) + +describe('isMatchForRule', () => { + const isMatch = async (info: { addr: string; type: AddressType }, rule: IPRestrictionRule) => { + const middleware = ipRestriction( + () => ({ + remote: { + address: info.addr, + addressType: info.type, + }, + }), + { + allowList: [rule], + } + ) + try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await middleware(undefined as any, () => Promise.resolve()) + } catch (e) { + return false + } + return true + } + + it('star', async () => { + expect(await isMatch({ addr: '192.168.2.0', type: 'IPv4' }, '*')).toBeTruthy() + expect(await isMatch({ addr: '192.168.2.1', type: 'IPv4' }, '*')).toBeTruthy() + expect(await isMatch({ addr: '::0', type: 'IPv6' }, '*')).toBeTruthy() + }) + it('CIDR Notation', async () => { + expect(await isMatch({ addr: '192.168.2.0', type: 'IPv4' }, '192.168.2.0/24')).toBeTruthy() + expect(await isMatch({ addr: '192.168.2.1', type: 'IPv4' }, '192.168.2.0/24')).toBeTruthy() + expect(await isMatch({ addr: '192.168.2.1', type: 'IPv4' }, '192.168.2.1/32')).toBeTruthy() + expect(await isMatch({ addr: '192.168.2.1', type: 'IPv4' }, '192.168.2.2/32')).toBeFalsy() + + expect(await isMatch({ addr: '::0', type: 'IPv6' }, '::0/1')).toBeTruthy() + }) + it('Static Rules', async () => { + expect(await isMatch({ addr: '192.168.2.1', type: 'IPv4' }, '192.168.2.1')).toBeTruthy() + expect(await isMatch({ addr: '1234::5678', type: 'IPv6' }, '1234::5678')).toBeTruthy() + }) + it('Function Rules', async () => { + expect(await isMatch({ addr: '0.0.0.0', type: 'IPv4' }, () => true)).toBeTruthy() + expect(await isMatch({ addr: '0.0.0.0', type: 'IPv4' }, () => false)).toBeFalsy() + + const ipaddr = '93.184.216.34' + await isMatch({ addr: ipaddr, type: 'IPv4' }, (ip) => { + expect(ipaddr).toBe(ip.addr) + return false + }) + }) +}) diff --git a/src/middleware/ip-restriction/index.ts b/src/middleware/ip-restriction/index.ts new file mode 100644 index 00000000..2c0db47c --- /dev/null +++ b/src/middleware/ip-restriction/index.ts @@ -0,0 +1,178 @@ +/** + * Middleware for restrict IP Address + * @module + */ + +import type { Context, MiddlewareHandler } from '../..' +import type { AddressType, GetConnInfo } from '../../helper/conninfo' +import { HTTPException } from '../../http-exception' +import { + convertIPv4ToBinary, + convertIPv6BinaryToString, + convertIPv6ToBinary, + distinctRemoteAddr, +} from '../../utils/ipaddr' + +/** + * Function to get IP Address + */ +type GetIPAddr = GetConnInfo | ((c: Context) => string) + +/** + * ### IPv4 and IPv6 + * - `*` match all + * + * ### IPv4 + * - `192.168.2.0` static + * - `192.168.2.0/24` CIDR Notation + * + * ### IPv6 + * - `::1` static + * - `::1/10` CIDR Notation + */ +type IPRestrictionRuleFunction = (addr: { addr: string; type: AddressType }) => boolean +export type IPRestrictionRule = string | ((addr: { addr: string; type: AddressType }) => boolean) + +const IS_CIDR_NOTATION_REGEX = /\/[0-9]{0,3}$/ +const buildMatcher = ( + rules: IPRestrictionRule[] +): ((addr: { addr: string; type: AddressType; isIPv4: boolean }) => boolean) => { + const functionRules: IPRestrictionRuleFunction[] = [] + const staticRules: Set = new Set() + const cidrRules: [boolean, bigint, bigint][] = [] + + for (let rule of rules) { + if (rule === '*') { + return () => true + } else if (typeof rule === 'function') { + functionRules.push(rule) + } else { + if (IS_CIDR_NOTATION_REGEX.test(rule)) { + const splittedRule = rule.split('/') + + const addrStr = splittedRule[0] + const type = distinctRemoteAddr(addrStr) + if (type === undefined) { + throw new TypeError(`Invalid rule: ${rule}`) + } + + const isIPv4 = type === 'IPv4' + const prefix = parseInt(splittedRule[1]) + + if (isIPv4 ? prefix === 32 : prefix === 128) { + // this rule is a static rule + rule = addrStr + } else { + const addr = (isIPv4 ? convertIPv4ToBinary : convertIPv6ToBinary)(addrStr) + const mask = ((1n << BigInt(prefix)) - 1n) << BigInt((isIPv4 ? 32 : 128) - prefix) + + cidrRules.push([isIPv4, addr & mask, mask] as [boolean, bigint, bigint]) + continue + } + } + + const type = distinctRemoteAddr(rule) + if (type === undefined) { + throw new TypeError(`Invalid rule: ${rule}`) + } + staticRules.add( + type === 'IPv4' + ? rule // IPv4 address is already normalized, so it is registered as is. + : convertIPv6BinaryToString(convertIPv6ToBinary(rule)) // normalize IPv6 address (e.g. 0000:0000:0000:0000:0000:0000:0000:0001 => ::1) + ) + } + } + + return (remote: { + addr: string + type: AddressType + isIPv4: boolean + binaryAddr?: bigint + }): boolean => { + if (staticRules.has(remote.addr)) { + return true + } + for (const [isIPv4, addr, mask] of cidrRules) { + if (isIPv4 !== remote.isIPv4) { + continue + } + const remoteAddr = (remote.binaryAddr ||= ( + isIPv4 ? convertIPv4ToBinary : convertIPv6ToBinary + )(remote.addr)) + if ((remoteAddr & mask) === addr) { + return true + } + } + for (const rule of functionRules) { + if (rule({ addr: remote.addr, type: remote.type })) { + return true + } + } + return false + } +} + +/** + * Rules for IP Limit Middleware + */ +export interface IPRestrictionRules { + denyList?: IPRestrictionRule[] + allowList?: IPRestrictionRule[] +} + +/** + * IP Limit Middleware + * + * @param getIP function to get IP Address + */ +export const ipRestriction = ( + getIP: GetIPAddr, + { denyList = [], allowList = [] }: IPRestrictionRules, + onError?: ( + remote: { addr: string; type: AddressType }, + c: Context + ) => Response | Promise +): MiddlewareHandler => { + const allowLength = allowList.length + + const denyMatcher = buildMatcher(denyList) + const allowMatcher = buildMatcher(allowList) + + const blockError = (c: Context): HTTPException => + new HTTPException(403, { + res: c.text('Forbidden', { + status: 403, + }), + }) + + return async function (c, next) { + const connInfo = getIP(c) + const addr = typeof connInfo === 'string' ? connInfo : connInfo.remote.address + if (!addr) { + throw blockError(c) + } + const type = + (typeof connInfo !== 'string' && connInfo.remote.addressType) || distinctRemoteAddr(addr) + + const remoteData = { addr, type, isIPv4: type === 'IPv4' } + + if (denyMatcher(remoteData)) { + if (onError) { + return onError({ addr, type }, c) + } + throw blockError(c) + } + if (allowMatcher(remoteData)) { + return await next() + } + + if (allowLength === 0) { + return await next() + } else { + if (onError) { + return await onError({ addr, type }, c) + } + throw blockError(c) + } + } +} diff --git a/src/utils/ipaddr.test.ts b/src/utils/ipaddr.test.ts new file mode 100644 index 00000000..d4cff8e1 --- /dev/null +++ b/src/utils/ipaddr.test.ts @@ -0,0 +1,61 @@ +import { + convertIPv4ToBinary, + convertIPv6BinaryToString, + convertIPv6ToBinary, + distinctRemoteAddr, + expandIPv6, +} from './ipaddr' + +describe('expandIPv6', () => { + it('Should result be valid', () => { + expect(expandIPv6('1::1')).toBe('0001:0000:0000:0000:0000:0000:0000:0001') + expect(expandIPv6('::1')).toBe('0000:0000:0000:0000:0000:0000:0000:0001') + expect(expandIPv6('2001:2::')).toBe('2001:0002:0000:0000:0000:0000:0000:0000') + expect(expandIPv6('2001:2::')).toBe('2001:0002:0000:0000:0000:0000:0000:0000') + expect(expandIPv6('2001:0:0:db8::1')).toBe('2001:0000:0000:0db8:0000:0000:0000:0001') + }) +}) +describe('distinctRemoteAddr', () => { + it('Should result be valud', () => { + expect(distinctRemoteAddr('1::1')).toBe('IPv6') + expect(distinctRemoteAddr('::1')).toBe('IPv6') + + expect(distinctRemoteAddr('192.168.2.0')).toBe('IPv4') + expect(distinctRemoteAddr('192.168.2.0')).toBe('IPv4') + + expect(distinctRemoteAddr('example.com')).toBeUndefined() + }) +}) + +describe('convertIPv4ToBinary', () => { + it('Should result is valid', () => { + expect(convertIPv4ToBinary('0.0.0.0')).toBe(0n) + expect(convertIPv4ToBinary('0.0.0.1')).toBe(1n) + + expect(convertIPv4ToBinary('0.0.1.0')).toBe(1n << 8n) + }) +}) +describe('convertIPv6ToBinary', () => { + it('Should result is valid', () => { + expect(convertIPv6ToBinary('::0')).toBe(0n) + expect(convertIPv6ToBinary('::1')).toBe(1n) + + expect(convertIPv6ToBinary('::f')).toBe(15n) + expect(convertIPv6ToBinary('1234:::5678')).toBe(24196103360772296748952112894165669496n) + }) +}) + +describe('convertIPv6ToString', () => { + // add tons of test cases here + test.each` + input | expected + ${'::1'} | ${'::1'} + ${'1::'} | ${'1::'} + ${'1234:::5678'} | ${'1234::5678'} + ${'2001:2::'} | ${'2001:2::'} + ${'2001::db8:0:0:0:0:1'} | ${'2001:0:db8::1'} + ${'1234:5678:9abc:def0:1234:5678:9abc:def0'} | ${'1234:5678:9abc:def0:1234:5678:9abc:def0'} + `('convertIPv6ToString($input) === $expected', ({ input, expected }) => { + expect(convertIPv6BinaryToString(convertIPv6ToBinary(input))).toBe(expected) + }) +}) diff --git a/src/utils/ipaddr.ts b/src/utils/ipaddr.ts new file mode 100644 index 00000000..925a98f5 --- /dev/null +++ b/src/utils/ipaddr.ts @@ -0,0 +1,113 @@ +/** + * Utils for IP Addresses + * @module + */ + +import type { AddressType } from '../helper/conninfo' + +/** + * Expand IPv6 Address + * @param ipV6 Shorten IPv6 Address + * @return expanded IPv6 Address + */ +export const expandIPv6 = (ipV6: string): string => { + const sections = ipV6.split(':') + for (let i = 0; i < sections.length; i++) { + const node = sections[i] + if (node !== '') { + sections[i] = node.padStart(4, '0') + } else { + sections[i + 1] === '' && sections.splice(i + 1, 1) + sections[i] = new Array(8 - sections.length + 1).fill('0000').join(':') + } + } + return sections.join(':') +} + +const IPV4_REGEX = /^[0-9]{0,3}\.[0-9]{0,3}\.[0-9]{0,3}\.[0-9]{0,3}$/ + +/** + * Distinct Remote Addr + * @param remoteAddr Remote Addr + */ +export const distinctRemoteAddr = (remoteAddr: string): AddressType => { + if (IPV4_REGEX.test(remoteAddr)) { + return 'IPv4' + } + if (remoteAddr.includes(':')) { + // Domain can't include `:` + return 'IPv6' + } +} + +/** + * Convert IPv4 to Uint8Array + * @param ipv4 IPv4 Address + * @returns BigInt + */ +export const convertIPv4ToBinary = (ipv4: string): bigint => { + const parts = ipv4.split('.') + let result = 0n + for (let i = 0; i < 4; i++) { + result <<= 8n + result += BigInt(parts[i]) + } + return result +} + +/** + * Convert IPv6 to Uint8Array + * @param ipv6 IPv6 Address + * @returns BigInt + */ +export const convertIPv6ToBinary = (ipv6: string): bigint => { + const sections = expandIPv6(ipv6).split(':') + let result = 0n + for (let i = 0; i < 8; i++) { + result <<= 16n + result += BigInt(parseInt(sections[i], 16)) + } + return result +} + +/** + * Convert a binary representation of an IPv6 address to a string. + * @param ipV6 binary IPv6 Address + * @return normalized IPv6 Address in string + */ +export const convertIPv6BinaryToString = (ipV6: bigint): string => { + const sections = [] + for (let i = 0; i < 8; i++) { + sections.push(((ipV6 >> BigInt(16 * (7 - i))) & 0xffffn).toString(16)) + } + + let currentZeroStart = -1 + let maxZeroStart = -1 + let maxZeroEnd = -1 + for (let i = 0; i < 8; i++) { + if (sections[i] === '0') { + if (currentZeroStart === -1) { + currentZeroStart = i + } + } else { + if (currentZeroStart > -1) { + if (i - currentZeroStart > maxZeroEnd - maxZeroStart) { + maxZeroStart = currentZeroStart + maxZeroEnd = i + } + currentZeroStart = -1 + } + } + } + if (currentZeroStart > -1) { + if (8 - currentZeroStart > maxZeroEnd - maxZeroStart) { + maxZeroStart = currentZeroStart + maxZeroEnd = 8 + } + } + if (maxZeroStart !== -1) { + sections.splice(maxZeroStart, maxZeroEnd - maxZeroStart, ':') + } + + return sections.join(':').replace(/:{2,}/g, '::') +}