From fb331ddb4764cdc2973198f54de0048cc26df5b9 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 16 Dec 2024 00:08:37 +0000 Subject: [PATCH] Update NNX state docs in graph.py --- flax/nnx/graph.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index fec21add20..d8c17455f1 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -1443,7 +1443,8 @@ def state( node, *filters: filterlib.Filter, ) -> tp.Union[GraphState, tuple[GraphState, ...]]: - """Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters. + """Similar to :func:`flax.nnx.split` but only returns the :class:`flax.nnx.State`'s indicated + by the NNX ``Filter``'s (``flax.nnx.filterlib``). Example usage:: @@ -1468,9 +1469,9 @@ def state( Args: node: A graph node object. - *filters: One or more :class:`Variable` objects to filter by. + *filters: One or more :class:`flax.nnx.Variable` objects to filter by. Returns: - One or more :class:`State` mappings. + One or more :class:`flax.nnx.State` mappings. """ _, state = flatten(node)