Skip to content

Commit

Permalink
Refactor AdamGraftingConfig to inherit from RMSPropGraftingConfig
Browse files Browse the repository at this point in the history
Summary:
This diff refactors `AdamGraftingConfig` to inherit from `RMSPropGraftingConfig`, taking advantage of the fact that `Adam` is essentially `RMSProp` with momentum. This inheritance was previously blocked by an `isinstance` check, but is now enabled by the switch to `type` checking.

Additionally, this diff fixes an oversight in [the previous type.__subclasses__() refactor](#72), ensuring that parent classes are properly included.

Differential Revision: D67719358
  • Loading branch information
tsunghsienlee authored and facebook-github-bot committed Dec 30, 2024
1 parent 8159f78 commit 47b9019
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
9 changes: 1 addition & 8 deletions distributed_shampoo/shampoo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __post_init__(self) -> None:


@dataclass(kw_only=True)
class AdamGraftingConfig(AdaGradGraftingConfig):
class AdamGraftingConfig(RMSpropGraftingConfig):
"""Configuration for grafting from Adam.
Args:
Expand All @@ -337,10 +337,3 @@ class AdamGraftingConfig(AdaGradGraftingConfig):
"""

beta2: float = 0.999

def __post_init__(self) -> None:
super().__post_init__()
if not 0.0 < self.beta2 <= 1.0:
raise ValueError(
f"Invalid grafting beta2 parameter: {self.beta2}. Must be in (0.0, 1.0]."
)
5 changes: 3 additions & 2 deletions distributed_shampoo/tests/shampoo_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class AdaGradGraftingConfigSubclassesTest(unittest.TestCase):
def test_illegal_epsilon(self) -> None:
epsilon = 0.0
for cls in AdaGradGraftingConfig.__subclasses__():
for cls in [AdaGradGraftingConfig] + AdaGradGraftingConfig.__subclasses__():
with self.subTest(cls=cls):
self.assertRaisesRegex(
ValueError,
Expand All @@ -36,7 +36,8 @@ def test_illegal_beta2(
self,
) -> None:
for cls, beta2 in itertools.product(
RMSpropGraftingConfig.__subclasses__(), (-1.0, 0.0, 1.3)
[RMSpropGraftingConfig] + RMSpropGraftingConfig.__subclasses__(),
(-1.0, 0.0, 1.3),
):
with self.subTest(cls=cls, beta2=beta2):
self.assertRaisesRegex(
Expand Down

0 comments on commit 47b9019

Please sign in to comment.