Skip to content

Commit

Permalink
Add pre-commit hook for linting and formatting (#56)
Browse files Browse the repository at this point in the history
Summary:
Also, clean up the workflows by merging the formatting workflows into one and simplifying the naming to be tool-agnostic.

Pull Request resolved: #56

Reviewed By: anana10c

Differential Revision: D66510666

Pulled By: tsunghsienlee

fbshipit-source-id: 33386ebdc703ed124805aa4d51c3d746ff4d85e9
  • Loading branch information
runame authored and facebook-github-bot committed Nov 26, 2024
1 parent c373c71 commit 14146a1
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 157 deletions.
12 changes: 0 additions & 12 deletions .github/workflows/format-ruff.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: format-usort
name: format

on: [push, pull_request]

Expand All @@ -18,3 +18,10 @@ jobs:
- name: Run usort check.
run: |
usort check .
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v1
with:
args: "format --check"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: lint-ruff
name: lint

on: [push, pull_request]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: type-check-mypy
name: type-check

on: [push, pull_request]

Expand Down
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.0
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/facebook/usort
rev: v1.0.8
hooks:
- id: usort
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ optimizers should be created in a separate public repo.
## Pull Requests
We actively welcome your pull requests for existing optimizers.

1. Fork the repo and create your branch from `main`. Install the package inside of your Python environment with `pip install -e ".[dev]"`.
1. Fork the repo and create your branch from `main`. Install the package inside of your Python environment with `pip install -e ".[dev]"`. Run `pre-commit install` to set up the git hook scripts.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes. To run the subset of the tests that can be run on CPU use `make test`; to run the tests for a single GPU use `make test-gpu` and to run the subset of tests that require 2-4 GPUs use `make test-multi-gpu`.
5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible. Use `make type-check` for type checking.
5. Make sure your code lints. You can use `make lint` and `make format` to automatically lint and format the code where possible (will also automatically run on `git commit` if the pre-commit hook was installed; however, this does not guarantee that there are not linting errors left). Use `make type-check` for type checking.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

> [!NOTE]
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
3.10 | 3.11 | 3.12](https://img.shields.io/badge/python-3.10_|_3.11_|_3.12-blue.svg)](https://www.python.org/downloads/)
![tests](https://github.com/facebookresearch/optimizers/actions/workflows/tests.yaml/badge.svg)
![gpu-tests](https://github.com/facebookresearch/optimizers/actions/workflows/gpu-tests.yaml/badge.svg)
![lint-ruff](https://github.com/facebookresearch/optimizers/actions/workflows/lint-ruff.yaml/badge.svg)
![format-ruff](https://github.com/facebookresearch/optimizers/actions/workflows/format-ruff.yaml/badge.svg)
![format-usort](https://github.com/facebookresearch/optimizers/actions/workflows/format-usort.yaml/badge.svg)
![type-check-mypy](https://github.com/facebookresearch/optimizers/actions/workflows/type-check-mypy.yaml/badge.svg)
![linting](https://github.com/facebookresearch/optimizers/actions/workflows/lint.yaml/badge.svg)
![formatting](https://github.com/facebookresearch/optimizers/actions/workflows/format.yaml/badge.svg)
![type-checking](https://github.com/facebookresearch/optimizers/actions/workflows/type-check.yaml/badge.svg)
![examples](https://github.com/facebookresearch/optimizers/actions/workflows/examples.yaml/badge.svg)

*Copyright (c) Meta Platforms, Inc. and affiliates.
Expand Down
17 changes: 10 additions & 7 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,14 +1044,17 @@ def _per_group_step_impl(
use_decoupled_weight_decay,
)

with DequantizePreconditionersContext(
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
), (
with (
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
),
(
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
),
):
# Update Shampoo and grafting preconditioners.
# Example for AdaGrad accumulation:
Expand Down
97 changes: 55 additions & 42 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def setUp(self) -> None:
)

def test_invalid_grafting_config(self) -> None:
with mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
), self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
with (
mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
),
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
),
),
):
DistributedShampoo(
Expand Down Expand Up @@ -126,22 +129,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None:
incorrect_hyperparameter_setting,
expected_error_msg,
) in incorrect_hyperparameter_setting_and_expected_error_msg:
with self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
), self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
with (
self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
),
self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)),
):
DistributedShampoo(
self._model.parameters(),
**incorrect_hyperparameter_setting,
)

def test_invalid_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -150,12 +157,13 @@ def test_invalid_pytorch_compile_setting(self) -> None:
shampoo_pt2_compile_config=ShampooPT2CompileConfig(),
)

with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -165,11 +173,12 @@ def test_invalid_pytorch_compile_setting(self) -> None:
)

def test_warning_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=True
), self.assertLogs(
level="WARNING",
) as cm:
with (
mock.patch.object(torch.cuda, "is_available", return_value=True),
self.assertLogs(
level="WARNING",
) as cm,
):
DistributedShampoo(
self._model.parameters(),
lr=0.01,
Expand All @@ -183,12 +192,13 @@ def test_warning_pytorch_compile_setting(self) -> None:
)

def test_invalid_cuda_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
),
),
):
DistributedShampoo(
Expand All @@ -213,16 +223,19 @@ def test_nesterov_and_zero_momentum(self) -> None:
)

def test_invalid_distributed_config(self) -> None:
with self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
with (
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
),
),
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
),
), mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
):
DistributedShampoo(
params=self._model.parameters(),
Expand Down
24 changes: 13 additions & 11 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ class AdaGradGraftingConfigTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
grafting_config_type = self._get_grafting_config_type()
with self.subTest(
grafting_config_type=grafting_config_type
), self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
with (
self.subTest(grafting_config_type=grafting_config_type),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
grafting_config_type(epsilon=epsilon)

Expand All @@ -46,12 +47,13 @@ def test_illegal_beta2(
) -> None:
grafting_config_type = self._get_grafting_config_type()
for beta2 in (-1.0, 0.0, 1.3):
with self.subTest(
grafting_config_type=grafting_config_type, beta2=beta2
), self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
with (
self.subTest(grafting_config_type=grafting_config_type, beta2=beta2),
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
),
),
):
grafting_config_type(beta2=beta2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ def test_dist_is_initialized(self) -> None:
device_mesh=mesh_2d,
)

with mock.patch.object(
torch.distributed, "is_initialized", return_value=False
), self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
with (
mock.patch.object(torch.distributed, "is_initialized", return_value=False),
self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
),
):
ShampooHSDPDistributorTest._train_model(
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
Expand All @@ -339,12 +340,15 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
)

# Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group.
with mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
), self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
with (
mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
),
self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
),
):
ShampooHSDPDistributorTest._train_model(
Expand Down
Loading

0 comments on commit 14146a1

Please sign in to comment.