diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 14bbab0..b96c1c7 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -617,22 +617,21 @@ def test_precondition_grad(self) -> None: # Compare the results of preconditioning the gradient with both setups for different contract dimensions. for dims in (([0], [0]), ([0], [1])): - torch.testing.assert_close( - self._preconditioner_list._precondition_grad( - grad=grad, - preconditioned_dims_selector=experimental_preconditioned_dims_selector, - preconditioner_list=experimental_preconditioner_list, - dims=dims, - ), - self._preconditioner_list._precondition_grad( - grad=grad, - preconditioned_dims_selector=control_preconditioned_dims_selector, - preconditioner_list=control_preconditioner_list, - dims=dims, - ), - rtol=0.0, - atol=0.0, - ) + with self.subTest(dims=dims): + torch.testing.assert_close( + self._preconditioner_list._precondition_grad( # type: ignore[attr-defined] + grad=grad, + preconditioned_dims_selector=experimental_preconditioned_dims_selector, + preconditioner_list=experimental_preconditioner_list, + dims=dims, + ), + self._preconditioner_list._precondition_grad( # type: ignore[attr-defined] + grad=grad, + preconditioned_dims_selector=control_preconditioned_dims_selector, + preconditioner_list=control_preconditioner_list, + dims=dims, + ), + ) def test_numel_list(self) -> None: self.assertEqual(self._preconditioner_list.numel_list, (8, 16, 10))