From 13f68b7862c2d2cfe99564f25fd759d676ac0f66 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Wed, 13 Nov 2024 01:39:30 -0800 Subject: [PATCH] Open-sourced update on 11/13/2024 Summary: 1. Migrate `typing.{Callable,Iterable,Iterator,Mapping,Sequence}` to `collections.abc.{Callable,Iterable,Iterator,Mapping,Sequence}` because [PEP 585](https://peps.python.org/pep-0585/) deprecated those in `typing` in Python 3.9. 2. Add integration test for using QR algorithm in eigenvalue-corrected Shampoo. 3. Factor out `AbstractDataclass` into `commons.py`. Reviewed By: hjmshi Differential Revision: D65863644 fbshipit-source-id: 555df03be30a84ba55eb16b650fbcc572894437e --- commons.py | 19 ++++++ distributed_shampoo/README.md | 2 +- distributed_shampoo/distributed_shampoo.py | 8 +-- .../shampoo_eigenvalue_correction_test.py | 60 +++++++++++++------ .../gpu_tests/shampoo_grafting_test.py | 3 +- .../gpu_tests/shampoo_pt2_test.py | 2 +- distributed_shampoo/shampoo_types.py | 19 +----- .../tests/shampoo_types_test.py | 15 ----- .../gpu_tests/shampoo_ddp_distributor_test.py | 4 +- .../gpu_tests/shampoo_dist_utils_test.py | 5 +- .../shampoo_fsdp_distributor_test.py | 2 +- .../shampoo_fully_shard_distributor_test.py | 2 +- .../shampoo_hsdp_distributor_test.py | 2 +- .../utils/shampoo_block_info.py | 2 +- .../utils/shampoo_ddp_distributor.py | 4 +- .../utils/shampoo_distributor.py | 3 +- .../utils/shampoo_fully_shard_distributor.py | 2 +- .../utils/shampoo_hsdp_distributor.py | 4 +- .../utils/shampoo_preconditioner_list.py | 4 +- .../utils/shampoo_quantization.py | 8 +-- distributed_shampoo/utils/shampoo_utils.py | 7 ++- .../tests/shampoo_preconditioner_list_test.py | 8 ++- .../utils/tests/shampoo_quantization_test.py | 2 +- .../utils/tests/shampoo_utils_test.py | 13 +++- matrix_functions_types.py | 8 +-- optimizer_modules.py | 3 +- tests/commons_test.py | 32 ++++++++++ 27 files changed, 147 insertions(+), 96 deletions(-) create mode 100644 commons.py create mode 100644 tests/commons_test.py diff --git a/commons.py b/commons.py new file mode 100644 index 0000000..bf940e9 --- /dev/null +++ b/commons.py @@ -0,0 +1,19 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class AbstractDataclass: + def __new__(cls, *args: Any, **kwargs: Any) -> "AbstractDataclass": + if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass: + raise TypeError(f"Cannot instantiate abstract class: {cls.__name__}.") + return super().__new__(cls) diff --git a/distributed_shampoo/README.md b/distributed_shampoo/README.md index 58cb57b..f0c1746 100644 --- a/distributed_shampoo/README.md +++ b/distributed_shampoo/README.md @@ -64,7 +64,7 @@ A few notes on hyperparameters: - We allow for decoupled and coupled weight decay. If one sets `use_decoupled_weight_decay=True`, then you are enabling AdamW-style weight decay, while `use_decoupled_weight_decay=False` corresponds to the normal L2-regularization style weight decay. -- When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. +- When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. ### Example 1: [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) with Momentum diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 8fb67b3..b09a2f4 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -10,9 +10,10 @@ import contextlib import dataclasses import logging +from collections.abc import Callable, Iterator, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Iterator, NoReturn, Sequence +from typing import Any, NoReturn import torch @@ -215,7 +216,7 @@ class DistributedShampoo(torch.optim.Optimizer): updated every iteration while the eigenbasis of Shampoo's preconditioner is only computed every `precondition_frequency` steps. Alternatively, this can be seen as running Adam in the eigenbasis of Shampoo's preconditioner. - When setting `preconditioner_computation_config` as an instance of EigenvalueCorrectionConfig, there is typically no need to use learning + When setting `preconditioner_computation_config` as an instance of `EigenvalueCorrectionConfig`, there is typically no need to use learning rate grafting from Adam (`grafting_config=None`) and, when they are available, Adam's optimal `lr`, `betas`, and `weight_decay` should be a good starting point for further tuning. However, the case of `beta2=1.0`, i.e. an AdaGrad-like accumulation, has not been explored yet. Also, in settings where Shampoo would usually graft its learning rate from SGD, grafting might still be beneficial. @@ -233,7 +234,7 @@ class DistributedShampoo(torch.optim.Optimizer): momentum (float): Momentum parameter. (Default: 0.) dampening (float): Dampening parameter for momentum. (Default: 0.) weight_decay (float): Weight decay (L2 penalty). (Default: 0.) - max_preconditioner_dim (int): Maximum preconditioner dimensio. (Default: 1024) + max_preconditioner_dim (int): Maximum preconditioner dimension. (Default: 1024) precondition_frequency (int): Frequency of updating all components of the preconditioner. If this field is an instance RootInvConfig, this is the update frequency of the root inverse of the preconditioner. If this field is an instance EigenvalueCorrectionConfig, this is the update frequency of the eigenbasis of preconditioner. @@ -819,7 +820,6 @@ def _compute_and_log_root_inverse_residuals( SHAMPOO_PRECONDITIONER_LIST ].compute_root_inverse_residuals(), ) - quantiles = torch.as_tensor( [0, 0.25, 0.5, 0.75, 1], device=relative_errors.device, diff --git a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py index 116dbbf..67dd271 100644 --- a/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_eigenvalue_correction_test.py @@ -11,14 +11,18 @@ import math import unittest +from collections.abc import Callable from functools import partial from itertools import product -from typing import Any, Callable, Type +from typing import Any, Type import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo from distributed_shampoo.tests.shampoo_test_utils import construct_training_problem -from matrix_functions_types import DefaultEighEigenvalueCorrectionConfig +from matrix_functions_types import ( + DefaultEighEigenvalueCorrectionConfig, + QREigenvalueCorrectionConfig, +) from torch.nn.parameter import Parameter from torch.optim.adagrad import Adagrad from torch.optim.adam import Adam @@ -97,19 +101,24 @@ def _optim_factory( return optim_cls(parameters, **kwargs) def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, and with CPU or GPU - for weight_decay, device in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. + for weight_decay, device, preconditioner_computation_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), + (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, lr=0.01, weight_decay=weight_decay, ) - with self.subTest(weight_decay=weight_decay, device=device): + with self.subTest( + weight_decay=weight_decay, + device=device, + preconditioner_computation_config=preconditioner_computation_config, + ): DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( baseline_optim_factory=partial( optim_factory, optim_cls=Adagrad, eps=1e-15 @@ -125,18 +134,19 @@ def test_adagrad_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=preconditioner_computation_config, ), device=device, ) def test_adam_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, and with CPU or GPU - for weight_decay, device in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. + for weight_decay, device, preconditioner_computation_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), + (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -144,7 +154,11 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: betas=(0.9, 0.999), weight_decay=weight_decay, ) - with self.subTest(weight_decay=weight_decay, device=device): + with self.subTest( + weight_decay=weight_decay, + device=device, + preconditioner_computation_config=preconditioner_computation_config, + ): DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( baseline_optim_factory=partial( optim_factory, @@ -161,18 +175,19 @@ def test_adam_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=False, grafting_config=None, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=preconditioner_computation_config, ), device=device, ) def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, and with CPU or GPU - for weight_decay, device in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. + for weight_decay, device, preconditioner_computation_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), + (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, @@ -180,7 +195,11 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: betas=(0.9, 0.999), weight_decay=weight_decay, ) - with self.subTest(weight_decay=weight_decay, device=device): + with self.subTest( + weight_decay=weight_decay, + device=device, + preconditioner_computation_config=preconditioner_computation_config, + ): DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( baseline_optim_factory=partial( optim_factory, @@ -197,25 +216,30 @@ def test_adamw_eigenvalue_correction_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_decoupled_weight_decay=True, grafting_config=None, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=preconditioner_computation_config, ), device=device, ) def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: - # Test with and without weight decay, and with CPU or GPU - for weight_decay, device in product( + # Test with and without weight decay, with CPU or GPU, and using eigendecomposition or QR algorithm. + for weight_decay, device, preconditioner_computation_config in product( (0.0, 0.3), (torch.device("cpu"),) + (torch.device("cuda"),) if torch.cuda.is_available() else (), + (DefaultEighEigenvalueCorrectionConfig, QREigenvalueCorrectionConfig()), ): optim_factory = partial( DistributedShampooEigenvalueCorrectionTest._optim_factory, lr=0.01, weight_decay=weight_decay, ) - with self.subTest(weight_decay=weight_decay, device=device): + with self.subTest( + weight_decay=weight_decay, + device=device, + preconditioner_computation_config=preconditioner_computation_config, + ): DistributedShampooEigenvalueCorrectionTest._test_baseline_and_shampoo( baseline_optim_factory=partial( optim_factory, @@ -235,7 +259,7 @@ def test_rmsprop_eigenvalue_correction_on_quadratic(self) -> None: use_decoupled_weight_decay=False, grafting_config=None, use_bias_correction=False, - preconditioner_computation_config=DefaultEighEigenvalueCorrectionConfig, + preconditioner_computation_config=preconditioner_computation_config, ), device=device, ) diff --git a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py index 229eee6..68f10b7 100644 --- a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py @@ -11,9 +11,10 @@ import math import unittest +from collections.abc import Callable from functools import partial from itertools import product -from typing import Any, Callable, Type +from typing import Any, Type import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo diff --git a/distributed_shampoo/gpu_tests/shampoo_pt2_test.py b/distributed_shampoo/gpu_tests/shampoo_pt2_test.py index 2dd2ff7..bd78eea 100644 --- a/distributed_shampoo/gpu_tests/shampoo_pt2_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_pt2_test.py @@ -11,8 +11,8 @@ import itertools import unittest +from collections.abc import Callable from functools import partial -from typing import Callable import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index bc25349..62f7168 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -9,9 +9,10 @@ import enum from dataclasses import dataclass -from typing import Any import torch + +from commons import AbstractDataclass from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import ShardingStrategy from torch.nn.parameter import Parameter @@ -128,20 +129,10 @@ class PrecisionConfig: grafting_state_dtype: torch.dtype = torch.float32 -@dataclass -class AbstractDataclass: - def __new__(cls, *args: Any, **kwargs: Any) -> "AbstractDataclass": - if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass: - raise TypeError(f"Cannot instantiate abstract class: {cls.__name__}.") - return super().__new__(cls) - - @dataclass class DistributedConfig(AbstractDataclass): """Abstract dataclass for distributed configs in Shampoo.""" - ... - @dataclass(kw_only=True) class DDPShampooConfig(DistributedConfig): @@ -184,8 +175,6 @@ class FullyShardShampooConfig(DistributedConfig): Currently only a placeholder used for Shampoo optimizer to select FullyShardDistributor. """ - pass - @dataclass class HSDPShampooConfig(FSDPShampooConfig, DDPShampooConfig): @@ -238,15 +227,11 @@ class ShampooPT2CompileConfig: class GraftingConfig(AbstractDataclass): """Abstract dataclass for grafting configurations in Shampoo.""" - ... - @dataclass class SGDGraftingConfig(GraftingConfig): """Configuration for grafting from SGD.""" - ... - @dataclass(kw_only=True) class AdaGradGraftingConfig(GraftingConfig): diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 3ee82e3..49d8a52 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -12,27 +12,12 @@ from typing import Type from distributed_shampoo.shampoo_types import ( - AbstractDataclass, AdaGradGraftingConfig, AdamGraftingConfig, - DistributedConfig, - GraftingConfig, RMSpropGraftingConfig, ) -class InvalidAbstractDataclassInitTest(unittest.TestCase): - def test_invalid_init(self) -> None: - for abstract_cls in (AbstractDataclass, DistributedConfig, GraftingConfig): - with self.subTest(abstract_cls=abstract_cls), self.assertRaisesRegex( - TypeError, - re.escape( - f"Cannot instantiate abstract class: {abstract_cls.__name__}." - ), - ): - abstract_cls() - - class AdaGradGraftingConfigTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py index 90c3ce7..b8079a7 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_ddp_distributor_test.py @@ -13,9 +13,9 @@ import contextlib import re import unittest -from itertools import product -from typing import Callable +from collections.abc import Callable +from itertools import product from unittest import mock import torch diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py index 3f46f48..3f4f80e 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_dist_utils_test.py @@ -46,8 +46,7 @@ def _verify_deivce_mesh(self, device_mesh: DeviceMesh) -> None: def test_get_device_mesh(self) -> None: mesh = tuple( map( - # type: ignore - tuple, + tuple, # type: ignore torch.tensor(range(self.world_size)) .view(-1, self.world_size // 2) .tolist(), @@ -70,7 +69,7 @@ def test_get_device_mesh(self) -> None: "__init__", ) as mock_device_mesh_init: device_mesh = get_device_mesh( - device_type=self.device_type, # type: ignore[attr-defined] + device_type=self.device_type, # type: ignore mesh=mesh, mesh_dim_names=("replicate", "shard"), ) diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_fsdp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_fsdp_distributor_test.py index 3c6c0b4..f3b239e 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_fsdp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_fsdp_distributor_test.py @@ -10,9 +10,9 @@ #!/usr/bin/env python3 import unittest +from collections.abc import Callable from functools import partial from itertools import pairwise -from typing import Callable import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py index 5bd15bd..e707c6e 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py @@ -10,8 +10,8 @@ #!/usr/bin/env python3 import unittest +from collections.abc import Callable from functools import partial -from typing import Callable import torch from distributed_shampoo.distributed_shampoo import DistributedShampoo diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py index fe3c212..f9fd65b 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py @@ -11,9 +11,9 @@ import re import unittest +from collections.abc import Callable from functools import partial from itertools import pairwise, product -from typing import Callable from unittest import mock import torch diff --git a/distributed_shampoo/utils/shampoo_block_info.py b/distributed_shampoo/utils/shampoo_block_info.py index 82e7e22..2a11708 100644 --- a/distributed_shampoo/utils/shampoo_block_info.py +++ b/distributed_shampoo/utils/shampoo_block_info.py @@ -7,8 +7,8 @@ """ +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch from torch import Tensor diff --git a/distributed_shampoo/utils/shampoo_ddp_distributor.py b/distributed_shampoo/utils/shampoo_ddp_distributor.py index ca6bea3..ad9408a 100644 --- a/distributed_shampoo/utils/shampoo_ddp_distributor.py +++ b/distributed_shampoo/utils/shampoo_ddp_distributor.py @@ -78,9 +78,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype = torch.bfloat16 + communication_dtype: torch.dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype = torch.float16 + communication_dtype: torch.dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32, diff --git a/distributed_shampoo/utils/shampoo_distributor.py b/distributed_shampoo/utils/shampoo_distributor.py index ce11065..8762b84 100644 --- a/distributed_shampoo/utils/shampoo_distributor.py +++ b/distributed_shampoo/utils/shampoo_distributor.py @@ -8,8 +8,9 @@ """ from abc import ABC, abstractmethod +from collections.abc import Iterable from operator import attrgetter -from typing import Any, Iterable +from typing import Any import torch from distributed_shampoo.shampoo_types import ( diff --git a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py index 38b6a31..130cce0 100644 --- a/distributed_shampoo/utils/shampoo_fully_shard_distributor.py +++ b/distributed_shampoo/utils/shampoo_fully_shard_distributor.py @@ -7,7 +7,7 @@ """ -from typing import Iterable +from collections.abc import Iterable from distributed_shampoo.shampoo_types import PARAMS from distributed_shampoo.utils.shampoo_block_info import BlockInfo diff --git a/distributed_shampoo/utils/shampoo_hsdp_distributor.py b/distributed_shampoo/utils/shampoo_hsdp_distributor.py index 433dcfc..c2b48e7 100644 --- a/distributed_shampoo/utils/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/utils/shampoo_hsdp_distributor.py @@ -152,9 +152,9 @@ def __init__( # Determine communication type. if distributed_config.communication_dtype == CommunicationDType.BF16: - communication_dtype = torch.bfloat16 + communication_dtype: torch.dtype = torch.bfloat16 elif distributed_config.communication_dtype == CommunicationDType.FP16: - communication_dtype = torch.float16 + communication_dtype: torch.dtype = torch.float16 else: assert distributed_config.communication_dtype in [ CommunicationDType.FP32, diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 2d5f5bd..ce71318 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -9,14 +9,14 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass, field from fractions import Fraction from functools import partial, reduce from itertools import chain from operator import methodcaller -from typing import Any, cast, Generic, Mapping, Sequence, TypeVar +from typing import Any, cast, Generic, TypeVar import torch from distributed_shampoo.shampoo_types import PrecisionConfig, PreconditionerValueError diff --git a/distributed_shampoo/utils/shampoo_quantization.py b/distributed_shampoo/utils/shampoo_quantization.py index 3b2dfa4..a9a0045 100644 --- a/distributed_shampoo/utils/shampoo_quantization.py +++ b/distributed_shampoo/utils/shampoo_quantization.py @@ -9,8 +9,8 @@ import logging import typing +from collections.abc import Sequence from operator import methodcaller -from typing import Sequence import torch from distributed_shampoo.utils.shampoo_block_info import BlockInfo @@ -118,10 +118,8 @@ def _convert_float_to_float(src: torch.Tensor, dest: torch.Tensor) -> None: class QuantizedTensorList: def __init__( self, - quantized_data: ( - Sequence[tuple[Tensor, Tensor | None, Tensor | None]] - | Sequence[QuantizedTensor] - ), + quantized_data: Sequence[tuple[Tensor, Tensor | None, Tensor | None]] + | Sequence[QuantizedTensor], quantized_dtype: torch.dtype, computation_dtype: torch.dtype = torch.float32, ) -> None: diff --git a/distributed_shampoo/utils/shampoo_utils.py b/distributed_shampoo/utils/shampoo_utils.py index 5f44810..5162631 100644 --- a/distributed_shampoo/utils/shampoo_utils.py +++ b/distributed_shampoo/utils/shampoo_utils.py @@ -8,10 +8,11 @@ """ import math +from collections.abc import Callable, Iterator, Sequence from functools import partial from itertools import accumulate, chain, compress, pairwise from types import TracebackType -from typing import Callable, Iterator, Sequence, Type, TypeVar +from typing import Type, TypeVar import torch from torch import Tensor @@ -25,7 +26,7 @@ def merge_small_dims(tensor_shape: Sequence[int], threshold: int) -> tuple[int, threshold (int): Threshold on the maximum size of each dimension. Returns: - new_tensor_shape (list[int]): New tensor shape. + new_tensor_shape (tuple[int, ...]): New tensor shape. """ @@ -49,7 +50,7 @@ def multi_dim_split(tensor: Tensor, split_size: int) -> tuple[Tensor, ...]: split_size (int): Size of a single chunk. Returns: - split_grad (list[Tensor]): List of tensors. + split_tensors (tuple[Tensor, ...]): List of tensors. """ split_tensors: tuple[Tensor, ...] = (tensor,) diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 1cd4c0b..cdb0b61 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -458,7 +458,9 @@ class ShampooPreconditionerListTest(AbstractTest.BaseShampooPreconditionerListTe def _amortized_computation_function(self) -> str: return "matrix_inverse_root" - def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: # type: ignore[override] + def _instantiate_preconditioner_list( # type: ignore[override] + self, **kwargs: Any + ) -> ShampooPreconditionerList: kwargs = { "beta2": 1.0, "epsilon": 0.0, @@ -741,7 +743,9 @@ class EigenvalueCorrectedShampooPreconditionerListTest( def _amortized_computation_function(self) -> str: return "matrix_eigenvectors" - def _instantiate_preconditioner_list(self, **kwargs: Any) -> PreconditionerList: # type: ignore[override] + def _instantiate_preconditioner_list( # type: ignore[override] + self, **kwargs: Any + ) -> EigenvalueCorrectedShampooPreconditionerList: kwargs = { "beta2": 1.0, "epsilon": 1e-12, diff --git a/distributed_shampoo/utils/tests/shampoo_quantization_test.py b/distributed_shampoo/utils/tests/shampoo_quantization_test.py index 6af98d7..96a06b4 100644 --- a/distributed_shampoo/utils/tests/shampoo_quantization_test.py +++ b/distributed_shampoo/utils/tests/shampoo_quantization_test.py @@ -115,7 +115,7 @@ def test_invalid_quantized_data_type(self) -> None: ), self.assertRaisesRegex( TypeError, re.escape( - "quantized_data must be typing.Union[typing.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]], typing.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor]] but get " + "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " ), ): QuantizedTensorList( diff --git a/distributed_shampoo/utils/tests/shampoo_utils_test.py b/distributed_shampoo/utils/tests/shampoo_utils_test.py index 2e08eeb..aa26a1b 100644 --- a/distributed_shampoo/utils/tests/shampoo_utils_test.py +++ b/distributed_shampoo/utils/tests/shampoo_utils_test.py @@ -53,9 +53,16 @@ def test_merge_small_dims_all_ones(self) -> None: def test_merge_small_dims_empty(self) -> None: merged_dims = (0,) threshold = 10 - self.assertEqual(merge_small_dims((0,), threshold), merged_dims) - self.assertEqual(merge_small_dims((0, 1), threshold), merged_dims) - self.assertEqual(merge_small_dims((0, 1, 5, 10, 20), threshold), merged_dims) + self.assertEqual( + merge_small_dims(tensor_shape=(0,), threshold=threshold), merged_dims + ) + self.assertEqual( + merge_small_dims(tensor_shape=(0, 1), threshold=threshold), merged_dims + ) + self.assertEqual( + merge_small_dims(tensor_shape=(0, 1, 5, 10, 20), threshold=threshold), + merged_dims, + ) class MultiDimSplitTest(unittest.TestCase): diff --git a/matrix_functions_types.py b/matrix_functions_types.py index 8801dd1..ecfac8c 100644 --- a/matrix_functions_types.py +++ b/matrix_functions_types.py @@ -9,22 +9,18 @@ from dataclasses import dataclass -from distributed_shampoo.shampoo_types import AbstractDataclass +from commons import AbstractDataclass @dataclass class PreconditionerComputationConfig(AbstractDataclass): """Configuration for preconditioner computation in Shampoo.""" - ... - @dataclass class RootInvConfig(PreconditionerComputationConfig): """Base dataclass for matrix root inverse method configurations in Shampoo.""" - ... - @dataclass(kw_only=True) class EigenConfig(RootInvConfig): @@ -88,8 +84,6 @@ class CoupledHigherOrderConfig(RootInvConfig): class EigenvalueCorrectionConfig(PreconditionerComputationConfig): """Base dataclass for matrix eigenvector method configurations in eigenvalue-corrected Shampoo.""" - ... - @dataclass(kw_only=True) class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig): diff --git a/optimizer_modules.py b/optimizer_modules.py index e148ac8..2492a3f 100644 --- a/optimizer_modules.py +++ b/optimizer_modules.py @@ -8,8 +8,9 @@ """ import logging +from collections.abc import Iterable from copy import deepcopy -from typing import Any, Iterable +from typing import Any import torch from torch.optim.optimizer import StateDict diff --git a/tests/commons_test.py b/tests/commons_test.py new file mode 100644 index 0000000..a3ccc62 --- /dev/null +++ b/tests/commons_test.py @@ -0,0 +1,32 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + +import re +import unittest + +from dataclasses import dataclass + +from commons import AbstractDataclass + + +@dataclass +class DummyOptimizerConfig(AbstractDataclass): + """Dummy abstract dataclass for testing. Instantiation should fail.""" + + +class InvalidAbstractDataclassInitTest(unittest.TestCase): + def test_invalid_init(self) -> None: + for abstract_cls in (AbstractDataclass, DummyOptimizerConfig): + with self.subTest(abstract_cls=abstract_cls), self.assertRaisesRegex( + TypeError, + re.escape( + f"Cannot instantiate abstract class: {abstract_cls.__name__}." + ), + ): + abstract_cls()