diff --git a/.github/workflows/format-ruff.yaml b/.github/workflows/format-ruff.yaml deleted file mode 100644 index 0d1d6aa..0000000 --- a/.github/workflows/format-ruff.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: format-ruff - -on: [push, pull_request] - -jobs: - ruff: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: "format --check" diff --git a/.github/workflows/format-usort.yaml b/.github/workflows/format.yaml similarity index 69% rename from .github/workflows/format-usort.yaml rename to .github/workflows/format.yaml index 31b29a2..5caa8ca 100644 --- a/.github/workflows/format-usort.yaml +++ b/.github/workflows/format.yaml @@ -1,4 +1,4 @@ -name: format-usort +name: format on: [push, pull_request] @@ -18,3 +18,10 @@ jobs: - name: Run usort check. run: | usort check . + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: "format --check" diff --git a/.github/workflows/lint-ruff.yaml b/.github/workflows/lint.yaml similarity index 90% rename from .github/workflows/lint-ruff.yaml rename to .github/workflows/lint.yaml index 8713a69..05ea7a3 100644 --- a/.github/workflows/lint-ruff.yaml +++ b/.github/workflows/lint.yaml @@ -1,4 +1,4 @@ -name: lint-ruff +name: lint on: [push, pull_request] diff --git a/.github/workflows/type-check-mypy.yaml b/.github/workflows/type-check.yaml similarity index 96% rename from .github/workflows/type-check-mypy.yaml rename to .github/workflows/type-check.yaml index 020b39c..637411e 100644 --- a/.github/workflows/type-check-mypy.yaml +++ b/.github/workflows/type-check.yaml @@ -1,4 +1,4 @@ -name: type-check-mypy +name: type-check on: [push, pull_request] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8fded0b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.8.0 + hooks: + # Run the linter. + - id: ruff + types_or: [ python, pyi ] + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi ] +- repo: https://github.com/facebook/usort + rev: v1.0.8 + hooks: + - id: usort diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9697b39..4349f5e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,11 +8,11 @@ optimizers should be created in a separate public repo. ## Pull Requests We actively welcome your pull requests for existing optimizers. -1. Fork the repo and create your branch from `main`. Install the package inside of your Python environment with `pip install -e ".[dev]"`. +1. Fork the repo and create your branch from `main`. Install the package inside of your Python environment with `pip install -e ".[dev]"`. Run `pre-commit install` to set up the git hook scripts. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. To run the subset of the tests that can be run on CPU use `make test`; to run the tests for a single GPU use `make test-gpu` and to run the subset of tests that require 2-4 GPUs use `make test-multi-gpu`. -5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible. Use `make type-check` for type checking. +5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible (will also automatically run on `git commit` if the pre-commit hook was installed; however, this does not guarantee that there are not linting errors left). Use `make type-check` for type checking. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). > [!NOTE] diff --git a/README.md b/README.md index 81ef658..b01fde4 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,9 @@ 3.10 | 3.11 | 3.12](https://img.shields.io/badge/python-3.10_|_3.11_|_3.12-blue.svg)](https://www.python.org/downloads/) ![tests](https://github.com/facebookresearch/optimizers/actions/workflows/tests.yaml/badge.svg) ![gpu-tests](https://github.com/facebookresearch/optimizers/actions/workflows/gpu-tests.yaml/badge.svg) -![lint-ruff](https://github.com/facebookresearch/optimizers/actions/workflows/lint-ruff.yaml/badge.svg) -![format-ruff](https://github.com/facebookresearch/optimizers/actions/workflows/format-ruff.yaml/badge.svg) -![format-usort](https://github.com/facebookresearch/optimizers/actions/workflows/format-usort.yaml/badge.svg) -![type-check-mypy](https://github.com/facebookresearch/optimizers/actions/workflows/type-check-mypy.yaml/badge.svg) +![linting](https://github.com/facebookresearch/optimizers/actions/workflows/lint.yaml/badge.svg) +![formatting](https://github.com/facebookresearch/optimizers/actions/workflows/format.yaml/badge.svg) +![type-checking](https://github.com/facebookresearch/optimizers/actions/workflows/type-check.yaml/badge.svg) ![examples](https://github.com/facebookresearch/optimizers/actions/workflows/examples.yaml/badge.svg) *Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index e0b5337..d222b1f 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -1044,14 +1044,17 @@ def _per_group_step_impl( use_decoupled_weight_decay, ) - with DequantizePreconditionersContext( - preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] - ), ( + with ( DequantizePreconditionersContext( - preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] - ) - if grafting_config_not_none - else contextlib.nullcontext() + preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] + ), + ( + DequantizePreconditionersContext( + preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] + ) + if grafting_config_not_none + else contextlib.nullcontext() + ), ): # Update Shampoo and grafting preconditioners. # Example for AdaGrad accumulation: diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index c191f40..396db52 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -47,12 +47,15 @@ def setUp(self) -> None: ) def test_invalid_grafting_config(self) -> None: - with mock.patch.object( - distributed_shampoo, "type", side_effect=lambda object: GraftingConfig - ), self.assertRaisesRegex( - NotImplementedError, - re.escape( - "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + with ( + mock.patch.object( + distributed_shampoo, "type", side_effect=lambda object: GraftingConfig + ), + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + ), ), ): DistributedShampoo( @@ -126,22 +129,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None: incorrect_hyperparameter_setting, expected_error_msg, ) in incorrect_hyperparameter_setting_and_expected_error_msg: - with self.subTest( - incorrect_hyperparameter_setting=incorrect_hyperparameter_setting, - expected_error_msg=expected_error_msg, - ), self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)): + with ( + self.subTest( + incorrect_hyperparameter_setting=incorrect_hyperparameter_setting, + expected_error_msg=expected_error_msg, + ), + self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)), + ): DistributedShampoo( self._model.parameters(), **incorrect_hyperparameter_setting, ) def test_invalid_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + ), ), ): DistributedShampoo( @@ -150,12 +157,13 @@ def test_invalid_pytorch_compile_setting(self) -> None: shampoo_pt2_compile_config=ShampooPT2CompileConfig(), ) - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + ), ), ): DistributedShampoo( @@ -165,11 +173,12 @@ def test_invalid_pytorch_compile_setting(self) -> None: ) def test_warning_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=True - ), self.assertLogs( - level="WARNING", - ) as cm: + with ( + mock.patch.object(torch.cuda, "is_available", return_value=True), + self.assertLogs( + level="WARNING", + ) as cm, + ): DistributedShampoo( self._model.parameters(), lr=0.01, @@ -183,12 +192,13 @@ def test_warning_pytorch_compile_setting(self) -> None: ) def test_invalid_cuda_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + ), ), ): DistributedShampoo( @@ -213,16 +223,19 @@ def test_nesterov_and_zero_momentum(self) -> None: ) def test_invalid_distributed_config(self) -> None: - with self.assertRaisesRegex( - NotImplementedError, - re.escape( - "distributed_config=DDPShampooConfig(communication_dtype=, " - "num_trainers_per_group=-1, communicate_params=False) not supported!" + with ( + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "distributed_config=DDPShampooConfig(communication_dtype=, " + "num_trainers_per_group=-1, communicate_params=False) not supported!" + ), + ), + mock.patch.object( + distributed_shampoo, + "type", + side_effect=lambda object: DistributedConfig, ), - ), mock.patch.object( - distributed_shampoo, - "type", - side_effect=lambda object: DistributedConfig, ): DistributedShampoo( params=self._model.parameters(), diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 49d8a52..2ec773f 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -22,11 +22,12 @@ class AdaGradGraftingConfigTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 grafting_config_type = self._get_grafting_config_type() - with self.subTest( - grafting_config_type=grafting_config_type - ), self.assertRaisesRegex( - ValueError, - re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), + with ( + self.subTest(grafting_config_type=grafting_config_type), + self.assertRaisesRegex( + ValueError, + re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), + ), ): grafting_config_type(epsilon=epsilon) @@ -46,12 +47,13 @@ def test_illegal_beta2( ) -> None: grafting_config_type = self._get_grafting_config_type() for beta2 in (-1.0, 0.0, 1.3): - with self.subTest( - grafting_config_type=grafting_config_type, beta2=beta2 - ), self.assertRaisesRegex( - ValueError, - re.escape( - f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." + with ( + self.subTest(grafting_config_type=grafting_config_type, beta2=beta2), + self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." + ), ), ): grafting_config_type(beta2=beta2) diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py index f9fd65b..fa58208 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py @@ -313,11 +313,12 @@ def test_dist_is_initialized(self) -> None: device_mesh=mesh_2d, ) - with mock.patch.object( - torch.distributed, "is_initialized", return_value=False - ), self.assertRaisesRegex( - RuntimeError, - re.escape("HSDPDistributor needs torch.distributed to be initialized!"), + with ( + mock.patch.object(torch.distributed, "is_initialized", return_value=False), + self.assertRaisesRegex( + RuntimeError, + re.escape("HSDPDistributor needs torch.distributed to be initialized!"), + ), ): ShampooHSDPDistributorTest._train_model( optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory( @@ -339,12 +340,15 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group( ) # Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group. - with mock.patch.object( - torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 - ), self.assertRaisesRegex( - ValueError, - re.escape( - "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" + with ( + mock.patch.object( + torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 + ), + self.assertRaisesRegex( + ValueError, + re.escape( + "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" + ), ), ): ShampooHSDPDistributorTest._train_model( diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index cce8aca..6d38af6 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -125,13 +125,16 @@ def _test_compress_preconditioner_list( self, expected_compress_list_call_count: int, ) -> None: - with mock.patch.object( - shampoo_preconditioner_list, - "compress_list", - ) as mock_compress_list, mock.patch.object( - QuantizedTensorList, - "compress", - ) as mock_compress_quant_list: + with ( + mock.patch.object( + shampoo_preconditioner_list, + "compress_list", + ) as mock_compress_list, + mock.patch.object( + QuantizedTensorList, + "compress", + ) as mock_compress_quant_list, + ): # Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList. # Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside. self.assertIsNone( @@ -328,11 +331,16 @@ def _instantiate_preconditioner_list( # type: ignore[override] def _test_raise_invalid_value_in_factor_matrix( self, invalid_value: float ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape(f"Encountered {str(invalid_value)} values in factor matrix"), + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertRaisesRegex( + PreconditionerValueError, + re.escape( + f"Encountered {str(invalid_value)} values in factor matrix" + ), + ), ): self._preconditioner_list.update_preconditioners( masked_grad_list=( @@ -356,16 +364,21 @@ def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation( self, ) -> None: for invalid_value in (torch.nan, torch.inf): - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.subTest(invalid_value=invalid_value), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in"), - ), mock.patch.object( - shampoo_preconditioner_list, - self._amortized_computation_function(), - side_effect=(torch.tensor([invalid_value]),), - ) as mock_amortized_computation: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.subTest(invalid_value=invalid_value), + self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in"), + ), + mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + side_effect=(torch.tensor([invalid_value]),), + ) as mock_amortized_computation, + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -384,9 +397,12 @@ def test_amortized_computation_internal_failure(self) -> None: # Simulate the situation throws an exception (not nan and inf) to test the warning side_effect=ZeroDivisionError, ) as mock_amortized_computation: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs(level="WARNING") as cm: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertLogs(level="WARNING") as cm, + ): # Because use_protected_eigh is True, we expect the warning to be logged. self._preconditioner_list.update_preconditioners( masked_grad_list=( @@ -415,9 +431,12 @@ def test_amortized_computation_internal_failure(self) -> None: self._preconditioner_list = self._instantiate_preconditioner_list( use_protected_eigh=False, ) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaises(ZeroDivisionError): + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertRaises(ZeroDivisionError), + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -443,11 +462,14 @@ def test_amortized_computation_factor_matrix_non_diagonal( self._preconditioner_list = self._instantiate_preconditioner_list( epsilon=1.0 ) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs( - level="DEBUG", - ) as cm: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertLogs( + level="DEBUG", + ) as cm, + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), diff --git a/distributed_shampoo/utils/tests/shampoo_quantization_test.py b/distributed_shampoo/utils/tests/shampoo_quantization_test.py index 96a06b4..13979d9 100644 --- a/distributed_shampoo/utils/tests/shampoo_quantization_test.py +++ b/distributed_shampoo/utils/tests/shampoo_quantization_test.py @@ -108,14 +108,17 @@ def test_invalid_quantized_data_type(self) -> None: class QuantizedTensorListInitTest(unittest.TestCase): def test_invalid_quantized_data_type(self) -> None: - with mock.patch.object( - shampoo_quantization, - "isinstance", - side_effect=lambda object, classinfo: False, - ), self.assertRaisesRegex( - TypeError, - re.escape( - "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " + with ( + mock.patch.object( + shampoo_quantization, + "isinstance", + side_effect=lambda object, classinfo: False, + ), + self.assertRaisesRegex( + TypeError, + re.escape( + "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " + ), ), ): QuantizedTensorList( diff --git a/pyproject.toml b/pyproject.toml index 8603623..d574bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ examples = [ dev = [ "torch-shampoo[examples]", + "pre-commit", "ruff", "usort", "mypy", diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 4ea176d..1658201 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -231,23 +231,27 @@ def test_matrix_inverse_root_reach_max_iterations(self) -> None: implementation, msg, ) in root_inv_config_and_implementation_and_msg: - with mock.patch.object( - matrix_functions, - implementation, - return_value=( - None, - None, - NewtonConvergenceFlag.REACHED_MAX_ITERS, - None, - None, + with ( + mock.patch.object( + matrix_functions, + implementation, + return_value=( + None, + None, + NewtonConvergenceFlag.REACHED_MAX_ITERS, + None, + None, + ), ), - ), self.subTest( - root_inv_config=root_inv_config, - implementation=implementation, - msg=msg, - ), self.assertLogs( - level="WARNING", - ) as cm: + self.subTest( + root_inv_config=root_inv_config, + implementation=implementation, + msg=msg, + ), + self.assertLogs( + level="WARNING", + ) as cm, + ): matrix_inverse_root( A=A, root=root, @@ -891,14 +895,17 @@ def test_matrix_eigenvectors(self) -> None: def test_invalid_eigenvalue_correction_config( self, ) -> None: - with mock.patch.object( - matrix_functions, - "type", - side_effect=lambda object: EigenvalueCorrectionConfig, - ), self.assertRaisesRegex( - NotImplementedError, - re.escape( - "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + with ( + mock.patch.object( + matrix_functions, + "type", + side_effect=lambda object: EigenvalueCorrectionConfig, + ), + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + ), ), ): matrix_eigenvectors(