From d6a87ceaa1e2ae97524858e1b8fdbf29cb659824 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:51:02 -0800 Subject: [PATCH] Update NNX Module train docs in module.py --- flax/nnx/module.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..826c4b1b81 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -322,12 +322,13 @@ def set_attributes( ) def train(self, **attributes): - """Sets the Module to training mode. + """Sets the :class:`flax.nnx.Module` to training mode. - ``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False`` - and ``use_running_average=False`` of all nested Modules that have these attributes. - Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` - Modules. + ``nnx.Module.train`` uses :func:`flax.nnx.Module.set_attributes`` to recursively set + attributes ``deterministic=False`` and ``use_running_average=False`` of all nested + ``nnx.Module``'s that have these attributes. It is primarily used to control the + runtime behavior of the :class:`flax.nnx.Dropout` and :class:`flax.nnx.BatchNorm` + ``nnx.Module``'s. Example:: @@ -348,7 +349,7 @@ def train(self, **attributes): (False, False) Args: - **attributes: additional attributes passed to ``set_attributes``. + **attributes: Additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=False,