diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 4e2e3cb..2fbebbb 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -24,21 +24,37 @@ SubclassesType = TypeVar("SubclassesType") -def get_all_subclasses(cls: SubclassesType) -> list[SubclassesType]: +def get_all_subclasses( + cls: SubclassesType, include_cls_self: bool = True +) -> list[SubclassesType]: + """ + Retrieves all subclasses of a given class, optionally including the class itself. + + This function uses a helper function to recursively find all unique subclasses + of the specified class. + + Args: + cls (SubclassesType): The class for which to find subclasses. + include_cls_self (bool): Whether to include the class itself in the result. (Default: True) + + Returns: + list[SubclassesType]: A list of all unique subclasses of the given class. + """ + def get_all_unique_subclasses(cls: SubclassesType) -> set[SubclassesType]: """Gets all unique subclasses of a given class recursively.""" assert ( subclasses := getattr(cls, "__subclasses__", lambda: None)() ) is not None, f"{cls} does not have __subclasses__." - return reduce(or_, map(get_all_unique_subclasses, subclasses), set()) + return reduce(or_, map(get_all_unique_subclasses, subclasses), {cls}) - return list(get_all_unique_subclasses(cls)) + return list(get_all_unique_subclasses(cls) - (set() if include_cls_self else {cls})) class AdaGradGraftingConfigSubclassesTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 - for cls in [AdaGradGraftingConfig] + get_all_subclasses(AdaGradGraftingConfig): + for cls in get_all_subclasses(AdaGradGraftingConfig): with self.subTest(cls=cls): self.assertRaisesRegex( ValueError, @@ -53,7 +69,7 @@ def test_illegal_beta2( self, ) -> None: for cls, beta2 in itertools.product( - [RMSpropGraftingConfig] + get_all_subclasses(RMSpropGraftingConfig), + get_all_subclasses(RMSpropGraftingConfig), (-1.0, 0.0, 1.3), ): with self.subTest(cls=cls, beta2=beta2): @@ -70,7 +86,8 @@ def test_illegal_beta2( class PreconditionerConfigSubclassesTest(unittest.TestCase): def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: num_tolerated_failed_amortized_computations = -1 - for cls in get_all_subclasses(PreconditionerConfig): + # Not testing for the base class PreconditionerConfig because it is an abstract class. + for cls in get_all_subclasses(PreconditionerConfig, include_cls_self=False): with self.subTest(cls=cls): self.assertRaisesRegex( ValueError,