diff --git a/distributed_shampoo/utils/shampoo_preconditioner_list.py b/distributed_shampoo/utils/shampoo_preconditioner_list.py index 8684fe9..6687d56 100644 --- a/distributed_shampoo/utils/shampoo_preconditioner_list.py +++ b/distributed_shampoo/utils/shampoo_preconditioner_list.py @@ -696,26 +696,33 @@ def _check_factor_matrix_for_diagonality_nan_and_inf( ) def _raise_exception_if_tolerance_exceeded( - self, counter: int, exception: Exception + self, success_tracker: list[bool], idx: int, exception: Exception ) -> None: """Raises an exception if the number of failed amortized computations exceeds the tolerance. + Resets the counter at the given index when all amortized computations are successful. + Args: - counter (int): The counter for the number of failed amortized computations. + success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful. + idx (int): The index of the preconditioner. exception (Exception): The exception to raise. Raises: exception (Exception): The exception to raise. """ - tolerance = ( - self._preconditioner_config.num_tolerated_failed_amortized_computations - ) - if counter > tolerance: - logger.error( - f"Exceeded tolerance ({tolerance}) for number of failed amortized computations." + if all(success_tracker): + # Reset counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] = 0 + else: + # Increment counter for failed amortized computations. + self._masked_failed_amortized_computation_counter_list[idx] += 1 + # Raise the exception if the tolerance at the given index is exceeded. + tolerance = ( + self._preconditioner_config.num_tolerated_failed_amortized_computations ) - raise exception + if self._masked_failed_amortized_computation_counter_list[idx] > tolerance: + raise exception def update_preconditioners( self, @@ -1044,19 +1051,14 @@ def _amortized_computation(self) -> None: ) inv_factor_matrix.copy_(computed_inv_factor_matrix) - if all(success_tracker): - # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 - else: - # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 - # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - self._masked_failed_amortized_computation_counter_list[idx], - ValueError( - f"Exceeded tolerance for number of failed root inverse computations for {kronecker_factors.factor_matrix_indices}." - ), - ) + # Only reuse previous inverse roots if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + success_tracker, + idx, + ValueError( + f"Exceeded tolerance for number of failed inverse root computations for {kronecker_factors.factor_matrix_indices}." + ), + ) def dequantize_preconditioners(self) -> None: with profiler.record_function( @@ -1379,19 +1381,14 @@ def _amortized_computation(self) -> None: ) factor_matrix_eigenvectors.copy_(computed_eigenvectors) - if all(success_tracker): - # Reset counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] = 0 - else: - # Increment counter for failed amortized computations. - self._masked_failed_amortized_computation_counter_list[idx] += 1 - # Only reuse previous eigenvectors if tolerance is not exceeded. - self._raise_exception_if_tolerance_exceeded( - self._masked_failed_amortized_computation_counter_list[idx], - ValueError( - f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." - ), - ) + # Only reuse previous eigenvectors if tolerance is not exceeded. + self._raise_exception_if_tolerance_exceeded( + success_tracker, + idx, + ValueError( + f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}." + ), + ) def dequantize_preconditioners(self) -> None: with profiler.record_function(