diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 795bb9a088..2f1e23a8ec 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -49,17 +49,17 @@ class ModuleMeta(ObjectMeta): class Module(Object, metaclass=ModuleMeta): - """Base class for all neural network modules. + """A base class for all neural network modules. - Layers and models should subclass this class. + Layers and models should subclass this :class`flax.nnx.Module` class. - ``Module``'s can contain submodules, and in this way can be nested in a tree - structure. Submodules can be assigned as regular attributes inside the - ``__init__`` method. + An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a + JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes + inside the ``__init__`` method. - You can define arbitrary "forward pass" methods on your ``Module`` subclass. + You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass. While no methods are special-cased, ``__call__`` is a popular choice since - you can call the ``Module`` directly:: + you can call the ``nnx.Module`` directly:: >>> from flax import nnx >>> import jax.numpy as jnp