Skip to content

Commit

Permalink
Leverage __subclasses__() to improve test configs
Browse files Browse the repository at this point in the history
Summary:
Current tests on the configs with subclasses relied on explicit instantiating subclasses to test it. There are some limitations on this approach:
1. It is hard to catch newly added subclasses.
2. Due to some unknown interactions with `typing.Generic`, [the current `buck` test discovery mechanism is not able to discover the tests in it](#64 (comment)).

To resolve this, this diff refactors those tests with [`type.__subclasses()`](https://docs.python.org/3/reference/datamodel.html#type.__subclasses__) to tests the configs with subclasses.

Differential Revision: D67652761
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 26, 2024
1 parent 3b8a3a6 commit c758179
Showing 1 changed file with 23 additions and 82 deletions.
105 changes: 23 additions & 82 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,88 +7,56 @@
"""

import itertools
import re
import unittest
from abc import ABC, abstractmethod
from typing import Generic, Type, TypeVar

from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
ShampooPreconditionerConfig,
)


class AdaGradGraftingConfigTest(unittest.TestCase):
class AdaGradGraftingConfigSubclassesTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
grafting_config_type = self._get_grafting_config_type()
with (
self.subTest(grafting_config_type=grafting_config_type),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
grafting_config_type(epsilon=epsilon)

def _get_grafting_config_type(
self,
) -> (
Type[AdaGradGraftingConfig]
| Type[RMSpropGraftingConfig]
| Type[AdamGraftingConfig]
):
return AdaGradGraftingConfig
for cls in AdaGradGraftingConfig.__subclasses__():
with (
self.subTest(cls=cls),
self.assertRaisesRegex(
ValueError,
re.escape(f"Invalid epsilon value: {epsilon}. Must be > 0.0."),
),
):
cls(epsilon=epsilon)


class RMSpropGraftingConfigTest(AdaGradGraftingConfigTest):
class RMSpropGraftingConfigSubclassesTest(AdaGradGraftingConfigSubclassesTest):
def test_illegal_beta2(
self,
) -> None:
grafting_config_type = self._get_grafting_config_type()
for beta2 in (-1.0, 0.0, 1.3):
for cls, beta2 in itertools.product(
RMSpropGraftingConfig.__subclasses__(), (-1.0, 0.0, 1.3)
):
with (
self.subTest(grafting_config_type=grafting_config_type, beta2=beta2),
self.subTest(cls=cls, beta2=beta2),
self.assertRaisesRegex(
ValueError,
re.escape(
f"Invalid grafting beta2 parameter: {beta2}. Must be in (0.0, 1.0]."
),
),
):
grafting_config_type(beta2=beta2)

def _get_grafting_config_type(
self,
) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]:
return RMSpropGraftingConfig


class AdamGraftingConfigTest(RMSpropGraftingConfigTest):
def _get_grafting_config_type(
self,
) -> Type[RMSpropGraftingConfig] | Type[AdamGraftingConfig]:
return AdamGraftingConfig


PreconditionerConfigType = TypeVar(
"PreconditionerConfigType", bound=PreconditionerConfig
)
cls(beta2=beta2)


class AbstractPreconditionerConfigTest:
class PreconditionerConfigTest(
ABC,
unittest.TestCase,
Generic[PreconditionerConfigType],
):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
class PreconditionerConfigSubclassesTest(unittest.TestCase):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
for cls in PreconditionerConfig.__subclasses__():
with (
self.subTest(cls=cls),
self.assertRaisesRegex(
ValueError,
re.escape(
Expand All @@ -97,33 +65,6 @@ def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
),
),
):
self._get_preconditioner_config_type()(
cls( # type: ignore
num_tolerated_failed_amortized_computations=num_tolerated_failed_amortized_computations,
)

@abstractmethod
def _get_preconditioner_config_type(
self,
) -> Type[PreconditionerConfigType]: ...


class ShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
ShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
self,
) -> Type[ShampooPreconditionerConfig]:
return ShampooPreconditionerConfig


class EigenvalueCorrectedShampooPreconditionerConfigTest(
AbstractPreconditionerConfigTest.PreconditionerConfigTest[
EigenvalueCorrectedShampooPreconditionerConfig
]
):
def _get_preconditioner_config_type(
self,
) -> Type[EigenvalueCorrectedShampooPreconditionerConfig]:
return EigenvalueCorrectedShampooPreconditionerConfig

0 comments on commit c758179

Please sign in to comment.