Skip to content

Commit

Permalink
Upgrade Module class docs in module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 16, 2024
1 parent 6bc9858 commit a59852e
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ class ModuleMeta(ObjectMeta):


class Module(Object, metaclass=ModuleMeta):
"""Base class for all neural network modules.
"""A base class for all neural network modules.
Layers and models should subclass this class.
Layers and models should subclass this :class`flax.nnx.Module` class.
``Module``'s can contain submodules, and in this way can be nested in a tree
structure. Submodules can be assigned as regular attributes inside the
``__init__`` method.
An ``nnx.Module`` can contain sub-``Module``'s, allowing them to be nested in a
JAX pytree-like structure. Sub-``Module``'s can be assigned as regular attributes
inside the ``__init__`` method.
You can define arbitrary "forward pass" methods on your ``Module`` subclass.
You can define arbitrary "forward pass" methods on your ``nnx.Module`` subclass.
While no methods are special-cased, ``__call__`` is a popular choice since
you can call the ``Module`` directly::
you can call the ``nnx.Module`` directly::
>>> from flax import nnx
>>> import jax.numpy as jnp
Expand Down

0 comments on commit a59852e

Please sign in to comment.