Skip to content

Commit

Permalink
Reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 10, 2024
1 parent 527c35e commit 01dde5f
Showing 1 changed file with 32 additions and 35 deletions.
67 changes: 32 additions & 35 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,26 +696,33 @@ def _check_factor_matrix_for_diagonality_nan_and_inf(
)

def _raise_exception_if_tolerance_exceeded(
self, counter: int, exception: Exception
self, success_tracker: list[bool], idx: int, exception: Exception
) -> None:
"""Raises an exception if the number of failed amortized computations exceeds the tolerance.
Resets the counter at the given index when all amortized computations are successful.
Args:
counter (int): The counter for the number of failed amortized computations.
success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful.
idx (int): The index of the preconditioner.
exception (Exception): The exception to raise.
Raises:
exception (Exception): The exception to raise.
"""
tolerance = (
self._preconditioner_config.num_tolerated_failed_amortized_computations
)
if counter > tolerance:
logger.error(
f"Exceeded tolerance ({tolerance}) for number of failed amortized computations."
if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
# Raise the exception if the tolerance at the given index is exceeded.
tolerance = (
self._preconditioner_config.num_tolerated_failed_amortized_computations
)
raise exception
if self._masked_failed_amortized_computation_counter_list[idx] > tolerance:
raise exception

def update_preconditioners(
self,
Expand Down Expand Up @@ -1044,19 +1051,14 @@ def _amortized_computation(self) -> None:
)
inv_factor_matrix.copy_(computed_inv_factor_matrix)

if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._masked_failed_amortized_computation_counter_list[idx],
ValueError(
f"Exceeded tolerance for number of failed root inverse computations for {kronecker_factors.factor_matrix_indices}."
),
)
# Only reuse previous inverse roots if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
success_tracker,
idx,
ValueError(
f"Exceeded tolerance for number of failed inverse root computations for {kronecker_factors.factor_matrix_indices}."
),
)

def dequantize_preconditioners(self) -> None:
with profiler.record_function(
Expand Down Expand Up @@ -1379,19 +1381,14 @@ def _amortized_computation(self) -> None:
)
factor_matrix_eigenvectors.copy_(computed_eigenvectors)

if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._masked_failed_amortized_computation_counter_list[idx],
ValueError(
f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}."
),
)
# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
success_tracker,
idx,
ValueError(
f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}."
),
)

def dequantize_preconditioners(self) -> None:
with profiler.record_function(
Expand Down

0 comments on commit 01dde5f

Please sign in to comment.