Skip to content

Commit

Permalink
Add option to correct eigenvalues of Shampoo's preconditioner
Browse files Browse the repository at this point in the history
Summary:
This update is based on #27, developed by Runa Eschenhagen (runame) and Tsung-Hsien Lee (tsunghsienlee). The research idea in this update originated from Runa Eschenhagen's internship at The Fundamental AI Research (FAIR) at Meta during the summer of 2024. Concurrently, Runa Eschenhagen, Michael Shi (hjmshi), Aaron Defazio (adefazio) worked on this method, which was also empirically evaluated on language models by Nikhil Vyas et al. [3], showing promising results.

This update enables approximately correcting the eigenvalues and running Adam in the eigenbasis of Shampoo's preconditioner. A variation of this method was first proposed for K-FAC by George et al. [1], and Anil et al. [2] noted its applicability to Shampoo in Appendix B, although they did not present empirical results or further discussion.

References:
1. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
2. [Scalable Second-Order Optimization for Deep Learning](https://arxiv.org/pdf/2002.09018.pdf). Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, and Yoram Singer. Tech Report, 2021.
3. [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade. Tech Report, 2024.

Reviewed By: hjmshi

Differential Revision: D65402620

fbshipit-source-id: 8ea4f761cfae04c5622a968cb499654816e4aa3e
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Nov 4, 2024
1 parent cc0a1ee commit f3451cd
Show file tree
Hide file tree
Showing 7 changed files with 1,731 additions and 336 deletions.
5 changes: 5 additions & 0 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Key distinctives of this implementation include:
- Choice of precision for preconditioner accumulation and root inverse computation.
- Ability to cache split parameters.
- Merging of small dimensions.
- [EXPERIMENTAL] Option to (approximately) correct the eigenvalues/run Adam in the eigenbasis of Shampoo's preconditioner [2,6,7].

## Requirements

Expand All @@ -62,6 +63,8 @@ A few notes on hyperparameters:

- We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay.

- When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.

### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum

If we previously used the optimizer:
Expand Down Expand Up @@ -479,3 +482,5 @@ When encountering those errors, following are things you could try:
3. [Learning Rate Grafting: Transferability of Optimizer Tuning](https://openreview.net/pdf?id=FpKgG31Z_i9). Naman Agarwal, Rohan Anil, Elad Hazan, Tomer Koren, and Cyril Zhang. Tech Report, 2021.
4. [Functions of Matrices: Theory and Computation](https://epubs.siam.org/doi/book/10.1137/1.9780898717778). Nicholas J. Higham. SIAM, 2008.
5. [A Distributed Data-Parallel PyTorch Implementation of the Distributed Shampoo Optimizer for Training Neural Networks At-Scale](https://arxiv.org/pdf/2309.06497.pdf). Hao-Jun Michael Shi, Tsung-Hsien Lee, Shintaro Iwasaki, Jose Gallego-Posada, Zhijing Li, Kaushik Rangadurai, Dheevatsa Mudigere, and Michael Rabbat. Tech Report, 2023.
6. [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884). Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
7. [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade. Tech Report, 2024.
116 changes: 77 additions & 39 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@
PRECISION_CONFIG,
PrecisionConfig,
PRECONDITION_FREQUENCY,
PRECONDITIONER_COMPUTATION_CONFIG,
PREVIOUS_GRAD_SELECTOR,
RMSpropGraftingConfig,
ROOT_INV_CONFIG,
SGDGraftingConfig,
SHAMPOO_PRECONDITIONER_LIST,
ShampooPT2CompileConfig,
Expand Down Expand Up @@ -90,6 +90,7 @@
from distributed_shampoo.utils.shampoo_preconditioner_list import (
AdagradPreconditionerList,
DequantizePreconditionersContext,
EigenvalueCorrectedShampooPreconditionerList,
SGDPreconditionerList,
ShampooPreconditionerList,
)
Expand All @@ -100,7 +101,13 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import DefaultEigenConfig, EigenConfig, RootInvConfig
from matrix_functions_types import (
DefaultEigenConfig,
EigenConfig,
EigenvalueCorrectionConfig,
PreconditionerComputationConfig,
RootInvConfig,
)
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -214,6 +221,16 @@ class DistributedShampoo(torch.optim.Optimizer):
particular tensor shape. Recommended to use `static` mode here for Shampoo.
More about dynamic shape: https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html
5. [EXPERIMENTAL] Eigenvalue correction: We can (approximately) correct the eigenvalues of Shampoo's preconditioner by accumulating a running
average of the squared gradient in the eigenbasis of Shampoo's preconditioner. This running average (with hyperparameter `betas[1]`) is
updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps.
Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner.
When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning
rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be
a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet.
Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
Args:
params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
lr (float): Learning rate. (Default: 1e-2)
Expand All @@ -228,13 +245,18 @@ class DistributedShampoo(torch.optim.Optimizer):
dampening (float): Dampening parameter for momentum. (Default: 0.)
weight_decay (float): Weight decay (L2 penalty). (Default: 0.)
max_preconditioner_dim (int): Maximum preconditioner dimensio. (Default: 1024)
precondition_frequency (int): Frequency for computing root inverse preconditioner. (Default: 1)
precondition_frequency (int): Frequency of updating all components of the preconditioner.
If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner.
(Default: 1)
start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses
the same value as precondition_frequency. (Default: -1)
inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
root -1 / (2 * o), where o is the order of the tensor. (Default: 0)
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of
EigenvalueCorrectionConfig, the default is -1 / 2.
(Default: 0)
exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: None)
use_nesterov (bool): Flag for using Nesterov momentum. (default: False)
Expand All @@ -259,7 +281,10 @@ class DistributedShampoo(torch.optim.Optimizer):
3. Otherwise, re-uses previous inverse factor matrix when both root inverse computations fail.
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
root_inv_config (RootInvConfig): Configuration for root inverse computation. (Default: DefaultEigenConfig)
preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation.
If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
(Default: DefaultEigenConfig)
"""

Expand Down Expand Up @@ -290,7 +315,7 @@ def __init__(
precision_config: Optional[PrecisionConfig] = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
root_inv_config: RootInvConfig = DefaultEigenConfig,
preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig,
) -> None:
# Hyperparameter checks.
if not lr >= 0.0:
Expand Down Expand Up @@ -404,17 +429,28 @@ def __init__(
"Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated."
)

if (
not isinstance(preconditioner_computation_config, RootInvConfig)
) and track_root_inv_residuals:
raise ValueError(
f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig."
)

# Create default precision config if it is not provided.
if precision_config is None:
precision_config = PrecisionConfig()

# Set exponent multiplier if this is not provided.
if isinstance(root_inv_config, EigenConfig) and exponent_multiplier is not None:
if (
isinstance(preconditioner_computation_config, EigenConfig)
and exponent_multiplier is not None
):
logger.warning(
f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future."
)
root_inv_config = dataclasses.replace(
root_inv_config, exponent_multiplier=exponent_multiplier
preconditioner_computation_config = dataclasses.replace(
preconditioner_computation_config,
exponent_multiplier=exponent_multiplier,
)

super().__init__(
Expand All @@ -437,7 +473,7 @@ def __init__(
GRAFTING_CONFIG: grafting_config,
USE_MERGE_DIMS: use_merge_dims,
PRECISION_CONFIG: precision_config,
ROOT_INV_CONFIG: root_inv_config,
PRECONDITIONER_COMPUTATION_CONFIG: preconditioner_computation_config,
},
)

Expand Down Expand Up @@ -508,17 +544,25 @@ def _instantiate_shampoo_preconditioner_list(
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
state_lists[SHAMPOO_PRECONDITIONER_LIST] = ShampooPreconditionerList(
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
EigenvalueCorrectedShampooPreconditionerList
if isinstance(
group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig
)
else ShampooPreconditionerList
)(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
distributor_selector=state_lists[DISTRIBUTOR].distributor_selector,
root_inv_config=group[ROOT_INV_CONFIG],
preconditioner_computation_config=group[
PRECONDITIONER_COMPUTATION_CONFIG
],
precision_config=group[PRECISION_CONFIG],
beta2=group[BETAS][1],
epsilon=group[EPSILON],
inv_root_override=group[INV_ROOT_OVERRIDE],
use_bias_correction=group[USE_BIAS_CORRECTION],
precision_config=group[PRECISION_CONFIG],
use_protected_eigh=use_protected_eigh,
)

Expand Down Expand Up @@ -755,6 +799,7 @@ def _mask_state_lists(state_lists: Dict[str, Any], group: Dict[str, Any]) -> Non
)

@torch.no_grad()
@torch.compiler.disable
def _compute_and_log_root_inverse_residuals(
self,
) -> None:
Expand Down Expand Up @@ -806,16 +851,6 @@ def _compute_and_log_root_inverse_residuals(
f"{torch.quantile(relative_residuals, quantiles, interpolation='nearest')}"
)

@torch.no_grad()
@torch.compiler.disable
def _compute_root_inverse(
self, state_lists: Dict[str, Any], compute_root_inverse: bool
) -> None:
if compute_root_inverse:
state_lists[SHAMPOO_PRECONDITIONER_LIST].compute_root_inverse()
if self._track_root_inv_residuals:
self._compute_and_log_root_inverse_residuals()

@torch.no_grad()
@torch.compiler.disable
def _precondition_and_grafting(
Expand Down Expand Up @@ -881,13 +916,17 @@ def _update_preconditioners(
self,
state_lists: Dict[str, Any],
step: torch.Tensor,
perform_amortized_computation: bool,
grafting_config_not_none: bool,
) -> None:
# Update Shampoo and grafting preconditioners / factor matrices.
# Update Shampoo and grafting preconditioners.
state_lists[SHAMPOO_PRECONDITIONER_LIST].update_preconditioners(
masked_grad_list=state_lists[MASKED_BLOCKED_GRADS],
step=step,
perform_amortized_computation=perform_amortized_computation,
)
if perform_amortized_computation and self._track_root_inv_residuals:
self._compute_and_log_root_inverse_residuals()
if grafting_config_not_none:
state_lists[GRAFTING_PRECONDITIONER_LIST].update_preconditioners(
masked_grad_list=state_lists[MASKED_BLOCKED_GRADS],
Expand Down Expand Up @@ -1005,7 +1044,7 @@ def _per_group_step_impl(
momentum_param: float,
dampening: float,
grafting_config_not_none: bool,
compute_root_inverse: bool,
perform_amortized_computation: bool,
use_decoupled_weight_decay: bool,
use_bias_correction: bool,
use_grafting_method: bool,
Expand All @@ -1028,23 +1067,23 @@ def _per_group_step_impl(
if grafting_config_not_none
else contextlib.nullcontext()
):
# Update Shampoo and grafting preconditioners / factor matrices.
# Example for AdaGrad accumulation:
# Update Shampoo and grafting preconditioners.
# Example for AdaGrad accumulation:
# 1. Update factor matrices/grafting preconditioners.
# L <- L + G * G^T
# R <- R + G^T * G
# V <- V + G^2 (element-wise)
# (and similar)
self._update_preconditioners(
state_lists,
step,
grafting_config_not_none,
)

# Compute matrix root inverse.
# 2. Compute root inverse if necessary.
# L_inv <- L ** (-1/4)
# R_inv <- R ** (-1/4)
# (and similar)
self._compute_root_inverse(state_lists, compute_root_inverse)
# (and similar);
self._update_preconditioners(
state_lists=state_lists,
step=step,
perform_amortized_computation=perform_amortized_computation,
grafting_config_not_none=grafting_config_not_none,
)

# Compute filtered gradient or EMA of the gradients if beta1 > 0 and beta3 > 0.
# Note that we use two beta factors here akin to Lion.
Expand Down Expand Up @@ -1157,8 +1196,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
momentum_param = group[MOMENTUM]
dampening = group[DAMPENING]
grafting_config_not_none = group[GRAFTING_CONFIG] is not None
# Check compute root inverse or not for preconditioner
compute_root_inverse = (
perform_amortized_computation = (
step.item() % group[PRECONDITION_FREQUENCY] == 0
and step.item() > group[START_PRECONDITIONING_STEP]
or step.item() == group[START_PRECONDITIONING_STEP]
Expand All @@ -1182,7 +1220,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
momentum_param,
dampening,
grafting_config_not_none,
compute_root_inverse,
perform_amortized_computation,
use_decoupled_weight_decay,
use_bias_correction,
use_grafting_method,
Expand Down
Loading

0 comments on commit f3451cd

Please sign in to comment.