Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add control over tolerance for failed amortized computations #64

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
55f1413
Refactor matrix functions types
runame Dec 6, 2024
9291d15
Refactor Shampoo types
runame Dec 6, 2024
dac3ac1
Adjust UI and docs
runame Dec 6, 2024
7e20cf3
Replace preconditioner_computation_config with preconditioner_config
runame Dec 6, 2024
1bbaa2f
Fix docstring
runame Dec 7, 2024
32c5df5
Add tolerance for amortized computation failures
runame Dec 9, 2024
65f4f27
Add test for amortized computation failure tolerance
runame Dec 9, 2024
9a137b5
Adjust abstractmethod test
runame Dec 9, 2024
03245f5
Add check that tolerance value non-negative
runame Dec 9, 2024
527c35e
Make failure tracking coarser
runame Dec 9, 2024
01dde5f
Reduce code duplication
runame Dec 10, 2024
6d9810b
Merge branch 'main' into configs-refactor
runame Dec 10, 2024
d3a10af
Set default values
runame Dec 10, 2024
709ab1c
Merge branch 'configs-refactor' into fail-counter
runame Dec 10, 2024
196a781
Merge branch 'fail-counter' into fail-counter-v2
runame Dec 10, 2024
273e0a1
Fix defaults with default_factory
runame Dec 10, 2024
d53ffc6
Merge branch 'configs-refactor' into fail-counter
runame Dec 10, 2024
8f19d2f
Merge branch 'fail-counter' into fail-counter-v2
runame Dec 10, 2024
a92348d
Improve naming
runame Dec 10, 2024
736c76d
Simplify test
runame Dec 12, 2024
54b8879
Merge branch 'main' into fail-counter-v2
runame Dec 18, 2024
cf08da0
Fix test
runame Dec 19, 2024
5b37d84
Use keywords explicitly
runame Dec 19, 2024
9e0e46e
Merge branch 'main' into fail-counter-v2
runame Dec 19, 2024
7793429
Revert outdated change
runame Dec 19, 2024
98051d2
Simplify no warnings assertion
runame Dec 19, 2024
8ea571e
Remove leftover variable
runame Dec 19, 2024
c098c6a
Improve readability of call count check
runame Dec 20, 2024
701e5e9
Merge branch 'main' into fail-counter-v2
runame Dec 20, 2024
16853ea
Further improve readability of test
runame Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@ class PreconditionerConfig(AbstractDataclass):

Args:
amortized_computation_config (MatrixFunctionConfig): Configuration for the amortized computation, e.g., inverse-root or eigenvector computation.
num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3)

