Skip to content

Commit

Permalink
[nnx] improve Module docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 22, 2025
1 parent e4418e2 commit 9b3ea6b
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 241 deletions.
2 changes: 1 addition & 1 deletion flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class MultiHeadAttention(Module):
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
Attributes:
Args:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
in_features: int or tuple with number of input features.
Expand Down
12 changes: 6 additions & 6 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class LinearGeneral(Module):
>>> y.shape
(16, 4, 5)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
Expand Down Expand Up @@ -301,7 +301,7 @@ class Linear(Module):
)
})
Attributes:
Args:
in_features: the number of input features.
out_features: the number of output features.
use_bias: whether to add a bias to the output (default: True).
Expand Down Expand Up @@ -393,7 +393,7 @@ class Einsum(Module):
>>> y.shape
(16, 11, 8, 4)
Attributes:
Args:
einsum_str: a string to denote the einsum equation. The equation must
have exactly two operands, the lhs being the input passed in, and
the rhs being the learnable kernel. Exactly one of ``einsum_str``
Expand Down Expand Up @@ -572,7 +572,7 @@ class Conv(Module):
... mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
kernel_size: shape of the convolutional kernel. For 1D convolution,
Expand Down Expand Up @@ -823,7 +823,7 @@ class ConvTranspose(Module):
... mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
kernel_size: shape of the convolutional kernel. For 1D convolution,
Expand Down Expand Up @@ -1092,7 +1092,7 @@ class Embed(Module):
broadcast the ``embedding`` matrix to input shape with ``features``
dimension appended.
Attributes:
Args:
num_embeddings: number of embeddings / vocab size.
features: number of feature dimensions for each embedding.
dtype: the dtype of the embedding vectors (default: same as embedding).
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LoRA(Module):
>>> y.shape
(16, 4)
Attributes:
Args:
in_features: the number of input features.
lora_rank: the rank of the LoRA dimension.
out_features: the number of output features.
Expand Down Expand Up @@ -133,7 +133,7 @@ class LoRALinear(Linear):
>>> y.shape
(16, 4)
Attributes:
Args:
in_features: the number of input features.
out_features: the number of output features.
lora_rank: the rank of the LoRA dimension.
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class BatchNorm(Module):
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
Attributes:
Args:
num_features: the number of input features.
use_running_average: if True, the stored batch statistics will be
used instead of computing the batch statistics on the input.
Expand Down Expand Up @@ -407,7 +407,7 @@ class LayerNorm(Module):
>>> y = layer(x)
Attributes:
Args:
num_features: the number of input features.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
Expand Down Expand Up @@ -539,7 +539,7 @@ class RMSNorm(Module):
>>> y = layer(x)
Attributes:
Args:
num_features: the number of input features.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
Expand Down Expand Up @@ -670,7 +670,7 @@ class GroupNorm(Module):
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
Attributes:
Args:
num_features: the number of input features/channels.
num_groups: the total number of channel groups. The default value of 32 is
proposed by the original group normalization paper.
Expand Down
Loading

0 comments on commit 9b3ea6b

Please sign in to comment.