-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
buildwithsuhana
wants to merge
45
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:tensor_parallel
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+621
−0
Open
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
06bb3bb
Adding tensor layout for TP autosharding
buildwithsuhana 41f8025
formatting files
buildwithsuhana e74eab2
Updating the docstring
buildwithsuhana 2cddf39
refactoring the code
buildwithsuhana fee036e
Merge branch 'tensor_parallel' of https://github.com/buildwithsuhana/…
buildwithsuhana 9bed6e4
Merge branch 'keras-team:master' into tensor_parallel
buildwithsuhana 5365f14
fixing test
buildwithsuhana bc4d094
fixing test
buildwithsuhana 4d32e49
adding autoconfig and coordinated_optimizer
buildwithsuhana 119ac15
updating docstrings and code format
buildwithsuhana 7851615
refactored autoconfig to not use recursion
buildwithsuhana 4707c2b
updating docstrings
buildwithsuhana 45aa44c
removing redundancies
buildwithsuhana 8bb39f6
added tests for autoconfig and coordinated optimizer
buildwithsuhana ab444b1
fixing autoconfig
buildwithsuhana d5612eb
fixing autoconfig
buildwithsuhana a777178
ficing autoconfig test
buildwithsuhana 7b144d9
fixing tensor layout and core
buildwithsuhana 12b038a
running pre commit
buildwithsuhana d9eabc8
adding test
buildwithsuhana 74437c9
adding test
buildwithsuhana 6eeb589
Fixing autoconfig
buildwithsuhana 207a4bf
fixing coordinated_optimizer
buildwithsuhana 17cf142
fixing tests
buildwithsuhana f834eca
fixed all comments and tests
buildwithsuhana e30b201
Merge branch 'keras-team:master' into tensor_parallel
buildwithsuhana 1ff9e48
fixing tests
buildwithsuhana ef537ff
Merge branch 'tensor_parallel' of https://github.com/buildwithsuhana/…
buildwithsuhana 89f5e77
reverting changes
buildwithsuhana 0973f69
fixing test
buildwithsuhana 68e8862
removing redundant lines
buildwithsuhana 5d82be8
bringing tests to similar format
buildwithsuhana 82fc41e
fixing dtype issue
buildwithsuhana 2801325
fixing test
buildwithsuhana 88d70c8
fixing test
buildwithsuhana 7e8aded
fixing comments
buildwithsuhana f3fe7ac
fixing format
buildwithsuhana 7b444ec
fixes
buildwithsuhana f608974
simplified coordinated_optimizer
buildwithsuhana 99a9885
fixing dtype error
buildwithsuhana 2f425c6
fixing dtype
buildwithsuhana bad0f22
fixing test
buildwithsuhana 5c0a2f2
fixing test
buildwithsuhana 9ee27c1
fixing test
buildwithsuhana 905a3ee
Removing coordinated_optimizer.py
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| import functools | ||
|
|
||
| from keras.src import layers | ||
| from keras.src.backend import distribution_lib | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import ( | ||
| split_tensor_for_parallelism, | ||
| ) | ||
|
|
||
|
|
||
| def analyze_dense_layer(layer): | ||
| """Classifies a Dense layer based on its input/output dimensions. | ||
|
|
||
| This function uses a heuristic to determine if a Dense layer acts as an | ||
| 'up_projection' (expansion), a 'down_projection' (contraction), or a | ||
| standard 'dense' layer. This classification is used to determine the | ||
| appropriate sharding strategy (e.g., column-parallel vs row-parallel). | ||
|
|
||
| Args: | ||
| layer: The Keras Dense layer instance to analyze. | ||
|
|
||
| Returns: | ||
| str: One of 'up_projection', 'down_projection', or 'dense'. | ||
| """ | ||
| input_dim = None | ||
| output_dim = None | ||
|
|
||
| kernel = getattr(layer, "kernel", getattr(layer, "_kernel", None)) | ||
| if kernel is not None: | ||
| if len(kernel.shape) == 2: | ||
| input_dim = kernel.shape[0] | ||
| output_dim = kernel.shape[1] | ||
|
|
||
| if output_dim is None and hasattr(layer, "units"): | ||
| output_dim = layer.units | ||
|
|
||
| if ( | ||
| input_dim is None | ||
| and hasattr(layer, "input_shape") | ||
| and layer.input_shape | ||
| and len(layer.input_shape) > 1 | ||
| ): | ||
| input_dim = layer.input_shape[-1] | ||
|
|
||
| if input_dim is None or output_dim is None: | ||
| return "dense" | ||
|
|
||
| expansion_threshold = 1.5 | ||
| is_expansion = output_dim > input_dim * expansion_threshold | ||
| is_contraction = input_dim > output_dim * expansion_threshold | ||
|
|
||
| if is_expansion: | ||
| return "up_projection" | ||
| elif is_contraction: | ||
| return "down_projection" | ||
| else: | ||
| return "dense" | ||
|
|
||
|
|
||
| def _reduce_sum(x): | ||
| """Performs an all-reduce sum operation across the 'model' mesh axis. | ||
|
|
||
| Args: | ||
| x: The input tensor to reduce. | ||
|
|
||
| Returns: | ||
| The reduced tensor, summed across all devices in the model axis. | ||
| """ | ||
| return distribution_lib.all_reduce(x, op="sum", axis_name="model") | ||
|
|
||
|
|
||
| def _gather(x, axis): | ||
| """Performs an all-gather operation across the 'model' mesh axis. | ||
|
|
||
| Args: | ||
| x: The input tensor shard to gather. | ||
| axis: The axis along which to concatenate the gathered parts. | ||
|
|
||
| Returns: | ||
| The gathered tensor, concatenated along the specified axis. | ||
| """ | ||
| return distribution_lib.all_gather(x, axis=axis, axis_name="model") | ||
|
|
||
|
|
||
| def _apply_layer_sharding_rules(layer, device_count, state_rules, output_rules): | ||
| """Applies sharding rules to a single layer based on its type. | ||
|
|
||
| This function populates `state_rules` and `output_rules` with strategies | ||
| specific to the layer class (e.g., Dense, EinsumDense, Embedding). It | ||
| determines how weights should be partitioned (state rules) and how outputs | ||
| should be synchronized (output rules). | ||
|
|
||
| Args: | ||
| layer: The Keras layer instance to configure. | ||
| device_count: The number of devices available for tensor parallelism. | ||
| state_rules: A dictionary mapping variable paths to sharding functions. | ||
| Updated in-place. | ||
| output_rules: A dictionary mapping layer paths to output communication | ||
| functions. Updated in-place. | ||
| """ | ||
|
|
||
| def split_rule(dim): | ||
| return functools.partial( | ||
| split_tensor_for_parallelism, device_count=device_count, dim=dim | ||
| ) | ||
|
|
||
| def gather_rule(axis): | ||
| return functools.partial(_gather, axis=axis) | ||
|
|
||
| layer_path = layer.path | ||
|
|
||
| if isinstance(layer, layers.Dense): | ||
| mlp_type = analyze_dense_layer(layer) | ||
|
|
||
| if mlp_type == "up_projection": | ||
| state_rules[id(layer.kernel)] = split_rule(dim=1) | ||
| if layer.use_bias: | ||
| state_rules[id(layer.bias)] = split_rule(dim=0) | ||
| output_rules[layer_path] = gather_rule(axis=-1) | ||
|
|
||
| elif mlp_type == "down_projection": | ||
| state_rules[id(layer.kernel)] = split_rule(dim=0) | ||
| output_rules[layer_path] = _reduce_sum | ||
|
|
||
| else: | ||
| state_rules[id(layer.kernel)] = split_rule(dim=1) | ||
| if layer.use_bias: | ||
| state_rules[id(layer.bias)] = split_rule(dim=0) | ||
| output_rules[layer_path] = gather_rule(axis=-1) | ||
|
|
||
| elif isinstance(layer, layers.EinsumDense): | ||
| if "attention_output" in layer.name: | ||
| state_rules[id(layer.kernel)] = split_rule(dim=0) | ||
| output_rules[layer_path] = _reduce_sum | ||
| else: | ||
| state_rules[id(layer.kernel)] = split_rule(dim=1) | ||
| if hasattr(layer, "bias") and layer.bias is not None: | ||
| state_rules[id(layer.bias)] = split_rule(dim=0) | ||
| output_rules[layer_path] = gather_rule(axis=-1) | ||
|
|
||
| elif ( | ||
| isinstance(layer, (layers.Embedding,)) | ||
| or "Embedding" in layer.__class__.__name__ | ||
| ): | ||
| embeddings_var = getattr(layer, "embeddings", None) | ||
| if embeddings_var is not None: | ||
| state_rules[id(embeddings_var)] = split_rule(dim=1) | ||
| output_rules[layer_path] = lambda x: x | ||
|
|
||
|
|
||
| def get_default_config(model, device_ids): | ||
| """Generates a default tensor parallelism configuration for a model. | ||
|
|
||
| This function traverses the model's layer hierarchy and | ||
| automatically generates a `LayoutMap`. This map contains: | ||
| 1. `state_rules`: How to shard the weights of supported layers | ||
| across the specified devices. | ||
| 2. `output_rules`: How to synchronize or gather the outputs of | ||
| these layers during the forward pass. | ||
|
|
||
| Args: | ||
| model: The Keras model to configure. | ||
| device_ids: A list of device identifiers to be used | ||
| for distribution. | ||
|
|
||
| Returns: | ||
| LayoutMap: A configuration object containing `state_rules` and | ||
| `output_rules` for tensor parallelism. | ||
| """ | ||
| device_count = len(device_ids) | ||
| state_rules = {} | ||
| output_rules = {} | ||
|
|
||
| for layer in model._flatten_layers(recursive=True, include_self=True): | ||
| _apply_layer_sharding_rules( | ||
| layer, device_count, state_rules, output_rules | ||
| ) | ||
|
|
||
| return LayoutMap(state_rules=state_rules, output_rules=output_rules) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.