Skip to content

Commit

Permalink
Update NNX Module eval docs in module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent fc38f21 commit 7d3faf4
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,13 @@ def train(self, **attributes):
)

def eval(self, **attributes):
"""Sets the Module to evaluation mode.
"""Sets the :class:`flax.nnx.Module` to evaluation mode.
``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True``
and ``use_running_average=True`` of all nested Modules that have these attributes.
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
Modules.
``nnx.Module.eval`` uses :func:`flax.nnx.Module.set_attributes` to recursively set
attributes ``deterministic=True`` and ``use_running_average=True`` of all nested
``nnx.Module``'s that have these attributes.
It is primarily used to control the runtime behavior of the :class:`flax.nnx.Dropout`
and :class:`flax.nnx.BatchNorm` ``nnx.Module``'s.
Example::
Expand All @@ -383,7 +384,7 @@ def eval(self, **attributes):
(True, True)
Args:
**attributes: additional attributes passed to ``set_attributes``.
**attributes: Additional attributes passed to ``set_attributes``.
"""
return self.set_attributes(
deterministic=True,
Expand Down

0 comments on commit 7d3faf4

Please sign in to comment.