diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 663b9a8ef6..b86823c527 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -19,6 +19,7 @@ from flax import struct +from flax.core.frozen_dict import FrozenDict from flax.nnx import ( extract, filterlib, @@ -362,6 +363,22 @@ def value_and_grad( return_value=True, ) +# ----------------------------------------------- +# custom_vjp +# ----------------------------------------------- +# custom_vjp is one of the most complicated transforms as it requires +# to handle 4 different functions: +# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes +# to pytrees and output pytrees to graph nodes. +# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts +# its input pytrees to graph nodes and output graph nodes to pytrees. +# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. Since it might run by itself in a separate context, +# it needs to be aware if the update_context is active or not in order to update the outer +# referenes. +# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. It doesn't need to be aware of the outer context +# since it will never update the outer references as it runs during the backward pass. def _custom_vjp_merge_fn( ctx: graph.MergeContext, @@ -381,16 +398,15 @@ def _custom_vjp_split_fn( prefix: bool | DiffState, value, *, - nondiff_states: deque[extract.GraphDefState], + nondiff_states: list[extract.GraphDefState], ): + broadcast: graph.GraphState if prefix is False: - # pure non-differentiable arg, we pass all the state through - # but we return TreeNode.from_split with a graphdef to we can call from_tree - # on the nondiff args during the backward pass - graphdef, passed = ctx.split(value) - broadcast = State({}) # type: ignore[var-annotated] - nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.NodeStates.from_split(graphdef, passed) + # pure non-differentiable arg, not supported + raise TypeError( + 'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. ' + f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}' + ) elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef @@ -409,23 +425,28 @@ def _custom_vjp_split_fn( return extract.NodeStates.from_states(passed) -class CustomVjpMetadata(struct.PyTreeNode): + nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) +def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]): + if isinstance(x, graph.NodeDef): + assert x.index_mapping is not None + index_mappings.append(x.index_mapping) + return dataclasses.replace(x, index_mapping=None) + return x @dataclasses.dataclass(eq=False) class CustomVjpFnWrapper: f: tp.Callable[..., tp.Any] + jax_nondiff_argnums: tuple[int, ...] ctxtag: str + nondiff_states: list[extract.GraphDefState] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) - ) - metadata, nondiff_states = broadcast + nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( @@ -436,10 +457,22 @@ def __call__(self, *pure_args): out = self.f(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag) + + pure_args_out, pure_out = jax.tree.map( + functools.partial(_extract_index_mappings, index_mappings=index_mappings), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) return pure_args_out, pure_out @@ -447,67 +480,90 @@ def __call__(self, *pure_args): @dataclasses.dataclass(eq=False) class FwdFn: fwd: tp.Callable[..., tp.Any] + nondiff_argnums: tuple[int, ...] ctxtag: str + nondiff_states: list[extract.GraphDefState] def __post_init__(self): functools.update_wrapper(self, self.fwd) def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) + # here we need to be aware if the update_context is active or not + # when its not active, index_mappings will be None + # when its active, we will remove the index_mappings from the NodeDef's and store them + # in the index_mappings deque created by CustomVjp + update_context_active = ( + self.ctxtag in graph.GRAPH_CONTEXT.update_context_stacks ) - metadata, nondiff_states = broadcast + nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), - ctxtag=self.ctxtag, + ctxtag=self.ctxtag if update_context_active else None, ) out, residual = self.fwd(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in self.nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( - (args_out, out), ctxtag=self.ctxtag + (args_out, out), + ctxtag=self.ctxtag if update_context_active else None, ) pure_residual = extract.to_tree(residual) - return (pure_args_out, pure_out), (metadata, pure_residual) + if update_context_active: + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[FrozenDict] = extract.get_broadcast_state( + self.ctxtag + ) + pure_args_out, pure_out = jax.tree.map( + functools.partial( + _extract_index_mappings, index_mappings=index_mappings + ), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) + + return (pure_args_out, pure_out), pure_residual @dataclasses.dataclass(eq=False) class BwdFn: bwd: tp.Callable[..., tp.Any] + tree_node_args: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.bwd) def __call__(self, *args): - res: tuple[CustomVjpMetadata, tp.Any] - pure_g: tuple[tp.Any, tp.Any] - *nondiff, res, pure_g = args - metadata, pure_residual = res - nondiff = extract.from_tree(nondiff) + *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args residual = extract.from_tree(pure_residual) - pure_g = jax.tree.map( + (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, - pure_g, + (pure_args_out_g, pure_out_g), is_leaf=lambda x: isinstance(x, extract.NodeStates), ) - tangent = self.bwd(*nondiff, residual, pure_g) + tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g)) - def state_to_tree_node(is_tree_node: bool, x): - if is_tree_node: - if not isinstance(x, State): + def state_to_node_states(is_differentiable: bool, x): + if is_differentiable: + if isinstance(x, jax.Array): + return x + elif not isinstance(x, State): raise ValueError(f'Expected State, got {type(x)}') return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( - state_to_tree_node, - metadata.tangent_tree_node_args, + state_to_node_states, + self.tree_node_args, tangent, is_leaf=lambda x: isinstance(x, State), ) @@ -521,14 +577,15 @@ def __init__( nondiff_argnums: tuple[int | DiffState, ...], ): functools.update_wrapper(self, fun) - jax_nondiff_argnums = tuple( - x.argnum if isinstance(x, DiffState) else x for x in nondiff_argnums + # first argument is metadata + self.jax_nondiff_argnums = tuple( + x for x in nondiff_argnums if isinstance(x, int) ) self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' - self.custom_vjp_fn = jax.custom_vjp( - CustomVjpFnWrapper(fun, self.ctxtag), - nondiff_argnums=jax_nondiff_argnums, - ) + self.fun = fun + self.fwd: tp.Callable | None = None + self.bwd: tp.Callable | None = None + self.symbolic_zeros: bool | None = None self.nondiff_argnums = nondiff_argnums self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} for argnum in self.nondiff_argnums: @@ -541,16 +598,18 @@ def __init__( else False ) - def __getattr__(self, name: str) -> tp.Any: - return getattr(self.custom_vjp_fn, name) + # def __getattr__(self, name: str) -> tp.Any: + # if not hasattr(self.custom_vjp_fn, name): + # raise AttributeError(f'{type(self).__name__} has no attribute {name}') + # return getattr(self.custom_vjp_fn, name) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: # pytype: disable=invalid-annotation with graph.update_context(self.ctxtag): - args = resolve_kwargs(self.custom_vjp_fn, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) del kwargs - nondiff_states: deque[extract.GraphDefState] = deque() + nondiff_states: list[extract.GraphDefState] = [] arg_filters = tuple( self.diff_filter.get(i, True) for i in range(len(args)) ) @@ -562,24 +621,57 @@ def __call__( ), ctxtag=self.ctxtag, ) - tangent_args = tp.cast( - tuple[tp.Literal[True] | DiffState, ...], - tuple(x for x in arg_filters if x is not False), - ) tree_node_args = jax.tree.map( lambda x: isinstance(x, extract.NodeStates), pure_args, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) - tangent_tree_node_args = tuple( - arg - for arg, is_tree_node in zip(args, tree_node_args) - if is_tree_node is not False + tree_node_args = tuple( + x + for i, x in enumerate(tree_node_args) + if i not in self.jax_nondiff_argnums + ) + index_mappings: deque[FrozenDict] = deque() + with extract.broadcast_state(self.ctxtag, index_mappings): + if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: + raise ValueError() + + custom_vjp_fn = jax.custom_vjp( + fun=CustomVjpFnWrapper( + f=self.fun, + jax_nondiff_argnums=self.jax_nondiff_argnums, + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + ), + nondiff_argnums=self.jax_nondiff_argnums, + ) + custom_vjp_fn.defvjp( + fwd=FwdFn( + fwd=self.fwd, + nondiff_argnums=self.jax_nondiff_argnums, + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + ), + bwd=BwdFn( + bwd=self.bwd, + tree_node_args=tree_node_args, + ), + symbolic_zeros=self.symbolic_zeros, + ) + pure_args_out, pure_out = custom_vjp_fn(*pure_args) + + # insert index_mappings + def _insert_index_mappings(x): + if isinstance(x, graph.NodeDef): + index_mapping: FrozenDict = index_mappings.popleft() + return dataclasses.replace(x, index_mapping=index_mapping) + return x + + pure_args_out, pure_out = jax.tree_util.tree_map( + _insert_index_mappings, + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), ) - metadata = CustomVjpMetadata(tangent_args) - - with extract.broadcast_state(self.ctxtag, (metadata, nondiff_states)): - pure_args_out, pure_out = self.custom_vjp_fn(*pure_args) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=self.ctxtag @@ -593,86 +685,9 @@ def defvjp( bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: - """Define a custom VJP rule for the function represented by this instance. - - Args: - fwd: a Python callable representing the forward pass of the custom VJP - rule. When there are no ``nondiff_argnums``, the ``fwd`` function has - the same input signature as the underlying primal function. It should - return as output a pair, where the first element represents the primal - output and the second element represents any "residual" values to store - from the forward pass for use on the backward pass by the function - ``bwd``. Input arguments and elements of the output pair may be arrays - or nested tuples/lists/dicts thereof. - bwd: a Python callable representing the backward pass of the custom VJP - rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes - two arguments, where the first is the "residual" values produced on the - forward pass by ``fwd``, and the second is the output cotangent with the - same structure as the primal function output. The output of ``bwd`` must - be a tuple of length equal to the number of arguments of the primal - function, and the tuple elements may be arrays or nested - tuples/lists/dicts thereof so as to match the structure of the primal - input arguments. - symbolic_zeros: boolean, determining whether to indicate symbolic zeros - to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom - derivative rules to detect when certain inputs, and when certain - output cotangents, are not involved in differentiation. If ``True``: - - * ``fwd`` must accept, in place of each leaf value ``x`` in - the pytree comprising an argument to the original function, - an object (of type - ``jax.custom_derivatives.CustomVJPPrimal``) with two - attributes instead: ``value`` and ``perturbed``. The - ``value`` field is the original primal argument, and - ``perturbed`` is a boolean. The ``perturbed`` bit indicates - whether the argument is involved in differentiation (i.e., - if it is ``False``, then the corresponding Jacobian "column" - is zero). - - * ``bwd`` will be passed objects representing static symbolic zeros in - its cotangent argument in correspondence with unperturbed values; - otherwise, only standard JAX types (e.g. array-likes) are passed. - - Setting this option to ``True`` allows these rules to detect whether - certain inputs and outputs are not involved in differentiation, but at - the cost of special handling. For instance: - - * The signature of ``fwd`` changes, and the objects it is passed cannot - be output from the rule directly. - - * The ``bwd`` rule is passed objects that are not entirely array-like, - and that cannot be passed to most ``jax.numpy`` functions. - - * Any custom pytree nodes involved in the primal function's arguments - must accept, in their unflattening functions, the two-field record - objects that are given as input leaves to the ``fwd`` rule. - - Default ``False``. - - Returns: - None. - - Examples: - - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) - """ - - self.custom_vjp_fn.defvjp( - FwdFn(fwd, self.ctxtag), - BwdFn(bwd), - symbolic_zeros=symbolic_zeros, - ) + self.fwd = fwd + self.bwd = bwd + self.symbolic_zeros = symbolic_zeros @tp.overload @@ -694,6 +709,14 @@ def custom_vjp( """Reference aware version of `jax.custom_vjp `__. + ``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference + with the JAX version is that, because Modules follow reference semantics, they propagate the State + updates for the inputs as auxiliary outputs. This means that the incomming gradients in the ``bwd`` function + will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of + the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in + ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tanget will be + expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms. + Example:: >>> import jax @@ -713,10 +736,14 @@ def custom_vjp( ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): - ... inputs_g, out_g = g + ... input_updates_g, out_g = g ... cos_x, sin_x, m = res - ... tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g)) - ... return (tangent_m,) + ... (m_updates_g,) = input_updates_g + ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy + ... + ... m_g['x'].value = cos_x * out_g * m.y + ... m_g['y'].value = sin_x * out_g + ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... @@ -735,6 +762,63 @@ def custom_vjp( ) }) + Note that the State objects that represent Module terms on ``input_updates_g`` have the + same shape as the State objects expected in the output tanget. This means that you can + usually just copy them from ``input_updates_g`` and update them with their corresponding + gradient values. + + You can select which substates are differentiable (have a tangent) for Modules and other + graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to + differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following:: + + >>> x_attribute = nnx.PathContains('x') + >>> diff_state = nnx.DiffState(0, x_attribute) + ... + >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) + ... def f(m: Foo): + ... return jnp.sin(m.x) * m.y # type: ignore + + >>> def f_fwd(m: Foo): + ... y = f(m) + ... res = (jnp.cos(m.x), m) # type: ignore + ... return y, res + ... + >>> def f_bwd(res, g): + ... input_updates_g, out_g = g + ... cos_x, m = res + ... (m_updates_g,) = input_updates_g + ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy + ... + ... m_g.x.value = cos_x * out_g * m.y + ... del m_g['y'] # y is not differentiable + ... return (m_g,) + + >>> f.defvjp(f_fwd, f_bwd) + ... + >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) + >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) + ... + >>> jax.tree.map(jnp.shape, grad) + State({ + 'x': VariableState( + type=Param, + value=() + ) + }) + + Note that ``grad`` cannot calculate gradients for states that don't have a tangent + defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute`` + filter to keep ``custom_vjp`` and ``grad`` in sync. + + Args: + fun: Callable base function. + nondiff_argnums: Tuple of integers or DiffState objects specifying the + argument indices that are not differentiated. By default all arguments are + differentiated. Integers cannot be used to mark graph nodes such as Modules + as non-differentiable, in this case use a DiffState object. DiffState objects + define the set of differentiable substates, contrary to what the name of this + argument suggests, this is done for compatibility with ``grad``. + """ if isinstance(fun, Missing): return functools.partial(custom_vjp, nondiff_argnums=nondiff_argnums) diff --git a/pyproject.toml b/pyproject.toml index baab1da052..6c67a21cc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,8 @@ filterwarnings = [ "ignore:.*invalid value encountered in cast.*:RuntimeWarning", # RuntimeWarning: divide by zero encountered in equal/not_equal "ignore:.*divide by zero encountered in.*:RuntimeWarning", + # DeprecationWarning: numpy.core is deprecated + "ignore:.*numpy.core is deprecated.*:DeprecationWarning", ] [tool.coverage.report] diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 5f478c4328..875456b110 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -598,7 +598,7 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): self.assertIn('bias', grads_m2[0]) -class TestCustomVJP(absltest.TestCase): +class TestCustomVJP(parameterized.TestCase): def test_basic_call(self): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) @@ -644,16 +644,16 @@ def f_fwd(m: Foo): return y, res def f_bwd(res, g): - inputs_g, out_g = g + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(inputs_g, tuple) - self.assertLen(inputs_g, 1) - self.assertIsInstance(inputs_g[0], nnx.State) + self.assertIsInstance(m_g, nnx.State) self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) @@ -666,6 +666,92 @@ def f_bwd(res, g): np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore self.assertEqual(m.z, 1) + def test_diff_state(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + x_in_path = nnx.PathContains('x') + diff_state = nnx.DiffState(0, x_in_path) + + @nnx.custom_vjp(nondiff_argnums=(diff_state,)) + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x), m) # type: ignore + return y, res + + def f_bwd(res, g): + (m_g,), out_g = g + cos_x, m = res + + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + m_g.x.value = cos_x * out_g * m.y + del m_g['y'] + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, x_in_path))(m) + + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + self.assertEqual(m.z, 1) + + def test_jax_example_with_remat(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp + @nnx.remat + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x.value) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x.value), jnp.sin(m.x.value), m) # type: ignore + return y, res + + def f_bwd(res, g): + (m_g,), out_g = g + cos_x, sin_x, m = res + + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + @nnx.jit + def loss_fn(m): + return f(m) + + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) + + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) + def test_two_args(self): @dataclasses.dataclass class Foo(nnx.Module): @@ -726,45 +812,49 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp(nondiff_argnums=(1, 2)) - def f(m1: Foo, m2: Foo, m3): - m1.z += 1 - y = jnp.sin(m1.x) * m1.y # type: ignore - return y, m2 + @nnx.custom_vjp(nondiff_argnums=(0, 2)) + def f(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore - def f_fwd(m1: Foo, m2: Foo, m3): - y, m2 = f(m1, m2, m3) - res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) # type: ignore - return (y, m2), res + def f_fwd(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + y = f(a, m, b) + res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + return y, res - def f_bwd(m2, m3, res, g): - (m1_g, m2_g, m3_g), (y_g, _) = g + def f_bwd(a, b, res, g): + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(m1_g, nnx.State) - self.assertIsInstance(m2_g, nnx.State) - self.assertEqual(y_g.shape, ()) + self.assertEqual(a, 1) + self.assertEqual(b, 2) + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) - - return (m1_g,) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) f.defvjp(f_fwd, f_bwd) - m1 = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - m2 = Foo(nnx.Param(jnp.array(3.0)), nnx.Param(jnp.array(4.0)), 0) + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - def loss_fn(m1, m2, m3): - y, m2 = f(m1, m2, m3) - return y + m2.x * m2.y + def loss_fn(m): + a = 1 + b = 2 + return f(a, m, b) - m1_grad: nnx.State - m1_grad = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m1, m2, m2) + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) - np.testing.assert_allclose(m1_grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore - np.testing.assert_allclose(m1_grad['y'].value, jnp.sin(1.0)) # type: ignore - self.assertEqual(m1.z, 1) + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) def test_docs_example(self): import jax.numpy as jnp @@ -794,6 +884,60 @@ def f_bwd(res, g): m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) grads = nnx.grad(f)(m) + @parameterized.parameters( + {'use_custom_vjp': False}, + {'use_custom_vjp': True}, + ) + def test_issue(self, use_custom_vjp: bool): + class MyLinear(nnx.Module): + def __init__( + self, in_features: int, out_features: int, *, rngs: nnx.Rngs + ): + kernel_init = nnx.initializers.normal(in_features**-0.5) + self.kernel = nnx.Param( + kernel_init(rngs.params(), (in_features, out_features), jnp.float32) + ) + self.bias = nnx.Param(jnp.zeros((out_features,), jnp.float32)) + self.n = nnx.BatchStat(jnp.array(0, jnp.uint32)) + + def linear(m: MyLinear, x: jax.Array) -> jax.Array: + m.n.value += 1 + y = x @ m.kernel + m.bias + return y + + def linear_fwd(m: MyLinear, x: jax.Array): + return linear(m, x), (m, x) + + def linear_bwd(res, g): + m, x = res + (m_g, _x_grad), outputs_g = g + kernel_grad = outputs_g[None, :] * x[:, None] + bias_grad = outputs_g + x_grad = m.kernel @ outputs_g + assert x_grad.shape == x.shape, 'Shape mismatch for x' + assert ( + m.kernel.value.shape == kernel_grad.shape + ), 'Shape mismatch for kernel' + assert m.bias.value.shape == bias_grad.shape, 'Shape mismatch for bias' + return (m_g, x_grad) + + if use_custom_vjp: + linear = nnx.custom_vjp(linear) + linear.defvjp(linear_fwd, linear_bwd) + + @nnx.jit + def loss_fn(x, mod): + y = linear(mod, x) + return y.mean() + + mod = MyLinear(10, 5, rngs=nnx.Rngs(0)) + self.assertEqual(mod.n.value, 0) + x = jax.random.normal(jax.random.key(0), (10,)) + loss, grad = nnx.value_and_grad(loss_fn)(x, mod) + self.assertEqual(loss.shape, ()) + self.assertEqual(grad.shape, (10,)) + self.assertEqual(mod.n.value, 1) + class TestScan(absltest.TestCase): def test_basic(self): diff --git a/uv.lock b/uv.lock index 6c9e67edfa..dd56faa629 100644 --- a/uv.lock +++ b/uv.lock @@ -2266,7 +2266,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.7.0" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2283,9 +2283,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b7/5a/e07d3b2a9dacc6fe882a255080d4af3ac180bc190fd8ce22ab64cf0bfe26/orbax_checkpoint-0.7.0.tar.gz", hash = "sha256:f5a59babbf86fdafacddcfd2fb1c6d45b4fa0685b38a87a4598a5702bb70a657", size = 201557 } +sdist = { url = "https://files.pythonhosted.org/packages/66/48/54339d92c2b37f2ddea72501653f1b85a85ca2f19f4102b4b966260c2700/orbax_checkpoint-0.8.0.tar.gz", hash = "sha256:0754ecc2e5fc858e62bbcf610606502d8e1c9ada7295d9bb49cc172f884b0b1e", size = 206396 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/63/45b63b51b320d104f21cb7f2a5d0ae2b37558e24296c02d33521a291ad87/orbax_checkpoint-0.7.0-py3-none-any.whl", hash = "sha256:0469030dd70729f7416981712a9ea8a82bd02c65ca82c933675c9e3ed4763f9b", size = 279660 }, + { url = "https://files.pythonhosted.org/packages/28/35/1a3ec885f192884867c1325920171d67ca2fa9122837ea96af284a2a2f05/orbax_checkpoint-0.8.0-py3-none-any.whl", hash = "sha256:df8e353feb7f4eeba9f5b16f704699df54c3c44c5c6ec4d4d117c40bf27830cc", size = 286357 }, ] [[package]]