Skip to content

Commit

Permalink
Bypasses WSC application in flax Dense layer's unboxing before reshape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705931656
  • Loading branch information
jkr26 authored and Flax Authors committed Dec 13, 2024
1 parent fc38f21 commit c9891c8
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,15 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
flat_shape = jax.tree_util.tree_map(int, flat_shape)
kernel = self.kernel_init(rng, flat_shape, dtype)
if isinstance(kernel, meta.AxisMetadata):
return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape))
if isinstance(kernel, meta.Partitioned):
# User may have written an nd sharding, but the value is 2D due to
# initialization above. Unbox and rebox without applying the
# constraint to the flattened array.
return kernel.replace_boxed(
jnp.reshape(kernel.unbox(apply_constraint=False), shape)
)
else:
return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape))
return jnp.reshape(kernel, shape)

batch_shape = tuple(inputs.shape[ax] for ax in batch_dims)
Expand Down

0 comments on commit c9891c8

Please sign in to comment.