Skip to content

Commit

Permalink
Open-sourced update on 11/25/2024
Browse files Browse the repository at this point in the history
Summary:
1. Refactor tests by adding `compare_two_optimizers_on_weight_and_loss()`.
2. Various codes refactorings.

Reviewed By: chuanhaozhuge

Differential Revision: D66429617

fbshipit-source-id: 7f4073f848b5c255501dc294a7a50b0b794a42d6
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Nov 25, 2024
1 parent 693993d commit c373c71
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 245 deletions.
83 changes: 15 additions & 68 deletions distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@

import math
import unittest
from collections.abc import Callable
from functools import partial
from itertools import product
from typing import Any, Type

import torch
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem
from distributed_shampoo.tests.shampoo_test_utils import (
compare_two_optimizers_on_weight_and_loss,
)
from matrix_functions_types import (
DefaultEighEigenvalueCorrectionConfig,
QREigenvalueCorrectionConfig,
)
from torch.nn.parameter import Parameter
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
Expand All @@ -39,59 +39,6 @@


class DistributedShampooEigenvalueCorrectionTest(unittest.TestCase):
@staticmethod
def _train_quadratic(
optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
device: torch.device,
) -> tuple[Parameter, torch.Tensor]:
model, loss, data, target = construct_training_problem(
model_linear_layers_dims=(10, 1, 1),
device=device,
fill=1.0,
)
params = model.parameters()
optimizer = optim_factory(params)
for _ in range(5):
optimizer.zero_grad()
objective = loss(model(data), target)
objective.backward()
optimizer.step()
return model.linear_layers[0].weight.data.cpu(), objective.detach().cpu()

@staticmethod
def _test_baseline_and_shampoo(
baseline_optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
shampoo_optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
device: torch.device,
) -> None:
(
baseline_params,
baseline_loss,
) = DistributedShampooEigenvalueCorrectionTest._train_quadratic(
baseline_optim_factory,
device=device,
)
shampoo_params, shampoo_loss = (
DistributedShampooEigenvalueCorrectionTest._train_quadratic(
shampoo_optim_factory,
device=device,
)
)
torch.testing.assert_close(shampoo_loss, baseline_loss)
torch.testing.assert_close(
shampoo_params,
baseline_params,
)

@staticmethod
def _optim_factory(
parameters: ParamsT,
Expand Down Expand Up @@ -119,11 +66,11 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory, optim_cls=Adagrad, eps=1e-15
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 1.0),
Expand Down Expand Up @@ -159,13 +106,13 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=Adam,
eps=1e-15,
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
epsilon=1e-15,
Expand Down Expand Up @@ -200,13 +147,13 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=AdamW,
eps=1e-15,
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
epsilon=1e-15,
Expand Down Expand Up @@ -240,14 +187,14 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
device=device,
preconditioner_computation_config=preconditioner_computation_config,
):
DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=RMSprop,
alpha=0.99,
eps=1e-15,
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 0.99),
Expand Down
87 changes: 18 additions & 69 deletions distributed_shampoo/gpu_tests/shampoo_grafting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import math
import unittest
from collections.abc import Callable
from functools import partial
from itertools import product
from typing import Any, Type
Expand All @@ -24,8 +23,9 @@
RMSpropGraftingConfig,
SGDGraftingConfig,
)
from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem
from torch.nn.parameter import Parameter
from distributed_shampoo.tests.shampoo_test_utils import (
compare_two_optimizers_on_weight_and_loss,
)
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
Expand All @@ -35,57 +35,6 @@


class DistributedShampooGraftingTest(unittest.TestCase):
@staticmethod
def _train_quadratic(
optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
device: torch.device,
) -> tuple[Parameter, torch.Tensor]:
model, loss, data, target = construct_training_problem(
model_linear_layers_dims=(10, 1, 1),
device=device,
fill=1.0,
)
params = model.parameters()
optimizer = optim_factory(params)
for _ in range(5):
optimizer.zero_grad()
objective = loss(model(data), target)
objective.backward()
optimizer.step()
return model.linear_layers[0].weight.data.cpu(), objective.detach().cpu()

@staticmethod
def _test_baseline_and_shampoo(
baseline_optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
shampoo_optim_factory: Callable[
[ParamsT],
torch.optim.Optimizer,
],
device: torch.device,
) -> None:
(
baseline_params,
baseline_loss,
) = DistributedShampooGraftingTest._train_quadratic(
baseline_optim_factory,
device=device,
)
shampoo_params, shampoo_loss = DistributedShampooGraftingTest._train_quadratic(
shampoo_optim_factory,
device=device,
)
torch.testing.assert_close(shampoo_loss, baseline_loss)
torch.testing.assert_close(
shampoo_params,
baseline_params,
)

@staticmethod
def _optim_factory(
parameters: ParamsT,
Expand All @@ -108,11 +57,11 @@ def test_adagrad_grafting_on_quadratic(self) -> None:
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
DistributedShampooGraftingTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory, optim_cls=Adagrad, eps=1e-10
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 1.0),
Expand Down Expand Up @@ -144,11 +93,11 @@ def test_adam_grafting_on_quadratic(self) -> None:
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
DistributedShampooGraftingTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory, optim_cls=Adam, eps=1e-8
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
epsilon=1e-8,
Expand Down Expand Up @@ -180,11 +129,11 @@ def test_adamw_grafting_on_quadratic(self) -> None:
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
DistributedShampooGraftingTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory, optim_cls=AdamW, eps=1e-8
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
epsilon=1e-8,
Expand Down Expand Up @@ -215,14 +164,14 @@ def test_rmsprop_grafting_on_quadratic(self) -> None:
weight_decay=weight_decay,
)
with self.subTest(weight_decay=weight_decay, device=device):
DistributedShampooGraftingTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=RMSprop,
alpha=0.99,
eps=1e-8,
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 0.99),
Expand Down Expand Up @@ -258,13 +207,13 @@ def test_sgd_grafting_on_quadratic(self) -> None:
with self.subTest(
weight_decay=weight_decay, use_nesterov=use_nesterov, device=device
):
DistributedShampooGraftingTest._test_baseline_and_shampoo(
baseline_optim_factory=partial(
compare_two_optimizers_on_weight_and_loss(
control_optim_factory=partial(
optim_factory,
optim_cls=SGD,
nesterov=use_nesterov,
),
shampoo_optim_factory=partial(
experimental_optim_factory=partial(
optim_factory,
optim_cls=DistributedShampoo,
betas=(0.0, 0.9),
Expand Down
Loading

0 comments on commit c373c71

Please sign in to comment.