diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 5b2e588..62258b5 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -76,7 +76,7 @@ def _get_grafting_config_type( PreconditionerConfigType = TypeVar( - "PreconditionerConfigType", bound=Type[PreconditionerConfig] + "PreconditionerConfigType", bound=PreconditionerConfig ) @@ -104,12 +104,12 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: @abstractmethod def _get_preconditioner_config_type( self, - ) -> PreconditionerConfigType: ... + ) -> Type[PreconditionerConfigType]: ... class ShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[ShampooPreconditionerConfig] + ShampooPreconditionerConfig ] ): def _get_preconditioner_config_type( @@ -120,7 +120,7 @@ def _get_preconditioner_config_type( class EigenvalueCorrectedShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[EigenvalueCorrectedShampooPreconditionerConfig] + EigenvalueCorrectedShampooPreconditionerConfig ] ): def _get_preconditioner_config_type(