"""

amortized_computation_config: MatrixFunctionConfig # type: ignore
num_tolerated_failed_amortized_computations: int = 3

def __post_init__(self) -> None:
if self.num_tolerated_failed_amortized_computations < 0:
raise ValueError(
f"Invalid num_tolerated_failed_amortized_computations value: {self.num_tolerated_failed_amortized_computations}. Must be >= 0."
)


@dataclass(kw_only=True)
Expand Down
84 changes: 72 additions & 12 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@

import re
import unittest
from typing import Type
from abc import ABC, abstractmethod
from typing import Generic, Type, TypeVar

from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
ShampooPreconditionerConfig,
)


class AdaGradGraftingConfigTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
grafting_config_type = self._get_grafting_config_type()
with self.subTest(
grafting_config_type=grafting_config_type
), self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
with (
self.subTest(grafting_config_type=grafting_config_type),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
grafting_config_type(epsilon=epsilon)

Expand All @@ -46,12 +51,13 @@ def test_illegal_beta2(
) -> None:
grafting_config_type = self._get_grafting_config_type()
for beta2 in (-1.0, 0.0, 1.3):
with self.subTest(
grafting_config_type=grafting_config_type, beta2=beta2
), self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
with (
self.subTest(grafting_config_type=grafting_config_type, beta2=beta2),
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
),
),
):
grafting_config_type(beta2=beta2)
Expand All @@ -67,3 +73,57 @@ def _get_grafting_config_type(
self,
) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]:
return AdamGraftingConfig


PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=Type[PreconditionerConfig]
)


class AbstractPreconditionerConfigTest:
class PreconditionerConfigTest(
ABC,
unittest.TestCase,
Generic[PreconditionerConfigType],
):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
with (
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid num_tolerated_failed_amortized_computations value: "
f"{num_tolerated_failed_amortized_computations}. Must be >= 0."
),
),
):
self._get_preconditioner_config_type()(
num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations,
)

@abstractmethod
def _get_preconditioner_config_type(
self,
) -> PreconditionerConfigType: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[ShampooPreconditionerConfig]
]
):
def _get_preconditioner_config_type(
self,
) -> Type[ShampooPreconditionerConfig]:
return ShampooPreconditionerConfig


class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[EigenvalueCorrectedShampooPreconditionerConfig]
]
):
def _get_preconditioner_config_type(
self,
) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]:
return EigenvalueCorrectedShampooPreconditionerConfig
Comment on lines +110 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @runame ,

It seems these two tests do not discovered by the test discovery; if you check the CI before and after this pull request, the total number of tests ran does not change, and I believe this is due to those two classes here are still considered as abstract class so they don't get instantiated at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the two tests are actually discovered and run by the CI. Note that the "before" CI does not actually reflect the state before the two tests are added, since it is merely the PR that was created before it (#63 before #64). See the commit history for the actual order in which the PRs were merged. Now we can compare the CI before and after the two tests were added and see that the number of tests increased from 23 to 25.

Finally, I also verified locally that commenting out the two tests results in two less tests being run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I check the commit history, and you are right the two tests are discovered in here.

Now it seems the issue is happened in the the Meta internal test discovery could not find those two tests so I was wondering the current setup is the culprit. If I run the shampoo_types_test.py in Meta internal, it will only discover 5 test cases, but it should be 7 test cases.

Screenshot 2024-12-24 at 8 04 25 AM

We might need a different setup to accommodate this unfortunately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. Do you have any idea why the tests are not discovered internally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generic is the reason because if we don't use that, it will be discovered. However, how and why are something I don't know, I will create a minimum example with Generic to verify this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#72 is a better design to resolve this but we might want to figure this out in the future for curiosity sake.

97 changes: 90 additions & 7 deletions distributed_shampoo/utils/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,45 @@ def _check_factor_matrix_for_diagonality_nan_and_inf(
f"{factor_matrix.isinf().any()=}, {factor_matrix.isnan().any()=}."
)

def _raise_exception_if_failure_tolerance_exceeded(
self,
success_tracker: list[bool],
preconditioner_index: int,
exception: Exception,
) -> None:
"""Raises an exception if the number of failed amortized computations exceeds the tolerance.

Resets the counter at the given index when all amortized computations are successful.

Args:
success_tracker (list[bool]): A list of booleans indicating whether the amortized computation was successful.
preconditioner_index (int): The index of the preconditioner.
exception (Exception): The exception to raise.

Raises:
exception (Exception): The exception to raise.

