Skip to content

Commit ac871f1

Browse files
Fix OOM Issue
1 parent e2be4de commit ac871f1

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,80 @@ def _to_backend_layout(tensor_layout):
246246
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
247247
jax_mesh = tensor_layout.device_mesh.backend_mesh
248248
return jax.sharding.NamedSharding(jax_mesh, partition_spec)
249+
250+
251+
def _distribute_initializer(
252+
init_func=None, mean=0.0, stddev=1.0, seed=None, layout=None
253+
):
254+
"""
255+
Distribution-aware token embedding initializer for JAX backend.
256+
257+
This function will create a Jax random array and
258+
distribute it according to the current token embedding layout.
259+
260+
Args:
261+
init_func: A functools.partial-wrapped object that takes the seed
262+
as argument and returns a jax.Array. Must have shape and dtype
263+
already bound via partial.
264+
mean: Mean of distribution (applied to normal/truncated_normal).
265+
stddev: Standard deviation of the distribution.
266+
seed: Random seed for initialization.
267+
layout: TensorLayout for the distributed tensor.
268+
269+
Returns:
270+
A distributed jax array.
271+
272+
Raises:
273+
ValueError: If init_func or seed is None.
274+
If init_func.func is not a supported random function.
275+
TypeError: If init_func is not a functools.partial object.
276+
"""
277+
import warnings
278+
from functools import partial
279+
280+
# Validate all required arguments
281+
if seed is None:
282+
raise ValueError("seed cannot be None. Use keras.random.SeedGenerator.")
283+
284+
if init_func is None:
285+
raise ValueError(
286+
"init_func cannot be None. Shape and dtype info are required."
287+
)
288+
289+
# Ensure init_func is a partial
290+
if not isinstance(init_func, partial):
291+
raise TypeError(
292+
f"init_func must be functools.partial object, got {type(init_func)}"
293+
)
294+
295+
# Shard based on tensor layout
296+
if layout is None:
297+
warnings.warn(
298+
f"The layout is {layout}, sharding will default to single device"
299+
)
300+
sharding = None
301+
else:
302+
sharding = _to_backend_layout(layout)
303+
304+
# The init_func has static arguments baked in as per initializer.
305+
compiled_init = jax.jit(
306+
lambda seed: init_func(seed), out_shardings=sharding
307+
)
308+
309+
sample = compiled_init(seed)
310+
311+
# Apply mean/stddev only for distributions where it makes sense
312+
if init_func.func in (jax.random.normal, jax.random.truncated_normal):
313+
return sample * stddev + mean
314+
elif init_func.func == jax.random.uniform:
315+
# Uniform doesn't use mean/stddev - warn
316+
if mean != 0.0 or stddev != 1.0:
317+
warnings.warn(
318+
"mean and stddev are ignored for uniform distribution"
319+
)
320+
return sample
321+
else:
322+
raise ValueError(
323+
f"Unsupported initializer: {init_func.func.__name__}. "
324+
f"Supported: normal, truncated_normal, uniform"
325+
)

0 commit comments

Comments
 (0)