diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 84490c7..8b8cd11 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -37,6 +37,7 @@ Key distinctives of this implementation include: - Choice of precision for preconditioner accumulation and root inverse computation. - Ability to cache split parameters. - Merging of small dimensions. +- [EXPERIMENTAL] Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner [2,6,7]. ## Requirements @@ -62,6 +63,8 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. +- When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. + ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum If we previously used the optimizer: @@ -479,3 +482,5 @@ When encountering those errors, following are things you could try: 3. [Learning Rate Grafting: Transferability of Optimizer Tuning](https://openreview.net/pdf?id=FpKgG31Z_i9). Naman Agarwal, Rohan Anil, Elad Hazan, Tomer Koren, and Cyril Zhang. Tech Report, 2021. 4. [Functions of Matrices: Theory and Computation](https://epubs.siam.org/doi/book/10.1137/1.9780898717778). Nicholas J. Higham. SIAM, 2008. 5. [A Distributed Data-Parallel PyTorch Implementation of the Distributed Shampoo Optimizer for Training Neural Networks At-Scale](https://arxiv.org/pdf/2309.06497.pdf). Hao-Jun Michael Shi, Tsung-Hsien Lee, Shintaro Iwasaki, Jose Gallego-Posada, Zhijing Li, Kaushik Rangadurai, Dheevatsa Mudigere, and Michael Rabbat. Tech Report, 2023. +6. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018. +7. [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade. Tech Report, 2024. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 6e8f197..b0bf3a1 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -58,9 +58,9 @@ PRECISION_CONFIG, PrecisionConfig, PRECONDITION_FREQUENCY, + PRECONDITIONER_COMPUTATION_CONFIG, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, - ROOT_INV_CONFIG, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, ShampooPT2CompileConfig, @@ -90,6 +90,7 @@ from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, DequantizePreconditionersContext, + EigenvalueCorrectedShampooPreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) @@ -100,7 +101,13 @@ ) from distributed_shampoo.utils.shampoo_utils import compress_list -from matrix_functions_types import DefaultEigenConfig, EigenConfig, RootInvConfig +from matrix_functions_types import ( + DefaultEigenConfig, + EigenConfig, + EigenvalueCorrectionConfig, + PreconditionerComputationConfig, + RootInvConfig, +) from torch.optim.optimizer import ParamsT, StateDict logger: logging.Logger = logging.getLogger(__name__) @@ -214,6 +221,16 @@ class DistributedShampoo(torch.optim.Optimizer): particular tensor shape. Recommended to use `static` mode here for Shampoo. More about dynamic shape: https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html + 5. [EXPERIMENTAL] Eigenvalue correction: We can (approximately) correct the eigenvalues of Shampoo's preconditioner by accumulating a running + average of the squared gradient in the eigenbasis of Shampoo's preconditioner. This running average (with hyperparameter `betas[1]`) is + updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. + Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner. + + When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning + rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be + a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. + Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. + Args: params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups. lr (float): Learning rate. (Default: 1e-2) @@ -228,13 +245,18 @@ class DistributedShampoo(torch.optim.Optimizer): dampening (float): Dampening parameter for momentum. (Default: 0.) weight_decay (float): Weight decay (L2 penalty). (Default: 0.) max_preconditioner_dim (int): Maximum preconditioner dimensio. (Default: 1024) - precondition_frequency (int): Frequency for computing root inverse preconditioner. (Default: 1) + precondition_frequency (int): Frequency of updating all components of the preconditioner. + If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner. + (Default: 1) start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses the same value as precondition_frequency. (Default: -1) inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse - root -1 / (2 * o), where o is the order of the tensor. (Default: 0) + root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of + EigenvalueCorrectionConfig, the default is -1 / 2. + (Default: 0) exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: None) use_nesterov (bool): Flag for using Nesterov momentum. (default: False) @@ -259,7 +281,10 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) - root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) + preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. + If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. + (Default: DefaultEigenConfig) """ @@ -290,7 +315,7 @@ def __init__( precision_config: Optional[PrecisionConfig] = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, - root_inv_config: RootInvConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -404,17 +429,28 @@ def __init__( "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." ) + if ( + not isinstance(preconditioner_computation_config, RootInvConfig) + ) and track_root_inv_residuals: + raise ValueError( + f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig." + ) + # Create default precision config if it is not provided. if precision_config is None: precision_config = PrecisionConfig() # Set exponent multiplier if this is not provided. - if isinstance(root_inv_config, EigenConfig) and exponent_multiplier is not None: + if ( + isinstance(preconditioner_computation_config, EigenConfig) + and exponent_multiplier is not None + ): logger.warning( f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future." ) - root_inv_config = dataclasses.replace( - root_inv_config, exponent_multiplier=exponent_multiplier + preconditioner_computation_config = dataclasses.replace( + preconditioner_computation_config, + exponent_multiplier=exponent_multiplier, ) super().__init__( @@ -437,7 +473,7 @@ def __init__( GRAFTING_CONFIG: grafting_config, USE_MERGE_DIMS: use_merge_dims, PRECISION_CONFIG: precision_config, - ROOT_INV_CONFIG: root_inv_config, + PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config, }, ) @@ -508,17 +544,25 @@ def _instantiate_shampoo_preconditioner_list( for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - state_lists[SHAMPOO_PRECONDITIONER_LIST] = ShampooPreconditionerList( + state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( + EigenvalueCorrectedShampooPreconditionerList + if isinstance( + group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig + ) + else ShampooPreconditionerList + )( block_list=state_lists[DISTRIBUTOR].global_blocked_params, state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, - root_inv_config=group[ROOT_INV_CONFIG], + preconditioner_computation_config=group[ + PRECONDITIONER_COMPUTATION_CONFIG + ], + precision_config=group[PRECISION_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], inv_root_override=group[INV_ROOT_OVERRIDE], use_bias_correction=group[USE_BIAS_CORRECTION], - precision_config=group[PRECISION_CONFIG], use_protected_eigh=use_protected_eigh, ) @@ -755,6 +799,7 @@ def _mask_state_lists(state_lists: Dict[str, Any], group: Dict[str, Any]) -> Non ) @torch.no_grad() + @torch.compiler.disable def _compute_and_log_root_inverse_residuals( self, ) -> None: @@ -806,16 +851,6 @@ def _compute_and_log_root_inverse_residuals( f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}" ) - @torch.no_grad() - @torch.compiler.disable - def _compute_root_inverse( - self, state_lists: Dict[str, Any], compute_root_inverse: bool - ) -> None: - if compute_root_inverse: - state_lists[SHAMPOO_PRECONDITIONER_LIST].compute_root_inverse() - if self._track_root_inv_residuals: - self._compute_and_log_root_inverse_residuals() - @torch.no_grad() @torch.compiler.disable def _precondition_and_grafting( @@ -881,13 +916,17 @@ def _update_preconditioners( self, state_lists: Dict[str, Any], step: torch.Tensor, + perform_amortized_computation: bool, grafting_config_not_none: bool, ) -> None: - # Update Shampoo and grafting preconditioners / factor matrices. + # Update Shampoo and grafting preconditioners. state_lists[SHAMPOO_PRECONDITIONER_LIST].update_preconditioners( masked_grad_list=state_lists[MASKED_BLOCKED_GRADS], step=step, + perform_amortized_computation=perform_amortized_computation, ) + if perform_amortized_computation and self._track_root_inv_residuals: + self._compute_and_log_root_inverse_residuals() if grafting_config_not_none: state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners( masked_grad_list=state_lists[MASKED_BLOCKED_GRADS], @@ -1005,7 +1044,7 @@ def _per_group_step_impl( momentum_param: float, dampening: float, grafting_config_not_none: bool, - compute_root_inverse: bool, + perform_amortized_computation: bool, use_decoupled_weight_decay: bool, use_bias_correction: bool, use_grafting_method: bool, @@ -1028,23 +1067,23 @@ def _per_group_step_impl( if grafting_config_not_none else contextlib.nullcontext() ): - # Update Shampoo and grafting preconditioners / factor matrices. - # Example for AdaGrad accumulation: + # Update Shampoo and grafting preconditioners. + # Example for AdaGrad accumulation: + # 1. Update factor matrices/grafting preconditioners. # L <- L + G * G^T # R <- R + G^T * G # V <- V + G^2 (element-wise) # (and similar) - self._update_preconditioners( - state_lists, - step, - grafting_config_not_none, - ) - - # Compute matrix root inverse. + # 2. Compute root inverse if necessary. # L_inv <- L ** (-1/4) # R_inv <- R ** (-1/4) - # (and similar) - self._compute_root_inverse(state_lists, compute_root_inverse) + # (and similar); + self._update_preconditioners( + state_lists=state_lists, + step=step, + perform_amortized_computation=perform_amortized_computation, + grafting_config_not_none=grafting_config_not_none, + ) # Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. # Note that we use two beta factors here akin to Lion. @@ -1157,8 +1196,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] momentum_param = group[MOMENTUM] dampening = group[DAMPENING] grafting_config_not_none = group[GRAFTING_CONFIG] is not None - # Check compute root inverse or not for preconditioner - compute_root_inverse = ( + perform_amortized_computation = ( step.item() % group[PRECONDITION_FREQUENCY] == 0 and step.item() > group[START_PRECONDITIONING_STEP] or step.item() == group[START_PRECONDITIONING_STEP] @@ -1182,7 +1220,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] momentum_param, dampening, grafting_config_not_none, - compute_root_inverse, + perform_amortized_computation, use_decoupled_weight_decay, use_bias_correction, use_grafting_method, diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py new file mode 100644 index 0000000..116dbbf --- /dev/null +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -0,0 +1,241 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +#!/usr/bin/env python3 + +import math +import unittest +from functools import partial +from itertools import product +from typing import Any, Callable, Type + +import torch +from distributed_shampoo.distributed_shampoo import DistributedShampoo +from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem +from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig +from torch.nn.parameter import Parameter +from torch.optim.adagrad import Adagrad +from torch.optim.adam import Adam +from torch.optim.adamw import AdamW +from torch.optim.optimizer import ParamsT +from torch.optim.rmsprop import RMSprop + + +# Note: We have to set the epsilon to a very small value (i.e., 1e-15) due to the +# the place epsilon is added in the PyTorch optimizers (i.e., AdaGrad, RMSProp, Adam, AdamW) +# and Distributed Shampoo. +# The PyTorch optimizers add epsilon outside of the square root, and Distributed Shampoo +# adds epsilon inside of the square root. + + +class DistributedShampooEigenvalueCorrectionTest(unittest.TestCase): + @staticmethod + def _train_quadratic( + optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + device: torch.device, + ) -> tuple[Parameter, torch.Tensor]: + model, loss, data, target = construct_training_problem( + model_linear_layers_dims=(10, 1, 1), + device=device, + fill=1.0, + ) + params = model.parameters() + optimizer = optim_factory(params) + for _ in range(5): + optimizer.zero_grad() + objective = loss(model(data), target) + objective.backward() + optimizer.step() + return model.linear_layers[0].weight.data.cpu(), objective.detach().cpu() + + @staticmethod + def _test_baseline_and_shampoo( + baseline_optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + shampoo_optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + device: torch.device, + ) -> None: + ( + baseline_params, + baseline_loss, + ) = DistributedShampooEigenvalueCorrectionTest._train_quadratic( + baseline_optim_factory, + device=device, + ) + shampoo_params, shampoo_loss = ( + DistributedShampooEigenvalueCorrectionTest._train_quadratic( + shampoo_optim_factory, + device=device, + ) + ) + torch.testing.assert_close(shampoo_loss, baseline_loss) + torch.testing.assert_close( + shampoo_params, + baseline_params, + ) + + @staticmethod + def _optim_factory( + parameters: ParamsT, + optim_cls: Type[torch.optim.Optimizer], + **kwargs: Any, + ) -> torch.optim.Optimizer: + return optim_cls(parameters, **kwargs) + + def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.01, + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, optim_cls=Adagrad, eps=1e-15 + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + betas=(0.0, 1.0), + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_adam_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.001, + betas=(0.9, 0.999), + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=Adam, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.001, + betas=(0.9, 0.999), + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=AdamW, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=True, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.01, + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=RMSprop, + alpha=0.99, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + betas=(0.0, 0.99), + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + use_bias_correction=False, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index da890d7..60af2f4 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -34,8 +34,9 @@ PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" -ROOT_INV_CONFIG = "root_inv_config" +PRECONDITIONER_COMPUTATION_CONFIG = "preconditioner_computation_config" START_PRECONDITIONING_STEP = "start_preconditioning_step" +USE_EIGENVALUE_CORRECTION = "use_eigenvalue_correction" USE_BIAS_CORRECTION = "use_bias_correction" USE_DECOUPLED_WEIGHT_DECAY = "use_decoupled_weight_decay" USE_MERGE_DIMS = "use_merge_dims" @@ -102,6 +103,8 @@ class PrecisionConfig: factor_matrix_dtype (torch.dtype): Data type for storing Shampoo factor matrices. (Default: torch.float32) inv_factor_matrix_dtype (torch.dtype): Data type for storing Shampoo inverse factor matrices. (Default: torch.float32) factor_matrix_computation_dtype (torch.dtype): Data type for accumulating factor matrices and computing their inverses. (Default: torch.float32) + corrected_eigenvalues_dtype (torch.dtype): Data type for storing the corrected eigenvalues of Shampoo preconditioner (EMA). (Default: torch.float32) + factor_matrix_eigenvectors_dtype (torch.dtype): Data type for storing the eigenvectors of Shampoo factor matrices. (Default: torch.float32) filtered_grad_dtype (torch.dtype): Data type for storing filtered gradients (EMA). (Default: torch.float32) momentum_dtype (torch.dtype): Data type for storing momentum states. (Default: torch.float32) grafting_state_dtype (torch.dtype): Data type for storing grafting preconditioners, if applicable. (Default: torch.float32) @@ -117,6 +120,8 @@ class PrecisionConfig: computation_dtype: torch.dtype = torch.float32 factor_matrix_dtype: torch.dtype = torch.float32 inv_factor_matrix_dtype: torch.dtype = torch.float32 + corrected_eigenvalues_dtype: torch.dtype = torch.float32 + factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32 factor_matrix_computation_dtype: torch.dtype = torch.float32 filtered_grad_dtype: torch.dtype = torch.float32 momentum_dtype: torch.dtype = torch.float32 diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 5c29dd8..942de40 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -36,7 +36,10 @@ ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import DefaultEigenConfig +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighEigenvalueCorrectionConfig, +) from torch import nn @@ -238,7 +241,7 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: lr=0.01, start_preconditioning_step=1, exponent_multiplier=2.0, - root_inv_config=DefaultEigenConfig, + preconditioner_computation_config=DefaultEigenConfig, ) self.assertCountEqual( [r.msg for r in cm.records], @@ -247,6 +250,21 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: ], ) + def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> None: + with self.assertRaisesRegex( + ValueError, + re.escape( + "track_root_inv_residuals=True has to be set to False when preconditioner_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True) is not an instance of RootInvConfig." + ), + ): + DistributedShampoo( + self._model.parameters(), + lr=0.01, + start_preconditioning_step=1, + track_root_inv_residuals=True, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ) + class DistributedShampooTest(unittest.TestCase): def setUp(self) -> None: @@ -460,7 +478,7 @@ def setUp(self) -> None: ), "use_merge_dims": True, "precision_config": PrecisionConfig(), - "root_inv_config": DefaultEigenConfig, + "preconditioner_computation_config": DefaultEigenConfig, } }, } @@ -835,6 +853,108 @@ def test_setting_preconditioner_dtype_only(self) -> None: ) +class EigenvalueCorrectedDistributedShampooPrecisionTest( + DistributedShampooPrecisionTest +): + def _instantiate_optimizer( + self, precision_config: PrecisionConfig + ) -> DistributedShampoo: + return DistributedShampoo( + self._model.parameters(), + lr=0.01, + betas=(0.9, 1.0), + epsilon=1e-12, + momentum=0.99, + weight_decay=0.0, + max_preconditioner_dim=5, + precondition_frequency=1, + start_preconditioning_step=1, + distributed_config=None, + grafting_config=None, + precision_config=precision_config, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ) + + def _assert_state_list_dtype( + self, state_list: Dict[str, Any], precision_config: PrecisionConfig + ) -> None: + # TODO: is it possible to avoid accessing private field _masked_kronecker_factors_list? + for kronecker_factor in state_list[ + SHAMPOO_PRECONDITIONER_LIST + ]._masked_kronecker_factors_list: + self._assert_equal_state_dtype( + kronecker_factor.factor_matrices, + precision_config.factor_matrix_computation_dtype, + precision_config.factor_matrix_dtype, + ) + self._assert_equal_state_dtype( + kronecker_factor.factor_matrices_eigenvectors, + precision_config.computation_dtype, + precision_config.factor_matrix_eigenvectors_dtype, + ) + self._assert_equal_state_dtype( + kronecker_factor.corrected_eigenvalues, + precision_config.computation_dtype, + precision_config.corrected_eigenvalues_dtype, + ) + self._assert_equal_state_dtype( + state_list[MASKED_FILTERED_GRAD_LIST], + precision_config.computation_dtype, + precision_config.filtered_grad_dtype, + ) + self._assert_equal_state_dtype( + state_list[MASKED_MOMENTUM_LIST], + precision_config.computation_dtype, + precision_config.momentum_dtype, + ) + + def test_precision_configs(self) -> None: + precision_configs = [ + PrecisionConfig(computation_dtype=torch.float16), + PrecisionConfig(factor_matrix_dtype=torch.float16), + PrecisionConfig(factor_matrix_eigenvectors_dtype=torch.float16), + PrecisionConfig(corrected_eigenvalues_dtype=torch.float16), + PrecisionConfig(filtered_grad_dtype=torch.float16), + PrecisionConfig(momentum_dtype=torch.float16), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + ), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float16, + ), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float16, + filtered_grad_dtype=torch.float16, + momentum_dtype=torch.float16, + ), + PrecisionConfig(factor_matrix_computation_dtype=torch.float64), + PrecisionConfig( + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float64, + factor_matrix_computation_dtype=torch.float64, + ), + ] + + for precision_config in precision_configs: + with self.subTest(precision_config=precision_config): + optimizer = self._instantiate_optimizer( + precision_config=precision_config + ) + for state_list in optimizer._per_group_state_lists: + self._assert_state_list_dtype(state_list, precision_config) + + for _ in range(2): + optimizer.step() + for state_list in optimizer._per_group_state_lists: + self._assert_state_list_dtype(state_list, precision_config) + + class DistributedShampooNoneGradTest(unittest.TestCase): def setUp(self) -> None: self._model = nn.Sequential( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index baef780..7a8fed3 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -9,12 +9,14 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field from fractions import Fraction +from functools import partial, reduce from itertools import chain from operator import methodcaller -from typing import Any, DefaultDict, Sequence, Tuple, Union +from typing import Any, cast, DefaultDict, Generic, Sequence, TypeVar import torch from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError @@ -32,10 +34,16 @@ from matrix_functions import ( check_diagonal, compute_matrix_root_inverse_residuals, + matrix_eigenvectors, matrix_inverse_root, ) -from matrix_functions_types import DefaultEigenConfig, RootInvConfig +from matrix_functions_types import ( + DefaultEigenConfig, + EigenvalueCorrectionConfig, + PreconditionerComputationConfig, + RootInvConfig, +) from optimizer_modules import OptimizerModule from torch import Tensor from torch.autograd import profiler @@ -51,36 +59,37 @@ class PreconditionerList(ABC): """Preconditioner base class. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. """ def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], ) -> None: super().__init__() - self._numel_list: Tuple[int, ...] = (0,) * len(block_list) - self._dims_list: Tuple[torch.Size, ...] = tuple( + self._numel_list: tuple[int, ...] = (0,) * len(block_list) + self._dims_list: tuple[torch.Size, ...] = tuple( block.size() for block in block_list ) - self._num_bytes_list: Tuple[int, ...] = (0,) * len(block_list) + self._num_bytes_list: tuple[int, ...] = (0,) * len(block_list) @abstractmethod def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool, ) -> None: ... @abstractmethod def precondition( - self, masked_grad_list: Tuple[Tensor, ...] - ) -> Tuple[Tensor, ...]: ... + self, masked_grad_list: tuple[Tensor, ...] + ) -> tuple[Tensor, ...]: ... @abstractmethod def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: ... @abstractmethod @@ -90,15 +99,15 @@ def dequantize_preconditioners(self) -> None: ... def quantize_preconditioners(self) -> None: ... @property - def numel_list(self) -> Tuple[int, ...]: + def numel_list(self) -> tuple[int, ...]: return self._numel_list @property - def dims_list(self) -> Tuple[torch.Size, ...]: + def dims_list(self) -> tuple[torch.Size, ...]: return self._dims_list @property - def num_bytes_list(self) -> Tuple[int, ...]: + def num_bytes_list(self) -> tuple[int, ...]: return self._num_bytes_list def numel(self) -> int: @@ -112,28 +121,29 @@ class SGDPreconditionerList(PreconditionerList): """SGD (identity) preconditioners for a list of parameters. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. """ def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], ) -> None: super().__init__(block_list) def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool = False, ) -> None: return - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: return masked_grad_list def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: return @@ -158,11 +168,11 @@ class AdagradPreconditionerList(PreconditionerList): Other variants can also be specified. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. state (DefaultDict[Tensor, Any]): Dictionary containing optimizer state. - block_info_list (Tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. + block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (Tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter + distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. beta2 (float): Exponential moving average factor for Adam/RMSprop second moment state. If beta2 = 1., will use unweighted sum. (Default: 1.0) @@ -174,11 +184,11 @@ class AdagradPreconditionerList(PreconditionerList): def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], # type: ignore state: DefaultDict[Tensor, Any], - block_info_list: Tuple[BlockInfo, ...], - distributor_selector: Tuple[bool, ...], + block_info_list: tuple[BlockInfo, ...], + distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, beta2: float = 1.0, epsilon: float = 1e-10, @@ -208,7 +218,9 @@ def __init__( preconditioner_index = str(param_index) + "." + str(block_index) block_state[ADAGRAD] = QuantizedTensor( quantized_values=block_info.allocate_zeros_tensor( - block.size(), precision_config.grafting_state_dtype, block.device + shape=block.size(), + dtype=precision_config.grafting_state_dtype, + device=block.device, ), block_info=block_info, ) @@ -230,22 +242,23 @@ def __init__( ) # Construct lists of dims, bytes, and numels for logging purposes. - self._dims_list: Tuple[torch.Size, ...] = compress_list( + self._dims_list: tuple[torch.Size, ...] = compress_list( self._dims_list, distributor_selector ) - self._numel_list: Tuple[int, ...] = tuple( + self._numel_list: tuple[int, ...] = tuple( quantized_preconditioner.numel() for quantized_preconditioner in self._local_preconditioner_list.quantized_value ) - self._num_bytes_list: Tuple[int, ...] = tuple( + self._num_bytes_list: tuple[int, ...] = tuple( quantize_preconditioner.numel() * quantize_preconditioner.element_size() for quantize_preconditioner in self._local_preconditioner_list.quantized_value ) def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool = False, ) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" @@ -272,7 +285,16 @@ def update_preconditioners( if self._use_bias_correction and self._beta2 < 1.0: self._bias_correction2 = torch.tensor(1.0) - self._beta2**step - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions the gradient list using the AdaGrad preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A tuple of gradients with None values removed. + + Returns: + tuple[Tensor, ...]: A tuple of preconditioned gradients. + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): @@ -301,7 +323,7 @@ def quantize_preconditioners(self) -> None: self._masked_preconditioner_list.quantize_() def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" @@ -311,70 +333,110 @@ def compress_preconditioner_list( ) +FactorMatricesType = TypeVar( + "FactorMatricesType", tuple[QuantizedTensor, ...], QuantizedTensorList +) + + @dataclass -class ShampooKroneckerFactorsState(OptimizerModule): - """Shampoo Kronecker Factors (wrapped) for storing in the optimizer state.""" +class BaseShampooKroneckerFactors(Generic[FactorMatricesType], OptimizerModule): + """Base class for Shampoo Kronecker factors.""" - factor_matrices: Tuple[QuantizedTensor, ...] - inv_factor_matrices: Tuple[QuantizedTensor, ...] - factor_matrix_indices: Tuple[str, ...] - is_factor_matrices_diagonal: Tuple[Tensor, ...] = field(init=False) + factor_matrices: FactorMatricesType + factor_matrix_indices: tuple[str, ...] + is_factor_matrices_diagonal: tuple[Tensor, ...] = field(init=False) def __post_init__(self) -> None: super().__init__() - assert ( - len(self.factor_matrices) - == len(self.inv_factor_matrices) - == len(self.factor_matrix_indices) - ) + assert len(self.factor_matrices) == len(self.factor_matrix_indices) self.is_factor_matrices_diagonal = tuple( torch.tensor(True) for _ in range(len(self.factor_matrices)) ) @dataclass -class ShampooKroneckerFactorsList(OptimizerModule): - """Shampoo Kronecker Factors (unwrapped) for operations during optimizer computation.""" +class ShampooKroneckerFactorsState( + BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] +): + """Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" + + inv_factor_matrices: tuple[QuantizedTensor, ...] + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.inv_factor_matrices) + + +@dataclass +class ShampooKroneckerFactorsList(BaseShampooKroneckerFactors[QuantizedTensorList]): + """Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" - factor_matrices: QuantizedTensorList inv_factor_matrices: QuantizedTensorList - factor_matrix_indices: Tuple[str, ...] - is_factor_matrices_diagonal: Tuple[Tensor, ...] = field(init=False) def __post_init__(self) -> None: - super().__init__() - assert ( - len(self.factor_matrices) - == len(self.inv_factor_matrices) - == len(self.factor_matrix_indices) - ) - self.is_factor_matrices_diagonal = tuple( - torch.tensor(True) for _ in range(len(self.factor_matrices)) - ) + super().__post_init__() + assert len(self.factor_matrices) == len(self.inv_factor_matrices) -class ShampooPreconditionerList(PreconditionerList): - """Shampoo preconditioners for list of parameters. +@dataclass +class EigenvalueCorrectedShampooKroneckerFactorsState( + BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] +): + """Eigenvalue-corrected Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" + + factor_matrices_eigenvectors: tuple[QuantizedTensor, ...] + corrected_eigenvalues: QuantizedTensor + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +@dataclass +class EigenvalueCorrectedShampooKroneckerFactorsList( + BaseShampooKroneckerFactors[QuantizedTensorList] +): + """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" + + factor_matrices_eigenvectors: QuantizedTensorList + corrected_eigenvalues: QuantizedTensorList + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +ShampooKroneckerFactorsListType = TypeVar( + "ShampooKroneckerFactorsListType", + ShampooKroneckerFactorsList, + EigenvalueCorrectedShampooKroneckerFactorsList, +) + + +class BaseShampooPreconditionerList( + PreconditionerList, Generic[ShampooKroneckerFactorsListType] +): + """Base class for Shampoo preconditioners. NOTE: Does not support sparse gradients at this time. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. state (DefaultDict[Tensor, Any]): Dictionary containing optimizer state. - block_info_list (Tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. + block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (Tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter + distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. - root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) + precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) + preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultEigenConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) - inv_root_override (Union[int, Tuple[int, ...]]): Inverse root to use in Shampoo. If a list [l0, l1, l2, ..., lp], then we will + inv_root_override (int | tuple[int, ...]): Inverse root to use in Shampoo. If a list [l0, l1, l2, ..., lp], then we will use -1 / l0 for 0-D tensors (scalars), -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the length of the list, we revert to using the default value. If 0 is used, uses the default inverse root -1 / (2 * o), where o is the order of the tensor. (Default: 0) use_bias_correction (bool): Flag for using bias correction. (Default: True) - precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) use_protected_eigh (bool): Flag for using two guards to prevent failures of torch.linalg.eigh. (Default: True) 1. Attempts to compute root inverse in preconditioner_dtype precision. 2. Attempts to recompute the eigendecomposition if using lower-precision fails. @@ -384,16 +446,16 @@ class ShampooPreconditionerList(PreconditionerList): def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], # type: ignore state: DefaultDict[Tensor, Any], - block_info_list: Tuple[BlockInfo, ...], - distributor_selector: Tuple[bool, ...], + block_info_list: tuple[BlockInfo, ...], + distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, - root_inv_config: RootInvConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, beta2: float = 1.0, epsilon: float = 1e-12, - inv_root_override: Union[int, Tuple[int, ...]] = 0, + inv_root_override: int | tuple[int, ...] = 0, use_bias_correction: bool = True, use_protected_eigh: bool = True, ) -> None: @@ -401,7 +463,7 @@ def __init__( # Initialize parameters. self._precision_config = precision_config - self._root_inv_config = root_inv_config + self._preconditioner_computation_config = preconditioner_computation_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override @@ -409,6 +471,103 @@ def __init__( self._use_protected_eigh = use_protected_eigh self._bias_correction2: Tensor = torch.tensor(1.0) + # Create the Kronecker factors. + kronecker_factors_list: list[ShampooKroneckerFactorsListType] = ( + self._create_kronecker_factors_state( + block_list=block_list, + state=state, + block_info_list=block_info_list, + ) + ) + + # Initialize state lists. + self._initialize_state_lists( + block_list=block_list, + kronecker_factors_list=kronecker_factors_list, + distributor_selector=distributor_selector, + ) + + def _create_base_kronecker_factors( + self, + block_info: BlockInfo, + dims: torch.Size, + ) -> BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]: + """ + Creates a BaseShampooKroneckerFactor object for a given block. + + Args: + block_info (BlockInfo): The BlockInfo object containing information about the block. + dims (torch.Size): The dimensions of the block. + + Returns: + kronecker_factors_state (BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]): An object containing the Kronecker factors for the block. + """ + factor_matrices = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.factor_matrix_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + + param_index, block_index = block_info.composable_block_ids + factor_matrix_indices = tuple( + ".".join((str(param_index), str(block_index), str(k))) + for k in range(len(dims)) + ) + return BaseShampooKroneckerFactors( + factor_matrices=factor_matrices, + factor_matrix_indices=factor_matrix_indices, + ) + + @abstractmethod + def _create_kronecker_factors_state_for_block( + self, + block: Tensor, + block_info: BlockInfo, + dims: torch.Size, + ) -> ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState: + """ + Creates a Kronecker factors state object for a given block. + + Args: + block (Tensor): The block of the parameter. + block_info (BlockInfo): The BlockInfo object containing information about the block. + dims (torch.Size): The dimensions of the block. + + Returns: + kronecker_factors_state (ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState): An object containing the Kronecker factors for the block. + """ + ... + + @abstractmethod + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> ShampooKroneckerFactorsListType: + """ + Creates a ShampooKroneckerFactorsList object from the given ShampooKroneckerFactorsState. + + Args: + kronecker_factors_state (ShampooKroneckerFactorsState): The state containing the Kronecker factors. + + Returns: + kronecker_factors_list: A list of ShampooKroneckerFactors objects. + """ + ... + + def _create_kronecker_factors_state( + self, + block_list: tuple[Tensor, ...], + # type: ignore + state: DefaultDict[Tensor, Any], + block_info_list: tuple[BlockInfo, ...], + ) -> list[ShampooKroneckerFactorsListType]: # Instantiate (blocked) Kronecker factors and construct list of Kronecker factors. # NOTE: We need to instantiate the Kronecker factor states within the optimizer's state dictionary, # and do not explicitly store them as ShampooPreconditionerList attributes here. @@ -423,104 +582,46 @@ def __init__( state[block_info.param][block_index] = {} block_state = state[block_info.param][block_index] - # Instantiate ShampooKroneckerFactors for this block. - factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - (dim, dim), - self._precision_config.factor_matrix_dtype, - block_info.param.device, - ), - block_info=block_info, - ) - for dim in dims - ) - inv_factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - (dim, dim), - self._precision_config.inv_factor_matrix_dtype, - block_info.param.device, - ), - block_info=block_info, - ) - for dim in dims + block_state[SHAMPOO] = self._create_kronecker_factors_state_for_block( + block=block, block_info=block_info, dims=dims ) - preconditioner_index = str(param_index) + "." + str(block_index) - factor_matrix_indices = tuple( - preconditioner_index + "." + str(k) for k in range(len(dims)) - ) - block_state[SHAMPOO] = ShampooKroneckerFactorsState( - factor_matrices=factor_matrices, - inv_factor_matrices=inv_factor_matrices, - factor_matrix_indices=factor_matrix_indices, - ) kronecker_factors_list.append( - ShampooKroneckerFactorsList( - # Factor matrices computation (accumulation, root inverse) should use the determined dtype. - factor_matrices=QuantizedTensorList( - quantized_data=factor_matrices, - quantized_dtype=self._precision_config.factor_matrix_dtype, - computation_dtype=self._precision_config.factor_matrix_computation_dtype, - ), - # Inverse factor matrices computation (preconditioning) should use the dtype of the block / gradient. - inv_factor_matrices=QuantizedTensorList( - quantized_data=inv_factor_matrices, - quantized_dtype=self._precision_config.inv_factor_matrix_dtype, - computation_dtype=self._precision_config.computation_dtype, - ), - factor_matrix_indices=factor_matrix_indices, - ) + self._create_kronecker_factors_list(block_state[SHAMPOO]) ) logger.info( - f"Instantiated Shampoo Preconditioner {preconditioner_index} " + f"Instantiated Shampoo Preconditioner {str(param_index) + '.' + str(block_index)} " f"({[(factor_matrix.quantized_values.shape, factor_matrix.quantized_values.dtype) for factor_matrix in block_state[SHAMPOO].factor_matrices]}) " f"for Parameter {param_index} ({block_info.param.shape}), Block {block_index} ({block.shape})." ) - # Initialize local lists. - local_block_list = compress_list(block_list, distributor_selector) - self._local_kronecker_factors_list: Tuple[ShampooKroneckerFactorsList, ...] = ( - compress_list(kronecker_factors_list, distributor_selector) - ) - self._local_order_list: Tuple[int, ...] = tuple( - block.dim() for block in local_block_list - ) - self._local_root_list: Tuple[int, ...] = self._get_inverse_roots_from_override( - self._inv_root_override, self._local_order_list - ) + return kronecker_factors_list - # Masked lists are the list of active preconditioners or values after filtering out gradients with None. - self._masked_order_list: Tuple[int, ...] = self._local_order_list - self._masked_root_list: Tuple[int, ...] = self._local_root_list - self._masked_kronecker_factors_list: Tuple[ShampooKroneckerFactorsList, ...] = ( - self._local_kronecker_factors_list - ) + @abstractmethod + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + """ + Retrieves the inverse roots from the override parameter. - # Construct lists of bytes and numels for logging purposes. - # NOTE: These lists are constructed across all blocked parameters. - self._dims_list: Tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) - self._numel_list: Tuple[int, ...] = tuple( - sum(2 * dim**2 for dim in dims) for dims in self._dims_list - ) - self._num_bytes_list: Tuple[int, ...] = tuple( - numel - * ( - get_dtype_size(self._precision_config.factor_matrix_dtype) - + get_dtype_size(block.dtype) - ) - // 2 - for numel, block in zip(self._numel_list, local_block_list, strict=True) - ) + Args: + inv_root_override (int | Sequence[int]): The override value for the inverse root. + order_list (tuple[int, ...]): A list of orders for each tensor in the preconditioner. + + Returns: + tuple[int, ...]: A list of inverse roots for each tensor in the preconditioner. + """ + ... @staticmethod - def _get_inverse_roots_from_override( - inv_root_override: Union[int, Sequence[int]], order_list: Tuple[int, ...] - ) -> Tuple[int, ...]: + def _get_inverse_roots_from_override_with_high_order_default( + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + high_order_default: Callable[[int], int], + ) -> tuple[int, ...]: """Retrieves the appropriate root from the inverse root override parameter for a list of tensor orders. @@ -529,20 +630,21 @@ def _get_inverse_roots_from_override( If order = 1, then we will return 1; If order = 2, then we will return 4; If order = 3, then we will return 3; - If order > 3, then we will return 2 * order. + If order > 3, then we will return high_order_default(order). Args: - inv_root_override (int, Sequence[int]): Inverse root override int or list. - order_list (Tuple[int, ...]): List of orders for their corresponding tensors. + inv_root_override (int | Sequence[int]): Inverse root override int or list. + order_list (tuple[int, ...]): List of orders for their corresponding tensors. + higher_order_default (Callable[[int], int]): Function for computing the inverse root for orders greater than the length of the inverse root override list. Returns: - root_list (Tuple[int, ...]): Inverse roots to use in Shampoo for a list of tensors. + root_list (int): Inverse roots to use in Shampoo for a list of tensors. """ if isinstance(inv_root_override, Sequence): return tuple( ( - 2 * order + high_order_default(order) if order >= len(inv_root_override) else inv_root_override[order] ) @@ -550,16 +652,151 @@ def _get_inverse_roots_from_override( ) else: return ( - tuple(2 * order for order in order_list) + tuple(high_order_default(order) for order in order_list) if inv_root_override == 0 else (inv_root_override,) * len(order_list) ) + @abstractmethod + def _amortized_computation(self) -> None: + """ + Computes the amortized computation needed for each Shampoo implementation. + This amortized computation is computation heavy work that could not be done for eeac step. + As a result, each Shampoo implementation may implement this method for its neeed. + """ + ... + + @staticmethod + def _check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix: Tensor, + is_factor_matrix_diagonal: Tensor, + factor_matrix_index: str, + ) -> None: + # For tracking diagonality of the factor matrix. + # Checks if the factor matrix is currently diagonal, then checks whether or not + # the update factor matrix is diagonal. + if is_factor_matrix_diagonal and not check_diagonal(factor_matrix): + is_factor_matrix_diagonal.copy_(torch.tensor(False)) + logger.debug(f"Factor matrix {factor_matrix_index} is not diagonal.") + + # Check for nan or inf values. + if torch.isnan(factor_matrix).any(): + raise PreconditionerValueError( + f"Encountered nan values in factor matrix {factor_matrix_index}! " + f"To mitigate, check if nan inputs are being passed into the network or nan gradients " + f"are being passed to the optimizer." + f"For debugging purposes, factor_matrix {factor_matrix_index}: " + f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " + f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." + ) + if torch.isinf(factor_matrix).any(): + raise PreconditionerValueError( + f"Encountered inf values in factor matrix {factor_matrix_index}! " + f"In some cases, this may be due to divergence of the algorithm. " + f"To mitigate, try decreasing the learning rate or increasing grafting epsilon." + f"For debugging purposes, factor_matrix {factor_matrix_index}: " + f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " + f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." + ) + def update_preconditioners( - self, masked_grad_list: Tuple[Tensor, ...], step: Tensor + self, + masked_grad_list: tuple[Tensor, ...], + step: Tensor, + perform_amortized_computation: bool, ) -> None: + """ + Updates the preconditioners. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + step (Tensor): The current step. + perform_amortized_computation (bool): Whether to perform an amortized computation. + + Returns: + None + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" + ): + # Update the Kronecker factor matrices. + self._update_factor_matrices(masked_grad_list=masked_grad_list) + + # Update bias correction term based on step. + if self._use_bias_correction and self._beta2 < 1.0: + self._bias_correction2 = torch.tensor(1.0) - self._beta2**step + + # In Shampoo, this is equivalent to computing the inverse factor matrix. + # In Eigenvalue-Corrected Shampoo, this is equivalent to computing the eigenvector of the factor matrix. + if perform_amortized_computation: + self._amortized_computation() + + def _initialize_state_lists( + self, + block_list: tuple[Tensor, ...], + kronecker_factors_list: list[ShampooKroneckerFactorsListType], + distributor_selector: tuple[bool, ...], + ) -> None: + # Initialize local lists. + local_block_list = compress_list(block_list, distributor_selector) + self._local_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = compress_list(kronecker_factors_list, distributor_selector) + self._local_order_list: tuple[int, ...] = tuple( + block.dim() for block in local_block_list + ) + self._local_root_list: tuple[int, ...] = self._get_inverse_roots_from_override( + self._inv_root_override, + self._local_order_list, + ) + + # Masked lists are the list of active preconditioners or values after filtering out gradients with None. + self._masked_order_list: tuple[int, ...] = self._local_order_list + self._masked_root_list: tuple[int, ...] = self._local_root_list + self._masked_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = self._local_kronecker_factors_list + + # Construct lists of bytes and numels for logging purposes. + # NOTE: These lists are constructed across all blocked parameters. + self._dims_list: tuple[torch.Size, ...] = compress_list( + self._dims_list, distributor_selector + ) + self._numel_list: tuple[int, ...] = tuple( + sum(2 * dim**2 for dim in dims) for dims in self._dims_list + ) + self._num_bytes_list: tuple[int, ...] = tuple( + numel + * ( + get_dtype_size(self._precision_config.factor_matrix_dtype) + + get_dtype_size(block.dtype) + ) + // 2 + for numel, block in zip(self._numel_list, local_block_list, strict=True) + ) + + def compress_preconditioner_list( + self, local_grad_selector: tuple[bool, ...] + ) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" + ): + self._masked_order_list = compress_list( + self._local_order_list, local_grad_selector + ) + self._masked_root_list = compress_list( + self._local_root_list, local_grad_selector + ) + self._masked_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = compress_list(self._local_kronecker_factors_list, local_grad_selector) + + def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self._update_factor_matrices.__name__} ##" ): # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. # We apply foreach operators onto the list of Kronecker factor matrices (as opposed to the @@ -594,43 +831,107 @@ def update_preconditioners( alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, ) - # Update bias correction term based on step list. - if self._use_bias_correction and self._beta2 < 1.0: - self._bias_correction2 = torch.tensor(1.0) - self._beta2**step + @staticmethod + def _precondition_grad( + grad: Tensor, + preconditioner_list: tuple[Tensor, ...], + dims: tuple[list[int], list[int]] = ([0], [0]), + ) -> Tensor: + return reduce(partial(torch.tensordot, dims=dims), preconditioner_list, grad) + + +class ShampooPreconditionerList( + BaseShampooPreconditionerList[ShampooKroneckerFactorsList] +): + """Shampoo preconditioners for list of parameters.""" + + def _create_kronecker_factors_state_for_block( + self, block: Tensor, block_info: BlockInfo, dims: torch.Size + ) -> ShampooKroneckerFactorsState: + inv_factor_matrices = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.inv_factor_matrix_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + + base_kronecker_factors = self._create_base_kronecker_factors( + block_info=block_info, dims=dims + ) + return ShampooKroneckerFactorsState( + factor_matrices=base_kronecker_factors.factor_matrices, + factor_matrix_indices=base_kronecker_factors.factor_matrix_indices, + inv_factor_matrices=inv_factor_matrices, + ) - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> ShampooKroneckerFactorsList: + assert isinstance(kronecker_factors_state, ShampooKroneckerFactorsState) + return ShampooKroneckerFactorsList( + # Factor matrices computation (accumulation, root inverse) should use the determined dtype. + factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices, + quantized_dtype=self._precision_config.factor_matrix_dtype, + computation_dtype=self._precision_config.factor_matrix_computation_dtype, + ), + # Inverse factor matrices computation (preconditioning) should use the dtype of the block / gradient. + inv_factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.inv_factor_matrices, + quantized_dtype=self._precision_config.inv_factor_matrix_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, + ) + + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + return BaseShampooPreconditionerList._get_inverse_roots_from_override_with_high_order_default( + inv_root_override, order_list, high_order_default=lambda order: 2 * order + ) + + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions a list of gradients using the Shampoo preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + + Returns: + tuple[Tensor, ...]: A list of preconditioned gradients. + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): - - def precondition_masked_grad( - masked_grad: Tensor, - inv_factor_matrices: Tuple[Tensor, ...], - ) -> Tensor: - for inv_factor_matrix in inv_factor_matrices: - masked_grad = torch.tensordot( - masked_grad, inv_factor_matrix, [[0], [0]] - ) - return masked_grad - return tuple( - precondition_masked_grad( - masked_grad=masked_grad, - inv_factor_matrices=kronecker_factors.inv_factor_matrices.dequantized_value, + self._precondition_grad( + grad=masked_grad, + preconditioner_list=kronecker_factors.inv_factor_matrices.dequantized_value, ) for masked_grad, kronecker_factors in zip( masked_grad_list, self._masked_kronecker_factors_list, strict=True ) ) - def compute_root_inverse(self) -> None: + @torch.compiler.disable + def _amortized_computation(self) -> None: # NOTE: This function currently only computes the matrix root inverse based on # the masked lists which combines both selection based on the distributor and where # grad is not None. Implicitly, this assumes that there are no changes between the # selector or masking from iteration-to-iteration within a single precondition_frequency # interval. with profiler.record_function( - f"## {self.__class__.__name__}:{self.compute_root_inverse.__name__} ##" + f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): for kronecker_factors, root in zip( self._masked_kronecker_factors_list, @@ -649,81 +950,58 @@ def compute_root_inverse(self) -> None: kronecker_factors.factor_matrix_indices, strict=True, ): - # For tracking diagonality of the preconditioner. - # Checks if the preconditioner is currently diagonal, then checks whether or not - # the update matrix is diagonal. - if is_factor_matrix_diagonal and not check_diagonal(factor_matrix): - is_factor_matrix_diagonal.copy_(torch.tensor(False)) - logger.debug( - f"Factor matrix {factor_matrix_index} is not diagonal." - ) - # Add epsilon term and incorporate bias correction. bias_corrected_factor_matrix = ( factor_matrix / self._bias_correction2 ) - # Check for nan or inf values. - if torch.isnan(bias_corrected_factor_matrix).any(): - raise PreconditionerValueError( - f"Encountered nan values in bias-corrected factor matrix {factor_matrix_index}! " - f"To mitigate, check if nan inputs are being passed into the network or nan gradients " - f"are being passed to the optimizer. " - f"For debugging purposes, factor_matrix {factor_matrix_index}: " - f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " - f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." - ) - if torch.isinf(bias_corrected_factor_matrix).any(): - raise PreconditionerValueError( - f"Encountered inf values in bias-corrected factor matrix {factor_matrix_index}! " - f"In some cases, this may be due to divergence of the algorithm. " - f"To mitigate, try decreasing the learning rate or increasing grafting epsilon. " - f"For debugging purposes, factor_matrix {factor_matrix_index}: " - f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " - f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." - ) + BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix=bias_corrected_factor_matrix, + is_factor_matrix_diagonal=is_factor_matrix_diagonal, + factor_matrix_index=factor_matrix_index, + ) # Compute inverse preconditioner. - # If reuse_previous_inv_factor_matrix is True, will reuse previous matrix if matrix - # inverse root computation fails. try: computed_inv_factor_matrix = matrix_inverse_root( A=bias_corrected_factor_matrix, root=Fraction( root / getattr( - self._root_inv_config, "exponent_multiplier", 1 + self._preconditioner_computation_config, + "exponent_multiplier", + 1, ) ), - root_inv_config=self._root_inv_config, + root_inv_config=cast( + RootInvConfig, self._preconditioner_computation_config + ), epsilon=self._epsilon, is_diagonal=is_factor_matrix_diagonal, ).to(dtype=inv_factor_matrix.dtype) - - # Check if we encounter NaN or inf values in computed inverse matrix. - if ( - torch.isnan(computed_inv_factor_matrix).any() - or torch.isinf(computed_inv_factor_matrix).any() - ): - torch.set_printoptions(threshold=100_000) - raise PreconditionerValueError( - f"Encountered nan or inf values in inverse factor matrix {factor_matrix_index}! " - f"To mitigate, check factor matrix before matrix inverse root computation: " - f"{bias_corrected_factor_matrix=}" - ) - - inv_factor_matrix.copy_(computed_inv_factor_matrix) - - except PreconditionerValueError as pve: - raise pve except Exception as exception: + # If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails. if not self._use_protected_eigh: raise exception else: logger.warning( - f"Matrix inverse root computation failed for factor matrix {factor_matrix_index} " - f"with exception {exception}. Using previous inv_factor_matrix and continuing..." + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous inversed factor matrix and continuing..." ) + # Define computed_inv_factor_matrix to prevent undefined local variable error. + computed_inv_factor_matrix = inv_factor_matrix + + # Check if we encounter NaN or inf values in computed inverse matrix. + if ( + torch.isnan(computed_inv_factor_matrix).any() + or torch.isinf(computed_inv_factor_matrix).any() + ): + torch.set_printoptions(threshold=100_000) + raise PreconditionerValueError( + f"Encountered nan or inf values in inverse factor matrix {factor_matrix_index}! " + f"To mitigate, check factor matrix before the matrix computation: {bias_corrected_factor_matrix=}" + ) + inv_factor_matrix.copy_(computed_inv_factor_matrix) def dequantize_preconditioners(self) -> None: with profiler.record_function( @@ -741,25 +1019,10 @@ def quantize_preconditioners(self) -> None: kronecker_factors.factor_matrices.quantize_() kronecker_factors.inv_factor_matrices.quantize_() - def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] - ) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" - ): - self._masked_order_list = compress_list( - self._local_order_list, local_grad_selector - ) - self._masked_root_list = compress_list( - self._local_root_list, local_grad_selector - ) - self._masked_kronecker_factors_list: Tuple[ - ShampooKroneckerFactorsList, ... - ] = compress_list(self._local_kronecker_factors_list, local_grad_selector) - + @torch.compiler.disable def compute_root_inverse_residuals( self, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: relative_errors = [] relative_residuals = [] @@ -781,10 +1044,17 @@ def compute_root_inverse_residuals( A=bias_corrected_factor_matrix, X_hat=inv_factor_matrix, root=Fraction( - root / getattr(self._root_inv_config, "exponent_multiplier", 1) + root + / getattr( + self._preconditioner_computation_config, + "exponent_multiplier", + 1, + ) ), epsilon=self._epsilon, - root_inv_config=self._root_inv_config, + root_inv_config=cast( + RootInvConfig, self._preconditioner_computation_config + ), ) relative_errors.append(relative_error) relative_residuals.append(relative_residual) @@ -795,6 +1065,273 @@ def compute_root_inverse_residuals( ) +class EigenvalueCorrectedShampooPreconditionerList( + BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList] +): + """Eigenvalue-corrected Shampoo preconditioners for list of parameters.""" + + def _create_kronecker_factors_state_for_block( + self, block: Tensor, block_info: BlockInfo, dims: torch.Size + ) -> EigenvalueCorrectedShampooKroneckerFactorsState: + factor_matrices_eigenvectors = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.factor_matrix_eigenvectors_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + corrected_eigenvalues = QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=tuple(dims), + dtype=self._precision_config.corrected_eigenvalues_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + + base_kronecker_factors = self._create_base_kronecker_factors( + block_info=block_info, dims=dims + ) + return EigenvalueCorrectedShampooKroneckerFactorsState( + factor_matrices=base_kronecker_factors.factor_matrices, + factor_matrices_eigenvectors=factor_matrices_eigenvectors, + corrected_eigenvalues=corrected_eigenvalues, + factor_matrix_indices=base_kronecker_factors.factor_matrix_indices, + ) + + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> EigenvalueCorrectedShampooKroneckerFactorsList: + assert isinstance( + kronecker_factors_state, EigenvalueCorrectedShampooKroneckerFactorsState + ) + return EigenvalueCorrectedShampooKroneckerFactorsList( + factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices, + quantized_dtype=self._precision_config.factor_matrix_dtype, + computation_dtype=self._precision_config.factor_matrix_computation_dtype, + ), + factor_matrices_eigenvectors=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices_eigenvectors, + quantized_dtype=self._precision_config.factor_matrix_eigenvectors_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + corrected_eigenvalues=QuantizedTensorList( + quantized_data=(kronecker_factors_state.corrected_eigenvalues,), + quantized_dtype=self._precision_config.corrected_eigenvalues_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, + ) + + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + return BaseShampooPreconditionerList._get_inverse_roots_from_override_with_high_order_default( + inv_root_override, order_list, high_order_default=lambda order: 2 + ) + + def update_preconditioners( + self, + masked_grad_list: tuple[Tensor, ...], + step: Tensor, + perform_amortized_computation: bool, + ) -> None: + """ + Updates the preconditioners. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + step (Tensor): The current step. + perform_amortized_computation (bool): Whether to perform an amortized computation. + + Returns: + None + """ + with profiler.record_function( + f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" + ): + super().update_preconditioners( + masked_grad_list=masked_grad_list, + step=step, + perform_amortized_computation=perform_amortized_computation, + ) + # Update the eigenvalue corrections of Shampoo's preconditioner. + self._update_eigenvalue_corrections(masked_grad_list=masked_grad_list) + + def _update_eigenvalue_corrections( + self, masked_grad_list: tuple[Tensor, ...] + ) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self._update_eigenvalue_corrections.__name__} ##" + ): + # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. + for grad, kronecker_factors in zip( + masked_grad_list, + self._masked_kronecker_factors_list, + strict=True, + ): + factor_eigenvectors = ( + kronecker_factors.factor_matrices_eigenvectors.dequantized_value + ) + if factor_eigenvectors[0].any(): + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + ) + # Scale corrected eigenvalues. + # NOTE: The case when self._beta2 == 1.0 is not well tested and might not be stable. + if self._beta2 != 1.0: + kronecker_factors.corrected_eigenvalues.dequantized_value[0].mul_( + self._beta2 + ) + # Update corrected eigenvalues (squared gradient in eigenbasis of Shampoo preconditioner). + kronecker_factors.corrected_eigenvalues.dequantized_value[0].add_( + grad.square(), + alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, + ) + + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions a list of gradients using the Eigenvalue-Corrected Shampoo preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + + Returns: + tuple[Tensor, ...]: A list of preconditioned gradients. + """ + with profiler.record_function( + f"## {self.__class__.__name__}:{self.precondition.__name__} ##" + ): + preconditioned_grad_list = [] + for masked_grad, kronecker_factors, root in zip( + masked_grad_list, + self._masked_kronecker_factors_list, + self._masked_root_list, + strict=True, + ): + factor_eigenvectors = ( + kronecker_factors.factor_matrices_eigenvectors.dequantized_value + ) + corrected_eigenvalues = ( + kronecker_factors.corrected_eigenvalues.dequantized_value[0] + ) + use_eigenbasis = factor_eigenvectors[0].any() + grad = masked_grad.clone() + if use_eigenbasis: + # Convert to eigenbasis of Shampoo factor matrices. + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + ) + + # Precondition with inverse root of corrected eigenvalues. + grad.div_( + corrected_eigenvalues.div(self._bias_correction2) + .add_(self._epsilon) + .pow_(1 / root) + ) + if use_eigenbasis: + # Convert back to basis of the parameters. + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + dims=([0], [1]), + ) + preconditioned_grad_list.append(grad) + return tuple(preconditioned_grad_list) + + @torch.compiler.disable + def _amortized_computation(self) -> None: + # NOTE: This function currently only computes the preconditioner eigenvectors based on + # the masked lists which combines both selection based on the distributor and where + # grad is not None. Implicitly, this assumes that there are no changes between the + # selector or masking from iteration-to-iteration within a single precondition_frequency + # interval. + with profiler.record_function( + f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + for ( + factor_matrix, + factor_matrix_eigenvectors, + is_factor_matrix_diagonal, + factor_matrix_index, + ) in zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.factor_matrices_eigenvectors.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, + ): + BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix=factor_matrix, + is_factor_matrix_diagonal=is_factor_matrix_diagonal, + factor_matrix_index=factor_matrix_index, + ) + + # Compute eigenvectors of factor matrix. + try: + computed_eigenvectors = matrix_eigenvectors( + A=factor_matrix, + eigenvector_computation_config=cast( + EigenvalueCorrectionConfig, + self._preconditioner_computation_config, + ), + is_diagonal=is_factor_matrix_diagonal, + ) + except Exception as exception: + # If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails. + if not self._use_protected_eigh: + raise exception + else: + logger.warning( + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." + ) + # Define computed_eigenvectors to prevent undefined local variable error. + computed_eigenvectors = factor_matrix_eigenvectors + + # Check if we encounter NaN or inf values in computed eigenvectors. + if ( + torch.isnan(computed_eigenvectors).any() + or torch.isinf(computed_eigenvectors).any() + ): + torch.set_printoptions(threshold=100_000) + raise PreconditionerValueError( + f"Encountered nan or inf values in eigenvectors of factor matrix {factor_matrix_index}! " + f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}" + ) + factor_matrix_eigenvectors.copy_(computed_eigenvectors) + + def dequantize_preconditioners(self) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + kronecker_factors.factor_matrices.dequantize_() + kronecker_factors.factor_matrices_eigenvectors.dequantize_() + kronecker_factors.corrected_eigenvalues.dequantize_() + + def quantize_preconditioners(self) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.quantize_preconditioners.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + kronecker_factors.factor_matrices.quantize_() + kronecker_factors.factor_matrices_eigenvectors.quantize_() + kronecker_factors.corrected_eigenvalues.quantize_() + + class DequantizePreconditionersContext(ParameterizeEnterExitContext): """DequantizePreconditionersContext is used for automatically dequantize and then quantize the preconditioners used within this context. diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 3f84622..56e9687 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -21,12 +21,13 @@ from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, DequantizePreconditionersContext, + EigenvalueCorrectedShampooPreconditionerList, PreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import EigenConfig +from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig, EigenConfig from torch import Tensor @@ -73,9 +74,12 @@ def _test_update_preconditioners_and_precondition( preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), + # Only compute the new layerwise direction when the update_preconditioners() reach the last step. + perform_amortized_computation=isinstance( + preconditioner_list, ShampooPreconditionerList + ) + and step == len(masked_grad_lists), ) - if isinstance(preconditioner_list, ShampooPreconditionerList): - preconditioner_list.compute_root_inverse() masked_preconditioned_grad_list = preconditioner_list.precondition( masked_grad_list=masked_grad_lists[-1] ) @@ -472,7 +476,7 @@ def test_inverse_roots_from_override( """ Tests that the inverse roots are computed correctly from inv_root_override. """ - root_inv_config = EigenConfig(exponent_multiplier=2.0) + preconditioner_computation_config = EigenConfig(exponent_multiplier=2.0) masked_grad_list1 = ( torch.tensor([1.0, 0.0]), @@ -497,7 +501,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - root_inv_config=root_inv_config, + preconditioner_computation_config=preconditioner_computation_config, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, @@ -510,37 +514,37 @@ def test_raise_inf_in_factor_matrix_compute_root_inverse(self) -> None: with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.inf, torch.inf]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.inf, torch.inf]]), - ), - step=torch.tensor(1), - ) with self.assertRaisesRegex( PreconditionerValueError, - re.escape("Encountered inf values in bias-corrected factor matrix"), + re.escape("Encountered inf values in factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.inf, torch.inf]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.inf, torch.inf]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) def test_raise_nan_in_factor_matrix_compute_root_inverse(self) -> None: with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.nan, torch.nan]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.nan, torch.nan]]), - ), - step=torch.tensor(1), - ) with self.assertRaisesRegex( PreconditionerValueError, - re.escape("Encountered nan values in bias-corrected factor matrix"), + re.escape("Encountered nan values in factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.nan, torch.nan]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.nan, torch.nan]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) # Note: This is needed for pyre to infer the type of argument into mock.patch.object. shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list @@ -559,7 +563,15 @@ def test_raise_inf_in_inv_factor_matrix_compute_root_inverse( PreconditionerValueError, re.escape("Encountered nan or inf values in inverse factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called_once() @mock.patch.object( @@ -576,7 +588,15 @@ def test_raise_nan_in_inv_factor_matrix_compute_root_inverse( PreconditionerValueError, re.escape("Encountered nan or inf values in inverse factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called_once() @mock.patch.object( @@ -592,20 +612,28 @@ def test_matrix_compute_root_inverse_internal_failure( preconditioner_list=self._preconditioner_list ), self.assertLogs(level="WARNING") as cm: # Because use_protected_eigh is True, we expect the warning to be logged. - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) self.assertCountEqual( [r.msg for r in cm.records], [ - "Matrix inverse root computation failed for factor matrix 0.block_0.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_0.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_0.1 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_1.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_1.1 with exception ." - " Using previous inv_factor_matrix and continuing...", + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_1.1 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", ], ) mock_matrix_inverse_root.assert_called() @@ -617,7 +645,15 @@ def test_matrix_compute_root_inverse_internal_failure( with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ), self.assertRaises(ZeroDivisionError): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called() @mock.patch.object( @@ -634,7 +670,15 @@ def test_matrix_compute_root_inverse_factor_matrix_non_diagonal( ), self.assertLogs( level="DEBUG", ) as cm: - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) self.assertCountEqual( [r.msg for r in cm.records], [ @@ -691,12 +735,13 @@ def test_compute_root_inverse_residuals(self) -> None: preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list1, step=torch.tensor(1), + perform_amortized_computation=False, ) preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list2, step=torch.tensor(2), + perform_amortized_computation=True, ) - preconditioner_list.compute_root_inverse() # Expect no relative errors and residuals because L is a diagonal matrix. ( @@ -709,3 +754,407 @@ def test_compute_root_inverse_residuals(self) -> None: self.assertTupleEqual(relative_errors, expected_relative_errors) self.assertTupleEqual(relative_residuals, expected_relative_residuals) + + +class EigenvalueCorrectedShampooPreconditionerListTest(AdagradPreconditionerListTest): + def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: + kwargs = { + "beta2": 1.0, + "epsilon": 1e-12, + "inv_root_override": 0, + "use_bias_correction": True, + "use_protected_eigh": True, + "preconditioner_computation_config": DefaultEighEigenvalueCorrectionConfig, + } | kwargs + return EigenvalueCorrectedShampooPreconditionerList( + block_list=self._block_list, + state=self._state, + block_info_list=self._block_info_list, + distributor_selector=self._distributor_selector, + precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), + **kwargs, + ) + + def test_update_preconditioners_and_precondition(self) -> None: + """ + We provide examples where we update the preconditioners twice using specially + chosen gradients such that we get a scalar * identity matrix for both Kronecker + factor matrices for all parameters of interest. + + Specifically, for the beta2 = 1 case, we have 3 parameters and define their gradients + as the following in order to get the expected preconditioned gradient list: + + (1) Tensor of Size 2 + G1 = [1, 0]^T + G2 = [0, 1]^T + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + B = [[1, 0], [0, 1]] # eigenvectors of L + E = G1^2 + (B G2)^2 # corrected eigenvalues + P = B ((B G2) / sqrt(E + eps)) = G2 / sqrt(E + eps) ≈ G2 + + (2) Tensor of Size 2 x 2 + G1 = [[1, 0], [0, 1]] / sqrt(2) + G2 = [[1, 0], [0, 1]] / sqrt(2) + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = [[1, 0], [0, 1]] # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / sqrt(E + eps) B_R = G2 / sqrt(E + eps) ≈ G2 + + (3) Tensor of Size 1 x 2 + G1 = [[1, 0]] + G2 = [[0, 1]] + + L = G1 * G1^T + G2 * G2^T = 2 + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = 1 # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / sqrt(E + eps) B_R = G2 / sqrt(E + eps) ≈ G2 + + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ) + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + use_bias_correction=True, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + """ + For the other two cases (beta2 < 1), note: + + E = beta2 * (1 - beta2) G1^2 + (1 - beta2) G2^2 + + Therefore, in order to retain the identity matrix, we simply need to scale each gradient by: + + G1 -> G1 / sqrt(beta2 * (1 - beta2)) + G2 -> G2 / sqrt(1 - beta2). + + """ + beta2 = 0.9 + + beta2_compensated_grad_list1 = torch._foreach_div( + masked_grad_list1, + torch.tensor(beta2 * (1 - beta2)).sqrt(), + ) + beta2_compensated_grad_list2 = torch._foreach_div( + masked_grad_list2, + torch.tensor(1 - beta2).sqrt(), + ) + + masked_expected_preconditioned_grad_list = [ + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ] + # Fix scaling due to EMA. + torch._foreach_div_( + masked_expected_preconditioned_grad_list, + torch.tensor(1 - beta2).sqrt(), + ) + + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=beta2, + use_bias_correction=False, + ), + masked_grad_lists=[ + beta2_compensated_grad_list1, + beta2_compensated_grad_list2, + ], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + """ + For the last case of including bias correction, we re-scale the entire matrix by the + bias correction at iteration 2. + + E -> E / (1 - beta2^2). + + Therefore, it is sufficient to additionally scale by this value: + + G1 -> sqrt(1 - beta2^2) * G1 + G2 -> sqrt(1 - beta2^2) * G2. + + """ + bias_compensated_grad_list1 = torch._foreach_mul( + beta2_compensated_grad_list1, + torch.tensor(1 - beta2**2).sqrt(), + ) + bias_compensated_grad_list2 = torch._foreach_mul( + beta2_compensated_grad_list2, + torch.tensor(1 - beta2**2).sqrt(), + ) + + # Fix scaling due to bias correction. + torch._foreach_mul_( + masked_expected_preconditioned_grad_list, + torch.tensor(1 - beta2**2).sqrt(), + ) + + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=beta2, + use_bias_correction=True, + ), + masked_grad_lists=[ + bias_compensated_grad_list1, + bias_compensated_grad_list2, + ], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + def test_inv_root_override(self) -> None: + """ + For this example, we modify the one given above such that the inv_root_override = 1. + Therefore, in all cases, the exponent is -2 / 2 = -1. + This should result in the following behavior: + + (1) Tensor of Size 2 + G1 = [1, 0]^T + G2 = [0, 2]^T + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 4]] + B = [[1, 0], [0, 1]] # eigenvectors of L + E = G1^2 + (B G2)^2 # corrected eigenvalues + P = B ((B G2) / (E + eps) = G2 / (E + eps) ≈ [0, 0.5]^T + + (2) Tensor of Size 2 x 2 + G1 = [[1, 0], [0, 1]] / sqrt(2) + G2 = [[1, 0], [0, 1]] / sqrt(2) + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = [[1, 0], [0, 1]] # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / (E + eps) B_R = G2 / (E + eps) ≈ G2 + + (3) Tensor of Size 1 x 2 + G1 = [[1, 0]] + G2 = [[0, 2]] + + L = G1 * G1^T + G2 * G2^T = 2 + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 4]] + B_L = 1 # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / (E + eps)) B_R = G2 / (E + eps) ≈ [[0, 0.5]] + + """ + + def test_inverse_roots_from_override( + inv_root_override: Union[int, List[int]], + ) -> None: + """ + Tests that the inverse roots are computed correctly from inv_root_override. + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 2.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 2.0]]), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0, 0.5]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0, 0.5]]), + ) + + with self.subTest(inv_root_override=inv_root_override): + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + use_bias_correction=True, + inv_root_override=inv_root_override, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + test_inverse_roots_from_override(inv_root_override=1) + test_inverse_roots_from_override(inv_root_override=[1, 1, 1]) + + """Tests for compute_preconditioner_eigenvectors.""" + + def test_raise_inf_in_factor_matrix_compute_preconditioner_eigenvectors( + self, + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ): + with self.assertRaisesRegex( + ValueError, + re.escape("Encountered inf values in factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.inf, torch.inf]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.inf, torch.inf]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + + def test_raise_nan_in_factor_matrix_compute_preconditioner_eigenvectors( + self, + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ): + with self.assertRaisesRegex( + ValueError, + re.escape("Encountered nan values in factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.nan, torch.nan]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.nan, torch.nan]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + + # Note: This is needed for pyre to infer the type of argument into mock.patch.object. + shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list + + @mock.patch.object( + shampoo_preconditioner_list_module, + "matrix_eigenvectors", + side_effect=(torch.tensor([torch.inf]),), + ) + def test_raise_inf_in_inv_factor_matrix_compute_preconditioner_eigenvectors( + self, mock_matrix_eigenvectors: mock.Mock + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_matrix_eigenvectors.assert_called_once() + + @mock.patch.object( + shampoo_preconditioner_list_module, + "matrix_eigenvectors", + side_effect=(torch.tensor([torch.nan]),), + ) + def test_raise_nan_in_inv_factor_matrix_compute_preconditioner_eigenvectors( + self, mock_matrix_eigenvectors: mock.Mock + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_matrix_eigenvectors.assert_called_once() + + @mock.patch.object( + shampoo_preconditioner_list_module, + "check_diagonal", + return_value=False, + ) + def test_matrix_compute_preconditioner_eigenvectors_factor_matrix_non_diagonal( + self, mock_check_diagonal: mock.Mock + ) -> None: + self._preconditioner_list = self._instantiate_preconditioner_list(epsilon=1.0) + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertLogs( + level="DEBUG", + ) as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + self.assertCountEqual( + [r.msg for r in cm.records], + [ + "Factor matrix 0.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.1 is not diagonal.", + "Factor matrix 1.block_1.0 is not diagonal.", + "Factor matrix 1.block_1.1 is not diagonal.", + ], + ) + mock_check_diagonal.assert_called() + + """End of tests for compute_preconditioner_eigenvectors.""" + + def test_numel_list(self) -> None: + self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) + + def test_dims_list(self) -> None: + self.assertEqual( + self._preconditioner_list.dims_list, + (torch.Size([2]), torch.Size([2, 2]), torch.Size([1, 2])), + ) + + def test_num_bytes_list(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes_list, (48, 96, 60)) + + def test_numel(self) -> None: + self.assertEqual(self._preconditioner_list.numel(), 34) + + def test_num_bytes(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes(), 204) + + def test_compress_preconditioner_list(self) -> None: + self._test_compress_preconditioner_list(expected_compress_list_call_count=3)