From 90de41827bc6a1700738c2bb9fd822841691ecaf Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 20 Nov 2024 12:20:59 +0000 Subject: [PATCH] Fix mypy errors --- distributed_shampoo/examples/trainer_utils.py | 2 +- distributed_shampoo/gpu_tests/shampoo_grafting_test.py | 2 +- distributed_shampoo/shampoo_types.py | 1 - distributed_shampoo/tests/distributed_shampoo_test.py | 2 +- .../utils/gpu_tests/shampoo_fully_shard_distributor_test.py | 2 +- .../utils/tests/shampoo_preconditioner_list_test.py | 2 +- tests/matrix_functions_test.py | 2 +- 7 files changed, 6 insertions(+), 7 deletions(-) diff --git a/distributed_shampoo/examples/trainer_utils.py b/distributed_shampoo/examples/trainer_utils.py index 146e552..635a57b 100644 --- a/distributed_shampoo/examples/trainer_utils.py +++ b/distributed_shampoo/examples/trainer_utils.py @@ -529,7 +529,7 @@ def instantiate_grafting_config( epsilon=grafting_epsilon, ) elif grafting_type == GraftingType.SGD: - return SGDGraftingConfig( + return SGDGraftingConfig( # type: ignore[abstract] beta2=grafting_beta2, epsilon=grafting_epsilon, ) diff --git a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py index 68f10b7..56b2c96 100644 --- a/distributed_shampoo/gpu_tests/shampoo_grafting_test.py +++ b/distributed_shampoo/gpu_tests/shampoo_grafting_test.py @@ -274,7 +274,7 @@ def test_sgd_grafting_on_quadratic(self) -> None: start_preconditioning_step=math.inf, use_nesterov=use_nesterov, use_decoupled_weight_decay=False, - grafting_config=SGDGraftingConfig(), + grafting_config=SGDGraftingConfig(), # type: ignore[abstract] ), device=device, ) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 56b6bd2..e0dd6a4 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -246,7 +246,6 @@ class AdaGradGraftingConfig(GraftingConfig): epsilon: float = 1e-10 def __post_init__(self) -> None: - super().__init__() if not self.epsilon > 0.0: raise ValueError(f"Invalid epsilon value: {self.epsilon}. Must be > 0.0.") diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index 9a96fe5..3a41d77 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -60,7 +60,7 @@ def test_invalid_grafting_config(self) -> None: ): DistributedShampoo( self._model.parameters(), - grafting_config=SGDGraftingConfig(), + grafting_config=SGDGraftingConfig(), # type: ignore[abstract] ) def test_invalid_with_incorrect_hyperparameter_setting(self) -> None: diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py index e707c6e..998dac5 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_fully_shard_distributor_test.py @@ -216,7 +216,7 @@ def _model_factory( @skip_if_lt_x_gpu(2) def test_fully_shard_shampoo_against_default_shampoo(self) -> None: - fully_shard_config = FullyShardShampooConfig() + fully_shard_config = FullyShardShampooConfig() # type: ignore[abstract] ShampooFullyShardDistributorTest._test_two_configs( ShampooFullyShardDistributorTest._shampoo_optim_factory( None, diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 8592b81..e3e69d5 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -282,7 +282,7 @@ def test_abstract_methods(self) -> None: # Basic setup for instantiating BaseShampooPreconditionerList. params = (torch.tensor([1.0, 2.0]),) block_list = (params[0],) - state = {params[0]: {}} + state: dict[Tensor, dict] = {params[0]: {}} block_info_list = ( BlockInfo( param=params[0], diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index f830d02..4ea176d 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -302,7 +302,7 @@ def test_matrix_inverse_root_with_invalid_root_inv_config(self) -> None: matrix_inverse_root( A=A, root=root, - root_inv_config=InvalidRootInvConfig(), + root_inv_config=InvalidRootInvConfig(), # type: ignore[abstract] is_diagonal=False, )