@@ -246,3 +246,80 @@ def _to_backend_layout(tensor_layout):
246246 partition_spec = jax .sharding .PartitionSpec (* tensor_layout .axes )
247247 jax_mesh = tensor_layout .device_mesh .backend_mesh
248248 return jax .sharding .NamedSharding (jax_mesh , partition_spec )
249+
250+
251+ def _distribute_initializer (
252+ init_func = None , mean = 0.0 , stddev = 1.0 , seed = None , layout = None
253+ ):
254+ """
255+ Distribution-aware token embedding initializer for JAX backend.
256+
257+ This function will create a Jax random array and
258+ distribute it according to the current token embedding layout.
259+
260+ Args:
261+ init_func: A functools.partial-wrapped object that takes the seed
262+ as argument and returns a jax.Array. Must have shape and dtype
263+ already bound via partial.
264+ mean: Mean of distribution (applied to normal/truncated_normal).
265+ stddev: Standard deviation of the distribution.
266+ seed: Random seed for initialization.
267+ layout: TensorLayout for the distributed tensor.
268+
269+ Returns:
270+ A distributed jax array.
271+
272+ Raises:
273+ ValueError: If init_func or seed is None.
274+ If init_func.func is not a supported random function.
275+ TypeError: If init_func is not a functools.partial object.
276+ """
277+ import warnings
278+ from functools import partial
279+
280+ # Validate all required arguments
281+ if seed is None :
282+ raise ValueError ("seed cannot be None. Use keras.random.SeedGenerator." )
283+
284+ if init_func is None :
285+ raise ValueError (
286+ "init_func cannot be None. Shape and dtype info are required."
287+ )
288+
289+ # Ensure init_func is a partial
290+ if not isinstance (init_func , partial ):
291+ raise TypeError (
292+ f"init_func must be functools.partial object, got { type (init_func )} "
293+ )
294+
295+ # Shard based on tensor layout
296+ if layout is None :
297+ warnings .warn (
298+ f"The layout is { layout } , sharding will default to single device"
299+ )
300+ sharding = None
301+ else :
302+ sharding = _to_backend_layout (layout )
303+
304+ # The init_func has static arguments baked in as per initializer.
305+ compiled_init = jax .jit (
306+ lambda seed : init_func (seed ), out_shardings = sharding
307+ )
308+
309+ sample = compiled_init (seed )
310+
311+ # Apply mean/stddev only for distributions where it makes sense
312+ if init_func .func in (jax .random .normal , jax .random .truncated_normal ):
313+ return sample * stddev + mean
314+ elif init_func .func == jax .random .uniform :
315+ # Uniform doesn't use mean/stddev - warn
316+ if mean != 0.0 or stddev != 1.0 :
317+ warnings .warn (
318+ "mean and stddev are ignored for uniform distribution"
319+ )
320+ return sample
321+ else :
322+ raise ValueError (
323+ f"Unsupported initializer: { init_func .func .__name__ } . "
324+ f"Supported: normal, truncated_normal, uniform"
325+ )
0 commit comments