Skip to content

Commit

Permalink
Open-sourced update on 11/13/2024
Browse files Browse the repository at this point in the history
Summary:
1. Migrate `typing.{Callable,Iterable,Iterator,Mapping,Sequence}` to `collections.abc.{Callable,Iterable,Iterator,Mapping,Sequence}` because [PEP 585](https://peps.python.org/pep-0585/) deprecated those in `typing` in Python 3.9.
2. Add integration test for using QR algorithm in eigenvalue-corrected Shampoo.
3. Factor out `AbstractDataclass` into `commons.py`.

Reviewed By: hjmshi

Differential Revision: D65863644

fbshipit-source-id: 555df03be30a84ba55eb16b650fbcc572894437e
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Nov 13, 2024
1 parent f68fcbe commit 13f68b7
Show file tree
Hide file tree
Showing 27 changed files with 147 additions and 96 deletions.
19 changes: 19 additions & 0 deletions commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""

from dataclasses import dataclass
from typing import Any


@dataclass
class AbstractDataclass:
def __new__(cls, *args: Any, **kwargs: Any) -> "AbstractDataclass":
if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass:
raise TypeError(f"Cannot instantiate abstract class: {cls.__name__}.")
return super().__new__(cls)
2 changes: 1 addition & 1 deletion distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ A few notes on hyperparameters:

- We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay.

- When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.

### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum

Expand Down
8 changes: 4 additions & 4 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import contextlib
import dataclasses
import logging
from collections.abc import Callable, Iterator, Sequence
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Iterator, NoReturn, Sequence
from typing import Any, NoReturn

import torch

Expand Down Expand Up @@ -215,7 +216,7 @@ class DistributedShampoo(torch.optim.Optimizer):
updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps.
Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner.
When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning
When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning
rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be
a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet.
Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial.
Expand All @@ -233,7 +234,7 @@ class DistributedShampoo(torch.optim.Optimizer):
momentum (float): Momentum parameter. (Default: 0.)
dampening (float): Dampening parameter for momentum. (Default: 0.)
weight_decay (float): Weight decay (L2 penalty). (Default: 0.)
max_preconditioner_dim (int): Maximum preconditioner dimensio. (Default: 1024)
max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024)
precondition_frequency (int): Frequency of updating all components of the preconditioner.
If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner.
If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner.
Expand Down Expand Up @@ -819,7 +820,6 @@ def _compute_and_log_root_inverse_residuals(
SHAMPOO_PRECONDITIONER_LIST
].compute_root_inverse_residuals(),
)

quantiles = torch.as_tensor(
[0, 0.25, 0.5, 0.75, 1],
device=relative_errors.device,
Expand Down
60 changes: 42 additions & 18 deletions distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@

import math
import unittest
from collections.abc import Callable
from functools import partial
from itertools import product
from typing import Any, Callable, Type
from typing import Any, Type

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem
from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig
from matrix_functions_types import (
DefaultEighEigenvalueCorrectionConfig,
QREigenvalueCorrectionConfig,
)
from torch.nn.parameter import Parameter
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
Expand Down Expand Up @@ -97,19 +101,24 @@ def _optim_factory(
return optim_cls(parameters, **kwargs)

def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, and with CPU or GPU
for weight_decay, device in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_computation_config in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
lr=0.01,
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
with self.subTest(
weight_decay=weight_decay,
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
optim_factory, optim_cls=Adagrad, eps=1e-15
Expand All @@ -125,26 +134,31 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=False,
grafting_config=None,
preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig,
preconditioner_computation_config=preconditioner_computation_config,
),
device=device,
)

def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, and with CPU or GPU
for weight_decay, device in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_computation_config in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
lr=0.001,
betas=(0.9, 0.999),
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
with self.subTest(
weight_decay=weight_decay,
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
optim_factory,
Expand All @@ -161,26 +175,31 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=False,
grafting_config=None,
preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig,
preconditioner_computation_config=preconditioner_computation_config,
),
device=device,
)

def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, and with CPU or GPU
for weight_decay, device in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_computation_config in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
lr=0.001,
betas=(0.9, 0.999),
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
with self.subTest(
weight_decay=weight_decay,
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
optim_factory,
Expand All @@ -197,25 +216,30 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
start_preconditioning_step=math.inf,
use_decoupled_weight_decay=True,
grafting_config=None,
preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig,
preconditioner_computation_config=preconditioner_computation_config,
),
device=device,
)

