Skip to content

Commit

Permalink
Improve naming
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 10, 2024
1 parent 8f19d2f commit a92348d
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,16 +695,19 @@ def _check_factor_matrix_for_diagonality_nan_and_inf(
f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}."
)

def _raise_exception_if_tolerance_exceeded(
self, success_tracker: list[bool], idx: int, exception: Exception
def _raise_exception_if_failure_tolerance_exceeded(
self,
success_tracker: list[bool],
preconditioner_index: int,
exception: Exception,
) -> None:
"""Raises an exception if the number of failed amortized computations exceeds the tolerance.
Resets the counter at the given index when all amortized computations are successful.
Args:
success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful.
idx (int): The index of the preconditioner.
preconditioner_index (int): The index of the preconditioner.
exception (Exception): The exception to raise.
Raises:
Expand All @@ -713,15 +716,22 @@ def _raise_exception_if_tolerance_exceeded(
"""
if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] = 0
self._masked_failed_amortized_computation_counter_list[
preconditioner_index
] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[idx] += 1
self._masked_failed_amortized_computation_counter_list[
preconditioner_index
] += 1
# Raise the exception if the tolerance at the given index is exceeded.
failure_counter = self._masked_failed_amortized_computation_counter_list[
preconditioner_index
]
tolerance = (
self._preconditioner_config.num_tolerated_failed_amortized_computations
)
if self._masked_failed_amortized_computation_counter_list[idx] > tolerance:
if failure_counter > tolerance:
raise exception

def update_preconditioners(
Expand Down Expand Up @@ -1052,7 +1062,7 @@ def _amortized_computation(self) -> None:
inv_factor_matrix.copy_(computed_inv_factor_matrix)

# Only reuse previous inverse roots if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._raise_exception_if_failure_tolerance_exceeded(
success_tracker,
idx,
ValueError(
Expand Down Expand Up @@ -1382,7 +1392,7 @@ def _amortized_computation(self) -> None:
factor_matrix_eigenvectors.copy_(computed_eigenvectors)

# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_tolerance_exceeded(
self._raise_exception_if_failure_tolerance_exceeded(
success_tracker,
idx,
ValueError(
Expand Down

0 comments on commit a92348d

Please sign in to comment.