Skip to content

Commit

Permalink
Update NNX State.state.filter method docs in statelib.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 17, 2024
1 parent fc38f21 commit e9b64a1
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,11 @@ def filter(
/,
*filters: filterlib.Filter,
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Filter a ``State`` into one or more ``State``'s. The
user must pass at least one ``Filter`` (i.e. :class:`Variable`).
This method is similar to :meth:`split() <flax.nnx.State.state.split>`,
except the filters can be non-exhaustive.
"""Filters a :class:`flax.nnx.State` into one or more ``nnx.State``'s.
You must pass at least one NNX ``Filter`` (``flax.nnx.filterlib``)
(i.e. :class:`flax.nnx. Variable`).
This method is similar to :func:`flax.nnx.State.state.split`,
except the ``Filter``'s can be non-exhaustive.
Example usage::
Expand All @@ -351,10 +352,11 @@ def filter(
>>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
Arguments:
first: The first filter
*filters: The optional, additional filters to group the state into mutually exclusive substates.
first: The first NNX ``Filter``.
*filters: The optional, additional NNX ``Filter``'s to group the :class:`flax.nnx.State`
into mutually exclusive sub-``State``'s.
Returns:
One or more ``States`` equal to the number of filters passed.
One or more ``nnx.State``'s equal to the number of NNX ``Filter``'s passed.
"""
*states_, _rest = _split_state(self.flat_state(), first, *filters)

Expand Down Expand Up @@ -492,4 +494,4 @@ def create_path_filters(state: State):
if isinstance(value, (variablelib.Variable, variablelib.VariableState)):
value = value.value
value_paths.setdefault(value, set()).add(path)
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}

0 comments on commit e9b64a1

Please sign in to comment.