From 206782a364e1565d8e62b9deafe27a471dee522e Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Tue, 17 Dec 2024 00:20:07 +0000 Subject: [PATCH] Update NNX BatchStat, Cache, Intermediate class docs in variablelib.py --- flax/nnx/variablelib.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7b..f6c0fe888 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -628,7 +628,7 @@ class Param(Variable[A]): class BatchStat(Variable[A]): """The mean and variance batch statistics stored in - the :class:`BatchNorm` layer. Note, these are not the + the :class:`flax.nnx.BatchNorm` layer. Note that these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:: @@ -662,7 +662,7 @@ class BatchStat(Variable[A]): class Cache(Variable[A]): - """Autoregressive cache in :class:`MultiHeadAttention`:: + """Autoregressive cache in :class:`flax.nnx.MultiHeadAttention`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp @@ -698,8 +698,8 @@ class Cache(Variable[A]): class Intermediate(Variable[A]): - """:class:`Variable` type that is typically used for - :func:`Module.sow`:: + """A :class:`flax.nnx.Variable` type that is typically used for + :func:`flax.nnx.Module.sow`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp