From 3b8a3a608401dae77e8063fd6ac2a1ef2d30591f Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Wed, 25 Dec 2024 08:20:20 -0800 Subject: [PATCH] Slight refactor of TypeVar definition in `shampoo_types_test.py` (#71) Summary: Pull Request resolved: https://github.com/facebookresearch/optimizers/pull/71 This is slightly succinct by saving one `Type` usage and make the`bound` check more semantically correct (ie., checking the use `PreconditionerConfigType` is the subtype of `PreconditionerConfig`). Reviewed By: chuanhaozhuge Differential Revision: D67624903 fbshipit-source-id: 880c25bab712c9bf05e89967b1dc6335a9611de0 --- distributed_shampoo/tests/shampoo_types_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 5b2e588..62258b5 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -76,7 +76,7 @@ def _get_grafting_config_type( PreconditionerConfigType = TypeVar( - "PreconditionerConfigType", bound=Type[PreconditionerConfig] + "PreconditionerConfigType", bound=PreconditionerConfig ) @@ -104,12 +104,12 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: @abstractmethod def _get_preconditioner_config_type( self, - ) -> PreconditionerConfigType: ... + ) -> Type[PreconditionerConfigType]: ... class ShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[ShampooPreconditionerConfig] + ShampooPreconditionerConfig ] ): def _get_preconditioner_config_type( @@ -120,7 +120,7 @@ def _get_preconditioner_config_type( class EigenvalueCorrectedShampooPreconditionerConfigTest( AbstractPreconditionerConfigTest.PreconditionerConfigTest[ - Type[EigenvalueCorrectedShampooPreconditionerConfig] + EigenvalueCorrectedShampooPreconditionerConfig ] ): def _get_preconditioner_config_type(