Skip to content

Commit

Permalink
Merge pull request #4499 from google:nnx-improve-module-docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718848543
  • Loading branch information
Flax Authors committed Jan 23, 2025
2 parents b5d4ed8 + bdcc33a commit e3bcc44
Show file tree
Hide file tree
Showing 19 changed files with 323 additions and 306 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and highlight the differences between the two libraries.
from jax import random
import optax
import flax.linen as nn
import haiku as hk

Basic Example
-----------------
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_logit_softcap(
all_outputs = []
for config in [config_soft_cap, config_no_soft_cap]:
transformer = transformer_lib.Transformer(
config=config, rngs=nnx.Rngs(params=0)
config=config, rngs=nnx.Rngs(params=1)
)
cache = transformer.init_cache(
cache_size=cache_size,
Expand Down
22 changes: 11 additions & 11 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,20 +1068,20 @@ class Embed(Module):
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> variables = layer.init(jax.random.key(0), indices_input)
>>> variables
{'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ],
[-0.11768015, -0.54618824, -0.3789283 ],
[ 0.30428642, 0.49511626, 0.01706631],
[-0.0982546 , -0.43055868, 0.20654906],
[-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}}
{'params': {'embedding': Array([[ 0.04396089, -0.9328513 , -0.97328115],
[ 0.41147125, 0.66334754, 0.49469155],
[ 0.09719624, 0.49861377, 0.49519277],
[-0.13316602, 0.6697022 , 0.3710195 ],
[-0.5039532 , 0.287319 , 1.4369922 ]], dtype=float32)}}
>>> # get the first three and last three embeddings
>>> layer.apply(variables, indices_input)
Array([[[-0.28884724, 0.19018005, -0.414205 ],
[-0.11768015, -0.54618824, -0.3789283 ],
[ 0.30428642, 0.49511626, 0.01706631]],
Array([[[ 0.04396089, -0.9328513 , -0.97328115],
[ 0.41147125, 0.66334754, 0.49469155],
[ 0.09719624, 0.49861377, 0.49519277]],
<BLANKLINE>
[[-0.688412 , -0.46882293, 0.26723292],
[-0.0982546 , -0.43055868, 0.20654906],
[ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)
[[-0.5039532 , 0.287319 , 1.4369922 ],
[-0.13316602, 0.6697022 , 0.3710195 ],
[ 0.09719624, 0.49861377, 0.49519277]]], dtype=float32)
Attributes:
num_embeddings: number of embeddings / vocab size.
Expand Down
12 changes: 6 additions & 6 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,18 +2684,18 @@ def perturb(
>>> variables = model.init(jax.random.key(0), x)
>>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y)
>>> print(intm_grads['perturbations']['dense3'])
[[-1.456924 -0.44332537 0.02422847]
[-1.456924 -0.44332537 0.02422847]]
[[-0.04684732 0.06573904 -0.3194327 ]
[-0.04684732 0.06573904 -0.3194327 ]]
If perturbations are not passed to ``apply``, ``perturb`` behaves like a no-op
so you can easily disable the behavior when not needed::
>>> model.apply(variables, x) # works as expected
Array([[-1.0980128 , -0.67961735],
[-1.0980128 , -0.67961735]], dtype=float32)
Array([[-0.04579116, 0.50412744],
[-0.04579116, 0.50412744]], dtype=float32)
>>> model.apply({'params': variables['params']}, x) # behaves like a no-op
Array([[-1.0980128 , -0.67961735],
[-1.0980128 , -0.67961735]], dtype=float32)
Array([[-0.04579116, 0.50412744],
[-0.04579116, 0.50412744]], dtype=float32)
>>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y)
>>> 'perturbations' not in intm_grads
True
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class Dropout(Module):
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[ 0. , -1.1856356, -1.0369378, 0. ]], dtype=float32)
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
Attributes:
rate: the dropout probability. (_not_ the keep rate!)
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class MultiHeadAttention(Module):
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
Attributes:
Args:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
in_features: int or tuple with number of input features.
Expand Down
34 changes: 17 additions & 17 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class LinearGeneral(Module):
>>> y.shape
(16, 4, 5)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
Expand Down Expand Up @@ -301,7 +301,7 @@ class Linear(Module):
)
})
Attributes:
Args:
in_features: the number of input features.
out_features: the number of output features.
use_bias: whether to add a bias to the output (default: True).
Expand Down Expand Up @@ -393,7 +393,7 @@ class Einsum(Module):
>>> y.shape
(16, 11, 8, 4)
Attributes:
Args:
einsum_str: a string to denote the einsum equation. The equation must
have exactly two operands, the lhs being the input passed in, and
the rhs being the learnable kernel. Exactly one of ``einsum_str``
Expand Down Expand Up @@ -572,7 +572,7 @@ class Conv(Module):
... mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
kernel_size: shape of the convolutional kernel. For 1D convolution,
Expand Down Expand Up @@ -823,7 +823,7 @@ class ConvTranspose(Module):
... mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
Attributes:
Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
kernel_size: shape of the convolutional kernel. For 1D convolution,
Expand Down Expand Up @@ -1065,23 +1065,23 @@ class Embed(Module):
State({
'embedding': VariableState( # 15 (60 B)
type=Param,
value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
[ 0.01070483, 0.27923733, 1.7487359 ],
[ 0.59161806, 0.8660184 , 1.2838588 ],
[-0.748139 , -0.15856352, 0.06061118],
[-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32)
value=Array([[ 0.57966787, -0.523274 , -0.43195742],
[-0.676289 , -0.50300646, 0.33996582],
[ 0.41796115, -0.59212935, 0.95934135],
[-1.0917838 , -0.7441663 , 0.07713798],
[-0.66570747, 0.13815777, 1.007365 ]], dtype=float32)
)
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
[ 0.01070483, 0.27923733, 1.7487359 ],
[ 0.59161806, 0.8660184 , 1.2838588 ]],
Array([[[ 0.57966787, -0.523274 , -0.43195742],
[-0.676289 , -0.50300646, 0.33996582],
[ 0.41796115, -0.59212935, 0.95934135]],
<BLANKLINE>
[[-0.4769059 , -0.6607095 , 0.46697947],
[-0.748139 , -0.15856352, 0.06061118],
[ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
[[-0.66570747, 0.13815777, 1.007365 ],
[-1.0917838 , -0.7441663 , 0.07713798],
[ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32)
A parameterized function from integers [0, ``num_embeddings``) to
``features``-dimensional vectors. This ``Module`` will create an ``embedding``
Expand All @@ -1092,7 +1092,7 @@ class Embed(Module):
broadcast the ``embedding`` matrix to input shape with ``features``
dimension appended.
Attributes:
Args:
num_embeddings: number of embeddings / vocab size.
features: number of feature dimensions for each embedding.
dtype: the dtype of the embedding vectors (default: same as embedding).
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LoRA(Module):
>>> y.shape
(16, 4)
Attributes:
Args:
in_features: the number of input features.
lora_rank: the rank of the LoRA dimension.
out_features: the number of output features.
Expand Down Expand Up @@ -133,7 +133,7 @@ class LoRALinear(Linear):
>>> y.shape
(16, 4)
Attributes:
Args:
in_features: the number of input features.
out_features: the number of output features.
lora_rank: the rank of the LoRA dimension.
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class BatchNorm(Module):
>>> assert (batch_stats2['mean'].value == batch_stats3['mean'].value).all()
>>> assert (batch_stats2['var'].value == batch_stats3['var'].value).all()
Attributes:
Args:
num_features: the number of input features.
use_running_average: if True, the stored batch statistics will be
used instead of computing the batch statistics on the input.
Expand Down Expand Up @@ -407,7 +407,7 @@ class LayerNorm(Module):
>>> y = layer(x)
Attributes:
Args:
num_features: the number of input features.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
Expand Down Expand Up @@ -539,7 +539,7 @@ class RMSNorm(Module):
>>> y = layer(x)
Attributes:
Args:
num_features: the number of input features.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
Expand Down Expand Up @@ -670,7 +670,7 @@ class GroupNorm(Module):
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)
Attributes:
Args:
num_features: the number of input features/channels.
num_groups: the total number of channel groups. The default value of 32 is
proposed by the original group normalization paper.
Expand Down
Loading

0 comments on commit e3bcc44

Please sign in to comment.