diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 6687d56..8548d5b 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -695,8 +695,11 @@ def _check_factor_matrix_for_diagonality_nan_and_inf( f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}." ) - def _raise_exception_if_tolerance_exceeded( - self, success_tracker: list[bool], idx: int, exception: Exception + def _raise_exception_if_failure_tolerance_exceeded( + self, + success_tracker: list[bool], + preconditioner_index: int, + exception: Exception, ) -> None: """Raises an exception if the number of failed amortized computations exceeds the tolerance. @@ -704,7 +707,7 @@ def _raise_exception_if_tolerance_exceeded( Args: success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful. - idx (int): The index of the preconditioner. + preconditioner_index (int): The index of the preconditioner. exception (Exception): The exception to raise. Raises: @@ -713,15 +716,22 @@ def _raise_exception_if_tolerance_exceeded( """ if all(success_tracker): # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 + self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] = 0 else: # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 + self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] += 1 # Raise the exception if the tolerance at the given index is exceeded. + failure_counter = self._masked_failed_amortized_computation_counter_list[ + preconditioner_index + ] tolerance = ( self._preconditioner_config.num_tolerated_failed_amortized_computations ) - if self._masked_failed_amortized_computation_counter_list[idx] > tolerance: + if failure_counter > tolerance: raise exception def update_preconditioners( @@ -1052,7 +1062,7 @@ def _amortized_computation(self) -> None: inv_factor_matrix.copy_(computed_inv_factor_matrix) # Only reuse previous inverse roots if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( + self._raise_exception_if_failure_tolerance_exceeded( success_tracker, idx, ValueError( @@ -1382,7 +1392,7 @@ def _amortized_computation(self) -> None: factor_matrix_eigenvectors.copy_(computed_eigenvectors) # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( + self._raise_exception_if_failure_tolerance_exceeded( success_tracker, idx, ValueError(