diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..28ceab870b 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -87,16 +87,16 @@ def sow( reduce_fn: tp.Callable[[B, A], B] = tuple_reduce, init_fn: tp.Callable[[], B] = tuple_init, # type: ignore ) -> None: - """``sow()`` can be used to collect intermediate values without - the overhead of explicitly passing a container through each Module call. - ``sow()`` stores a value in a new ``Module`` attribute, denoted by ``name``. - The value will be wrapped by a :class:`Variable` of type ``variable_type``, - which can be useful to filter for in :func:`split`, :func:`state` and - :func:`pop`. + """``sow()`` can be used to collect intermediate values without the overhead + of explicitly passing a container through each :class:`flax.nnx.Module` call. + ``sow()`` stores a value in a new ``nnx.Module`` attribute, denoted by ``name``. + The value will be wrapped by a :class:`flax.nnx.Variable` of type ``variable_type``, + which can be useful to filter for in :func:`flax.nnx.split`, :func:`flax.nnx.state` + and :func:`flax.nnx.pop`. - By default the values are stored in a tuple and each stored value - is appended at the end. This way all intermediates can be tracked when - the same module is called multiple times. + By default, values are stored in a tuple with each stored value + appended at the end. This allows all intermediates to be tracked when + the same ``nnx.Module`` is called multiple times. Example usage:: @@ -126,7 +126,7 @@ def sow( >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all() - Alternatively, a custom init/reduce function can be passed:: + Alternatively, a custom ``init_fn``/``reduce_fn`` function can be passed as follows:: >>> class Model(nnx.Module): ... def __init__(self, rngs): @@ -155,16 +155,16 @@ def sow( >>> assert (model.product.value == intermediate**2).all() Args: - variable_type: The :class:`Variable` type for the stored value. - Typically :class:`Intermediate` is used to indicate an + variable_type: The :class:`flax.nnx.Variable` type for the stored value. + Typically, :class:`flax.nnx.Intermediate` is used to indicate an intermediate value. - name: A string denoting the ``Module`` attribute name, where + name: A string denoting the :class:`flax.nnx.Module` attribute name where the sowed value is stored. value: The value to be stored. reduce_fn: The function used to combine the existing value with the new - value. The default is to append the value to a tuple. + value. Defaults to appending the value to a tuple. init_fn: For the first value stored, ``reduce_fn`` will be passed the result - of ``init_fn`` together with the value to be stored. The default is an + of ``init_fn`` together with the value to be stored. Defaults to an empty tuple. """ if hasattr(self, name):