Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update formatting according to new ruff version #55

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,14 +1044,17 @@ def _per_group_step_impl(
use_decoupled_weight_decay,
)

with DequantizePreconditionersContext(
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
), (
with (
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
),
(
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
),
):
# Update Shampoo and grafting preconditioners.
# Example for AdaGrad accumulation:
Expand Down
97 changes: 55 additions & 42 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def setUp(self) -> None:
)

def test_invalid_grafting_config(self) -> None:
with mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
), self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
with (
mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
),
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
),
),
):
DistributedShampoo(
Expand Down Expand Up @@ -126,22 +129,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None:
incorrect_hyperparameter_setting,
expected_error_msg,
) in incorrect_hyperparameter_setting_and_expected_error_msg:
with self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
), self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
with (
self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
),
self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)),
):
DistributedShampoo(
self._model.parameters(),
**incorrect_hyperparameter_setting,
)

def test_invalid_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -150,12 +157,13 @@ def test_invalid_pytorch_compile_setting(self) -> None:
shampoo_pt2_compile_config=ShampooPT2CompileConfig(),
)

with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -165,11 +173,12 @@ def test_invalid_pytorch_compile_setting(self) -> None:
)

def test_warning_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=True
), self.assertLogs(
level="WARNING",
) as cm:
with (
mock.patch.object(torch.cuda, "is_available", return_value=True),
self.assertLogs(
level="WARNING",
) as cm,
):
DistributedShampoo(
self._model.parameters(),
lr=0.01,
Expand All @@ -183,12 +192,13 @@ def test_warning_pytorch_compile_setting(self) -> None:
)

def test_invalid_cuda_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
),
),
):
DistributedShampoo(
Expand All @@ -213,16 +223,19 @@ def test_nesterov_and_zero_momentum(self) -> None:
)

def test_invalid_distributed_config(self) -> None:
with self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
with (
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
),
),
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
),
), mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
):
DistributedShampoo(
params=self._model.parameters(),
Expand Down
24 changes: 13 additions & 11 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ class AdaGradGraftingConfigTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
grafting_config_type = self._get_grafting_config_type()
with self.subTest(
grafting_config_type=grafting_config_type
), self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
with (
self.subTest(grafting_config_type=grafting_config_type),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
grafting_config_type(epsilon=epsilon)

Expand All @@ -46,12 +47,13 @@ def test_illegal_beta2(
) -> None:
grafting_config_type = self._get_grafting_config_type()
for beta2 in (-1.0, 0.0, 1.3):
with self.subTest(
grafting_config_type=grafting_config_type, beta2=beta2
), self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
with (
self.subTest(grafting_config_type=grafting_config_type, beta2=beta2),
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
),
),
):
grafting_config_type(beta2=beta2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ def test_dist_is_initialized(self) -> None:
device_mesh=mesh_2d,
)

with mock.patch.object(
torch.distributed, "is_initialized", return_value=False
), self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
with (
mock.patch.object(torch.distributed, "is_initialized", return_value=False),
self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
),
):
ShampooHSDPDistributorTest._train_model(
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
Expand All @@ -339,12 +340,15 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
)

# Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group.
with mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
), self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
with (
mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
),
self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
),
):
ShampooHSDPDistributorTest._train_model(
Expand Down
88 changes: 55 additions & 33 deletions distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,16 @@ def _test_compress_preconditioner_list(
self,
expected_compress_list_call_count: int,
) -> None:
with mock.patch.object(
shampoo_preconditioner_list,
"compress_list",
) as mock_compress_list, mock.patch.object(
QuantizedTensorList,
"compress",
) as mock_compress_quant_list:
with (
mock.patch.object(
shampoo_preconditioner_list,
"compress_list",
) as mock_compress_list,
mock.patch.object(
QuantizedTensorList,
"compress",
) as mock_compress_quant_list,
):
# Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList.
# Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside.
self.assertIsNone(
Expand Down Expand Up @@ -328,11 +331,16 @@ def _instantiate_preconditioner_list( # type: ignore[override]
def _test_raise_invalid_value_in_factor_matrix(
self, invalid_value: float
) -> None:
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertRaisesRegex(
PreconditionerValueError,
re.escape(f"Encountered {str(invalid_value)} values in factor matrix"),
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertRaisesRegex(
PreconditionerValueError,
re.escape(
f"Encountered {str(invalid_value)} values in factor matrix"
),
),
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
Expand All @@ -356,16 +364,21 @@ def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation(
self,
) -> None:
for invalid_value in (torch.nan, torch.inf):
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.subTest(invalid_value=invalid_value), self.assertRaisesRegex(
PreconditionerValueError,
re.escape("Encountered nan or inf values in"),
), mock.patch.object(
shampoo_preconditioner_list,
self._amortized_computation_function(),
side_effect=(torch.tensor([invalid_value]),),
) as mock_amortized_computation:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.subTest(invalid_value=invalid_value),
self.assertRaisesRegex(
PreconditionerValueError,
re.escape("Encountered nan or inf values in"),
),
mock.patch.object(
shampoo_preconditioner_list,
self._amortized_computation_function(),
side_effect=(torch.tensor([invalid_value]),),
) as mock_amortized_computation,
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand All @@ -384,9 +397,12 @@ def test_amortized_computation_internal_failure(self) -> None:
# Simulate the situation throws an exception (not nan and inf) to test the warning
side_effect=ZeroDivisionError,
) as mock_amortized_computation:
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertLogs(level="WARNING") as cm:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertLogs(level="WARNING") as cm,
):
# Because use_protected_eigh is True, we expect the warning to be logged.
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
Expand Down Expand Up @@ -415,9 +431,12 @@ def test_amortized_computation_internal_failure(self) -> None:
self._preconditioner_list = self._instantiate_preconditioner_list(
use_protected_eigh=False,
)
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertRaises(ZeroDivisionError):
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertRaises(ZeroDivisionError),
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand All @@ -443,11 +462,14 @@ def test_amortized_computation_factor_matrix_non_diagonal(
self._preconditioner_list = self._instantiate_preconditioner_list(
epsilon=1.0
)
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertLogs(
level="DEBUG",
) as cm:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertLogs(
level="DEBUG",
) as cm,
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand Down
Loading
Loading