Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,183 @@ def set_distribution(value):
value: a `Distribution` instance.
"""
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value)


@keras_export("keras.distribution.AutoTPDistribution")
class AutoTPDistribution(Distribution):
"""A distribution strategy for automated tensor and data parallelism.

This distribution strategy provides a high-level abstraction for combining
both data parallelism and tensor parallelism. It automatically shards Keras
model's layers across multiple devices (tensor parallelism) while also
distributing the input data across those devices (data parallelism).

It uses a `DeviceMesh` to represent the grid of computational devices. If no
mesh is provided, it creates one using all available devices. The mesh must
have a 'data' axis for data sharding and a 'model' axis for model sharding.

Internally, this class wraps the user-provided Keras `Model` with the
`TensorParallelKeras` utility to handle the model sharding.

Args:
model: A `keras.Model` instance to be distributed.
device_mesh: (Optional) A `keras.distribution.DeviceMesh` instance.
If not provided, a `DeviceMesh` will be automatically created using
all available devices, arranging them for both data and model
parallelism.
auto_shard_dataset: (Optional) A boolean indicating whether to
automatically shard `tf.data.Dataset` instances across multiple
processes. Defaults to `True`.

Attributes:
model: The wrapped, tensor-parallel `keras.Model` instance that is
ready for distributed training.
device_mesh: The `DeviceMesh` instance used for distribution.

Raises:
RuntimeError: If no computational devices are found and `device_mesh`
is not provided.
ValueError: If the provided `device_mesh` does not have a 'data' axis.

Example:

