Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,80 @@ def _to_backend_layout(tensor_layout):
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
jax_mesh = tensor_layout.device_mesh.backend_mesh
return jax.sharding.NamedSharding(jax_mesh, partition_spec)


def _distribute_initializer(
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
):
"""
Distribution-aware token embedding initializer for JAX backend.

This function will create a Jax random array and
distribute it according to the current token embedding layout.

Args:
init_func: A functools.partial-wrapped object that takes the seed
as argument and returns a jax.Array. Must have shape and dtype
already bound via partial.
mean: Mean of distribution (applied to normal/truncated_normal).
stddev: Standard deviation of the distribution.
seed: Random seed for initialization.
layout: TensorLayout for the distributed tensor.

Returns:
A distributed jax array.

Raises:
ValueError: If init_func or seed is None.
If init_func.func is not a supported random function.
TypeError: If init_func is not a functools.partial object.
"""
import warnings
from functools import partial

# Validate all required arguments
if seed is None:
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")

if init_func is None:
raise ValueError(
"init_func cannot be None. Shape and dtype info are required."
)

# Ensure init_func is a partial
if not isinstance(init_func, partial):
raise TypeError(
f"init_func must be functools.partial object, got {type(init_func)}"
)

# Shard based on tensor layout
if layout is None:
warnings.warn(
f"The layout is {layout}, sharding will default to single device"
)
sharding = None
else:
sharding = _to_backend_layout(layout)

# The init_func has static arguments baked in as per initializer.
compiled_init = jax.jit(
lambda seed: init_func(seed), out_shardings=sharding
)

sample = compiled_init(seed)

# Apply mean/stddev only for distributions where it makes sense
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
return sample * stddev + mean
elif init_func.func == jax.random.uniform:
# Uniform doesn't use mean/stddev - warn
if mean != 0.0 or stddev != 1.0:
warnings.warn(
"mean and stddev are ignored for uniform distribution"
)
return sample
else:
raise ValueError(
f"Unsupported initializer: {init_func.func.__name__}. "
f"Supported: normal, truncated_normal, uniform"
)