diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index dd9a59d..5b2e588 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -20,13 +20,6 @@ RMSpropGraftingConfig, ShampooPreconditionerConfig, ) -from matrix_functions_types import ( - DefaultEigenConfig, - DefaultEighConfig, - EigenvectorConfig, - MatrixFunctionConfig, - RootInvConfig, -) class AdaGradGraftingConfigTest(unittest.TestCase): @@ -85,16 +78,13 @@ def _get_grafting_config_type( PreconditionerConfigType = TypeVar( "PreconditionerConfigType", bound=Type[PreconditionerConfig] ) -AmortizedComputationConfigType = TypeVar( - "AmortizedComputationConfigType", bound=MatrixFunctionConfig -) class AbstractPreconditionerConfigTest: class PreconditionerConfigTest( ABC, unittest.TestCase, - Generic[PreconditionerConfigType, AmortizedComputationConfigType], + Generic[PreconditionerConfigType], ): def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: num_tolerated_failed_amortized_computations = -1 @@ -108,7 +98,6 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: ), ): self._get_preconditioner_config_type()( - amortized_computation_config=self._get_amortized_computation_config(), num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations, ) @@ -117,20 +106,12 @@ def _get_preconditioner_config_type( self, ) -> PreconditionerConfigType: ... - @abstractmethod - def _get_amortized_computation_config( - self, - ) -> AmortizedComputationConfigType: ... - class ShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[ShampooPreconditionerConfig], RootInvConfig + Type[ShampooPreconditionerConfig] ] ): - def _get_amortized_computation_config(self) -> RootInvConfig: - return DefaultEigenConfig - def _get_preconditioner_config_type( self, ) -> Type[ShampooPreconditionerConfig]: @@ -139,12 +120,9 @@ def _get_preconditioner_config_type( class EigenvalueCorrectedShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[EigenvalueCorrectedShampooPreconditionerConfig], EigenvectorConfig + Type[EigenvalueCorrectedShampooPreconditionerConfig] ] ): - def _get_amortized_computation_config(self) -> EigenvectorConfig: - return DefaultEighConfig - def _get_preconditioner_config_type( self, ) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]: