Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
06bb3bb
Adding tensor layout for TP autosharding
buildwithsuhana Oct 28, 2025
41f8025
formatting files
buildwithsuhana Oct 28, 2025
e74eab2
Updating the docstring
buildwithsuhana Oct 28, 2025
2cddf39
refactoring the code
buildwithsuhana Oct 28, 2025
fee036e
Merge branch 'tensor_parallel' of https://github.com/buildwithsuhana/…
buildwithsuhana Oct 28, 2025
9bed6e4
Merge branch 'keras-team:master' into tensor_parallel
buildwithsuhana Nov 6, 2025
5365f14
fixing test
buildwithsuhana Nov 6, 2025
bc4d094
fixing test
buildwithsuhana Nov 6, 2025
4d32e49
adding autoconfig and coordinated_optimizer
buildwithsuhana Nov 17, 2025
119ac15
updating docstrings and code format
buildwithsuhana Nov 17, 2025
7851615
refactored autoconfig to not use recursion
buildwithsuhana Nov 17, 2025
4707c2b
updating docstrings
buildwithsuhana Nov 17, 2025
45aa44c
removing redundancies
buildwithsuhana Nov 17, 2025
8bb39f6
added tests for autoconfig and coordinated optimizer
buildwithsuhana Nov 18, 2025
ab444b1
fixing autoconfig
buildwithsuhana Dec 8, 2025
d5612eb
fixing autoconfig
buildwithsuhana Dec 8, 2025
a777178
ficing autoconfig test
buildwithsuhana Dec 8, 2025
7b144d9
fixing tensor layout and core
buildwithsuhana Dec 8, 2025
12b038a
running pre commit
buildwithsuhana Dec 8, 2025
d9eabc8
adding test
buildwithsuhana Dec 8, 2025
74437c9
adding test
buildwithsuhana Dec 8, 2025
6eeb589
Fixing autoconfig
buildwithsuhana Dec 11, 2025
207a4bf
fixing coordinated_optimizer
buildwithsuhana Dec 11, 2025
17cf142
fixing tests
buildwithsuhana Dec 29, 2025
f834eca
fixed all comments and tests
buildwithsuhana Dec 29, 2025
e30b201
Merge branch 'keras-team:master' into tensor_parallel
buildwithsuhana Jan 2, 2026
1ff9e48
fixing tests
buildwithsuhana Jan 2, 2026
ef537ff
Merge branch 'tensor_parallel' of https://github.com/buildwithsuhana/…
buildwithsuhana Jan 2, 2026
89f5e77
reverting changes
buildwithsuhana Jan 2, 2026
0973f69
fixing test
buildwithsuhana Jan 2, 2026
68e8862
removing redundant lines
buildwithsuhana Jan 2, 2026
5d82be8
bringing tests to similar format
buildwithsuhana Jan 2, 2026
82fc41e
fixing dtype issue
buildwithsuhana Jan 2, 2026
2801325
fixing test
buildwithsuhana Jan 2, 2026
88d70c8
fixing test
buildwithsuhana Jan 2, 2026
7e8aded
fixing comments
buildwithsuhana Jan 13, 2026
f3fe7ac
fixing format
buildwithsuhana Jan 13, 2026
7b444ec
fixes
buildwithsuhana Jan 13, 2026
f608974
simplified coordinated_optimizer
buildwithsuhana Jan 13, 2026
99a9885
fixing dtype error
buildwithsuhana Jan 13, 2026
2f425c6
fixing dtype
buildwithsuhana Jan 13, 2026
bad0f22
fixing test
buildwithsuhana Jan 13, 2026
5c0a2f2
fixing test
buildwithsuhana Jan 13, 2026
9ee27c1
fixing test
buildwithsuhana Jan 13, 2026
905a3ee
Removing coordinated_optimizer.py
buildwithsuhana Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for distribution strategy with JAX backend."""

import jax
import jax.lax as lax
import numpy as np

from keras.src.backend.common import global_state
Expand Down Expand Up @@ -212,6 +213,58 @@ def process_id():
return jax.process_index()


def all_reduce(x, op="sum", axis_name="model"):
"""Reduces a tensor across a device mesh axis using a collective.

Args:
x: The tensor to reduce.
op: The reduction operation. "sum" or "mean".
axis_name: The name of the mesh axis to reduce over.

Returns:
The reduced tensor.
"""

def _reduce_fn(y):
if op == "sum":
return lax.psum(y, axis_name=axis_name)
elif op == "mean":
return lax.pmean(y, axis_name=axis_name)
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)

return jax.pmap(_reduce_fn, axis_name=axis_name)(x)


def all_gather(x, axis, axis_name="model"):
"""Gathers and concatenates tensors from all devices across a mesh axis.

This function assumes it is called within a `pjit` context. It takes
the local shard `x` from each device along the `axis_name` of the mesh
and concatenates them along the specified tensor `axis` to form a
single, larger tensor that is then replicated on all participating devices.

Args:
x (jax.Array): The input JAX array (tensor) shard on the local device.
axis (int): The tensor axis along which to concatenate the gathered
shards.
axis_name (str, optional): The name of the mesh axis to gather
from. Defaults to 'model'.

Returns:
jax.Array: The full, gathered JAX array, which is identical across
all devices participating in the gather.
"""

def _gather_fn(y):
return lax.all_gather(y, axis_name=axis_name, axis=axis, tiled=False)

return jax.pmap(_gather_fn, axis_name=axis_name)(x)


def _to_backend_device(device_name):
if isinstance(device_name, jax.Device):
return device_name
Expand Down
45 changes: 45 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,51 @@ def test_distribute_data_input(self):
for shard in result.addressable_shards:
self.assertEqual(shard.data.shape, (3, 4))

def test_all_reduce(self):
devices = jax.devices()
num_devices = len(devices)
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("batch",))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("batch")
)

input_data = jax.device_put(
np.ones((num_devices, 2), dtype="float32"), sharding
)

result_sum = backend_dlib.all_reduce(
input_data, op="sum", axis_name="batch"
)

expected_sum = np.full((num_devices, 2), num_devices, dtype="float32")
self.assertAllClose(result_sum, expected_sum)

result_mean = backend_dlib.all_reduce(
input_data, op="mean", axis_name="batch"
)
self.assertAllClose(result_mean, input_data)

def test_all_gather(self):
devices = jax.devices()
num_devices = len(devices)
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("batch",))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("batch")
)

input_data = jax.device_put(
np.arange(num_devices, dtype="float32").reshape((num_devices, 1)),
sharding,
)

results = backend_dlib.all_gather(input_data, axis=0, axis_name="batch")

expected_gathered = np.arange(num_devices, dtype="float32").reshape(
num_devices, 1
)
expected_results = np.stack([expected_gathered] * num_devices)
self.assertAllClose(results, expected_results)


class ShardingCaptureLayer(layers.Layer):
def __init__(self, **kwargs):
Expand Down
179 changes: 179 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import functools

from keras.src import layers
from keras.src.backend import distribution_lib
from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap
from keras.src.distribution.tensor_parallel.tensor_layout import (
split_tensor_for_parallelism,
)


def analyze_dense_layer(layer):
"""Classifies a Dense layer based on its input/output dimensions.

This function uses a heuristic to determine if a Dense layer acts as an
'up_projection' (expansion), a 'down_projection' (contraction), or a
standard 'dense' layer. This classification is used to determine the
appropriate sharding strategy (e.g., column-parallel vs row-parallel).

Args:
layer: The Keras Dense layer instance to analyze.

Returns:
str: One of 'up_projection', 'down_projection', or 'dense'.
"""
input_dim = None
output_dim = None

kernel = getattr(layer, "kernel", getattr(layer, "_kernel", None))
if kernel is not None:
if len(kernel.shape) == 2:
input_dim = kernel.shape[0]
output_dim = kernel.shape[1]

if output_dim is None and hasattr(layer, "units"):
output_dim = layer.units

if (
input_dim is None
and hasattr(layer, "input_shape")
and layer.input_shape
and len(layer.input_shape) > 1
):
input_dim = layer.input_shape[-1]

if input_dim is None or output_dim is None:
return "dense"

expansion_threshold = 1.5
is_expansion = output_dim > input_dim * expansion_threshold
is_contraction = input_dim > output_dim * expansion_threshold

if is_expansion:
return "up_projection"
elif is_contraction:
return "down_projection"
else:
return "dense"


def _reduce_sum(x):
"""Performs an all-reduce sum operation across the 'model' mesh axis.

Args:
x: The input tensor to reduce.

Returns:
The reduced tensor, summed across all devices in the model axis.
"""
return distribution_lib.all_reduce(x, op="sum", axis_name="model")


def _gather(x, axis):
"""Performs an all-gather operation across the 'model' mesh axis.

Args:
x: The input tensor shard to gather.
axis: The axis along which to concatenate the gathered parts.

Returns:
The gathered tensor, concatenated along the specified axis.
"""
return distribution_lib.all_gather(x, axis=axis, axis_name="model")


def _apply_layer_sharding_rules(layer, device_count, state_rules, output_rules):
"""Applies sharding rules to a single layer based on its type.

This function populates `state_rules` and `output_rules` with strategies
specific to the layer class (e.g., Dense, EinsumDense, Embedding). It
determines how weights should be partitioned (state rules) and how outputs
should be synchronized (output rules).

Args:
layer: The Keras layer instance to configure.
device_count: The number of devices available for tensor parallelism.
state_rules: A dictionary mapping variable paths to sharding functions.
Updated in-place.
output_rules: A dictionary mapping layer paths to output communication
functions. Updated in-place.
"""

def split_rule(dim):
return functools.partial(
split_tensor_for_parallelism, device_count=device_count, dim=dim
)

def gather_rule(axis):
return functools.partial(_gather, axis=axis)

layer_path = layer.path

if isinstance(layer, layers.Dense):
mlp_type = analyze_dense_layer(layer)

if mlp_type == "up_projection":
state_rules[id(layer.kernel)] = split_rule(dim=1)
if layer.use_bias:
state_rules[id(layer.bias)] = split_rule(dim=0)
output_rules[layer_path] = gather_rule(axis=-1)

elif mlp_type == "down_projection":
state_rules[id(layer.kernel)] = split_rule(dim=0)
output_rules[layer_path] = _reduce_sum

else:
state_rules[id(layer.kernel)] = split_rule(dim=1)
if layer.use_bias:
state_rules[id(layer.bias)] = split_rule(dim=0)
output_rules[layer_path] = gather_rule(axis=-1)

elif isinstance(layer, layers.EinsumDense):
if "attention_output" in layer.name:
state_rules[id(layer.kernel)] = split_rule(dim=0)
output_rules[layer_path] = _reduce_sum
else:
state_rules[id(layer.kernel)] = split_rule(dim=1)
if hasattr(layer, "bias") and layer.bias is not None:
state_rules[id(layer.bias)] = split_rule(dim=0)
output_rules[layer_path] = gather_rule(axis=-1)

elif (
isinstance(layer, (layers.Embedding,))
or "Embedding" in layer.__class__.__name__
):
embeddings_var = getattr(layer, "embeddings", None)
if embeddings_var is not None:
state_rules[id(embeddings_var)] = split_rule(dim=1)
output_rules[layer_path] = lambda x: x


def get_default_config(model, device_ids):
"""Generates a default tensor parallelism configuration for a model.

This function traverses the model's layer hierarchy and
automatically generates a `LayoutMap`. This map contains:
1. `state_rules`: How to shard the weights of supported layers
across the specified devices.
2. `output_rules`: How to synchronize or gather the outputs of
these layers during the forward pass.

Args:
model: The Keras model to configure.
device_ids: A list of device identifiers to be used
for distribution.

Returns:
LayoutMap: A configuration object containing `state_rules` and
`output_rules` for tensor parallelism.
"""
device_count = len(device_ids)
state_rules = {}
output_rules = {}

for layer in model._flatten_layers(recursive=True, include_self=True):
_apply_layer_sharding_rules(
layer, device_count, state_rules, output_rules
)

return LayoutMap(state_rules=state_rules, output_rules=output_rules)
Loading