diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 97a898a..1e255a0 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -22,7 +22,6 @@ FullyShardShampooConfig, GraftingConfig, HSDPShampooConfig, - PrecisionConfig, PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, @@ -58,7 +57,6 @@ "FullyShardShampooConfig", "HSDPShampooConfig", # `precision_config`. - "PrecisionConfig", # `preconditioner_config` options. "PreconditionerConfig", # Abstract base class. "ShampooPreconditionerConfig", # Based on `PreconditionerConfig`. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index bdade04..7f7c9bc 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -7,7 +7,6 @@ """ -import contextlib import dataclasses import logging from collections.abc import Callable, Iterator, Sequence @@ -37,6 +36,7 @@ GRAFTING_PRECONDITIONER_LIST, GraftingConfig, HSDPShampooConfig, + HybridShardShampooConfig, INV_ROOT_OVERRIDE, LR, MASKED_BLOCKED_GRADS, @@ -47,10 +47,9 @@ MOMENTUM, MOMENTUM_LIST, PARAMS, - PRECISION_CONFIG, - PrecisionConfig, PRECONDITION_FREQUENCY, PRECONDITIONER_CONFIG, + PRECONDITIONER_DTYPE, PreconditionerConfig, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, @@ -80,19 +79,16 @@ FullyShardDistributor, ) from distributed_shampoo.utils.shampoo_hsdp_distributor import HSDPDistributor +from distributed_shampoo.utils.shampoo_hybrid_shard_distributor import ( + HybridShardDistributor, +) from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, - DequantizePreconditionersContext, EigenvalueCorrectedShampooPreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) -from distributed_shampoo.utils.shampoo_quantization import ( - DequantizeQuantizedTensorListContext, - QuantizedTensor, - QuantizedTensorList, -) from distributed_shampoo.utils.shampoo_utils import compress_list from matrix_functions_types import EigenConfig, RootInvConfig @@ -283,20 +279,12 @@ class DistributedShampoo(torch.optim.Optimizer): grafting_config (GraftingConfig | None): Configuration for grafting method. If None, ignores grafting. (Default: None) use_merge_dims (bool): Merge dimensions if possible while respecting max_preconditioner_dim. (Default: True) - use_pytorch_compile (bool | None): Use PyTorch 2.0 compiler feature to speed up training. Deprecating, please use - shampoo_pt2_compile_config instead; when this field is None, the use of PyTorch 2.0 compiler is decided by - shampoo_pt2_compile_config. (Default: None) shampoo_pt2_compile_config (ShampooPT2CompileConfig | None): Configuration for Shampoo PT2 compilation. If None, ignores compilation, and Shampoo will run in eager mode. (Default: None) distributed_config (DistributedConfig | None): Configuration for applying Shampoo to different distributed training frameworks, such as distributed-data parallel (DDP) training. Based on the configuration, determines which version of Shampoo to use. (Default: None) - preconditioner_dtype (torch.dtype | None): **DEPRECATING** Data type for preconditioner. (Default: None) - precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) - use_protected_eigh (bool): **DEPRECATED** Flag for using two guards to prevent failures of torch.linalg.eigh. (Default: True) - 1. Attempts to compute root inverse in preconditioner_dtype precision. - 2. Attempts to recompute the eigendecomposition in higher precision if using lower-precision fails. - 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. + preconditioner_dtype (torch.dtype): Data type for preconditioner. (Default: None) track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. @@ -326,12 +314,9 @@ def __init__( use_decoupled_weight_decay: bool = True, grafting_config: GraftingConfig | None = None, use_merge_dims: bool = True, - use_pytorch_compile: bool | None = None, shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None, distributed_config: DistributedConfig | None = None, - preconditioner_dtype: torch.dtype | None = None, - precision_config: PrecisionConfig | None = None, - use_protected_eigh: bool = True, + preconditioner_dtype: torch.dtype = torch.float, track_root_inv_residuals: bool = False, preconditioner_config: PreconditionerConfig = DefaultShampooConfig, ) -> None: @@ -410,43 +395,12 @@ def __init__( "Continuing without using momentum or Nesterov acceleration..." ) - # Deprecation warning for use_pytorch_compile - if use_pytorch_compile is not None: - if use_pytorch_compile and shampoo_pt2_compile_config is None: - shampoo_pt2_compile_config = ShampooPT2CompileConfig() - logger.warning( - "use_pytorch_compile is deprecating. Please use shampoo_pt2_compile_config instead." - ) - elif use_pytorch_compile and shampoo_pt2_compile_config is not None: - raise ValueError( - "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." - ) - elif not use_pytorch_compile and shampoo_pt2_compile_config is not None: - raise ValueError( - "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." - ) - # Provide error for system Pytorch compile availability if shampoo_pt2_compile_config is not None and not torch.cuda.is_available(): raise ValueError( - "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + "Backend does NOT support Pytorch 2.0 compile. Switch to shampoo_pt2_compile_config=None." ) - # Deprecation warning for preconditioner_dtype - if preconditioner_dtype is not None: - if precision_config is None: - precision_config = PrecisionConfig( - factor_matrix_dtype=preconditioner_dtype, - factor_matrix_computation_dtype=preconditioner_dtype, - ) - logger.warning( - "preconditioner_dtype is deprecated. Please use precision_config instead." - ) - else: - raise ValueError( - "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." - ) - amortized_computation_config = ( preconditioner_config.amortized_computation_config ) @@ -457,10 +411,6 @@ def __init__( f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig." ) - # Create default precision config if it is not provided. - if precision_config is None: - precision_config = PrecisionConfig() - # Set exponent multiplier if this is not provided. if ( isinstance(amortized_computation_config, EigenConfig) @@ -493,7 +443,7 @@ def __init__( USE_DECOUPLED_WEIGHT_DECAY: use_decoupled_weight_decay, GRAFTING_CONFIG: grafting_config, USE_MERGE_DIMS: use_merge_dims, - PRECISION_CONFIG: precision_config, + PRECONDITIONER_DTYPE: preconditioner_dtype, PRECONDITIONER_CONFIG: preconditioner_config, }, ) @@ -508,7 +458,7 @@ def __init__( # Block parameters and instantiate optimizer states. self._instantiate_distributor(distributed_config) - self._instantiate_shampoo_preconditioner_list(use_protected_eigh) + self._instantiate_shampoo_preconditioner_list() self._instantiate_grafting() self._instantiate_steps() self._instantiate_momentum() @@ -535,6 +485,11 @@ def _instantiate_distributor( HSDPDistributor, distributed_config=distributed_config, ) # type: ignore[assignment] + elif type(distributed_config) is HybridShardShampooConfig: + distributor = partial( + HybridShardDistributor, + distributed_config=distributed_config, + ) # type: ignore[assignment] else: raise NotImplementedError(f"{distributed_config=} not supported!") @@ -559,9 +514,7 @@ def _instantiate_distributor( state_lists[PREVIOUS_GRAD_SELECTOR] = None @torch.no_grad() - def _instantiate_shampoo_preconditioner_list( - self, use_protected_eigh: bool - ) -> None: + def _instantiate_shampoo_preconditioner_list(self) -> None: for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): @@ -583,12 +536,11 @@ def _instantiate_shampoo_preconditioner_list( block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, preconditioner_config=group[PRECONDITIONER_CONFIG], - precision_config=group[PRECISION_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], inv_root_override=group[INV_ROOT_OVERRIDE], use_bias_correction=group[USE_BIAS_CORRECTION], - use_protected_eigh=use_protected_eigh, + factor_matrix_dtype=group[PRECONDITIONER_DTYPE], ) @torch.no_grad() @@ -612,7 +564,6 @@ def _instantiate_grafting(self) -> None: state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, - precision_config=group[PRECISION_CONFIG], beta2=( 1.0 if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig @@ -667,24 +618,19 @@ def _instantiate_momentum(self) -> None: "for the correctness of block_index." block_state = self.state[block_info.param][block_index] - block_state[MOMENTUM] = QuantizedTensor( - block_info.allocate_zeros_tensor( - shape=block.size(), - dtype=group[PRECISION_CONFIG].momentum_dtype, - device=block.device, - ), - block_info, + block_state[MOMENTUM] = block_info.allocate_zeros_tensor( + shape=block.size(), + dtype=block.dtype, + device=block.device, + ) + global_momentum_list.append( + block_info.get_tensor(block_state[MOMENTUM]) ) - global_momentum_list.append(block_state[MOMENTUM]) # We compress the momentum list to only the locally-owned parameter states. - state_lists[MOMENTUM_LIST] = QuantizedTensorList( - compress_list( - global_momentum_list, - state_lists[DISTRIBUTOR].distributor_selector, - ), - group[PRECISION_CONFIG].momentum_dtype, - group[PRECISION_CONFIG].computation_dtype, + state_lists[MOMENTUM_LIST] = compress_list( + global_momentum_list, + state_lists[DISTRIBUTOR].distributor_selector, ) # Here, we set masked momentum list to momentum list because we assume # all parameters are active. @@ -715,24 +661,19 @@ def _instantiate_filtered_grads(self) -> None: "Distributor for the correctness of block_index." block_state = self.state[block_info.param][block_index] - block_state[FILTERED_GRAD] = QuantizedTensor( - block_info.allocate_zeros_tensor( - shape=block.size(), - dtype=group[PRECISION_CONFIG].filtered_grad_dtype, - device=block.device, - ), - block_info, + block_state[FILTERED_GRAD] = block_info.allocate_zeros_tensor( + shape=block.size(), + dtype=block.dtype, + device=block.device, + ) + global_filtered_grad_list.append( + block_info.get_tensor(block_state[FILTERED_GRAD]) ) - global_filtered_grad_list.append(block_state[FILTERED_GRAD]) # We compress the momentum list to only the locally-owned parameter states. - state_lists[FILTERED_GRAD_LIST] = QuantizedTensorList( - compress_list( - global_filtered_grad_list, - state_lists[DISTRIBUTOR].distributor_selector, - ), - group[PRECISION_CONFIG].filtered_grad_dtype, - group[PRECISION_CONFIG].computation_dtype, + state_lists[FILTERED_GRAD_LIST] = compress_list( + global_filtered_grad_list, + state_lists[DISTRIBUTOR].distributor_selector, ) # Here, we set masked filtered grad list to filtered grad list because we assume # all parameters are active. @@ -789,10 +730,29 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non ): return - if state_lists[STEP].item() >= 1: + # Warning for potential PT2 recompile due to gradient selector change. + # This warning is expected in either training from scratch or reloading from a checkpoint, as state_lists[PREVIOUS_GRAD_SELECTOR] is initialized to `None`, triggering this warning. + if state_lists[PREVIOUS_GRAD_SELECTOR] is not None: + grad_selector_different = [ + a ^ b + for a, b in zip( + state_lists[DISTRIBUTOR].local_grad_selector, + state_lists[PREVIOUS_GRAD_SELECTOR], + strict=True, + ) + ] + mismatch_grad_selector_indices = [ + i + for i, is_grad_selector_different in enumerate(grad_selector_different) + if is_grad_selector_different + ] logger.warning( - "PT2 will recompile because the gradient selction of model parameters have changed from the previous step. Possible reasons include some gradients are None. If this is not intended, please check the data and/or model." + f"""PT2 will recompile because the gradient selction of model parameters have changed from the previous step. Possible reasons include some gradients are None. If this is not intended, please check the data and/or model. + Details: + - Current step: {state_lists[STEP].item()} + - Changed gradient selector indices: {mismatch_grad_selector_indices}""" ) + # Updates masked state lists if previous block selector disagrees with current selector. # State list compression is necessary in order to avoid handling gradients with None. state_lists[PREVIOUS_GRAD_SELECTOR] = state_lists[ @@ -809,13 +769,13 @@ def _mask_state_lists(state_lists: dict[str, Any], group: dict[str, Any]) -> Non local_grad_selector=state_lists[DISTRIBUTOR].local_grad_selector, ) if group[BETAS][0] != 0.0: - state_lists[MASKED_FILTERED_GRAD_LIST] = state_lists[ - FILTERED_GRAD_LIST - ].compress( + state_lists[MASKED_FILTERED_GRAD_LIST] = compress_list( + state_lists[FILTERED_GRAD_LIST], state_lists[DISTRIBUTOR].local_grad_selector, ) if group[MOMENTUM] != 0.0: - state_lists[MASKED_MOMENTUM_LIST] = state_lists[MOMENTUM_LIST].compress( + state_lists[MASKED_MOMENTUM_LIST] = compress_list( + state_lists[MOMENTUM_LIST], state_lists[DISTRIBUTOR].local_grad_selector, ) @@ -833,11 +793,9 @@ def _compute_and_log_root_inverse_residuals( for (group_index, group), state_lists in zip( enumerate(self.param_groups), self._per_group_state_lists, strict=True ): - # TODO: update values depending on both factor_matrix_dtype and inv_factor_matrix_dtype - # Get expected relative errors/residuals for debugging purposes - if group[PRECISION_CONFIG].inv_factor_matrix_dtype == torch.float64: + if group[PRECONDITIONER_DTYPE] == torch.float64: expected_relative_error = 1e-7 - elif group[PRECISION_CONFIG].inv_factor_matrix_dtype == torch.float32: + elif group[PRECONDITIONER_DTYPE] == torch.float32: expected_relative_error = 1e-3 else: logger.warning( @@ -960,34 +918,31 @@ def _compute_filtered_grad_list( use_bias_correction: bool, ) -> tuple[torch.Tensor, ...]: if beta1 != 0.0: - with DequantizeQuantizedTensorListContext( - quantized_tensor_list=state_lists[MASKED_FILTERED_GRAD_LIST] - ): - # Computes filtered gradient or EMA of the gradients with respect to beta3 if beta3 != beta1. - masked_filtered_grad_list = ( - torch._foreach_lerp( - state_lists[MASKED_FILTERED_GRAD_LIST].dequantized_value, - state_lists[MASKED_BLOCKED_GRADS], - weight=1 - beta3, - ) - if beta3 != beta1 - else state_lists[MASKED_FILTERED_GRAD_LIST].dequantized_value - ) - - # Update EMA of the gradients (with respect to beta1). - torch._foreach_lerp_( - state_lists[MASKED_FILTERED_GRAD_LIST].dequantized_value, + # Computes filtered gradient or EMA of the gradients with respect to beta3 if beta3 != beta1. + masked_filtered_grad_list = ( + torch._foreach_lerp( + state_lists[MASKED_FILTERED_GRAD_LIST], state_lists[MASKED_BLOCKED_GRADS], - weight=1 - beta1, + weight=1 - beta3, ) + if beta3 != beta1 + else state_lists[MASKED_FILTERED_GRAD_LIST] + ) - # Apply bias correction if necessary. - if use_bias_correction: - bias_correction1 = 1.0 - beta3 * beta1 ** (step - 1) - masked_filtered_grad_list = torch._foreach_div( - masked_filtered_grad_list, - bias_correction1, - ) + # Update EMA of the gradients (with respect to beta1). + torch._foreach_lerp_( + state_lists[MASKED_FILTERED_GRAD_LIST], + state_lists[MASKED_BLOCKED_GRADS], + weight=1 - beta1, + ) + + # Apply bias correction if necessary. + if use_bias_correction: + bias_correction1 = 1.0 - beta3 * beta1 ** (step - 1) + masked_filtered_grad_list = torch._foreach_div( + masked_filtered_grad_list, + bias_correction1, + ) else: masked_filtered_grad_list = state_lists[MASKED_BLOCKED_GRADS] @@ -1020,34 +975,29 @@ def _update_momentum( ) -> None: # Update momentum optimizer state and use momentum / Nesterov if enabled. if momentum_param != 0.0: - with DequantizeQuantizedTensorListContext( - quantized_tensor_list=state_lists[MASKED_MOMENTUM_LIST] - ): + torch._foreach_mul_(state_lists[MASKED_MOMENTUM_LIST], momentum_param) + torch._foreach_add_( + state_lists[MASKED_MOMENTUM_LIST], + masked_blocked_search_directions, + alpha=1 - dampening, + ) + + # Incorporates Nesterov momentum. + if use_nesterov: torch._foreach_mul_( - state_lists[MASKED_MOMENTUM_LIST].dequantized_value, momentum_param + masked_blocked_search_directions, + 1 - dampening, ) torch._foreach_add_( - state_lists[MASKED_MOMENTUM_LIST].dequantized_value, masked_blocked_search_directions, - alpha=1 - dampening, + state_lists[MASKED_MOMENTUM_LIST], + alpha=momentum_param, + ) + else: + torch._foreach_copy_( + masked_blocked_search_directions, + state_lists[MASKED_MOMENTUM_LIST], ) - - # Incorporates Nesterov momentum. - if use_nesterov: - torch._foreach_mul_( - masked_blocked_search_directions, - 1 - dampening, - ) - torch._foreach_add_( - masked_blocked_search_directions, - state_lists[MASKED_MOMENTUM_LIST].dequantized_value, - alpha=momentum_param, - ) - else: - torch._foreach_copy_( - masked_blocked_search_directions, - state_lists[MASKED_MOMENTUM_LIST].dequantized_value, - ) @torch.no_grad() def _per_group_step_impl( @@ -1075,62 +1025,50 @@ def _per_group_step_impl( use_decoupled_weight_decay, ) - with ( - DequantizePreconditionersContext( - preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] - ), - ( - DequantizePreconditionersContext( - preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] - ) - if grafting_config_not_none - else contextlib.nullcontext() - ), - ): - # Update Shampoo and grafting preconditioners. - # Example for AdaGrad accumulation: - # 1. Update factor matrices/grafting preconditioners. - # L <- L + G * G^T - # R <- R + G^T * G - # V <- V + G^2 (element-wise) - # (and similar) - # 2. Compute root inverse if necessary. - # L_inv <- L ** (-1/4) - # R_inv <- R ** (-1/4) - # (and similar); - self._update_preconditioners( - state_lists=state_lists, - step=step, - perform_amortized_computation=perform_amortized_computation, - grafting_config_not_none=grafting_config_not_none, - ) + # Update Shampoo and grafting preconditioners. + # Example for AdaGrad accumulation: + # 1. Update factor matrices/grafting preconditioners. + # L <- L + G * G^T + # R <- R + G^T * G + # V <- V + G^2 (element-wise) + # (and similar) + # 2. Compute root inverse if necessary. + # L_inv <- L ** (-1/4) + # R_inv <- R ** (-1/4) + # (and similar); + self._update_preconditioners( + state_lists=state_lists, + step=step, + perform_amortized_computation=perform_amortized_computation, + grafting_config_not_none=grafting_config_not_none, + ) - # Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. - # Note that we use two beta factors here akin to Lion. - # G_bar <- beta3 * G_tilde + (1 - beta3) * G - # G_tilde <- beta1 * G_tilde + (1 - beta1) * G - masked_filtered_grad_list = self._compute_filtered_grad_list( - state_lists, - step, - beta1, - beta3, - use_bias_correction, - ) + # Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. + # Note that we use two beta factors here akin to Lion. + # G_bar <- beta3 * G_tilde + (1 - beta3) * G + # G_tilde <- beta1 * G_tilde + (1 - beta1) * G + masked_filtered_grad_list = self._compute_filtered_grad_list( + state_lists, + step, + beta1, + beta3, + use_bias_correction, + ) - # Precondition and graft filtered gradients. - # PT2 compile is currently disabled for preconditioning and grafting. - # NOTE: Preconditioning and grafting is not compatible with PT2 compile. - # - # P_shampoo <- L_inv * G_bar * R_inv (and similar) - # P_grafting <- G_bar / (sqrt(V) + epsilon) - # P <- P_grafting if step < start_preconditioning_step - # P <- ||P_grafting|| / ||P_shampoo|| * P_shampoo otherwise - masked_blocked_search_directions = self._precondition_and_grafting( - state_lists, - masked_filtered_grad_list, - use_grafting_method, - grafting_config_not_none, - ) + # Precondition and graft filtered gradients. + # PT2 compile is currently disabled for preconditioning and grafting. + # NOTE: Preconditioning and grafting is not compatible with PT2 compile. + # + # P_shampoo <- L_inv * G_bar * R_inv (and similar) + # P_grafting <- G_bar / (sqrt(V) + epsilon) + # P <- P_grafting if step < start_preconditioning_step + # P <- ||P_grafting|| / ||P_shampoo|| * P_shampoo otherwise + masked_blocked_search_directions = self._precondition_and_grafting( + state_lists, + masked_filtered_grad_list, + use_grafting_method, + grafting_config_not_none, + ) # Incorporate decoupled weight decay into search direction if enabled. # P <- P + weight_decay * W @@ -1189,19 +1127,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # t if not state_lists[MASKED_BLOCKED_GRADS]: continue - # Convert the gradient dtype to the computation dtype set in the precision_config if - # necessary. - # - # This conversion is needed because the blocked gradient list has float32 dtype, and we - # need to convert it to the desired precision before precondition computation. - if ( - computation_dtype := group[PRECISION_CONFIG].computation_dtype - ) != state_lists[MASKED_BLOCKED_GRADS][0].dtype: - state_lists[MASKED_BLOCKED_GRADS] = tuple( - tensor.to(dtype=computation_dtype) - for tensor in state_lists[MASKED_BLOCKED_GRADS] - ) - # Iterate group step counter and define Python scalar step. step = state_lists[STEP].add_(1) # NOTE: Wrap scalar of group[LR] into a 0D tensor to avoid PT2 recompilation; diff --git a/distributed_shampoo/examples/ddp_cifar10_example.py b/distributed_shampoo/examples/ddp_cifar10_example.py index f0430bc..65360aa 100644 --- a/distributed_shampoo/examples/ddp_cifar10_example.py +++ b/distributed_shampoo/examples/ddp_cifar10_example.py @@ -13,7 +13,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dist_checkpoint -from distributed_shampoo import DDPShampooConfig, DistributedShampoo, PrecisionConfig +from distributed_shampoo import DDPShampooConfig, DistributedShampoo from distributed_shampoo.examples.trainer_utils import ( get_data_loader_and_sampler, get_model_and_loss_fn, @@ -117,23 +117,12 @@ grafting_beta2=args.grafting_beta2, grafting_epsilon=args.grafting_epsilon, use_merge_dims=args.use_merge_dims, - use_pytorch_compile=args.use_pytorch_compile, distributed_config=DDPShampooConfig( communication_dtype=args.communication_dtype, num_trainers_per_group=args.num_trainers_per_group, communicate_params=args.communicate_params, ), - precision_config=PrecisionConfig( - computation_dtype=args.computation_dtype.value, - factor_matrix_dtype=args.factor_matrix_dtype.value, - inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, - corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, - factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, - filtered_grad_dtype=args.filtered_grad_dtype.value, - momentum_dtype=args.momentum_dtype.value, - grafting_state_dtype=args.grafting_state_dtype.value, - ), - use_protected_eigh=args.use_protected_eigh, + preconditioner_dtype=args.preconditioner_dtype, track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/default_cifar10_example.py b/distributed_shampoo/examples/default_cifar10_example.py index 53d3767..8fcfbc6 100644 --- a/distributed_shampoo/examples/default_cifar10_example.py +++ b/distributed_shampoo/examples/default_cifar10_example.py @@ -11,7 +11,6 @@ import os import torch -from distributed_shampoo import PrecisionConfig from distributed_shampoo.examples.trainer_utils import ( get_data_loader_and_sampler, @@ -132,19 +131,8 @@ def train_default_model( grafting_epsilon=args.grafting_epsilon, grafting_beta2=args.grafting_beta2, use_merge_dims=args.use_merge_dims, - use_pytorch_compile=args.use_pytorch_compile, distributed_config=None, - precision_config=PrecisionConfig( - computation_dtype=args.computation_dtype.value, - factor_matrix_dtype=args.factor_matrix_dtype.value, - inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, - corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, - factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, - filtered_grad_dtype=args.filtered_grad_dtype.value, - momentum_dtype=args.momentum_dtype.value, - grafting_state_dtype=args.grafting_state_dtype.value, - ), - use_protected_eigh=args.use_protected_eigh, + preconditioner_dtype=args.preconditioner_dtype, track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/fsdp_cifar10_example.py b/distributed_shampoo/examples/fsdp_cifar10_example.py index d94c1a7..cfedf46 100644 --- a/distributed_shampoo/examples/fsdp_cifar10_example.py +++ b/distributed_shampoo/examples/fsdp_cifar10_example.py @@ -12,11 +12,7 @@ import torch.distributed as dist -from distributed_shampoo import ( - compile_fsdp_parameter_metadata, - FSDPShampooConfig, - PrecisionConfig, -) +from distributed_shampoo import compile_fsdp_parameter_metadata, FSDPShampooConfig from distributed_shampoo.examples.trainer_utils import ( get_data_loader_and_sampler, get_model_and_loss_fn, @@ -115,21 +111,10 @@ grafting_epsilon=args.grafting_epsilon, grafting_beta2=args.grafting_beta2, use_merge_dims=args.use_merge_dims, - use_pytorch_compile=args.use_pytorch_compile, distributed_config=FSDPShampooConfig( param_to_metadata=compile_fsdp_parameter_metadata(model), ), - precision_config=PrecisionConfig( - computation_dtype=args.computation_dtype.value, - factor_matrix_dtype=args.factor_matrix_dtype.value, - inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, - corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, - factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, - filtered_grad_dtype=args.filtered_grad_dtype.value, - momentum_dtype=args.momentum_dtype.value, - grafting_state_dtype=args.grafting_state_dtype.value, - ), - use_protected_eigh=args.use_protected_eigh, + preconditioner_dtype=args.preconditioner_dtype, track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/fully_shard_cifar10_example.py b/distributed_shampoo/examples/fully_shard_cifar10_example.py index 7f1554a..7cae037 100644 --- a/distributed_shampoo/examples/fully_shard_cifar10_example.py +++ b/distributed_shampoo/examples/fully_shard_cifar10_example.py @@ -14,11 +14,7 @@ import torch.distributed as dist import torch.distributed.checkpoint as dist_checkpoint -from distributed_shampoo import ( - DistributedShampoo, - FullyShardShampooConfig, - PrecisionConfig, -) +from distributed_shampoo import DistributedShampoo, FullyShardShampooConfig from distributed_shampoo.examples.trainer_utils import ( get_data_loader_and_sampler, get_model_and_loss_fn, @@ -137,19 +133,8 @@ def create_model_and_optimizer_and_loss_fn(args, device): grafting_epsilon=args.grafting_epsilon, grafting_beta2=args.grafting_beta2, use_merge_dims=args.use_merge_dims, - use_pytorch_compile=args.use_pytorch_compile, distributed_config=FullyShardShampooConfig(), - precision_config=PrecisionConfig( - computation_dtype=args.computation_dtype.value, - factor_matrix_dtype=args.factor_matrix_dtype.value, - inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, - corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, - factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, - filtered_grad_dtype=args.filtered_grad_dtype.value, - momentum_dtype=args.momentum_dtype.value, - grafting_state_dtype=args.grafting_state_dtype.value, - ), - use_protected_eigh=args.use_protected_eigh, + preconditioner_dtype=args.preconditioner_dtype, track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/hsdp_cifar10_example.py b/distributed_shampoo/examples/hsdp_cifar10_example.py index 32cea95..f2403bb 100644 --- a/distributed_shampoo/examples/hsdp_cifar10_example.py +++ b/distributed_shampoo/examples/hsdp_cifar10_example.py @@ -12,11 +12,7 @@ import torch.distributed as dist -from distributed_shampoo import ( - compile_fsdp_parameter_metadata, - HSDPShampooConfig, - PrecisionConfig, -) +from distributed_shampoo import compile_fsdp_parameter_metadata, HSDPShampooConfig from distributed_shampoo.examples.trainer_utils import ( get_data_loader_and_sampler, @@ -128,23 +124,12 @@ grafting_epsilon=args.grafting_epsilon, grafting_beta2=args.grafting_beta2, use_merge_dims=args.use_merge_dims, - use_pytorch_compile=args.use_pytorch_compile, distributed_config=HSDPShampooConfig( param_to_metadata=compile_fsdp_parameter_metadata(model), device_mesh=device_mesh, num_trainers_per_group=args.num_trainers_per_group, ), - precision_config=PrecisionConfig( - computation_dtype=args.computation_dtype.value, - factor_matrix_dtype=args.factor_matrix_dtype.value, - inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value, - corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value, - factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value, - filtered_grad_dtype=args.filtered_grad_dtype.value, - momentum_dtype=args.momentum_dtype.value, - grafting_state_dtype=args.grafting_state_dtype.value, - ), - use_protected_eigh=args.use_protected_eigh, + preconditioner_dtype=args.preconditioner_dtype, track_root_inv_residuals=args.track_root_inv_residuals, preconditioner_computation_type=args.preconditioner_computation_type, ) diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 313d980..eb84fd6 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -30,7 +30,6 @@ DistributedConfig, DistributedShampoo, GraftingConfig, - PrecisionConfig, PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, @@ -195,16 +194,6 @@ def get_args(): action="store_true", help="Use merge dims for Shampoo.", ) - parser.add_argument( - "--use-pytorch-compile", - action="store_true", - help="Use PyTorch compile for Shampoo.", - ) - parser.add_argument( - "--use-protected-eigh", - action="store_true", - help="Uses protected eigendecomposition.", - ) parser.add_argument( "--track-root-inv-residuals", action="store_true", @@ -239,52 +228,10 @@ def get_args(): # Arguments for mixed-precision. parser.add_argument( - "--computation-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for all computation in Shampoo.", - ) - parser.add_argument( - "--factor-matrix-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing Shampoo factor matrices.", - ) - parser.add_argument( - "--inv-factor-matrix-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing Shampoo inverse factor matrices.", - ) - parser.add_argument( - "--corrected-eigenvalues-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing corrected eigenvalues of Shampoo preconditioner.", - ) - parser.add_argument( - "--factor-matrix-eigenvectors-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing Shampoo factor matrices eigenvectors.", - ) - parser.add_argument( - "--filtered-grad-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing filtered gradients.", - ) - parser.add_argument( - "--momentum-dtype", - type=lambda t: enum_type_parse(t, DType), - default=DType.FP32, - help="Data type for storing momentum states.", - ) - parser.add_argument( - "--grafting-state-dtype", + "--preconditioner-dtype", type=lambda t: enum_type_parse(t, DType), default=DType.FP32, - help="Data type for storing grafting preconditioners.", + help="Preconditioner dtype for Shampoo.", ) # Arguments for DDP Shampoo. @@ -438,10 +385,8 @@ def instantiate_optimizer( grafting_beta2: float, grafting_epsilon: float, use_merge_dims: bool, - use_pytorch_compile: bool, distributed_config: DistributedConfig | None, - precision_config: PrecisionConfig | None, - use_protected_eigh: bool, + preconditioner_dtype: DType, track_root_inv_residuals: bool, preconditioner_computation_type: PreconditionerComputationType, ) -> torch.optim.Optimizer: @@ -493,10 +438,8 @@ def instantiate_optimizer( grafting_type, grafting_beta2, grafting_epsilon ), use_merge_dims=use_merge_dims, - use_pytorch_compile=use_pytorch_compile, distributed_config=distributed_config, - precision_config=precision_config, - use_protected_eigh=use_protected_eigh, + preconditioner_dtype=preconditioner_dtype.value, track_root_inv_residuals=track_root_inv_residuals, preconditioner_config=instantiate_preconditioner_config( preconditioner_computation_type diff --git a/distributed_shampoo/gpu_tests/shampoo_pt2_test.py b/distributed_shampoo/gpu_tests/shampoo_pt2_test.py index 1f2e29d..debbf3f 100644 --- a/distributed_shampoo/gpu_tests/shampoo_pt2_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_pt2_test.py @@ -42,7 +42,8 @@ def _shampoo_optim_factory( parameters, lr=0.01, betas=betas, - beta3=betas[0] * betas[0], + # TODO: comment out beta3 to unblock quantization changes; need to fix PT2 FMA changes for this test + # beta3=betas[0] * betas[0], epsilon=1e-10, momentum=0.9, dampening=0.9, diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 98fa544..a3f7100 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -41,7 +41,6 @@ LR = "lr" MAX_PRECONDITIONER_DIM = "max_preconditioner_dim" PARAMS = "params" # While this is stored in groups by default, we do not checkpoint this quantity. -PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" PRECONDITIONER_CONFIG = "preconditioner_config" @@ -89,7 +88,7 @@ class PreconditionerConfig(AbstractDataclass): """ - amortized_computation_config: MatrixFunctionConfig + amortized_computation_config: MatrixFunctionConfig # type: ignore @dataclass(kw_only=True) @@ -154,42 +153,6 @@ class FSDPParameterMetadata: sharding_strategy: ShardingStrategy -@dataclass -class PrecisionConfig: - """Configuration for precision of each optimizer state. - - TODO: allow more specific computation dtypes that only apply to some computations - - Args: - computation_dtype (torch.dtype): Data type that all computation is performed in, except factor matrices (see factor_matrix_computation_dtype). (Default: torch.float32) - factor_matrix_dtype (torch.dtype): Data type for storing Shampoo factor matrices. (Default: torch.float32) - inv_factor_matrix_dtype (torch.dtype): Data type for storing Shampoo inverse factor matrices. (Default: torch.float32) - factor_matrix_computation_dtype (torch.dtype): Data type for accumulating factor matrices and computing their inverses. (Default: torch.float32) - corrected_eigenvalues_dtype (torch.dtype): Data type for storing the corrected eigenvalues of Shampoo preconditioner (EMA). (Default: torch.float32) - factor_matrix_eigenvectors_dtype (torch.dtype): Data type for storing the eigenvectors of Shampoo factor matrices. (Default: torch.float32) - filtered_grad_dtype (torch.dtype): Data type for storing filtered gradients (EMA). (Default: torch.float32) - momentum_dtype (torch.dtype): Data type for storing momentum states. (Default: torch.float32) - grafting_state_dtype (torch.dtype): Data type for storing grafting preconditioners, if applicable. (Default: torch.float32) - Current applicable grafting configs: - - AdaGradGraftingConfig - - RMSpropGraftingConfig - - AdamGraftingConfig - NOT applicable configs: - - SGDGraftingConfig - - None (i.e. no grafting) - """ - - computation_dtype: torch.dtype = torch.float32 - factor_matrix_dtype: torch.dtype = torch.float32 - inv_factor_matrix_dtype: torch.dtype = torch.float32 - corrected_eigenvalues_dtype: torch.dtype = torch.float32 - factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32 - factor_matrix_computation_dtype: torch.dtype = torch.float32 - filtered_grad_dtype: torch.dtype = torch.float32 - momentum_dtype: torch.dtype = torch.float32 - grafting_state_dtype: torch.dtype = torch.float32 - - @dataclass(init=False) class DistributedConfig(AbstractDataclass): """Abstract dataclass for distributed configs in Shampoo.""" @@ -229,6 +192,29 @@ class FSDPShampooConfig(DistributedConfig): param_to_metadata: dict[Parameter, FSDPParameterMetadata] +@dataclass +class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig): + """Configuration for HSDP Shampoo. + + Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared + parameters between different HSDP process groups. + + Args: + device_mesh (torch.distributed.device_mesh.DeviceMesh): A 2D device mesh that specifies the layout of the numbers of + shard and replicate dimensions. + param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from HSDP. + communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT) + num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory. + If num_trainers_per_group = -1 is used, then defaults to using the number of workers in each replicated HSDP + group. (Default: -1) + communicate_params (bool): Flag for all-gathering updated params across multiple workers. + If False, all-gathers parameter updates across multiple workers. (Default: False) + + """ + + device_mesh: DeviceMesh + + @dataclass(kw_only=True) class FullyShardShampooConfig(DistributedConfig): """Configuration for FullyShard (per-parameter FSDP) Shampoo. @@ -238,16 +224,14 @@ class FullyShardShampooConfig(DistributedConfig): @dataclass -class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig): - """Configuration for HSDP Shampoo. +class HybridShardShampooConfig(FullyShardShampooConfig, DDPShampooConfig): + """Configuration for HybridShard (per-parameter FSDP) Shampoo. Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared - parameters between different HSDP process groups. + parameters between different Hybrid Shard process groups. Args: - device_mesh (torch.distributed.device_mesh.DeviceMesh): A 2D device mesh that specifies the layout of the numbers of - shard and replicate dimensions. - param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from HSDP. + device_mesh (torch.distributed.device_mesh.DeviceMesh): Device mesh for Hybrid Shard. communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT) num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory. If num_trainers_per_group = -1 is used, then defaults to using the number of workers in each replicated HSDP diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 92eb0aa..b861a61 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -25,18 +25,12 @@ DefaultEigenvalueCorrectedShampooConfig, DefaultShampooConfig, DistributedConfig, - GRAFTING_PRECONDITIONER_LIST, GraftingConfig, - MASKED_FILTERED_GRAD_LIST, - MASKED_MOMENTUM_LIST, - PrecisionConfig, PreconditionerConfig, SGDGraftingConfig, - SHAMPOO_PRECONDITIONER_LIST, ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) -from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList from torch import nn @@ -162,63 +156,13 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None: **incorrect_hyperparameter_setting, ) - def test_invalid_pytorch_compile_setting(self) -> None: - with ( - mock.patch.object(torch.cuda, "is_available", return_value=False), - self.assertRaisesRegex( - ValueError, - re.escape( - "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." - ), - ), - ): - DistributedShampoo( - self._model.parameters(), - use_pytorch_compile=True, - shampoo_pt2_compile_config=ShampooPT2CompileConfig(), - ) - - with ( - mock.patch.object(torch.cuda, "is_available", return_value=False), - self.assertRaisesRegex( - ValueError, - re.escape( - "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." - ), - ), - ): - DistributedShampoo( - self._model.parameters(), - use_pytorch_compile=False, - shampoo_pt2_compile_config=ShampooPT2CompileConfig(), - ) - - def test_warning_pytorch_compile_setting(self) -> None: - with ( - mock.patch.object(torch.cuda, "is_available", return_value=True), - self.assertLogs( - level="WARNING", - ) as cm, - ): - DistributedShampoo( - self._model.parameters(), - lr=0.01, - use_pytorch_compile=True, - shampoo_pt2_compile_config=None, - ) - - self.assertIn( - "use_pytorch_compile is deprecating. Please use shampoo_pt2_compile_config instead.", - [r.msg for r in cm.records], - ) - def test_invalid_cuda_pytorch_compile_setting(self) -> None: with ( mock.patch.object(torch.cuda, "is_available", return_value=False), self.assertRaisesRegex( ValueError, re.escape( - "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + "Backend does NOT support Pytorch 2.0 compile. Switch to shampoo_pt2_compile_config=None." ), ), ): @@ -373,7 +317,7 @@ def setUp(self) -> None: "state": { "0.weight": { '["step"]': torch.tensor(0), - '["block_0", "shampoo", "factor_matrices", 0, "quantized_values"]': torch.tensor( + '["block_0", "shampoo", "factor_matrices", 0]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -382,7 +326,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_0", "shampoo", "factor_matrices", 1, "quantized_values"]': torch.tensor( + '["block_0", "shampoo", "factor_matrices", 1]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -391,7 +335,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_0", "shampoo", "inv_factor_matrices", 0, "quantized_values"]': torch.tensor( + '["block_0", "shampoo", "inv_factor_matrices", 0]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -400,7 +344,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_0", "shampoo", "inv_factor_matrices", 1, "quantized_values"]': torch.tensor( + '["block_0", "shampoo", "inv_factor_matrices", 1]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -415,7 +359,7 @@ def setUp(self) -> None: '["block_0", "shampoo", "is_factor_matrices_diagonal", 1]': torch.tensor( True ), - '["block_1", "shampoo", "factor_matrices", 0, "quantized_values"]': torch.tensor( + '["block_1", "shampoo", "factor_matrices", 0]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -424,7 +368,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_1", "shampoo", "factor_matrices", 1, "quantized_values"]': torch.tensor( + '["block_1", "shampoo", "factor_matrices", 1]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -433,7 +377,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_1", "shampoo", "inv_factor_matrices", 0, "quantized_values"]': torch.tensor( + '["block_1", "shampoo", "inv_factor_matrices", 0]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -442,7 +386,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_1", "shampoo", "inv_factor_matrices", 1, "quantized_values"]': torch.tensor( + '["block_1", "shampoo", "inv_factor_matrices", 1]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -457,7 +401,7 @@ def setUp(self) -> None: '["block_1", "shampoo", "is_factor_matrices_diagonal", 1]': torch.tensor( True ), - '["block_0", "adagrad", "quantized_values"]': torch.tensor( + '["block_0", "adagrad"]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -466,7 +410,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_1", "adagrad", "quantized_values"]': torch.tensor( + '["block_1", "adagrad"]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -475,7 +419,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_0", "filtered_grad", "quantized_values"]': torch.tensor( + '["block_0", "filtered_grad"]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -484,7 +428,7 @@ def setUp(self) -> None: [0.0, 0.0, 0.0, 0.0, 0.0], ] ), - '["block_1", "filtered_grad", "quantized_values"]': torch.tensor( + '["block_1", "filtered_grad"]': torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], @@ -515,7 +459,7 @@ def setUp(self) -> None: epsilon=0.001, ), "use_merge_dims": True, - "precision_config": PrecisionConfig(), + "preconditioner_dtype": torch.float32, "preconditioner_config": DefaultShampooConfig, } }, @@ -695,15 +639,7 @@ def _get_track_root_inverse_residuals_output(self, dtype: torch.dtype) -> list[s params=model.parameters(), precondition_frequency=2, start_preconditioning_step=2, - precision_config=PrecisionConfig( - computation_dtype=dtype, - factor_matrix_dtype=dtype, - inv_factor_matrix_dtype=dtype, - factor_matrix_computation_dtype=dtype, - filtered_grad_dtype=dtype, - momentum_dtype=dtype, - grafting_state_dtype=dtype, - ), + preconditioner_dtype=dtype, track_root_inv_residuals=True, ) @@ -740,259 +676,6 @@ def test_compute_and_log_root_inverse_residuals(self) -> None: ) -class DistributedShampooPrecisionTest(unittest.TestCase): - def setUp(self) -> None: - self._model = nn.Sequential( - nn.Linear(5, 10, bias=False), - ) - self._x = torch.randn(5) - self._y = torch.randn(10) - - def _instantiate_optimizer( - self, precision_config: PrecisionConfig - ) -> DistributedShampoo: - return DistributedShampoo( - self._model.parameters(), - lr=0.01, - betas=(0.9, 1.0), - epsilon=1e-12, - momentum=0.99, - weight_decay=0.0, - max_preconditioner_dim=5, - precondition_frequency=1, - start_preconditioning_step=1, - distributed_config=None, - grafting_config=AdaGradGraftingConfig( - epsilon=0.001, - ), - precision_config=precision_config, - ) - - def _assert_equal_state_dtype( - self, - quantized_tensor_list: QuantizedTensorList, - computation_dtype: torch.dtype, - quantized_dtype: torch.dtype, - ) -> None: - self.assertEqual(quantized_tensor_list.computation_dtype, computation_dtype) - self.assertEqual(quantized_tensor_list.quantized_dtype, quantized_dtype) - self.assertIsNone(quantized_tensor_list.dequantized_value_list) - - def _assert_state_list_dtype( - self, state_list: dict[str, Any], precision_config: PrecisionConfig - ) -> None: - # TODO: is it possible to avoid accessing private field _masked_kronecker_factors_list? - for kronecker_factor in state_list[ - SHAMPOO_PRECONDITIONER_LIST - ]._masked_kronecker_factors_list: - self._assert_equal_state_dtype( - kronecker_factor.factor_matrices, - precision_config.factor_matrix_computation_dtype, - precision_config.factor_matrix_dtype, - ) - self._assert_equal_state_dtype( - kronecker_factor.inv_factor_matrices, - precision_config.computation_dtype, - precision_config.inv_factor_matrix_dtype, - ) - self._assert_equal_state_dtype( - state_list[GRAFTING_PRECONDITIONER_LIST]._masked_preconditioner_list, - precision_config.computation_dtype, - precision_config.grafting_state_dtype, - ) - self._assert_equal_state_dtype( - state_list[MASKED_FILTERED_GRAD_LIST], - precision_config.computation_dtype, - precision_config.filtered_grad_dtype, - ) - self._assert_equal_state_dtype( - state_list[MASKED_MOMENTUM_LIST], - precision_config.computation_dtype, - precision_config.momentum_dtype, - ) - - def test_precision_configs(self) -> None: - precision_configs = [ - PrecisionConfig(computation_dtype=torch.float16), - PrecisionConfig(factor_matrix_dtype=torch.float16), - PrecisionConfig(inv_factor_matrix_dtype=torch.float16), - PrecisionConfig(filtered_grad_dtype=torch.float16), - PrecisionConfig(momentum_dtype=torch.float16), - PrecisionConfig(grafting_state_dtype=torch.float16), - PrecisionConfig( - factor_matrix_dtype=torch.float16, inv_factor_matrix_dtype=torch.float16 - ), - PrecisionConfig( - factor_matrix_dtype=torch.float16, - inv_factor_matrix_dtype=torch.float16, - grafting_state_dtype=torch.float16, - filtered_grad_dtype=torch.float16, - momentum_dtype=torch.float16, - ), - PrecisionConfig( - factor_matrix_dtype=torch.float64, - inv_factor_matrix_dtype=torch.float64, - filtered_grad_dtype=torch.float64, - computation_dtype=torch.float64, - ), - PrecisionConfig(factor_matrix_computation_dtype=torch.float64), - PrecisionConfig( - factor_matrix_dtype=torch.float64, - inv_factor_matrix_dtype=torch.float64, - factor_matrix_computation_dtype=torch.float64, - ), - ] - - for precision_config in precision_configs: - with self.subTest(precision_config=precision_config): - optimizer = self._instantiate_optimizer( - precision_config=precision_config - ) - for state_list in optimizer._per_group_state_lists: - self._assert_state_list_dtype(state_list, precision_config) - - for _ in range(2): - optimizer.zero_grad() - y_hat = self._model(self._x) - loss = torch.nn.CrossEntropyLoss()(y_hat, self._y) - loss.backward() - optimizer.step() - for state_list in optimizer._per_group_state_lists: - self._assert_state_list_dtype(state_list, precision_config) - - def test_setting_both_preconditioner_dtype_and_precision_config(self) -> None: - with self.assertRaisesRegex( - ValueError, - re.escape( - "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." - ), - ): - DistributedShampoo( - self._model.parameters(), - lr=0.01, - preconditioner_dtype=torch.float16, - precision_config=PrecisionConfig(), - ) - - def test_setting_preconditioner_dtype_only(self) -> None: - with self.assertLogs( - level="WARNING", - ) as cm: - DistributedShampoo( - self._model.parameters(), - lr=0.01, - preconditioner_dtype=torch.float16, - precision_config=None, - ) - - self.assertIn( - "preconditioner_dtype is deprecated. Please use precision_config instead.", - [r.msg for r in cm.records], - ) - - -class EigenvalueCorrectedDistributedShampooPrecisionTest( - DistributedShampooPrecisionTest -): - def _instantiate_optimizer( - self, precision_config: PrecisionConfig - ) -> DistributedShampoo: - return DistributedShampoo( - self._model.parameters(), - lr=0.01, - betas=(0.9, 1.0), - epsilon=1e-12, - momentum=0.99, - weight_decay=0.0, - max_preconditioner_dim=5, - precondition_frequency=1, - start_preconditioning_step=1, - distributed_config=None, - grafting_config=None, - precision_config=precision_config, - preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, - ) - - def _assert_state_list_dtype( - self, state_list: dict[str, Any], precision_config: PrecisionConfig - ) -> None: - # TODO: is it possible to avoid accessing private field _masked_kronecker_factors_list? - for kronecker_factor in state_list[ - SHAMPOO_PRECONDITIONER_LIST - ]._masked_kronecker_factors_list: - self._assert_equal_state_dtype( - kronecker_factor.factor_matrices, - precision_config.factor_matrix_computation_dtype, - precision_config.factor_matrix_dtype, - ) - self._assert_equal_state_dtype( - kronecker_factor.factor_matrices_eigenvectors, - precision_config.computation_dtype, - precision_config.factor_matrix_eigenvectors_dtype, - ) - self._assert_equal_state_dtype( - kronecker_factor.corrected_eigenvalues, - precision_config.computation_dtype, - precision_config.corrected_eigenvalues_dtype, - ) - self._assert_equal_state_dtype( - state_list[MASKED_FILTERED_GRAD_LIST], - precision_config.computation_dtype, - precision_config.filtered_grad_dtype, - ) - self._assert_equal_state_dtype( - state_list[MASKED_MOMENTUM_LIST], - precision_config.computation_dtype, - precision_config.momentum_dtype, - ) - - def test_precision_configs(self) -> None: - precision_configs = [ - PrecisionConfig(computation_dtype=torch.float16), - PrecisionConfig(factor_matrix_dtype=torch.float16), - PrecisionConfig(factor_matrix_eigenvectors_dtype=torch.float16), - PrecisionConfig(corrected_eigenvalues_dtype=torch.float16), - PrecisionConfig(filtered_grad_dtype=torch.float16), - PrecisionConfig(momentum_dtype=torch.float16), - PrecisionConfig( - factor_matrix_dtype=torch.float16, - factor_matrix_eigenvectors_dtype=torch.float16, - ), - PrecisionConfig( - factor_matrix_dtype=torch.float16, - factor_matrix_eigenvectors_dtype=torch.float16, - corrected_eigenvalues_dtype=torch.float16, - ), - PrecisionConfig( - factor_matrix_dtype=torch.float16, - factor_matrix_eigenvectors_dtype=torch.float16, - corrected_eigenvalues_dtype=torch.float16, - filtered_grad_dtype=torch.float16, - momentum_dtype=torch.float16, - ), - PrecisionConfig(factor_matrix_computation_dtype=torch.float64), - PrecisionConfig( - factor_matrix_dtype=torch.float64, - factor_matrix_eigenvectors_dtype=torch.float16, - corrected_eigenvalues_dtype=torch.float64, - factor_matrix_computation_dtype=torch.float64, - ), - ] - - for precision_config in precision_configs: - with self.subTest(precision_config=precision_config): - optimizer = self._instantiate_optimizer( - precision_config=precision_config - ) - for state_list in optimizer._per_group_state_lists: - self._assert_state_list_dtype(state_list, precision_config) - - for _ in range(2): - optimizer.step() - for state_list in optimizer._per_group_state_lists: - self._assert_state_list_dtype(state_list, precision_config) - - class DistributedShampooNoneGradTest(unittest.TestCase): def setUp(self) -> None: self._model = nn.Sequential( @@ -1015,23 +698,26 @@ def setUp(self) -> None: def test_step_with_consistent_grads(self) -> None: with self.assertNoLogs(level="WARNING"): + self._optimizer.zero_grad() self._model[0].weight.grad = torch.rand(10, 5) self._optimizer.step() + + self._optimizer.zero_grad() self._model[0].weight.grad = torch.rand(10, 5) self._optimizer.step() def test_step_with_none_grads(self) -> None: - expected_msgs = [ - "PT2 will recompile because the gradient selction of model parameters have changed from the previous step. Possible reasons include some gradients are None. If this is not intended, please check the data and/or model.", - ] + expected_msg = "PT2 will recompile because the gradient selction of model parameters have changed from the previous step. Possible reasons include some gradients are None. If this is not intended, please check the data and/or model." + ending_msg = "Changed gradient selector indices: [0, 1]" with self.assertLogs(level="WARNING") as cm: + self._optimizer.zero_grad() self._model[0].weight.grad = torch.rand(10, 5) self._optimizer.step() - self._model[0].weight.grad = None # set grad=None in second step + + self._optimizer.zero_grad() # Implicitly set grad=None in second step self._optimizer.step() msgs = [r.msg for r in cm.records] - self.assertEqual( - msgs, - expected_msgs, - ) + self.assertEqual(len(msgs), 1) + self.assertIn(expected_msg, msgs[0]) + self.assertIn(ending_msg, msgs[0]) diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 2ec773f..49d8a52 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -22,12 +22,11 @@ class AdaGradGraftingConfigTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 grafting_config_type = self._get_grafting_config_type() - with ( - self.subTest(grafting_config_type=grafting_config_type), - self.assertRaisesRegex( - ValueError, - re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), - ), + with self.subTest( + grafting_config_type=grafting_config_type + ), self.assertRaisesRegex( + ValueError, + re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), ): grafting_config_type(epsilon=epsilon) @@ -47,13 +46,12 @@ def test_illegal_beta2( ) -> None: grafting_config_type = self._get_grafting_config_type() for beta2 in (-1.0, 0.0, 1.3): - with ( - self.subTest(grafting_config_type=grafting_config_type, beta2=beta2), - self.assertRaisesRegex( - ValueError, - re.escape( - f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." - ), + with self.subTest( + grafting_config_type=grafting_config_type, beta2=beta2 + ), self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." ), ): grafting_config_type(beta2=beta2) diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py index fa58208..f9fd65b 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py @@ -313,12 +313,11 @@ def test_dist_is_initialized(self) -> None: device_mesh=mesh_2d, ) - with ( - mock.patch.object(torch.distributed, "is_initialized", return_value=False), - self.assertRaisesRegex( - RuntimeError, - re.escape("HSDPDistributor needs torch.distributed to be initialized!"), - ), + with mock.patch.object( + torch.distributed, "is_initialized", return_value=False + ), self.assertRaisesRegex( + RuntimeError, + re.escape("HSDPDistributor needs torch.distributed to be initialized!"), ): ShampooHSDPDistributorTest._train_model( optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory( @@ -340,15 +339,12 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group( ) # Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group. - with ( - mock.patch.object( - torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 - ), - self.assertRaisesRegex( - ValueError, - re.escape( - "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" - ), + with mock.patch.object( + torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 + ), self.assertRaisesRegex( + ValueError, + re.escape( + "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" ), ): ShampooHSDPDistributorTest._train_model( diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_hybrid_shard_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_hybrid_shard_distributor_test.py new file mode 100644 index 0000000..9e3e002 --- /dev/null +++ b/distributed_shampoo/utils/gpu_tests/shampoo_hybrid_shard_distributor_test.py @@ -0,0 +1,407 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +#!/usr/bin/env python3 + +import re +import unittest +from collections.abc import Callable +from functools import partial +from itertools import product +from unittest import mock + +import torch +from distributed_shampoo.distributed_shampoo import DistributedShampoo +from distributed_shampoo.shampoo_types import ( + AdaGradGraftingConfig, + CommunicationDType, + HybridShardShampooConfig, +) +from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem +from distributed_shampoo.utils.shampoo_preconditioner_list import SHAMPOO + +from torch import nn +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.checkpoint._nested_dict import flatten_state_dict +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import ParamsT +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available") +class ShampooHybridShardDistributorTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 4 + + @property + def backend(self) -> str: + return "cpu:gloo,cuda:nccl" + + @staticmethod + def _construct_model( + device: torch.device, + distributed_config: HybridShardShampooConfig | None, + ) -> tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor, bool]: + IN_DIM = 16 + data = torch.arange(IN_DIM, dtype=torch.float, device=device) + data /= torch.norm(data) + # NOTE: We construct the model here specifically in order to ensure that + # FullyShard Shampoo and default Shampoo produce equivalent results. + # This requires us to construct a model such that FullyShard will split the + # parameters such that the preconditioners created between the FullyShard + # and default Shampoo are equivalent. + # +----------------+ + # | [4, 16] | + # | GPU0 | + # -------------------- +------+ + # | [4, 16] | |[4, 4]| + # | GPU1 | | | + # +----------------+ +------+ + # For the first linear layer, each GPU has a [4, 16] parameter. The blocked + # parameters are of size [4, 4] and each GPU has four local blocks (eight + # blocks in total). In comparison, with default shampoo, the eight blocks + # are replicated on two GPUs. + # Similarly, the second linear layer has a [1, 8] parameter and is split + # into two [4] chunks. + + model_linear_layers_dims = (IN_DIM, 8, 1) + # model dead layers won't parpicipate in the training and thus don't have grads. + model_dead_layer_dims = (4, 1) + model, loss, data, target = construct_training_problem( + model_linear_layers_dims=model_linear_layers_dims, + model_dead_layer_dims=model_dead_layer_dims, + device=device, + fill=0.1, + ) + + if uses_hybrid_shard := isinstance( + distributed_config, HybridShardShampooConfig + ): + # Need this to get pass type-checking test. + assert distributed_config is not None + model = fully_shard( + model, + mesh=distributed_config.device_mesh, + ) + return model, loss, data, target, uses_hybrid_shard + + @staticmethod + def _train_model( + optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + model_factory: Callable[ + [torch.device], + tuple[ + nn.Module, + nn.Module, + torch.Tensor, + torch.Tensor, + bool, + ], + ], + device: torch.device, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + model, loss, data, target, uses_hybrid_shard = model_factory(device) + params = model.parameters() + optimizer = optim_factory(params) + for _ in range(5): + optimizer.zero_grad() + objective = loss(model(data), target) + objective.backward() + optimizer.step() + + linear_layers = model.get_submodule("linear_layers") + # Need this assertion to get pass type-checking test. + assert linear_layers is not None + if uses_hybrid_shard: + # When Hybrid Shard is used, model parameters are DTensors. We obtain the full value of + # parameters from DTensors. + params_list = [] + # We only care model_linear_layers_dim params, not model_dead_layer params. + for param in linear_layers.parameters(): + # Need this assertion to get pass type-checking test. + assert isinstance(param, DTensor) + params_list.append(param.full_tensor().view(-1).detach().cpu()) + else: + params_list = [ + param.view(-1).detach().cpu() for param in linear_layers.parameters() + ] + + return params_list, objective.detach().cpu() + + @staticmethod + def _test_two_configs( + optim_factory1: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + model_factory1: Callable[ + [torch.device], + tuple[ + nn.Module, + nn.Module, + torch.Tensor, + torch.Tensor, + bool, + ], + ], + optim_factory2: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + model_factory2: Callable[ + [torch.device], + tuple[ + nn.Module, + nn.Module, + torch.Tensor, + torch.Tensor, + bool, + ], + ], + device: torch.device, + ) -> None: + params1, loss1 = ShampooHybridShardDistributorTest._train_model( + optim_factory1, + model_factory1, + device=device, + ) + params2, loss2 = ShampooHybridShardDistributorTest._train_model( + optim_factory2, + model_factory2, + device=device, + ) + torch.testing.assert_close(loss1, loss2, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(params1, params2) + + @staticmethod + def _shampoo_optim_factory( + distributed_config: HybridShardShampooConfig | None, + ) -> Callable[ + [ParamsT], + torch.optim.Optimizer, + ]: + return lambda parameters: ( + lambda distributed_config: DistributedShampoo( + parameters, + lr=0.001, + betas=(0.9, 1.0), + epsilon=1e-8, + momentum=0.0, + weight_decay=0.0, + max_preconditioner_dim=4, + precondition_frequency=1, + start_preconditioning_step=2, + use_decoupled_weight_decay=True, + grafting_config=AdaGradGraftingConfig( + epsilon=1e-8, + ), + distributed_config=distributed_config, + ) + )( + distributed_config, + ) + + @staticmethod + def _model_factory( + distributed_config: HybridShardShampooConfig | None, + ) -> Callable[ + [torch.device], + tuple[ + nn.Module, + nn.Module, + torch.Tensor, + torch.Tensor, + bool, + ], + ]: + return partial( + ShampooHybridShardDistributorTest._construct_model, + distributed_config=distributed_config, + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_hybrid_shard_shampoo_against_default_shampoo(self) -> None: + mesh_2d = init_device_mesh( + "cuda", (2, 2), mesh_dim_names=("replicate", "shard") + ) + for num_trainers_per_group, ( + communication_dtype, + communicate_params, + ) in product( + (-1, 1, 2), + ( + (CommunicationDType.DEFAULT, False), + (CommunicationDType.DEFAULT, True), + (CommunicationDType.FP16, False), + (CommunicationDType.BF16, False), + ), + ): + hybrid_shard_config = HybridShardShampooConfig( + device_mesh=mesh_2d, + communication_dtype=communication_dtype, + num_trainers_per_group=num_trainers_per_group, + communicate_params=communicate_params, + ) + + with self.subTest( + communication_dtype=communication_dtype, + num_trainers_per_group=num_trainers_per_group, + communicate_params=communicate_params, + ): + ShampooHybridShardDistributorTest._test_two_configs( + ShampooHybridShardDistributorTest._shampoo_optim_factory( + None, + ), + ShampooHybridShardDistributorTest._model_factory( + None, + ), + ShampooHybridShardDistributorTest._shampoo_optim_factory( + distributed_config=hybrid_shard_config, + ), + ShampooHybridShardDistributorTest._model_factory( + hybrid_shard_config, + ), + device=torch.device("cuda"), + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_hybrid_shard_shampoo_block_index(self) -> None: + mesh_2d = init_device_mesh( + "cuda", (2, 2), mesh_dim_names=("replicate", "shard") + ) + hybrid_shard_config = HybridShardShampooConfig( + device_mesh=mesh_2d, + ) + model_factory = ShampooHybridShardDistributorTest._model_factory( + hybrid_shard_config + ) + optim_factory = ShampooHybridShardDistributorTest._shampoo_optim_factory( + hybrid_shard_config + ) + model, loss, data, target, _ = model_factory(torch.device("cuda")) + params = model.parameters() + optimizer = optim_factory(params) + assert isinstance(optimizer, DistributedShampoo) + state_dict = optimizer.distributed_state_dict( + key_to_param=model.named_parameters() + ) + flattened_state_dict = flatten_state_dict(state_dict)[0] + + # Note that we get the local rank corresponding to the second mesh dimension + # because the first mesh dimension corresponds to replication and the second + # mesh dimension corresponds to the sharding dimension. + # + # We expect that the rank should correspond to the rank in the shard dimension + # in order to avoid having the same key. + rank = mesh_2d.get_local_rank(mesh_dim=1) + matches = 0 + for key in flattened_state_dict.keys(): + if SHAMPOO in key: + with self.subTest(key=key): + self.assertIn(f"rank_{rank}-block_", key) + matches += 1 + self.assertGreater(matches, 0) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_number_of_trainers_per_group_out_of_range(self) -> None: + mesh_2d = init_device_mesh( + "cuda", (2, 2), mesh_dim_names=("replicate", "shard") + ) + hybrid_shard_config = HybridShardShampooConfig( + device_mesh=mesh_2d, + num_trainers_per_group=3, + ) + + with self.assertRaisesRegex( + ValueError, + re.escape( + "Invalid number of trainers per group: 3. Must be between [1, 2] or set to -1." + ), + ): + ShampooHybridShardDistributorTest._train_model( + optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory( + distributed_config=hybrid_shard_config, + ), + model_factory=ShampooHybridShardDistributorTest._model_factory( + hybrid_shard_config + ), + device=torch.device("cuda"), + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_dist_is_initialized(self) -> None: + mesh_2d = init_device_mesh( + "cuda", (2, 2), mesh_dim_names=("replicate", "shard") + ) + hybrid_shard_config = HybridShardShampooConfig( + device_mesh=mesh_2d, + ) + + with mock.patch.object( + torch.distributed, "is_initialized", return_value=False + ), self.assertRaisesRegex( + RuntimeError, + re.escape( + "HybridShardDistributor needs torch.distributed to be initialized!" + ), + ): + ShampooHybridShardDistributorTest._train_model( + optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory( + distributed_config=hybrid_shard_config, + ), + model_factory=ShampooHybridShardDistributorTest._model_factory( + hybrid_shard_config + ), + device=torch.device("cuda"), + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_incompatible_replicated_group_size_and_num_trainers_per_group( + self, + ) -> None: + mesh_2d = init_device_mesh( + "cuda", (2, 2), mesh_dim_names=("replicate", "shard") + ) + hybrid_shard_config = HybridShardShampooConfig( + device_mesh=mesh_2d, + num_trainers_per_group=3, + ) + + # Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group. + with mock.patch.object( + torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 + ), self.assertRaisesRegex( + ValueError, + re.escape( + "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" + ), + ): + ShampooHybridShardDistributorTest._train_model( + optim_factory=ShampooHybridShardDistributorTest._shampoo_optim_factory( + distributed_config=hybrid_shard_config, + ), + model_factory=ShampooHybridShardDistributorTest._model_factory( + hybrid_shard_config + ), + device=torch.device("cuda"), + ) diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index ca6bea3..e2d949f 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -9,7 +9,9 @@ import heapq import logging +import operator from functools import partial +from itertools import islice from typing import Any import torch @@ -87,7 +89,6 @@ def __init__( CommunicationDType.DEFAULT, ] communication_dtype = torch.float32 - self._communication_dtype: torch.dtype = communication_dtype # Initialize _dist_group and _group_rank. self._dist_group: dist.ProcessGroup | None = ( @@ -95,12 +96,12 @@ def __init__( if self._group_size == self._global_size else dist.new_subgroups(group_size=self._group_size)[0] ) - self._group_rank: int = dist.get_rank(group=self._dist_group) + group_rank: int = dist.get_rank(group=self._dist_group) # Assign ranks to blocks with their respective buffer size. buffer_size_ranks = self._distribute_buffer_sizes( buffer_sizes=tuple( - blocked_param.numel() * get_dtype_size(self._communication_dtype) + blocked_param.numel() * get_dtype_size(communication_dtype) for blocked_param in self._global_blocked_params ) ) @@ -109,7 +110,7 @@ def __init__( # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( - block_info.group_source_rank == self._group_rank + block_info.group_source_rank == group_rank for block_info in self._global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( @@ -122,7 +123,11 @@ def __init__( self._local_blocked_params ) - self._construct_distributed_buffers(buffer_size_ranks) + self._construct_distributed_buffers( + buffer_size_ranks=buffer_size_ranks, + communication_dtype=communication_dtype, + group_rank=group_rank, + ) # NOTE: Remove this function once PT2 supports all_gather with functional collective @torch.no_grad() @@ -228,7 +233,7 @@ def _distribute_buffer_sizes( for index, aligned_buffer_size in sorted( enumerate(aligned_buffer_sizes), - key=lambda t: t[1], + key=operator.itemgetter(1), reverse=True, ): # Greedily find the group with the least allocated buffer size and its group index @@ -266,7 +271,9 @@ def _construct_global_block_info_list( self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( DDPBlockInfo( param=param, - composable_block_ids=(param_index, f"block_{block_index}"), + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, block_index=block_index + ), # Curry a function to capture a local variable "group_source_rank". allocate_zeros_tensor=partial( self._allocate_zeros_distributed_tensor, @@ -281,18 +288,16 @@ def _construct_global_block_info_list( ) for ( (param_index, param), - num_blocks_within_param, (buffer_size_ranks_start, buffer_size_ranks_end), ) in zip( enumerate(self._param_group[PARAMS]), - self._global_num_blocks_per_param, generate_pairwise_indices(self._global_num_blocks_per_param), strict=True, ) - for block_index, (_, group_source_rank) in zip( - range(num_blocks_within_param), - buffer_size_ranks[buffer_size_ranks_start:buffer_size_ranks_end], - strict=True, + for block_index, (_, group_source_rank) in enumerate( + islice( + buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + ) ) ) @@ -356,7 +361,10 @@ def _split_local_dist_buffers( return tuple(splitted_local_dist_buffers) def _construct_distributed_buffers( - self, buffer_size_ranks: tuple[tuple[int, int], ...] + self, + buffer_size_ranks: tuple[tuple[int, int], ...], + communication_dtype: torch.dtype, + group_rank: int, ) -> None: """Construct the distributed buffers for AllGather communications. @@ -367,6 +375,8 @@ def _construct_distributed_buffers( Args: buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size and an assigned rank for each block. + communication_dtype (torch.dtype): Data type used for communication. + group_rank (int): Rank of the current process group. """ @@ -390,16 +400,14 @@ def _construct_distributed_buffers( ) # Get local buffer for specific group rank. - self._local_dist_buffer = local_dist_buffers[self._group_rank] + self._local_dist_buffer = local_dist_buffers[group_rank] # Obtain the list of buffers corresponding to each block (ignoring padding). # Note that each buffer is reshaped into the block's shape and viewed in terms # of the communication data type. self._global_dist_blocked_buffers = tuple( - buffer.split( - blocked_param.numel() * get_dtype_size(self._communication_dtype) - )[0] - .view(self._communication_dtype) + buffer.split(blocked_param.numel() * get_dtype_size(communication_dtype))[0] + .view(communication_dtype) .view(blocked_param.shape) for buffer, blocked_param in zip( splitted_local_dist_buffers, self._global_blocked_params, strict=True diff --git a/distributed_shampoo/utils/shampoo_distributor.py b/distributed_shampoo/utils/shampoo_distributor.py index 8762b84..3a007f9 100644 --- a/distributed_shampoo/utils/shampoo_distributor.py +++ b/distributed_shampoo/utils/shampoo_distributor.py @@ -97,6 +97,26 @@ def local_masked_blocked_params(self) -> tuple[Tensor, ...]: def global_block_info_list(self) -> tuple[BlockInfo, ...]: return self._global_block_info_list + def _construct_composable_block_ids( + self, + param_index: int, + block_index: int, + rank: int | None = None, + ) -> tuple[int, str]: + """Construct composable block ids. + + Args: + param_index (int): Index of the parameter in self._param_group[PARAMS]. + block_index (int): Index of the tensor block within a given parameter. + rank (int | None): Rank of this process group; used in FSDP/HSDP. (Default: None) + + Returns: + tuple[int, str]: Composable block id tuple containing global block index and local block name. + The latter will be used to identify blocks in the masked tensor. + + """ + return (param_index, f"block_{block_index}") + def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None]: """Helper function that gets params or grads from the parameter group. @@ -275,7 +295,9 @@ def _construct_global_block_info_list( self._global_block_info_list = tuple( BlockInfo( param=param, - composable_block_ids=(param_index, f"block_{block_index}"), + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, block_index=block_index + ), ) # Block index that is accumulated across all parameters within a parameter group. for ((param_index, param), num_blocks_within_param) in zip( diff --git a/distributed_shampoo/utils/shampoo_fsdp_distributor.py b/distributed_shampoo/utils/shampoo_fsdp_distributor.py index 58fb57e..852f598 100644 --- a/distributed_shampoo/utils/shampoo_fsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_fsdp_distributor.py @@ -83,6 +83,25 @@ def update_params( masked_blocked_search_directions, ) + def _construct_composable_block_ids( + self, + param_index: int, + block_index: int, + rank: int | None = None, + ) -> tuple[int, str]: + """Construct composable block ids for each parameter. + + Args: + param_index (int): Index of the current parameter within self._param_group[PARAMS]. + block_index (int): Block index that is accumulated across all parameters within a parameter group. + rank (int | None): Rank of this process group; should be non None in FSDP/HSDP setting. (Default: None) + + Returns: + tuple[int, str]: Composable block id tuple containing global block index and local block name. + The latter will be used to identify blocks in the masked tensor. + """ + return (param_index, f"rank_{rank}-block_{block_index}") + def _construct_global_block_info_list( self, ) -> None: @@ -91,7 +110,9 @@ def _construct_global_block_info_list( self._global_block_info_list = tuple( BlockInfo( param=param, - composable_block_ids=(param_index, f"rank_{rank}-block_{block_index}"), + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, block_index=block_index, rank=rank + ), ) # Block index that is accumulated across all parameters within a parameter group. for ((param_index, param), num_blocks_within_param) in zip( diff --git a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py index a715490..e63b359 100644 --- a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py @@ -45,6 +45,25 @@ def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None if (local_p := p.to_local()).numel() > 0 ) + def _construct_composable_block_ids( + self, + param_index: int, + block_index: int, + rank: int | None = None, + ) -> tuple[int, str]: + """Construct composable block ids for each parameter. + + Args: + param_index (int): Index of the current parameter within self._param_group[PARAMS]. + block_index (int): Block index that is accumulated across all parameters within a parameter group. + rank (int | None): Rank of this process group; should be non None in FSDP/HSDP setting. (Default: None) + + Returns: + tuple[int, str]: Composable block id tuple containing global block index and local block name. + The latter will be used to identify blocks in the masked tensor. + """ + return (param_index, f"rank_{rank}-block_{block_index}") + @torch.no_grad() def _construct_global_block_info_list( self, @@ -61,7 +80,9 @@ def _construct_global_block_info_list( self._global_block_info_list = tuple( BlockInfo( param=param, - composable_block_ids=(param_index, f"rank_{rank}-block_{block_index}"), + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, block_index=block_index, rank=rank + ), ) # Block index that is accumulated across all parameters within a parameter group. for ((param_index, param), num_blocks_within_param) in zip( diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index 433dcfc..68bc8b8 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -9,7 +9,9 @@ import heapq import logging +import operator from functools import partial +from itertools import islice from math import prod from typing import Any @@ -161,7 +163,6 @@ def __init__( CommunicationDType.DEFAULT, ] communication_dtype = torch.float32 - self._communication_dtype: torch.dtype = communication_dtype # Initialize _dist_group and _group_rank. # Note that this requires initializing all process groups. @@ -189,12 +190,12 @@ def __init__( "shard" ) - self._comms_group_rank: int = dist.get_rank(self._comms_dist_group) + comms_group_rank: int = dist.get_rank(self._comms_dist_group) # Assign ranks to blocks with their respective buffer size. buffer_size_ranks = self._distribute_buffer_sizes( buffer_sizes=tuple( - blocked_param.numel() * get_dtype_size(self._communication_dtype) + blocked_param.numel() * get_dtype_size(communication_dtype) for blocked_param in self._global_blocked_params ) ) @@ -203,7 +204,7 @@ def __init__( # Initialize selectors and local blocked (masked) parameters. self._distributor_selector: tuple[bool, ...] = tuple( - block_info.group_source_rank == self._comms_group_rank + block_info.group_source_rank == comms_group_rank for block_info in self._global_block_info_list ) self._local_blocked_params: tuple[Tensor, ...] = compress_list( @@ -216,7 +217,11 @@ def __init__( self._local_blocked_params ) - self._construct_distributed_buffers(buffer_size_ranks) + self._construct_distributed_buffers( + buffer_size_ranks=buffer_size_ranks, + communication_dtype=communication_dtype, + comms_group_rank=comms_group_rank, + ) # NOTE: Remove this function once PT2 supports all_gather with functional collective @torch.no_grad() @@ -321,7 +326,7 @@ def _distribute_buffer_sizes( for index, aligned_buffer_size in sorted( enumerate(aligned_buffer_sizes), - key=lambda t: t[1], + key=operator.itemgetter(1), reverse=True, ): # Greedily find the group with the least allocated buffer size and its group index @@ -345,6 +350,25 @@ def _distribute_buffer_sizes( return tuple(buffer_size_ranks) + def _construct_composable_block_ids( + self, + param_index: int, + block_index: int, + rank: int | None = None, + ) -> tuple[int, str]: + """Construct composable block ids for each parameter. + + Args: + param_index (int): Index of the current parameter within self._param_group[PARAMS]. + block_index (int): Block index that is accumulated across all parameters within a parameter group. + rank (int | None): Rank of this process group; should be non None in FSDP/HSDP setting. (Default: None) + + Returns: + tuple[int, str]: Composable block id tuple containing global block index and local block name. + The latter will be used to identify blocks in the masked tensor. + """ + return (param_index, f"rank_{rank}-block_{block_index}") + def _construct_global_block_info_list( self, buffer_size_ranks: tuple[tuple[int, int], ...] ) -> None: @@ -355,9 +379,10 @@ def _construct_global_block_info_list( self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( DDPBlockInfo( param=param, - composable_block_ids=( - param_index, - f"rank_{sharded_group_rank}-block_{block_index}", + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, + block_index=block_index, + rank=sharded_group_rank, ), allocate_zeros_tensor=partial( self._allocate_zeros_distributed_tensor, @@ -372,18 +397,16 @@ def _construct_global_block_info_list( ) for ( (param_index, param), - num_blocks_within_param, (buffer_size_ranks_start, buffer_size_ranks_end), ) in zip( enumerate(self._param_group[PARAMS]), - self._global_num_blocks_per_param, generate_pairwise_indices(self._global_num_blocks_per_param), strict=True, ) - for block_index, (_, group_source_rank) in zip( - range(num_blocks_within_param), - buffer_size_ranks[buffer_size_ranks_start:buffer_size_ranks_end], - strict=True, + for block_index, (_, group_source_rank) in enumerate( + islice( + buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + ) ) ) @@ -521,7 +544,10 @@ def _split_local_dist_buffers( return tuple(splitted_local_dist_buffers) def _construct_distributed_buffers( - self, buffer_size_ranks: tuple[tuple[int, int], ...] + self, + buffer_size_ranks: tuple[tuple[int, int], ...], + communication_dtype: torch.dtype, + comms_group_rank: int, ) -> None: """Construct the distributed buffers for AllGather communications. @@ -532,6 +558,8 @@ def _construct_distributed_buffers( Args: buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the buffer size and an assigned rank for each block. + communication_dtype (torch.dtype): The data type used for communication. + comms_group_rank (int): The rank of the current group within the comms group. """ @@ -555,16 +583,14 @@ def _construct_distributed_buffers( ) # Get local buffer for specific group rank. - self._local_dist_buffer = local_dist_buffers[self._comms_group_rank] + self._local_dist_buffer = local_dist_buffers[comms_group_rank] # Obtain the list of buffers corresponding to each block (ignoring padding). # Note that each buffer is reshaped into the block's shape and viewed in terms # of the communication data type. self._global_dist_blocked_buffers = tuple( - buffer.split( - blocked_param.numel() * get_dtype_size(self._communication_dtype) - )[0] - .view(self._communication_dtype) + buffer.split(blocked_param.numel() * get_dtype_size(communication_dtype))[0] + .view(communication_dtype) .view(blocked_param.shape) for buffer, blocked_param in zip( splitted_local_dist_buffers, self._global_blocked_params, strict=True diff --git a/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py new file mode 100644 index 0000000..a25a683 --- /dev/null +++ b/distributed_shampoo/utils/shampoo_hybrid_shard_distributor.py @@ -0,0 +1,632 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +import heapq +import logging +import operator +from functools import partial +from itertools import islice +from typing import Any, Iterable + +import torch +from distributed_shampoo.shampoo_types import ( + CommunicationDType, + HybridShardShampooConfig, + PARAMS, +) +from distributed_shampoo.utils.shampoo_block_info import DDPBlockInfo +from distributed_shampoo.utils.shampoo_dist_utils import get_device_mesh +from distributed_shampoo.utils.shampoo_distributor import DistributorInterface +from distributed_shampoo.utils.shampoo_utils import ( + compress_list, + generate_pairwise_indices, + get_dtype_size, +) +from torch import distributed as dist, Tensor +from torch.distributed import tensor as dtensor +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DTensor, zeros as dtensor_zeros + +logger: logging.Logger = logging.getLogger(__name__) + + +class HybridShardDistributor(DistributorInterface): + """HybridShard Distributor class. + + The constructor internally sets up `DeviceMesh` objects as necessary for distributing memory + and computation, so torch.distributed must be initialized in advance. + + Unlike FullyShardDistributor, HybridShardDistributor requires the user to pass in a device mesh used for + Hybrid Shard. For example, suppose we have 48 GPUs and the Hybrid Shard group size is 8. Then: + + Hybrid Shard Device Mesh with (Replicate, Shard) = (6, 8): + + device_mesh = [[ 0, 1, 2, 3, 4, 5, 6, 7] + [ 8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47]] + + For example, if my device is rank 11, then: + device_mesh["replicate"] = [3, 11, 19, 27, 35, 43] + device_mesh["shard"] = [8, 9, 10, 11, 12, 13, 14, 15] + + Since the parameters are sharded along the "shard" dimension, we would normally replicate the + computation along the "replicate" dimension. With Hybrid Shard Shampoo, we instead want to + distribute the computation and memory requirements across the "replicate" dimension of the original + Hybrid Shard device mesh. + + For example, suppose that the num_trainers_per_group = 3. We want to form a (2, 3)-submesh on + the ranks [3, 11, 19, 27, 35, 43] (and similar). + + HybridShardDistributor 2D Sub-Mesh Example with (Replicate, Shard) = (2, 3): + + submesh = [[ 3, 11, 19] + [27, 35, 43]] + + In this case, optimizer states will live on different "replicate" meshes: {[3, 27], [11, 35], + [19, 43]}. In order to synchronize the optimizer step, we will communicate along the "shard" + mesh {[3, 11, 19], [27, 35, 43]}. + + Args: + param_group (dict[str, Any]): Parameter group containing parameters. + distributed_config (HybridShardShampooConfig): Configuration for HybridShard Shampoo. + + """ + + def __init__( + self, + param_group: dict[str, Any], + distributed_config: HybridShardShampooConfig, + ) -> None: + self._hybrid_shard_device_mesh: torch.distributed.device_mesh.DeviceMesh = ( + distributed_config.device_mesh + ) + self._global_num_blocks_per_param: tuple[int, ...] = () + + super().__init__(param_group) + if not dist.is_initialized(): + raise RuntimeError( + "HybridShardDistributor needs torch.distributed to be initialized!" + ) + + # Construct global masked blocked parameters (which is DDP-specific). + self._global_masked_blocked_params: tuple[Tensor, ...] = ( + self._global_blocked_params + ) + + # Check num_trainers_per_group and replicated group size. + # NOTE: If num_trainers_per_group = -1, then we use the replicated group size. + self._replicated_group_size: int = self._hybrid_shard_device_mesh.size(0) + + if not ( + 1 + <= distributed_config.num_trainers_per_group + <= self._replicated_group_size + or distributed_config.num_trainers_per_group == -1 + ): + raise ValueError( + f"Invalid number of trainers per group: {distributed_config.num_trainers_per_group}. " + f"Must be between [1, {self._replicated_group_size}] or set to -1." + ) + if distributed_config.num_trainers_per_group == -1: + logger.info( + f"Note that {distributed_config.num_trainers_per_group=}! Defaulting to replicated group size {self._replicated_group_size}." + ) + elif ( + not self._replicated_group_size % distributed_config.num_trainers_per_group + == 0 + ): + raise ValueError( + f"{distributed_config.num_trainers_per_group=} must divide {self._replicated_group_size=}!" + ) + + # Group size for distributing computation / memory requirements. + self._dist_group_size: int = ( + distributed_config.num_trainers_per_group + if distributed_config.num_trainers_per_group != -1 + else self._replicated_group_size + ) + + # Create flag for distributing parameters instead of search directions. + self._communicate_params: bool = distributed_config.communicate_params + + # Determine communication type. + if distributed_config.communication_dtype == CommunicationDType.BF16: + communication_dtype = torch.bfloat16 + elif distributed_config.communication_dtype == CommunicationDType.FP16: + communication_dtype = torch.float16 + else: + assert distributed_config.communication_dtype in [ + CommunicationDType.FP32, + CommunicationDType.DEFAULT, + ] + communication_dtype = torch.float32 + + # Initialize _dist_group and _group_rank. + # Note that this requires initializing all process groups. + # Splits replicated ranks group into smaller groups of size self._dist_group_size. + # Instantiates this by using DeviceMesh. + ranks_in_all_replicated_groups = self._hybrid_shard_device_mesh.mesh.T + for ranks_in_replicated_group in ranks_in_all_replicated_groups: + device_mesh = get_device_mesh( + device_type=self._hybrid_shard_device_mesh.device_type, + mesh=tuple( + tuple(ranks_in_replicated_subgroup) + for ranks_in_replicated_subgroup in ranks_in_replicated_group.view( + -1, self._dist_group_size + ).tolist() + ), + mesh_dim_names=("replicate", "shard"), + ) + if dist.get_rank() in ranks_in_replicated_group: + # NOTE: We want the process group in the device mesh that the current rank + # belongs to but solely along the "shard" dimension for communications. + # + # For example, if the current rank is 11, then I want the process group + # that contains the ranks [3, 11, 19]. + self._comms_dist_group: dist.ProcessGroup = device_mesh.get_group( + "shard" + ) + + comms_group_rank: int = dist.get_rank(self._comms_dist_group) + + # Assign ranks to blocks with their respective buffer size. + buffer_size_ranks = self._distribute_buffer_sizes( + buffer_sizes=tuple( + blocked_param.numel() * get_dtype_size(communication_dtype) + for blocked_param in self._global_blocked_params + ) + ) + + self._construct_global_block_info_list(buffer_size_ranks) + + # Initialize selectors and local blocked (masked) parameters. + self._distributor_selector: tuple[bool, ...] = tuple( + block_info.group_source_rank == comms_group_rank + for block_info in self._global_block_info_list + ) + self._local_blocked_params: tuple[Tensor, ...] = compress_list( + self._global_blocked_params, self._distributor_selector + ) + self._local_masked_blocked_params: tuple[Tensor, ...] = ( + self._local_blocked_params + ) + self._local_grad_selector: tuple[bool, ...] = (True,) * len( + self._local_blocked_params + ) + + self._construct_distributed_buffers( + buffer_size_ranks=buffer_size_ranks, + communication_dtype=communication_dtype, + comms_group_rank=comms_group_rank, + ) + + @torch.no_grad() + def _get_params_or_grads(self, get_grad: bool = False) -> Iterable[Tensor | None]: + """Helper function to get the local params (or grad) from the param_group, where params are represented as DTensors. + + Args: + get_grad (bool): Whether to return the param or the grad of the param. + Returns: + local (Iterable[Tensor | None]): Local params (or grad) from the param_group. + """ + # If a parameter is in a "dead layer", it won't have any gradient. In this case, we + # should return `None` for the gradient. + return ( + (None if p.grad is None else p.grad.to_local()) if get_grad else local_p + for p in self._param_group[PARAMS] + if (local_p := p.to_local()).numel() > 0 + ) + + # NOTE: Remove this function once PT2 supports all_gather with functional collective + @torch.no_grad() + @torch.compiler.disable + def all_gather_into_tensor(self) -> None: + dist.all_gather_into_tensor( + self._global_dist_buffer, + self._local_dist_buffer, + group=self._comms_dist_group, + ) + + @torch.no_grad() + def update_params( + self, + masked_blocked_search_directions: tuple[Tensor, ...], + ) -> None: + """Update params stored inside this distributor according to the input search directions argument. + + Args: + masked_blocked_search_directions (tuple[Tensor, ...]): Search directions for each local blocked parameter. + + See the comment in the parent class for details. + + """ + if self._communicate_params: + # Perform your update to your local masked parameters and copy into buffers. + torch._foreach_add_( + self._local_masked_blocked_params, + masked_blocked_search_directions, + ) + torch._foreach_copy_( + self._local_masked_dist_blocked_buffers, + self._local_masked_blocked_params, + ) + + self.all_gather_into_tensor() + + # Copy updated blocked params in global_masked_dist_blocked_buffers + # into global_masked_blocked_params. + torch._foreach_copy_( + self._global_masked_blocked_params, + self._global_masked_dist_blocked_buffers, + ) + + else: + # Search directions multiplied by alpha are distributed. + # Copy the local search directions to the communication buffer. + torch._foreach_copy_( + self._local_masked_dist_blocked_buffers, + masked_blocked_search_directions, + ) + + self.all_gather_into_tensor() + + # Add search directions in global_masked_dist_blocked_buffers + # to global_masked_blocked_params. + torch._foreach_add_( + self._global_masked_blocked_params, + self._global_masked_dist_blocked_buffers, + ) + + def _distribute_buffer_sizes( + self, + buffer_sizes: tuple[int, ...], + ) -> tuple[tuple[int, int], ...]: + """Distribute given buffer sizes across ranks in a group. + + Buffer sizes will be rounded up for memory allocation. Buffers are distributed such that + total buffer sizes of each rank are as even as possible. This is currently performed + using a greedy algorithm. We do not currently consider computational cost + or kernel launching overheads. + + Note: A better distribution strategy should try to minimize the delta of buffer sizes + between the most and the least allocated groups. + + Args: + buffer_sizes (tuple[int, ...]): Buffer sizes of blocks to be distributed. + + Returns: + buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the + buffer size for each block and its assigned rank. + + Example: + Assuming ALIGNMENT_BYTES = 64, given buffer_sizes = [128, 64, 500, 256], group_size = 2 + -> buffer_size_ranks = [(128, 1), (64, 1), (512, 0), (256, 1)] + + """ + ALIGNMENT_BYTES = ( + 64 # necessary for determining buffer size, possibly hardware-dependent + ) + + # Convert each of buffer_sizes into smallest multiple of ALIGNMENT_BYTES that is >= buffer size. + aligned_buffer_sizes = [ + (buffer_size + ALIGNMENT_BYTES - 1) // ALIGNMENT_BYTES * ALIGNMENT_BYTES + for buffer_size in buffer_sizes + ] + buffer_size_ranks = [(-1, -1)] * len(buffer_sizes) + allocated_buffer_sizes = [ + (0, group_index) for group_index in range(self._dist_group_size) + ] + heapq.heapify(allocated_buffer_sizes) + + for index, aligned_buffer_size in sorted( + enumerate(aligned_buffer_sizes), + key=operator.itemgetter(1), + reverse=True, + ): + # Greedily find the group with the least allocated buffer size and its group index + # in order to allocate buffers on that group. + ( + min_allocated_buffer_size, + min_allocated_buffer_size_group_index, + ) = heapq.heappop(allocated_buffer_sizes) + + heapq.heappush( + allocated_buffer_sizes, + ( + min_allocated_buffer_size + aligned_buffer_size, + min_allocated_buffer_size_group_index, + ), + ) + buffer_size_ranks[index] = ( + aligned_buffer_size, + min_allocated_buffer_size_group_index, + ) + + return tuple(buffer_size_ranks) + + def _construct_composable_block_ids( + self, + param_index: int, + block_index: int, + rank: int | None = None, + ) -> tuple[int, str]: + """Construct composable block ids for each parameter. + + Args: + param_index (int): Index of the current parameter within self._param_group[PARAMS]. + block_index (int): Block index that is accumulated across all parameters within a parameter group. + rank (int | None): Rank of this process group; should be non None in FullyShard/HybridShard setting. (Default: None) + + Returns: + tuple[int, str]: Composable block id tuple containing global block index and local block name. + The latter will be used to identify blocks in the masked tensor. + """ + return (param_index, f"rank_{rank}-block_{block_index}") + + @torch.no_grad() + def _construct_global_block_info_list( + self, buffer_size_ranks: tuple[tuple[int, int], ...] + ) -> None: + """Construct global block info list from param_group and num_blocks_within_param.""" + # Call `super()` instead of `self` as a performance optimization. + # This leads to O(1) instead of O(N) complexity to retrieve the parameters. + non_empty_params: Iterable[DTensor] = filter( + lambda p: p.to_local().numel() > 0, # type: ignore[arg-type] + super()._get_params_or_grads(), + ) + + # Note that for HybridShard, we want to get the rank within each sharded group for the block id. + # When using a device mesh, 0 corresponds to the replicated group and 1 corresponds to the sharded group. + sharded_group_rank = self._hybrid_shard_device_mesh.get_local_rank(1) + self._global_block_info_list: tuple[DDPBlockInfo, ...] = tuple( + DDPBlockInfo( + param=param, + composable_block_ids=self._construct_composable_block_ids( + param_index=param_index, + block_index=block_index, + rank=sharded_group_rank, + ), + allocate_zeros_tensor=partial( + self._allocate_zeros_distributed_tensor, + group_source_rank=group_source_rank, + ), + get_tensor=lambda input_tensor: ( + input_tensor.to_local() + if isinstance(input_tensor, dtensor.DTensor) + else input_tensor + ), + group_source_rank=group_source_rank, + ) + for ( + (param_index, param), + (buffer_size_ranks_start, buffer_size_ranks_end), + ) in zip( + enumerate(non_empty_params), + generate_pairwise_indices(self._global_num_blocks_per_param), + strict=True, + ) + for block_index, (_, group_source_rank) in enumerate( + islice( + buffer_size_ranks, buffer_size_ranks_start, buffer_size_ranks_end + ) + ) + ) + + @staticmethod + def _split_local_dist_buffers( + buffer_size_ranks: tuple[tuple[int, int], ...], + local_dist_buffers: tuple[torch.Tensor, ...], + ) -> tuple[torch.Tensor, ...]: + """Split distributed buffers for each local rank into views for each assigned block. + + Args: + buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the + buffer size and an assigned rank for each block. + local_dist_buffers (tuple[torch.Tensor, ...]): A list of local distributed buffers that + correspond to each rank. Each distributed buffer will be split according to the + assigned tensor blocks. + + Returns: + splitted_local_dist_buffers (tuple[torch.Tensor, ...]): A list of tuples containing a view of the + local distributed buffer for each tensor block. + + Example: + tensor0 = tensor(1024) + tensor1 = tensor(1024) + buffer_size_ranks = [(128, 0), (64, 0), (512, 1), (256, 0)] + local_dist_buffers = [tensor0, tensor1] + -> splitted_local_dist_buffers = [ + tensor0's view( 0-128 bytes), + tensor0's view(128-192 bytes), + tensor1's view( 0-512 bytes), + tensor0's view(192-448 bytes), + ] + + """ + + # Create list of lists containing local views of each split tensor for each rank. + split_tensors_list = [] + for rank, local_dist_buffer in enumerate(local_dist_buffers): + required_buffer_sizes = [s for s, r in buffer_size_ranks if r == rank] + remainder_size = local_dist_buffer.size(0) - sum(required_buffer_sizes) + assert ( + remainder_size >= 0 + ), f"Local distributed buffer size {local_dist_buffer.size(0)} is " + "not larger than or equal to the sum of buffer sizes {sum(required_buffer_sizes)}!" + split_tensors = torch.split( + local_dist_buffer, required_buffer_sizes + [remainder_size] + ) + split_tensors_list.append(split_tensors) + + # Obtain ordered buffer ranks containing (view of local buffer, rank). + splitted_local_dist_buffers = [] + buffer_indices = [0] * len( + local_dist_buffers + ) # index counter for each rank for obtaining right buffer + for _, rank in buffer_size_ranks: + splitted_local_dist_buffers.append( + split_tensors_list[rank][buffer_indices[rank]] + ) + buffer_indices[rank] += 1 + + return tuple(splitted_local_dist_buffers) + + def _construct_distributed_buffers( + self, + buffer_size_ranks: tuple[tuple[int, int], ...], + communication_dtype: torch.dtype, + comms_group_rank: int, + ) -> None: + """Construct the distributed buffers for AllGather communications. + + Note that this function will construct the distributed buffer for the AllGather + communication. In addition, it massages the distributed buffer to obtain views + of the buffer corresponding to each block assigned to the current rank. + + Args: + buffer_size_ranks (tuple[tuple[int, int], ...]): A list of tuples containing the + buffer size and an assigned rank for each block. + communication_dtype (torch.dtype): The data type used for communication. + comms_group_rank (int): The rank of the current group within the comms group. + + """ + + # Calculate buffer size each rank needs. + local_buffer_sizes = tuple( + sum(buffer_size for buffer_size, rank in buffer_size_ranks if rank == i) + for i in range(self._dist_group_size) + ) + + # Calculate the whole buffer size and obtain buffers for every rank. + max_buffer_size_sum = max(local_buffer_sizes) + total_buffer_size = max_buffer_size_sum * self._dist_group_size + self._global_dist_buffer = torch.zeros( + total_buffer_size, + dtype=torch.int8, + device=self._global_block_info_list[0].param.device, + ) + local_dist_buffers = torch.split(self._global_dist_buffer, max_buffer_size_sum) + splitted_local_dist_buffers = HybridShardDistributor._split_local_dist_buffers( + buffer_size_ranks, local_dist_buffers + ) + + # Get local buffer for specific group rank. + self._local_dist_buffer = local_dist_buffers[comms_group_rank] + + # Obtain the list of buffers corresponding to each block (ignoring padding). + # Note that each buffer is reshaped into the block's shape and viewed in terms + # of the communication data type. + self._global_dist_blocked_buffers = tuple( + buffer.split(blocked_param.numel() * get_dtype_size(communication_dtype))[0] + .view(communication_dtype) + .view(blocked_param.shape) + for buffer, blocked_param in zip( + splitted_local_dist_buffers, self._global_blocked_params, strict=True + ) + ) + self._local_dist_blocked_buffers = compress_list( + self._global_dist_blocked_buffers, self._distributor_selector + ) + self._global_masked_dist_blocked_buffers = self._global_dist_blocked_buffers + self._local_masked_dist_blocked_buffers = self._local_dist_blocked_buffers + + def merge_and_block_gradients( + self, + ) -> tuple[Tensor, ...]: + """Merge and block gradients. + + NOTE: This function MUST be called in the step function of the optimizer after the + gradient has been updated. + + Returns: + local_masked_blocked_grads (tuple[Tensor, ...]): Local blocked gradients masked with grad existence. + + """ + local_masked_blocked_grads = self._merge_and_block_gradients() + + if self._previous_global_grad_selector != self._global_grad_selector: + self._previous_global_grad_selector = self._global_grad_selector + + # Update _local_grad_selector and _local_masked_blocked_params only when global_grad_selector is changed. + self._local_grad_selector = compress_list( + self._global_grad_selector, + self._distributor_selector, + ) + self._local_masked_blocked_params = compress_list( + self._local_blocked_params, self._local_grad_selector + ) + + # Re-compress DDP-specific tensor lists using the updated selector. + self._global_masked_blocked_params = compress_list( + self._global_blocked_params, self._global_grad_selector + ) + self._global_masked_dist_blocked_buffers = compress_list( + self._global_dist_blocked_buffers, self._global_grad_selector + ) + self._local_masked_dist_blocked_buffers = compress_list( + self._local_dist_blocked_buffers, self._local_grad_selector + ) + + return local_masked_blocked_grads + + def _allocate_zeros_distributed_tensor( + self, + shape: tuple[int, ...], + dtype: torch.dtype, + device: torch.device, + group_source_rank: int, + ) -> torch.Tensor: + """Instantiates distributed tensor using DTensor. + + Args: + shape (shape type accepted by torch.zeros() including tuple[int, ...]): + Shape of desired tensor. + dtype (dtype type accepted by torch.zeros() including torch.dtype): + DType of desired tensor. + device (device type accepted by torch.zeros() including torch.device): + Device of desired tensor. + group_source_rank (int): Group rank (with respect to the sharded group of + the 2D submesh) that determines which ranks the DTensor is allocated on. + + Returns: + out (Tensor): Desired Tensor. + + """ + ranks_in_replicated_group = torch.tensor( + dist.get_process_group_ranks(self._hybrid_shard_device_mesh.get_group(0)) + ) + device_mesh_2d = get_device_mesh( + device_type=device.type, + mesh=tuple( + tuple(ranks_in_replicated_subgroup) + for ranks_in_replicated_subgroup in ranks_in_replicated_group.view( + -1, self._dist_group_size + ).tolist() + ), + mesh_dim_names=("replicate", "shard"), + ) + # NOTE: We get all submeshes along the "replicate" dimension, then pick out + # the sub-mesh that the optimizer state is assigned to. + # + # For the example above, this would give me submeshes [[3, 27], [11, 35], [19, 43]]. + # Note that the group source rank must belong to {0, 1, 2} in this case. + # Suppose the group_source_rank = 1, then this would get the submesh [11, 35]. + replicate_submesh = _mesh_resources._get_all_submeshes( + device_mesh_2d, "replicate" + )[group_source_rank] + + return dtensor_zeros( + shape, + dtype=dtype, + device_mesh=replicate_submesh, + placements=[dtensor.Replicate()], + ) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index ed15da8..d0a0597 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -15,26 +15,15 @@ from functools import partial, reduce from itertools import chain -from operator import methodcaller from typing import Any, cast, Generic, TypeVar import torch from distributed_shampoo.shampoo_types import ( - PrecisionConfig, PreconditionerConfig, PreconditionerValueError, ) from distributed_shampoo.utils.shampoo_block_info import BlockInfo -from distributed_shampoo.utils.shampoo_quantization import ( - QuantizedTensor, - QuantizedTensorList, -) -from distributed_shampoo.utils.shampoo_utils import ( - compress_list, - get_dtype_size, - ParameterizeEnterExitContext, -) - +from distributed_shampoo.utils.shampoo_utils import compress_list, get_dtype_size from matrix_functions import ( check_diagonal, compute_matrix_root_inverse_residuals, @@ -91,12 +80,6 @@ def compress_preconditioner_list( self, local_grad_selector: tuple[bool, ...] ) -> None: ... - @abstractmethod - def dequantize_preconditioners(self) -> None: ... - - @abstractmethod - def quantize_preconditioners(self) -> None: ... - @property def numel_list(self) -> tuple[int, ...]: return self._numel_list @@ -146,12 +129,6 @@ def compress_preconditioner_list( ) -> None: return - def dequantize_preconditioners(self) -> None: - return - - def quantize_preconditioners(self) -> None: - return - class AdagradPreconditionerList(PreconditionerList): """Adagrad / Adam / RMSProp preconditioners for a list of parameters. @@ -176,7 +153,6 @@ class AdagradPreconditionerList(PreconditionerList): beta2 (float): Exponential moving average factor for Adam/RMSprop second moment state. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-10) - precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) use_bias_correction (bool): Flag for using bias correction. (Default: False) """ @@ -188,7 +164,6 @@ def __init__( state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], distributor_selector: tuple[bool, ...], - precision_config: PrecisionConfig, beta2: float = 1.0, epsilon: float = 1e-10, use_bias_correction: bool = True, @@ -215,28 +190,23 @@ def __init__( # Instantiate AdaGrad optimizer state for this block. preconditioner_index = str(param_index) + "." + str(block_index) - block_state[ADAGRAD] = QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - shape=block.size(), - dtype=precision_config.grafting_state_dtype, - device=block.device, - ), - block_info=block_info, + block_state[ADAGRAD] = block_info.allocate_zeros_tensor( + shape=block.size(), + dtype=block.dtype, + device=block.device, ) - preconditioner_list.append(block_state[ADAGRAD]) + preconditioner_list.append(block_info.get_tensor(block_state[ADAGRAD])) logger.info( - f"Instantiated Adagrad Preconditioner {preconditioner_index} ({block_state[ADAGRAD].quantized_values.shape} with dtype {block_state[ADAGRAD].quantized_values.dtype}) " + f"Instantiated Adagrad Preconditioner {preconditioner_index} ({block_state[ADAGRAD].shape} with dtype {block_state[ADAGRAD].dtype}) " f"for Parameter {param_index} ({block_info.param.shape}), Block {block_index} ({block.shape})." ) # Masked lists are the list of active preconditioners or values after filtering out gradients with None. - self._local_preconditioner_list = QuantizedTensorList( - quantized_data=compress_list(preconditioner_list, distributor_selector), - quantized_dtype=precision_config.grafting_state_dtype, - computation_dtype=precision_config.computation_dtype, + self._local_preconditioner_list: tuple[Tensor, ...] = compress_list( + preconditioner_list, distributor_selector ) - self._masked_preconditioner_list: QuantizedTensorList = ( + self._masked_preconditioner_list: tuple[Tensor, ...] = ( self._local_preconditioner_list ) @@ -245,12 +215,11 @@ def __init__( self._dims_list, distributor_selector ) self._numel_list: tuple[int, ...] = tuple( - quantized_preconditioner.numel() - for quantized_preconditioner in self._local_preconditioner_list.quantized_value + preconditioner.numel() for preconditioner in self._local_preconditioner_list ) self._num_bytes_list: tuple[int, ...] = tuple( - quantize_preconditioner.numel() * quantize_preconditioner.element_size() - for quantize_preconditioner in self._local_preconditioner_list.quantized_value + preconditioner.numel() * preconditioner.element_size() + for preconditioner in self._local_preconditioner_list ) def update_preconditioners( @@ -264,17 +233,15 @@ def update_preconditioners( ): if self._beta2 == 1.0: torch._foreach_addcmul_( - self._masked_preconditioner_list.dequantized_value, + self._masked_preconditioner_list, masked_grad_list, masked_grad_list, value=1.0, ) else: - torch._foreach_mul_( - self._masked_preconditioner_list.dequantized_value, self._beta2 - ) + torch._foreach_mul_(self._masked_preconditioner_list, self._beta2) torch._foreach_addcmul_( - self._masked_preconditioner_list.dequantized_value, + self._masked_preconditioner_list, masked_grad_list, masked_grad_list, value=1 - self._beta2, @@ -298,7 +265,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): masked_bias_corrected_preconditioner_list = torch._foreach_div( - self._masked_preconditioner_list.dequantized_value, + self._masked_preconditioner_list, self._bias_correction2, ) torch._foreach_sqrt_(masked_bias_corrected_preconditioner_list) @@ -309,39 +276,22 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. masked_grad_list, masked_bias_corrected_preconditioner_list ) - def dequantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" - ): - self._masked_preconditioner_list.dequantize_() - - def quantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.quantize_preconditioners.__name__} ##" - ): - self._masked_preconditioner_list.quantize_() - def compress_preconditioner_list( self, local_grad_selector: tuple[bool, ...] ) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" ): - self._masked_preconditioner_list = self._local_preconditioner_list.compress( - local_grad_selector + self._masked_preconditioner_list = compress_list( + self._local_preconditioner_list, local_grad_selector ) -FactorMatricesType = TypeVar( - "FactorMatricesType", tuple[QuantizedTensor, ...], QuantizedTensorList -) - - @dataclass -class BaseShampooKroneckerFactors(Generic[FactorMatricesType], OptimizerModule): +class BaseShampooKroneckerFactors(OptimizerModule): """Base class for Shampoo Kronecker factors.""" - factor_matrices: FactorMatricesType + factor_matrices: tuple[Tensor, ...] factor_matrix_indices: tuple[str, ...] is_factor_matrices_diagonal: tuple[Tensor, ...] = field(init=False) @@ -354,12 +304,10 @@ def __post_init__(self) -> None: @dataclass -class ShampooKroneckerFactorsState( - BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] -): +class ShampooKroneckerFactorsState(BaseShampooKroneckerFactors): """Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" - inv_factor_matrices: tuple[QuantizedTensor, ...] + inv_factor_matrices: tuple[Tensor, ...] def __post_init__(self) -> None: super().__post_init__() @@ -367,10 +315,10 @@ def __post_init__(self) -> None: @dataclass -class ShampooKroneckerFactorsList(BaseShampooKroneckerFactors[QuantizedTensorList]): +class ShampooKroneckerFactorsList(BaseShampooKroneckerFactors): """Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" - inv_factor_matrices: QuantizedTensorList + inv_factor_matrices: tuple[Tensor, ...] def __post_init__(self) -> None: super().__post_init__() @@ -378,13 +326,11 @@ def __post_init__(self) -> None: @dataclass -class EigenvalueCorrectedShampooKroneckerFactorsState( - BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] -): +class EigenvalueCorrectedShampooKroneckerFactorsState(BaseShampooKroneckerFactors): """Eigenvalue-corrected Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" - factor_matrices_eigenvectors: tuple[QuantizedTensor, ...] - corrected_eigenvalues: QuantizedTensor + factor_matrices_eigenvectors: tuple[Tensor, ...] + corrected_eigenvalues: Tensor def __post_init__(self) -> None: super().__post_init__() @@ -392,13 +338,11 @@ def __post_init__(self) -> None: @dataclass -class EigenvalueCorrectedShampooKroneckerFactorsList( - BaseShampooKroneckerFactors[QuantizedTensorList] -): +class EigenvalueCorrectedShampooKroneckerFactorsList(BaseShampooKroneckerFactors): """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" - factor_matrices_eigenvectors: QuantizedTensorList - corrected_eigenvalues: QuantizedTensorList + factor_matrices_eigenvectors: tuple[Tensor, ...] + corrected_eigenvalues: Tensor def __post_init__(self) -> None: super().__post_init__() @@ -426,7 +370,6 @@ class BaseShampooPreconditionerList( Note that this should have the same length as block_list. distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. - precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) @@ -436,10 +379,7 @@ class BaseShampooPreconditionerList( If the order of the tensor exceeds the length of the list, we revert to using the default value. If 0 is used, uses the default inverse root -1 / (2 * o), where o is the order of the tensor. (Default: 0) use_bias_correction (bool): Flag for using bias correction. (Default: True) - use_protected_eigh (bool): Flag for using two guards to prevent failures of torch.linalg.eigh. (Default: True) - 1. Attempts to compute root inverse in preconditioner_dtype precision. - 2. Attempts to recompute the eigendecomposition if using lower-precision fails. - 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. + factor_matrix_dtype (torch.dtype): Data type for accumulating and computing root inverse of preconditioners. (Default: torch.float) """ @@ -450,24 +390,22 @@ def __init__( state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], distributor_selector: tuple[bool, ...], - precision_config: PrecisionConfig, preconditioner_config: PreconditionerConfig, beta2: float = 1.0, epsilon: float = 1e-12, inv_root_override: int | tuple[int, ...] = 0, use_bias_correction: bool = True, - use_protected_eigh: bool = True, + factor_matrix_dtype: torch.dtype = torch.float, ) -> None: super().__init__(block_list) # Initialize parameters. - self._precision_config = precision_config self._preconditioner_config = preconditioner_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override + self._factor_matrix_dtype = factor_matrix_dtype self._use_bias_correction = use_bias_correction - self._use_protected_eigh = use_protected_eigh self._bias_correction2: Tensor = torch.tensor(1.0) # Create the Kronecker factors. @@ -490,7 +428,7 @@ def _create_base_kronecker_factors( self, block_info: BlockInfo, dims: torch.Size, - ) -> BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]: + ) -> BaseShampooKroneckerFactors: """ Creates a BaseShampooKroneckerFactor object for a given block. @@ -499,16 +437,13 @@ def _create_base_kronecker_factors( dims (torch.Size): The dimensions of the block. Returns: - kronecker_factors_state (BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]): An object containing the Kronecker factors for the block. + kronecker_factors_state (BaseShampooKroneckerFactors): An object containing the Kronecker factors for the block. """ factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - shape=(dim, dim), - dtype=self._precision_config.factor_matrix_dtype, - device=block_info.param.device, - ), - block_info=block_info, + block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._factor_matrix_dtype, + device=block_info.param.device, ) for dim in dims ) @@ -548,12 +483,14 @@ def _create_kronecker_factors_list( self, kronecker_factors_state: ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState, + block_info: BlockInfo, ) -> ShampooKroneckerFactorsListType: """ Creates a ShampooKroneckerFactorsList object from the given ShampooKroneckerFactorsState. Args: kronecker_factors_state (ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState): The state containing the Kronecker factors. + block_info (BlockInfo): The BlockInfo object containing information about the block. Returns: kronecker_factors_list (ShampooKroneckerFactorsListType): A list of ShampooKroneckerFactors objects. @@ -586,7 +523,7 @@ def _create_kronecker_factors_state( ) kronecker_factors_list.append( - self._create_kronecker_factors_list(block_state[SHAMPOO]) + self._create_kronecker_factors_list(block_state[SHAMPOO], block_info) ) logger.info( @@ -765,10 +702,7 @@ def _initialize_state_lists( ) self._num_bytes_list: tuple[int, ...] = tuple( numel - * ( - get_dtype_size(self._precision_config.factor_matrix_dtype) - + get_dtype_size(block.dtype) - ) + * (get_dtype_size(self._factor_matrix_dtype) + get_dtype_size(block.dtype)) // 2 for numel, block in zip(self._numel_list, local_block_list, strict=True) ) @@ -805,9 +739,7 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: ): # Scale Kronecker factors as a list. if self._beta2 != 1.0: - torch._foreach_mul_( - kronecker_factors.factor_matrices.dequantized_value, self._beta2 - ) + torch._foreach_mul_(kronecker_factors.factor_matrices, self._beta2) # Construct outer product list for updating Kronecker factors. outer_product_list = tuple( @@ -822,7 +754,7 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: # Update Kronecker factors. torch._foreach_add_( - kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.factor_matrices, outer_product_list, alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, ) @@ -845,13 +777,10 @@ def _create_kronecker_factors_state_for_block( self, block: Tensor, block_info: BlockInfo, dims: torch.Size ) -> ShampooKroneckerFactorsState: inv_factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - shape=(dim, dim), - dtype=self._precision_config.inv_factor_matrix_dtype, - device=block_info.param.device, - ), - block_info=block_info, + block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=block.dtype, + device=block_info.param.device, ) for dim in dims ) @@ -869,20 +798,17 @@ def _create_kronecker_factors_list( self, kronecker_factors_state: ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState, + block_info: BlockInfo, ) -> ShampooKroneckerFactorsList: assert isinstance(kronecker_factors_state, ShampooKroneckerFactorsState) return ShampooKroneckerFactorsList( - # Factor matrices computation (accumulation, root inverse) should use the determined dtype. - factor_matrices=QuantizedTensorList( - quantized_data=kronecker_factors_state.factor_matrices, - quantized_dtype=self._precision_config.factor_matrix_dtype, - computation_dtype=self._precision_config.factor_matrix_computation_dtype, + factor_matrices=tuple( + block_info.get_tensor(t) + for t in kronecker_factors_state.factor_matrices ), - # Inverse factor matrices computation (preconditioning) should use the dtype of the block / gradient. - inv_factor_matrices=QuantizedTensorList( - quantized_data=kronecker_factors_state.inv_factor_matrices, - quantized_dtype=self._precision_config.inv_factor_matrix_dtype, - computation_dtype=self._precision_config.computation_dtype, + inv_factor_matrices=tuple( + block_info.get_tensor(t) + for t in kronecker_factors_state.inv_factor_matrices ), factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, ) @@ -912,7 +838,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. return tuple( self._precondition_grad( grad=masked_grad, - preconditioner_list=kronecker_factors.inv_factor_matrices.dequantized_value, + preconditioner_list=kronecker_factors.inv_factor_matrices, ) for masked_grad, kronecker_factors in zip( masked_grad_list, self._masked_kronecker_factors_list, strict=True @@ -940,8 +866,8 @@ def _amortized_computation(self) -> None: is_factor_matrix_diagonal, factor_matrix_index, ) in zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.inv_factor_matrices.dequantized_value, + kronecker_factors.factor_matrices, + kronecker_factors.inv_factor_matrices, kronecker_factors.is_factor_matrices_diagonal, kronecker_factors.factor_matrix_indices, strict=True, @@ -978,16 +904,13 @@ def _amortized_computation(self) -> None: is_diagonal=bool(is_factor_matrix_diagonal), ).to(dtype=inv_factor_matrix.dtype) except Exception as exception: - # If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails. - if not self._use_protected_eigh: - raise exception - else: - logger.warning( - f"Matrix computation failed for factor matrix {factor_matrix_index} " - f"with {exception=}. Using previous inverted factor matrix and continuing..." - ) - # Define computed_inv_factor_matrix to prevent undefined local variable error. - computed_inv_factor_matrix = inv_factor_matrix + # Reuse previous matrix if matrix inverse root computation fails. + logger.warning( + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous inverted factor matrix and continuing..." + ) + # Define computed_inv_factor_matrix to prevent undefined local variable error. + computed_inv_factor_matrix = inv_factor_matrix # Check if we encounter NaN or inf values in computed inverse matrix. if ( @@ -1001,22 +924,6 @@ def _amortized_computation(self) -> None: ) inv_factor_matrix.copy_(computed_inv_factor_matrix) - def dequantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" - ): - for kronecker_factors in self._masked_kronecker_factors_list: - kronecker_factors.factor_matrices.dequantize_() - kronecker_factors.inv_factor_matrices.dequantize_() - - def quantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.quantize_preconditioners.__name__} ##" - ): - for kronecker_factors in self._masked_kronecker_factors_list: - kronecker_factors.factor_matrices.quantize_() - kronecker_factors.inv_factor_matrices.quantize_() - @torch.compiler.disable def compute_root_inverse_residuals( self, @@ -1034,8 +941,8 @@ def compute_root_inverse_residuals( strict=True, ): for factor_matrix, inv_factor_matrix in zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.inv_factor_matrices.dequantized_value, + kronecker_factors.factor_matrices, + kronecker_factors.inv_factor_matrices, strict=True, ): bias_corrected_factor_matrix = factor_matrix / self._bias_correction2 @@ -1074,23 +981,17 @@ def _create_kronecker_factors_state_for_block( self, block: Tensor, block_info: BlockInfo, dims: torch.Size ) -> EigenvalueCorrectedShampooKroneckerFactorsState: factor_matrices_eigenvectors = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - shape=(dim, dim), - dtype=self._precision_config.factor_matrix_eigenvectors_dtype, - device=block_info.param.device, - ), - block_info=block_info, + block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=block.dtype, + device=block_info.param.device, ) for dim in dims ) - corrected_eigenvalues = QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - shape=tuple(dims), - dtype=self._precision_config.corrected_eigenvalues_dtype, - device=block_info.param.device, - ), - block_info=block_info, + corrected_eigenvalues = block_info.allocate_zeros_tensor( + shape=tuple(dims), + dtype=block.dtype, + device=block_info.param.device, ) base_kronecker_factors = self._create_base_kronecker_factors( @@ -1107,25 +1008,22 @@ def _create_kronecker_factors_list( self, kronecker_factors_state: ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState, + block_info: BlockInfo, ) -> EigenvalueCorrectedShampooKroneckerFactorsList: assert isinstance( kronecker_factors_state, EigenvalueCorrectedShampooKroneckerFactorsState ) return EigenvalueCorrectedShampooKroneckerFactorsList( - factor_matrices=QuantizedTensorList( - quantized_data=kronecker_factors_state.factor_matrices, - quantized_dtype=self._precision_config.factor_matrix_dtype, - computation_dtype=self._precision_config.factor_matrix_computation_dtype, + factor_matrices=tuple( + block_info.get_tensor(t) + for t in kronecker_factors_state.factor_matrices ), - factor_matrices_eigenvectors=QuantizedTensorList( - quantized_data=kronecker_factors_state.factor_matrices_eigenvectors, - quantized_dtype=self._precision_config.factor_matrix_eigenvectors_dtype, - computation_dtype=self._precision_config.computation_dtype, + factor_matrices_eigenvectors=tuple( + block_info.get_tensor(t) + for t in kronecker_factors_state.factor_matrices_eigenvectors ), - corrected_eigenvalues=QuantizedTensorList( - quantized_data=(kronecker_factors_state.corrected_eigenvalues,), - quantized_dtype=self._precision_config.corrected_eigenvalues_dtype, - computation_dtype=self._precision_config.computation_dtype, + corrected_eigenvalues=block_info.get_tensor( + kronecker_factors_state.corrected_eigenvalues ), factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, ) @@ -1179,9 +1077,7 @@ def _update_eigenvalue_corrections( self._masked_kronecker_factors_list, strict=True, ): - factor_eigenvectors = ( - kronecker_factors.factor_matrices_eigenvectors.dequantized_value - ) + factor_eigenvectors = kronecker_factors.factor_matrices_eigenvectors if factor_eigenvectors[0].any(): grad = self._precondition_grad( grad=grad, @@ -1190,11 +1086,9 @@ def _update_eigenvalue_corrections( # Scale corrected eigenvalues. # NOTE: The case when self._beta2 == 1.0 is not well tested and might not be stable. if self._beta2 != 1.0: - kronecker_factors.corrected_eigenvalues.dequantized_value[0].mul_( - self._beta2 - ) + kronecker_factors.corrected_eigenvalues.mul_(self._beta2) # Update corrected eigenvalues (squared gradient in eigenbasis of Shampoo preconditioner). - kronecker_factors.corrected_eigenvalues.dequantized_value[0].add_( + kronecker_factors.corrected_eigenvalues.add_( grad.square(), alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, ) @@ -1219,12 +1113,8 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. self._masked_root_list, strict=True, ): - factor_eigenvectors = ( - kronecker_factors.factor_matrices_eigenvectors.dequantized_value - ) - corrected_eigenvalues = ( - kronecker_factors.corrected_eigenvalues.dequantized_value[0] - ) + factor_eigenvectors = kronecker_factors.factor_matrices_eigenvectors + corrected_eigenvalues = kronecker_factors.corrected_eigenvalues use_eigenbasis = factor_eigenvectors[0].any() grad = masked_grad.clone() if use_eigenbasis: @@ -1267,8 +1157,8 @@ def _amortized_computation(self) -> None: is_factor_matrix_diagonal, factor_matrix_index, ) in zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.factor_matrices_eigenvectors.dequantized_value, + kronecker_factors.factor_matrices, + kronecker_factors.factor_matrices_eigenvectors, kronecker_factors.is_factor_matrices_diagonal, kronecker_factors.factor_matrix_indices, strict=True, @@ -1292,16 +1182,13 @@ def _amortized_computation(self) -> None: is_diagonal=bool(is_factor_matrix_diagonal), ) except Exception as exception: - # If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails. - if not self._use_protected_eigh: - raise exception - else: - logger.warning( - f"Matrix computation failed for factor matrix {factor_matrix_index} " - f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." - ) - # Define computed_eigenvectors to prevent undefined local variable error. - computed_eigenvectors = factor_matrix_eigenvectors + # Reuse previous matrix if matrix eigenvector computation fails. + logger.warning( + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." + ) + # Define computed_eigenvectors to prevent undefined local variable error. + computed_eigenvectors = factor_matrix_eigenvectors # Check if we encounter NaN or inf values in computed eigenvectors. if ( @@ -1314,42 +1201,3 @@ def _amortized_computation(self) -> None: f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}" ) factor_matrix_eigenvectors.copy_(computed_eigenvectors) - - def dequantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" - ): - for kronecker_factors in self._masked_kronecker_factors_list: - kronecker_factors.factor_matrices.dequantize_() - kronecker_factors.factor_matrices_eigenvectors.dequantize_() - kronecker_factors.corrected_eigenvalues.dequantize_() - - def quantize_preconditioners(self) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.quantize_preconditioners.__name__} ##" - ): - for kronecker_factors in self._masked_kronecker_factors_list: - kronecker_factors.factor_matrices.quantize_() - kronecker_factors.factor_matrices_eigenvectors.quantize_() - kronecker_factors.corrected_eigenvalues.quantize_() - - -class DequantizePreconditionersContext(ParameterizeEnterExitContext): - """DequantizePreconditionersContext is used for automatically dequantize and then quantize the preconditioners used within this context. - - Args: - preconditioner_list (PreconditionerList): Preconditioner list which contains the preconditioners to be dequantized and quantized. - - Examples: - >>> with DequantizePreconditionersContext(preconditioner_list): - >>> # Do something with the preconditioners which are dequantized. - >>> # After the context is exited, the preconditioners will be quantized. - - """ - - def __init__(self, preconditioner_list: PreconditionerList) -> None: - super().__init__( - input_with_enter_exit_context=preconditioner_list, - enter_method_caller=methodcaller("dequantize_preconditioners"), - exit_method_caller=methodcaller("quantize_preconditioners"), - ) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 8bb061f..dd41ae5 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -18,7 +18,6 @@ from distributed_shampoo.shampoo_types import ( DefaultEigenvalueCorrectedShampooConfig, DefaultShampooConfig, - PrecisionConfig, PreconditionerValueError, ShampooPreconditionerConfig, ) @@ -28,14 +27,12 @@ from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, BaseShampooPreconditionerList, - DequantizePreconditionersContext, EigenvalueCorrectedShampooPreconditionerList, PreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) -from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions import EigenConfig +from matrix_functions_types import EigenConfig from torch import Tensor @@ -77,27 +74,26 @@ def _test_update_preconditioners_and_precondition( masked_grad_lists: list[tuple[Tensor, ...]], masked_expected_preconditioned_grad_list: tuple[Tensor, ...] | None, ) -> None: - with DequantizePreconditionersContext(preconditioner_list=preconditioner_list): - for step, masked_grad_list in enumerate(masked_grad_lists, start=1): - preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list, - step=torch.tensor(step), - # Only update the complete preconditioner during the last call to update_preconditioners(). - perform_amortized_computation=isinstance( - preconditioner_list, BaseShampooPreconditionerList - ) - and step == len(masked_grad_lists), + for step, masked_grad_list in enumerate(masked_grad_lists, start=1): + preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + # Only update the complete preconditioner during the last call to update_preconditioners(). + perform_amortized_computation=isinstance( + preconditioner_list, BaseShampooPreconditionerList ) - masked_preconditioned_grad_list = preconditioner_list.precondition( - masked_grad_list=masked_grad_lists[-1] + and step == len(masked_grad_lists), ) - if masked_expected_preconditioned_grad_list is not None: - torch.testing.assert_close( - masked_preconditioned_grad_list, - masked_expected_preconditioned_grad_list, - ) - else: - self.assertIsNone(masked_preconditioned_grad_list) + masked_preconditioned_grad_list = preconditioner_list.precondition( + masked_grad_list=masked_grad_lists[-1] + ) + if masked_expected_preconditioned_grad_list is not None: + torch.testing.assert_close( + masked_preconditioned_grad_list, + masked_expected_preconditioned_grad_list, + ) + else: + self.assertIsNone(masked_preconditioned_grad_list) def test_update_preconditioners_and_precondition(self) -> None: masked_grad_list = ( @@ -131,16 +127,10 @@ def _test_compress_preconditioner_list( self, expected_compress_list_call_count: int, ) -> None: - with ( - mock.patch.object( - shampoo_preconditioner_list, - "compress_list", - ) as mock_compress_list, - mock.patch.object( - QuantizedTensorList, - "compress", - ) as mock_compress_quant_list, - ): + with mock.patch.object( + shampoo_preconditioner_list, + "compress_list", + ) as mock_compress_list: # Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList. # Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside. self.assertIsNone( @@ -149,7 +139,7 @@ def _test_compress_preconditioner_list( ) ) self.assertEqual( - mock_compress_list.call_count + mock_compress_quant_list.call_count, + mock_compress_list.call_count, expected_compress_list_call_count, ) @@ -195,7 +185,6 @@ def _instantiate_preconditioner_list( state=self._state, block_info_list=self._block_info_list, distributor_selector=self._distributor_selector, - precision_config=PrecisionConfig(), **kwargs, ) @@ -309,7 +298,6 @@ def test_abstract_methods(self) -> None: ), ), distributor_selector=(True,), - precision_config=PrecisionConfig(), preconditioner_config=DefaultShampooConfig, beta2=1.0, ) @@ -338,16 +326,9 @@ def _instantiate_preconditioner_list( # type: ignore[override] def _test_raise_invalid_value_in_factor_matrix( self, invalid_value: float ) -> None: - with ( - DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), - self.assertRaisesRegex( - PreconditionerValueError, - re.escape( - f"Encountered {str(invalid_value)} values in factor matrix" - ), - ), + with self.assertRaisesRegex( + PreconditionerValueError, + re.escape(f"Encountered {str(invalid_value)} values in factor matrix"), ): self._preconditioner_list.update_preconditioners( masked_grad_list=( @@ -371,21 +352,14 @@ def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation( self, ) -> None: for invalid_value in (torch.nan, torch.inf): - with ( - DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), - self.subTest(invalid_value=invalid_value), - self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in"), - ), - mock.patch.object( - shampoo_preconditioner_list, - self._amortized_computation_function(), - side_effect=(torch.tensor([invalid_value]),), - ) as mock_amortized_computation, - ): + with self.subTest(invalid_value=invalid_value), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in"), + ), mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + side_effect=(torch.tensor([invalid_value]),), + ) as mock_amortized_computation: self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -404,13 +378,7 @@ def test_amortized_computation_internal_failure(self) -> None: # Simulate the situation throws an exception (not nan and inf) to test the warning side_effect=ZeroDivisionError, ) as mock_amortized_computation: - with ( - DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), - self.assertLogs(level="WARNING") as cm, - ): - # Because use_protected_eigh is True, we expect the warning to be logged. + with self.assertLogs(level="WARNING") as cm: self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -434,27 +402,6 @@ def test_amortized_computation_internal_failure(self) -> None: mock_amortized_computation.assert_called() mock_amortized_computation.reset_mock() - # Turn off use_protected_eigh and expect ZeroDivisionError to be logged. - self._preconditioner_list = self._instantiate_preconditioner_list( - use_protected_eigh=False, - ) - with ( - DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), - self.assertRaises(ZeroDivisionError), - ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([1.0, 0.0]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[1.0, 0.0]]), - ), - step=torch.tensor(1), - perform_amortized_computation=True, - ) - mock_amortized_computation.assert_called() - # Note: This is needed for type checking to infer the type of argument into mock.patch.object. shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list @@ -469,14 +416,9 @@ def test_amortized_computation_factor_matrix_non_diagonal( self._preconditioner_list = self._instantiate_preconditioner_list( epsilon=1.0 ) - with ( - DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), - self.assertLogs( - level="DEBUG", - ) as cm, - ): + with self.assertLogs( + level="DEBUG", + ) as cm: self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -532,7 +474,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] "epsilon": 0.0, "inv_root_override": 0, "use_bias_correction": True, - "use_protected_eigh": True, "preconditioner_config": DefaultShampooConfig, } | kwargs return ShampooPreconditionerList( @@ -540,7 +481,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] state=self._state, block_info_list=self._block_info_list, distributor_selector=self._distributor_selector, - precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), + factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] ) @@ -775,30 +716,28 @@ def test_compute_root_inverse_residuals(self) -> None: state=self._state, block_info_list=(self._block_info_list[0],), distributor_selector=(self._distributor_selector[0],), - precision_config=PrecisionConfig(), preconditioner_config=DefaultShampooConfig, epsilon=0.0, ) masked_grad_list1 = (torch.tensor([1.0, 0.0]),) masked_grad_list2 = (torch.tensor([0.0, 1.0]),) - with DequantizePreconditionersContext(preconditioner_list=preconditioner_list): - preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list1, - step=torch.tensor(1), - perform_amortized_computation=False, - ) - preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list2, - step=torch.tensor(2), - perform_amortized_computation=True, - ) + preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list1, + step=torch.tensor(1), + perform_amortized_computation=False, + ) + preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list2, + step=torch.tensor(2), + perform_amortized_computation=True, + ) - # Expect no relative errors and residuals because L is a diagonal matrix. - ( - relative_errors, - relative_residuals, - ) = preconditioner_list.compute_root_inverse_residuals() + # Expect no relative errors and residuals because L is a diagonal matrix. + ( + relative_errors, + relative_residuals, + ) = preconditioner_list.compute_root_inverse_residuals() expected_relative_errors = (torch.tensor(0.0),) expected_relative_residuals = (torch.tensor(0.0),) @@ -821,7 +760,6 @@ def _instantiate_preconditioner_list( # type: ignore[override] "epsilon": 1e-12, "inv_root_override": 0, "use_bias_correction": True, - "use_protected_eigh": True, "preconditioner_config": DefaultEigenvalueCorrectedShampooConfig, } | kwargs return EigenvalueCorrectedShampooPreconditionerList( @@ -829,7 +767,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] state=self._state, block_info_list=self._block_info_list, distributor_selector=self._distributor_selector, - precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), + factor_matrix_dtype=torch.float64, **kwargs, # type: ignore[arg-type] ) diff --git a/distributed_shampoo/utils/tests/shampoo_quantization_test.py b/distributed_shampoo/utils/tests/shampoo_quantization_test.py index 13979d9..96a06b4 100644 --- a/distributed_shampoo/utils/tests/shampoo_quantization_test.py +++ b/distributed_shampoo/utils/tests/shampoo_quantization_test.py @@ -108,17 +108,14 @@ def test_invalid_quantized_data_type(self) -> None: class QuantizedTensorListInitTest(unittest.TestCase): def test_invalid_quantized_data_type(self) -> None: - with ( - mock.patch.object( - shampoo_quantization, - "isinstance", - side_effect=lambda object, classinfo: False, - ), - self.assertRaisesRegex( - TypeError, - re.escape( - "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " - ), + with mock.patch.object( + shampoo_quantization, + "isinstance", + side_effect=lambda object, classinfo: False, + ), self.assertRaisesRegex( + TypeError, + re.escape( + "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " ), ): QuantizedTensorList( diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 3d31a7b..dbd78be 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -231,27 +231,23 @@ def test_matrix_inverse_root_reach_max_iterations(self) -> None: implementation, msg, ) in root_inv_config_and_implementation_and_msg: - with ( - mock.patch.object( - matrix_functions, - implementation, - return_value=( - None, - None, - NewtonConvergenceFlag.REACHED_MAX_ITERS, - None, - None, - ), - ), - self.subTest( - root_inv_config=root_inv_config, - implementation=implementation, - msg=msg, + with mock.patch.object( + matrix_functions, + implementation, + return_value=( + None, + None, + NewtonConvergenceFlag.REACHED_MAX_ITERS, + None, + None, ), - self.assertLogs( - level="WARNING", - ) as cm, - ): + ), self.subTest( + root_inv_config=root_inv_config, + implementation=implementation, + msg=msg, + ), self.assertLogs( + level="WARNING", + ) as cm: matrix_inverse_root( A=A, root=root,