Skip to content

Commit

Permalink
Update NNX BatchStat, Cache, Intermediate class docs in variablelib.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 17, 2024
1 parent fc38f21 commit 206782a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ class Param(Variable[A]):

class BatchStat(Variable[A]):
"""The mean and variance batch statistics stored in
the :class:`BatchNorm` layer. Note, these are not the
the :class:`flax.nnx.BatchNorm` layer. Note that these are not the
learnable scale and bias parameters, but rather the
running average statistics that are typically used
during post-training inference::
Expand Down Expand Up @@ -662,7 +662,7 @@ class BatchStat(Variable[A]):


class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
"""Autoregressive cache in :class:`flax.nnx.MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
Expand Down Expand Up @@ -698,8 +698,8 @@ class Cache(Variable[A]):


class Intermediate(Variable[A]):
""":class:`Variable` type that is typically used for
:func:`Module.sow`::
"""A :class:`flax.nnx.Variable` type that is typically used for
:func:`flax.nnx.Module.sow`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
Expand Down

0 comments on commit 206782a

Please sign in to comment.