Skip to content

Commit

Permalink
Don't create param in normalization layers instead of create None-val…
Browse files Browse the repository at this point in the history
…ue params.

This makes these layers align better with behavior of Linen layers, and also reduce confusion.

PiperOrigin-RevId: 719036051
  • Loading branch information
IvyZX authored and Flax Authors committed Jan 23, 2025
1 parent d28f03f commit 1246924
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,10 @@ def __init__(
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)

self.num_features = num_features
self.use_running_average = use_running_average
Expand Down Expand Up @@ -368,8 +364,8 @@ def __call__(
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.use_scale else None,
self.bias.value if self.use_bias else None,
reduction_axes,
feature_axes,
self.dtype,
Expand Down Expand Up @@ -457,14 +453,10 @@ def __init__(
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)

self.num_features = num_features
self.epsilon = epsilon
Expand Down Expand Up @@ -503,8 +495,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.use_scale else None,
self.bias.value if self.use_bias else None,
self.reduction_axes,
self.feature_axes,
self.dtype,
Expand Down Expand Up @@ -585,8 +577,6 @@ def __init__(
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

self.num_features = num_features
self.epsilon = epsilon
Expand Down Expand Up @@ -624,7 +614,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.scale.value if self.use_scale else None,
None,
self.reduction_axes,
self.feature_axes,
Expand Down Expand Up @@ -760,14 +750,10 @@ def __init__(
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = nnx.Param(None)

if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = nnx.Param(None)

self.epsilon = epsilon
self.dtype = dtype
Expand Down Expand Up @@ -822,8 +808,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
x,
mean,
var,
self.scale.value,
self.bias.value,
self.scale.value if self.use_scale else None,
self.bias.value if self.use_bias else None,
reduction_axes[:-1],
(self.feature_axis,),
self.dtype,
Expand Down

0 comments on commit 1246924

Please sign in to comment.