"""
if all(success_tracker):
# Reset counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[
preconditioner_index
] = 0
else:
# Increment counter for failed amortized computations.
self._masked_failed_amortized_computation_counter_list[
preconditioner_index
] += 1
# Raise the exception if the tolerance at the given index is exceeded.
failure_counter = self._masked_failed_amortized_computation_counter_list[
preconditioner_index
]
tolerance = (
self._preconditioner_config.num_tolerated_failed_amortized_computations
)
if failure_counter > tolerance:
raise exception

def update_preconditioners(
self,
masked_grad_list: tuple[Tensor, ...],
Expand Down Expand Up @@ -678,10 +717,16 @@ def _initialize_state_lists(
self._inv_root_override,
self._local_order_list,
)
self._local_failed_amortized_computation_counter_list: list[int] = [0] * len(
self._local_kronecker_factors_list
)

# Masked lists are the list of active preconditioners or values after filtering out gradients with None.
self._masked_order_list: tuple[int, ...] = self._local_order_list
self._masked_root_list: tuple[int, ...] = self._local_root_list
self._masked_failed_amortized_computation_counter_list: list[int] = (
self._local_failed_amortized_computation_counter_list
)
self._masked_kronecker_factors_list: tuple[
ShampooKroneckerFactorsListType,
...,
Expand Down Expand Up @@ -714,6 +759,14 @@ def compress_preconditioner_list(
self._masked_root_list: tuple[int, ...] = compress_list( # type: ignore[no-redef]
self._local_root_list, local_grad_selector
)
self._masked_failed_amortized_computation_counter_list: list[int] = ( # type: ignore[no-redef]
list(
compress_list(
self._local_failed_amortized_computation_counter_list,
local_grad_selector,
)
)
)
self._masked_kronecker_factors_list: tuple[ # type: ignore[no-redef]
ShampooKroneckerFactorsListType,
...,
Expand Down Expand Up @@ -850,11 +903,14 @@ def _amortized_computation(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##"
):
for kronecker_factors, root in zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
strict=True,
for idx, (kronecker_factors, root) in enumerate(
zip(
self._masked_kronecker_factors_list,
self._masked_root_list,
strict=True,
)
):
success_tracker: list[bool] = []
for (
factor_matrix,
inv_factor_matrix,
Expand Down Expand Up @@ -898,8 +954,11 @@ def _amortized_computation(self) -> None:
epsilon=self._epsilon,
is_diagonal=bool(is_factor_matrix_diagonal),
).to(dtype=inv_factor_matrix.dtype)
# Add success to success tracker.
success_tracker.append(True)
except Exception as exception:
# Reuse previous matrix if matrix inverse root computation fails.
# Add failure to success tracker.
success_tracker.append(False)
logger.warning(
f"Matrix computation failed for factor matrix {factor_matrix_index} "
f"with {exception=}. Using previous inverted factor matrix and continuing..."
Expand All @@ -919,6 +978,15 @@ def _amortized_computation(self) -> None:
)
inv_factor_matrix.copy_(computed_inv_factor_matrix)

# Only reuse previous inverse roots if tolerance is not exceeded.
self._raise_exception_if_failure_tolerance_exceeded(
success_tracker=success_tracker,
preconditioner_index=idx,
exception=ValueError(
f"Exceeded tolerance for number of failed inverse root computations for {kronecker_factors.factor_matrix_indices}."
),
)


class EigenvalueCorrectedShampooPreconditionerList(
BaseShampooPreconditionerList[EigenvalueCorrectedShampooKroneckerFactorsList]
Expand Down Expand Up @@ -1098,7 +1166,10 @@ def _amortized_computation(self) -> None:
with profiler.record_function(
f"## {self.__class__.__name__}:{self._amortized_computation.__name__} ##"
):
for kronecker_factors in self._masked_kronecker_factors_list:
for idx, kronecker_factors in enumerate(
self._masked_kronecker_factors_list
):
success_tracker: list[bool] = []
for (
factor_matrix,
factor_matrix_eigenvectors,
Expand Down Expand Up @@ -1129,8 +1200,11 @@ def _amortized_computation(self) -> None:
eigenvector_computation_config=eigenvector_computation_config,
is_diagonal=bool(is_factor_matrix_diagonal),
)
# Add success to success tracker.
success_tracker.append(True)
except Exception as exception:
# Reuse previous matrix if matrix eigenvector computation fails.
# Add failure to success tracker.
success_tracker.append(False)
logger.warning(
f"Matrix computation failed for factor matrix {factor_matrix_index} "
f"with {exception=}. Using previous factor matrix eigenvectors and continuing..."
Expand All @@ -1149,3 +1223,12 @@ def _amortized_computation(self) -> None:
f"To mitigate, check factor matrix before the matrix computation: {factor_matrix=}"
)
factor_matrix_eigenvectors.copy_(computed_eigenvectors)

# Only reuse previous eigenvectors if tolerance is not exceeded.
self._raise_exception_if_failure_tolerance_exceeded(
success_tracker=success_tracker,
preconditioner_index=idx,
exception=ValueError(
f"Exceeded tolerance for number of failed eigenvector computations for {kronecker_factors.factor_matrix_indices}."
),
)
Loading
Loading