From a3cbcd114a140757fbc6be02b9f5330ded66b0ae Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 23 Feb 2025 23:57:21 +0000 Subject: [PATCH] Fix return type and tolerance in test --- distributed_shampoo/utils/shampoo_preconditioner_list.py | 2 +- .../utils/tests/shampoo_preconditioner_list_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 695f300..9d4b015 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -1165,7 +1165,7 @@ def _adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance off_diagonal_summed = squared_approximate_eigenvalues.fill_diagonal_(0.0).sum() norm = torch.sqrt(diagonal_summed + off_diagonal_summed) off_diagonal_norm = torch.sqrt(off_diagonal_summed) - return off_diagonal_norm <= tolerance * norm + return bool(off_diagonal_norm <= tolerance * norm) @torch.compiler.disable def _amortized_computation(self) -> None: diff --git a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py index 6705860..d7075e1 100644 --- a/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py @@ -1136,7 +1136,7 @@ def test_adaptive_amortized_computation_frequency_criterion_below_tolerance( assert not (off_diagonal_norm <= test_tolerance1 * norm) # Below tolerance. - test_tolerance2 = 2e-7 + test_tolerance2 = 3e-7 test_criterion2 = EigenvalueCorrectedShampooPreconditionerList._adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance( test_factor_matrix, test_factor_matrix_eigenvectors, test_tolerance2 )