Skip to content

Commit

Permalink
Refactor Shampoo types
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 6, 2024
1 parent 55f1413 commit 9291d15
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 79 deletions.
60 changes: 35 additions & 25 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
BETAS,
DAMPENING,
DDPShampooConfig,
DefaultShampooConfig,
DistributedConfig,
DISTRIBUTOR,
EigenvalueCorrectedShampooPreconditionerConfig,
EPSILON,
FILTERED_GRAD,
FILTERED_GRAD_LIST,
Expand All @@ -49,10 +51,12 @@
PrecisionConfig,
PRECONDITION_FREQUENCY,
PRECONDITIONER_COMPUTATION_CONFIG,
PreconditionerComputationConfig,
PREVIOUS_GRAD_SELECTOR,
RMSpropGraftingConfig,
SGDGraftingConfig,
SHAMPOO_PRECONDITIONER_LIST,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
START_PRECONDITIONING_STEP,
STEP,
Expand Down Expand Up @@ -91,13 +95,7 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import (
DefaultEigenConfig,
EigenConfig,
EigenvalueCorrectionConfig,
PreconditionerComputationConfig,
RootInvConfig,
)
from matrix_functions_types import EigenConfig, RootInvConfig
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -216,7 +214,7 @@ class DistributedShampoo(torch.optim.Optimizer):
updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps.
Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner, also known as SOAP.
When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning
When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectedShampooPreconditionerConfig`, there is typically no need to use learning
rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be
a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet.
Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
Expand All @@ -236,16 +234,16 @@ class DistributedShampoo(torch.optim.Optimizer):
weight_decay (float): Weight decay (L2 penalty). (Default: 0.)
max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024)
precondition_frequency (int): Frequency of updating all components of the preconditioner.
If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner.
If this field is an instance ShampooPreconditionerConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig, this is the update frequency of the eigenbasis of preconditioner.
(Default: 1)
start_preconditioning_step (int): Iteration to start computing inverse preconditioner. If -1, uses
the same value as precondition_frequency. (Default: -1)
inv_root_override (int, Sequence[int]): Inverse root to use in Shampoo. If a list [l1, l2, ..., lp], then we will
use -1 / l1 for 1-D tensor (vectors), -1 / l2 for 2-D tensors (matrices), and so on. If the order of the
tensor exceeds the order of the tensor, reverts to the default value. If 0 is used, uses the default inverse
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_computation_config is an instance of
EigenvalueCorrectionConfig, the default is -1 / 2.
EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2.
(Default: 0)
exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: None)
Expand All @@ -272,8 +270,8 @@ class DistributedShampoo(torch.optim.Optimizer):
track_root_inv_residuals (bool): Track errors and residuals of root inverse. For debugging purposes.
(Default: False)
preconditioner_computation_config (PreconditionerComputationConfig): Configuration for preconditioner computation.
If this field is an instance RootInvConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
If this field is an instance ShampooPreconditionerConfig, Shampoo uses the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectedShampooPreconditionerConfig Shampoo uses corrected the eigenvalues/running Adam in the eigenbasis of preconditioner.
(Default: DefaultEigenConfig)
"""
Expand Down Expand Up @@ -305,7 +303,7 @@ def __init__(
precision_config: PrecisionConfig | None = None,
use_protected_eigh: bool = True,
track_root_inv_residuals: bool = False,
preconditioner_computation_config: PreconditionerComputationConfig = DefaultEigenConfig,
preconditioner_computation_config: PreconditionerComputationConfig = DefaultShampooConfig,
) -> None:
# Hyperparameter checks.
if not lr >= 0.0:
Expand Down Expand Up @@ -419,11 +417,14 @@ def __init__(
"Both preconditioner_dtype and precision_config are provided. Please use only precision_config as preconditioner_dtype is deprecated."
)

amortized_computation_config = (
preconditioner_computation_config.amortized_computation_config
)
if (
not isinstance(preconditioner_computation_config, RootInvConfig)
not isinstance(amortized_computation_config, RootInvConfig)
) and track_root_inv_residuals:
raise ValueError(
f"{track_root_inv_residuals=} has to be set to False when {preconditioner_computation_config=} is not an instance of RootInvConfig."
f"{track_root_inv_residuals=} has to be set to False when {amortized_computation_config=} is not an instance of RootInvConfig."
)

# Create default precision config if it is not provided.
Expand All @@ -432,14 +433,14 @@ def __init__(

# Set exponent multiplier if this is not provided.
if (
isinstance(preconditioner_computation_config, EigenConfig)
isinstance(amortized_computation_config, EigenConfig)
and exponent_multiplier is not None
):
logger.warning(
f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future."
)
preconditioner_computation_config = dataclasses.replace(
preconditioner_computation_config,
amortized_computation_config = dataclasses.replace(
amortized_computation_config,
exponent_multiplier=exponent_multiplier,
)

Expand Down Expand Up @@ -534,13 +535,22 @@ def _instantiate_shampoo_preconditioner_list(
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
EigenvalueCorrectedShampooPreconditionerList
if isinstance(
group[PRECONDITIONER_COMPUTATION_CONFIG], EigenvalueCorrectionConfig
if (
type(group[PRECONDITIONER_COMPUTATION_CONFIG])
is ShampooPreconditionerConfig
):
preconditioner_list_cls = ShampooPreconditionerList
elif (
type(group[PRECONDITIONER_COMPUTATION_CONFIG])
is EigenvalueCorrectedShampooPreconditionerConfig
):
preconditioner_list_cls = EigenvalueCorrectedShampooPreconditionerList # type: ignore[assignment]
else:
raise NotImplementedError(
f"{group[PRECONDITIONER_COMPUTATION_CONFIG]=} not supported!"
)
else ShampooPreconditionerList
)(

state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls(
block_list=state_lists[DISTRIBUTOR].global_blocked_params,
state=self.state,
block_info_list=state_lists[DISTRIBUTOR].global_block_info_list,
Expand Down
21 changes: 13 additions & 8 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@
CommunicationDType,
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
DistributedShampoo,
EigenConfig,
EighEigenvalueCorrectionConfig,
GraftingConfig,
PrecisionConfig,
PreconditionerComputationConfig,
QREigenvalueCorrectionConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
)
from distributed_shampoo.examples.convnet import ConvNet

Expand Down Expand Up @@ -541,27 +542,31 @@ def instantiate_preconditioner_computation_config(
preconditioner_computation_type: PreconditionerComputationType,
) -> PreconditionerComputationConfig:
if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV:
return EigenConfig()
return DefaultShampooConfig
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
):
return CoupledNewtonConfig()
return ShampooPreconditionerConfig(
amortized_computation_config=CoupledNewtonConfig(),
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_HIGHER_ORDER_ROOT_INV
):
return CoupledHigherOrderConfig()
return ShampooPreconditionerConfig(
amortized_computation_config=CoupledHigherOrderConfig(),
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.EIGH_EIGENVALUE_CORRECTION
):
return EighEigenvalueCorrectionConfig()
return DefaultEigenvalueCorrectedShampooConfig
elif (
preconditioner_computation_type
== PreconditionerComputationType.QR_EIGENVALUE_CORRECTION
):
return QREigenvalueCorrectionConfig()
return DefaultSOAPConfig
else:
raise ValueError(
f"Invalid PreconditionerComputationType {preconditioner_computation_type}!"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
DefaultEigenvalueCorrectedShampooConfig,
DefaultSOAPConfig,
)
from distributed_shampoo.tests.shampoo_test_utils import (
compare_two_optimizers_on_weight_and_loss,
)
from matrix_functions_types import (
DefaultEighEigenvalueCorrectionConfig,
QREigenvalueCorrectionConfig,
)
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand Down
60 changes: 60 additions & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
import torch

from commons import AbstractDataclass

from matrix_functions_types import (
DefaultEigenConfig,
DefaultEighConfig,
EigenvectorConfig,
MatrixFunctionConfig,
QRConfig,
RootInvConfig,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -71,6 +80,57 @@ class PreconditionerValueError(ValueError):


###### DATACLASSES ######
@dataclass(init=False)
class PreconditionerComputationConfig(AbstractDataclass):
"""Configuration for preconditioner computation in DistributedShampoo.
Args:
amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation.
"""

amortized_computation_config: MatrixFunctionConfig


@dataclass(kw_only=True)
class ShampooPreconditionerConfig(PreconditionerComputationConfig):
"""Configuration for Shampoo preconditioner computation.
Args:
amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation.
"""

amortized_computation_config: RootInvConfig


DefaultShampooConfig = ShampooPreconditionerConfig(
amortized_computation_config=DefaultEigenConfig
)


@dataclass(kw_only=True)
class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerComputationConfig):
"""Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation.
Args:
amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation.
"""

amortized_computation_config: EigenvectorConfig


DefaultEigenvalueCorrectedShampooConfig = (
EigenvalueCorrectedShampooPreconditionerConfig(
amortized_computation_config=DefaultEighConfig,
)
)
DefaultSOAPConfig = EigenvalueCorrectedShampooPreconditionerConfig(
amortized_computation_config=QRConfig(),
)


@dataclass
class FSDPParameterMetadata:
"""FSDP Metadata for a parameter.
Expand Down
Loading

0 comments on commit 9291d15

Please sign in to comment.