Skip to content

Commit

Permalink
Add ignored_dims config for single factor matrix enablement
Browse files Browse the repository at this point in the history
Summary:
This diff adds a new configuration option called `ignored_dims` to `PreconditionerConfig`; this option allows the user to specify which dimensions of the matrix should be ignored when computing the preconditioner.

Note that `PreconditionerConfig.ignored_dims` is not compatible with `inv_root_override`, and we plan to merge `inv_root_override` into `PreconditionerConfig.amortized_computation_config.exponent_multiplier` as `Preconditioner.exponent_override`, a list of float, representing that exponent overrides for each order of tensors. Given `Preconditioner.exponent_override=[e1, e2, ..., ep]`, then we will use ^e1 for 1-D tensors (vectors), ^e2 for 2-D tensors (matrices), and so on; when `ei=0` as the exponent for i-dimensional tensors which should result in no preconditioning for all i-dimensional tensors. On the other hands, setting `i in Preconditioner.ignored_dims` only results no preconditioning for i-th dimension for all tensors (if their orders are >= i). For example, if `Preconditioner.exponent_override=[0.5, 0.0, 0.25]` and `Preconditioner.ignored_dims=[0, 2]`, this means no preconditioning 1-D tensors (due to `0 in Preconditioner.ignored_dims` even though setting `Preconditioner.exponent_override[0]=0.5` is redundant), no preconditioning for 2-D tensors (due to `Preconditioner.exponent_override[1]=0.0`), and precondition the first and the second dimensions with ^0.25 (due to `Preconditioner.exponent_override[2]=0.25`) and no preconditioning the third dimension (due to `2 in Preconditioner.ignored_dims`) for 3-D tensors.

Pair-programmed with runame.

Reviewed By: anana10c

Differential Revision: D70198403

fbshipit-source-id: 6ba4f84461cc32c185cac949b74c6cd51b33e795
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Feb 25, 2025
1 parent 947ceec commit 66f348c
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 51 deletions.
6 changes: 6 additions & 0 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ def __init__(
"Continuing without using momentum or Nesterov acceleration..."
)

# Check potential conflict between preconditioner_config.ignored_dims and inv_root_override.
if preconditioner_config.ignored_dims != [] and inv_root_override != 0:
raise ValueError(
f"{preconditioner_config.ignored_dims=} is not supported when {inv_root_override=} is not set to 0. Please set {inv_root_override=} to 0 if you set {preconditioner_config.ignored_dims=}."
)

