Skip to content

Commit

Permalink
Upgrade NNX Module sow docs in module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent 6bc9858 commit 028cdca
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ def sow(
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> None:
"""``sow()`` can be used to collect intermediate values without
the overhead of explicitly passing a container through each Module call.
``sow()`` stores a value in a new ``Module`` attribute, denoted by ``name``.
The value will be wrapped by a :class:`Variable` of type ``variable_type``,
which can be useful to filter for in :func:`split`, :func:`state` and
:func:`pop`.
"""``sow()`` can be used to collect intermediate values without the overhead
of explicitly passing a container through each :class:`flax.nnx.Module` call.
``sow()`` stores a value in a new ``nnx.Module`` attribute, denoted by ``name``.
The value will be wrapped by a :class:`flax.nnx.Variable` of type ``variable_type``,
which can be useful to filter for in :func:`flax.nnx.split`, :func:`flax.nnx.state`
and :func:`flax.nnx.pop`.
By default the values are stored in a tuple and each stored value
is appended at the end. This way all intermediates can be tracked when
the same module is called multiple times.
By default, values are stored in a tuple with each stored value
appended at the end. This allows all intermediates to be tracked when
the same ``nnx.Module`` is called multiple times.
Example usage::
Expand Down Expand Up @@ -126,7 +126,7 @@ def sow(
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
Alternatively, a custom init/reduce function can be passed::
Alternatively, a custom ``init_fn``/``reduce_fn`` function can be passed as follows::
>>> class Model(nnx.Module):
... def __init__(self, rngs):
Expand Down Expand Up @@ -155,16 +155,16 @@ def sow(
>>> assert (model.product.value == intermediate**2).all()
Args:
variable_type: The :class:`Variable` type for the stored value.
Typically :class:`Intermediate` is used to indicate an
variable_type: The :class:`flax.nnx.Variable` type for the stored value.
Typically, :class:`flax.nnx.Intermediate` is used to indicate an
intermediate value.
name: A string denoting the ``Module`` attribute name, where
name: A string denoting the :class:`flax.nnx.Module` attribute name where
the sowed value is stored.
value: The value to be stored.
reduce_fn: The function used to combine the existing value with the new
value. The default is to append the value to a tuple.
value. Defaults to appending the value to a tuple.
init_fn: For the first value stored, ``reduce_fn`` will be passed the result
of ``init_fn`` together with the value to be stored. The default is an
of ``init_fn`` together with the value to be stored. Defaults to an
empty tuple.
"""
if hasattr(self, name):
Expand Down

0 comments on commit 028cdca

Please sign in to comment.