def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, and with CPU or GPU
for weight_decay, device in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_computation_config in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
lr=0.01,
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
with self.subTest(
weight_decay=weight_decay,
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
optim_factory,
Expand All @@ -235,7 +259,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
use_decoupled_weight_decay=False,
grafting_config=None,
use_bias_correction=False,
preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig,
preconditioner_computation_config=preconditioner_computation_config,
),
device=device,
)
3 changes: 2 additions & 1 deletion distributed_shampoo/gpu_tests/shampoo_grafting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

import math
import unittest
from collections.abc import Callable
from functools import partial
from itertools import product
from typing import Any, Callable, Type
from typing import Any, Type

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/gpu_tests/shampoo_pt2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import itertools
import unittest
from collections.abc import Callable
from functools import partial
from typing import Callable

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
Expand Down
19 changes: 2 additions & 17 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import enum
from dataclasses import dataclass
from typing import Any

import torch

from commons import AbstractDataclass
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import ShardingStrategy
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -128,20 +129,10 @@ class PrecisionConfig:
grafting_state_dtype: torch.dtype = torch.float32


@dataclass
class AbstractDataclass:
def __new__(cls, *args: Any, **kwargs: Any) -> "AbstractDataclass":
if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass:
raise TypeError(f"Cannot instantiate abstract class: {cls.__name__}.")
return super().__new__(cls)


@dataclass
class DistributedConfig(AbstractDataclass):
"""Abstract dataclass for distributed configs in Shampoo."""

...


@dataclass(kw_only=True)
class DDPShampooConfig(DistributedConfig):
Expand Down Expand Up @@ -184,8 +175,6 @@ class FullyShardShampooConfig(DistributedConfig):
Currently only a placeholder used for Shampoo optimizer to select FullyShardDistributor.
"""

pass


@dataclass
class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig):
Expand Down Expand Up @@ -238,15 +227,11 @@ class ShampooPT2CompileConfig:
class GraftingConfig(AbstractDataclass):
"""Abstract dataclass for grafting configurations in Shampoo."""

...


@dataclass
class SGDGraftingConfig(GraftingConfig):
"""Configuration for grafting from SGD."""

...


@dataclass(kw_only=True)
class AdaGradGraftingConfig(GraftingConfig):
Expand Down
15 changes: 0 additions & 15 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,12 @@
from typing import Type

from distributed_shampoo.shampoo_types import (
AbstractDataclass,
AdaGradGraftingConfig,
AdamGraftingConfig,
DistributedConfig,
GraftingConfig,
RMSpropGraftingConfig,
)


class InvalidAbstractDataclassInitTest(unittest.TestCase):
def test_invalid_init(self) -> None:
for abstract_cls in (AbstractDataclass, DistributedConfig, GraftingConfig):
with self.subTest(abstract_cls=abstract_cls), self.assertRaisesRegex(
TypeError,
re.escape(
f"Cannot instantiate abstract class: {abstract_cls.__name__}."
),
):
abstract_cls()


class AdaGradGraftingConfigTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import contextlib
import re
import unittest
from itertools import product

from typing import Callable
from collections.abc import Callable
from itertools import product
from unittest import mock

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def _verify_deivce_mesh(self, device_mesh: DeviceMesh) -> None:
def test_get_device_mesh(self) -> None:
mesh = tuple(
map(
# type: ignore
tuple,
tuple, # type: ignore
torch.tensor(range(self.world_size))
.view(-1, self.world_size // 2)
.tolist(),
Expand All @@ -70,7 +69,7 @@ def test_get_device_mesh(self) -> None:
"__init__",
) as mock_device_mesh_init:
device_mesh = get_device_mesh(
device_type=self.device_type, # type: ignore[attr-defined]
device_type=self.device_type, # type: ignore
mesh=mesh,
mesh_dim_names=("replicate", "shard"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#!/usr/bin/env python3

import unittest
from collections.abc import Callable
from functools import partial
from itertools import pairwise
from typing import Callable

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#!/usr/bin/env python3

import unittest
from collections.abc import Callable
from functools import partial
from typing import Callable

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
Expand Down
Loading

0 comments on commit 13f68b7

Please sign in to comment.