From 55f1413a5b3e73896e43ceed851aaa0b1dc7b014 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Dec 2024 15:27:30 +0000 Subject: [PATCH 01/22] Refactor matrix functions types --- matrix_functions.py | 20 ++++++++++---------- matrix_functions_types.py | 28 ++++++++++++++-------------- tests/matrix_functions_test.py | 16 +++++++--------- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/matrix_functions.py b/matrix_functions.py index 70b1366..9a14093 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -20,11 +20,11 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, + DefaultEighConfig, EigenConfig, - EigenvalueCorrectionConfig, - EighEigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + EighConfig, + QRConfig, RootInvConfig, ) @@ -599,7 +599,7 @@ def compute_matrix_root_inverse_residuals( def matrix_eigenvectors( A: Tensor, eigenvectors_estimate: Tensor | None = None, - eigenvector_computation_config: EigenvalueCorrectionConfig = DefaultEighEigenvalueCorrectionConfig, + eigenvector_computation_config: EigenvectorConfig = DefaultEighConfig, is_diagonal: bool = False, ) -> Tensor: """Compute eigenvectors of matrix using eigendecomposition of symmetric positive (semi-)definite matrix. @@ -612,8 +612,8 @@ def matrix_eigenvectors( A (Tensor): Square matrix of interest. eigenvectors_estimate (Tensor | None): The current estimate of the eigenvectors of A. (Default: None) - eigenvector_computation_config (EigenvalueCorrectionConfig): Determines how eigenvectors are computed. - (Default: DefaultEighEigenvalueCorrectionConfig) + eigenvector_computation_config (EigenvectorConfig): Determines how eigenvectors are computed. + (Default: DefaultEighConfig) is_diagonal (bool): Whether A is diagonal. (Default: False) Returns: @@ -638,15 +638,15 @@ def matrix_eigenvectors( device=A.device, ) - if type(eigenvector_computation_config) is EighEigenvalueCorrectionConfig: + if type(eigenvector_computation_config) is EighConfig: return _compute_eigenvalue_decomposition( A, retry_double_precision=eigenvector_computation_config.retry_double_precision, )[1] - elif type(eigenvector_computation_config) is QREigenvalueCorrectionConfig: + elif type(eigenvector_computation_config) is QRConfig: assert ( eigenvectors_estimate is not None - ), "Estimate of eigenvectors is required when using QREigenvalueCorrectionConfig." + ), "Estimate of eigenvectors is required when using QRConfig." return _compute_orthogonal_iterations( A, eigenvectors_estimate=eigenvectors_estimate, diff --git a/matrix_functions_types.py b/matrix_functions_types.py index af267d3..4a58450 100644 --- a/matrix_functions_types.py +++ b/matrix_functions_types.py @@ -13,18 +13,18 @@ @dataclass(init=False) -class PreconditionerComputationConfig(AbstractDataclass): - """Configuration for preconditioner computation in Shampoo.""" +class MatrixFunctionConfig(AbstractDataclass): + """Base dataclass for matrix function configurations.""" @dataclass(init=False) -class RootInvConfig(PreconditionerComputationConfig): - """Base dataclass for matrix root inverse method configurations in Shampoo.""" +class RootInvConfig(MatrixFunctionConfig): + """Base dataclass for matrix root inverse (`matrix_inverse_root`) method configurations.""" @dataclass(kw_only=True) class EigenConfig(RootInvConfig): - """Configuration for eigendecomposition method in Shampoo. + """Configuration for eigendecomposition (`_matrix_inverse_root_eigen`) method. Args: make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True) @@ -45,7 +45,7 @@ class EigenConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledNewtonConfig(RootInvConfig): - """Configuration for coupled Newton method in Shampoo. + """Configuration for coupled Newton (`_matrix_inverse_root_newton`) method. Args: max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100) @@ -59,7 +59,7 @@ class CoupledNewtonConfig(RootInvConfig): @dataclass(kw_only=True) class CoupledHigherOrderConfig(RootInvConfig): - """Configuration for coupled higher-order method in Shampoo. + """Configuration for coupled higher-order (`_matrix_inverse_root_higher_order`) method. Args: rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix @@ -81,13 +81,13 @@ class CoupledHigherOrderConfig(RootInvConfig): @dataclass(init=False) -class EigenvalueCorrectionConfig(PreconditionerComputationConfig): - """Base dataclass for matrix eigenvector method configurations in eigenvalue-corrected Shampoo.""" +class EigenvectorConfig(MatrixFunctionConfig): + """Base dataclass for matrix eigenvector (`matrix_eigenvectors`) method.""" @dataclass(kw_only=True) -class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig): - """Configuration for eigendecomposition method used in eigenvalue-corrected Shampoo. +class EighConfig(EigenvectorConfig): + """Configuration for eigendecomposition (`_compute_eigenvalue_decomposition`) method. Args: retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due @@ -98,12 +98,12 @@ class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig): retry_double_precision: bool = True -DefaultEighEigenvalueCorrectionConfig = EighEigenvalueCorrectionConfig() +DefaultEighConfig = EighConfig() @dataclass(kw_only=True) -class QREigenvalueCorrectionConfig(EigenvalueCorrectionConfig): - """Configuration for orthogonal/simultaneous iterations (QR algorithm) used in eigenvalue-corrected Shampoo. +class QRConfig(EigenvectorConfig): + """Configuration for orthogonal/simultaneous iterations/QR algorithm (`_compute_orthogonal_iterations`). Args: max_iterations (int): The maximum number of iterations to perform. (Default: 1) diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 1658201..61f7069 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -35,8 +35,8 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, EigenConfig, - EigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + QRConfig, RootInvConfig, ) from torch import Tensor @@ -859,18 +859,16 @@ def test_matrix_eigenvectors(self) -> None: rtol=rtol, ) - # Tests for `QREigenvalueCorrectionConfig`. + # Tests for `QRConfig`. initialization_strategies = { "zero": lambda A: torch.zeros_like(A), "identity": lambda A: torch.eye(A.shape[0], dtype=A.dtype, device=A.device), "exact": lambda A: matrix_eigenvectors(A), # Eigendecomposition. } for name, initialization_fn in initialization_strategies.items(): - with self.subTest( - f"Test with QREigenvalueCorrectionConfig with {name} initialization." - ): + with self.subTest(f"Test with QRConfig with {name} initialization."): # Set `max_iterations` to large int to run until numerical tolerance. - qr_config = QREigenvalueCorrectionConfig(max_iterations=10_000) + qr_config = QRConfig(max_iterations=10_000) for A, expected_eigenvectors in zip( A_list, expected_eigenvectors_list, strict=True ): @@ -899,12 +897,12 @@ def test_invalid_eigenvalue_correction_config( mock.patch.object( matrix_functions, "type", - side_effect=lambda object: EigenvalueCorrectionConfig, + side_effect=lambda object: EigenvectorConfig, ), self.assertRaisesRegex( NotImplementedError, re.escape( - "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighConfig(retry_double_precision=True)." ), ), ): From 9291d1599edd4c8532e04ae85f8094893bc3b3b9 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Dec 2024 15:34:34 +0000 Subject: [PATCH 02/22] Refactor Shampoo types --- distributed_shampoo/distributed_shampoo.py | 60 +++++++++++-------- distributed_shampoo/examples/trainer_utils.py | 21 ++++--- .../shampoo_eigenvalue_correction_test.py | 16 ++--- distributed_shampoo/shampoo_types.py | 60 +++++++++++++++++++ .../tests/distributed_shampoo_test.py | 22 +++---- .../utils/shampoo_preconditioner_list.py | 48 ++++++++------- .../tests/shampoo_preconditioner_list_test.py | 23 +++++-- 7 files changed, 171 insertions(+), 79 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index ef8e2ce..67d3fa4 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -24,8 +24,10 @@ BETAS, DAMPENING, DDPShampooConfig, + DefaultShampooConfig, DistributedConfig, DISTRIBUTOR, + EigenvalueCorrectedShampooPreconditionerConfig, EPSILON, FILTERED_GRAD, FILTERED_GRAD_LIST, @@ -49,10 +51,12 @@ PrecisionConfig, PRECONDITION_FREQUENCY, PRECONDITIONER_COMPUTATION_CONFIG, + PreconditionerComputationConfig, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, START_PRECONDITIONING_STEP, STEP, @@ -91,13 +95,7 @@ ) from distributed_shampoo.utils.shampoo_utils import compress_list -from matrix_functions_types import ( - DefaultEigenConfig, - EigenConfig, - EigenvalueCorrectionConfig, - PreconditionerComputationConfig, - RootInvConfig, -) +from matrix_functions_types import EigenConfig, RootInvConfig from torch.optim.optimizer import ParamsT, StateDict logger: logging.Logger = logging.getLogger(__name__) @@ -216,7 +214,7 @@ class DistributedShampoo(torch.optim.Optimizer): updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner, also known as SOAP. - When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning + When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. @@ -236,8 +234,8 @@ class DistributedShampoo(torch.optim.Optimizer): weight_decay (float): Weight decay (L2 penalty). (Default: 0.) max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024) precondition_frequency (int): Frequency of updating all components of the preconditioner. - If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner. - If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner. + If this field is an instance ShampooPreconditionerConfig, this is the update frequency of the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig, this is the update frequency of the eigenbasis of preconditioner. (Default: 1) start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses the same value as precondition_frequency. (Default: -1) @@ -245,7 +243,7 @@ class DistributedShampoo(torch.optim.Optimizer): use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of - EigenvalueCorrectionConfig, the default is -1 / 2. + EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2. (Default: 0) exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: None) @@ -272,8 +270,8 @@ class DistributedShampoo(torch.optim.Optimizer): track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. - If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner. - If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. + If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. (Default: DefaultEigenConfig) """ @@ -305,7 +303,7 @@ def __init__( precision_config: PrecisionConfig | None = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, - preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig = DefaultShampooConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -419,11 +417,14 @@ def __init__( "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." ) + amortized_computation_config = ( + preconditioner_computation_config.amortized_computation_config + ) if ( - not isinstance(preconditioner_computation_config, RootInvConfig) + not isinstance(amortized_computation_config, RootInvConfig) ) and track_root_inv_residuals: raise ValueError( - f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig." + f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig." ) # Create default precision config if it is not provided. @@ -432,14 +433,14 @@ def __init__( # Set exponent multiplier if this is not provided. if ( - isinstance(preconditioner_computation_config, EigenConfig) + isinstance(amortized_computation_config, EigenConfig) and exponent_multiplier is not None ): logger.warning( f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future." ) - preconditioner_computation_config = dataclasses.replace( - preconditioner_computation_config, + amortized_computation_config = dataclasses.replace( + amortized_computation_config, exponent_multiplier=exponent_multiplier, ) @@ -534,13 +535,22 @@ def _instantiate_shampoo_preconditioner_list( for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( - EigenvalueCorrectedShampooPreconditionerList - if isinstance( - group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig + if ( + type(group[PRECONDITIONER_COMPUTATION_CONFIG]) + is ShampooPreconditionerConfig + ): + preconditioner_list_cls = ShampooPreconditionerList + elif ( + type(group[PRECONDITIONER_COMPUTATION_CONFIG]) + is EigenvalueCorrectedShampooPreconditionerConfig + ): + preconditioner_list_cls = EigenvalueCorrectedShampooPreconditionerList # type: ignore[assignment] + else: + raise NotImplementedError( + f"{group[PRECONDITIONER_COMPUTATION_CONFIG]=} not supported!" ) - else ShampooPreconditionerList - )( + + state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( block_list=state_lists[DISTRIBUTOR].global_blocked_params, state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 635a57b..e819d72 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -24,16 +24,17 @@ CommunicationDType, CoupledHigherOrderConfig, CoupledNewtonConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + DefaultSOAPConfig, DistributedConfig, DistributedShampoo, - EigenConfig, - EighEigenvalueCorrectionConfig, GraftingConfig, PrecisionConfig, PreconditionerComputationConfig, - QREigenvalueCorrectionConfig, RMSpropGraftingConfig, SGDGraftingConfig, + ShampooPreconditionerConfig, ) from distributed_shampoo.examples.convnet import ConvNet @@ -541,27 +542,31 @@ def instantiate_preconditioner_computation_config( preconditioner_computation_type: PreconditionerComputationType, ) -> PreconditionerComputationConfig: if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV: - return EigenConfig() + return DefaultShampooConfig elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV ): - return CoupledNewtonConfig() + return ShampooPreconditionerConfig( + amortized_computation_config=CoupledNewtonConfig(), + ) elif ( preconditioner_computation_type == PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV ): - return CoupledHigherOrderConfig() + return ShampooPreconditionerConfig( + amortized_computation_config=CoupledHigherOrderConfig(), + ) elif ( preconditioner_computation_type == PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION ): - return EighEigenvalueCorrectionConfig() + return DefaultEigenvalueCorrectedShampooConfig elif ( preconditioner_computation_type == PreconditionerComputationType.QR_EIGENVALUE_CORRECTION ): - return QREigenvalueCorrectionConfig() + return DefaultSOAPConfig else: raise ValueError( f"Invalid PreconditionerComputationType {preconditioner_computation_type}!" diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py index f63ab7a..1d0b6ec 100644 --- a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -17,13 +17,13 @@ import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo +from distributed_shampoo.shampoo_types import ( + DefaultEigenvalueCorrectedShampooConfig, + DefaultSOAPConfig, +) from distributed_shampoo.tests.shampoo_test_utils import ( compare_two_optimizers_on_weight_and_loss, ) -from matrix_functions_types import ( - DefaultEighEigenvalueCorrectionConfig, - QREigenvalueCorrectionConfig, -) from torch.optim.adagrad import Adagrad from torch.optim.adam import Adam from torch.optim.adamw import AdamW @@ -54,7 +54,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -93,7 +93,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -134,7 +134,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -175,7 +175,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), - (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), + (DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index e0dd6a4..42af3ce 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -13,6 +13,15 @@ import torch from commons import AbstractDataclass + +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighConfig, + EigenvectorConfig, + MatrixFunctionConfig, + QRConfig, + RootInvConfig, +) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.parameter import Parameter @@ -71,6 +80,57 @@ class PreconditionerValueError(ValueError): ###### DATACLASSES ###### +@dataclass(init=False) +class PreconditionerComputationConfig(AbstractDataclass): + """Configuration for preconditioner computation in DistributedShampoo. + + Args: + amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation. + + """ + + amortized_computation_config: MatrixFunctionConfig + + +@dataclass(kw_only=True) +class ShampooPreconditionerConfig(PreconditionerComputationConfig): + """Configuration for Shampoo preconditioner computation. + + Args: + amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. + + """ + + amortized_computation_config: RootInvConfig + + +DefaultShampooConfig = ShampooPreconditionerConfig( + amortized_computation_config=DefaultEigenConfig +) + + +@dataclass(kw_only=True) +class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerComputationConfig): + """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. + + Args: + amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. + + """ + + amortized_computation_config: EigenvectorConfig + + +DefaultEigenvalueCorrectedShampooConfig = ( + EigenvalueCorrectedShampooPreconditionerConfig( + amortized_computation_config=DefaultEighConfig, + ) +) +DefaultSOAPConfig = EigenvalueCorrectedShampooPreconditionerConfig( + amortized_computation_config=QRConfig(), +) + + @dataclass class FSDPParameterMetadata: """FSDP Metadata for a parameter. diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 396db52..05aa644 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -22,6 +22,8 @@ from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, DDPShampooConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, DistributedConfig, GRAFTING_PRECONDITIONER_LIST, GraftingConfig, @@ -33,10 +35,6 @@ ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import ( - DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, -) from torch import nn @@ -49,7 +47,11 @@ def setUp(self) -> None: def test_invalid_grafting_config(self) -> None: with ( mock.patch.object( - distributed_shampoo, "type", side_effect=lambda object: GraftingConfig + distributed_shampoo, + "type", + side_effect=lambda object: GraftingConfig + if type(object) is SGDGraftingConfig + else type(object), ), self.assertRaisesRegex( NotImplementedError, @@ -251,7 +253,7 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: lr=0.01, start_preconditioning_step=1, exponent_multiplier=2.0, - preconditioner_computation_config=DefaultEigenConfig, + preconditioner_computation_config=DefaultShampooConfig, ) self.assertCountEqual( [r.msg for r in cm.records], @@ -264,7 +266,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No with self.assertRaisesRegex( ValueError, re.escape( - "track_root_inv_residuals=True has to be set to False when preconditioner_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True) is not an instance of RootInvConfig." + "track_root_inv_residuals=True has to be set to False when amortized_computation_config=EighConfig(retry_double_precision=True) is not an instance of RootInvConfig." ), ): DistributedShampoo( @@ -272,7 +274,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No lr=0.01, start_preconditioning_step=1, track_root_inv_residuals=True, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, ) @@ -495,7 +497,7 @@ def setUp(self) -> None: ), "use_merge_dims": True, "precision_config": PrecisionConfig(), - "preconditioner_computation_config": DefaultEigenConfig, + "preconditioner_computation_config": DefaultShampooConfig, } }, } @@ -889,7 +891,7 @@ def _instantiate_optimizer( distributed_config=None, grafting_config=None, precision_config=precision_config, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, ) def _assert_state_list_dtype( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 7bb3231..4c89442 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -19,7 +19,11 @@ from typing import Any, cast, Generic, TypeVar import torch -from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError +from distributed_shampoo.shampoo_types import ( + PrecisionConfig, + PreconditionerComputationConfig, + PreconditionerValueError, +) from distributed_shampoo.utils.shampoo_block_info import BlockInfo from distributed_shampoo.utils.shampoo_quantization import ( QuantizedTensor, @@ -38,12 +42,7 @@ matrix_inverse_root, ) -from matrix_functions_types import ( - DefaultEigenConfig, - EigenvalueCorrectionConfig, - PreconditionerComputationConfig, - RootInvConfig, -) +from matrix_functions_types import EigenvectorConfig, RootInvConfig from optimizer_modules import OptimizerModule from torch import Tensor from torch.autograd import profiler @@ -428,7 +427,7 @@ class BaseShampooPreconditionerList( distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) - preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultEigenConfig) + preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) @@ -452,7 +451,7 @@ def __init__( block_info_list: tuple[BlockInfo, ...], distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, - preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig, beta2: float = 1.0, epsilon: float = 1e-12, inv_root_override: int | tuple[int, ...] = 0, @@ -959,20 +958,22 @@ def _amortized_computation(self) -> None: ) # Compute inverse preconditioner. + root_inv_config = cast( + RootInvConfig, + self._preconditioner_computation_config.amortized_computation_config, + ) try: computed_inv_factor_matrix = matrix_inverse_root( A=bias_corrected_factor_matrix, root=Fraction( root / getattr( - self._preconditioner_computation_config, + root_inv_config, "exponent_multiplier", 1, ) ), - root_inv_config=cast( - RootInvConfig, self._preconditioner_computation_config - ), + root_inv_config=root_inv_config, epsilon=self._epsilon, is_diagonal=bool(is_factor_matrix_diagonal), ).to(dtype=inv_factor_matrix.dtype) @@ -983,7 +984,7 @@ def _amortized_computation(self) -> None: else: logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " - f"with {exception=}. Using previous inversed factor matrix and continuing..." + f"with {exception=}. Using previous inverted factor matrix and continuing..." ) # Define computed_inv_factor_matrix to prevent undefined local variable error. computed_inv_factor_matrix = inv_factor_matrix @@ -1020,6 +1021,10 @@ def quantize_preconditioners(self) -> None: def compute_root_inverse_residuals( self, ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: + root_inv_config = cast( + RootInvConfig, + self._preconditioner_computation_config.amortized_computation_config, + ) relative_errors = [] relative_residuals = [] @@ -1043,15 +1048,13 @@ def compute_root_inverse_residuals( root=Fraction( root / getattr( - self._preconditioner_computation_config, + root_inv_config, "exponent_multiplier", 1, ) ), epsilon=self._epsilon, - root_inv_config=cast( - RootInvConfig, self._preconditioner_computation_config - ), + root_inv_config=root_inv_config, ) relative_errors.append(relative_error) relative_residuals.append(relative_residual) @@ -1277,14 +1280,15 @@ def _amortized_computation(self) -> None: ) # Compute eigenvectors of factor matrix. + eigenvector_computation_config = cast( + EigenvectorConfig, + self._preconditioner_computation_config.amortized_computation_config, + ) try: computed_eigenvectors = matrix_eigenvectors( A=factor_matrix, eigenvectors_estimate=factor_matrix_eigenvectors, - eigenvector_computation_config=cast( - EigenvalueCorrectionConfig, - self._preconditioner_computation_config, - ), + eigenvector_computation_config=eigenvector_computation_config, is_diagonal=bool(is_factor_matrix_diagonal), ) except Exception as exception: diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 6d38af6..54f87c0 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -15,7 +15,13 @@ from unittest import mock import torch -from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError +from distributed_shampoo.shampoo_types import ( + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + PrecisionConfig, + PreconditionerValueError, + ShampooPreconditionerConfig, +) from distributed_shampoo.utils import shampoo_preconditioner_list from distributed_shampoo.utils.shampoo_block_info import BlockInfo @@ -29,7 +35,7 @@ ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig, EigenConfig +from matrix_functions import EigenConfig from torch import Tensor @@ -304,6 +310,7 @@ def test_abstract_methods(self) -> None: ), distributor_selector=(True,), precision_config=PrecisionConfig(), + preconditioner_computation_config=DefaultShampooConfig, beta2=1.0, ) @@ -526,6 +533,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, + "preconditioner_computation_config": DefaultShampooConfig, } | kwargs return ShampooPreconditionerList( block_list=self._block_list, @@ -533,7 +541,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] block_info_list=self._block_info_list, distributor_selector=self._distributor_selector, precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), - **kwargs, + **kwargs, # type: ignore[arg-type] ) def test_update_preconditioners_and_precondition(self) -> None: @@ -718,7 +726,9 @@ def test_inverse_roots_from_override( """ Tests that the inverse roots are computed correctly from inv_root_override. """ - preconditioner_computation_config = EigenConfig(exponent_multiplier=2.0) + preconditioner_computation_config = ShampooPreconditionerConfig( + amortized_computation_config=EigenConfig(exponent_multiplier=2.0), + ) masked_grad_list1 = ( torch.tensor([1.0, 0.0]), @@ -766,6 +776,7 @@ def test_compute_root_inverse_residuals(self) -> None: block_info_list=(self._block_info_list[0],), distributor_selector=(self._distributor_selector[0],), precision_config=PrecisionConfig(), + preconditioner_computation_config=DefaultShampooConfig, epsilon=0.0, ) @@ -811,7 +822,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, - "preconditioner_computation_config": DefaultEighEigenvalueCorrectionConfig, + "preconditioner_computation_config": DefaultEigenvalueCorrectedShampooConfig, } | kwargs return EigenvalueCorrectedShampooPreconditionerList( block_list=self._block_list, @@ -1044,7 +1055,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, From dac3ac1e467bde41a2f1547190fac6d1128a7238 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Dec 2024 15:35:14 +0000 Subject: [PATCH 03/22] Adjust UI and docs --- distributed_shampoo/README.md | 15 +++++++------ distributed_shampoo/__init__.py | 38 +++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 0d35aa2..9b1e48d 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -64,7 +64,7 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. -- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. +- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum @@ -221,7 +221,7 @@ optimizer = DistributedShampoo( ) ``` -### Example 5: eigenvalue-corrected Shampoo (SOAP) +### Example 5: eigenvalue-corrected Shampoo/SOAP If we previously used the optimizer: ```python @@ -241,7 +241,10 @@ optimizer = AdamW( we would instead use: ```python import torch -from distributed_shampoo import DistributedShampoo, EighEigenvalueCorrectionConfig +from distributed_shampoo import ( + DistributedShampoo, + DefaultEigenvalueCorrectedShampooConfig, +) model = instantiate_model() @@ -254,9 +257,9 @@ optimizer = DistributedShampoo( max_preconditioner_dim=8192, precondition_frequency=100, use_decoupled_weight_decay=True, - # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive - # and might therefore allow for a smaller `precondition_frequency`. - preconditioner_computation_config=EighEigenvalueCorrectionConfig(), + # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is + # less expensive and might thereby allow for a smaller `precondition_frequency`. + preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, ) ``` diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 2b2c45f..6ba8c2a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -13,14 +13,20 @@ AdamGraftingConfig, CommunicationDType, DDPShampooConfig, + DefaultEigenvalueCorrectedShampooConfig, + DefaultShampooConfig, + DefaultSOAPConfig, DistributedConfig, + EigenvalueCorrectedShampooPreconditionerConfig, FSDPShampooConfig, FullyShardShampooConfig, GraftingConfig, HSDPShampooConfig, PrecisionConfig, + PreconditionerComputationConfig, RMSpropGraftingConfig, SGDGraftingConfig, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata @@ -28,12 +34,9 @@ CoupledHigherOrderConfig, CoupledNewtonConfig, DefaultEigenConfig, - DefaultEighEigenvalueCorrectionConfig, EigenConfig, - EigenvalueCorrectionConfig, - EighEigenvalueCorrectionConfig, - PreconditionerComputationConfig, - QREigenvalueCorrectionConfig, + EigenvectorConfig, + MatrixFunctionConfig, RootInvConfig, ) @@ -58,15 +61,22 @@ "PrecisionConfig", # `preconditioner_computation_config` options. "PreconditionerComputationConfig", # Abstract base class. - "RootInvConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EigenConfig", - "DefaultEigenConfig", # Default `RootInvConfig`. - "CoupledNewtonConfig", - "CoupledHigherOrderConfig", - "EigenvalueCorrectionConfig", # Abstract base class (based on `PreconditionerComputationConfig`). - "EighEigenvalueCorrectionConfig", - "DefaultEighEigenvalueCorrectionConfig", # Default `EigenvalueCorrectionConfig`. - "QREigenvalueCorrectionConfig", + "ShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + "DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`. + "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`. + # matrix functions configs. + "MatrixFunctionConfig", # Abstract base class. + "RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EigenConfig", # Based on `RootInvConfig`. + "DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`. + "CoupledNewtonConfig", # Based on `RootInvConfig`. + "CoupledHigherOrderConfig", # Based on `RootInvConfig`. + "EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`). + "EighConfig", # Based on `EigenvectorConfig`. + "DefaultEighConfig", # Default `EigenvectorConfig` using `EighConfig`. + "QRConfig", # Based on `EigenvectorConfig`. # Other utilities. "compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`. "CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`. From 7e20cf37b2b1b22ae342bcbbf40346cc3b98bd93 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Dec 2024 20:26:59 +0000 Subject: [PATCH 04/22] Replace preconditioner_computation_config with preconditioner_config --- distributed_shampoo/README.md | 4 +- distributed_shampoo/__init__.py | 10 ++--- distributed_shampoo/distributed_shampoo.py | 33 +++++++--------- distributed_shampoo/examples/trainer_utils.py | 8 ++-- .../shampoo_eigenvalue_correction_test.py | 24 ++++++------ distributed_shampoo/shampoo_types.py | 8 ++-- .../tests/distributed_shampoo_test.py | 39 ++++++++++++++----- .../utils/shampoo_preconditioner_list.py | 14 +++---- .../tests/shampoo_preconditioner_list_test.py | 14 +++---- 9 files changed, 83 insertions(+), 71 deletions(-) diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 9b1e48d..ef9fd53 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -64,7 +64,7 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. -- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. +- When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig` (see Example 5), there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum @@ -259,7 +259,7 @@ optimizer = DistributedShampoo( use_decoupled_weight_decay=True, # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is # less expensive and might thereby allow for a smaller `precondition_frequency`. - preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) ``` diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 6ba8c2a..da71b9e 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -23,7 +23,7 @@ GraftingConfig, HSDPShampooConfig, PrecisionConfig, - PreconditionerComputationConfig, + PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, ShampooPreconditionerConfig, @@ -59,11 +59,11 @@ "HSDPShampooConfig", # `precision_config`. "PrecisionConfig", - # `preconditioner_computation_config` options. - "PreconditionerComputationConfig", # Abstract base class. - "ShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + # `preconditioner_config` options. + "PreconditionerConfig", # Abstract base class. + "ShampooPreconditionerConfig", # Based on `PreconditionerConfig`. "DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`. - "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerComputationConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`. # matrix functions configs. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 67d3fa4..376daad 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -50,8 +50,8 @@ PRECISION_CONFIG, PrecisionConfig, PRECONDITION_FREQUENCY, - PRECONDITIONER_COMPUTATION_CONFIG, - PreconditionerComputationConfig, + PRECONDITIONER_CONFIG, + PreconditionerConfig, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, SGDGraftingConfig, @@ -214,7 +214,7 @@ class DistributedShampoo(torch.optim.Optimizer): updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner, also known as SOAP. - When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning + When setting `preconditioner_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. @@ -242,7 +242,7 @@ class DistributedShampoo(torch.optim.Optimizer): inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse - root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of + root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2. (Default: 0) exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the @@ -269,7 +269,7 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) - preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. + preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner. If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. (Default: DefaultEigenConfig) @@ -303,7 +303,7 @@ def __init__( precision_config: PrecisionConfig | None = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, - preconditioner_computation_config: PreconditionerComputationConfig = DefaultShampooConfig, + preconditioner_config: PreconditionerConfig = DefaultShampooConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -418,7 +418,7 @@ def __init__( ) amortized_computation_config = ( - preconditioner_computation_config.amortized_computation_config + preconditioner_config.amortized_computation_config ) if ( not isinstance(amortized_computation_config, RootInvConfig) @@ -464,7 +464,7 @@ def __init__( GRAFTING_CONFIG: grafting_config, USE_MERGE_DIMS: use_merge_dims, PRECISION_CONFIG: precision_config, - PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config, + PRECONDITIONER_CONFIG: preconditioner_config, }, ) @@ -535,19 +535,16 @@ def _instantiate_shampoo_preconditioner_list( for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - if ( - type(group[PRECONDITIONER_COMPUTATION_CONFIG]) - is ShampooPreconditionerConfig - ): + if type(group[PRECONDITIONER_CONFIG]) is ShampooPreconditionerConfig: preconditioner_list_cls = ShampooPreconditionerList elif ( - type(group[PRECONDITIONER_COMPUTATION_CONFIG]) + type(group[PRECONDITIONER_CONFIG]) is EigenvalueCorrectedShampooPreconditionerConfig ): preconditioner_list_cls = EigenvalueCorrectedShampooPreconditionerList # type: ignore[assignment] else: raise NotImplementedError( - f"{group[PRECONDITIONER_COMPUTATION_CONFIG]=} not supported!" + f"{group[PRECONDITIONER_CONFIG]=} not supported!" ) state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( @@ -555,9 +552,7 @@ def _instantiate_shampoo_preconditioner_list( state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, - preconditioner_computation_config=group[ - PRECONDITIONER_COMPUTATION_CONFIG - ], + preconditioner_config=group[PRECONDITIONER_CONFIG], precision_config=group[PRECISION_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], @@ -598,9 +593,7 @@ def _instantiate_grafting(self) -> None: is AdamGraftingConfig, ) else: - raise NotImplementedError( - f"Unsupported grafting config: {group[GRAFTING_CONFIG]=}." - ) + raise NotImplementedError(f"{group[GRAFTING_CONFIG]=} not supported!") @torch.no_grad() def _instantiate_steps(self) -> None: diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index e819d72..313d980 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -31,7 +31,7 @@ DistributedShampoo, GraftingConfig, PrecisionConfig, - PreconditionerComputationConfig, + PreconditionerConfig, RMSpropGraftingConfig, SGDGraftingConfig, ShampooPreconditionerConfig, @@ -498,7 +498,7 @@ def instantiate_optimizer( precision_config=precision_config, use_protected_eigh=use_protected_eigh, track_root_inv_residuals=track_root_inv_residuals, - preconditioner_computation_config=instantiate_preconditioner_computation_config( + preconditioner_config=instantiate_preconditioner_config( preconditioner_computation_type ), ) # type: ignore[assignment] @@ -538,9 +538,9 @@ def instantiate_grafting_config( raise ValueError(f"Invalid GraftingType {grafting_type}!") -def instantiate_preconditioner_computation_config( +def instantiate_preconditioner_config( preconditioner_computation_type: PreconditionerComputationType, -) -> PreconditionerComputationConfig: +) -> PreconditionerConfig: if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV: return DefaultShampooConfig elif ( diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py index 1d0b6ec..19e9177 100644 --- a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -49,7 +49,7 @@ def _optim_factory( def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() @@ -64,7 +64,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -81,14 +81,14 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_adam_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() @@ -104,7 +104,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -122,14 +122,14 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() @@ -145,7 +145,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -163,14 +163,14 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=True, grafting_config=None, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. - for weight_decay, device, preconditioner_computation_config in product( + for weight_decay, device, preconditioner_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() @@ -185,7 +185,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: with self.subTest( weight_decay=weight_decay, device=device, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ): compare_two_optimizers_on_weight_and_loss( control_optim_factory=partial( @@ -206,7 +206,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: use_decoupled_weight_decay=False, grafting_config=None, use_bias_correction=False, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), device=device, ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 42af3ce..76ab2b3 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -44,7 +44,7 @@ PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" -PRECONDITIONER_COMPUTATION_CONFIG = "preconditioner_computation_config" +PRECONDITIONER_CONFIG = "preconditioner_config" START_PRECONDITIONING_STEP = "start_preconditioning_step" USE_EIGENVALUE_CORRECTION = "use_eigenvalue_correction" USE_BIAS_CORRECTION = "use_bias_correction" @@ -81,7 +81,7 @@ class PreconditionerValueError(ValueError): ###### DATACLASSES ###### @dataclass(init=False) -class PreconditionerComputationConfig(AbstractDataclass): +class PreconditionerConfig(AbstractDataclass): """Configuration for preconditioner computation in DistributedShampoo. Args: @@ -93,7 +93,7 @@ class PreconditionerComputationConfig(AbstractDataclass): @dataclass(kw_only=True) -class ShampooPreconditionerConfig(PreconditionerComputationConfig): +class ShampooPreconditionerConfig(PreconditionerConfig): """Configuration for Shampoo preconditioner computation. Args: @@ -110,7 +110,7 @@ class ShampooPreconditionerConfig(PreconditionerComputationConfig): @dataclass(kw_only=True) -class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerComputationConfig): +class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. Args: diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 05aa644..01d841f 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -30,8 +30,10 @@ MASKED_FILTERED_GRAD_LIST, MASKED_MOMENTUM_LIST, PrecisionConfig, + PreconditionerConfig, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, + ShampooPreconditionerConfig, ShampooPT2CompileConfig, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList @@ -44,21 +46,38 @@ def setUp(self) -> None: nn.Linear(5, 10, bias=False), ) - def test_invalid_grafting_config(self) -> None: + def test_invalid_preconditioner_config(self) -> None: with ( mock.patch.object( distributed_shampoo, "type", - side_effect=lambda object: GraftingConfig - if type(object) is SGDGraftingConfig - else type(object), + side_effect=lambda object: { + ShampooPreconditionerConfig: PreconditionerConfig + }.get(type(object), type(object)), ), self.assertRaisesRegex( NotImplementedError, - re.escape( - "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + re.escape("group[PRECONDITIONER_CONFIG]=ShampooPreconditionerConfig"), + ), + ): + DistributedShampoo( + self._model.parameters(), + preconditioner_config=DefaultShampooConfig, + ) + + def test_invalid_grafting_config(self) -> None: + with ( + mock.patch.object( + distributed_shampoo, + "type", + side_effect=lambda object: {SGDGraftingConfig: GraftingConfig}.get( + type(object), type(object) ), ), + self.assertRaisesRegex( + NotImplementedError, + re.escape("group[GRAFTING_CONFIG]=SGDGraftingConfig"), + ), ): DistributedShampoo( self._model.parameters(), @@ -253,7 +272,7 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: lr=0.01, start_preconditioning_step=1, exponent_multiplier=2.0, - preconditioner_computation_config=DefaultShampooConfig, + preconditioner_config=DefaultShampooConfig, ) self.assertCountEqual( [r.msg for r in cm.records], @@ -274,7 +293,7 @@ def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> No lr=0.01, start_preconditioning_step=1, track_root_inv_residuals=True, - preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) @@ -497,7 +516,7 @@ def setUp(self) -> None: ), "use_merge_dims": True, "precision_config": PrecisionConfig(), - "preconditioner_computation_config": DefaultShampooConfig, + "preconditioner_config": DefaultShampooConfig, } }, } @@ -891,7 +910,7 @@ def _instantiate_optimizer( distributed_config=None, grafting_config=None, precision_config=precision_config, - preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ) def _assert_state_list_dtype( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 4c89442..ed15da8 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -21,7 +21,7 @@ import torch from distributed_shampoo.shampoo_types import ( PrecisionConfig, - PreconditionerComputationConfig, + PreconditionerConfig, PreconditionerValueError, ) from distributed_shampoo.utils.shampoo_block_info import BlockInfo @@ -427,7 +427,7 @@ class BaseShampooPreconditionerList( distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) - preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) + preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. (Default: DefaultShampooConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) @@ -451,7 +451,7 @@ def __init__( block_info_list: tuple[BlockInfo, ...], distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, - preconditioner_computation_config: PreconditionerComputationConfig, + preconditioner_config: PreconditionerConfig, beta2: float = 1.0, epsilon: float = 1e-12, inv_root_override: int | tuple[int, ...] = 0, @@ -462,7 +462,7 @@ def __init__( # Initialize parameters. self._precision_config = precision_config - self._preconditioner_computation_config = preconditioner_computation_config + self._preconditioner_config = preconditioner_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override @@ -960,7 +960,7 @@ def _amortized_computation(self) -> None: # Compute inverse preconditioner. root_inv_config = cast( RootInvConfig, - self._preconditioner_computation_config.amortized_computation_config, + self._preconditioner_config.amortized_computation_config, ) try: computed_inv_factor_matrix = matrix_inverse_root( @@ -1023,7 +1023,7 @@ def compute_root_inverse_residuals( ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: root_inv_config = cast( RootInvConfig, - self._preconditioner_computation_config.amortized_computation_config, + self._preconditioner_config.amortized_computation_config, ) relative_errors = [] relative_residuals = [] @@ -1282,7 +1282,7 @@ def _amortized_computation(self) -> None: # Compute eigenvectors of factor matrix. eigenvector_computation_config = cast( EigenvectorConfig, - self._preconditioner_computation_config.amortized_computation_config, + self._preconditioner_config.amortized_computation_config, ) try: computed_eigenvectors = matrix_eigenvectors( diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 54f87c0..8bb061f 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -310,7 +310,7 @@ def test_abstract_methods(self) -> None: ), distributor_selector=(True,), precision_config=PrecisionConfig(), - preconditioner_computation_config=DefaultShampooConfig, + preconditioner_config=DefaultShampooConfig, beta2=1.0, ) @@ -533,7 +533,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, - "preconditioner_computation_config": DefaultShampooConfig, + "preconditioner_config": DefaultShampooConfig, } | kwargs return ShampooPreconditionerList( block_list=self._block_list, @@ -726,7 +726,7 @@ def test_inverse_roots_from_override( """ Tests that the inverse roots are computed correctly from inv_root_override. """ - preconditioner_computation_config = ShampooPreconditionerConfig( + preconditioner_config = ShampooPreconditionerConfig( amortized_computation_config=EigenConfig(exponent_multiplier=2.0), ) @@ -753,7 +753,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - preconditioner_computation_config=preconditioner_computation_config, + preconditioner_config=preconditioner_config, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, @@ -776,7 +776,7 @@ def test_compute_root_inverse_residuals(self) -> None: block_info_list=(self._block_info_list[0],), distributor_selector=(self._distributor_selector[0],), precision_config=PrecisionConfig(), - preconditioner_computation_config=DefaultShampooConfig, + preconditioner_config=DefaultShampooConfig, epsilon=0.0, ) @@ -822,7 +822,7 @@ def _instantiate_preconditioner_list( # type: ignore[override] "inv_root_override": 0, "use_bias_correction": True, "use_protected_eigh": True, - "preconditioner_computation_config": DefaultEigenvalueCorrectedShampooConfig, + "preconditioner_config": DefaultEigenvalueCorrectedShampooConfig, } | kwargs return EigenvalueCorrectedShampooPreconditionerList( block_list=self._block_list, @@ -1055,7 +1055,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - preconditioner_computation_config=DefaultEigenvalueCorrectedShampooConfig, + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, From 1bbaa2f1193867412210b2fe9902dcf51229437c Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 7 Dec 2024 19:19:51 +0000 Subject: [PATCH 05/22] Fix docstring --- distributed_shampoo/distributed_shampoo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 376daad..c31043c 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -272,7 +272,7 @@ class DistributedShampoo(torch.optim.Optimizer): preconditioner_config (PreconditionerConfig): Configuration for preconditioner computation. If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner. If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. - (Default: DefaultEigenConfig) + (Default: DefaultShampooConfig) """ From 32c5df5707c7034276b18a01c745c8face3958d6 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Dec 2024 16:59:01 +0000 Subject: [PATCH 06/22] Add tolerance for amortized computation failures --- distributed_shampoo/shampoo_types.py | 2 + .../utils/shampoo_preconditioner_list.py | 95 +++++++++++++++---- 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 76ab2b3..234c598 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -86,10 +86,12 @@ class PreconditionerConfig(AbstractDataclass): Args: amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation. + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) """ amortized_computation_config: MatrixFunctionConfig + num_tolerated_failed_amortized_computations: int = 3 @dataclass(kw_only=True) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index ed15da8..5d70dd8 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -695,6 +695,28 @@ def _check_factor_matrix_for_diagonality_nan_and_inf( f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." ) + def _raise_exception_if_tolerance_exceeded( + self, counter: int, exception: Exception + ) -> None: + """Raises an exception if the number of failed amortized computations exceeds the tolerance. + + Args: + counter (int): The counter for the number of failed amortized computations. + exception (Exception): The exception to raise. + + Raises: + exception (Exception): The exception to raise. + + """ + tolerance = ( + self._preconditioner_config.num_tolerated_failed_amortized_computations + ) + if counter > tolerance: + logger.error( + f"Exceeded tolerance ({tolerance}) for number of failed amortized computations." + ) + raise exception + def update_preconditioners( self, masked_grad_list: tuple[Tensor, ...], @@ -746,10 +768,18 @@ def _initialize_state_lists( self._inv_root_override, self._local_order_list, ) + self._local_failed_amortized_computation_counter_list: list[list[int]] = [ + [0] * len(kronecker_factors.factor_matrices) + for kronecker_factors in self._local_kronecker_factors_list + if kronecker_factors is not None + ] # Masked lists are the list of active preconditioners or values after filtering out gradients with None. self._masked_order_list: tuple[int, ...] = self._local_order_list self._masked_root_list: tuple[int, ...] = self._local_root_list + self._masked_failed_amortized_computation_counter_list: list[list[int]] = ( + self._local_failed_amortized_computation_counter_list + ) self._masked_kronecker_factors_list: tuple[ ShampooKroneckerFactorsListType, ..., @@ -785,6 +815,14 @@ def compress_preconditioner_list( self._masked_root_list: tuple[int, ...] = compress_list( # type: ignore[no-redef] self._local_root_list, local_grad_selector ) + self._masked_failed_amortized_computation_counter_list: list[list[int]] = ( # type: ignore[no-redef] + list( + compress_list( + self._local_failed_amortized_computation_counter_list, + local_grad_selector, + ) + ) + ) self._masked_kronecker_factors_list: tuple[ # type: ignore[no-redef] ShampooKroneckerFactorsListType, ..., @@ -929,22 +967,25 @@ def _amortized_computation(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): - for kronecker_factors, root in zip( + for kronecker_factors, root, fail_counter_list in zip( self._masked_kronecker_factors_list, self._masked_root_list, + self._masked_failed_amortized_computation_counter_list, strict=True, ): - for ( + for idx, ( factor_matrix, inv_factor_matrix, is_factor_matrix_diagonal, factor_matrix_index, - ) in zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.inv_factor_matrices.dequantized_value, - kronecker_factors.is_factor_matrices_diagonal, - kronecker_factors.factor_matrix_indices, - strict=True, + ) in enumerate( + zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.inv_factor_matrices.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, + ) ): # Add epsilon term and incorporate bias correction. bias_corrected_factor_matrix = ( @@ -977,11 +1018,19 @@ def _amortized_computation(self) -> None: epsilon=self._epsilon, is_diagonal=bool(is_factor_matrix_diagonal), ).to(dtype=inv_factor_matrix.dtype) + # Reset counter for failed amortized computations. + fail_counter_list[idx] = 0 except Exception as exception: # If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails. if not self._use_protected_eigh: raise exception else: + # Increment counter for failed amortized computations. + fail_counter_list[idx] += 1 + # Only reuse previous matrix if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + fail_counter_list[idx], exception + ) logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " f"with {exception=}. Using previous inverted factor matrix and continuing..." @@ -1260,18 +1309,24 @@ def _amortized_computation(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): - for kronecker_factors in self._masked_kronecker_factors_list: - for ( + for kronecker_factors, fail_counter_list in zip( + self._masked_kronecker_factors_list, + self._masked_failed_amortized_computation_counter_list, + strict=True, + ): + for idx, ( factor_matrix, factor_matrix_eigenvectors, is_factor_matrix_diagonal, factor_matrix_index, - ) in zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.factor_matrices_eigenvectors.dequantized_value, - kronecker_factors.is_factor_matrices_diagonal, - kronecker_factors.factor_matrix_indices, - strict=True, + ) in enumerate( + zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.factor_matrices_eigenvectors.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, + ) ): BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( factor_matrix=factor_matrix, @@ -1291,11 +1346,19 @@ def _amortized_computation(self) -> None: eigenvector_computation_config=eigenvector_computation_config, is_diagonal=bool(is_factor_matrix_diagonal), ) + # Reset counter for failed amortized computations. + fail_counter_list[idx] = 0 except Exception as exception: # If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails. if not self._use_protected_eigh: raise exception else: + # Increment counter for failed amortized computations. + fail_counter_list[idx] += 1 + # Only reuse previous matrix if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + fail_counter_list[idx], exception + ) logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." From 65f4f2731d6679fc3ffb275cede663d536e39da7 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Dec 2024 16:59:49 +0000 Subject: [PATCH 07/22] Add test for amortized computation failure tolerance --- .../tests/shampoo_preconditioner_list_test.py | 115 +++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 8bb061f..cd205d8 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -142,7 +142,7 @@ def _test_compress_preconditioner_list( ) as mock_compress_quant_list, ): # Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList. - # Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside. + # Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() four times inside. self.assertIsNone( self._preconditioner_list.compress_preconditioner_list( local_grad_selector=(True,) * len(self._block_list) @@ -327,6 +327,9 @@ def test_abstract_methods(self) -> None: # Use outer class as wrapper to avoid running the abstract test. class AbstractTest: class BaseShampooPreconditionerListTest(abc.ABC, AdagradPreconditionerListTest): + # Number of calls to the amortized computation function per update. + NUM_AMORTIZED_COMPUTATION_CALLS = 5 + @abc.abstractmethod def _amortized_computation_function(self) -> str: ... @@ -455,6 +458,114 @@ def test_amortized_computation_internal_failure(self) -> None: ) mock_amortized_computation.assert_called() + def test_amortized_computation_failure_tolerance(self) -> None: + self._preconditioner_list = self._instantiate_preconditioner_list() + masked_grad_list0 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ) + masked_grad_list = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ) + + with mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + side_effect=[ + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(torch.tensor([1.0]),) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + ], + ) as mock_amortized_computation: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ): + step = 1 + # Accumulate factor matrices for valid amortized computation. + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list0, + step=torch.tensor(step), + perform_amortized_computation=False, + ) + self.assertEqual(mock_amortized_computation.call_count, 0) + step += 1 + + # Case 1: amortized computation fails less often than tolerance -> no error. + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertEqual( + mock_amortized_computation.call_count, + self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), + ) + step += 1 + + # Case 2: amortized computation fails exactly as often as tolerance (3) -> no error. + for _ in range(2): + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertEqual( + mock_amortized_computation.call_count, + self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), + ) + step += 1 + + # Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error. + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertEqual( + mock_amortized_computation.call_count, + self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), + ) + step += 1 + + # Case 4: amortized computation fails more often than tolerance -> error. + for _ in range(3): + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertEqual( + mock_amortized_computation.call_count, + self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), + ) + step += 1 + # At tolerance now. + with self.assertRaises(ValueError): + with self.assertLogs(level="ERROR") as log: + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertIn( + "Exceeded tolerance (3) for number of failed amortized computations.", + log.output, + ) + # The error will be raised for the first Kronecker factor, so the + # call expected count should only be increased by 1. + self.assertEqual( + mock_amortized_computation.call_count, + self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 2) + 1, + ) + # Note: This is needed for type checking to infer the type of argument into mock.patch.object. shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list @@ -517,7 +628,7 @@ def test_num_bytes(self) -> None: self.assertEqual(self._preconditioner_list.num_bytes(), 204) def test_compress_preconditioner_list(self) -> None: - self._test_compress_preconditioner_list(expected_compress_list_call_count=3) + self._test_compress_preconditioner_list(expected_compress_list_call_count=4) class ShampooPreconditionerListTest(AbstractTest.BaseShampooPreconditionerListTest): From 9a137b560e3ee26b7bd1b2c5c9af95008f79ec52 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Dec 2024 18:18:42 +0000 Subject: [PATCH 08/22] Adjust abstractmethod test --- distributed_shampoo/utils/shampoo_preconditioner_list.py | 1 - .../utils/tests/shampoo_preconditioner_list_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 5d70dd8..6dae748 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -771,7 +771,6 @@ def _initialize_state_lists( self._local_failed_amortized_computation_counter_list: list[list[int]] = [ [0] * len(kronecker_factors.factor_matrices) for kronecker_factors in self._local_kronecker_factors_list - if kronecker_factors is not None ] # Masked lists are the list of active preconditioners or values after filtering out gradients with None. diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index cd205d8..610e676 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -308,7 +308,7 @@ def test_abstract_methods(self) -> None: composable_block_ids=(0, "block_0"), ), ), - distributor_selector=(True,), + distributor_selector=(False,), precision_config=PrecisionConfig(), preconditioner_config=DefaultShampooConfig, beta2=1.0, From 03245f56ae298f88fb7d5aa0c94107ac4dacbf8e Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Dec 2024 19:26:17 +0000 Subject: [PATCH 09/22] Add check that tolerance value non-negative --- distributed_shampoo/shampoo_types.py | 6 ++ .../tests/shampoo_types_test.py | 82 ++++++++++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 234c598..4074fea 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -93,6 +93,12 @@ class PreconditionerConfig(AbstractDataclass): amortized_computation_config: MatrixFunctionConfig num_tolerated_failed_amortized_computations: int = 3 + def __post_init__(self) -> None: + if self.num_tolerated_failed_amortized_computations < 0: + raise ValueError( + f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0." + ) + @dataclass(kw_only=True) class ShampooPreconditionerConfig(PreconditionerConfig): diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 2ec773f..dd9a59d 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -9,12 +9,23 @@ import re import unittest -from typing import Type +from abc import ABC, abstractmethod +from typing import Generic, Type, TypeVar from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, AdamGraftingConfig, + EigenvalueCorrectedShampooPreconditionerConfig, + PreconditionerConfig, RMSpropGraftingConfig, + ShampooPreconditionerConfig, +) +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighConfig, + EigenvectorConfig, + MatrixFunctionConfig, + RootInvConfig, ) @@ -69,3 +80,72 @@ def _get_grafting_config_type( self, ) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]: return AdamGraftingConfig + + +PreconditionerConfigType = TypeVar( + "PreconditionerConfigType", bound=Type[PreconditionerConfig] +) +AmortizedComputationConfigType = TypeVar( + "AmortizedComputationConfigType", bound=MatrixFunctionConfig +) + + +class AbstractPreconditionerConfigTest: + class PreconditionerConfigTest( + ABC, + unittest.TestCase, + Generic[PreconditionerConfigType, AmortizedComputationConfigType], + ): + def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: + num_tolerated_failed_amortized_computations = -1 + with ( + self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid num_tolerated_failed_amortized_computations value: " + f"{num_tolerated_failed_amortized_computations}. Must be >= 0." + ), + ), + ): + self._get_preconditioner_config_type()( + amortized_computation_config=self._get_amortized_computation_config(), + num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations, + ) + + @abstractmethod + def _get_preconditioner_config_type( + self, + ) -> PreconditionerConfigType: ... + + @abstractmethod + def _get_amortized_computation_config( + self, + ) -> AmortizedComputationConfigType: ... + + +class ShampooPreconditionerConfigTest( + AbstractPreconditionerConfigTest.PreconditionerConfigTest[ + Type[ShampooPreconditionerConfig], RootInvConfig + ] +): + def _get_amortized_computation_config(self) -> RootInvConfig: + return DefaultEigenConfig + + def _get_preconditioner_config_type( + self, + ) -> Type[ShampooPreconditionerConfig]: + return ShampooPreconditionerConfig + + +class EigenvalueCorrectedShampooPreconditionerConfigTest( + AbstractPreconditionerConfigTest.PreconditionerConfigTest[ + Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig + ] +): + def _get_amortized_computation_config(self) -> EigenvectorConfig: + return DefaultEighConfig + + def _get_preconditioner_config_type( + self, + ) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]: + return EigenvalueCorrectedShampooPreconditionerConfig From 527c35e63faf32ac5a55b7422583074a227a428c Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Dec 2024 22:46:41 +0000 Subject: [PATCH 10/22] Make failure tracking coarser --- .../utils/shampoo_preconditioner_list.py | 114 ++++++++++-------- 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 6dae748..8684fe9 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -768,15 +768,14 @@ def _initialize_state_lists( self._inv_root_override, self._local_order_list, ) - self._local_failed_amortized_computation_counter_list: list[list[int]] = [ - [0] * len(kronecker_factors.factor_matrices) - for kronecker_factors in self._local_kronecker_factors_list - ] + self._local_failed_amortized_computation_counter_list: list[int] = [0] * len( + self._local_kronecker_factors_list + ) # Masked lists are the list of active preconditioners or values after filtering out gradients with None. self._masked_order_list: tuple[int, ...] = self._local_order_list self._masked_root_list: tuple[int, ...] = self._local_root_list - self._masked_failed_amortized_computation_counter_list: list[list[int]] = ( + self._masked_failed_amortized_computation_counter_list: list[int] = ( self._local_failed_amortized_computation_counter_list ) self._masked_kronecker_factors_list: tuple[ @@ -814,7 +813,7 @@ def compress_preconditioner_list( self._masked_root_list: tuple[int, ...] = compress_list( # type: ignore[no-redef] self._local_root_list, local_grad_selector ) - self._masked_failed_amortized_computation_counter_list: list[list[int]] = ( # type: ignore[no-redef] + self._masked_failed_amortized_computation_counter_list: list[int] = ( # type: ignore[no-redef] list( compress_list( self._local_failed_amortized_computation_counter_list, @@ -966,25 +965,25 @@ def _amortized_computation(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): - for kronecker_factors, root, fail_counter_list in zip( - self._masked_kronecker_factors_list, - self._masked_root_list, - self._masked_failed_amortized_computation_counter_list, - strict=True, + for idx, (kronecker_factors, root) in enumerate( + zip( + self._masked_kronecker_factors_list, + self._masked_root_list, + strict=True, + ) ): - for idx, ( + success_tracker: list[bool] = [] + for ( factor_matrix, inv_factor_matrix, is_factor_matrix_diagonal, factor_matrix_index, - ) in enumerate( - zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.inv_factor_matrices.dequantized_value, - kronecker_factors.is_factor_matrices_diagonal, - kronecker_factors.factor_matrix_indices, - strict=True, - ) + ) in zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.inv_factor_matrices.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, ): # Add epsilon term and incorporate bias correction. bias_corrected_factor_matrix = ( @@ -1017,19 +1016,15 @@ def _amortized_computation(self) -> None: epsilon=self._epsilon, is_diagonal=bool(is_factor_matrix_diagonal), ).to(dtype=inv_factor_matrix.dtype) - # Reset counter for failed amortized computations. - fail_counter_list[idx] = 0 + # Add success to success tracker. + success_tracker.append(True) except Exception as exception: # If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails. if not self._use_protected_eigh: raise exception else: - # Increment counter for failed amortized computations. - fail_counter_list[idx] += 1 - # Only reuse previous matrix if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - fail_counter_list[idx], exception - ) + # Add failure to success tracker. + success_tracker.append(False) logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " f"with {exception=}. Using previous inverted factor matrix and continuing..." @@ -1049,6 +1044,20 @@ def _amortized_computation(self) -> None: ) inv_factor_matrix.copy_(computed_inv_factor_matrix) + if all(success_tracker): + # Reset counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] = 0 + else: + # Increment counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] += 1 + # Only reuse previous eigenvectors if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + self._masked_failed_amortized_computation_counter_list[idx], + ValueError( + f"Exceeded tolerance for number of failed root inverse computations for {kronecker_factors.factor_matrix_indices}." + ), + ) + def dequantize_preconditioners(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" @@ -1308,24 +1317,21 @@ def _amortized_computation(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): - for kronecker_factors, fail_counter_list in zip( - self._masked_kronecker_factors_list, - self._masked_failed_amortized_computation_counter_list, - strict=True, + for idx, kronecker_factors in enumerate( + self._masked_kronecker_factors_list ): - for idx, ( + success_tracker: list[bool] = [] + for ( factor_matrix, factor_matrix_eigenvectors, is_factor_matrix_diagonal, factor_matrix_index, - ) in enumerate( - zip( - kronecker_factors.factor_matrices.dequantized_value, - kronecker_factors.factor_matrices_eigenvectors.dequantized_value, - kronecker_factors.is_factor_matrices_diagonal, - kronecker_factors.factor_matrix_indices, - strict=True, - ) + ) in zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.factor_matrices_eigenvectors.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, ): BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( factor_matrix=factor_matrix, @@ -1345,19 +1351,15 @@ def _amortized_computation(self) -> None: eigenvector_computation_config=eigenvector_computation_config, is_diagonal=bool(is_factor_matrix_diagonal), ) - # Reset counter for failed amortized computations. - fail_counter_list[idx] = 0 + # Add success to success tracker. + success_tracker.append(True) except Exception as exception: # If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails. if not self._use_protected_eigh: raise exception else: - # Increment counter for failed amortized computations. - fail_counter_list[idx] += 1 - # Only reuse previous matrix if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - fail_counter_list[idx], exception - ) + # Add failure to success tracker. + success_tracker.append(False) logger.warning( f"Matrix computation failed for factor matrix {factor_matrix_index} " f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." @@ -1377,6 +1379,20 @@ def _amortized_computation(self) -> None: ) factor_matrix_eigenvectors.copy_(computed_eigenvectors) + if all(success_tracker): + # Reset counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] = 0 + else: + # Increment counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] += 1 + # Only reuse previous eigenvectors if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + self._masked_failed_amortized_computation_counter_list[idx], + ValueError( + f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." + ), + ) + def dequantize_preconditioners(self) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" From 01dde5fc93fb6eafcf7588af05afd759d18a5cf5 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 10 Dec 2024 15:06:42 +0000 Subject: [PATCH 11/22] Reduce code duplication --- .../utils/shampoo_preconditioner_list.py | 67 +++++++++---------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 8684fe9..6687d56 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -696,26 +696,33 @@ def _check_factor_matrix_for_diagonality_nan_and_inf( ) def _raise_exception_if_tolerance_exceeded( - self, counter: int, exception: Exception + self, success_tracker: list[bool], idx: int, exception: Exception ) -> None: """Raises an exception if the number of failed amortized computations exceeds the tolerance. + Resets the counter at the given index when all amortized computations are successful. + Args: - counter (int): The counter for the number of failed amortized computations. + success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful. + idx (int): The index of the preconditioner. exception (Exception): The exception to raise. Raises: exception (Exception): The exception to raise. """ - tolerance = ( - self._preconditioner_config.num_tolerated_failed_amortized_computations - ) - if counter > tolerance: - logger.error( - f"Exceeded tolerance ({tolerance}) for number of failed amortized computations." + if all(success_tracker): + # Reset counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] = 0 + else: + # Increment counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] += 1 + # Raise the exception if the tolerance at the given index is exceeded. + tolerance = ( + self._preconditioner_config.num_tolerated_failed_amortized_computations ) - raise exception + if self._masked_failed_amortized_computation_counter_list[idx] > tolerance: + raise exception def update_preconditioners( self, @@ -1044,19 +1051,14 @@ def _amortized_computation(self) -> None: ) inv_factor_matrix.copy_(computed_inv_factor_matrix) - if all(success_tracker): - # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 - else: - # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 - # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - self._masked_failed_amortized_computation_counter_list[idx], - ValueError( - f"Exceeded tolerance for number of failed root inverse computations for {kronecker_factors.factor_matrix_indices}." - ), - ) + # Only reuse previous inverse roots if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + success_tracker, + idx, + ValueError( + f"Exceeded tolerance for number of failed inverse root computations for {kronecker_factors.factor_matrix_indices}." + ), + ) def dequantize_preconditioners(self) -> None: with profiler.record_function( @@ -1379,19 +1381,14 @@ def _amortized_computation(self) -> None: ) factor_matrix_eigenvectors.copy_(computed_eigenvectors) - if all(success_tracker): - # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 - else: - # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 - # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - self._masked_failed_amortized_computation_counter_list[idx], - ValueError( - f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." - ), - ) + # Only reuse previous eigenvectors if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + success_tracker, + idx, + ValueError( + f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." + ), + ) def dequantize_preconditioners(self) -> None: with profiler.record_function( From d3a10aff8de1679ecae85daba7cfd21d2b3fbbe1 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 10 Dec 2024 17:48:18 +0000 Subject: [PATCH 12/22] Set default values --- distributed_shampoo/shampoo_types.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 76ab2b3..42e57ab 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -97,16 +97,14 @@ class ShampooPreconditionerConfig(PreconditionerConfig): """Configuration for Shampoo preconditioner computation. Args: - amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. + amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig) """ - amortized_computation_config: RootInvConfig + amortized_computation_config: RootInvConfig = DefaultEigenConfig -DefaultShampooConfig = ShampooPreconditionerConfig( - amortized_computation_config=DefaultEigenConfig -) +DefaultShampooConfig = ShampooPreconditionerConfig() @dataclass(kw_only=True) @@ -114,17 +112,15 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. Args: - amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. + amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation. (Default: DefaultEighConfig) """ - amortized_computation_config: EigenvectorConfig + amortized_computation_config: EigenvectorConfig = DefaultEighConfig DefaultEigenvalueCorrectedShampooConfig = ( - EigenvalueCorrectedShampooPreconditionerConfig( - amortized_computation_config=DefaultEighConfig, - ) + EigenvalueCorrectedShampooPreconditionerConfig() ) DefaultSOAPConfig = EigenvalueCorrectedShampooPreconditionerConfig( amortized_computation_config=QRConfig(), From 273e0a1585bdfa6adb1fb454aed3963b7c0dc518 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 10 Dec 2024 18:07:15 +0000 Subject: [PATCH 13/22] Fix defaults with default_factory --- distributed_shampoo/shampoo_types.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 42e57ab..bf5ccdd 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -8,7 +8,7 @@ """ import enum -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -101,7 +101,9 @@ class ShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: RootInvConfig = DefaultEigenConfig + amortized_computation_config: RootInvConfig = field( + default_factory=lambda: DefaultEigenConfig + ) DefaultShampooConfig = ShampooPreconditionerConfig() @@ -116,7 +118,9 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig): """ - amortized_computation_config: EigenvectorConfig = DefaultEighConfig + amortized_computation_config: EigenvectorConfig = field( + default_factory=lambda: DefaultEighConfig + ) DefaultEigenvalueCorrectedShampooConfig = ( From a92348d048d4b509404d65a3f7824a076d028629 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 10 Dec 2024 18:12:08 +0000 Subject: [PATCH 14/22] Improve naming --- .../utils/shampoo_preconditioner_list.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 6687d56..8548d5b 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -695,8 +695,11 @@ def _check_factor_matrix_for_diagonality_nan_and_inf( f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." ) - def _raise_exception_if_tolerance_exceeded( - self, success_tracker: list[bool], idx: int, exception: Exception + def _raise_exception_if_failure_tolerance_exceeded( + self, + success_tracker: list[bool], + preconditioner_index: int, + exception: Exception, ) -> None: """Raises an exception if the number of failed amortized computations exceeds the tolerance. @@ -704,7 +707,7 @@ def _raise_exception_if_tolerance_exceeded( Args: success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful. - idx (int): The index of the preconditioner. + preconditioner_index (int): The index of the preconditioner. exception (Exception): The exception to raise. Raises: @@ -713,15 +716,22 @@ def _raise_exception_if_tolerance_exceeded( """ if all(success_tracker): # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 + self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] = 0 else: # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 + self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] += 1 # Raise the exception if the tolerance at the given index is exceeded. + failure_counter = self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] tolerance = ( self._preconditioner_config.num_tolerated_failed_amortized_computations ) - if self._masked_failed_amortized_computation_counter_list[idx] > tolerance: + if failure_counter > tolerance: raise exception def update_preconditioners( @@ -1052,7 +1062,7 @@ def _amortized_computation(self) -> None: inv_factor_matrix.copy_(computed_inv_factor_matrix) # Only reuse previous inverse roots if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( + self._raise_exception_if_failure_tolerance_exceeded( success_tracker, idx, ValueError( @@ -1382,7 +1392,7 @@ def _amortized_computation(self) -> None: factor_matrix_eigenvectors.copy_(computed_eigenvectors) # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( + self._raise_exception_if_failure_tolerance_exceeded( success_tracker, idx, ValueError( From 736c76d8f022eb68461770b7bfbeee5da839ac64 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 12 Dec 2024 15:16:51 +0000 Subject: [PATCH 15/22] Simplify test --- .../tests/shampoo_types_test.py | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index dd9a59d..5b2e588 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -20,13 +20,6 @@ RMSpropGraftingConfig, ShampooPreconditionerConfig, ) -from matrix_functions_types import ( - DefaultEigenConfig, - DefaultEighConfig, - EigenvectorConfig, - MatrixFunctionConfig, - RootInvConfig, -) class AdaGradGraftingConfigTest(unittest.TestCase): @@ -85,16 +78,13 @@ def _get_grafting_config_type( PreconditionerConfigType = TypeVar( "PreconditionerConfigType", bound=Type[PreconditionerConfig] ) -AmortizedComputationConfigType = TypeVar( - "AmortizedComputationConfigType", bound=MatrixFunctionConfig -) class AbstractPreconditionerConfigTest: class PreconditionerConfigTest( ABC, unittest.TestCase, - Generic[PreconditionerConfigType, AmortizedComputationConfigType], + Generic[PreconditionerConfigType], ): def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: num_tolerated_failed_amortized_computations = -1 @@ -108,7 +98,6 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: ), ): self._get_preconditioner_config_type()( - amortized_computation_config=self._get_amortized_computation_config(), num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations, ) @@ -117,20 +106,12 @@ def _get_preconditioner_config_type( self, ) -> PreconditionerConfigType: ... - @abstractmethod - def _get_amortized_computation_config( - self, - ) -> AmortizedComputationConfigType: ... - class ShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[ShampooPreconditionerConfig], RootInvConfig + Type[ShampooPreconditionerConfig] ] ): - def _get_amortized_computation_config(self) -> RootInvConfig: - return DefaultEigenConfig - def _get_preconditioner_config_type( self, ) -> Type[ShampooPreconditionerConfig]: @@ -139,12 +120,9 @@ def _get_preconditioner_config_type( class EigenvalueCorrectedShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig + Type[EigenvalueCorrectedShampooPreconditionerConfig] ] ): - def _get_amortized_computation_config(self) -> EigenvectorConfig: - return DefaultEighConfig - def _get_preconditioner_config_type( self, ) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]: From cf08da0efb3fa05acd9f64ef453cae509fe11461 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Dec 2024 11:38:10 +0000 Subject: [PATCH 16/22] Fix test --- .../tests/shampoo_preconditioner_list_test.py | 101 +++++++++++++----- 1 file changed, 77 insertions(+), 24 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 610e676..27c0535 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -10,6 +10,7 @@ import abc import re import unittest +import warnings from types import ModuleType from typing import Any from unittest import mock @@ -475,14 +476,15 @@ def test_amortized_computation_failure_tolerance(self) -> None: shampoo_preconditioner_list, self._amortized_computation_function(), side_effect=[ - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + *(ValueError,) * (self.NUM_AMORTIZED_COMPUTATION_CALLS - 1), + torch.tensor([1.0]), *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, *(torch.tensor([1.0]),) * self.NUM_AMORTIZED_COMPUTATION_CALLS, *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, + ValueError, ], ) as mock_amortized_computation: with DequantizePreconditionersContext( @@ -499,10 +501,23 @@ def test_amortized_computation_failure_tolerance(self) -> None: step += 1 # Case 1: amortized computation fails less often than tolerance -> no error. - self._preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list, - step=torch.tensor(step), - perform_amortized_computation=True, + with self.assertLogs(level="WARNING") as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + # Check that warnings are logged for four failed amortized computations. + # The fifth one doesn't raise an exception, so no warning is logged. + self.assertCountEqual( + # Only extracts the first sentence in the warning message for simple comparison. + [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], + [ + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ValueError()", + ], ) self.assertEqual( mock_amortized_computation.call_count, @@ -512,10 +527,23 @@ def test_amortized_computation_failure_tolerance(self) -> None: # Case 2: amortized computation fails exactly as often as tolerance (3) -> no error. for _ in range(2): - self._preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list, - step=torch.tensor(step), - perform_amortized_computation=True, + with self.assertLogs(level="WARNING") as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + # Check that warnings are logged for all failed amortized computations. + self.assertCountEqual( + # Only extracts the first sentence in the warning message for simple comparison. + [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], + [ + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_1.1 with exception=ValueError()", + ], ) self.assertEqual( mock_amortized_computation.call_count, @@ -524,10 +552,17 @@ def test_amortized_computation_failure_tolerance(self) -> None: step += 1 # Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error. - self._preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list, - step=torch.tensor(step), - perform_amortized_computation=True, + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + self.assertEqual( + len(warning_list), + 0, + f"Expected no warnings but got: {warning_list}", ) self.assertEqual( mock_amortized_computation.call_count, @@ -537,10 +572,23 @@ def test_amortized_computation_failure_tolerance(self) -> None: # Case 4: amortized computation fails more often than tolerance -> error. for _ in range(3): - self._preconditioner_list.update_preconditioners( - masked_grad_list=masked_grad_list, - step=torch.tensor(step), - perform_amortized_computation=True, + with self.assertLogs(level="WARNING") as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=masked_grad_list, + step=torch.tensor(step), + perform_amortized_computation=True, + ) + # Check that warnings are logged for four failed amortized computations. + self.assertCountEqual( + # Only extracts the first sentence in the warning message for simple comparison. + [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], + [ + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ValueError()", + "Matrix computation failed for factor matrix 1.block_1.1 with exception=ValueError()", + ], ) self.assertEqual( mock_amortized_computation.call_count, @@ -548,17 +596,22 @@ def test_amortized_computation_failure_tolerance(self) -> None: ) step += 1 # At tolerance now. - with self.assertRaises(ValueError): - with self.assertLogs(level="ERROR") as log: + with self.assertLogs(level="WARNING") as cm: + expected_error_message = "Exceeded tolerance.*('0.block_0.0',)." + with self.assertRaisesRegex(ValueError, expected_error_message): self._preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), perform_amortized_computation=True, ) - self.assertIn( - "Exceeded tolerance (3) for number of failed amortized computations.", - log.output, - ) + # Check that the warning is logged for the failed amortized computation of the first matrix. + self.assertCountEqual( + # Only extracts the first sentence in the warning message for simple comparison. + [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], + [ + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ValueError()", + ], + ) # The error will be raised for the first Kronecker factor, so the # call expected count should only be increased by 1. self.assertEqual( From 5b37d84d2da287fe5ca8b66ad4819e0c43cbae1d Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Dec 2024 19:49:10 +0000 Subject: [PATCH 17/22] Use keywords explicitly --- .../utils/shampoo_preconditioner_list.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 8548d5b..f33e624 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -1063,9 +1063,9 @@ def _amortized_computation(self) -> None: # Only reuse previous inverse roots if tolerance is not exceeded. self._raise_exception_if_failure_tolerance_exceeded( - success_tracker, - idx, - ValueError( + success_tracker=success_tracker, + preconditioner_index=idx, + exception=ValueError( f"Exceeded tolerance for number of failed inverse root computations for {kronecker_factors.factor_matrix_indices}." ), ) @@ -1393,9 +1393,9 @@ def _amortized_computation(self) -> None: # Only reuse previous eigenvectors if tolerance is not exceeded. self._raise_exception_if_failure_tolerance_exceeded( - success_tracker, - idx, - ValueError( + success_tracker=success_tracker, + preconditioner_index=idx, + exception=ValueError( f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." ), ) From 77934296d123dae4d599e8dc90097db70f5219fe Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Dec 2024 20:27:14 +0000 Subject: [PATCH 18/22] Revert outdated change --- .../utils/tests/shampoo_preconditioner_list_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index ba44780..b120d8b 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -298,7 +298,7 @@ def test_abstract_methods(self) -> None: composable_block_ids=(0, "block_0"), ), ), - distributor_selector=(False,), + distributor_selector=(True,), preconditioner_config=DefaultShampooConfig, beta2=1.0, ) From 98051d2999d8d1377c8bd68fc2079527583cdc07 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Dec 2024 20:37:19 +0000 Subject: [PATCH 19/22] Simplify no warnings assertion --- .../utils/tests/shampoo_preconditioner_list_test.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index b120d8b..066748d 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -10,7 +10,6 @@ import abc import re import unittest -import warnings from types import ModuleType from typing import Any from unittest import mock @@ -500,18 +499,12 @@ def test_amortized_computation_failure_tolerance(self) -> None: step += 1 # Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error. - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") + with self.assertNoLogs(level="WARNING") as cm: self._preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), perform_amortized_computation=True, ) - self.assertEqual( - len(warning_list), - 0, - f"Expected no warnings but got: {warning_list}", - ) self.assertEqual( mock_amortized_computation.call_count, self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), From 8ea571e14c75cdf612ca95937ce518f05e038959 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 19 Dec 2024 20:40:48 +0000 Subject: [PATCH 20/22] Remove leftover variable --- .../utils/tests/shampoo_preconditioner_list_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 066748d..53b30f3 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -499,7 +499,7 @@ def test_amortized_computation_failure_tolerance(self) -> None: step += 1 # Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error. - with self.assertNoLogs(level="WARNING") as cm: + with self.assertNoLogs(level="WARNING"): self._preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), From c098c6a063532b02bb26ac41d225f72222633b0b Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 20 Dec 2024 16:56:15 +0000 Subject: [PATCH 21/22] Improve readability of call count check --- .../utils/tests/shampoo_preconditioner_list_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 53b30f3..e1f7d06 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -422,6 +422,8 @@ def test_amortized_computation_failure_tolerance(self) -> None: torch.tensor([[0.0, 1.0]]), ) + # Initialize step counter. + step = 1 with mock.patch.object( shampoo_preconditioner_list, self._amortized_computation_function(), @@ -437,7 +439,6 @@ def test_amortized_computation_failure_tolerance(self) -> None: ValueError, ], ) as mock_amortized_computation: - step = 1 # Accumulate factor matrices for valid amortized computation. self._preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list0, @@ -536,7 +537,9 @@ def test_amortized_computation_failure_tolerance(self) -> None: self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), ) step += 1 - # At tolerance now. + # Cache current call count. + previous_call_count = mock_amortized_computation.call_count + # Exactly at failure tolerance now. with self.assertLogs(level="WARNING") as cm: expected_error_message = "Exceeded tolerance.*('0.block_0.0',)." with self.assertRaisesRegex(ValueError, expected_error_message): @@ -554,10 +557,10 @@ def test_amortized_computation_failure_tolerance(self) -> None: ], ) # The error will be raised for the first Kronecker factor, so the - # call expected count should only be increased by 1. + # expected call count should only be increased by 1. self.assertEqual( mock_amortized_computation.call_count, - self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 2) + 1, + previous_call_count + 1, ) # Note: This is needed for type checking to infer the type of argument into mock.patch.object. From 16853eaf78bce67180006089741ceb9e7865598d Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 20 Dec 2024 18:18:25 +0000 Subject: [PATCH 22/22] Further improve readability of test --- .../tests/shampoo_preconditioner_list_test.py | 70 ++++++++----------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index d920bfc..a7439a6 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -424,21 +424,35 @@ def test_amortized_computation_failure_tolerance(self) -> None: # Initialize step counter. step = 1 - with mock.patch.object( - shampoo_preconditioner_list, - self._amortized_computation_function(), - side_effect=[ - *(ValueError,) * (self.NUM_AMORTIZED_COMPUTATION_CALLS - 1), - torch.tensor([1.0]), - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(torch.tensor([1.0]),) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - *(ValueError,) * self.NUM_AMORTIZED_COMPUTATION_CALLS, - ValueError, - ], - ) as mock_amortized_computation: + # Define the side effect for each call of the amortized computation function. + fail = ValueError + success = torch.tensor([1.0]) + all_but_one_fail = (fail,) * (self.NUM_AMORTIZED_COMPUTATION_CALLS - 1) + ( + success, + ) + all_fail = (fail,) * self.NUM_AMORTIZED_COMPUTATION_CALLS + all_success = (success,) * self.NUM_AMORTIZED_COMPUTATION_CALLS + with ( + mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + # Note that the cases causally depend on each other. + side_effect=[ + # Case 1: amortized computation fails less often than tolerance. + *all_but_one_fail, # Success for a single Kronecker factor is not enough to reset counter. + # Case 2: amortized computation fails exactly as often as tolerance (3). + *all_fail, + *all_fail, + # Case 3: amortized computation succeeds after tolerance hit (counter is reset). + *all_success, + # Case 4: amortized computation fails more often than tolerance. + *all_fail, + *all_fail, + *all_fail, + fail, # One failure is enough to raise an exception in this case. + ], + ) as mock_amortized_computation + ): # Accumulate factor matrices for valid amortized computation. self._preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list0, @@ -456,7 +470,7 @@ def test_amortized_computation_failure_tolerance(self) -> None: perform_amortized_computation=True, ) # Check that warnings are logged for four failed amortized computations. - # The fifth one doesn't raise an exception, so no warning is logged. + # The fifth one doesn't raise an exception (see the definition of the side effect), so no warning is logged. self.assertCountEqual( # Only extracts the first sentence in the warning message for simple comparison. [r.msg.split(". ", maxsplit=1)[0] for r in cm.records], @@ -467,10 +481,6 @@ def test_amortized_computation_failure_tolerance(self) -> None: "Matrix computation failed for factor matrix 1.block_1.0 with exception=ValueError()", ], ) - self.assertEqual( - mock_amortized_computation.call_count, - self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), - ) step += 1 # Case 2: amortized computation fails exactly as often as tolerance (3) -> no error. @@ -493,10 +503,6 @@ def test_amortized_computation_failure_tolerance(self) -> None: "Matrix computation failed for factor matrix 1.block_1.1 with exception=ValueError()", ], ) - self.assertEqual( - mock_amortized_computation.call_count, - self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), - ) step += 1 # Case 3: amortized computation succeeds after tolerance hit (test reset) -> no error. @@ -506,10 +512,6 @@ def test_amortized_computation_failure_tolerance(self) -> None: step=torch.tensor(step), perform_amortized_computation=True, ) - self.assertEqual( - mock_amortized_computation.call_count, - self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), - ) step += 1 # Case 4: amortized computation fails more often than tolerance -> error. @@ -532,13 +534,7 @@ def test_amortized_computation_failure_tolerance(self) -> None: "Matrix computation failed for factor matrix 1.block_1.1 with exception=ValueError()", ], ) - self.assertEqual( - mock_amortized_computation.call_count, - self.NUM_AMORTIZED_COMPUTATION_CALLS * (step - 1), - ) step += 1 - # Cache current call count. - previous_call_count = mock_amortized_computation.call_count # Exactly at failure tolerance now. with self.assertLogs(level="WARNING") as cm: expected_error_message = "Exceeded tolerance.*('0.block_0.0',)." @@ -556,12 +552,6 @@ def test_amortized_computation_failure_tolerance(self) -> None: "Matrix computation failed for factor matrix 0.block_0.0 with exception=ValueError()", ], ) - # The error will be raised for the first Kronecker factor, so the - # expected call count should only be increased by 1. - self.assertEqual( - mock_amortized_computation.call_count, - previous_call_count + 1, - ) # Note: This is needed for type checking to infer the type of argument into mock.patch.object. shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list