From 66f348c6496dae63bbce292f3d1de54a4ef01351 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Tue, 25 Feb 2025 14:55:20 -0800 Subject: [PATCH] Add ignored_dims config for single factor matrix enablement Summary: This diff adds a new configuration option called `ignored_dims` to `PreconditionerConfig`; this option allows the user to specify which dimensions of the matrix should be ignored when computing the preconditioner. Note that `PreconditionerConfig.ignored_dims` is not compatible with `inv_root_override`, and we plan to merge `inv_root_override` into `PreconditionerConfig.amortized_computation_config.exponent_multiplier` as `Preconditioner.exponent_override`, a list of float, representing that exponent overrides for each order of tensors. Given `Preconditioner.exponent_override=[e1, e2, ..., ep]`, then we will use ^e1 for 1-D tensors (vectors), ^e2 for 2-D tensors (matrices), and so on; when `ei=0` as the exponent for i-dimensional tensors which should result in no preconditioning for all i-dimensional tensors. On the other hands, setting `i in Preconditioner.ignored_dims` only results no preconditioning for i-th dimension for all tensors (if their orders are >= i). For example, if `Preconditioner.exponent_override=[0.5, 0.0, 0.25]` and `Preconditioner.ignored_dims=[0, 2]`, this means no preconditioning 1-D tensors (due to `0 in Preconditioner.ignored_dims` even though setting `Preconditioner.exponent_override[0]=0.5` is redundant), no preconditioning for 2-D tensors (due to `Preconditioner.exponent_override[1]=0.0`), and precondition the first and the second dimensions with ^0.25 (due to `Preconditioner.exponent_override[2]=0.25`) and no preconditioning the third dimension (due to `2 in Preconditioner.ignored_dims`) for 3-D tensors. Pair-programmed with runame. Reviewed By: anana10c Differential Revision: D70198403 fbshipit-source-id: 6ba4f84461cc32c185cac949b74c6cd51b33e795 --- distributed_shampoo/distributed_shampoo.py | 6 + .../shampoo_eigenvalue_correction_test.py | 77 ++++++++-- distributed_shampoo/shampoo_types.py | 8 + .../tests/distributed_shampoo_test.py | 16 ++ .../tests/shampoo_types_test.py | 14 ++ .../utils/shampoo_preconditioner_list.py | 141 ++++++++++++++---- .../tests/shampoo_preconditioner_list_test.py | 130 +++++++++++++++- 7 files changed, 341 insertions(+), 51 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 27d36ab..df463a0 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -385,6 +385,12 @@ def __init__( "Continuing without using momentum or Nesterov acceleration..." ) + # Check potential conflict between preconditioner_config.ignored_dims and inv_root_override. + if preconditioner_config.ignored_dims != [] and inv_root_override != 0: + raise ValueError( + f"{preconditioner_config.ignored_dims=} is not supported when {inv_root_override=} is not set to 0. Please set {inv_root_override=} to 0 if you set {preconditioner_config.ignored_dims=}." + ) + super().__init__( params, { diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py index 19e9177..dabefee 100644 --- a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -20,6 +20,7 @@ from distributed_shampoo.shampoo_types import ( DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig, + EigenvalueCorrectedShampooPreconditionerConfig, ) from distributed_shampoo.tests.shampoo_test_utils import ( compare_two_optimizers_on_weight_and_loss, @@ -48,13 +49,23 @@ def _optim_factory( return optim_cls(parameters, **kwargs) def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_config in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored. + for weight_decay, device, ( + start_preconditioning_step, + preconditioner_config, + ) in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), + ( + (math.inf, DefaultEigenvalueCorrectedShampooConfig), + (math.inf, DefaultSOAPConfig), + ( + 1, + EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]), + ), + ), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -64,6 +75,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, + start_preconditioning_step=start_preconditioning_step, preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( @@ -78,7 +90,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=math.inf, + start_preconditioning_step=start_preconditioning_step, use_decoupled_weight_decay=False, grafting_config=None, preconditioner_config=preconditioner_config, @@ -87,13 +99,23 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: ) def test_adam_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_config in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored. + for weight_decay, device, ( + start_preconditioning_step, + preconditioner_config, + ) in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), + ( + (math.inf, DefaultEigenvalueCorrectedShampooConfig), + (math.inf, DefaultSOAPConfig), + ( + 1, + EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]), + ), + ), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -104,6 +126,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, + start_preconditioning_step=start_preconditioning_step, preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( @@ -119,7 +142,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=math.inf, + start_preconditioning_step=start_preconditioning_step, use_decoupled_weight_decay=False, grafting_config=None, preconditioner_config=preconditioner_config, @@ -128,13 +151,23 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: ) def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_config in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored. + for weight_decay, device, ( + start_preconditioning_step, + preconditioner_config, + ) in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), + ( + (math.inf, DefaultEigenvalueCorrectedShampooConfig), + (math.inf, DefaultSOAPConfig), + ( + 1, + EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]), + ), + ), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -145,6 +178,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, + start_preconditioning_step=start_preconditioning_step, preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( @@ -160,7 +194,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=math.inf, + start_preconditioning_step=start_preconditioning_step, use_decoupled_weight_decay=True, grafting_config=None, preconditioner_config=preconditioner_config, @@ -169,13 +203,23 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: ) def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_config in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored. + for weight_decay, device, ( + start_preconditioning_step, + preconditioner_config, + ) in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), + ( + (math.inf, DefaultEigenvalueCorrectedShampooConfig), + (math.inf, DefaultSOAPConfig), + ( + 1, + EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]), + ), + ), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -185,6 +229,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, + start_preconditioning_step=start_preconditioning_step, preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( @@ -202,7 +247,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: momentum=0.0, max_preconditioner_dim=10, precondition_frequency=1, - start_preconditioning_step=math.inf, + start_preconditioning_step=start_preconditioning_step, use_decoupled_weight_decay=False, grafting_config=None, use_bias_correction=False, diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index b621571..74bc7e6 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -86,17 +86,23 @@ class PreconditionerConfig(AbstractDataclass): Attributes: amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation. num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: []) """ amortized_computation_config: MatrixFunctionConfig # type: ignore num_tolerated_failed_amortized_computations: int = 3 + ignored_dims: list[int] = field(default_factory=list) def __post_init__(self) -> None: if self.num_tolerated_failed_amortized_computations < 0: raise ValueError( f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0." ) + if len(self.ignored_dims) != len(set(self.ignored_dims)): + raise ValueError( + f"Invalid ignored_dims value: {self.ignored_dims}. Must be a list of unique dimensions." + ) @dataclass(kw_only=True) @@ -106,6 +112,7 @@ class ShampooPreconditionerConfig(PreconditionerConfig): Attributes: amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig) num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: []) """ @@ -125,6 +132,7 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. (Default: DefaultEighEigenvectorConfig) num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: []) """ diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 296c962..20baf7b 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -181,6 +181,22 @@ def test_invalid_distributed_config(self) -> None: distributed_config=DDPShampooConfig(), ) + def test_ignored_dims_conflicts_with_inv_root_override(self) -> None: + inv_root_override = 2 + preconditioner_config = ShampooPreconditionerConfig( + ignored_dims=[1, 3], + ) + self.assertRaisesRegex( + ValueError, + re.escape( + f"{preconditioner_config.ignored_dims=} is not supported when {inv_root_override=} is not set to 0. Please set {inv_root_override=} to 0 if you set {preconditioner_config.ignored_dims=}." + ), + DistributedShampoo, + params=self._model.parameters(), + inv_root_override=inv_root_override, + preconditioner_config=preconditioner_config, + ) + class DistributedShampooTest(unittest.TestCase): def setUp(self) -> None: diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 2fbebbb..6766e41 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -98,3 +98,17 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: cls, num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations, ) + + def test_illegal_ignored_dims(self) -> None: + ignored_dims = [1, 2, 3, 1] + # Not testing for the base class PreconditionerConfig because it is an abstract class. + for cls in get_all_subclasses(PreconditionerConfig, include_cls_self=False): + with self.subTest(cls=cls): + self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid ignored_dims value: {ignored_dims}. Must be a list of unique dimensions." + ), + cls, + ignored_dims=ignored_dims, + ) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index ed3601c..a8c4bda 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from fractions import Fraction -from functools import partial, reduce +from functools import reduce from itertools import chain from typing import Any, cast, Generic, TypeVar @@ -391,7 +391,7 @@ class BaseShampooPreconditionerList( state (Mapping[Tensor, Any]): Mapping containing optimizer state. block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) + preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) @@ -428,12 +428,24 @@ def __init__( self._use_bias_correction = use_bias_correction self._bias_correction2: Tensor = torch.tensor(1.0) + preconditioned_dims_selector_list: tuple[tuple[bool, ...], ...] = tuple( + tuple(d not in preconditioner_config.ignored_dims for d in range(len(dims))) + for dims in self._dims_list + ) + preconditioned_dims_list: tuple[tuple[int, ...], ...] = tuple( + compress_list(dims, preconditioned_dims_selector) + for dims, preconditioned_dims_selector in zip( + self._dims_list, preconditioned_dims_selector_list, strict=True + ) + ) + # Create the Kronecker factors. kronecker_factors_list: list[ShampooKroneckerFactorsListType] = ( self._create_kronecker_factors_state( block_list=block_list, state=state, block_info_list=block_info_list, + preconditioned_dims_list=preconditioned_dims_list, ) ) @@ -441,19 +453,21 @@ def __init__( self._initialize_state_lists( block_list=block_list, kronecker_factors_list=kronecker_factors_list, + preconditioned_dims_list=preconditioned_dims_list, + preconditioned_dims_selector_list=preconditioned_dims_selector_list, ) def _create_base_kronecker_factors( self, block_info: BlockInfo, - dims: torch.Size, + preconditioned_dims: tuple[int, ...], ) -> BaseShampooKroneckerFactors: """ Creates a BaseShampooKroneckerFactor object for a given block. Args: block_info (BlockInfo): The BlockInfo object containing information about the block. - dims (torch.Size): The dimensions of the block. + preconditioned_dims (tuple[int, ...]): The preconditioned dimensions of the block. Returns: kronecker_factors_state (BaseShampooKroneckerFactors): An object containing the Kronecker factors for the block. @@ -464,13 +478,13 @@ def _create_base_kronecker_factors( dtype=self._factor_matrix_dtype, device=block_info.param.device, ) - for dim in dims + for dim in preconditioned_dims ) param_index, block_index = block_info.composable_block_ids factor_matrix_indices = tuple( ".".join((str(param_index), str(block_index), str(k))) - for k in range(len(dims)) + for k in range(len(preconditioned_dims)) ) return BaseShampooKroneckerFactors( factor_matrices=factor_matrices, @@ -483,6 +497,7 @@ def _create_kronecker_factors_state_for_block( block: Tensor, block_info: BlockInfo, dims: torch.Size, + preconditioned_dims: tuple[int, ...], ) -> ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState: """ Creates a Kronecker factors state object for a given block. @@ -491,6 +506,7 @@ def _create_kronecker_factors_state_for_block( block (Tensor): The block of the parameter. block_info (BlockInfo): The BlockInfo object containing information about the block. dims (torch.Size): The dimensions of the block. + preconditioned_dims (tuple[int, ...]): The preconditioned dimensions of the block. Returns: kronecker_factors_state (ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState): An object containing the Kronecker factors for the block. @@ -522,6 +538,7 @@ def _create_kronecker_factors_state( # type: ignore state: Mapping[Tensor, Any], block_info_list: tuple[BlockInfo, ...], + preconditioned_dims_list: tuple[tuple[int, ...], ...], ) -> list[ShampooKroneckerFactorsListType]: # Instantiate (blocked) Kronecker factors and construct list of Kronecker factors. # NOTE: We need to instantiate the Kronecker factor states within the optimizer's state dictionary, @@ -529,8 +546,12 @@ def _create_kronecker_factors_state( # This is because the optimizer state is defined per-parameter, but ShampooPreconditionerList is defined # across each parameter group (which includes multiple parameters). kronecker_factors_list = [] - for block, block_info, dims in zip( - block_list, block_info_list, self._dims_list, strict=True + for block, block_info, dims, preconditioned_dims in zip( + block_list, + block_info_list, + self._dims_list, + preconditioned_dims_list, + strict=True, ): param_index, block_index = block_info.composable_block_ids if block_index not in state[block_info.param]: @@ -538,7 +559,10 @@ def _create_kronecker_factors_state( block_state = state[block_info.param][block_index] block_state[SHAMPOO] = self._create_kronecker_factors_state_for_block( - block=block, block_info=block_info, dims=dims + block=block, + block_info=block_info, + dims=dims, + preconditioned_dims=preconditioned_dims, ) kronecker_factors_list.append( @@ -591,7 +615,7 @@ def _get_inverse_roots_from_override_with_high_order_default( higher_order_default (Callable[[int], int]): Function for computing the inverse root for orders greater than the length of the inverse root override list. Returns: - root_list (int): Inverse roots to use in Shampoo for a list of tensors. + root_list (tuple[int, ...]): Inverse roots to use in Shampoo for a list of tensors. """ if isinstance(inv_root_override, Sequence): @@ -726,6 +750,8 @@ def _initialize_state_lists( self, block_list: tuple[Tensor, ...], kronecker_factors_list: list[ShampooKroneckerFactorsListType], + preconditioned_dims_list: tuple[tuple[int, ...], ...], + preconditioned_dims_selector_list: tuple[tuple[bool, ...], ...], ) -> None: # Initialize local lists. self._local_kronecker_factors_list: tuple[ @@ -742,6 +768,9 @@ def _initialize_state_lists( self._local_failed_amortized_computation_counter_list: list[int] = [0] * len( self._local_kronecker_factors_list ) + self._local_preconditioned_dims_selector_list: tuple[tuple[bool, ...], ...] = ( + preconditioned_dims_selector_list + ) # Masked lists are the list of active preconditioners or values after filtering out gradients with None. self._masked_order_list: tuple[int, ...] = self._local_order_list @@ -753,11 +782,15 @@ def _initialize_state_lists( ShampooKroneckerFactorsListType, ..., ] = self._local_kronecker_factors_list + self._masked_preconditioned_dims_selector_list: tuple[tuple[bool, ...], ...] = ( + self._local_preconditioned_dims_selector_list + ) # Construct lists of bytes and numels for logging purposes. # NOTE: These lists are constructed across all blocked parameters. self._numel_list: tuple[int, ...] = tuple( - sum(2 * dim**2 for dim in dims) for dims in self._dims_list + sum(2 * dim**2 for dim in preconditioned_dims) + for preconditioned_dims in preconditioned_dims_list ) self._num_bytes_list: tuple[int, ...] = tuple( numel @@ -790,6 +823,9 @@ def compress_preconditioner_list( ShampooKroneckerFactorsListType, ..., ] = compress_list(self._local_kronecker_factors_list, local_grad_selector) + self._masked_preconditioned_dims_selector_list = compress_list( # type: ignore[no-redef] + self._local_preconditioned_dims_selector_list, local_grad_selector + ) def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: with profiler.record_function( @@ -798,14 +834,15 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. # We apply foreach operators onto the list of Kronecker factor matrices (as opposed to the # full list of gradients/optimizer states). - for grad, order, kronecker_factors in zip( + for grad, order, preconditioned_dims_selector, kronecker_factors in zip( masked_grad_list, self._masked_order_list, + self._masked_preconditioned_dims_selector_list, self._masked_kronecker_factors_list, strict=True, ): # Scale Kronecker factors as a list. - if self._beta2 != 1.0: + if self._beta2 != 1.0 and kronecker_factors.factor_matrices: torch._foreach_mul_(kronecker_factors.factor_matrices, self._beta2) # Construct outer product list for updating Kronecker factors. @@ -816,23 +853,38 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: # Contracts across all dimensions except for k. dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] ) - for k in range(order) + for k in compress_list(range(order), preconditioned_dims_selector) ) - # Update Kronecker factors. - torch._foreach_add_( - kronecker_factors.factor_matrices, - outer_product_list, - alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, - ) + # Because of preconditioned_dims_selector, we may have no factor matrices to update. + if kronecker_factors.factor_matrices: + # Update Kronecker factors. + torch._foreach_add_( + kronecker_factors.factor_matrices, + outer_product_list, + alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, + ) @staticmethod def _precondition_grad( grad: Tensor, + preconditioned_dims_selector: tuple[bool, ...], preconditioner_list: tuple[Tensor, ...], dims: tuple[list[int], list[int]] = ([0], [0]), ) -> Tensor: - return reduce(partial(torch.tensordot, dims=dims), preconditioner_list, grad) + # TODO: Need to refactor this function to be more efficient. Ideally eliminate those branches. + # Might consider einsum? + preconditioner_list_iter = iter(preconditioner_list) + return reduce( + lambda grad, should_precondition: torch.tensordot( + grad, next(preconditioner_list_iter), dims=dims + ) + if should_precondition + # Perform a left rotation on grad if not preconditioned. + else grad.permute(*range(1, grad.ndim), 0), + preconditioned_dims_selector, + grad, + ) class ShampooPreconditionerList( @@ -841,7 +893,11 @@ class ShampooPreconditionerList( """Shampoo preconditioners for list of parameters.""" def _create_kronecker_factors_state_for_block( - self, block: Tensor, block_info: BlockInfo, dims: torch.Size + self, + block: Tensor, + block_info: BlockInfo, + dims: torch.Size, + preconditioned_dims: tuple[int, ...], ) -> ShampooKroneckerFactorsState: inv_factor_matrices = tuple( block_info.allocate_zeros_tensor( @@ -849,11 +905,11 @@ def _create_kronecker_factors_state_for_block( dtype=block.dtype, device=block_info.param.device, ) - for dim in dims + for dim in preconditioned_dims ) base_kronecker_factors = self._create_base_kronecker_factors( - block_info=block_info, dims=dims + block_info=block_info, preconditioned_dims=preconditioned_dims ) return ShampooKroneckerFactorsState( factor_matrices=base_kronecker_factors.factor_matrices, @@ -905,10 +961,14 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. return tuple( self._precondition_grad( grad=masked_grad, + preconditioned_dims_selector=preconditioned_dims_selector, preconditioner_list=kronecker_factors.inv_factor_matrices, ) - for masked_grad, kronecker_factors in zip( - masked_grad_list, self._masked_kronecker_factors_list, strict=True + for masked_grad, preconditioned_dims_selector, kronecker_factors in zip( + masked_grad_list, + self._masked_preconditioned_dims_selector_list, + self._masked_kronecker_factors_list, + strict=True, ) ) @@ -1013,7 +1073,11 @@ class EigenvalueCorrectedShampooPreconditionerList( """Eigenvalue-corrected Shampoo preconditioners for list of parameters.""" def _create_kronecker_factors_state_for_block( - self, block: Tensor, block_info: BlockInfo, dims: torch.Size + self, + block: Tensor, + block_info: BlockInfo, + dims: torch.Size, + preconditioned_dims: tuple[int, ...], ) -> EigenvalueCorrectedShampooKroneckerFactorsState: factor_matrices_eigenvectors = tuple( block_info.allocate_zeros_tensor( @@ -1021,16 +1085,17 @@ def _create_kronecker_factors_state_for_block( dtype=block.dtype, device=block_info.param.device, ) - for dim in dims + for dim in preconditioned_dims ) corrected_eigenvalues = block_info.allocate_zeros_tensor( + # Note that the corrected eigenvalues are not affected by the preconditioned_dims. size=tuple(dims), dtype=block.dtype, device=block_info.param.device, ) base_kronecker_factors = self._create_base_kronecker_factors( - block_info=block_info, dims=dims + block_info=block_info, preconditioned_dims=preconditioned_dims ) return EigenvalueCorrectedShampooKroneckerFactorsState( factor_matrices=base_kronecker_factors.factor_matrices, @@ -1107,15 +1172,17 @@ def _update_eigenvalue_corrections( f"## {self.__class__.__name__}:{self._update_eigenvalue_corrections.__name__} ##" ): # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. - for grad, kronecker_factors in zip( + for grad, preconditioned_dims_selector, kronecker_factors in zip( masked_grad_list, + self._masked_preconditioned_dims_selector_list, self._masked_kronecker_factors_list, strict=True, ): factor_eigenvectors = kronecker_factors.factor_matrices_eigenvectors - if factor_eigenvectors[0].any(): + if factor_eigenvectors and factor_eigenvectors[0].any(): grad = self._precondition_grad( grad=grad, + preconditioned_dims_selector=preconditioned_dims_selector, preconditioner_list=factor_eigenvectors, ) # Scale corrected eigenvalues. @@ -1142,20 +1209,27 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): preconditioned_grad_list = [] - for masked_grad, kronecker_factors, root in zip( + for ( + masked_grad, + preconditioned_dims_selector, + kronecker_factors, + root, + ) in zip( masked_grad_list, + self._masked_preconditioned_dims_selector_list, self._masked_kronecker_factors_list, self._masked_root_list, strict=True, ): factor_eigenvectors = kronecker_factors.factor_matrices_eigenvectors corrected_eigenvalues = kronecker_factors.corrected_eigenvalues - use_eigenbasis = factor_eigenvectors[0].any() + use_eigenbasis = factor_eigenvectors and factor_eigenvectors[0].any() grad = masked_grad.clone() if use_eigenbasis: # Convert to eigenbasis of Shampoo factor matrices. grad = self._precondition_grad( grad=grad, + preconditioned_dims_selector=preconditioned_dims_selector, preconditioner_list=factor_eigenvectors, ) @@ -1169,6 +1243,7 @@ def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, .. # Convert back to basis of the parameters. grad = self._precondition_grad( grad=grad, + preconditioned_dims_selector=preconditioned_dims_selector, preconditioner_list=factor_eigenvectors, dims=([0], [1]), ) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 8d11822..14bbab0 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -8,6 +8,7 @@ """ import abc +import math import re import unittest from types import ModuleType @@ -21,7 +22,6 @@ PreconditionerValueError, ShampooPreconditionerConfig, ) - from distributed_shampoo.utils import shampoo_preconditioner_list from distributed_shampoo.utils.shampoo_block_info import BlockInfo from distributed_shampoo.utils.shampoo_preconditioner_list import ( @@ -32,6 +32,8 @@ SGDPreconditionerList, ShampooPreconditionerList, ) + +from distributed_shampoo.utils.shampoo_utils import compress_list from matrix_functions_types import EigenConfig from torch import Tensor @@ -575,6 +577,63 @@ def test_amortized_computation_factor_matrix_non_diagonal( ) mock_check_diagonal.assert_called() + def test_precondition_grad(self) -> None: + # Generate a random gradient tensor with shape (2, 3, 4, 5, 6, 7). + grad = torch.randn((2, 3, 4, 5, 6, 7)) + + # Define selectors for which dimensions to precondition in the experimental setup. + # Note that in the control setup, we will precondtion all dimensions normally except for the `False` ones with identity matrices. + experimental_preconditioned_dims_selector = ( + True, + False, + False, + True, + True, + False, + ) + # Define selectors for which dimensions to precondition in the control setup. + control_preconditioned_dims_selector = (True,) * grad.ndim + + # Create a list of random preconditioner matrices for each dimension of the gradient. + preconditioner_list = [torch.randn((d, d)) for d in grad.shape] + + # Compress the preconditioner list based on experimental_preconditioned_dims_selector. + experimental_preconditioner_list = compress_list( + preconditioner_list, + experimental_preconditioned_dims_selector, + ) + + # Create a control preconditioner list, using identity matrices where not preconditioning. + control_preconditioner_list = [ + preconditioner + if should_precondition + else torch.eye(preconditioner.shape[0]) + for preconditioner, should_precondition in zip( + preconditioner_list, + experimental_preconditioned_dims_selector, + strict=True, + ) + ] + + # Compare the results of preconditioning the gradient with both setups for different contract dimensions. + for dims in (([0], [0]), ([0], [1])): + torch.testing.assert_close( + self._preconditioner_list._precondition_grad( + grad=grad, + preconditioned_dims_selector=experimental_preconditioned_dims_selector, + preconditioner_list=experimental_preconditioner_list, + dims=dims, + ), + self._preconditioner_list._precondition_grad( + grad=grad, + preconditioned_dims_selector=control_preconditioned_dims_selector, + preconditioner_list=control_preconditioner_list, + dims=dims, + ), + rtol=0.0, + atol=0.0, + ) + def test_numel_list(self) -> None: self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) @@ -594,7 +653,7 @@ def test_num_bytes(self) -> None: self.assertEqual(self._preconditioner_list.num_bytes(), 204) def test_compress_preconditioner_list(self) -> None: - self._test_compress_preconditioner_list(expected_compress_list_call_count=4) + self._test_compress_preconditioner_list(expected_compress_list_call_count=5) class ShampooPreconditionerListTest(AbstractTest.BaseShampooPreconditionerListTest): @@ -764,6 +823,73 @@ def test_update_preconditioners_and_precondition(self) -> None: ), ) + def test_update_preconditioners_and_precondition_with_ignored_dims(self) -> None: + """ + + (1) Tensor of Size 2 + G1 = [4, 0]^T + G2 = [0, 4]^T + + L = G1 * G1^T + G2 * G2^T = [[4*4, 0], [0, 4*4]] + P = L^{-1/2} G2 = [[1/4, 0], [0, 1/4]] G2 = [0, 1]^T + + (2) Tensor of Size 2 x 2 + G1 = [[3, 0], [0, 3]] + G2 = [[4, 0], [0, 4]] + + L = G1 * G1^T + G2 * G2^T = [[3*3+4*4, 0], [0, 3*3+4*4]] + R = G1^T * G1 + G2^T * G2 = [[3*3+4*4, 0], [0, 3*3+4*4]] + P = L^{-1/4} G2 R^{-1/4} = [[1/sqrt(5), 0], [0, 1/sqrt(5)]] G2 [[1/sqrt(5), 0], [0, 1/sqrt(5)]] = G2 / 5 + + (3) Tensor of Size 1 x 2 + G1 = [[2, 0]] + G2 = [[0, 2]] + + L = G1 * G1^T + G2 * G2^T = 2*2+2*2 = 8 + R = G1^T * G1 + G2^T * G2 = [[4, 0], [0, 4]] + P = L^{-1/4} G2 R^{-1/4} = 8^{-1/4} G2 [[1/sqrt(2), 0], [0, 1/sqrt(2)]] = G2 / (sqrt(2 * sqrt(8))) + + """ + masked_grad_list1 = ( + torch.tensor([4.0, 0.0]), + torch.eye(2) * 3, + torch.tensor([[2.0, 0.0]]), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 4.0]), + torch.eye(2) * 4, + torch.tensor([[0.0, 2.0]]), + ) + + masked_expected_preconditioned_grad_list = [ + torch.tensor([0.0, 1.0]), + masked_grad_list2[1] / 5, + masked_grad_list2[2] / math.sqrt(2 * math.sqrt(8)), + ] + + # The default case where we do not ignore any dimensions. + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + # When ignoring all the dimensions, the preconditioner should be the identity matrix, and the expected preconditioned gradient should be the same as the input gradient. + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + preconditioner_config=ShampooPreconditionerConfig( + ignored_dims=[0, 1], + ), + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_grad_list2, + ) + def test_inv_root_override_and_exponent_multiplier(self) -> None: """ For this example, we modify the one given above such that the inv_root_override = 2