From 8afe0657fe57e0864aa64908fe9cc75bc88cf410 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 23 Jan 2025 15:28:28 -0800 Subject: [PATCH] Don't create param in normalization layers instead of create None-value params. This makes these layers align better with behavior of Linen layers, and also reduce confusion. PiperOrigin-RevId: 719036051 --- flax/nnx/nn/normalization.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 921a030f25..07dcf8dc93 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -290,14 +290,10 @@ def __init__( if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) - else: - self.scale = nnx.Param(None) if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) - else: - self.bias = nnx.Param(None) self.num_features = num_features self.use_running_average = use_running_average @@ -368,8 +364,8 @@ def __call__( x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.use_scale else None, + self.bias.value if self.use_bias else None, reduction_axes, feature_axes, self.dtype, @@ -457,14 +453,10 @@ def __init__( if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) - else: - self.scale = nnx.Param(None) if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) - else: - self.bias = nnx.Param(None) self.num_features = num_features self.epsilon = epsilon @@ -503,8 +495,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.use_scale else None, + self.bias.value if self.use_bias else None, self.reduction_axes, self.feature_axes, self.dtype, @@ -585,8 +577,6 @@ def __init__( if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) - else: - self.scale = nnx.Param(None) self.num_features = num_features self.epsilon = epsilon @@ -624,7 +614,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, + self.scale.value if self.use_scale else None, None, self.reduction_axes, self.feature_axes, @@ -760,14 +750,10 @@ def __init__( if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) - else: - self.scale = nnx.Param(None) if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) - else: - self.bias = nnx.Param(None) self.epsilon = epsilon self.dtype = dtype @@ -822,8 +808,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale.value, - self.bias.value, + self.scale.value if self.use_scale else None, + self.bias.value if self.use_bias else None, reduction_axes[:-1], (self.feature_axis,), self.dtype,