From 3a0b1f4dc7dd7b2c136f3575077ee88a14fee55a Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 23 Dec 2024 16:31:02 -0500 Subject: [PATCH] [nnx] cache flatten --- flax/nnx/graph.py | 319 +++++++++++++++++++++++++++-- flax/nnx/nn/stochastic.py | 3 + flax/nnx/rnglib.py | 6 +- flax/nnx/transforms/compilation.py | 34 ++- tests/nnx/graph_utils_test.py | 41 +++- tests/nnx/transforms_test.py | 6 + 6 files changed, 379 insertions(+), 30 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 44750d630b..f8c4d30875 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -20,6 +20,7 @@ import functools import threading import typing as tp +from weakref import WeakKeyDictionary import jax import numpy as np @@ -74,6 +75,9 @@ def __init__( self._mapping: dict[int, tuple[A, B]] = {} self.update(mapping) + def copy(self) -> RefMap[A, B]: + return RefMap(self) + def __getitem__(self, key: A) -> B: return self._mapping[id(key)][1] @@ -368,6 +372,7 @@ def flatten( *, ref_index: RefMap[tp.Any, Index] | None = None, with_paths: tp.Literal[True] = True, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( @@ -376,6 +381,7 @@ def flatten( *, ref_index: RefMap[tp.Any, Index] | None = None, with_paths: tp.Literal[False], + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[Node], list[tp.Any]]: ... def flatten( node: Node, @@ -383,6 +389,7 @@ def flatten( *, ref_index: RefMap[tp.Any, Index] | None = None, with_paths: bool = True, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]] | list[tp.Any]]: """Flattens a graph node into a (graphdef, state) pair. @@ -396,10 +403,79 @@ def flatten( """ if ref_index is None: ref_index = RefMap() - leaves: list[StateLeaf] = [] - path: list[Key] | None = [] if with_paths else None - paths: list[PathParts] | None = [] if with_paths else None - graphdef = _graph_flatten(node, path, ref_index, leaves, paths) + + if node in ref_index: + graphdef = NodeRef(type(node), ref_index[node]) + if with_paths: + return graphdef, FlatState.from_sorted_keys_values([], []) + else: + return graphdef, [] + + # main flatten function + def do_flatten(*, with_paths: bool, return_variables: bool): + assert ref_index is not None + leaves: list[StateLeaf | Variable[tp.Any]] = [] + path: list[Key] | None = [] if with_paths else None + paths: list[PathParts] | None = [] if with_paths else None + graphdef = _graph_flatten( + node, path, ref_index, leaves, paths, return_variables + ) + return graphdef, paths, leaves + + # cache logic + if cache_context is None: + graphdef, paths, leaves = do_flatten( + with_paths=with_paths, return_variables=False + ) + elif node in cache_context: + node_cache = cache_context[node] + cache_fp = node_cache.fingerprint + new_ref_index = RefMap() + node_fp = fingerprint( + node, ref_index=ref_index, new_ref_index=new_ref_index + ) + if cache_fp == node_fp: + graphdef = node_cache.graphdef + + if with_paths: + paths = node_cache.paths + leaves = [variable.to_state() for variable in node_cache.variables] + else: + paths = None + leaves = [variable.raw_value for variable in node_cache.variables] + + # add the new references to the ref_index + ref_index.update(new_ref_index) + else: + graphdef, paths, variables = do_flatten( + with_paths=True, return_variables=True + ) + variables = tp.cast(list[Variable[tp.Any]], variables) + assert paths is not None + if with_paths: + leaves = [variable.to_state() for variable in variables] + else: + leaves = [variable.raw_value for variable in variables] + node_cache = CacheContext.create( + node_fp, graphdef, paths, variables, new_ref_index + ) + else: # node not in cache_context + new_ref_index = RefMap() + node_fp = fingerprint( + node, ref_index=ref_index, new_ref_index=new_ref_index + ) + graphdef, paths, variables = do_flatten( + with_paths=True, return_variables=True + ) + variables = tp.cast(list[Variable[tp.Any]], variables) + assert paths is not None + if with_paths: + leaves = [variable.to_state() for variable in variables] + else: + leaves = [variable.raw_value for variable in variables] + node_cache = CacheContext.create( + node_fp, graphdef, paths, variables, new_ref_index + ) if paths is not None: return graphdef, FlatState.from_sorted_keys_values(paths, leaves) @@ -411,8 +487,9 @@ def _graph_flatten( node: Node, path: list[Key] | None, ref_index: RefMap[tp.Any, Index], - leaves: list[StateLeaf], + leaves: list[StateLeaf | Variable[tp.Any]], paths: list[PathParts] | None, + return_variables: bool, ) -> NodeDef[Node] | NodeRef: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') @@ -438,13 +515,17 @@ def _graph_flatten( if path is not None: path.append(key) if is_node(value): - nodedef = _graph_flatten(value, path, ref_index, leaves, paths) + nodedef = _graph_flatten( + value, path, ref_index, leaves, paths, return_variables + ) attributes.append((key, nodedef)) elif isinstance(value, Variable): if value in ref_index: attributes.append((key, NodeRef(type(value), ref_index[value]))) else: - if path is None: + if return_variables: + leaf = value + elif path is None: leaf = value.raw_value else: leaf = value.to_state() @@ -481,6 +562,84 @@ def _graph_flatten( ) return nodedef +def fingerprint( + node, + /, + *, + ref_index: RefMap[tp.Any, Index] | None = None, + new_ref_index: RefMap[tp.Any, Index] | None = None, +) -> tuple[tp.Any, ...]: + """ """ + if ref_index is None: + ref_index = RefMap() + + if new_ref_index is None: + new_ref_index = RefMap() + fp = _graph_fingerprint(node, ref_index, new_ref_index) + return fp + + +def _graph_fingerprint( + node, + ref_index: RefMap[tp.Any, Index], + new_ref_index: RefMap[tp.Any, Index], +) -> tuple[tp.Any, ...]: + if not is_node(node): + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + + if node in ref_index: + return (id(node), type(node), ref_index[node]) + elif node in new_ref_index: + return (id(node), type(node), new_ref_index[node]) + + node_impl = get_node_impl(node) + + # only cache graph nodes + if isinstance(node_impl, GraphNodeImpl): + index = len(ref_index) + len(new_ref_index) + new_ref_index[node] = index + else: + index = -1 + + attributes: list[tuple[tp.Any, ...]] = [] + + values, metadata = node_impl.flatten(node) + for key, value in values: + if is_node(value): + node_fp = _graph_fingerprint(value, ref_index, new_ref_index) + attributes.append((key, node_fp)) + elif isinstance(value, Variable): + if value in ref_index: + attributes.append((key, id(value), type(value), ref_index[value])) + elif value in new_ref_index: + attributes.append((key, id(value), type(value), new_ref_index[value])) + else: + variable_index = new_ref_index[value] = len(ref_index) + # the fingerprint must be sensitive to Variable identity + attributes.append( + ( + key, + id(value), + type(value), + variable_index, + tuple(value._var_metadata.items()), + ) + ) + else: + if isinstance(value, (jax.Array, np.ndarray)): + raise ValueError(f'Arrays leaves are not supported: {value}') + attributes.append((key, value)) + + node_fp = ( + id(node), + node_impl.type, + index, + tuple(attributes), + metadata, + ) + return node_fp + + def _get_sorted_leaves( xs: tp.Mapping[tp.Any, tp.Any], ) -> deque[tp.Any]: @@ -507,6 +666,7 @@ def unflatten( *, index_ref: dict[Index, tp.Any] | None = None, index_ref_cache: dict[Index, tp.Any] | None = None, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -533,14 +693,55 @@ def unflatten( raise ValueError(f'Unsupported state type: {type(state)}') if index_ref is None: index_ref = {} - assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten(graphdef, leaves, index_ref, index_ref_cache) - if leaves: + + if isinstance(graphdef, NodeRef): + return index_ref[graphdef.index] + + assert isinstance(graphdef, NodeDef) + + def do_unflatten(): + node = _graph_unflatten(graphdef, leaves, index_ref, index_ref_cache) + if leaves: + raise ValueError( + f'Incorrect number of leaves: got an extra {len(leaves)} leaves in the state' + ) + return node + + if cache_context is None: + node = do_unflatten() + elif index_ref_cache is None: raise ValueError( - f'Incorrect number of leaves: got an extra {len(leaves)} leaves in the state' + 'index_ref_cache must be provided when cache_context is used.' ) + elif graphdef.index in index_ref_cache: + node = index_ref_cache[graphdef.index] + if node in cache_context: + # node is in cache_context, retrieve its cache + cache = cache_context[node] + assert graphdef.index_mapping is not None + + # check if the graphdef is the same and index_mapping maps to the same references + graphdef_fp = dataclasses.replace(graphdef, index_mapping=None) + if cache.graphdef == graphdef_fp and all( + a == b for a, b in graphdef.index_mapping.items() + ): + # graphdefs match, update variables from state + for variable, leaf in zip(cache.variables, leaves): + variable.raw_value = leaf + index_ref.update(cache.new_ref_index) + else: # cache.graphdef != graphdef_fp + # graph changed, re-create the node + node = do_unflatten() + else: # node not in cache_context + # all nodes in index_ref_cache must be in cache_context + raise RuntimeError(f'Node not found in cache_context, node: {node}') + else: + # its a new node, create it + node = do_unflatten() + return node + def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], leaves: deque[tp.Any], @@ -755,6 +956,28 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): # UpdateContext # -------------------------------------------------------- +class CacheContext(tp.NamedTuple): + fingerprint: tuple[tp.Any, ...] + graphdef: GraphDef[tp.Any] + paths: list[PathParts] + variables: list[Variable[tp.Any]] + new_ref_index: RefMap[tp.Any, Index] + new_index_ref: dict[Index, tp.Any] + + @staticmethod + def create( + fingerprint: tuple[tp.Any, ...], + graphdef: GraphDef[tp.Any], + paths: list[PathParts], + variables: list[Variable[tp.Any]], + new_ref_index: RefMap[tp.Any, Index], + ) -> CacheContext: + new_index_ref = {index: obj for obj, index in new_ref_index.items()} + return CacheContext( + fingerprint, graphdef, paths, variables, new_ref_index, new_index_ref + ) + + @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( @@ -762,6 +985,9 @@ class GraphContext(threading.local): ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) + cache_context: WeakKeyDictionary[ + tp.Callable, WeakKeyDictionary[tp.Any, CacheContext] + ] = dataclasses.field(default_factory=WeakKeyDictionary) GRAPH_CONTEXT = GraphContext() @@ -773,10 +999,21 @@ class SplitContext: ref_index: RefMap[tp.Any, Index] @tp.overload - def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... + def split( + self, + graph_node: A, + /, + *, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, + ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( - self, graph_node: A, first: filterlib.Filter, / + self, + graph_node: A, + first: filterlib.Filter, + /, + *, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( @@ -786,14 +1023,20 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... def split( - self, node: A, *filters: filterlib.Filter + self, + node: A, + *filters: filterlib.Filter, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, flat_state = flatten(node, ref_index=self.ref_index) + graphdef, flat_state = flatten( + node, ref_index=self.ref_index, cache_context=cache_context + ) flat_states = _split_state(flat_state, filters) states = tuple( State.from_flat_path(flat_state) for flat_state in flat_states @@ -809,15 +1052,29 @@ def split( @tp.overload def flatten( - self, graph_node: A, /, *, with_paths: tp.Literal[False] + self, + graph_node: A, + /, + *, + with_paths: tp.Literal[False], + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], list[tp.Any]]: ... @tp.overload def flatten( - self, graph_node: A, / + self, + graph_node: A, + /, + *, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( - self, graph_node: A, first: filterlib.Filter, / + self, + graph_node: A, + first: filterlib.Filter, + /, + *, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... @tp.overload def flatten( @@ -827,13 +1084,18 @@ def flatten( second: filterlib.Filter, /, *filters: filterlib.Filter, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[ GraphDef[A], FlatState[VariableState[tp.Any]], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], ]: ... def flatten( - self, node: A, *filters: filterlib.Filter, with_paths: bool = True + self, + node: A, + *filters: filterlib.Filter, + with_paths: bool = True, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> tuple[ GraphDef[A], FlatState[VariableState[tp.Any]] | list[tp.Any], @@ -844,14 +1106,20 @@ def flatten( ) if with_paths: graphdef, flat_state = flatten( - node, ref_index=self.ref_index, with_paths=True + node, + ref_index=self.ref_index, + with_paths=True, + cache_context=cache_context, ) flat_states = _split_state(flat_state, filters) else: if filters: raise ValueError('Cannot use filters with with_paths=False') graphdef, flat_state = flatten( - node, ref_index=self.ref_index, with_paths=False + node, + ref_index=self.ref_index, + with_paths=False, + cache_context=cache_context, ) flat_states = (flat_state,) @@ -888,7 +1156,12 @@ class MergeContext: index_ref: dict[Index, tp.Any] def merge( - self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState + self, + graphdef: GraphDef[A], + state: GraphState, + /, + *states: GraphState, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None @@ -913,6 +1186,7 @@ def merge( state, index_ref=self.index_ref, index_ref_cache=index_ref_cache, + cache_context=cache_context, ) return node @@ -922,6 +1196,7 @@ def unflatten( flat_state: GraphFlatState | list[tp.Any], /, *flat_states: GraphFlatState, + cache_context: WeakKeyDictionary[tp.Any, CacheContext] | None = None, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index 2a495826a4..737c6e3102 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -125,3 +125,6 @@ def __call__( mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + def __hash__(self): + return id(self) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 17bbaf37c8..bc4b551972 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -80,7 +80,7 @@ def __call__(self) -> jax.Array: ] -class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]): +class Rngs(Object): """NNX rng container class. To instantiate the ``Rngs``, pass in an integer, specifying the starting seed. ``Rngs`` can have different "streams", allowing the user to generate different @@ -237,6 +237,10 @@ def __getstate__(self): def __setstate__(self, state): vars(self).update(state) + def items(self): + for name in self: + yield name, self[name] + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 4763ea8491..655a1be2ab 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -16,6 +16,7 @@ import dataclasses import functools import typing as tp +from weakref import WeakKeyDictionary from flax.nnx import ( extract, @@ -88,7 +89,7 @@ def __hash__(self): return hash((self.filters, self.shardings)) -def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): +def _inner_jit_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): return extract.NodeStates.from_split( *ctx.flatten(x, *prefix.filters), metadata=prefix @@ -96,7 +97,7 @@ def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False)) -def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any: +def _inner_jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any: if not isinstance(leaf, extract.NodeStates): raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') return ctx.unflatten(leaf.graphdef, *leaf.states) @@ -114,7 +115,7 @@ def __post_init__(self): def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( - (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag='jit' + (pure_args, pure_kwargs), merge_fn=_inner_jit_merge_fn, ctxtag='jit' ) out = self.f(*args, **kwargs) @@ -124,7 +125,7 @@ def __call__(self, *pure_args, **pure_kwargs): (args_out, kwargs_out, out), prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings), ctxtag='jit', - split_fn=_jit_split_fn, + split_fn=_inner_jit_split_fn, ) return pure_args_out, pure_kwargs_out, pure_out @@ -343,10 +344,31 @@ def jit( @functools.wraps(fun) @graph.update_context('jit') def jit_wrapper(*args, **kwargs): + if jit_wrapper not in graph.GRAPH_CONTEXT.cache_context: + graph.GRAPH_CONTEXT.cache_context[jit_wrapper] = WeakKeyDictionary() + jit_cache = graph.GRAPH_CONTEXT.cache_context[jit_wrapper] + + def _outer_jit_split_fn(ctx: graph.SplitContext, path, prefix, x): + if isinstance(prefix, StateSharding): + return extract.NodeStates.from_split( + *ctx.flatten(x, *prefix.filters, cache_context=jit_cache), + metadata=prefix, + ) + return extract.NodeStates.from_split( + *ctx.flatten(x, cache_context=jit_cache, with_paths=False) + ) + + def _outer_jit_merge_fn( + ctx: graph.MergeContext, path, prefix, leaf + ) -> tp.Any: + if not isinstance(leaf, extract.NodeStates): + raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') + return ctx.unflatten(leaf.graphdef, *leaf.states, cache_context=jit_cache) + pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(in_shardings, kwarg_shardings), - split_fn=_jit_split_fn, + split_fn=_outer_jit_split_fn, check_aliasing=in_shardings is not None, ctxtag='jit', ) @@ -355,7 +377,7 @@ def jit_wrapper(*args, **kwargs): ) _args_out, _kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), - merge_fn=_jit_merge_fn, + merge_fn=_outer_jit_merge_fn, ctxtag='jit', ) return out diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index daad0873a1..c0330d7764 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -321,7 +321,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree + assert graphdef.attributes[0][1].type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state) @@ -817,6 +817,45 @@ def f(*pure_args): self.assertIs(m1, args_out[2]['b']) self.assertIs(m2, args_out[1]) + def test_fingerprint_basic(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m) + m1_hash = hash(fp1) + self.assertIsInstance(m1_hash, int) + + fp2 = nnx.graph.fingerprint(m) + m2_hash = hash(fp2) + + self.assertEqual(fp1, fp2) + self.assertEqual(m1_hash, m2_hash) + + def test_fingerprint_variable_id_sensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m1) + m1_hash = hash(fp1) + + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp2 = nnx.graph.fingerprint(m2) + m2_hash = hash(fp2) + + self.assertNotEqual(fp1, fp2) + self.assertNotEqual(m1_hash, m2_hash) + + def test_fingerprint_module_id_insensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + m1.kernel = m2.kernel + m1.bias = m2.bias + + fp1 = nnx.graph.fingerprint(m1) + m1_hash = hash(fp1) + fp2 = nnx.graph.fingerprint(m2) + m2_hash = hash(fp2) + + self.assertNotEqual(fp1, fp2) + self.assertNotEqual(m1_hash, m2_hash) + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 736da9acf0..b775ccdde4 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -715,6 +715,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + @nnx.custom_vjp @nnx.remat def f(m: Foo): @@ -3081,6 +3084,9 @@ def test_basic(self): class Foo(nnx.Module): a: nnx.Param + def __hash__(self): + return id(self) + @nnx.jit def f(m): y = jnp.sin(m.a.value) # error