From 55835d6d5eaf794304eeb470a07762e78185bbed Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Tue, 17 Dec 2024 00:24:32 +0000 Subject: [PATCH] Update NNX State class docs in statelib.py --- flax/nnx/statelib.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 42a2604042..8bdbf47485 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -105,9 +105,9 @@ def _flat_state_pytree_unflatten( class State(MutableMapping[K, V], reprlib.Representable): - """A pytree-like structure that contains a ``Mapping`` from hashable and - comparable keys to leaves. Leaves can be of any type but :class:`VariableState` - and :class:`Variable` are the most common. + """A JAX pytree-like structure that contains a ``Mapping`` from hashable + and comparable keys to pytree leaves. Pytree leaves can be of any type + but :class:`flax.nnx.VariableState` and :class:`flax.nnx.Variable` are the most common. """ def __init__( @@ -492,4 +492,4 @@ def create_path_filters(state: State): if isinstance(value, (variablelib.Variable, variablelib.VariableState)): value = value.value value_paths.setdefault(value, set()).add(path) - return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file + return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}