Skip to content

Commit

Permalink
Update formatting according to new ruff version
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Nov 25, 2024
1 parent c373c71 commit 9322df0
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 136 deletions.
17 changes: 10 additions & 7 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,14 +1044,17 @@ def _per_group_step_impl(
use_decoupled_weight_decay,
)

with DequantizePreconditionersContext(
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
), (
with (
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST]
),
(
DequantizePreconditionersContext(
preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST]
)
if grafting_config_not_none
else contextlib.nullcontext()
),
):
# Update Shampoo and grafting preconditioners.
# Example for AdaGrad accumulation:
Expand Down
97 changes: 55 additions & 42 deletions distributed_shampoo/tests/distributed_shampoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ def setUp(self) -> None:
)

def test_invalid_grafting_config(self) -> None:
with mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
), self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
with (
mock.patch.object(
distributed_shampoo, "type", side_effect=lambda object: GraftingConfig
),
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig"
),
),
):
DistributedShampoo(
Expand Down Expand Up @@ -126,22 +129,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None:
incorrect_hyperparameter_setting,
expected_error_msg,
) in incorrect_hyperparameter_setting_and_expected_error_msg:
with self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
), self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)):
with (
self.subTest(
incorrect_hyperparameter_setting=incorrect_hyperparameter_setting,
expected_error_msg=expected_error_msg,
),
self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)),
):
DistributedShampoo(
self._model.parameters(),
**incorrect_hyperparameter_setting,
)

def test_invalid_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -150,12 +157,13 @@ def test_invalid_pytorch_compile_setting(self) -> None:
shampoo_pt2_compile_config=ShampooPT2CompileConfig(),
)

with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating."
),
),
):
DistributedShampoo(
Expand All @@ -165,11 +173,12 @@ def test_invalid_pytorch_compile_setting(self) -> None:
)

def test_warning_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=True
), self.assertLogs(
level="WARNING",
) as cm:
with (
mock.patch.object(torch.cuda, "is_available", return_value=True),
self.assertLogs(
level="WARNING",
) as cm,
):
DistributedShampoo(
self._model.parameters(),
lr=0.01,
Expand All @@ -183,12 +192,13 @@ def test_warning_pytorch_compile_setting(self) -> None:
)

def test_invalid_cuda_pytorch_compile_setting(self) -> None:
with mock.patch.object(
torch.cuda, "is_available", return_value=False
), self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
with (
mock.patch.object(torch.cuda, "is_available", return_value=False),
self.assertRaisesRegex(
ValueError,
re.escape(
"Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None."
),
),
):
DistributedShampoo(
Expand All @@ -213,16 +223,19 @@ def test_nesterov_and_zero_momentum(self) -> None:
)

def test_invalid_distributed_config(self) -> None:
with self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
with (
self.assertRaisesRegex(
NotImplementedError,
re.escape(
"distributed_config=DDPShampooConfig(communication_dtype=<CommunicationDType.DEFAULT: 0>, "
"num_trainers_per_group=-1, communicate_params=False) not supported!"
),
),
mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
),
), mock.patch.object(
distributed_shampoo,
"type",
side_effect=lambda object: DistributedConfig,
):
DistributedShampoo(
params=self._model.parameters(),
Expand Down
24 changes: 13 additions & 11 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ class AdaGradGraftingConfigTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
grafting_config_type = self._get_grafting_config_type()
with self.subTest(
grafting_config_type=grafting_config_type
), self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
with (
self.subTest(grafting_config_type=grafting_config_type),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
grafting_config_type(epsilon=epsilon)

Expand All @@ -46,12 +47,13 @@ def test_illegal_beta2(
) -> None:
grafting_config_type = self._get_grafting_config_type()
for beta2 in (-1.0, 0.0, 1.3):
with self.subTest(
grafting_config_type=grafting_config_type, beta2=beta2
), self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
with (
self.subTest(grafting_config_type=grafting_config_type, beta2=beta2),
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
),
),
):
grafting_config_type(beta2=beta2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ def test_dist_is_initialized(self) -> None:
device_mesh=mesh_2d,
)

