-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Enable Automatic Tensor Parallelism #21726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
buildwithsuhana
wants to merge
8
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:Tensor_parallel_keras_4
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+785
−0
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ee43a75
added tensor_parallel and autoTPDistribution API
buildwithsuhana 46cb777
added tests and docstring
buildwithsuhana 8734a1f
refactor
buildwithsuhana d4f2dcc
fixing parameter sharding
buildwithsuhana 2652f26
fixing tensor parallel keras
buildwithsuhana ccd215c
fixing tensor parallel keras
buildwithsuhana cd28f06
fixing tensor parallel keras
buildwithsuhana 3c003a5
fixing tensor parallel keras
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -896,3 +896,183 @@ def set_distribution(value): | |
| value: a `Distribution` instance. | ||
| """ | ||
| global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) | ||
|
|
||
|
|
||
| @keras_export("keras.distribution.AutoTPDistribution") | ||
| class AutoTPDistribution(Distribution): | ||
| """A distribution strategy for automated tensor and data parallelism. | ||
|
|
||
| This distribution strategy provides a high-level abstraction for combining | ||
| both data parallelism and tensor parallelism. It automatically shards Keras | ||
| model's layers across multiple devices (tensor parallelism) while also | ||
| distributing the input data across those devices (data parallelism). | ||
|
|
||
| It uses a `DeviceMesh` to represent the grid of computational devices. If no | ||
| mesh is provided, it creates one using all available devices. The mesh must | ||
| have a 'data' axis for data sharding and a 'model' axis for model sharding. | ||
|
|
||
| Internally, this class wraps the user-provided Keras `Model` with the | ||
| `TensorParallelKeras` utility to handle the model sharding. | ||
|
|
||
| Args: | ||
| model: A `keras.Model` instance to be distributed. | ||
| device_mesh: (Optional) A `keras.distribution.DeviceMesh` instance. | ||
| If not provided, a `DeviceMesh` will be automatically created using | ||
| all available devices, arranging them for both data and model | ||
| parallelism. | ||
| auto_shard_dataset: (Optional) A boolean indicating whether to | ||
| automatically shard `tf.data.Dataset` instances across multiple | ||
| processes. Defaults to `True`. | ||
|
|
||
| Attributes: | ||
| model: The wrapped, tensor-parallel `keras.Model` instance that is | ||
| ready for distributed training. | ||
| device_mesh: The `DeviceMesh` instance used for distribution. | ||
|
|
||
| Raises: | ||
| RuntimeError: If no computational devices are found and `device_mesh` | ||
| is not provided. | ||
| ValueError: If the provided `device_mesh` does not have a 'data' axis. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| # Create a simple Keras model | ||
| inputs = keras.Input(shape=(64,)) | ||
| x = keras.layers.Dense(128, activation="relu")(inputs) | ||
| outputs = keras.layers.Dense(10)(x) | ||
| model = keras.Model(inputs=inputs, outputs=outputs) | ||
|
|
||
| # Create the distribution strategy with the model | ||
| # It will automatically use all available GPUs/TPUs | ||
| distribution = keras.distribution.AutoTPDistribution(model) | ||
|
|
||
| # The distributed model is accessed via the .model attribute | ||
| distributed_model = distribution.model | ||
|
|
||
| # Compile the model as usual | ||
| distributed_model.compile(optimizer="adam", loss="mse") | ||
|
|
||
| # Prepare a dataset | ||
| input_data = np.random.rand(32, 64) | ||
| target_data = np.random.rand(32, 10) | ||
|
|
||
| # Train the model | ||
| distributed_model.fit(input_data, target_data) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self, model, device_mesh=None, auto_shard_dataset=True): | ||
| if device_mesh is None: | ||
| all_devices = list_devices() | ||
| if not all_devices: | ||
| raise RuntimeError("No computational devices found.") | ||
| device_mesh = DeviceMesh( | ||
| shape=(1, len(all_devices)), | ||
| axis_names=("data", "model"), | ||
| devices=all_devices, | ||
| ) | ||
|
|
||
| if "data" not in device_mesh.axis_names: | ||
| raise ValueError( | ||
| "DeviceMesh for AutoTPDistribution must have a 'data' axis." | ||
| ) | ||
| batch_dim_name = "data" | ||
|
|
||
| super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) | ||
|
|
||
| self._original_model = model | ||
| self._num_process = distribution_lib.num_processes() | ||
| self._process_id = distribution_lib.process_id() | ||
| self._is_multi_process = self._num_process > 1 | ||
| from keras.src.distribution.tensor_parallel.tensor_parallel import ( | ||
| TensorParallelKeras, | ||
| ) | ||
|
|
||
| self.model = TensorParallelKeras( | ||
| model=self._original_model, | ||
| world_size=np.prod(self.device_mesh.shape), | ||
| device_ids=self.device_mesh.devices.flatten().tolist(), | ||
| ) | ||
|
|
||
| def get_data_layout(self, data_shape): | ||
| data_shard_spec = [None] * len(data_shape) | ||
| data_shard_spec[0] = self.batch_dim_name | ||
| return TensorLayout(data_shard_spec, self.device_mesh) | ||
|
|
||
| def get_variable_layout(self, variable): | ||
| warnings.warn( | ||
| "Variable layout is determined automatically within " | ||
| "AutoTPDistribution. This method will return a replicated layout." | ||
| ) | ||
| return TensorLayout([None] * len(variable.shape), self.device_mesh) | ||
|
|
||
| def get_tensor_layout(self, path): | ||
| return None | ||
|
Comment on lines
+1010
to
+1011
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is your injection point. This is where you should return the output layout from your |
||
|
|
||
| def distribute_dataset(self, dataset): | ||
| """Distributes the dataset across processes based on the device mesh.""" | ||
| if not self._is_multi_process or not self.auto_shard_dataset: | ||
| return dataset | ||
|
|
||
| from keras.src.utils.module_utils import tensorflow as tf | ||
|
|
||
| if not tf.available or not isinstance(dataset, tf.data.Dataset): | ||
| raise ValueError( | ||
| "Only `tf.data.Dataset` is supported for auto-sharding, " | ||
| f"got {type(dataset)}" | ||
| ) | ||
|
|
||
| from tensorflow.python.data.experimental.ops import ( | ||
| distribute as tf_data_distribute, | ||
| ) | ||
|
|
||
| global_batch_size = tf_data_distribute.compute_batch_size(dataset) | ||
| if global_batch_size.numpy() < 0: | ||
| raise ValueError( | ||
| "The batch size of the input dataset is unknown. " | ||
| "Please configure the batch size for the input dataset, " | ||
| "e.g., via `dataset.batch(batch_size)`" | ||
| ) | ||
|
|
||
| mesh_batch_dim_index = self.device_mesh.axis_names.index( | ||
| self.batch_dim_name | ||
| ) | ||
| num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] | ||
|
|
||
| if num_model_replicas == 1: | ||
| return dataset.prefetch(tf.data.AUTOTUNE) | ||
|
|
||
| num_model_replicas_per_process = num_model_replicas / self._num_process | ||
| if num_model_replicas_per_process >= 1: | ||
| if global_batch_size % self._num_process != 0: | ||
| raise ValueError( | ||
| "Global batch size must be divisible by the number of " | ||
| f"processes. `global_batch_size`={global_batch_size} and " | ||
| f"`num_process`={self._num_process}" | ||
| ) | ||
| per_process_batch_size = global_batch_size // self._num_process | ||
| distributed_dataset = dataset.rebatch(per_process_batch_size) | ||
| distributed_dataset = distributed_dataset.shard( | ||
| num_shards=self._num_process, | ||
| index=self._process_id, | ||
| ) | ||
| return distributed_dataset.prefetch(tf.data.AUTOTUNE) | ||
| else: | ||
| if global_batch_size % num_model_replicas != 0: | ||
| raise ValueError( | ||
| "Global batch size must be divisible by the number of " | ||
| f"replicas. `global_batch_size`={global_batch_size} and " | ||
| f"`num_model_replicas`={num_model_replicas}" | ||
| ) | ||
| per_replica_batch_size = global_batch_size // num_model_replicas | ||
| distributed_dataset = dataset.rebatch(per_replica_batch_size) | ||
|
|
||
| processes_per_replica = self._num_process // num_model_replicas | ||
| data_shard_id = self._process_id // processes_per_replica | ||
|
|
||
| distributed_dataset = distributed_dataset.shard( | ||
| num_shards=num_model_replicas, | ||
| index=data_shard_id, | ||
| ) | ||
| return distributed_dataset.prefetch(tf.data.AUTOTUNE) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is your injection point. This is where you should return the layout for each variable from your LayoutMap.