From fe7cd458fd39cf3980bb702ef1444a7a1dbccb60 Mon Sep 17 00:00:00 2001 From: Tsung-Hsien Lee Date: Wed, 1 Jan 2025 11:59:06 -0800 Subject: [PATCH] Refactor AdamGraftingConfig to inherit from RMSPropGraftingConfig (#74) Summary: Pull Request resolved: https://github.com/facebookresearch/optimizers/pull/74 This diff refactors `AdamGraftingConfig` to inherit from `RMSPropGraftingConfig`, taking advantage of the fact that `Adam` is essentially `RMSProp` with momentum. This inheritance was previously blocked by an `isinstance` check, but is now enabled by the switch to `type` checking. Additionally, this diff fixes an oversight in [the previous `type.__subclasses__()` refactor](https://github.com/facebookresearch/optimizers/pull/72), ensuring that parent classes are properly included. Reviewed By: anana10c Differential Revision: D67719358 fbshipit-source-id: 24a0264bd3d314f2b5b370af3b075dde4773fb4f --- distributed_shampoo/shampoo_types.py | 9 +------ .../tests/shampoo_types_test.py | 24 ++++++++++++++++--- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 2e89352..d95abed 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -326,7 +326,7 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) -class AdamGraftingConfig(AdaGradGraftingConfig): +class AdamGraftingConfig(RMSpropGraftingConfig): """Configuration for grafting from Adam. Args: @@ -337,10 +337,3 @@ class AdamGraftingConfig(AdaGradGraftingConfig): """ beta2: float = 0.999 - - def __post_init__(self) -> None: - super().__post_init__() - if not 0.0 < self.beta2 <= 1.0: - raise ValueError( - f"Invalid grafting beta2 parameter: {self.beta2}. Must be in (0.0, 1.0]." - ) diff --git a/distributed_shampoo/tests/shampoo_types_test.py b/distributed_shampoo/tests/shampoo_types_test.py index 19d32d1..4e2e3cb 100644 --- a/distributed_shampoo/tests/shampoo_types_test.py +++ b/distributed_shampoo/tests/shampoo_types_test.py @@ -10,6 +10,9 @@ import itertools import re import unittest +from functools import reduce +from operator import or_ +from typing import TypeVar from distributed_shampoo.shampoo_types import ( AdaGradGraftingConfig, @@ -18,10 +21,24 @@ ) +SubclassesType = TypeVar("SubclassesType") + + +def get_all_subclasses(cls: SubclassesType) -> list[SubclassesType]: + def get_all_unique_subclasses(cls: SubclassesType) -> set[SubclassesType]: + """Gets all unique subclasses of a given class recursively.""" + assert ( + subclasses := getattr(cls, "__subclasses__", lambda: None)() + ) is not None, f"{cls} does not have __subclasses__." + return reduce(or_, map(get_all_unique_subclasses, subclasses), set()) + + return list(get_all_unique_subclasses(cls)) + + class AdaGradGraftingConfigSubclassesTest(unittest.TestCase): def test_illegal_epsilon(self) -> None: epsilon = 0.0 - for cls in AdaGradGraftingConfig.__subclasses__(): + for cls in [AdaGradGraftingConfig] + get_all_subclasses(AdaGradGraftingConfig): with self.subTest(cls=cls): self.assertRaisesRegex( ValueError, @@ -36,7 +53,8 @@ def test_illegal_beta2( self, ) -> None: for cls, beta2 in itertools.product( - RMSpropGraftingConfig.__subclasses__(), (-1.0, 0.0, 1.3) + [RMSpropGraftingConfig] + get_all_subclasses(RMSpropGraftingConfig), + (-1.0, 0.0, 1.3), ): with self.subTest(cls=cls, beta2=beta2): self.assertRaisesRegex( @@ -52,7 +70,7 @@ def test_illegal_beta2( class PreconditionerConfigSubclassesTest(unittest.TestCase): def test_illegal_num_tolerated_failed_amortized_computations(self) -> None: num_tolerated_failed_amortized_computations = -1 - for cls in PreconditionerConfig.__subclasses__(): + for cls in get_all_subclasses(PreconditionerConfig): with self.subTest(cls=cls): self.assertRaisesRegex( ValueError,