Skip to content

Commit a3cbcd1

Browse files
committed
Fix return type and tolerance in test
1 parent 8411ed6 commit a3cbcd1

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

distributed_shampoo/utils/shampoo_preconditioner_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@ def _adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance
11651165
off_diagonal_summed = squared_approximate_eigenvalues.fill_diagonal_(0.0).sum()
11661166
norm = torch.sqrt(diagonal_summed + off_diagonal_summed)
11671167
off_diagonal_norm = torch.sqrt(off_diagonal_summed)
1168-
return off_diagonal_norm <= tolerance * norm
1168+
return bool(off_diagonal_norm <= tolerance * norm)
11691169

11701170
@torch.compiler.disable
11711171
def _amortized_computation(self) -> None:

distributed_shampoo/utils/tests/shampoo_preconditioner_list_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ def test_adaptive_amortized_computation_frequency_criterion_below_tolerance(
11361136
assert not (off_diagonal_norm <= test_tolerance1 * norm)
11371137

11381138
# Below tolerance.
1139-
test_tolerance2 = 2e-7
1139+
test_tolerance2 = 3e-7
11401140
test_criterion2 = EigenvalueCorrectedShampooPreconditionerList._adaptive_amortized_computation_frequency_criterion_below_or_equal_tolerance(
11411141
test_factor_matrix, test_factor_matrix_eigenvectors, test_tolerance2
11421142
)

0 commit comments

Comments
 (0)