Skip to content

Commit

Permalink
Open-sourced update on 02/20/2024 (#84)
Browse files Browse the repository at this point in the history
Summary:

1. Remove `exponent_multiplier` in the hyperparameters.
2. Add `eigen_decomp_offload_device` in `_matrix_inverse_root_eigen()`.
3. `dataclass` docstrings groomings.
4. Remove PyTorch 2.0 compile check with cuda due to its limitation to Nvidia GPU only when other accelerators are also available; this makes the config check pure-algorithmic setting only.

Differential Revision: D69944184
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Feb 21, 2025
1 parent 23e8606 commit dfb3c4b
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 104 deletions.
34 changes: 4 additions & 30 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import dataclasses
import logging
from collections.abc import Callable, Iterator, Sequence
from copy import deepcopy
Expand Down Expand Up @@ -91,7 +90,6 @@
)
from distributed_shampoo.utils.shampoo_utils import compress_list

from matrix_functions_types import EigenConfig
from torch.optim.optimizer import ParamsT, StateDict

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -271,8 +269,6 @@ class DistributedShampoo(torch.optim.Optimizer):
root -1 / (2 * o), where o is the order of the tensor. If preconditioner_config is an instance of
EigenvalueCorrectedShampooPreconditionerConfig, the default is -1 / 2.
(Default: 0)
exponent_multiplier (float | None): **DEPRECATING** Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: None)
use_nesterov (bool): Flag for using Nesterov momentum. (default: False)
use_bias_correction (bool): Flag for using bias correction. (Default: True)
use_decoupled_weight_decay (bool): Flag for using AdamW-style decoupled weight decay. (Default: True)
Expand Down Expand Up @@ -306,7 +302,6 @@ def __init__(
precondition_frequency: int = 1,
start_preconditioning_step: int = -1,
inv_root_override: int | Sequence[int] = 0,
exponent_multiplier: float | None = None,
use_nesterov: bool = False,
use_bias_correction: bool = True,
use_decoupled_weight_decay: bool = True,
Expand Down Expand Up @@ -390,28 +385,6 @@ def __init__(
"Continuing without using momentum or Nesterov acceleration..."
)

# Provide error for system Pytorch compile availability
if shampoo_pt2_compile_config is not None and not torch.cuda.is_available():
raise ValueError(
"Backend does NOT support Pytorch 2.0 compile. Switch to shampoo_pt2_compile_config=None."
)

amortized_computation_config = (
preconditioner_config.amortized_computation_config
)
# Set exponent multiplier if this is not provided.
if (
isinstance(amortized_computation_config, EigenConfig)
and exponent_multiplier is not None
):
logger.warning(
f"{exponent_multiplier=} is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future."
)
amortized_computation_config = dataclasses.replace(
amortized_computation_config,
exponent_multiplier=exponent_multiplier,
)

super().__init__(
params,
{
Expand Down Expand Up @@ -561,7 +534,9 @@ def _instantiate_grafting(self) -> None:

@torch.no_grad()
def _instantiate_steps(self) -> None:
for state_lists in self._per_group_state_lists:
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
assert (
len(state_lists[DISTRIBUTOR].local_block_info_list) > 0
), "There is no params in your param_group. Please check the instantiation of DistributedShampoo "
Expand All @@ -573,8 +548,7 @@ def _instantiate_steps(self) -> None:

# In order to ensure that the step counter is checkpointed correctly, we store it
# as a tensor (which is replicated across all devices) under the first parameter's state.
block_info = state_lists[DISTRIBUTOR].local_block_info_list[0]
self.state[block_info.param][STEP] = state_lists[STEP]
self.state[group[PARAMS][0]][STEP] = state_lists[STEP]

@torch.no_grad()
def _instantiate_momentum(self) -> None:
Expand Down
18 changes: 14 additions & 4 deletions distributed_shampoo/examples/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
DistributedShampoo,
EigenConfig,
GraftingConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
Expand Down Expand Up @@ -424,7 +424,6 @@ def instantiate_optimizer(
precondition_frequency=precondition_frequency,
start_preconditioning_step=start_preconditioning_step,
inv_root_override=inv_root_override,
exponent_multiplier=exponent_multiplier,
use_nesterov=use_nesterov,
use_bias_correction=use_bias_correction,
use_decoupled_weight_decay=use_decoupled_weight_decay,
Expand All @@ -435,7 +434,8 @@ def instantiate_optimizer(
distributed_config=distributed_config,
preconditioner_dtype=preconditioner_dtype.value,
preconditioner_config=instantiate_preconditioner_config(
preconditioner_computation_type
preconditioner_computation_type=preconditioner_computation_type,
exponent_multiplier=exponent_multiplier,
),
) # type: ignore[assignment]
else:
Expand Down Expand Up @@ -476,9 +476,19 @@ def instantiate_grafting_config(

def instantiate_preconditioner_config(
preconditioner_computation_type: PreconditionerComputationType,
exponent_multiplier: float,
) -> PreconditionerConfig:
assert (
exponent_multiplier == 1.0
or preconditioner_computation_type
== PreconditionerComputationType.EIGEN_ROOT_INV
), "Exponent multiplier is only supported for EIGH root inverse computation."
if preconditioner_computation_type == PreconditionerComputationType.EIGEN_ROOT_INV:
return DefaultShampooConfig
return ShampooPreconditionerConfig(
amortized_computation_config=EigenConfig(
exponent_multiplier=exponent_multiplier
)
)
elif (
preconditioner_computation_type
== PreconditionerComputationType.COUPLED_NEWTON_ROOT_INV
Expand Down
26 changes: 14 additions & 12 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PreconditionerValueError(ValueError):
class PreconditionerConfig(AbstractDataclass):
"""Configuration for preconditioner computation in DistributedShampoo.
Args:
Attributes:
amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation.
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
Expand All @@ -103,8 +103,9 @@ def __post_init__(self) -> None:
class ShampooPreconditionerConfig(PreconditionerConfig):
"""Configuration for Shampoo preconditioner computation.
Args:
Attributes:
amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig)
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
"""

Expand All @@ -120,9 +121,10 @@ class ShampooPreconditionerConfig(PreconditionerConfig):
class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
"""Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation.
Args:
Attributes:
amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation.
(Default: DefaultEighEigenvectorConfig)
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
"""

Expand All @@ -143,7 +145,7 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
class FSDPParameterMetadata:
"""FSDP Metadata for a parameter.
Args:
Attributes:
fqn (str): Fully qualified name of the parameter.
shape (torch.Size): Shape of the parameter.
numel (int): Number of elements in the parameter.
Expand Down Expand Up @@ -172,7 +174,7 @@ class DDPShampooConfig(DistributedConfig):
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo.
Args:
Attributes:
communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT)
num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory.
If num_trainers_per_group = -1 is used, then defaults to using the LOCAL_WORLD_SIZE. (Default: -1)
Expand All @@ -192,7 +194,7 @@ class FSDPShampooConfig(DistributedConfig):
Passes in additional metadata necessary to run FSDP Shampoo.
Args:
Attributes:
param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from FSDP.
"""
Expand All @@ -207,7 +209,7 @@ class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig):
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared
parameters between different HSDP process groups.
Args:
Attributes:
device_mesh (torch.distributed.device_mesh.DeviceMesh): A 2D device mesh that specifies the layout of the numbers of
shard and replicate dimensions.
param_to_metadata (dict[Parameter, FSDPParameterMetadata]): Dictionary mapping parameter to its metadata from HSDP.
Expand Down Expand Up @@ -238,7 +240,7 @@ class HybridShardShampooConfig(FullyShardShampooConfig, DDPShampooConfig):
Enables distributed computation and optimizer states (like ZeRO-1) via DTensor for Shampoo across ranks with shared
parameters between different Hybrid Shard process groups.
Args:
Attributes:
device_mesh (torch.distributed.device_mesh.DeviceMesh): Device mesh for Hybrid Shard.
communication_dtype (CommunicationDType): Data type for communication between ranks. (Default: DEFAULT)
num_trainers_per_group (int): Number of GPUs per distributed process group for distributed computation/memory.
Expand All @@ -259,7 +261,7 @@ class ShampooPT2CompileConfig:
Enables Shampoo pytorch compilation with configure to speed up model training.
For more details: https://pytorch.org/get-started/pytorch-2.0/
Args:
Attributes:
pytorch_compile_backend (str): The backend for PT2 compilation. More info about PT2 backends:
https://pytorch.org/docs/stable/torch.compiler.html (Default: inductor)
enable_shampoo_pt2_dynamic_shape (bool | None): Compile Shampoo in static, dynamic or auto-dynamic shape mode (Default: False).
Expand Down Expand Up @@ -291,7 +293,7 @@ class SGDGraftingConfig(GraftingConfig):
class AdaGradGraftingConfig(GraftingConfig):
"""Configuration for grafting from AdaGrad.
Args:
Attributes:
epsilon (float): Epsilon term for regularizing square-root of the aggregated second moment to ensure positive definiteness.
(Default: 1e-10)
Expand All @@ -308,7 +310,7 @@ def __post_init__(self) -> None:
class RMSpropGraftingConfig(AdaGradGraftingConfig):
"""Configuration for grafting from RMSprop.
Args:
Attributes:
beta2 (float): Exponential moving average factor for second moment. (Default: 0.99)
epsilon (float): Epsilon term for regularizing square-root of the second moment to ensure positive definiteness.
(Default: 1e-10)
Expand All @@ -329,7 +331,7 @@ def __post_init__(self) -> None:
class AdamGraftingConfig(RMSpropGraftingConfig):
"""Configuration for grafting from Adam.
Args:
Attributes:
beta2 (float): Exponential moving average factor for second moment. (Default: 0.999)
epsilon (float): Epsilon term for regularizing square-root of the second moment to ensure positive definiteness.
(Default: 1e-10)
Expand Down
31 changes: 0 additions & 31 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
PreconditionerConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
)
from torch import nn

Expand Down Expand Up @@ -149,18 +148,6 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None:
**incorrect_hyperparameter_setting,
)

def test_invalid_cuda_pytorch_compile_setting(self) -> None:
with mock.patch.object(torch.cuda, "is_available", return_value=False):
self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to shampoo_pt2_compile_config=None."
),
DistributedShampoo,
self._model.parameters(),
shampoo_pt2_compile_config=ShampooPT2CompileConfig(),
)