super().__init__(
params,
{
Expand Down
77 changes: 61 additions & 16 deletions distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from distributed_shampoo.shampoo_types import (
DefaultEigenvalueCorrectedShampooConfig,
DefaultSOAPConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
)
from distributed_shampoo.tests.shampoo_test_utils import (
compare_two_optimizers_on_weight_and_loss,
Expand Down Expand Up @@ -48,13 +49,23 @@ def _optim_factory(
return optim_cls(parameters, **kwargs)

def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_config in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored.
for weight_decay, device, (
start_preconditioning_step,
preconditioner_config,
) in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
(
(math.inf, DefaultEigenvalueCorrectedShampooConfig),
(math.inf, DefaultSOAPConfig),
(
1,
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
),
),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand All @@ -64,6 +75,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
with self.subTest(
weight_decay=weight_decay,
device=device,
start_preconditioning_step=start_preconditioning_step,
preconditioner_config=preconditioner_config,
):
compare_two_optimizers_on_weight_and_loss(
Expand All @@ -78,7 +90,7 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
start_preconditioning_step=start_preconditioning_step,
use_decoupled_weight_decay=False,
grafting_config=None,
preconditioner_config=preconditioner_config,
Expand All @@ -87,13 +99,23 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None:
)

def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_config in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored.
for weight_decay, device, (
start_preconditioning_step,
preconditioner_config,
) in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
(
(math.inf, DefaultEigenvalueCorrectedShampooConfig),
(math.inf, DefaultSOAPConfig),
(
1,
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
),
),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand All @@ -104,6 +126,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
with self.subTest(
weight_decay=weight_decay,
device=device,
start_preconditioning_step=start_preconditioning_step,
preconditioner_config=preconditioner_config,
):
compare_two_optimizers_on_weight_and_loss(
Expand All @@ -119,7 +142,7 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
start_preconditioning_step=start_preconditioning_step,
use_decoupled_weight_decay=False,
grafting_config=None,
preconditioner_config=preconditioner_config,
Expand All @@ -128,13 +151,23 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None:
)

def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_config in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored.
for weight_decay, device, (
start_preconditioning_step,
preconditioner_config,
) in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
(
(math.inf, DefaultEigenvalueCorrectedShampooConfig),
(math.inf, DefaultSOAPConfig),
(
1,
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
),
),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand All @@ -145,6 +178,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
with self.subTest(
weight_decay=weight_decay,
device=device,
start_preconditioning_step=start_preconditioning_step,
preconditioner_config=preconditioner_config,
):
compare_two_optimizers_on_weight_and_loss(
Expand All @@ -160,7 +194,7 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
start_preconditioning_step=start_preconditioning_step,
use_decoupled_weight_decay=True,
grafting_config=None,
preconditioner_config=preconditioner_config,
Expand All @@ -169,13 +203,23 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None:
)

def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm.
for weight_decay, device, preconditioner_config in product(
# Test with and without weight decay, with CPU or GPU, and using eigendecomposition, QR algorithm, or eigendecomposition with all dims ignored.
for weight_decay, device, (
start_preconditioning_step,
preconditioner_config,
) in product(
(0.0, 0.3),
(torch.device("cpu"),) + (torch.device("cuda"),)
if torch.cuda.is_available()
else (),
(DefaultEigenvalueCorrectedShampooConfig, DefaultSOAPConfig),
(
(math.inf, DefaultEigenvalueCorrectedShampooConfig),
(math.inf, DefaultSOAPConfig),
(
1,
EigenvalueCorrectedShampooPreconditionerConfig(ignored_dims=[0, 1]),
),
),
):
optim_factory = partial(
DistributedShampooEigenvalueCorrectionTest._optim_factory,
Expand All @@ -185,6 +229,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
with self.subTest(
weight_decay=weight_decay,
device=device,
start_preconditioning_step=start_preconditioning_step,
preconditioner_config=preconditioner_config,
):
compare_two_optimizers_on_weight_and_loss(
Expand All @@ -202,7 +247,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None:
momentum=0.0,
max_preconditioner_dim=10,
precondition_frequency=1,
start_preconditioning_step=math.inf,
start_preconditioning_step=start_preconditioning_step,
use_decoupled_weight_decay=False,
grafting_config=None,
use_bias_correction=False,
Expand Down
8 changes: 8 additions & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,23 @@ class PreconditionerConfig(AbstractDataclass):
Attributes:
amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation.
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: [])
"""

amortized_computation_config: MatrixFunctionConfig # type: ignore
num_tolerated_failed_amortized_computations: int = 3
ignored_dims: list[int] = field(default_factory=list)

def __post_init__(self) -> None:
if self.num_tolerated_failed_amortized_computations < 0:
raise ValueError(
f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0."
)
if len(self.ignored_dims) != len(set(self.ignored_dims)):
raise ValueError(
f"Invalid ignored_dims value: {self.ignored_dims}. Must be a list of unique dimensions."
)


@dataclass(kw_only=True)
Expand All @@ -106,6 +112,7 @@ class ShampooPreconditionerConfig(PreconditionerConfig):
Attributes:
amortized_computation_config (RootInvConfig): Configuration for the inverse-root computation. (Default: DefaultEigenConfig)
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: [])
"""

Expand All @@ -125,6 +132,7 @@ class EigenvalueCorrectedShampooPreconditionerConfig(PreconditionerConfig):
amortized_computation_config (EigenvectorConfig): Configuration for the eigenvector computation.
(Default: DefaultEighEigenvectorConfig)
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)
ignored_dims (list[int]): List of dimensions to ignore when computing the preconditioner. This is equivalent to setting the preconditioner for these dimensions to the identity matrix. (Default: [])
"""

Expand Down
16 changes: 16 additions & 0 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,22 @@ def test_invalid_distributed_config(self) -> None:
distributed_config=DDPShampooConfig(),
)

def test_ignored_dims_conflicts_with_inv_root_override(self) -> None:
inv_root_override = 2
preconditioner_config = ShampooPreconditionerConfig(
ignored_dims=[1, 3],
)
self.assertRaisesRegex(
ValueError,
re.escape(
f"{preconditioner_config.ignored_dims=} is not supported when {inv_root_override=} is not set to 0. Please set {inv_root_override=} to 0 if you set {preconditioner_config.ignored_dims=}."
),
DistributedShampoo,
params=self._model.parameters(),
inv_root_override=inv_root_override,
preconditioner_config=preconditioner_config,
)


class DistributedShampooTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
14 changes: 14 additions & 0 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,17 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
cls,
num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations,
)

def test_illegal_ignored_dims(self) -> None:
ignored_dims = [1, 2, 3, 1]
# Not testing for the base class PreconditionerConfig because it is an abstract class.
for cls in get_all_subclasses(PreconditionerConfig, include_cls_self=False):
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid ignored_dims value: {ignored_dims}. Must be a list of unique dimensions."
),
cls,
ignored_dims=ignored_dims,
)
Loading

0 comments on commit 66f348c

Please sign in to comment.