From ea3d799cdf199781c8a38496003aa32bad1d255e Mon Sep 17 00:00:00 2001 From: Taku Amano Date: Wed, 30 Oct 2024 11:04:05 +0900 Subject: [PATCH] fix(jsx/dom): fix `memo` for DOM renderer (#3568) Fixes #3473 Fixes #3567 * fix(jsx/dom): fix memoization mechanism in dom renderer * fix(jsx/dom): fix `memo` for DOM renderer * feat(jsx/dom): implement light weight `memo` function for DOM renderer * test(jsx/dom): add tests for memoization --- src/jsx/base.ts | 19 ++++- src/jsx/constants.ts | 1 + src/jsx/dom/index.test.tsx | 165 +++++++++++++++++++++++++++++-------- src/jsx/dom/index.ts | 14 +++- src/jsx/dom/render.ts | 27 +++--- 5 files changed, 176 insertions(+), 50 deletions(-) diff --git a/src/jsx/base.ts b/src/jsx/base.ts index d3b4d1c2..c2ddfa29 100644 --- a/src/jsx/base.ts +++ b/src/jsx/base.ts @@ -1,7 +1,7 @@ import { raw } from '../helper/html' import { escapeToBuffer, resolveCallbackSync, stringBufferToString } from '../utils/html' import type { HtmlEscaped, HtmlEscapedString, StringBufferWithCallbacks } from '../utils/html' -import { DOM_RENDERER } from './constants' +import { DOM_RENDERER, DOM_MEMO } from './constants' import type { Context } from './context' import { createContext, globalContexts, useContext } from './context' import { domRenderers } from './intrinsic-element/common' @@ -346,7 +346,7 @@ export const jsxFn = ( } } -const shallowEqual = (a: Props, b: Props): boolean => { +export const shallowEqual = (a: Props, b: Props): boolean => { if (a === b) { return true } @@ -373,19 +373,30 @@ const shallowEqual = (a: Props, b: Props): boolean => { return true } +export type MemorableFC = FC & { + [DOM_MEMO]: (prevProps: Readonly, nextProps: Readonly) => boolean +} export const memo = ( component: FC, propsAreEqual: (prevProps: Readonly, nextProps: Readonly) => boolean = shallowEqual ): FC => { let computed: ReturnType> = null let prevProps: T | undefined = undefined - return ((props) => { + const wrapper: MemorableFC = ((props: T) => { if (prevProps && !propsAreEqual(prevProps, props)) { computed = null } prevProps = props return (computed ||= component(props)) - }) as FC + }) as MemorableFC + + // This function is for toString(), but it can also be used for DOM renderer. + // So, set DOM_MEMO and DOM_RENDERER for DOM renderer. + wrapper[DOM_MEMO] = propsAreEqual + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ;(wrapper as any)[DOM_RENDERER] = component + + return wrapper as FC } export const Fragment = ({ diff --git a/src/jsx/constants.ts b/src/jsx/constants.ts index 7d1b0aae..e4644ca8 100644 --- a/src/jsx/constants.ts +++ b/src/jsx/constants.ts @@ -2,4 +2,5 @@ export const DOM_RENDERER = Symbol('RENDERER') export const DOM_ERROR_HANDLER = Symbol('ERROR_HANDLER') export const DOM_STASH = Symbol('STASH') export const DOM_INTERNAL_TAG = Symbol('INTERNAL') +export const DOM_MEMO = Symbol('MEMO') export const PERMALINK = Symbol('PERMALINK') diff --git a/src/jsx/dom/index.test.tsx b/src/jsx/dom/index.test.tsx index 71166ec3..a1876f43 100644 --- a/src/jsx/dom/index.test.tsx +++ b/src/jsx/dom/index.test.tsx @@ -282,7 +282,7 @@ describe('DOM', () => { }) }) - describe('skip build child', () => { + describe('child component', () => { it('simple', async () => { const Child = vi.fn(({ count }: { count: number }) =>
{count}
) const App = () => { @@ -301,11 +301,11 @@ describe('DOM', () => { root.querySelector('button')?.click() await Promise.resolve() expect(root.innerHTML).toBe('
1
0
') - expect(Child).toBeCalledTimes(1) + expect(Child).toBeCalledTimes(2) root.querySelector('button')?.click() await Promise.resolve() expect(root.innerHTML).toBe('
2
1
') - expect(Child).toBeCalledTimes(2) + expect(Child).toBeCalledTimes(3) }) }) @@ -1321,38 +1321,137 @@ describe('DOM', () => { }) }) - it('memo', async () => { - let renderCount = 0 - const Counter = ({ count }: { count: number }) => { - renderCount++ - return ( -
-

Count: {count}

-
+ describe('memo', () => { + it('simple', async () => { + let renderCount = 0 + const Counter = ({ count }: { count: number }) => { + renderCount++ + return ( +
+

Count: {count}

+
+ ) + } + const MemoCounter = memo(Counter) + const App = () => { + const [count, setCount] = useState(0) + return ( +
+ + +
+ ) + } + const app = + render(app, root) + expect(root.innerHTML).toBe('

Count: 0

') + expect(renderCount).toBe(1) + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('

Count: 1

') + expect(renderCount).toBe(2) + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('

Count: 1

') + expect(renderCount).toBe(2) + }) + + it('useState', async () => { + const Child = vi.fn(({ count }: { count: number }) => { + const [count2, setCount2] = useState(0) + return ( + <> +
+ {count} : {count2} +
+ + + ) + }) + const MemoChild = memo(Child) + const App = () => { + const [count, setCount] = useState(0) + return ( + <> + + + + ) + } + render(, root) + expect(root.innerHTML).toBe( + '
0 : 0
' ) - } - const MemoCounter = memo(Counter) - const App = () => { - const [count, setCount] = useState(0) - return ( -
- - -
+ root.querySelector('button#app-button')?.click() + await Promise.resolve() + expect(Child).toBeCalledTimes(1) + expect(root.innerHTML).toBe( + '
0 : 0
' ) - } - const app = - render(app, root) - expect(root.innerHTML).toBe('

Count: 0

') - expect(renderCount).toBe(1) - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('

Count: 1

') - expect(renderCount).toBe(2) - root.querySelector('button')?.click() - await Promise.resolve() - expect(root.innerHTML).toBe('

Count: 1

') - expect(renderCount).toBe(2) + root.querySelector('button#app-button')?.click() + await Promise.resolve() + expect(Child).toBeCalledTimes(2) + expect(root.innerHTML).toBe( + '
1 : 0
' + ) + root.querySelector('button#child-button')?.click() + await Promise.resolve() + expect(Child).toBeCalledTimes(3) + expect(root.innerHTML).toBe( + '
1 : 1
' + ) + }) + + // The react compiler generates code like the following for memoization. + it('react compiler', async () => { + let renderCount = 0 + const Counter = ({ count }: { count: number }) => { + renderCount++ + return ( +
+

Count: {count}

+
+ ) + } + + const App = () => { + const [cache] = useState(() => []) + const [count, setCount] = useState(0) + const countForDisplay = Math.floor(count / 2) + + let localCounter + if (cache[0] !== countForDisplay) { + localCounter = + cache[0] = countForDisplay + cache[1] = localCounter + } else { + localCounter = cache[1] + } + + return ( +
+ {localCounter} + +
+ ) + } + const app = + render(app, root) + expect(root.innerHTML).toBe('

Count: 0

') + expect(renderCount).toBe(1) + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('

Count: 0

') + expect(renderCount).toBe(1) + root.querySelector('button')?.click() + await Promise.resolve() + expect(root.innerHTML).toBe('

Count: 1

') + expect(renderCount).toBe(2) + }) }) describe('useRef', async () => { diff --git a/src/jsx/dom/index.ts b/src/jsx/dom/index.ts index 1949d2da..6d147c1c 100644 --- a/src/jsx/dom/index.ts +++ b/src/jsx/dom/index.ts @@ -3,9 +3,10 @@ * This module provides APIs for `hono/jsx/dom`. */ -import { isValidElement, memo, reactAPICompatVersion } from '../base' -import type { Child, DOMAttributes, JSX, JSXNode, Props } from '../base' +import { isValidElement, reactAPICompatVersion, shallowEqual } from '../base' +import type { Child, DOMAttributes, JSX, JSXNode, Props, FC, MemorableFC } from '../base' import { Children } from '../children' +import { DOM_MEMO } from '../constants' import { useContext } from '../context' import { createRef, @@ -72,6 +73,15 @@ const cloneElement = ( ) as T } +const memo = ( + component: FC, + propsAreEqual: (prevProps: Readonly, nextProps: Readonly) => boolean = shallowEqual +): FC => { + const wrapper = ((props: T) => component(props)) as MemorableFC + wrapper[DOM_MEMO] = propsAreEqual + return wrapper as FC +} + export { reactAPICompatVersion as version, createElement as jsx, diff --git a/src/jsx/dom/render.ts b/src/jsx/dom/render.ts index e23b1f55..877493f8 100644 --- a/src/jsx/dom/render.ts +++ b/src/jsx/dom/render.ts @@ -1,6 +1,12 @@ -import type { Child, FC, JSXNode, Props } from '../base' +import type { Child, FC, JSXNode, Props, MemorableFC } from '../base' import { toArray } from '../children' -import { DOM_ERROR_HANDLER, DOM_INTERNAL_TAG, DOM_RENDERER, DOM_STASH } from '../constants' +import { + DOM_ERROR_HANDLER, + DOM_INTERNAL_TAG, + DOM_MEMO, + DOM_RENDERER, + DOM_STASH, +} from '../constants' import type { Context as JSXContext } from '../context' import { globalContexts as globalJSXContexts, useContext } from '../context' import type { EffectData } from '../hooks' @@ -45,6 +51,7 @@ export type NodeObject = { e: SupportedElement | Text | undefined // rendered element p?: PreserveNodeType // preserve HTMLElement if it will be unmounted a?: boolean // cancel apply() if true + o?: NodeObject // original node [DOM_STASH]: | [ number, // current hook index @@ -516,15 +523,12 @@ export const build = (context: Context, node: NodeObject, children?: Child[]): v oldChild[DOM_STASH][2] = child[DOM_STASH][2] || [] oldChild[DOM_STASH][3] = child[DOM_STASH][3] - if (!oldChild.f) { - const prevPropsKeys = Object.keys(pP) - const currentProps = oldChild.props - if ( - prevPropsKeys.length === Object.keys(currentProps).length && - prevPropsKeys.every((k) => k in currentProps && currentProps[k] === pP[k]) - ) { - oldChild.s = true - } + if ( + !oldChild.f && + ((oldChild.o || oldChild) === child.o || // The code generated by the react compiler is memoized under this condition. + (oldChild.tag as MemorableFC)[DOM_MEMO]?.(pP, oldChild.props)) // The `memo` function is memoized under this condition. + ) { + oldChild.s = true } } child = oldChild @@ -626,6 +630,7 @@ export const buildNode = (node: Child): Node | undefined => { f: (node as NodeObject).f, type: (node as NodeObject).tag, ref: (node as NodeObject).props.ref, + o: (node as NodeObject).o || node, // eslint-disable-next-line @typescript-eslint/no-explicit-any } as any }