From f3451cdfce8a123b6161220f58b782b7cebac7d1 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Mon, 4 Nov 2024 10:16:39 -0800 Subject: [PATCH] Add option to correct eigenvalues of Shampoo's preconditioner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This update is based on https://github.com/facebookresearch/optimizers/pull/27, developed by Runa Eschenhagen (runame) and Tsung-Hsien Lee (tsunghsienlee). The research idea in this update originated from Runa Eschenhagen's internship at The Fundamental AI Research (FAIR) at Meta during the summer of 2024. Concurrently, Runa Eschenhagen, Michael Shi (hjmshi), Aaron Defazio (adefazio) worked on this method, which was also empirically evaluated on language models by Nikhil Vyas et al. [3], showing promising results. This update enables approximately correcting the eigenvalues and running Adam in the eigenbasis of Shampoo's preconditioner. A variation of this method was first proposed for K-FAC by George et al. [1], and Anil et al. [2] noted its applicability to Shampoo in Appendix B, although they did not present empirical results or further discussion. References: 1. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018. 2. [Scalable Second-Order Optimization for Deep Learning](https://arxiv.org/pdf/2002.09018.pdf). Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, and Yoram Singer. Tech Report, 2021. 3. [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade. Tech Report, 2024. Reviewed By: hjmshi Differential Revision: D65402620 fbshipit-source-id: 8ea4f761cfae04c5622a968cb499654816e4aa3e --- distributed_shampoo/README.md | 5 + distributed_shampoo/distributed_shampoo.py | 116 +- .../shampoo_eigenvalue_correction_test.py | 241 ++++ distributed_shampoo/shampoo_types.py | 7 +- .../tests/distributed_shampoo_test.py | 126 +- .../utils/shampoo_preconditioner_list.py | 1041 +++++++++++++---- .../tests/shampoo_preconditioner_list_test.py | 531 ++++++++- 7 files changed, 1731 insertions(+), 336 deletions(-) create mode 100644 distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 84490c7..8b8cd11 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -37,6 +37,7 @@ Key distinctives of this implementation include: - Choice of precision for preconditioner accumulation and root inverse computation. - Ability to cache split parameters. - Merging of small dimensions. +- [EXPERIMENTAL] Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner [2,6,7]. ## Requirements @@ -62,6 +63,8 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. +- When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. + ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum If we previously used the optimizer: @@ -479,3 +482,5 @@ When encountering those errors, following are things you could try: 3. [Learning Rate Grafting: Transferability of Optimizer Tuning](https://openreview.net/pdf?id=FpKgG31Z_i9). Naman Agarwal, Rohan Anil, Elad Hazan, Tomer Koren, and Cyril Zhang. Tech Report, 2021. 4. [Functions of Matrices: Theory and Computation](https://epubs.siam.org/doi/book/10.1137/1.9780898717778). Nicholas J. Higham. SIAM, 2008. 5. [A Distributed Data-Parallel PyTorch Implementation of the Distributed Shampoo Optimizer for Training Neural Networks At-Scale](https://arxiv.org/pdf/2309.06497.pdf). Hao-Jun Michael Shi, Tsung-Hsien Lee, Shintaro Iwasaki, Jose Gallego-Posada, Zhijing Li, Kaushik Rangadurai, Dheevatsa Mudigere, and Michael Rabbat. Tech Report, 2023. +6. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018. +7. [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade. Tech Report, 2024. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 6e8f197..b0bf3a1 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -58,9 +58,9 @@ PRECISION_CONFIG, PrecisionConfig, PRECONDITION_FREQUENCY, + PRECONDITIONER_COMPUTATION_CONFIG, PREVIOUS_GRAD_SELECTOR, RMSpropGraftingConfig, - ROOT_INV_CONFIG, SGDGraftingConfig, SHAMPOO_PRECONDITIONER_LIST, ShampooPT2CompileConfig, @@ -90,6 +90,7 @@ from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, DequantizePreconditionersContext, + EigenvalueCorrectedShampooPreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) @@ -100,7 +101,13 @@ ) from distributed_shampoo.utils.shampoo_utils import compress_list -from matrix_functions_types import DefaultEigenConfig, EigenConfig, RootInvConfig +from matrix_functions_types import ( + DefaultEigenConfig, + EigenConfig, + EigenvalueCorrectionConfig, + PreconditionerComputationConfig, + RootInvConfig, +) from torch.optim.optimizer import ParamsT, StateDict logger: logging.Logger = logging.getLogger(__name__) @@ -214,6 +221,16 @@ class DistributedShampoo(torch.optim.Optimizer): particular tensor shape. Recommended to use `static` mode here for Shampoo. More about dynamic shape: https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html + 5. [EXPERIMENTAL] Eigenvalue correction: We can (approximately) correct the eigenvalues of Shampoo's preconditioner by accumulating a running + average of the squared gradient in the eigenbasis of Shampoo's preconditioner. This running average (with hyperparameter `betas[1]`) is + updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. + Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner. + + When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning + rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be + a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. + Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. + Args: params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups. lr (float): Learning rate. (Default: 1e-2) @@ -228,13 +245,18 @@ class DistributedShampoo(torch.optim.Optimizer): dampening (float): Dampening parameter for momentum. (Default: 0.) weight_decay (float): Weight decay (L2 penalty). (Default: 0.) max_preconditioner_dim (int): Maximum preconditioner dimensio. (Default: 1024) - precondition_frequency (int): Frequency for computing root inverse preconditioner. (Default: 1) + precondition_frequency (int): Frequency of updating all components of the preconditioner. + If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner. + (Default: 1) start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses the same value as precondition_frequency. (Default: -1) inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse - root -1 / (2 * o), where o is the order of the tensor. (Default: 0) + root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of + EigenvalueCorrectionConfig, the default is -1 / 2. + (Default: 0) exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the exponent is -eta / (2 * p). (Default: None) use_nesterov (bool): Flag for using Nesterov momentum. (default: False) @@ -259,7 +281,10 @@ class DistributedShampoo(torch.optim.Optimizer): 3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail. track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes. (Default: False) - root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) + preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. + If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner. + If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner. + (Default: DefaultEigenConfig) """ @@ -290,7 +315,7 @@ def __init__( precision_config: Optional[PrecisionConfig] = None, use_protected_eigh: bool = True, track_root_inv_residuals: bool = False, - root_inv_config: RootInvConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, ) -> None: # Hyperparameter checks. if not lr >= 0.0: @@ -404,17 +429,28 @@ def __init__( "Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated." ) + if ( + not isinstance(preconditioner_computation_config, RootInvConfig) + ) and track_root_inv_residuals: + raise ValueError( + f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig." + ) + # Create default precision config if it is not provided. if precision_config is None: precision_config = PrecisionConfig() # Set exponent multiplier if this is not provided. - if isinstance(root_inv_config, EigenConfig) and exponent_multiplier is not None: + if ( + isinstance(preconditioner_computation_config, EigenConfig) + and exponent_multiplier is not None + ): logger.warning( f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future." ) - root_inv_config = dataclasses.replace( - root_inv_config, exponent_multiplier=exponent_multiplier + preconditioner_computation_config = dataclasses.replace( + preconditioner_computation_config, + exponent_multiplier=exponent_multiplier, ) super().__init__( @@ -437,7 +473,7 @@ def __init__( GRAFTING_CONFIG: grafting_config, USE_MERGE_DIMS: use_merge_dims, PRECISION_CONFIG: precision_config, - ROOT_INV_CONFIG: root_inv_config, + PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config, }, ) @@ -508,17 +544,25 @@ def _instantiate_shampoo_preconditioner_list( for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - state_lists[SHAMPOO_PRECONDITIONER_LIST] = ShampooPreconditionerList( + state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( + EigenvalueCorrectedShampooPreconditionerList + if isinstance( + group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig + ) + else ShampooPreconditionerList + )( block_list=state_lists[DISTRIBUTOR].global_blocked_params, state=self.state, block_info_list=state_lists[DISTRIBUTOR].global_block_info_list, distributor_selector=state_lists[DISTRIBUTOR].distributor_selector, - root_inv_config=group[ROOT_INV_CONFIG], + preconditioner_computation_config=group[ + PRECONDITIONER_COMPUTATION_CONFIG + ], + precision_config=group[PRECISION_CONFIG], beta2=group[BETAS][1], epsilon=group[EPSILON], inv_root_override=group[INV_ROOT_OVERRIDE], use_bias_correction=group[USE_BIAS_CORRECTION], - precision_config=group[PRECISION_CONFIG], use_protected_eigh=use_protected_eigh, ) @@ -755,6 +799,7 @@ def _mask_state_lists(state_lists: Dict[str, Any], group: Dict[str, Any]) -> Non ) @torch.no_grad() + @torch.compiler.disable def _compute_and_log_root_inverse_residuals( self, ) -> None: @@ -806,16 +851,6 @@ def _compute_and_log_root_inverse_residuals( f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}" ) - @torch.no_grad() - @torch.compiler.disable - def _compute_root_inverse( - self, state_lists: Dict[str, Any], compute_root_inverse: bool - ) -> None: - if compute_root_inverse: - state_lists[SHAMPOO_PRECONDITIONER_LIST].compute_root_inverse() - if self._track_root_inv_residuals: - self._compute_and_log_root_inverse_residuals() - @torch.no_grad() @torch.compiler.disable def _precondition_and_grafting( @@ -881,13 +916,17 @@ def _update_preconditioners( self, state_lists: Dict[str, Any], step: torch.Tensor, + perform_amortized_computation: bool, grafting_config_not_none: bool, ) -> None: - # Update Shampoo and grafting preconditioners / factor matrices. + # Update Shampoo and grafting preconditioners. state_lists[SHAMPOO_PRECONDITIONER_LIST].update_preconditioners( masked_grad_list=state_lists[MASKED_BLOCKED_GRADS], step=step, + perform_amortized_computation=perform_amortized_computation, ) + if perform_amortized_computation and self._track_root_inv_residuals: + self._compute_and_log_root_inverse_residuals() if grafting_config_not_none: state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners( masked_grad_list=state_lists[MASKED_BLOCKED_GRADS], @@ -1005,7 +1044,7 @@ def _per_group_step_impl( momentum_param: float, dampening: float, grafting_config_not_none: bool, - compute_root_inverse: bool, + perform_amortized_computation: bool, use_decoupled_weight_decay: bool, use_bias_correction: bool, use_grafting_method: bool, @@ -1028,23 +1067,23 @@ def _per_group_step_impl( if grafting_config_not_none else contextlib.nullcontext() ): - # Update Shampoo and grafting preconditioners / factor matrices. - # Example for AdaGrad accumulation: + # Update Shampoo and grafting preconditioners. + # Example for AdaGrad accumulation: + # 1. Update factor matrices/grafting preconditioners. # L <- L + G * G^T # R <- R + G^T * G # V <- V + G^2 (element-wise) # (and similar) - self._update_preconditioners( - state_lists, - step, - grafting_config_not_none, - ) - - # Compute matrix root inverse. + # 2. Compute root inverse if necessary. # L_inv <- L ** (-1/4) # R_inv <- R ** (-1/4) - # (and similar) - self._compute_root_inverse(state_lists, compute_root_inverse) + # (and similar); + self._update_preconditioners( + state_lists=state_lists, + step=step, + perform_amortized_computation=perform_amortized_computation, + grafting_config_not_none=grafting_config_not_none, + ) # Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0. # Note that we use two beta factors here akin to Lion. @@ -1157,8 +1196,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] momentum_param = group[MOMENTUM] dampening = group[DAMPENING] grafting_config_not_none = group[GRAFTING_CONFIG] is not None - # Check compute root inverse or not for preconditioner - compute_root_inverse = ( + perform_amortized_computation = ( step.item() % group[PRECONDITION_FREQUENCY] == 0 and step.item() > group[START_PRECONDITIONING_STEP] or step.item() == group[START_PRECONDITIONING_STEP] @@ -1182,7 +1220,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] momentum_param, dampening, grafting_config_not_none, - compute_root_inverse, + perform_amortized_computation, use_decoupled_weight_decay, use_bias_correction, use_grafting_method, diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py new file mode 100644 index 0000000..116dbbf --- /dev/null +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -0,0 +1,241 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +#!/usr/bin/env python3 + +import math +import unittest +from functools import partial +from itertools import product +from typing import Any, Callable, Type + +import torch +from distributed_shampoo.distributed_shampoo import DistributedShampoo +from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem +from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig +from torch.nn.parameter import Parameter +from torch.optim.adagrad import Adagrad +from torch.optim.adam import Adam +from torch.optim.adamw import AdamW +from torch.optim.optimizer import ParamsT +from torch.optim.rmsprop import RMSprop + + +# Note: We have to set the epsilon to a very small value (i.e., 1e-15) due to the +# the place epsilon is added in the PyTorch optimizers (i.e., AdaGrad, RMSProp, Adam, AdamW) +# and Distributed Shampoo. +# The PyTorch optimizers add epsilon outside of the square root, and Distributed Shampoo +# adds epsilon inside of the square root. + + +class DistributedShampooEigenvalueCorrectionTest(unittest.TestCase): + @staticmethod + def _train_quadratic( + optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + device: torch.device, + ) -> tuple[Parameter, torch.Tensor]: + model, loss, data, target = construct_training_problem( + model_linear_layers_dims=(10, 1, 1), + device=device, + fill=1.0, + ) + params = model.parameters() + optimizer = optim_factory(params) + for _ in range(5): + optimizer.zero_grad() + objective = loss(model(data), target) + objective.backward() + optimizer.step() + return model.linear_layers[0].weight.data.cpu(), objective.detach().cpu() + + @staticmethod + def _test_baseline_and_shampoo( + baseline_optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + shampoo_optim_factory: Callable[ + [ParamsT], + torch.optim.Optimizer, + ], + device: torch.device, + ) -> None: + ( + baseline_params, + baseline_loss, + ) = DistributedShampooEigenvalueCorrectionTest._train_quadratic( + baseline_optim_factory, + device=device, + ) + shampoo_params, shampoo_loss = ( + DistributedShampooEigenvalueCorrectionTest._train_quadratic( + shampoo_optim_factory, + device=device, + ) + ) + torch.testing.assert_close(shampoo_loss, baseline_loss) + torch.testing.assert_close( + shampoo_params, + baseline_params, + ) + + @staticmethod + def _optim_factory( + parameters: ParamsT, + optim_cls: Type[torch.optim.Optimizer], + **kwargs: Any, + ) -> torch.optim.Optimizer: + return optim_cls(parameters, **kwargs) + + def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.01, + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, optim_cls=Adagrad, eps=1e-15 + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + betas=(0.0, 1.0), + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_adam_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.001, + betas=(0.9, 0.999), + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=Adam, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.001, + betas=(0.9, 0.999), + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=AdamW, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=True, + grafting_config=None, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) + + def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: + # Test with and without weight decay, and with CPU or GPU + for weight_decay, device in product( + (0.0, 0.3), + (torch.device("cpu"),) + (torch.device("cuda"),) + if torch.cuda.is_available() + else (), + ): + optim_factory = partial( + DistributedShampooEigenvalueCorrectionTest._optim_factory, + lr=0.01, + weight_decay=weight_decay, + ) + with self.subTest(weight_decay=weight_decay, device=device): + DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( + baseline_optim_factory=partial( + optim_factory, + optim_cls=RMSprop, + alpha=0.99, + eps=1e-15, + ), + shampoo_optim_factory=partial( + optim_factory, + optim_cls=DistributedShampoo, + betas=(0.0, 0.99), + epsilon=1e-15, + momentum=0.0, + max_preconditioner_dim=10, + precondition_frequency=1, + start_preconditioning_step=math.inf, + use_decoupled_weight_decay=False, + grafting_config=None, + use_bias_correction=False, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + device=device, + ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index da890d7..60af2f4 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -34,8 +34,9 @@ PRECISION_CONFIG = "precision_config" PRECONDITION_FREQUENCY = "precondition_frequency" PRECONDITIONER_DTYPE = "preconditioner_dtype" -ROOT_INV_CONFIG = "root_inv_config" +PRECONDITIONER_COMPUTATION_CONFIG = "preconditioner_computation_config" START_PRECONDITIONING_STEP = "start_preconditioning_step" +USE_EIGENVALUE_CORRECTION = "use_eigenvalue_correction" USE_BIAS_CORRECTION = "use_bias_correction" USE_DECOUPLED_WEIGHT_DECAY = "use_decoupled_weight_decay" USE_MERGE_DIMS = "use_merge_dims" @@ -102,6 +103,8 @@ class PrecisionConfig: factor_matrix_dtype (torch.dtype): Data type for storing Shampoo factor matrices. (Default: torch.float32) inv_factor_matrix_dtype (torch.dtype): Data type for storing Shampoo inverse factor matrices. (Default: torch.float32) factor_matrix_computation_dtype (torch.dtype): Data type for accumulating factor matrices and computing their inverses. (Default: torch.float32) + corrected_eigenvalues_dtype (torch.dtype): Data type for storing the corrected eigenvalues of Shampoo preconditioner (EMA). (Default: torch.float32) + factor_matrix_eigenvectors_dtype (torch.dtype): Data type for storing the eigenvectors of Shampoo factor matrices. (Default: torch.float32) filtered_grad_dtype (torch.dtype): Data type for storing filtered gradients (EMA). (Default: torch.float32) momentum_dtype (torch.dtype): Data type for storing momentum states. (Default: torch.float32) grafting_state_dtype (torch.dtype): Data type for storing grafting preconditioners, if applicable. (Default: torch.float32) @@ -117,6 +120,8 @@ class PrecisionConfig: computation_dtype: torch.dtype = torch.float32 factor_matrix_dtype: torch.dtype = torch.float32 inv_factor_matrix_dtype: torch.dtype = torch.float32 + corrected_eigenvalues_dtype: torch.dtype = torch.float32 + factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32 factor_matrix_computation_dtype: torch.dtype = torch.float32 filtered_grad_dtype: torch.dtype = torch.float32 momentum_dtype: torch.dtype = torch.float32 diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 5c29dd8..942de40 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -36,7 +36,10 @@ ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import DefaultEigenConfig +from matrix_functions_types import ( + DefaultEigenConfig, + DefaultEighEigenvalueCorrectionConfig, +) from torch import nn @@ -238,7 +241,7 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: lr=0.01, start_preconditioning_step=1, exponent_multiplier=2.0, - root_inv_config=DefaultEigenConfig, + preconditioner_computation_config=DefaultEigenConfig, ) self.assertCountEqual( [r.msg for r in cm.records], @@ -247,6 +250,21 @@ def test_setting_exponent_multiplier_with_eigen_config(self) -> None: ], ) + def test_conflict_eigenvalue_correction_and_track_root_inv_residuals(self) -> None: + with self.assertRaisesRegex( + ValueError, + re.escape( + "track_root_inv_residuals=True has to be set to False when preconditioner_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True) is not an instance of RootInvConfig." + ), + ): + DistributedShampoo( + self._model.parameters(), + lr=0.01, + start_preconditioning_step=1, + track_root_inv_residuals=True, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ) + class DistributedShampooTest(unittest.TestCase): def setUp(self) -> None: @@ -460,7 +478,7 @@ def setUp(self) -> None: ), "use_merge_dims": True, "precision_config": PrecisionConfig(), - "root_inv_config": DefaultEigenConfig, + "preconditioner_computation_config": DefaultEigenConfig, } }, } @@ -835,6 +853,108 @@ def test_setting_preconditioner_dtype_only(self) -> None: ) +class EigenvalueCorrectedDistributedShampooPrecisionTest( + DistributedShampooPrecisionTest +): + def _instantiate_optimizer( + self, precision_config: PrecisionConfig + ) -> DistributedShampoo: + return DistributedShampoo( + self._model.parameters(), + lr=0.01, + betas=(0.9, 1.0), + epsilon=1e-12, + momentum=0.99, + weight_decay=0.0, + max_preconditioner_dim=5, + precondition_frequency=1, + start_preconditioning_step=1, + distributed_config=None, + grafting_config=None, + precision_config=precision_config, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ) + + def _assert_state_list_dtype( + self, state_list: Dict[str, Any], precision_config: PrecisionConfig + ) -> None: + # TODO: is it possible to avoid accessing private field _masked_kronecker_factors_list? + for kronecker_factor in state_list[ + SHAMPOO_PRECONDITIONER_LIST + ]._masked_kronecker_factors_list: + self._assert_equal_state_dtype( + kronecker_factor.factor_matrices, + precision_config.factor_matrix_computation_dtype, + precision_config.factor_matrix_dtype, + ) + self._assert_equal_state_dtype( + kronecker_factor.factor_matrices_eigenvectors, + precision_config.computation_dtype, + precision_config.factor_matrix_eigenvectors_dtype, + ) + self._assert_equal_state_dtype( + kronecker_factor.corrected_eigenvalues, + precision_config.computation_dtype, + precision_config.corrected_eigenvalues_dtype, + ) + self._assert_equal_state_dtype( + state_list[MASKED_FILTERED_GRAD_LIST], + precision_config.computation_dtype, + precision_config.filtered_grad_dtype, + ) + self._assert_equal_state_dtype( + state_list[MASKED_MOMENTUM_LIST], + precision_config.computation_dtype, + precision_config.momentum_dtype, + ) + + def test_precision_configs(self) -> None: + precision_configs = [ + PrecisionConfig(computation_dtype=torch.float16), + PrecisionConfig(factor_matrix_dtype=torch.float16), + PrecisionConfig(factor_matrix_eigenvectors_dtype=torch.float16), + PrecisionConfig(corrected_eigenvalues_dtype=torch.float16), + PrecisionConfig(filtered_grad_dtype=torch.float16), + PrecisionConfig(momentum_dtype=torch.float16), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + ), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float16, + ), + PrecisionConfig( + factor_matrix_dtype=torch.float16, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float16, + filtered_grad_dtype=torch.float16, + momentum_dtype=torch.float16, + ), + PrecisionConfig(factor_matrix_computation_dtype=torch.float64), + PrecisionConfig( + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float16, + corrected_eigenvalues_dtype=torch.float64, + factor_matrix_computation_dtype=torch.float64, + ), + ] + + for precision_config in precision_configs: + with self.subTest(precision_config=precision_config): + optimizer = self._instantiate_optimizer( + precision_config=precision_config + ) + for state_list in optimizer._per_group_state_lists: + self._assert_state_list_dtype(state_list, precision_config) + + for _ in range(2): + optimizer.step() + for state_list in optimizer._per_group_state_lists: + self._assert_state_list_dtype(state_list, precision_config) + + class DistributedShampooNoneGradTest(unittest.TestCase): def setUp(self) -> None: self._model = nn.Sequential( diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index baef780..7a8fed3 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -9,12 +9,14 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field from fractions import Fraction +from functools import partial, reduce from itertools import chain from operator import methodcaller -from typing import Any, DefaultDict, Sequence, Tuple, Union +from typing import Any, cast, DefaultDict, Generic, Sequence, TypeVar import torch from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError @@ -32,10 +34,16 @@ from matrix_functions import ( check_diagonal, compute_matrix_root_inverse_residuals, + matrix_eigenvectors, matrix_inverse_root, ) -from matrix_functions_types import DefaultEigenConfig, RootInvConfig +from matrix_functions_types import ( + DefaultEigenConfig, + EigenvalueCorrectionConfig, + PreconditionerComputationConfig, + RootInvConfig, +) from optimizer_modules import OptimizerModule from torch import Tensor from torch.autograd import profiler @@ -51,36 +59,37 @@ class PreconditionerList(ABC): """Preconditioner base class. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. """ def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], ) -> None: super().__init__() - self._numel_list: Tuple[int, ...] = (0,) * len(block_list) - self._dims_list: Tuple[torch.Size, ...] = tuple( + self._numel_list: tuple[int, ...] = (0,) * len(block_list) + self._dims_list: tuple[torch.Size, ...] = tuple( block.size() for block in block_list ) - self._num_bytes_list: Tuple[int, ...] = (0,) * len(block_list) + self._num_bytes_list: tuple[int, ...] = (0,) * len(block_list) @abstractmethod def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool, ) -> None: ... @abstractmethod def precondition( - self, masked_grad_list: Tuple[Tensor, ...] - ) -> Tuple[Tensor, ...]: ... + self, masked_grad_list: tuple[Tensor, ...] + ) -> tuple[Tensor, ...]: ... @abstractmethod def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: ... @abstractmethod @@ -90,15 +99,15 @@ def dequantize_preconditioners(self) -> None: ... def quantize_preconditioners(self) -> None: ... @property - def numel_list(self) -> Tuple[int, ...]: + def numel_list(self) -> tuple[int, ...]: return self._numel_list @property - def dims_list(self) -> Tuple[torch.Size, ...]: + def dims_list(self) -> tuple[torch.Size, ...]: return self._dims_list @property - def num_bytes_list(self) -> Tuple[int, ...]: + def num_bytes_list(self) -> tuple[int, ...]: return self._num_bytes_list def numel(self) -> int: @@ -112,28 +121,29 @@ class SGDPreconditionerList(PreconditionerList): """SGD (identity) preconditioners for a list of parameters. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. """ def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], ) -> None: super().__init__(block_list) def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool = False, ) -> None: return - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: return masked_grad_list def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: return @@ -158,11 +168,11 @@ class AdagradPreconditionerList(PreconditionerList): Other variants can also be specified. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. state (DefaultDict[Tensor, Any]): Dictionary containing optimizer state. - block_info_list (Tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. + block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (Tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter + distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. beta2 (float): Exponential moving average factor for Adam/RMSprop second moment state. If beta2 = 1., will use unweighted sum. (Default: 1.0) @@ -174,11 +184,11 @@ class AdagradPreconditionerList(PreconditionerList): def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], # type: ignore state: DefaultDict[Tensor, Any], - block_info_list: Tuple[BlockInfo, ...], - distributor_selector: Tuple[bool, ...], + block_info_list: tuple[BlockInfo, ...], + distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, beta2: float = 1.0, epsilon: float = 1e-10, @@ -208,7 +218,9 @@ def __init__( preconditioner_index = str(param_index) + "." + str(block_index) block_state[ADAGRAD] = QuantizedTensor( quantized_values=block_info.allocate_zeros_tensor( - block.size(), precision_config.grafting_state_dtype, block.device + shape=block.size(), + dtype=precision_config.grafting_state_dtype, + device=block.device, ), block_info=block_info, ) @@ -230,22 +242,23 @@ def __init__( ) # Construct lists of dims, bytes, and numels for logging purposes. - self._dims_list: Tuple[torch.Size, ...] = compress_list( + self._dims_list: tuple[torch.Size, ...] = compress_list( self._dims_list, distributor_selector ) - self._numel_list: Tuple[int, ...] = tuple( + self._numel_list: tuple[int, ...] = tuple( quantized_preconditioner.numel() for quantized_preconditioner in self._local_preconditioner_list.quantized_value ) - self._num_bytes_list: Tuple[int, ...] = tuple( + self._num_bytes_list: tuple[int, ...] = tuple( quantize_preconditioner.numel() * quantize_preconditioner.element_size() for quantize_preconditioner in self._local_preconditioner_list.quantized_value ) def update_preconditioners( self, - masked_grad_list: Tuple[Tensor, ...], + masked_grad_list: tuple[Tensor, ...], step: Tensor, + perform_amortized_computation: bool = False, ) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" @@ -272,7 +285,16 @@ def update_preconditioners( if self._use_bias_correction and self._beta2 < 1.0: self._bias_correction2 = torch.tensor(1.0) - self._beta2**step - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions the gradient list using the AdaGrad preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A tuple of gradients with None values removed. + + Returns: + tuple[Tensor, ...]: A tuple of preconditioned gradients. + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): @@ -301,7 +323,7 @@ def quantize_preconditioners(self) -> None: self._masked_preconditioner_list.quantize_() def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] + self, local_grad_selector: tuple[bool, ...] ) -> None: with profiler.record_function( f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" @@ -311,70 +333,110 @@ def compress_preconditioner_list( ) +FactorMatricesType = TypeVar( + "FactorMatricesType", tuple[QuantizedTensor, ...], QuantizedTensorList +) + + @dataclass -class ShampooKroneckerFactorsState(OptimizerModule): - """Shampoo Kronecker Factors (wrapped) for storing in the optimizer state.""" +class BaseShampooKroneckerFactors(Generic[FactorMatricesType], OptimizerModule): + """Base class for Shampoo Kronecker factors.""" - factor_matrices: Tuple[QuantizedTensor, ...] - inv_factor_matrices: Tuple[QuantizedTensor, ...] - factor_matrix_indices: Tuple[str, ...] - is_factor_matrices_diagonal: Tuple[Tensor, ...] = field(init=False) + factor_matrices: FactorMatricesType + factor_matrix_indices: tuple[str, ...] + is_factor_matrices_diagonal: tuple[Tensor, ...] = field(init=False) def __post_init__(self) -> None: super().__init__() - assert ( - len(self.factor_matrices) - == len(self.inv_factor_matrices) - == len(self.factor_matrix_indices) - ) + assert len(self.factor_matrices) == len(self.factor_matrix_indices) self.is_factor_matrices_diagonal = tuple( torch.tensor(True) for _ in range(len(self.factor_matrices)) ) @dataclass -class ShampooKroneckerFactorsList(OptimizerModule): - """Shampoo Kronecker Factors (unwrapped) for operations during optimizer computation.""" +class ShampooKroneckerFactorsState( + BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] +): + """Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" + + inv_factor_matrices: tuple[QuantizedTensor, ...] + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.inv_factor_matrices) + + +@dataclass +class ShampooKroneckerFactorsList(BaseShampooKroneckerFactors[QuantizedTensorList]): + """Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" - factor_matrices: QuantizedTensorList inv_factor_matrices: QuantizedTensorList - factor_matrix_indices: Tuple[str, ...] - is_factor_matrices_diagonal: Tuple[Tensor, ...] = field(init=False) def __post_init__(self) -> None: - super().__init__() - assert ( - len(self.factor_matrices) - == len(self.inv_factor_matrices) - == len(self.factor_matrix_indices) - ) - self.is_factor_matrices_diagonal = tuple( - torch.tensor(True) for _ in range(len(self.factor_matrices)) - ) + super().__post_init__() + assert len(self.factor_matrices) == len(self.inv_factor_matrices) -class ShampooPreconditionerList(PreconditionerList): - """Shampoo preconditioners for list of parameters. +@dataclass +class EigenvalueCorrectedShampooKroneckerFactorsState( + BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]] +): + """Eigenvalue-corrected Shampoo Kronecker factors (wrapped) for storing in the optimizer state.""" + + factor_matrices_eigenvectors: tuple[QuantizedTensor, ...] + corrected_eigenvalues: QuantizedTensor + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +@dataclass +class EigenvalueCorrectedShampooKroneckerFactorsList( + BaseShampooKroneckerFactors[QuantizedTensorList] +): + """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.""" + + factor_matrices_eigenvectors: QuantizedTensorList + corrected_eigenvalues: QuantizedTensorList + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +ShampooKroneckerFactorsListType = TypeVar( + "ShampooKroneckerFactorsListType", + ShampooKroneckerFactorsList, + EigenvalueCorrectedShampooKroneckerFactorsList, +) + + +class BaseShampooPreconditionerList( + PreconditionerList, Generic[ShampooKroneckerFactorsListType] +): + """Base class for Shampoo preconditioners. NOTE: Does not support sparse gradients at this time. Args: - block_list (Tuple[Tensor, ...]): List of (blocks of) parameters. + block_list (tuple[Tensor, ...]): List of (blocks of) parameters. state (DefaultDict[Tensor, Any]): Dictionary containing optimizer state. - block_info_list (Tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. + block_info_list (tuple[BlockInfo, ...]): List containing corresponding BlockInfo for each block/parameter in block_list. Note that this should have the same length as block_list. - distributor_selector (Tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter + distributor_selector (tuple[bool, ...]): Distributor selector is a boolean list indicating whether a blocked parameter is selected by the current Distributor. - root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig) + precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) + preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation. (Default: DefaultEigenConfig) beta2 (float): Exponential moving average factor for Shampoo factor matrices. If beta2 = 1., will use unweighted sum. (Default: 1.0) epsilon (float): Epsilon term for regularizing preconditioner to ensure positive definiteness. (Default: 1e-12) - inv_root_override (Union[int, Tuple[int, ...]]): Inverse root to use in Shampoo. If a list [l0, l1, l2, ..., lp], then we will + inv_root_override (int | tuple[int, ...]): Inverse root to use in Shampoo. If a list [l0, l1, l2, ..., lp], then we will use -1 / l0 for 0-D tensors (scalars), -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the tensor exceeds the length of the list, we revert to using the default value. If 0 is used, uses the default inverse root -1 / (2 * o), where o is the order of the tensor. (Default: 0) use_bias_correction (bool): Flag for using bias correction. (Default: True) - precision_config (PrecisionConfig): Data types for optimizer states. (Default: all fields torch.float) use_protected_eigh (bool): Flag for using two guards to prevent failures of torch.linalg.eigh. (Default: True) 1. Attempts to compute root inverse in preconditioner_dtype precision. 2. Attempts to recompute the eigendecomposition if using lower-precision fails. @@ -384,16 +446,16 @@ class ShampooPreconditionerList(PreconditionerList): def __init__( self, - block_list: Tuple[Tensor, ...], + block_list: tuple[Tensor, ...], # type: ignore state: DefaultDict[Tensor, Any], - block_info_list: Tuple[BlockInfo, ...], - distributor_selector: Tuple[bool, ...], + block_info_list: tuple[BlockInfo, ...], + distributor_selector: tuple[bool, ...], precision_config: PrecisionConfig, - root_inv_config: RootInvConfig = DefaultEigenConfig, + preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig, beta2: float = 1.0, epsilon: float = 1e-12, - inv_root_override: Union[int, Tuple[int, ...]] = 0, + inv_root_override: int | tuple[int, ...] = 0, use_bias_correction: bool = True, use_protected_eigh: bool = True, ) -> None: @@ -401,7 +463,7 @@ def __init__( # Initialize parameters. self._precision_config = precision_config - self._root_inv_config = root_inv_config + self._preconditioner_computation_config = preconditioner_computation_config self._beta2 = beta2 self._epsilon = epsilon self._inv_root_override = inv_root_override @@ -409,6 +471,103 @@ def __init__( self._use_protected_eigh = use_protected_eigh self._bias_correction2: Tensor = torch.tensor(1.0) + # Create the Kronecker factors. + kronecker_factors_list: list[ShampooKroneckerFactorsListType] = ( + self._create_kronecker_factors_state( + block_list=block_list, + state=state, + block_info_list=block_info_list, + ) + ) + + # Initialize state lists. + self._initialize_state_lists( + block_list=block_list, + kronecker_factors_list=kronecker_factors_list, + distributor_selector=distributor_selector, + ) + + def _create_base_kronecker_factors( + self, + block_info: BlockInfo, + dims: torch.Size, + ) -> BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]: + """ + Creates a BaseShampooKroneckerFactor object for a given block. + + Args: + block_info (BlockInfo): The BlockInfo object containing information about the block. + dims (torch.Size): The dimensions of the block. + + Returns: + kronecker_factors_state (BaseShampooKroneckerFactors[tuple[QuantizedTensor, ...]]): An object containing the Kronecker factors for the block. + """ + factor_matrices = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.factor_matrix_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + + param_index, block_index = block_info.composable_block_ids + factor_matrix_indices = tuple( + ".".join((str(param_index), str(block_index), str(k))) + for k in range(len(dims)) + ) + return BaseShampooKroneckerFactors( + factor_matrices=factor_matrices, + factor_matrix_indices=factor_matrix_indices, + ) + + @abstractmethod + def _create_kronecker_factors_state_for_block( + self, + block: Tensor, + block_info: BlockInfo, + dims: torch.Size, + ) -> ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState: + """ + Creates a Kronecker factors state object for a given block. + + Args: + block (Tensor): The block of the parameter. + block_info (BlockInfo): The BlockInfo object containing information about the block. + dims (torch.Size): The dimensions of the block. + + Returns: + kronecker_factors_state (ShampooKroneckerFactorsState | EigenvalueCorrectedShampooKroneckerFactorsState): An object containing the Kronecker factors for the block. + """ + ... + + @abstractmethod + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> ShampooKroneckerFactorsListType: + """ + Creates a ShampooKroneckerFactorsList object from the given ShampooKroneckerFactorsState. + + Args: + kronecker_factors_state (ShampooKroneckerFactorsState): The state containing the Kronecker factors. + + Returns: + kronecker_factors_list: A list of ShampooKroneckerFactors objects. + """ + ... + + def _create_kronecker_factors_state( + self, + block_list: tuple[Tensor, ...], + # type: ignore + state: DefaultDict[Tensor, Any], + block_info_list: tuple[BlockInfo, ...], + ) -> list[ShampooKroneckerFactorsListType]: # Instantiate (blocked) Kronecker factors and construct list of Kronecker factors. # NOTE: We need to instantiate the Kronecker factor states within the optimizer's state dictionary, # and do not explicitly store them as ShampooPreconditionerList attributes here. @@ -423,104 +582,46 @@ def __init__( state[block_info.param][block_index] = {} block_state = state[block_info.param][block_index] - # Instantiate ShampooKroneckerFactors for this block. - factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - (dim, dim), - self._precision_config.factor_matrix_dtype, - block_info.param.device, - ), - block_info=block_info, - ) - for dim in dims - ) - inv_factor_matrices = tuple( - QuantizedTensor( - quantized_values=block_info.allocate_zeros_tensor( - (dim, dim), - self._precision_config.inv_factor_matrix_dtype, - block_info.param.device, - ), - block_info=block_info, - ) - for dim in dims + block_state[SHAMPOO] = self._create_kronecker_factors_state_for_block( + block=block, block_info=block_info, dims=dims ) - preconditioner_index = str(param_index) + "." + str(block_index) - factor_matrix_indices = tuple( - preconditioner_index + "." + str(k) for k in range(len(dims)) - ) - block_state[SHAMPOO] = ShampooKroneckerFactorsState( - factor_matrices=factor_matrices, - inv_factor_matrices=inv_factor_matrices, - factor_matrix_indices=factor_matrix_indices, - ) kronecker_factors_list.append( - ShampooKroneckerFactorsList( - # Factor matrices computation (accumulation, root inverse) should use the determined dtype. - factor_matrices=QuantizedTensorList( - quantized_data=factor_matrices, - quantized_dtype=self._precision_config.factor_matrix_dtype, - computation_dtype=self._precision_config.factor_matrix_computation_dtype, - ), - # Inverse factor matrices computation (preconditioning) should use the dtype of the block / gradient. - inv_factor_matrices=QuantizedTensorList( - quantized_data=inv_factor_matrices, - quantized_dtype=self._precision_config.inv_factor_matrix_dtype, - computation_dtype=self._precision_config.computation_dtype, - ), - factor_matrix_indices=factor_matrix_indices, - ) + self._create_kronecker_factors_list(block_state[SHAMPOO]) ) logger.info( - f"Instantiated Shampoo Preconditioner {preconditioner_index} " + f"Instantiated Shampoo Preconditioner {str(param_index) + '.' + str(block_index)} " f"({[(factor_matrix.quantized_values.shape, factor_matrix.quantized_values.dtype) for factor_matrix in block_state[SHAMPOO].factor_matrices]}) " f"for Parameter {param_index} ({block_info.param.shape}), Block {block_index} ({block.shape})." ) - # Initialize local lists. - local_block_list = compress_list(block_list, distributor_selector) - self._local_kronecker_factors_list: Tuple[ShampooKroneckerFactorsList, ...] = ( - compress_list(kronecker_factors_list, distributor_selector) - ) - self._local_order_list: Tuple[int, ...] = tuple( - block.dim() for block in local_block_list - ) - self._local_root_list: Tuple[int, ...] = self._get_inverse_roots_from_override( - self._inv_root_override, self._local_order_list - ) + return kronecker_factors_list - # Masked lists are the list of active preconditioners or values after filtering out gradients with None. - self._masked_order_list: Tuple[int, ...] = self._local_order_list - self._masked_root_list: Tuple[int, ...] = self._local_root_list - self._masked_kronecker_factors_list: Tuple[ShampooKroneckerFactorsList, ...] = ( - self._local_kronecker_factors_list - ) + @abstractmethod + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + """ + Retrieves the inverse roots from the override parameter. - # Construct lists of bytes and numels for logging purposes. - # NOTE: These lists are constructed across all blocked parameters. - self._dims_list: Tuple[torch.Size, ...] = compress_list( - self._dims_list, distributor_selector - ) - self._numel_list: Tuple[int, ...] = tuple( - sum(2 * dim**2 for dim in dims) for dims in self._dims_list - ) - self._num_bytes_list: Tuple[int, ...] = tuple( - numel - * ( - get_dtype_size(self._precision_config.factor_matrix_dtype) - + get_dtype_size(block.dtype) - ) - // 2 - for numel, block in zip(self._numel_list, local_block_list, strict=True) - ) + Args: + inv_root_override (int | Sequence[int]): The override value for the inverse root. + order_list (tuple[int, ...]): A list of orders for each tensor in the preconditioner. + + Returns: + tuple[int, ...]: A list of inverse roots for each tensor in the preconditioner. + """ + ... @staticmethod - def _get_inverse_roots_from_override( - inv_root_override: Union[int, Sequence[int]], order_list: Tuple[int, ...] - ) -> Tuple[int, ...]: + def _get_inverse_roots_from_override_with_high_order_default( + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + high_order_default: Callable[[int], int], + ) -> tuple[int, ...]: """Retrieves the appropriate root from the inverse root override parameter for a list of tensor orders. @@ -529,20 +630,21 @@ def _get_inverse_roots_from_override( If order = 1, then we will return 1; If order = 2, then we will return 4; If order = 3, then we will return 3; - If order > 3, then we will return 2 * order. + If order > 3, then we will return high_order_default(order). Args: - inv_root_override (int, Sequence[int]): Inverse root override int or list. - order_list (Tuple[int, ...]): List of orders for their corresponding tensors. + inv_root_override (int | Sequence[int]): Inverse root override int or list. + order_list (tuple[int, ...]): List of orders for their corresponding tensors. + higher_order_default (Callable[[int], int]): Function for computing the inverse root for orders greater than the length of the inverse root override list. Returns: - root_list (Tuple[int, ...]): Inverse roots to use in Shampoo for a list of tensors. + root_list (int): Inverse roots to use in Shampoo for a list of tensors. """ if isinstance(inv_root_override, Sequence): return tuple( ( - 2 * order + high_order_default(order) if order >= len(inv_root_override) else inv_root_override[order] ) @@ -550,16 +652,151 @@ def _get_inverse_roots_from_override( ) else: return ( - tuple(2 * order for order in order_list) + tuple(high_order_default(order) for order in order_list) if inv_root_override == 0 else (inv_root_override,) * len(order_list) ) + @abstractmethod + def _amortized_computation(self) -> None: + """ + Computes the amortized computation needed for each Shampoo implementation. + This amortized computation is computation heavy work that could not be done for eeac step. + As a result, each Shampoo implementation may implement this method for its neeed. + """ + ... + + @staticmethod + def _check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix: Tensor, + is_factor_matrix_diagonal: Tensor, + factor_matrix_index: str, + ) -> None: + # For tracking diagonality of the factor matrix. + # Checks if the factor matrix is currently diagonal, then checks whether or not + # the update factor matrix is diagonal. + if is_factor_matrix_diagonal and not check_diagonal(factor_matrix): + is_factor_matrix_diagonal.copy_(torch.tensor(False)) + logger.debug(f"Factor matrix {factor_matrix_index} is not diagonal.") + + # Check for nan or inf values. + if torch.isnan(factor_matrix).any(): + raise PreconditionerValueError( + f"Encountered nan values in factor matrix {factor_matrix_index}! " + f"To mitigate, check if nan inputs are being passed into the network or nan gradients " + f"are being passed to the optimizer." + f"For debugging purposes, factor_matrix {factor_matrix_index}: " + f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " + f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." + ) + if torch.isinf(factor_matrix).any(): + raise PreconditionerValueError( + f"Encountered inf values in factor matrix {factor_matrix_index}! " + f"In some cases, this may be due to divergence of the algorithm. " + f"To mitigate, try decreasing the learning rate or increasing grafting epsilon." + f"For debugging purposes, factor_matrix {factor_matrix_index}: " + f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " + f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." + ) + def update_preconditioners( - self, masked_grad_list: Tuple[Tensor, ...], step: Tensor + self, + masked_grad_list: tuple[Tensor, ...], + step: Tensor, + perform_amortized_computation: bool, ) -> None: + """ + Updates the preconditioners. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + step (Tensor): The current step. + perform_amortized_computation (bool): Whether to perform an amortized computation. + + Returns: + None + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" + ): + # Update the Kronecker factor matrices. + self._update_factor_matrices(masked_grad_list=masked_grad_list) + + # Update bias correction term based on step. + if self._use_bias_correction and self._beta2 < 1.0: + self._bias_correction2 = torch.tensor(1.0) - self._beta2**step + + # In Shampoo, this is equivalent to computing the inverse factor matrix. + # In Eigenvalue-Corrected Shampoo, this is equivalent to computing the eigenvector of the factor matrix. + if perform_amortized_computation: + self._amortized_computation() + + def _initialize_state_lists( + self, + block_list: tuple[Tensor, ...], + kronecker_factors_list: list[ShampooKroneckerFactorsListType], + distributor_selector: tuple[bool, ...], + ) -> None: + # Initialize local lists. + local_block_list = compress_list(block_list, distributor_selector) + self._local_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = compress_list(kronecker_factors_list, distributor_selector) + self._local_order_list: tuple[int, ...] = tuple( + block.dim() for block in local_block_list + ) + self._local_root_list: tuple[int, ...] = self._get_inverse_roots_from_override( + self._inv_root_override, + self._local_order_list, + ) + + # Masked lists are the list of active preconditioners or values after filtering out gradients with None. + self._masked_order_list: tuple[int, ...] = self._local_order_list + self._masked_root_list: tuple[int, ...] = self._local_root_list + self._masked_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = self._local_kronecker_factors_list + + # Construct lists of bytes and numels for logging purposes. + # NOTE: These lists are constructed across all blocked parameters. + self._dims_list: tuple[torch.Size, ...] = compress_list( + self._dims_list, distributor_selector + ) + self._numel_list: tuple[int, ...] = tuple( + sum(2 * dim**2 for dim in dims) for dims in self._dims_list + ) + self._num_bytes_list: tuple[int, ...] = tuple( + numel + * ( + get_dtype_size(self._precision_config.factor_matrix_dtype) + + get_dtype_size(block.dtype) + ) + // 2 + for numel, block in zip(self._numel_list, local_block_list, strict=True) + ) + + def compress_preconditioner_list( + self, local_grad_selector: tuple[bool, ...] + ) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" + ): + self._masked_order_list = compress_list( + self._local_order_list, local_grad_selector + ) + self._masked_root_list = compress_list( + self._local_root_list, local_grad_selector + ) + self._masked_kronecker_factors_list: tuple[ + ShampooKroneckerFactorsListType, + ..., + ] = compress_list(self._local_kronecker_factors_list, local_grad_selector) + + def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self._update_factor_matrices.__name__} ##" ): # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. # We apply foreach operators onto the list of Kronecker factor matrices (as opposed to the @@ -594,43 +831,107 @@ def update_preconditioners( alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, ) - # Update bias correction term based on step list. - if self._use_bias_correction and self._beta2 < 1.0: - self._bias_correction2 = torch.tensor(1.0) - self._beta2**step + @staticmethod + def _precondition_grad( + grad: Tensor, + preconditioner_list: tuple[Tensor, ...], + dims: tuple[list[int], list[int]] = ([0], [0]), + ) -> Tensor: + return reduce(partial(torch.tensordot, dims=dims), preconditioner_list, grad) + + +class ShampooPreconditionerList( + BaseShampooPreconditionerList[ShampooKroneckerFactorsList] +): + """Shampoo preconditioners for list of parameters.""" + + def _create_kronecker_factors_state_for_block( + self, block: Tensor, block_info: BlockInfo, dims: torch.Size + ) -> ShampooKroneckerFactorsState: + inv_factor_matrices = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.inv_factor_matrix_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + + base_kronecker_factors = self._create_base_kronecker_factors( + block_info=block_info, dims=dims + ) + return ShampooKroneckerFactorsState( + factor_matrices=base_kronecker_factors.factor_matrices, + factor_matrix_indices=base_kronecker_factors.factor_matrix_indices, + inv_factor_matrices=inv_factor_matrices, + ) - def precondition(self, masked_grad_list: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> ShampooKroneckerFactorsList: + assert isinstance(kronecker_factors_state, ShampooKroneckerFactorsState) + return ShampooKroneckerFactorsList( + # Factor matrices computation (accumulation, root inverse) should use the determined dtype. + factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices, + quantized_dtype=self._precision_config.factor_matrix_dtype, + computation_dtype=self._precision_config.factor_matrix_computation_dtype, + ), + # Inverse factor matrices computation (preconditioning) should use the dtype of the block / gradient. + inv_factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.inv_factor_matrices, + quantized_dtype=self._precision_config.inv_factor_matrix_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, + ) + + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + return BaseShampooPreconditionerList._get_inverse_roots_from_override_with_high_order_default( + inv_root_override, order_list, high_order_default=lambda order: 2 * order + ) + + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions a list of gradients using the Shampoo preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + + Returns: + tuple[Tensor, ...]: A list of preconditioned gradients. + """ with profiler.record_function( f"## {self.__class__.__name__}:{self.precondition.__name__} ##" ): - - def precondition_masked_grad( - masked_grad: Tensor, - inv_factor_matrices: Tuple[Tensor, ...], - ) -> Tensor: - for inv_factor_matrix in inv_factor_matrices: - masked_grad = torch.tensordot( - masked_grad, inv_factor_matrix, [[0], [0]] - ) - return masked_grad - return tuple( - precondition_masked_grad( - masked_grad=masked_grad, - inv_factor_matrices=kronecker_factors.inv_factor_matrices.dequantized_value, + self._precondition_grad( + grad=masked_grad, + preconditioner_list=kronecker_factors.inv_factor_matrices.dequantized_value, ) for masked_grad, kronecker_factors in zip( masked_grad_list, self._masked_kronecker_factors_list, strict=True ) ) - def compute_root_inverse(self) -> None: + @torch.compiler.disable + def _amortized_computation(self) -> None: # NOTE: This function currently only computes the matrix root inverse based on # the masked lists which combines both selection based on the distributor and where # grad is not None. Implicitly, this assumes that there are no changes between the # selector or masking from iteration-to-iteration within a single precondition_frequency # interval. with profiler.record_function( - f"## {self.__class__.__name__}:{self.compute_root_inverse.__name__} ##" + f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" ): for kronecker_factors, root in zip( self._masked_kronecker_factors_list, @@ -649,81 +950,58 @@ def compute_root_inverse(self) -> None: kronecker_factors.factor_matrix_indices, strict=True, ): - # For tracking diagonality of the preconditioner. - # Checks if the preconditioner is currently diagonal, then checks whether or not - # the update matrix is diagonal. - if is_factor_matrix_diagonal and not check_diagonal(factor_matrix): - is_factor_matrix_diagonal.copy_(torch.tensor(False)) - logger.debug( - f"Factor matrix {factor_matrix_index} is not diagonal." - ) - # Add epsilon term and incorporate bias correction. bias_corrected_factor_matrix = ( factor_matrix / self._bias_correction2 ) - # Check for nan or inf values. - if torch.isnan(bias_corrected_factor_matrix).any(): - raise PreconditionerValueError( - f"Encountered nan values in bias-corrected factor matrix {factor_matrix_index}! " - f"To mitigate, check if nan inputs are being passed into the network or nan gradients " - f"are being passed to the optimizer. " - f"For debugging purposes, factor_matrix {factor_matrix_index}: " - f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " - f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." - ) - if torch.isinf(bias_corrected_factor_matrix).any(): - raise PreconditionerValueError( - f"Encountered inf values in bias-corrected factor matrix {factor_matrix_index}! " - f"In some cases, this may be due to divergence of the algorithm. " - f"To mitigate, try decreasing the learning rate or increasing grafting epsilon. " - f"For debugging purposes, factor_matrix {factor_matrix_index}: " - f"{torch.min(factor_matrix)=}, {torch.max(factor_matrix)=}, " - f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." - ) + BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix=bias_corrected_factor_matrix, + is_factor_matrix_diagonal=is_factor_matrix_diagonal, + factor_matrix_index=factor_matrix_index, + ) # Compute inverse preconditioner. - # If reuse_previous_inv_factor_matrix is True, will reuse previous matrix if matrix - # inverse root computation fails. try: computed_inv_factor_matrix = matrix_inverse_root( A=bias_corrected_factor_matrix, root=Fraction( root / getattr( - self._root_inv_config, "exponent_multiplier", 1 + self._preconditioner_computation_config, + "exponent_multiplier", + 1, ) ), - root_inv_config=self._root_inv_config, + root_inv_config=cast( + RootInvConfig, self._preconditioner_computation_config + ), epsilon=self._epsilon, is_diagonal=is_factor_matrix_diagonal, ).to(dtype=inv_factor_matrix.dtype) - - # Check if we encounter NaN or inf values in computed inverse matrix. - if ( - torch.isnan(computed_inv_factor_matrix).any() - or torch.isinf(computed_inv_factor_matrix).any() - ): - torch.set_printoptions(threshold=100_000) - raise PreconditionerValueError( - f"Encountered nan or inf values in inverse factor matrix {factor_matrix_index}! " - f"To mitigate, check factor matrix before matrix inverse root computation: " - f"{bias_corrected_factor_matrix=}" - ) - - inv_factor_matrix.copy_(computed_inv_factor_matrix) - - except PreconditionerValueError as pve: - raise pve except Exception as exception: + # If self._use_protected_eigh is True, will reuse previous matrix if matrix inverse root computation fails. if not self._use_protected_eigh: raise exception else: logger.warning( - f"Matrix inverse root computation failed for factor matrix {factor_matrix_index} " - f"with exception {exception}. Using previous inv_factor_matrix and continuing..." + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous inversed factor matrix and continuing..." ) + # Define computed_inv_factor_matrix to prevent undefined local variable error. + computed_inv_factor_matrix = inv_factor_matrix + + # Check if we encounter NaN or inf values in computed inverse matrix. + if ( + torch.isnan(computed_inv_factor_matrix).any() + or torch.isinf(computed_inv_factor_matrix).any() + ): + torch.set_printoptions(threshold=100_000) + raise PreconditionerValueError( + f"Encountered nan or inf values in inverse factor matrix {factor_matrix_index}! " + f"To mitigate, check factor matrix before the matrix computation: {bias_corrected_factor_matrix=}" + ) + inv_factor_matrix.copy_(computed_inv_factor_matrix) def dequantize_preconditioners(self) -> None: with profiler.record_function( @@ -741,25 +1019,10 @@ def quantize_preconditioners(self) -> None: kronecker_factors.factor_matrices.quantize_() kronecker_factors.inv_factor_matrices.quantize_() - def compress_preconditioner_list( - self, local_grad_selector: Tuple[bool, ...] - ) -> None: - with profiler.record_function( - f"## {self.__class__.__name__}:{self.compress_preconditioner_list.__name__} ##" - ): - self._masked_order_list = compress_list( - self._local_order_list, local_grad_selector - ) - self._masked_root_list = compress_list( - self._local_root_list, local_grad_selector - ) - self._masked_kronecker_factors_list: Tuple[ - ShampooKroneckerFactorsList, ... - ] = compress_list(self._local_kronecker_factors_list, local_grad_selector) - + @torch.compiler.disable def compute_root_inverse_residuals( self, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + ) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: relative_errors = [] relative_residuals = [] @@ -781,10 +1044,17 @@ def compute_root_inverse_residuals( A=bias_corrected_factor_matrix, X_hat=inv_factor_matrix, root=Fraction( - root / getattr(self._root_inv_config, "exponent_multiplier", 1) + root + / getattr( + self._preconditioner_computation_config, + "exponent_multiplier", + 1, + ) ), epsilon=self._epsilon, - root_inv_config=self._root_inv_config, + root_inv_config=cast( + RootInvConfig, self._preconditioner_computation_config + ), ) relative_errors.append(relative_error) relative_residuals.append(relative_residual) @@ -795,6 +1065,273 @@ def compute_root_inverse_residuals( ) +class EigenvalueCorrectedShampooPreconditionerList( + BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList] +): + """Eigenvalue-corrected Shampoo preconditioners for list of parameters.""" + + def _create_kronecker_factors_state_for_block( + self, block: Tensor, block_info: BlockInfo, dims: torch.Size + ) -> EigenvalueCorrectedShampooKroneckerFactorsState: + factor_matrices_eigenvectors = tuple( + QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=(dim, dim), + dtype=self._precision_config.factor_matrix_eigenvectors_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + for dim in dims + ) + corrected_eigenvalues = QuantizedTensor( + quantized_values=block_info.allocate_zeros_tensor( + shape=tuple(dims), + dtype=self._precision_config.corrected_eigenvalues_dtype, + device=block_info.param.device, + ), + block_info=block_info, + ) + + base_kronecker_factors = self._create_base_kronecker_factors( + block_info=block_info, dims=dims + ) + return EigenvalueCorrectedShampooKroneckerFactorsState( + factor_matrices=base_kronecker_factors.factor_matrices, + factor_matrices_eigenvectors=factor_matrices_eigenvectors, + corrected_eigenvalues=corrected_eigenvalues, + factor_matrix_indices=base_kronecker_factors.factor_matrix_indices, + ) + + def _create_kronecker_factors_list( + self, + kronecker_factors_state: ShampooKroneckerFactorsState + | EigenvalueCorrectedShampooKroneckerFactorsState, + ) -> EigenvalueCorrectedShampooKroneckerFactorsList: + assert isinstance( + kronecker_factors_state, EigenvalueCorrectedShampooKroneckerFactorsState + ) + return EigenvalueCorrectedShampooKroneckerFactorsList( + factor_matrices=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices, + quantized_dtype=self._precision_config.factor_matrix_dtype, + computation_dtype=self._precision_config.factor_matrix_computation_dtype, + ), + factor_matrices_eigenvectors=QuantizedTensorList( + quantized_data=kronecker_factors_state.factor_matrices_eigenvectors, + quantized_dtype=self._precision_config.factor_matrix_eigenvectors_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + corrected_eigenvalues=QuantizedTensorList( + quantized_data=(kronecker_factors_state.corrected_eigenvalues,), + quantized_dtype=self._precision_config.corrected_eigenvalues_dtype, + computation_dtype=self._precision_config.computation_dtype, + ), + factor_matrix_indices=kronecker_factors_state.factor_matrix_indices, + ) + + def _get_inverse_roots_from_override( + self, + inv_root_override: int | Sequence[int], + order_list: tuple[int, ...], + ) -> tuple[int, ...]: + return BaseShampooPreconditionerList._get_inverse_roots_from_override_with_high_order_default( + inv_root_override, order_list, high_order_default=lambda order: 2 + ) + + def update_preconditioners( + self, + masked_grad_list: tuple[Tensor, ...], + step: Tensor, + perform_amortized_computation: bool, + ) -> None: + """ + Updates the preconditioners. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + step (Tensor): The current step. + perform_amortized_computation (bool): Whether to perform an amortized computation. + + Returns: + None + """ + with profiler.record_function( + f"## {self.__class__.__name__}:{self.update_preconditioners.__name__} ##" + ): + super().update_preconditioners( + masked_grad_list=masked_grad_list, + step=step, + perform_amortized_computation=perform_amortized_computation, + ) + # Update the eigenvalue corrections of Shampoo's preconditioner. + self._update_eigenvalue_corrections(masked_grad_list=masked_grad_list) + + def _update_eigenvalue_corrections( + self, masked_grad_list: tuple[Tensor, ...] + ) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self._update_eigenvalue_corrections.__name__} ##" + ): + # NOTE: Unlike AdagradPreconditionerList, we will loop through each gradient individually. + for grad, kronecker_factors in zip( + masked_grad_list, + self._masked_kronecker_factors_list, + strict=True, + ): + factor_eigenvectors = ( + kronecker_factors.factor_matrices_eigenvectors.dequantized_value + ) + if factor_eigenvectors[0].any(): + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + ) + # Scale corrected eigenvalues. + # NOTE: The case when self._beta2 == 1.0 is not well tested and might not be stable. + if self._beta2 != 1.0: + kronecker_factors.corrected_eigenvalues.dequantized_value[0].mul_( + self._beta2 + ) + # Update corrected eigenvalues (squared gradient in eigenbasis of Shampoo preconditioner). + kronecker_factors.corrected_eigenvalues.dequantized_value[0].add_( + grad.square(), + alpha=1 - self._beta2 if self._beta2 != 1.0 else 1.0, + ) + + def precondition(self, masked_grad_list: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """ + Preconditions a list of gradients using the Eigenvalue-Corrected Shampoo preconditioner. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + + Returns: + tuple[Tensor, ...]: A list of preconditioned gradients. + """ + with profiler.record_function( + f"## {self.__class__.__name__}:{self.precondition.__name__} ##" + ): + preconditioned_grad_list = [] + for masked_grad, kronecker_factors, root in zip( + masked_grad_list, + self._masked_kronecker_factors_list, + self._masked_root_list, + strict=True, + ): + factor_eigenvectors = ( + kronecker_factors.factor_matrices_eigenvectors.dequantized_value + ) + corrected_eigenvalues = ( + kronecker_factors.corrected_eigenvalues.dequantized_value[0] + ) + use_eigenbasis = factor_eigenvectors[0].any() + grad = masked_grad.clone() + if use_eigenbasis: + # Convert to eigenbasis of Shampoo factor matrices. + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + ) + + # Precondition with inverse root of corrected eigenvalues. + grad.div_( + corrected_eigenvalues.div(self._bias_correction2) + .add_(self._epsilon) + .pow_(1 / root) + ) + if use_eigenbasis: + # Convert back to basis of the parameters. + grad = self._precondition_grad( + grad=grad, + preconditioner_list=factor_eigenvectors, + dims=([0], [1]), + ) + preconditioned_grad_list.append(grad) + return tuple(preconditioned_grad_list) + + @torch.compiler.disable + def _amortized_computation(self) -> None: + # NOTE: This function currently only computes the preconditioner eigenvectors based on + # the masked lists which combines both selection based on the distributor and where + # grad is not None. Implicitly, this assumes that there are no changes between the + # selector or masking from iteration-to-iteration within a single precondition_frequency + # interval. + with profiler.record_function( + f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + for ( + factor_matrix, + factor_matrix_eigenvectors, + is_factor_matrix_diagonal, + factor_matrix_index, + ) in zip( + kronecker_factors.factor_matrices.dequantized_value, + kronecker_factors.factor_matrices_eigenvectors.dequantized_value, + kronecker_factors.is_factor_matrices_diagonal, + kronecker_factors.factor_matrix_indices, + strict=True, + ): + BaseShampooPreconditionerList._check_factor_matrix_for_diagonality_nan_and_inf( + factor_matrix=factor_matrix, + is_factor_matrix_diagonal=is_factor_matrix_diagonal, + factor_matrix_index=factor_matrix_index, + ) + + # Compute eigenvectors of factor matrix. + try: + computed_eigenvectors = matrix_eigenvectors( + A=factor_matrix, + eigenvector_computation_config=cast( + EigenvalueCorrectionConfig, + self._preconditioner_computation_config, + ), + is_diagonal=is_factor_matrix_diagonal, + ) + except Exception as exception: + # If self._use_protected_eigh is True, will reuse previous matrix if matrix eigenvector computation fails. + if not self._use_protected_eigh: + raise exception + else: + logger.warning( + f"Matrix computation failed for factor matrix {factor_matrix_index} " + f"with {exception=}. Using previous factor matrix eigenvectors and continuing..." + ) + # Define computed_eigenvectors to prevent undefined local variable error. + computed_eigenvectors = factor_matrix_eigenvectors + + # Check if we encounter NaN or inf values in computed eigenvectors. + if ( + torch.isnan(computed_eigenvectors).any() + or torch.isinf(computed_eigenvectors).any() + ): + torch.set_printoptions(threshold=100_000) + raise PreconditionerValueError( + f"Encountered nan or inf values in eigenvectors of factor matrix {factor_matrix_index}! " + f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}" + ) + factor_matrix_eigenvectors.copy_(computed_eigenvectors) + + def dequantize_preconditioners(self) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.dequantize_preconditioners.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + kronecker_factors.factor_matrices.dequantize_() + kronecker_factors.factor_matrices_eigenvectors.dequantize_() + kronecker_factors.corrected_eigenvalues.dequantize_() + + def quantize_preconditioners(self) -> None: + with profiler.record_function( + f"## {self.__class__.__name__}:{self.quantize_preconditioners.__name__} ##" + ): + for kronecker_factors in self._masked_kronecker_factors_list: + kronecker_factors.factor_matrices.quantize_() + kronecker_factors.factor_matrices_eigenvectors.quantize_() + kronecker_factors.corrected_eigenvalues.quantize_() + + class DequantizePreconditionersContext(ParameterizeEnterExitContext): """DequantizePreconditionersContext is used for automatically dequantize and then quantize the preconditioners used within this context. diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 3f84622..56e9687 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -21,12 +21,13 @@ from distributed_shampoo.utils.shampoo_preconditioner_list import ( AdagradPreconditionerList, DequantizePreconditionersContext, + EigenvalueCorrectedShampooPreconditionerList, PreconditionerList, SGDPreconditionerList, ShampooPreconditionerList, ) from distributed_shampoo.utils.shampoo_quantization import QuantizedTensorList -from matrix_functions_types import EigenConfig +from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig, EigenConfig from torch import Tensor @@ -73,9 +74,12 @@ def _test_update_preconditioners_and_precondition( preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list, step=torch.tensor(step), + # Only compute the new layerwise direction when the update_preconditioners() reach the last step. + perform_amortized_computation=isinstance( + preconditioner_list, ShampooPreconditionerList + ) + and step == len(masked_grad_lists), ) - if isinstance(preconditioner_list, ShampooPreconditionerList): - preconditioner_list.compute_root_inverse() masked_preconditioned_grad_list = preconditioner_list.precondition( masked_grad_list=masked_grad_lists[-1] ) @@ -472,7 +476,7 @@ def test_inverse_roots_from_override( """ Tests that the inverse roots are computed correctly from inv_root_override. """ - root_inv_config = EigenConfig(exponent_multiplier=2.0) + preconditioner_computation_config = EigenConfig(exponent_multiplier=2.0) masked_grad_list1 = ( torch.tensor([1.0, 0.0]), @@ -497,7 +501,7 @@ def test_inverse_roots_from_override( beta2=1.0, use_bias_correction=True, inv_root_override=inv_root_override, - root_inv_config=root_inv_config, + preconditioner_computation_config=preconditioner_computation_config, ), masked_grad_lists=[masked_grad_list1, masked_grad_list2], masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, @@ -510,37 +514,37 @@ def test_raise_inf_in_factor_matrix_compute_root_inverse(self) -> None: with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.inf, torch.inf]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.inf, torch.inf]]), - ), - step=torch.tensor(1), - ) with self.assertRaisesRegex( PreconditionerValueError, - re.escape("Encountered inf values in bias-corrected factor matrix"), + re.escape("Encountered inf values in factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.inf, torch.inf]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.inf, torch.inf]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) def test_raise_nan_in_factor_matrix_compute_root_inverse(self) -> None: with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ): - self._preconditioner_list.update_preconditioners( - masked_grad_list=( - torch.tensor([torch.nan, torch.nan]), - torch.eye(2) / torch.tensor(2.0).sqrt(), - torch.tensor([[torch.nan, torch.nan]]), - ), - step=torch.tensor(1), - ) with self.assertRaisesRegex( PreconditionerValueError, - re.escape("Encountered nan values in bias-corrected factor matrix"), + re.escape("Encountered nan values in factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.nan, torch.nan]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.nan, torch.nan]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) # Note: This is needed for pyre to infer the type of argument into mock.patch.object. shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list @@ -559,7 +563,15 @@ def test_raise_inf_in_inv_factor_matrix_compute_root_inverse( PreconditionerValueError, re.escape("Encountered nan or inf values in inverse factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called_once() @mock.patch.object( @@ -576,7 +588,15 @@ def test_raise_nan_in_inv_factor_matrix_compute_root_inverse( PreconditionerValueError, re.escape("Encountered nan or inf values in inverse factor matrix"), ): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called_once() @mock.patch.object( @@ -592,20 +612,28 @@ def test_matrix_compute_root_inverse_internal_failure( preconditioner_list=self._preconditioner_list ), self.assertLogs(level="WARNING") as cm: # Because use_protected_eigh is True, we expect the warning to be logged. - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) self.assertCountEqual( [r.msg for r in cm.records], [ - "Matrix inverse root computation failed for factor matrix 0.block_0.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_0.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_0.1 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_1.0 with exception ." - " Using previous inv_factor_matrix and continuing...", - "Matrix inverse root computation failed for factor matrix 1.block_1.1 with exception ." - " Using previous inv_factor_matrix and continuing...", + "Matrix computation failed for factor matrix 0.block_0.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_0.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_0.1 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_1.0 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", + "Matrix computation failed for factor matrix 1.block_1.1 with exception=ZeroDivisionError()." + " Using previous inversed factor matrix and continuing...", ], ) mock_matrix_inverse_root.assert_called() @@ -617,7 +645,15 @@ def test_matrix_compute_root_inverse_internal_failure( with DequantizePreconditionersContext( preconditioner_list=self._preconditioner_list ), self.assertRaises(ZeroDivisionError): - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) mock_matrix_inverse_root.assert_called() @mock.patch.object( @@ -634,7 +670,15 @@ def test_matrix_compute_root_inverse_factor_matrix_non_diagonal( ), self.assertLogs( level="DEBUG", ) as cm: - self._preconditioner_list.compute_root_inverse() + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) self.assertCountEqual( [r.msg for r in cm.records], [ @@ -691,12 +735,13 @@ def test_compute_root_inverse_residuals(self) -> None: preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list1, step=torch.tensor(1), + perform_amortized_computation=False, ) preconditioner_list.update_preconditioners( masked_grad_list=masked_grad_list2, step=torch.tensor(2), + perform_amortized_computation=True, ) - preconditioner_list.compute_root_inverse() # Expect no relative errors and residuals because L is a diagonal matrix. ( @@ -709,3 +754,407 @@ def test_compute_root_inverse_residuals(self) -> None: self.assertTupleEqual(relative_errors, expected_relative_errors) self.assertTupleEqual(relative_residuals, expected_relative_residuals) + + +class EigenvalueCorrectedShampooPreconditionerListTest(AdagradPreconditionerListTest): + def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: + kwargs = { + "beta2": 1.0, + "epsilon": 1e-12, + "inv_root_override": 0, + "use_bias_correction": True, + "use_protected_eigh": True, + "preconditioner_computation_config": DefaultEighEigenvalueCorrectionConfig, + } | kwargs + return EigenvalueCorrectedShampooPreconditionerList( + block_list=self._block_list, + state=self._state, + block_info_list=self._block_info_list, + distributor_selector=self._distributor_selector, + precision_config=PrecisionConfig(factor_matrix_dtype=torch.float64), + **kwargs, + ) + + def test_update_preconditioners_and_precondition(self) -> None: + """ + We provide examples where we update the preconditioners twice using specially + chosen gradients such that we get a scalar * identity matrix for both Kronecker + factor matrices for all parameters of interest. + + Specifically, for the beta2 = 1 case, we have 3 parameters and define their gradients + as the following in order to get the expected preconditioned gradient list: + + (1) Tensor of Size 2 + G1 = [1, 0]^T + G2 = [0, 1]^T + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + B = [[1, 0], [0, 1]] # eigenvectors of L + E = G1^2 + (B G2)^2 # corrected eigenvalues + P = B ((B G2) / sqrt(E + eps)) = G2 / sqrt(E + eps) ≈ G2 + + (2) Tensor of Size 2 x 2 + G1 = [[1, 0], [0, 1]] / sqrt(2) + G2 = [[1, 0], [0, 1]] / sqrt(2) + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = [[1, 0], [0, 1]] # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / sqrt(E + eps) B_R = G2 / sqrt(E + eps) ≈ G2 + + (3) Tensor of Size 1 x 2 + G1 = [[1, 0]] + G2 = [[0, 1]] + + L = G1 * G1^T + G2 * G2^T = 2 + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = 1 # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / sqrt(E + eps) B_R = G2 / sqrt(E + eps) ≈ G2 + + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ) + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + use_bias_correction=True, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + """ + For the other two cases (beta2 < 1), note: + + E = beta2 * (1 - beta2) G1^2 + (1 - beta2) G2^2 + + Therefore, in order to retain the identity matrix, we simply need to scale each gradient by: + + G1 -> G1 / sqrt(beta2 * (1 - beta2)) + G2 -> G2 / sqrt(1 - beta2). + + """ + beta2 = 0.9 + + beta2_compensated_grad_list1 = torch._foreach_div( + masked_grad_list1, + torch.tensor(beta2 * (1 - beta2)).sqrt(), + ) + beta2_compensated_grad_list2 = torch._foreach_div( + masked_grad_list2, + torch.tensor(1 - beta2).sqrt(), + ) + + masked_expected_preconditioned_grad_list = [ + torch.tensor([0.0, 1.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 1.0]]), + ] + # Fix scaling due to EMA. + torch._foreach_div_( + masked_expected_preconditioned_grad_list, + torch.tensor(1 - beta2).sqrt(), + ) + + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=beta2, + use_bias_correction=False, + ), + masked_grad_lists=[ + beta2_compensated_grad_list1, + beta2_compensated_grad_list2, + ], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + """ + For the last case of including bias correction, we re-scale the entire matrix by the + bias correction at iteration 2. + + E -> E / (1 - beta2^2). + + Therefore, it is sufficient to additionally scale by this value: + + G1 -> sqrt(1 - beta2^2) * G1 + G2 -> sqrt(1 - beta2^2) * G2. + + """ + bias_compensated_grad_list1 = torch._foreach_mul( + beta2_compensated_grad_list1, + torch.tensor(1 - beta2**2).sqrt(), + ) + bias_compensated_grad_list2 = torch._foreach_mul( + beta2_compensated_grad_list2, + torch.tensor(1 - beta2**2).sqrt(), + ) + + # Fix scaling due to bias correction. + torch._foreach_mul_( + masked_expected_preconditioned_grad_list, + torch.tensor(1 - beta2**2).sqrt(), + ) + + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=beta2, + use_bias_correction=True, + ), + masked_grad_lists=[ + bias_compensated_grad_list1, + bias_compensated_grad_list2, + ], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + def test_inv_root_override(self) -> None: + """ + For this example, we modify the one given above such that the inv_root_override = 1. + Therefore, in all cases, the exponent is -2 / 2 = -1. + This should result in the following behavior: + + (1) Tensor of Size 2 + G1 = [1, 0]^T + G2 = [0, 2]^T + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 4]] + B = [[1, 0], [0, 1]] # eigenvectors of L + E = G1^2 + (B G2)^2 # corrected eigenvalues + P = B ((B G2) / (E + eps) = G2 / (E + eps) ≈ [0, 0.5]^T + + (2) Tensor of Size 2 x 2 + G1 = [[1, 0], [0, 1]] / sqrt(2) + G2 = [[1, 0], [0, 1]] / sqrt(2) + + L = G1 * G1^T + G2 * G2^T = [[1, 0], [0, 1]] + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 1]] + B_L = [[1, 0], [0, 1]] # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / (E + eps) B_R = G2 / (E + eps) ≈ G2 + + (3) Tensor of Size 1 x 2 + G1 = [[1, 0]] + G2 = [[0, 2]] + + L = G1 * G1^T + G2 * G2^T = 2 + R = G1^T * G1 + G2^T * G2 = [[1, 0], [0, 4]] + B_L = 1 # eigenvectors of L + B_R = [[1, 0], [0, 1]] # eigenvectors of R + E = G1^2 + (B_L G2 B_R)^2 # corrected eigenvalues + P = B_L ((B_L G2 B_R) / (E + eps)) B_R = G2 / (E + eps) ≈ [[0, 0.5]] + + """ + + def test_inverse_roots_from_override( + inv_root_override: Union[int, List[int]], + ) -> None: + """ + Tests that the inverse roots are computed correctly from inv_root_override. + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 2.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0.0, 2.0]]), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0, 0.5]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[0, 0.5]]), + ) + + with self.subTest(inv_root_override=inv_root_override): + self._test_update_preconditioners_and_precondition( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=1.0, + use_bias_correction=True, + inv_root_override=inv_root_override, + preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + test_inverse_roots_from_override(inv_root_override=1) + test_inverse_roots_from_override(inv_root_override=[1, 1, 1]) + + """Tests for compute_preconditioner_eigenvectors.""" + + def test_raise_inf_in_factor_matrix_compute_preconditioner_eigenvectors( + self, + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ): + with self.assertRaisesRegex( + ValueError, + re.escape("Encountered inf values in factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.inf, torch.inf]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.inf, torch.inf]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + + def test_raise_nan_in_factor_matrix_compute_preconditioner_eigenvectors( + self, + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ): + with self.assertRaisesRegex( + ValueError, + re.escape("Encountered nan values in factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([torch.nan, torch.nan]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[torch.nan, torch.nan]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + + # Note: This is needed for pyre to infer the type of argument into mock.patch.object. + shampoo_preconditioner_list_module: ModuleType = shampoo_preconditioner_list + + @mock.patch.object( + shampoo_preconditioner_list_module, + "matrix_eigenvectors", + side_effect=(torch.tensor([torch.inf]),), + ) + def test_raise_inf_in_inv_factor_matrix_compute_preconditioner_eigenvectors( + self, mock_matrix_eigenvectors: mock.Mock + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_matrix_eigenvectors.assert_called_once() + + @mock.patch.object( + shampoo_preconditioner_list_module, + "matrix_eigenvectors", + side_effect=(torch.tensor([torch.nan]),), + ) + def test_raise_nan_in_inv_factor_matrix_compute_preconditioner_eigenvectors( + self, mock_matrix_eigenvectors: mock.Mock + ) -> None: + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in eigenvectors of factor matrix"), + ): + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + mock_matrix_eigenvectors.assert_called_once() + + @mock.patch.object( + shampoo_preconditioner_list_module, + "check_diagonal", + return_value=False, + ) + def test_matrix_compute_preconditioner_eigenvectors_factor_matrix_non_diagonal( + self, mock_check_diagonal: mock.Mock + ) -> None: + self._preconditioner_list = self._instantiate_preconditioner_list(epsilon=1.0) + with DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), self.assertLogs( + level="DEBUG", + ) as cm: + self._preconditioner_list.update_preconditioners( + masked_grad_list=( + torch.tensor([1.0, 0.0]), + torch.eye(2) / torch.tensor(2.0).sqrt(), + torch.tensor([[1.0, 0.0]]), + ), + step=torch.tensor(1), + perform_amortized_computation=True, + ) + self.assertCountEqual( + [r.msg for r in cm.records], + [ + "Factor matrix 0.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.0 is not diagonal.", + "Factor matrix 1.block_0.1 is not diagonal.", + "Factor matrix 1.block_1.0 is not diagonal.", + "Factor matrix 1.block_1.1 is not diagonal.", + ], + ) + mock_check_diagonal.assert_called() + + """End of tests for compute_preconditioner_eigenvectors.""" + + def test_numel_list(self) -> None: + self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10)) + + def test_dims_list(self) -> None: + self.assertEqual( + self._preconditioner_list.dims_list, + (torch.Size([2]), torch.Size([2, 2]), torch.Size([1, 2])), + ) + + def test_num_bytes_list(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes_list, (48, 96, 60)) + + def test_numel(self) -> None: + self.assertEqual(self._preconditioner_list.numel(), 34) + + def test_num_bytes(self) -> None: + self.assertEqual(self._preconditioner_list.num_bytes(), 204) + + def test_compress_preconditioner_list(self) -> None: + self._test_compress_preconditioner_list(expected_compress_list_call_count=3)