Skip to content

Commit

Permalink
Simplify test
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 12, 2024
1 parent a92348d commit 736c76d
Showing 1 changed file with 3 additions and 25 deletions.
28 changes: 3 additions & 25 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
RMSpropGraftingConfig,
ShampooPreconditionerConfig,
)
from matrix_functions_types import (
DefaultEigenConfig,
DefaultEighConfig,
EigenvectorConfig,
MatrixFunctionConfig,
RootInvConfig,
)


class AdaGradGraftingConfigTest(unittest.TestCase):
Expand Down Expand Up @@ -85,16 +78,13 @@ def _get_grafting_config_type(
PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=Type[PreconditionerConfig]
)
AmortizedComputationConfigType = TypeVar(
"AmortizedComputationConfigType", bound=MatrixFunctionConfig
)


class AbstractPreconditionerConfigTest:
class PreconditionerConfigTest(
ABC,
unittest.TestCase,
Generic[PreconditionerConfigType, AmortizedComputationConfigType],
Generic[PreconditionerConfigType],
):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
Expand All @@ -108,7 +98,6 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
),
):
self._get_preconditioner_config_type()(
amortized_computation_config=self._get_amortized_computation_config(),
num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations,
)

Expand All @@ -117,20 +106,12 @@ def _get_preconditioner_config_type(
self,
) -> PreconditionerConfigType: ...

@abstractmethod
def _get_amortized_computation_config(
self,
) -> AmortizedComputationConfigType: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[ShampooPreconditionerConfig], RootInvConfig
Type[ShampooPreconditionerConfig]
]
):
def _get_amortized_computation_config(self) -> RootInvConfig:
return DefaultEigenConfig

def _get_preconditioner_config_type(
self,
) -> Type[ShampooPreconditionerConfig]:
Expand All @@ -139,12 +120,9 @@ def _get_preconditioner_config_type(

class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig
Type[EigenvalueCorrectedShampooPreconditionerConfig]
]
):
def _get_amortized_computation_config(self) -> EigenvectorConfig:
return DefaultEighConfig

def _get_preconditioner_config_type(
self,
) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]:
Expand Down

0 comments on commit 736c76d

Please sign in to comment.