diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index 2daef40a2ed8..bc4936ce1a36 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -896,3 +896,183 @@ def set_distribution(value): value: a `Distribution` instance. """ global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) + + +@keras_export("keras.distribution.AutoTPDistribution") +class AutoTPDistribution(Distribution): + """A distribution strategy for automated tensor and data parallelism. + + This distribution strategy provides a high-level abstraction for combining + both data parallelism and tensor parallelism. It automatically shards Keras + model's layers across multiple devices (tensor parallelism) while also + distributing the input data across those devices (data parallelism). + + It uses a `DeviceMesh` to represent the grid of computational devices. If no + mesh is provided, it creates one using all available devices. The mesh must + have a 'data' axis for data sharding and a 'model' axis for model sharding. + + Internally, this class wraps the user-provided Keras `Model` with the + `TensorParallelKeras` utility to handle the model sharding. + + Args: + model: A `keras.Model` instance to be distributed. + device_mesh: (Optional) A `keras.distribution.DeviceMesh` instance. + If not provided, a `DeviceMesh` will be automatically created using + all available devices, arranging them for both data and model + parallelism. + auto_shard_dataset: (Optional) A boolean indicating whether to + automatically shard `tf.data.Dataset` instances across multiple + processes. Defaults to `True`. + + Attributes: + model: The wrapped, tensor-parallel `keras.Model` instance that is + ready for distributed training. + device_mesh: The `DeviceMesh` instance used for distribution. + + Raises: + RuntimeError: If no computational devices are found and `device_mesh` + is not provided. + ValueError: If the provided `device_mesh` does not have a 'data' axis. + + Example: + + ```python + # Create a simple Keras model + inputs = keras.Input(shape=(64,)) + x = keras.layers.Dense(128, activation="relu")(inputs) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs=inputs, outputs=outputs) + + # Create the distribution strategy with the model + # It will automatically use all available GPUs/TPUs + distribution = keras.distribution.AutoTPDistribution(model) + + # The distributed model is accessed via the .model attribute + distributed_model = distribution.model + + # Compile the model as usual + distributed_model.compile(optimizer="adam", loss="mse") + + # Prepare a dataset + input_data = np.random.rand(32, 64) + target_data = np.random.rand(32, 10) + + # Train the model + distributed_model.fit(input_data, target_data) + ``` + """ + + def __init__(self, model, device_mesh=None, auto_shard_dataset=True): + if device_mesh is None: + all_devices = list_devices() + if not all_devices: + raise RuntimeError("No computational devices found.") + device_mesh = DeviceMesh( + shape=(1, len(all_devices)), + axis_names=("data", "model"), + devices=all_devices, + ) + + if "data" not in device_mesh.axis_names: + raise ValueError( + "DeviceMesh for AutoTPDistribution must have a 'data' axis." + ) + batch_dim_name = "data" + + super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) + + self._original_model = model + self._num_process = distribution_lib.num_processes() + self._process_id = distribution_lib.process_id() + self._is_multi_process = self._num_process > 1 + from keras.src.distribution.tensor_parallel.tensor_parallel import ( + TensorParallelKeras, + ) + + self.model = TensorParallelKeras( + model=self._original_model, + world_size=np.prod(self.device_mesh.shape), + device_ids=self.device_mesh.devices.flatten().tolist(), + ) + + def get_data_layout(self, data_shape): + data_shard_spec = [None] * len(data_shape) + data_shard_spec[0] = self.batch_dim_name + return TensorLayout(data_shard_spec, self.device_mesh) + + def get_variable_layout(self, variable): + warnings.warn( + "Variable layout is determined automatically within " + "AutoTPDistribution. This method will return a replicated layout." + ) + return TensorLayout([None] * len(variable.shape), self.device_mesh) + + def get_tensor_layout(self, path): + return None + + def distribute_dataset(self, dataset): + """Distributes the dataset across processes based on the device mesh.""" + if not self._is_multi_process or not self.auto_shard_dataset: + return dataset + + from keras.src.utils.module_utils import tensorflow as tf + + if not tf.available or not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "Only `tf.data.Dataset` is supported for auto-sharding, " + f"got {type(dataset)}" + ) + + from tensorflow.python.data.experimental.ops import ( + distribute as tf_data_distribute, + ) + + global_batch_size = tf_data_distribute.compute_batch_size(dataset) + if global_batch_size.numpy() < 0: + raise ValueError( + "The batch size of the input dataset is unknown. " + "Please configure the batch size for the input dataset, " + "e.g., via `dataset.batch(batch_size)`" + ) + + mesh_batch_dim_index = self.device_mesh.axis_names.index( + self.batch_dim_name + ) + num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] + + if num_model_replicas == 1: + return dataset.prefetch(tf.data.AUTOTUNE) + + num_model_replicas_per_process = num_model_replicas / self._num_process + if num_model_replicas_per_process >= 1: + if global_batch_size % self._num_process != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"processes. `global_batch_size`={global_batch_size} and " + f"`num_process`={self._num_process}" + ) + per_process_batch_size = global_batch_size // self._num_process + distributed_dataset = dataset.rebatch(per_process_batch_size) + distributed_dataset = distributed_dataset.shard( + num_shards=self._num_process, + index=self._process_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) + else: + if global_batch_size % num_model_replicas != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"replicas. `global_batch_size`={global_batch_size} and " + f"`num_model_replicas`={num_model_replicas}" + ) + per_replica_batch_size = global_batch_size // num_model_replicas + distributed_dataset = dataset.rebatch(per_replica_batch_size) + + processes_per_replica = self._num_process // num_model_replicas + data_shard_id = self._process_id // processes_per_replica + + distributed_dataset = distributed_dataset.shard( + num_shards=num_model_replicas, + index=data_shard_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 66f996b3fb68..142d5d8a307c 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -11,6 +11,15 @@ from keras.src import testing from keras.src.backend import distribution_lib as backend_dlib from keras.src.distribution import distribution_lib +from keras.src.distribution.distribution_lib import AutoTPDistribution +from keras.src.distribution.tensor_parallel.tensor_parallel import ( + TensorParallelKeras, +) +from keras.src.layers import Dense +from keras.src.layers import Input +from keras.src.losses import SparseCategoricalCrossentropy +from keras.src.models import Model +from keras.src.optimizers import Adam @pytest.mark.skipif( @@ -535,3 +544,118 @@ def test_iter(self): # ValueError, "Cannot create sharding when device mesh is not set" # ): # backend_dlib._to_dtensor_layout(layout) + + +class AutoTPDistributionTest(testing.TestCase): + def setUp(self): + super().setUp() + self.devices = distribution_lib.list_devices() + if len(self.devices) < 2: + self.skipTest("This test requires at least 2 devices.") + inputs = Input(shape=(4,), name="input_layer") + x = Dense(8, name="dense_1")(inputs) + outputs = Dense(2, name="dense_2")(x) + self.model = Model(inputs, outputs) + + def test_init_with_explicit_device_mesh(self): + """Tests initialization with a user-provided DeviceMesh.""" + device_mesh = distribution_lib.DeviceMesh( + shape=(1, 2), axis_names=["data", "model"], devices=self.devices + ) + distribution = AutoTPDistribution(self.model, device_mesh=device_mesh) + + self.assertIs(distribution.device_mesh, device_mesh) + self.assertEqual(distribution.batch_dim_name, "data") + self.assertIsInstance( + distribution.model, + TensorParallelKeras, + ) + self.assertEqual(distribution.model.world_size, 2) + + @mock.patch.object( + distribution_lib, + "list_devices", + return_value=[f"cpu:{i}" for i in range(2)], + ) + def test_init_without_device_mesh_for_auto_creation( + self, mock_list_devices + ): + """Tests the automatic creation of DeviceMesh when none is provided.""" + distribution = AutoTPDistribution(self.model, device_mesh=None) + mock_list_devices.assert_called_once() + + device_mesh = distribution.device_mesh + self.assertEqual(device_mesh.shape, (1, 2)) + self.assertEqual(device_mesh.axis_names, ("data", "model")) + self.assertEqual(distribution.batch_dim_name, "data") + self.assertEqual(distribution.model.world_size, 2) + + def test_init_raises_error_on_missing_data_axis(self): + """Ensures an error is raised if the DeviceMesh lacks a 'data' axis.""" + device_mesh = distribution_lib.DeviceMesh( + shape=(2,), axis_names=["model"], devices=self.devices + ) + with self.assertRaisesRegex(ValueError, "must have a 'data' axis"): + AutoTPDistribution(self.model, device_mesh=device_mesh) + + def test_get_data_layout(self): + """Verifies the layout for input data sharding.""" + distribution = AutoTPDistribution(self.model) + data_shape = (16, 4) + layout = distribution.get_data_layout(data_shape) + + self.assertEqual(layout.axes, ("data", None)) + self.assertIs(layout.device_mesh, distribution.device_mesh) + + def test_get_variable_layout_warns_and_returns_replicated(self): + """Verifies that variable layout is handled internally.""" + distribution = AutoTPDistribution(self.model) + dummy_variable = backend.Variable(initializer=np.zeros((8, 2))) + + with self.assertWarns(UserWarning) as w: + layout = distribution.get_variable_layout(dummy_variable) + + self.assertIn( + "Variable layout is determined automatically", + str(w.warnings[0].message), + ) + + self.assertEqual(layout.axes, (None, None)) + + def test_distribute_dataset_in_single_process_mode(self): + """Tests dataset distribution in a single-process environment.""" + distribution = AutoTPDistribution(self.model) + dataset = tf.data.Dataset.from_tensor_slices( + (np.zeros((16, 4)), np.zeros((16, 1))) + ) + + distributed_dataset = distribution.distribute_dataset(dataset) + self.assertIs(dataset, distributed_dataset) + + def test_full_compile_and_fit_integration(self): + """A test to ensure the distributed model can compile and train.""" + distribution = AutoTPDistribution(self.model) + + x_train = np.random.rand(16, 4).astype("float32") + y_train = np.random.randint(0, 2, size=(16, 1)) + + dist_model = distribution.model + + with distribution.scope(): + dist_model.compile( + optimizer=Adam(0.01), + loss=SparseCategoricalCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + + self.assertEqual(self.model.count_params(), dist_model.count_params()) + + history = dist_model.fit( + x_train, + y_train, + epochs=1, + batch_size=4, + verbose=0, + ) + self.assertIn("loss", history.history) + self.assertIn("accuracy", history.history) diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel.py b/keras/src/distribution/tensor_parallel/tensor_parallel.py new file mode 100644 index 000000000000..2251a699c723 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel.py @@ -0,0 +1,349 @@ +""" +Tensor Parallel implementation for Keras 3.0 +Port of the PyTorch tensor_parallel library +""" + +import re +from typing import Optional +from typing import Sequence +from typing import Union + +import numpy as np + +from keras.src import ops +from keras.src.distribution import list_devices +from keras.src.distribution.tensor_parallel.autoconfig import get_default_config +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.layers import Add +from keras.src.layers import Input +from keras.src.layers import Lambda +from keras.src.models import Model + + +class TensorParallelKeras(Model): + def __init__( + self, + model, + device_count=None, + device_ids=None, + **kwargs, + ): + super().__init__(**kwargs) + + self._original_model = model + + if device_count is None: + device_count, device_ids = self._auto_detect_parallelism() + elif device_ids is None: + device_ids = self._auto_configure_devices(device_count) + + self.device_count = device_count + self.device_ids = device_ids + self.sharding_strategy = "auto" + + self.tensor_parallel_config = None + self.distributed = True + + self.sharded_models = [self._original_model] + + accel_devices = list_devices() + device_ids = list(self.check_device_ids(device_ids)) + + if accel_devices: + if len(accel_devices) >= device_count: + device_ids = accel_devices[:device_count] + else: + device_count = len(accel_devices) + device_ids = accel_devices[:device_count] + + if not device_ids: + device_ids = self._auto_configure_devices(device_count) + + if len(device_ids) != device_count: + device_ids = self._adjust_device_list(device_ids, device_count) + + self.devices = device_ids + self.device_count = device_count + + if self.device_count <= 1: + self.model_shards = [model] + self.distributed = False + if len(self.devices) == 1: + from keras import device + + with device(self.devices[0]): + self.model_shards[0] = model + + self.assembled_model = self._original_model + + if hasattr(self._original_model, "inputs"): + self._inputs = self._original_model.inputs + self._outputs = self._original_model.outputs + + self.built = True + return + + if self.tensor_parallel_config is None: + device_names = [str(d) for d in self.devices] + self.tensor_parallel_config = get_default_config( + model, device_names + ) + + self._is_multi_layer_model = len(model.layers) > 2 + + self.model_shards = [] + self.modified_parameters_names = set() + + for rank, device_id in enumerate(self.devices): + shard, modified_parameters_names = make_parameter_sharded_model( + model, + self.tensor_parallel_config, + rank=rank, + device_count=self.device_count, + device_id=device_id, + ) + self.model_shards.append(shard) + self.modified_parameters_names.update(modified_parameters_names) + + params_per_shard = [] + for i, shard in enumerate(self.model_shards): + total_params = sum(np.prod(p.shape) for p in shard.weights) + params_per_shard.append(int(total_params)) + + self.built = True + if self.distributed: + self.assembled_model = self.build_assembled_model() + self._inputs = self.assembled_model.inputs + self._outputs = self.assembled_model.outputs + else: + self.assembled_model = self._original_model + + @property + def variables(self): + """Returns a list of all unique variables from all model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.variables + } + return list(unique_vars.values()) + + @property + def trainable_variables(self): + """Returns list of all unique trainable variables from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.trainable_variables + } + return list(unique_vars.values()) + + @property + def non_trainable_variables(self): + """Returns list of unique non-trainable variables from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.non_trainable_variables + } + return list(unique_vars.values()) + + @property + def weights(self): + """Returns a list of all unique weights from all model shards.""" + unique_vars = { + id(var): var for shard in self.model_shards for var in shard.weights + } + return list(unique_vars.values()) + + @property + def trainable_weights(self): + """Returns a list of all unique trainable weights from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.trainable_weights + } + return list(unique_vars.values()) + + @property + def non_trainable_weights(self): + """Returns list of unique non-trainable weights from model shards.""" + unique_vars = { + id(var): var + for shard in self.model_shards + for var in shard.non_trainable_weights + } + return list(unique_vars.values()) + + def _auto_detect_parallelism(self): + """Auto-detect device_count and device_ids efficiently.""" + from keras.src.distribution import get_best_devices + + available_devices = list_devices() + device_count = len(available_devices) + + device_ids = get_best_devices(device_count) + + return device_count, device_ids + + def _adjust_device_list(self, device_ids, target_device_count): + """Adjust device list to match target device_count intelligently.""" + current_size = len(device_ids) + if current_size >= target_device_count: + return device_ids[:target_device_count] + + return list(device_ids) + [ + f"cpu:{i}" for i in range(current_size, target_device_count) + ] + + def _auto_configure_devices(self, device_count): + """Auto-configure devices - simplified version.""" + available_devices = list_devices() + if available_devices: + devices = available_devices[:device_count] + return devices + else: + return ["cpu:0"] + + def check_device_ids( + self, device_ids: Optional[Sequence[str]] + ) -> Sequence[str]: + """Validate and normalize device IDs for Keras.""" + if device_ids is None: + device_ids = self._get_all_device_indices() + + return tuple(self.canonicalize_device(d) for d in device_ids) + + def _get_all_device_indices(self) -> Sequence[str]: + """Get all available device indices using distribution library.""" + return list_devices() + + def build_assembled_model(self): + """ + Builds a single, JIT-friendly Keras Functional model that encapsulates + the entire tensor parallel logic, correctly handling multiple inputs. + """ + if not self.distributed: + return self._original_model + + input_layers = { + inp.name.split(":")[0]: Input( + shape=inp.shape[1:], + dtype=inp.dtype, + name=inp.name.split(":")[0], + ) + for inp in self._original_model.inputs + } + + partial_outputs = [] + for shard in self.model_shards: + shard_inputs = {} + input_names = getattr(shard, "input_names", None) + if input_names: + for name in input_names: + clean_name = name.split(":")[0] + if clean_name in input_layers: + shard_inputs[clean_name] = input_layers[clean_name] + else: + for inp in getattr(shard, "inputs", []): + clean_name = inp.name.split(":")[0] + if clean_name in input_layers: + shard_inputs[clean_name] = input_layers[clean_name] + + if not shard_inputs: + shard_inputs = dict(input_layers) + + partial_outputs.append(shard(shard_inputs)) + + final_layer = self._original_model.layers[-1] + sharding_type = "unknown" + final_kernel_name = f"{final_layer.name}.kernel" + if hasattr(self._original_model, "name") and self._original_model.name: + final_kernel_name = ( + f"{self._original_model.name}.{final_kernel_name}" + ) + + for pattern, action in self.tensor_parallel_config.state_rules.items(): + if re.search(pattern, final_kernel_name): + if hasattr(action, "sharding_type"): + sharding_type = action.sharding_type + break + + if sharding_type == "column": + final_output = ops.concatenate(partial_outputs, axis=-1) + original_output_dim = self._original_model.output_shape[-1] + if final_output.shape[-1] != original_output_dim: + final_output = Lambda(lambda x: x[..., :original_output_dim])( + final_output + ) + elif sharding_type == "row": + if len(partial_outputs) > 1: + summed_output = Add()(partial_outputs) + else: + summed_output = partial_outputs[0] + + if final_layer.use_bias: + bias = final_layer.bias + final_output = Lambda( + lambda x: x - bias * (self.device_count - 1) + )(summed_output) + else: + final_output = summed_output + else: + final_output = partial_outputs[0] + + assembled_model = Model( + inputs=list(input_layers.values()), outputs=final_output + ) + return assembled_model + + def canonicalize_device(self, device_spec: Union[str, int]) -> str: + """Convert device specification to canonical form.""" + if isinstance(device_spec, int): + if device_spec == -1: + return "cpu" + else: + return f"gpu:{device_spec}" + elif isinstance(device_spec, str): + if device_spec == "cpu": + return "cpu" + elif device_spec.startswith("gpu:"): + return device_spec + elif device_spec.startswith("cuda:"): + return f"gpu:{device_spec.split(':')[1]}" + else: + return device_spec + else: + return "cpu" + + def call(self, inputs, training=None, **kwargs): + """ + Forward pass for the tensor-parallel model. + """ + return self.assembled_model(inputs, training=training, **kwargs) + + def compile( + self, + optimizer=None, + loss=None, + metrics=None, + loss_weights=None, + **kwargs, + ): + """ + Compile the tensor parallel model. + """ + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + **kwargs, + ) + + def fit(self, x=None, y=None, **kwargs): + """Use standard Keras training which handles the train_step.""" + return super().fit(x, y, **kwargs) diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel_test.py b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py new file mode 100644 index 000000000000..5af9534bebe8 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py @@ -0,0 +1,132 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src.distribution.tensor_parallel.tensor_parallel import ( + TensorParallelKeras, +) +from keras.src.layers import Input +from keras.src.models import Model +from keras.src.optimizers import Adam +from keras.src.testing import TestCase + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) +class TensorParallelKerasTest(TestCase): + """ + Test suite for the TensorParallelKeras class running on the JAX backend. + """ + + def setUp(self): + """Set up a reusable model and data for all tests.""" + super().setUp() + + inputs = Input(shape=(64,), name="input_layer") + x = layers.Dense(128, activation="relu", name="dense_column_sharded")( + inputs + ) + outputs = layers.Dense(10, name="dense_row_sharded")(x) + self.original_model = Model( + inputs=inputs, outputs=outputs, name="test_mlp" + ) + + self.input_data = np.random.rand(32, 64).astype("float32") + self.target_data = np.random.rand(32, 10).astype("float32") + + self.world_size = 2 + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] + + def test_initialization_and_sharding_verification(self): + """ + Tests if model is correctly initialized and parameter sharding occurs. + """ + tp_model = TensorParallelKeras( + self.original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + self.assertTrue(tp_model.distributed) + self.assertEqual(tp_model.world_size, self.world_size) + self.assertEqual(len(tp_model.model_shards), self.world_size) + + original_params = self.original_model.count_params() + shard_0_params = tp_model.model_shards[0].count_params() + + self.assertLess(shard_0_params, original_params) + + tp_model_total_params = sum(np.prod(v.shape) for v in tp_model.weights) + self.assertEqual(tp_model_total_params, original_params) + + def test_non_distributed_case_world_size_one(self): + """ + Tests the behavior when world_size is 1, ensuring it gracefully degrades + to a standard, non-distributed model. + """ + tp_model = TensorParallelKeras(self.original_model, world_size=1) + + self.assertFalse(tp_model.distributed) + self.assertEqual(tp_model.world_size, 1) + self.assertEqual(len(tp_model.model_shards), 1) + self.assertIs(tp_model.assembled_model, self.original_model) + + output = tp_model.predict(self.input_data, verbose=0) + self.assertEqual(output.shape, (32, 10)) + + def test_forward_pass_correctness(self): + """ + Tests if the output of the sharded model is numerically identical + to the original model. + """ + inputs = Input(shape=(64,), name="input_layer") + x = layers.Dense( + 128, activation="relu", kernel_initializer="glorot_uniform" + )(inputs) + outputs = layers.Dense(10, kernel_initializer="glorot_uniform")(x) + original_model = Model(inputs=inputs, outputs=outputs) + + input_data = np.random.rand(32, 64).astype("float32") + + original_output = original_model(input_data, training=False) + + tp_model = TensorParallelKeras( + original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + tp_output = tp_model(input_data, training=False) + + self.assertAllClose(original_output, tp_output, atol=1e-5, rtol=1e-5) + + def test_distributed_training_workflow(self): + """ + Tests if model can be compiled and trained for one step. + """ + tp_model = TensorParallelKeras( + self.original_model, + world_size=self.world_size, + device_ids=self.device_ids, + ) + + tp_model.compile( + optimizer=Adam(learning_rate=0.01), + loss="mse", + ) + + self.assertTrue(hasattr(tp_model, "coordinated_optimizer")) + + history = tp_model.fit( + self.input_data, + self.target_data, + epochs=1, + batch_size=16, + verbose=0, + ) + + self.assertIn("loss", history.history) + self.assertIsNotNone(history.history["loss"][0])