diff --git a/flax/linen/linear.py b/flax/linen/linear.py index babe809af0..f6b013eb88 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -145,7 +145,15 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): flat_shape = jax.tree_util.tree_map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, meta.AxisMetadata): - return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape)) + if isinstance(kernel, meta.Partitioned): + # User may have written an nd sharding, but the value is 2D due to + # initialization above. Unbox and rebox without applying the + # constraint to the flattened array. + return kernel.replace_boxed( + jnp.reshape(kernel.unbox(apply_constraint=False), shape) + ) + else: + return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape)) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims)