Skip to content

Commit

Permalink
Merge pull request #4346 from google:update-state-docstrings
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691845679
  • Loading branch information
Flax Authors committed Oct 31, 2024
2 parents b8bdafb + 13b4077 commit 8292d9c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ graph
.. autofunction:: update
.. autofunction:: pop
.. autofunction:: state
.. autofunction:: variables
.. autofunction:: graph
.. autofunction:: graphdef
.. autofunction:: iter_graph
Expand Down
19 changes: 19 additions & 0 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,25 @@ def variables(
node,
*filters: filterlib.Filter,
) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]:
"""Similar to :func:`state` but returns the current :class:`Variable` objects instead
of new :class:`VariableState` instances.
Example::
>>> from flax import nnx
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> params = nnx.variables(model, nnx.Param)
...
>>> assert params['kernel'] is model.kernel
>>> assert params['bias'] is model.bias
Args:
node: A graph node object.
*filters: One or more :class:`Variable` objects to filter by.
Returns:
One or more :class:`State` mappings containing the :class:`Variable` objects.
"""
num_filters = len(filters)
if num_filters == 0:
filters = (..., ...)
Expand Down
9 changes: 4 additions & 5 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ def __treescope_repr__(self, path, subtree_renderer):


class State(MutableMapping[K, V], reprlib.Representable):
"""A pytree-like structure that contains a ``Mapping`` from strings or
integers to leaves. A valid leaf type is either :class:`Variable`,
``jax.Array``, ``numpy.ndarray`` or nested ``State``'s. A ``State``
can be generated by either calling :func:`split` or :func:`state` on
the :class:`Module`."""
"""A pytree-like structure that contains a ``Mapping`` from hashable and
comparable keys to leaves. Leaves can be of any type but :class:`VariableState`
and :class:`Variable` are the most common.
"""

def __init__(
self,
Expand Down
15 changes: 1 addition & 14 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8292d9c

Please sign in to comment.