Skip to content

Commit

Permalink
Don't create param in normalization layers instead of create None-val…
Browse files Browse the repository at this point in the history
…ue params.

This makes these layers align better with behavior of Linen layers, and also reduce confusion.

PiperOrigin-RevId: 719036051
  • Loading branch information
IvyZX authored and Flax Authors committed Jan 24, 2025
1 parent d28f03f commit 752c7a9
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,14 @@ def __init__(
self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32))
self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32))

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
Expand Down Expand Up @@ -368,8 +370,8 @@ def __call__(
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
reduction_axes,
feature_axes,
self.dtype,
Expand Down Expand Up @@ -454,12 +456,14 @@ def __init__(
):
feature_shape = (num_features,)

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
Expand Down Expand Up @@ -503,8 +507,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
self.reduction_axes,
self.feature_axes,
self.dtype,
Expand Down Expand Up @@ -582,6 +586,7 @@ def __init__(
):
feature_shape = (num_features,)

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
Expand Down Expand Up @@ -624,7 +629,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.scale.value if self.scale else None,
None,
self.reduction_axes,
self.feature_axes,
Expand Down Expand Up @@ -757,12 +762,14 @@ def __init__(
self.group_size = num_features // num_groups

feature_shape = (num_features,)
self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
Expand Down Expand Up @@ -822,8 +829,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
reduction_axes[:-1],
(self.feature_axis,),
self.dtype,
Expand Down

0 comments on commit 752c7a9

Please sign in to comment.