Skip to content

Commit

Permalink
Update State.state.merge method docs in statelib.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 17, 2024
1 parent fc38f21 commit eae49e3
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,10 @@ def filter(
def merge(
state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]
) -> State[K, V]:
"""The inverse of :meth:`split() <flax.nnx.State.state.split>`.
"""The inverse of :func:`flax.nnx.State.state.split`.
``merge`` takes one or more ``State``'s and creates
a new ``State``.
``nnx.State.state.merge`` takes one or more :class:`flax.nnx.State`'s
and creates a new ``nnx.State``.
Example usage::
Expand All @@ -398,10 +398,10 @@ def merge(
>>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
Args:
state: A ``State`` object.
*states: Additional ``State`` objects.
state: A :class:`flax.nnx.State` object.
*states: Additional ``nnx.State`` objects.
Returns:
The merged ``State``.
The merged ``nnx.State``.
"""
if not states:
if isinstance(state, State):
Expand Down Expand Up @@ -492,4 +492,4 @@ def create_path_filters(state: State):
if isinstance(value, (variablelib.Variable, variablelib.VariableState)):
value = value.value
value_paths.setdefault(value, set()).add(path)
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}

0 comments on commit eae49e3

Please sign in to comment.