Skip to content

Commit

Permalink
Refactor AdamGraftingConfig to inherit from RMSPropGraftingConfig (#74)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #74

This diff refactors `AdamGraftingConfig` to inherit from `RMSPropGraftingConfig`, taking advantage of the fact that `Adam` is essentially `RMSProp` with momentum. This inheritance was previously blocked by an `isinstance` check, but is now enabled by the switch to `type` checking.

Additionally, this diff fixes an oversight in [the previous `type.__subclasses__()` refactor](#72), ensuring that parent classes are properly included.

Reviewed By: anana10c

Differential Revision: D67719358

fbshipit-source-id: 24a0264bd3d314f2b5b370af3b075dde4773fb4f
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Jan 1, 2025
1 parent 8159f78 commit fe7cd45
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
9 changes: 1 addition & 8 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __post_init__(self) -> None:


@dataclass(kw_only=True)
class AdamGraftingConfig(AdaGradGraftingConfig):
class AdamGraftingConfig(RMSpropGraftingConfig):
"""Configuration for grafting from Adam.
Args:
Expand All @@ -337,10 +337,3 @@ class AdamGraftingConfig(AdaGradGraftingConfig):
"""

beta2: float = 0.999

def __post_init__(self) -> None:
super().__post_init__()
if not 0.0 < self.beta2 <= 1.0:
raise ValueError(
f"Invalid grafting beta2 parameter: {self.beta2}. Must be in (0.0, 1.0]."
)
24 changes: 21 additions & 3 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import itertools
import re
import unittest
from functools import reduce
from operator import or_
from typing import TypeVar

from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
Expand All @@ -18,10 +21,24 @@
)


SubclassesType = TypeVar("SubclassesType")


def get_all_subclasses(cls: SubclassesType) -> list[SubclassesType]:
def get_all_unique_subclasses(cls: SubclassesType) -> set[SubclassesType]:
"""Gets all unique subclasses of a given class recursively."""
assert (
subclasses := getattr(cls, "__subclasses__", lambda: None)()
) is not None, f"{cls} does not have __subclasses__."
return reduce(or_, map(get_all_unique_subclasses, subclasses), set())

return list(get_all_unique_subclasses(cls))


class AdaGradGraftingConfigSubclassesTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
for cls in AdaGradGraftingConfig.__subclasses__():
for cls in [AdaGradGraftingConfig] + get_all_subclasses(AdaGradGraftingConfig):
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
Expand All @@ -36,7 +53,8 @@ def test_illegal_beta2(
self,
) -> None:
for cls, beta2 in itertools.product(
RMSpropGraftingConfig.__subclasses__(), (-1.0, 0.0, 1.3)
[RMSpropGraftingConfig] + get_all_subclasses(RMSpropGraftingConfig),
(-1.0, 0.0, 1.3),
):
with self.subTest(cls=cls, beta2=beta2):
self.assertRaisesRegex(
Expand All @@ -52,7 +70,7 @@ def test_illegal_beta2(
class PreconditionerConfigSubclassesTest(unittest.TestCase):
def test_illegal_num_tolerated_failed_amortized_computations(self) -> None:
num_tolerated_failed_amortized_computations = -1
for cls in PreconditionerConfig.__subclasses__():
for cls in get_all_subclasses(PreconditionerConfig):
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit fe7cd45

Please sign in to comment.