diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 97e039663d..1e30918612 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -14,9 +14,9 @@ """Normalization modules for Flax.""" +import functools from typing import (Any, Callable, Iterable, Optional, Tuple, Union) from flax.linen.dtypes import canonicalize_dtype - from flax.linen.module import Module, compact, merge_param # pylint: disable=g-multiple-import from jax import lax from jax.nn import initializers @@ -90,12 +90,13 @@ def _compute_stats(x: Array, axes: Optional[Axes], mean = jnp.zeros(mean2.shape, dtype=dtype) if axis_name is not None: - concatenated_mean = jnp.concatenate([mean, mean2]) - mean, mean2 = jnp.split( - lax.pmean( - concatenated_mean, - axis_name=axis_name, - axis_index_groups=axis_index_groups), 2) + pmean = functools.partial( + lax.pmean, axis_name=axis_name, axis_index_groups=axis_index_groups + ) + if use_mean: + mean, mean2 = jnp.split(pmean(jnp.concatenate([mean, mean2])), 2) + else: + mean2 = pmean(mean2) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0., mean2 - _abs_sq(mean))