def test_nesterov_and_zero_momentum(self) -> None:
with self.assertLogs(
level="WARNING",
Expand Down Expand Up @@ -194,24 +181,6 @@ def test_invalid_distributed_config(self) -> None:
distributed_config=DDPShampooConfig(),
)

def test_setting_exponent_multiplier_with_eigen_config(self) -> None:
with self.assertLogs(
level="WARNING",
) as cm:
DistributedShampoo(
self._model.parameters(),
lr=0.01,
start_preconditioning_step=1,
exponent_multiplier=2.0,
preconditioner_config=DefaultShampooConfig,
)
self.assertCountEqual(
[r.msg for r in cm.records],
[
"exponent_multiplier=2.0 is deprecating. Please consider using EigenConfig.exponent_multiplier directly and setting exponent_multipler=None instead in the future."
],
)


class DistributedShampooTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from torch import distributed as dist
from torch.optim.optimizer import ParamsT
from torch.testing._comparison import default_tolerances
from torch.testing._internal.common_distributed import (
DynamoDistributedMultiProcTestCase,
)
Expand Down Expand Up @@ -96,7 +97,7 @@ def _shampoo_optim_factory(
lr=0.001,
betas=(0.9, 1.0),
epsilon=1e-8,
momentum=0.0,
momentum=0.9,
weight_decay=0.0,
max_preconditioner_dim=20,
precondition_frequency=1,
Expand All @@ -116,14 +117,25 @@ def test_losses(self) -> None:
for num_trainers_per_group, (
communication_dtype,
communicate_params,
(rtol, atol),
) in product(
(-1, 1, 2),
(
(CommunicationDType.DEFAULT, False),
(CommunicationDType.DEFAULT, True),
(CommunicationDType.FP16, False),
(CommunicationDType.BF16, False),
# Expecting CommunicationDType.DEFAULT would have bitwise identical results (by setting rtol=atol=0.0).
(CommunicationDType.DEFAULT, False, (0.0, 0.0)),
(CommunicationDType.DEFAULT, True, (0.0, 0.0)),
# Using FP16 for distributed parameters prohibitively lowers precision.
(
CommunicationDType.FP16,
False,
default_tolerances(torch.float16),
),
(
CommunicationDType.BF16,
False,
# BF16 requires 2x tolerances than the original bfloat16 tolerances.
[2 * tol for tol in default_tolerances(torch.bfloat16)],
),
),
):
with self.subTest(
Expand All @@ -146,6 +158,8 @@ def test_losses(self) -> None:
model_dead_layer_dims=(20, 20),
device=self._device,
fill=0.01,
rtol=rtol,
atol=atol,
)

# This mock is used to catch the number of calls to Shampoo's step(), which happened after __init__().
Expand Down
Loading

0 comments on commit dfb3c4b

Please sign in to comment.