diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 921a030f25..81f68240b1 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -287,12 +287,14 @@ def __init__( self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32)) self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32)) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: self.scale = nnx.Param(None) + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) @@ -368,8 +370,8 @@ def __call__( x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, reduction_axes, feature_axes, self.dtype, @@ -454,12 +456,14 @@ def __init__( ): feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: self.scale = nnx.Param(None) + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) @@ -503,8 +507,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, self.reduction_axes, self.feature_axes, self.dtype, @@ -582,6 +586,7 @@ def __init__( ): feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) @@ -624,7 +629,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, + self.scale.value if self.scale else None, None, self.reduction_axes, self.feature_axes, @@ -757,12 +762,14 @@ def __init__( self.group_size = num_features // num_groups feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) else: self.scale = nnx.Param(None) + self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) @@ -822,8 +829,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, reduction_axes[:-1], (self.feature_axis,), self.dtype,