Skip to content

Commit

Permalink
Slight refactor of TypeVar definition in shampoo_types_test.py
Browse files Browse the repository at this point in the history
Summary: This is slightly succinct by saving one `Type` usage and make the`bound` check more semantically correct (ie., checking the use `PreconditionerConfigType` is the subtype of `PreconditionerConfig`).

Differential Revision: D67624903
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 24, 2024
1 parent 7b14418 commit 81678e5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_grafting_config_type(


PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=Type[PreconditionerConfig]
"PreconditionerConfigType", bound=PreconditionerConfig
)


Expand Down Expand Up @@ -104,12 +104,12 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
@abstractmethod
def _get_preconditioner_config_type(
self,
) -> PreconditionerConfigType: ...
) -> Type[PreconditionerConfigType]: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[ShampooPreconditionerConfig]
ShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
Expand All @@ -120,7 +120,7 @@ def _get_preconditioner_config_type(

class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
Type[EigenvalueCorrectedShampooPreconditionerConfig]
EigenvalueCorrectedShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
Expand Down

0 comments on commit 81678e5

Please sign in to comment.