From 943d625b42b515ae43b04ea47e7962c261ffdc10 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Mon, 23 Dec 2024 11:42:35 -0800 Subject: [PATCH] Simplify the input argument of `_construct_global_block_info_list` Summary: Instead of passing the entire `buffer_size_ranks` tuple, the `group_source_ranks` tuple is extracted from it and passed as an argument. This simplifies the code and makes it more readable. Furthermore, this might help the consolidation of `_construct_global_block_info_list` and `_construct_local_block_info_list` in the future. Differential Revision: D67606282 --- .../utils/shampoo_ddp_distributor.py | 19 ++++++++++----- .../utils/shampoo_hsdp_distributor.py | 23 +++++++++++++++---- .../utils/shampoo_hybrid_shard_distributor.py | 23 +++++++++++++++---- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index eb76ff8..c8f4613 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -107,7 +107,9 @@ def __init__( ) global_block_info_list = self._construct_global_block_info_list( - buffer_size_ranks + group_source_ranks=tuple( + group_source_rank for _, group_source_rank in buffer_size_ranks + ) ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( @@ -263,14 +265,19 @@ def _distribute_buffer_sizes( @torch.no_grad() def _construct_global_block_info_list( - self, buffer_size_ranks: tuple[tuple[int, int], ...] + self, group_source_ranks: tuple[int, ...] ) -> tuple[DDPBlockInfo, ...]: """Construct the global block info list. + This method creates a list of DDPBlockInfo objects, which contain information + about each parameter block, including its composable block IDs, a function to + allocate zero tensors, a method to retrieve tensors, and the group source rank. + Args: - buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size - and an assigned rank for each block. + group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block. + Returns: + tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block. """ return tuple( DDPBlockInfo( @@ -298,9 +305,9 @@ def _construct_global_block_info_list( generate_pairwise_indices(self._global_num_blocks_per_param), strict=True, ) - for block_index, (_, group_source_rank) in enumerate( + for block_index, group_source_rank in enumerate( islice( - buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end ) ) ) diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index e8042ce..5e4e010 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -201,7 +201,9 @@ def __init__( ) global_block_info_list = self._construct_global_block_info_list( - buffer_size_ranks + group_source_ranks=tuple( + group_source_rank for _, group_source_rank in buffer_size_ranks + ) ) # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( @@ -375,9 +377,20 @@ def _construct_composable_block_ids( @torch.no_grad() def _construct_global_block_info_list( - self, buffer_size_ranks: tuple[tuple[int, int], ...] + self, group_source_ranks: tuple[int, ...] ) -> tuple[DDPBlockInfo, ...]: - """Construct global block info list from param_group and num_blocks_within_param.""" + """Construct the global block info list. + + This method creates a list of DDPBlockInfo objects, which contain information + about each parameter block, including its composable block IDs, a function to + allocate zero tensors, a method to retrieve tensors, and the group source rank. + + Args: + group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block. + + Returns: + tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block. + """ # Note that for HSDP, we want to get the rank within each sharded group for the block id. # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. sharded_group_rank = self._hsdp_device_mesh.get_local_rank(1) @@ -408,9 +421,9 @@ def _construct_global_block_info_list( generate_pairwise_indices(self._global_num_blocks_per_param), strict=True, ) - for block_index, (_, group_source_rank) in enumerate( + for block_index, group_source_rank in enumerate( islice( - buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end ) ) ) diff --git a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py index 795c36d..42f3422 100644 --- a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py @@ -187,7 +187,9 @@ def __init__( ) global_block_info_list = self._construct_global_block_info_list( - buffer_size_ranks + group_source_ranks=tuple( + group_source_rank for _, group_source_rank in buffer_size_ranks + ) ) # Initialize selectors and local blocked (masked) parameters. @@ -379,9 +381,20 @@ def _construct_composable_block_ids( @torch.no_grad() def _construct_global_block_info_list( - self, buffer_size_ranks: tuple[tuple[int, int], ...] + self, group_source_ranks: tuple[int, ...] ) -> tuple[DDPBlockInfo, ...]: - """Construct global block info list from param_group and num_blocks_within_param.""" + """Construct the global block info list. + + This method creates a list of DDPBlockInfo objects, which contain information + about each parameter block, including its composable block IDs, a function to + allocate zero tensors, a method to retrieve tensors, and the group source rank. + + Args: + group_source_ranks (tuple[int, ...]): A list of assigned ranks for each block. + + Returns: + tuple[DDPBlockInfo, ...]: A tuple of DDPBlockInfo objects for each parameter block. + """ # Call `super()` instead of `self` as a performance optimization. # This leads to O(1) instead of O(N) complexity to retrieve the parameters. non_empty_params: Iterable[DTensor] = filter( @@ -419,9 +432,9 @@ def _construct_global_block_info_list( generate_pairwise_indices(self._global_num_blocks_per_param), strict=True, ) - for block_index, (_, group_source_rank) in enumerate( + for block_index, group_source_rank in enumerate( islice( - buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + group_source_ranks, buffer_size_ranks_start, buffer_size_ranks_end ) ) )