with mock.patch.object(
torch.distributed, "is_initialized", return_value=False
), self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
with (
mock.patch.object(torch.distributed, "is_initialized", return_value=False),
self.assertRaisesRegex(
RuntimeError,
re.escape("HSDPDistributor needs torch.distributed to be initialized!"),
),
):
ShampooHSDPDistributorTest._train_model(
optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory(
Expand All @@ -339,12 +340,15 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group(
)

# Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group.
with mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
), self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
with (
mock.patch.object(
torch.distributed.device_mesh.DeviceMesh, "size", return_value=4
),
self.assertRaisesRegex(
ValueError,
re.escape(
"distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!"
),
),
):
ShampooHSDPDistributorTest._train_model(
Expand Down
88 changes: 55 additions & 33 deletions distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,16 @@ def _test_compress_preconditioner_list(
self,
expected_compress_list_call_count: int,
) -> None:
with mock.patch.object(
shampoo_preconditioner_list,
"compress_list",
) as mock_compress_list, mock.patch.object(
QuantizedTensorList,
"compress",
) as mock_compress_quant_list:
with (
mock.patch.object(
shampoo_preconditioner_list,
"compress_list",
) as mock_compress_list,
mock.patch.object(
QuantizedTensorList,
"compress",
) as mock_compress_quant_list,
):
# Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList.
# Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside.
self.assertIsNone(
Expand Down Expand Up @@ -328,11 +331,16 @@ def _instantiate_preconditioner_list( # type: ignore[override]
def _test_raise_invalid_value_in_factor_matrix(
self, invalid_value: float
) -> None:
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertRaisesRegex(
PreconditionerValueError,
re.escape(f"Encountered {str(invalid_value)} values in factor matrix"),
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertRaisesRegex(
PreconditionerValueError,
re.escape(
f"Encountered {str(invalid_value)} values in factor matrix"
),
),
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
Expand All @@ -356,16 +364,21 @@ def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation(
self,
) -> None:
for invalid_value in (torch.nan, torch.inf):
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.subTest(invalid_value=invalid_value), self.assertRaisesRegex(
PreconditionerValueError,
re.escape("Encountered nan or inf values in"),
), mock.patch.object(
shampoo_preconditioner_list,
self._amortized_computation_function(),
side_effect=(torch.tensor([invalid_value]),),
) as mock_amortized_computation:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.subTest(invalid_value=invalid_value),
self.assertRaisesRegex(
PreconditionerValueError,
re.escape("Encountered nan or inf values in"),
),
mock.patch.object(
shampoo_preconditioner_list,
self._amortized_computation_function(),
side_effect=(torch.tensor([invalid_value]),),
) as mock_amortized_computation,
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand All @@ -384,9 +397,12 @@ def test_amortized_computation_internal_failure(self) -> None:
# Simulate the situation throws an exception (not nan and inf) to test the warning
side_effect=ZeroDivisionError,
) as mock_amortized_computation:
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertLogs(level="WARNING") as cm:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertLogs(level="WARNING") as cm,
):
# Because use_protected_eigh is True, we expect the warning to be logged.
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
Expand Down Expand Up @@ -415,9 +431,12 @@ def test_amortized_computation_internal_failure(self) -> None:
self._preconditioner_list = self._instantiate_preconditioner_list(
use_protected_eigh=False,
)
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertRaises(ZeroDivisionError):
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertRaises(ZeroDivisionError),
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand All @@ -443,11 +462,14 @@ def test_amortized_computation_factor_matrix_non_diagonal(
self._preconditioner_list = self._instantiate_preconditioner_list(
epsilon=1.0
)
with DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
), self.assertLogs(
level="DEBUG",
) as cm:
with (
DequantizePreconditionersContext(
preconditioner_list=self._preconditioner_list
),
self.assertLogs(
level="DEBUG",
) as cm,
):
self._preconditioner_list.update_preconditioners(
masked_grad_list=(
torch.tensor([1.0, 0.0]),
Expand Down
Loading

0 comments on commit 9322df0

Please sign in to comment.