```python
# Create a simple Keras model
inputs = keras.Input(shape=(64,))
x = keras.layers.Dense(128, activation="relu")(inputs)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Create the distribution strategy with the model
# It will automatically use all available GPUs/TPUs
distribution = keras.distribution.AutoTPDistribution(model)

# The distributed model is accessed via the .model attribute
distributed_model = distribution.model

# Compile the model as usual
distributed_model.compile(optimizer="adam", loss="mse")

# Prepare a dataset
input_data = np.random.rand(32, 64)
target_data = np.random.rand(32, 10)

# Train the model
distributed_model.fit(input_data, target_data)
```
"""

def __init__(self, model, device_mesh=None, auto_shard_dataset=True):
if device_mesh is None:
all_devices = list_devices()
if not all_devices:
raise RuntimeError("No computational devices found.")
device_mesh = DeviceMesh(
shape=(1, len(all_devices)),
axis_names=("data", "model"),
devices=all_devices,
)

if "data" not in device_mesh.axis_names:
raise ValueError(
"DeviceMesh for AutoTPDistribution must have a 'data' axis."
)
batch_dim_name = "data"

super().__init__(device_mesh, batch_dim_name, auto_shard_dataset)

self._original_model = model
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
self._is_multi_process = self._num_process > 1
from keras.src.distribution.tensor_parallel.tensor_parallel import (
TensorParallelKeras,
)

self.model = TensorParallelKeras(
model=self._original_model,
world_size=np.prod(self.device_mesh.shape),
device_ids=self.device_mesh.devices.flatten().tolist(),
)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self.batch_dim_name
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
warnings.warn(
"Variable layout is determined automatically within "
"AutoTPDistribution. This method will return a replicated layout."
)
return TensorLayout([None] * len(variable.shape), self.device_mesh)
Comment on lines +1004 to +1008
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is your injection point. This is where you should return the layout for each variable from your LayoutMap.


def get_tensor_layout(self, path):
return None
Comment on lines +1010 to +1011
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is your injection point. This is where you should return the output layout from your LayoutMap.state_rules.


def distribute_dataset(self, dataset):
"""Distributes the dataset across processes based on the device mesh."""
if not self._is_multi_process or not self.auto_shard_dataset:
return dataset

from keras.src.utils.module_utils import tensorflow as tf

if not tf.available or not isinstance(dataset, tf.data.Dataset):
raise ValueError(
"Only `tf.data.Dataset` is supported for auto-sharding, "
f"got {type(dataset)}"
)

from tensorflow.python.data.experimental.ops import (
distribute as tf_data_distribute,
)

global_batch_size = tf_data_distribute.compute_batch_size(dataset)
if global_batch_size.numpy() < 0:
raise ValueError(
"The batch size of the input dataset is unknown. "
"Please configure the batch size for the input dataset, "
"e.g., via `dataset.batch(batch_size)`"
)

mesh_batch_dim_index = self.device_mesh.axis_names.index(
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]

if num_model_replicas == 1:
return dataset.prefetch(tf.data.AUTOTUNE)

num_model_replicas_per_process = num_model_replicas / self._num_process
if num_model_replicas_per_process >= 1:
if global_batch_size % self._num_process != 0:
raise ValueError(
"Global batch size must be divisible by the number of "
f"processes. `global_batch_size`={global_batch_size} and "
f"`num_process`={self._num_process}"
)
per_process_batch_size = global_batch_size // self._num_process
distributed_dataset = dataset.rebatch(per_process_batch_size)
distributed_dataset = distributed_dataset.shard(
num_shards=self._num_process,
index=self._process_id,
)
return distributed_dataset.prefetch(tf.data.AUTOTUNE)
else:
if global_batch_size % num_model_replicas != 0:
raise ValueError(
"Global batch size must be divisible by the number of "
f"replicas. `global_batch_size`={global_batch_size} and "
f"`num_model_replicas`={num_model_replicas}"
)
per_replica_batch_size = global_batch_size // num_model_replicas
distributed_dataset = dataset.rebatch(per_replica_batch_size)

processes_per_replica = self._num_process // num_model_replicas
data_shard_id = self._process_id // processes_per_replica

distributed_dataset = distributed_dataset.shard(
num_shards=num_model_replicas,
index=data_shard_id,
)
return distributed_dataset.prefetch(tf.data.AUTOTUNE)
124 changes: 124 additions & 0 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from keras.src import testing
from keras.src.backend import distribution_lib as backend_dlib
from keras.src.distribution import distribution_lib
from keras.src.distribution.distribution_lib import AutoTPDistribution
from keras.src.distribution.tensor_parallel.tensor_parallel import (
TensorParallelKeras,
)
from keras.src.layers import Dense
from keras.src.layers import Input
from keras.src.losses import SparseCategoricalCrossentropy
from keras.src.models import Model
from keras.src.optimizers import Adam


@pytest.mark.skipif(
Expand Down Expand Up @@ -535,3 +544,118 @@ def test_iter(self):
# ValueError, "Cannot create sharding when device mesh is not set"
# ):
# backend_dlib._to_dtensor_layout(layout)


class AutoTPDistributionTest(testing.TestCase):
def setUp(self):
super().setUp()
self.devices = distribution_lib.list_devices()
if len(self.devices) < 2:
self.skipTest("This test requires at least 2 devices.")
inputs = Input(shape=(4,), name="input_layer")
x = Dense(8, name="dense_1")(inputs)
outputs = Dense(2, name="dense_2")(x)
self.model = Model(inputs, outputs)

def test_init_with_explicit_device_mesh(self):
"""Tests initialization with a user-provided DeviceMesh."""
device_mesh = distribution_lib.DeviceMesh(
shape=(1, 2), axis_names=["data", "model"], devices=self.devices
)
distribution = AutoTPDistribution(self.model, device_mesh=device_mesh)

self.assertIs(distribution.device_mesh, device_mesh)
self.assertEqual(distribution.batch_dim_name, "data")
self.assertIsInstance(
distribution.model,
TensorParallelKeras,
)
self.assertEqual(distribution.model.world_size, 2)

@mock.patch.object(
distribution_lib,
"list_devices",
return_value=[f"cpu:{i}" for i in range(2)],
)
def test_init_without_device_mesh_for_auto_creation(
self, mock_list_devices
):
"""Tests the automatic creation of DeviceMesh when none is provided."""
distribution = AutoTPDistribution(self.model, device_mesh=None)
mock_list_devices.assert_called_once()

device_mesh = distribution.device_mesh
self.assertEqual(device_mesh.shape, (1, 2))
self.assertEqual(device_mesh.axis_names, ("data", "model"))
self.assertEqual(distribution.batch_dim_name, "data")
self.assertEqual(distribution.model.world_size, 2)

def test_init_raises_error_on_missing_data_axis(self):
"""Ensures an error is raised if the DeviceMesh lacks a 'data' axis."""
device_mesh = distribution_lib.DeviceMesh(
shape=(2,), axis_names=["model"], devices=self.devices
)
with self.assertRaisesRegex(ValueError, "must have a 'data' axis"):
AutoTPDistribution(self.model, device_mesh=device_mesh)

def test_get_data_layout(self):
"""Verifies the layout for input data sharding."""
distribution = AutoTPDistribution(self.model)
data_shape = (16, 4)
layout = distribution.get_data_layout(data_shape)

self.assertEqual(layout.axes, ("data", None))
self.assertIs(layout.device_mesh, distribution.device_mesh)

def test_get_variable_layout_warns_and_returns_replicated(self):
"""Verifies that variable layout is handled internally."""
distribution = AutoTPDistribution(self.model)
dummy_variable = backend.Variable(initializer=np.zeros((8, 2)))

with self.assertWarns(UserWarning) as w:
layout = distribution.get_variable_layout(dummy_variable)

self.assertIn(
"Variable layout is determined automatically",
str(w.warnings[0].message),
)

self.assertEqual(layout.axes, (None, None))

def test_distribute_dataset_in_single_process_mode(self):
"""Tests dataset distribution in a single-process environment."""
distribution = AutoTPDistribution(self.model)
dataset = tf.data.Dataset.from_tensor_slices(
(np.zeros((16, 4)), np.zeros((16, 1)))
)

distributed_dataset = distribution.distribute_dataset(dataset)
self.assertIs(dataset, distributed_dataset)

def test_full_compile_and_fit_integration(self):
"""A test to ensure the distributed model can compile and train."""
distribution = AutoTPDistribution(self.model)

x_train = np.random.rand(16, 4).astype("float32")
y_train = np.random.randint(0, 2, size=(16, 1))

dist_model = distribution.model

with distribution.scope():
dist_model.compile(
optimizer=Adam(0.01),
loss=SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)

self.assertEqual(self.model.count_params(), dist_model.count_params())

history = dist_model.fit(
x_train,
y_train,
epochs=1,
batch_size=4,
verbose=0,
)
self.assertIn("loss", history.history)
self.assertIn("accuracy", history.history)
Loading
Loading