Skip to content

Commit

Permalink
Slight refactor of TypeVar definition in shampoo_types_test.py (#71)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #71

This is slightly succinct by saving one `Type` usage and make the`bound` check more semantically correct (ie., checking the use `PreconditionerConfigType` is the subtype of `PreconditionerConfig`).

Reviewed By: chuanhaozhuge

Differential Revision: D67624903

fbshipit-source-id: 880c25bab712c9bf05e89967b1dc6335a9611de0
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 25, 2024
1 parent 7b14418 commit 3b8a3a6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_grafting_config_type(


PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=Type[PreconditionerConfig]
"PreconditionerConfigType", bound=PreconditionerConfig
)


Expand Down Expand Up @@ -104,12 +104,12 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
@abstractmethod
def _get_preconditioner_config_type(
self,
) -> PreconditionerConfigType: ...
) -> Type[PreconditionerConfigType]: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[ShampooPreconditionerConfig]
ShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
Expand All @@ -120,7 +120,7 @@ def _get_preconditioner_config_type(

class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[EigenvalueCorrectedShampooPreconditionerConfig]
EigenvalueCorrectedShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
Expand Down

0 comments on commit 3b8a3a6

Please sign in to comment.