From 9322df0d9db8985d99f43f6fd67835ed5b4b27d3 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 25 Nov 2024 18:28:19 +0000 Subject: [PATCH] Update formatting according to new ruff version --- distributed_shampoo/distributed_shampoo.py | 17 ++-- .../tests/distributed_shampoo_test.py | 97 +++++++++++-------- .../tests/shampoo_types_test.py | 24 ++--- .../shampoo_hsdp_distributor_test.py | 26 ++--- .../tests/shampoo_preconditioner_list_test.py | 88 ++++++++++------- .../utils/tests/shampoo_quantization_test.py | 19 ++-- tests/matrix_functions_test.py | 55 ++++++----- 7 files changed, 190 insertions(+), 136 deletions(-) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index e0b5337..d222b1f 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -1044,14 +1044,17 @@ def _per_group_step_impl( use_decoupled_weight_decay, ) - with DequantizePreconditionersContext( - preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] - ), ( + with ( DequantizePreconditionersContext( - preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] - ) - if grafting_config_not_none - else contextlib.nullcontext() + preconditioner_list=state_lists[SHAMPOO_PRECONDITIONER_LIST] + ), + ( + DequantizePreconditionersContext( + preconditioner_list=state_lists[GRAFTING_PRECONDITIONER_LIST] + ) + if grafting_config_not_none + else contextlib.nullcontext() + ), ): # Update Shampoo and grafting preconditioners. # Example for AdaGrad accumulation: diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index c191f40..396db52 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -47,12 +47,15 @@ def setUp(self) -> None: ) def test_invalid_grafting_config(self) -> None: - with mock.patch.object( - distributed_shampoo, "type", side_effect=lambda object: GraftingConfig - ), self.assertRaisesRegex( - NotImplementedError, - re.escape( - "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + with ( + mock.patch.object( + distributed_shampoo, "type", side_effect=lambda object: GraftingConfig + ), + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "Unsupported grafting config: group[GRAFTING_CONFIG]=SGDGraftingConfig" + ), ), ): DistributedShampoo( @@ -126,22 +129,26 @@ def test_invalid_with_incorrect_hyperparameter_setting(self) -> None: incorrect_hyperparameter_setting, expected_error_msg, ) in incorrect_hyperparameter_setting_and_expected_error_msg: - with self.subTest( - incorrect_hyperparameter_setting=incorrect_hyperparameter_setting, - expected_error_msg=expected_error_msg, - ), self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)): + with ( + self.subTest( + incorrect_hyperparameter_setting=incorrect_hyperparameter_setting, + expected_error_msg=expected_error_msg, + ), + self.assertRaisesRegex(ValueError, re.escape(expected_error_msg)), + ): DistributedShampoo( self._model.parameters(), **incorrect_hyperparameter_setting, ) def test_invalid_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "Both use_pytorch_compile and shampoo_pt2_compile_config are provided. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + ), ), ): DistributedShampoo( @@ -150,12 +157,13 @@ def test_invalid_pytorch_compile_setting(self) -> None: shampoo_pt2_compile_config=ShampooPT2CompileConfig(), ) - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "use_pytorch_compile=False conflicts with non-None shampoo_pt2_compile_config arg. Please use only shampoo_pt2_compile_config as use_pytorch_compile is deprecating." + ), ), ): DistributedShampoo( @@ -165,11 +173,12 @@ def test_invalid_pytorch_compile_setting(self) -> None: ) def test_warning_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=True - ), self.assertLogs( - level="WARNING", - ) as cm: + with ( + mock.patch.object(torch.cuda, "is_available", return_value=True), + self.assertLogs( + level="WARNING", + ) as cm, + ): DistributedShampoo( self._model.parameters(), lr=0.01, @@ -183,12 +192,13 @@ def test_warning_pytorch_compile_setting(self) -> None: ) def test_invalid_cuda_pytorch_compile_setting(self) -> None: - with mock.patch.object( - torch.cuda, "is_available", return_value=False - ), self.assertRaisesRegex( - ValueError, - re.escape( - "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + with ( + mock.patch.object(torch.cuda, "is_available", return_value=False), + self.assertRaisesRegex( + ValueError, + re.escape( + "Backend does NOT support Pytorch 2.0 compile. Switch to use_pytorch_compile in (False, None) and shampoo_pt2_compile_config=None." + ), ), ): DistributedShampoo( @@ -213,16 +223,19 @@ def test_nesterov_and_zero_momentum(self) -> None: ) def test_invalid_distributed_config(self) -> None: - with self.assertRaisesRegex( - NotImplementedError, - re.escape( - "distributed_config=DDPShampooConfig(communication_dtype=, " - "num_trainers_per_group=-1, communicate_params=False) not supported!" + with ( + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "distributed_config=DDPShampooConfig(communication_dtype=, " + "num_trainers_per_group=-1, communicate_params=False) not supported!" + ), + ), + mock.patch.object( + distributed_shampoo, + "type", + side_effect=lambda object: DistributedConfig, ), - ), mock.patch.object( - distributed_shampoo, - "type", - side_effect=lambda object: DistributedConfig, ): DistributedShampoo( params=self._model.parameters(), diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 49d8a52..2ec773f 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -22,11 +22,12 @@ class AdaGradGraftingConfigTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 grafting_config_type = self._get_grafting_config_type() - with self.subTest( - grafting_config_type=grafting_config_type - ), self.assertRaisesRegex( - ValueError, - re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), + with ( + self.subTest(grafting_config_type=grafting_config_type), + self.assertRaisesRegex( + ValueError, + re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."), + ), ): grafting_config_type(epsilon=epsilon) @@ -46,12 +47,13 @@ def test_illegal_beta2( ) -> None: grafting_config_type = self._get_grafting_config_type() for beta2 in (-1.0, 0.0, 1.3): - with self.subTest( - grafting_config_type=grafting_config_type, beta2=beta2 - ), self.assertRaisesRegex( - ValueError, - re.escape( - f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." + with ( + self.subTest(grafting_config_type=grafting_config_type, beta2=beta2), + self.assertRaisesRegex( + ValueError, + re.escape( + f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]." + ), ), ): grafting_config_type(beta2=beta2) diff --git a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py index f9fd65b..fa58208 100644 --- a/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/utils/gpu_tests/shampoo_hsdp_distributor_test.py @@ -313,11 +313,12 @@ def test_dist_is_initialized(self) -> None: device_mesh=mesh_2d, ) - with mock.patch.object( - torch.distributed, "is_initialized", return_value=False - ), self.assertRaisesRegex( - RuntimeError, - re.escape("HSDPDistributor needs torch.distributed to be initialized!"), + with ( + mock.patch.object(torch.distributed, "is_initialized", return_value=False), + self.assertRaisesRegex( + RuntimeError, + re.escape("HSDPDistributor needs torch.distributed to be initialized!"), + ), ): ShampooHSDPDistributorTest._train_model( optim_factory=ShampooHSDPDistributorTest._shampoo_optim_factory( @@ -339,12 +340,15 @@ def test_incompatible_replicated_group_size_and_num_trainers_per_group( ) # Hijack the DeviceMesh.size() method to return 4 instead of 2 to bypass the check of num_trainers_per_group. - with mock.patch.object( - torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 - ), self.assertRaisesRegex( - ValueError, - re.escape( - "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" + with ( + mock.patch.object( + torch.distributed.device_mesh.DeviceMesh, "size", return_value=4 + ), + self.assertRaisesRegex( + ValueError, + re.escape( + "distributed_config.num_trainers_per_group=3 must divide self._replicated_group_size=4!" + ), ), ): ShampooHSDPDistributorTest._train_model( diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index cce8aca..6d38af6 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -125,13 +125,16 @@ def _test_compress_preconditioner_list( self, expected_compress_list_call_count: int, ) -> None: - with mock.patch.object( - shampoo_preconditioner_list, - "compress_list", - ) as mock_compress_list, mock.patch.object( - QuantizedTensorList, - "compress", - ) as mock_compress_quant_list: + with ( + mock.patch.object( + shampoo_preconditioner_list, + "compress_list", + ) as mock_compress_list, + mock.patch.object( + QuantizedTensorList, + "compress", + ) as mock_compress_quant_list, + ): # Count the number of list compressions at the preconditioner list level, including compressions of QuantizedTensorList. # Each call to compress() under QuantizedTensorList counts once, though note that it calls compress_list() three times inside. self.assertIsNone( @@ -328,11 +331,16 @@ def _instantiate_preconditioner_list( # type: ignore[override] def _test_raise_invalid_value_in_factor_matrix( self, invalid_value: float ) -> None: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaisesRegex( - PreconditionerValueError, - re.escape(f"Encountered {str(invalid_value)} values in factor matrix"), + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertRaisesRegex( + PreconditionerValueError, + re.escape( + f"Encountered {str(invalid_value)} values in factor matrix" + ), + ), ): self._preconditioner_list.update_preconditioners( masked_grad_list=( @@ -356,16 +364,21 @@ def test_raise_nan_and_inf_in_inv_factor_matrix_amortized_computation( self, ) -> None: for invalid_value in (torch.nan, torch.inf): - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.subTest(invalid_value=invalid_value), self.assertRaisesRegex( - PreconditionerValueError, - re.escape("Encountered nan or inf values in"), - ), mock.patch.object( - shampoo_preconditioner_list, - self._amortized_computation_function(), - side_effect=(torch.tensor([invalid_value]),), - ) as mock_amortized_computation: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.subTest(invalid_value=invalid_value), + self.assertRaisesRegex( + PreconditionerValueError, + re.escape("Encountered nan or inf values in"), + ), + mock.patch.object( + shampoo_preconditioner_list, + self._amortized_computation_function(), + side_effect=(torch.tensor([invalid_value]),), + ) as mock_amortized_computation, + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -384,9 +397,12 @@ def test_amortized_computation_internal_failure(self) -> None: # Simulate the situation throws an exception (not nan and inf) to test the warning side_effect=ZeroDivisionError, ) as mock_amortized_computation: - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs(level="WARNING") as cm: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertLogs(level="WARNING") as cm, + ): # Because use_protected_eigh is True, we expect the warning to be logged. self._preconditioner_list.update_preconditioners( masked_grad_list=( @@ -415,9 +431,12 @@ def test_amortized_computation_internal_failure(self) -> None: self._preconditioner_list = self._instantiate_preconditioner_list( use_protected_eigh=False, ) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertRaises(ZeroDivisionError): + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertRaises(ZeroDivisionError), + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), @@ -443,11 +462,14 @@ def test_amortized_computation_factor_matrix_non_diagonal( self._preconditioner_list = self._instantiate_preconditioner_list( epsilon=1.0 ) - with DequantizePreconditionersContext( - preconditioner_list=self._preconditioner_list - ), self.assertLogs( - level="DEBUG", - ) as cm: + with ( + DequantizePreconditionersContext( + preconditioner_list=self._preconditioner_list + ), + self.assertLogs( + level="DEBUG", + ) as cm, + ): self._preconditioner_list.update_preconditioners( masked_grad_list=( torch.tensor([1.0, 0.0]), diff --git a/distributed_shampoo/utils/tests/shampoo_quantization_test.py b/distributed_shampoo/utils/tests/shampoo_quantization_test.py index 96a06b4..13979d9 100644 --- a/distributed_shampoo/utils/tests/shampoo_quantization_test.py +++ b/distributed_shampoo/utils/tests/shampoo_quantization_test.py @@ -108,14 +108,17 @@ def test_invalid_quantized_data_type(self) -> None: class QuantizedTensorListInitTest(unittest.TestCase): def test_invalid_quantized_data_type(self) -> None: - with mock.patch.object( - shampoo_quantization, - "isinstance", - side_effect=lambda object, classinfo: False, - ), self.assertRaisesRegex( - TypeError, - re.escape( - "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " + with ( + mock.patch.object( + shampoo_quantization, + "isinstance", + side_effect=lambda object, classinfo: False, + ), + self.assertRaisesRegex( + TypeError, + re.escape( + "quantized_data must be collections.abc.Sequence[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | collections.abc.Sequence[distributed_shampoo.utils.shampoo_quantization.QuantizedTensor] but get " + ), ), ): QuantizedTensorList( diff --git a/tests/matrix_functions_test.py b/tests/matrix_functions_test.py index 4ea176d..1658201 100644 --- a/tests/matrix_functions_test.py +++ b/tests/matrix_functions_test.py @@ -231,23 +231,27 @@ def test_matrix_inverse_root_reach_max_iterations(self) -> None: implementation, msg, ) in root_inv_config_and_implementation_and_msg: - with mock.patch.object( - matrix_functions, - implementation, - return_value=( - None, - None, - NewtonConvergenceFlag.REACHED_MAX_ITERS, - None, - None, + with ( + mock.patch.object( + matrix_functions, + implementation, + return_value=( + None, + None, + NewtonConvergenceFlag.REACHED_MAX_ITERS, + None, + None, + ), ), - ), self.subTest( - root_inv_config=root_inv_config, - implementation=implementation, - msg=msg, - ), self.assertLogs( - level="WARNING", - ) as cm: + self.subTest( + root_inv_config=root_inv_config, + implementation=implementation, + msg=msg, + ), + self.assertLogs( + level="WARNING", + ) as cm, + ): matrix_inverse_root( A=A, root=root, @@ -891,14 +895,17 @@ def test_matrix_eigenvectors(self) -> None: def test_invalid_eigenvalue_correction_config( self, ) -> None: - with mock.patch.object( - matrix_functions, - "type", - side_effect=lambda object: EigenvalueCorrectionConfig, - ), self.assertRaisesRegex( - NotImplementedError, - re.escape( - "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + with ( + mock.patch.object( + matrix_functions, + "type", + side_effect=lambda object: EigenvalueCorrectionConfig, + ), + self.assertRaisesRegex( + NotImplementedError, + re.escape( + "Eigenvector computation method is not implemented! Specified eigenvector method is eigenvector_computation_config=EighEigenvalueCorrectionConfig(retry_double_precision=True)." + ), ), ): matrix_eigenvectors(