Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade NNX Module sow docs in module.py #4443

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ def sow(
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> None:
"""``sow()`` can be used to collect intermediate values without
the overhead of explicitly passing a container through each Module call.
``sow()`` stores a value in a new ``Module`` attribute, denoted by ``name``.
The value will be wrapped by a :class:`Variable` of type ``variable_type``,
which can be useful to filter for in :func:`split`, :func:`state` and
:func:`pop`.
"""``sow()`` can be used to collect intermediate values without the overhead
of explicitly passing a container through each :class:`flax.nnx.Module` call.
``sow()`` stores a value in a new ``nnx.Module`` attribute, denoted by ``name``.
The value will be wrapped by a :class:`flax.nnx.Variable` of type ``variable_type``,
which can be useful to filter for in :func:`flax.nnx.split`, :func:`flax.nnx.state`
and :func:`flax.nnx.pop`.

By default the values are stored in a tuple and each stored value
is appended at the end. This way all intermediates can be tracked when
the same module is called multiple times.
By default, values are stored in a tuple with each stored value
appended at the end. This allows all intermediates to be tracked when
the same ``nnx.Module`` is called multiple times.

Example usage::

Expand Down Expand Up @@ -126,7 +126,7 @@ def sow(
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

Alternatively, a custom init/reduce function can be passed::
Alternatively, a custom ``init_fn``/``reduce_fn`` function can be passed as follows::

>>> class Model(nnx.Module):
... def __init__(self, rngs):
Expand Down Expand Up @@ -155,16 +155,16 @@ def sow(
>>> assert (model.product.value == intermediate**2).all()

Args:
variable_type: The :class:`Variable` type for the stored value.
Typically :class:`Intermediate` is used to indicate an
variable_type: The :class:`flax.nnx.Variable` type for the stored value.
Typically, :class:`flax.nnx.Intermediate` is used to indicate an
intermediate value.
name: A string denoting the ``Module`` attribute name, where
name: A string denoting the :class:`flax.nnx.Module` attribute name where
the sowed value is stored.
value: The value to be stored.
reduce_fn: The function used to combine the existing value with the new
value. The default is to append the value to a tuple.
value. Defaults to appending the value to a tuple.
init_fn: For the first value stored, ``reduce_fn`` will be passed the result
of ``init_fn`` together with the value to be stored. The default is an
of ``init_fn`` together with the value to be stored. Defaults to an
empty tuple.
"""
if hasattr(self, name):
Expand Down
Loading