diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 37e1f9a..0d1bd2a 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from itertools import chain from typing import Callable, Iterable, List, Optional, Tuple, Union @@ -81,6 +82,7 @@ class SOAP(optim.Optimizer): power_iter_steps: Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time. max_update_rms: Clip the update RMS to this value (0 means no clipping). + use_kl_shampoo: Whether to use KL-Shampoo correction. """ def __init__( @@ -107,6 +109,7 @@ def __init__( adaptive_update_tolerance: Optional[float] = None, power_iter_steps: int = 1, max_update_rms: float = 0.0, + use_kl_shampoo: bool = False, ) -> None: # Check for betas. if betas is None: @@ -159,6 +162,7 @@ def __init__( "adaptive_update_tolerance": adaptive_update_tolerance, "power_iter_steps": power_iter_steps, "max_update_rms": max_update_rms, + "use_kl_shampoo": use_kl_shampoo, } super().__init__(params, defaults) @@ -194,6 +198,21 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(grad) + if "Q" not in state: + state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] + + # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here + # because it needs access to state and group + kronecker_factor_update_fn = partial( + update_kronecker_factors, precondition_1d=group["precondition_1d"] + ) + if group["use_kl_shampoo"]: + kronecker_factor_update_fn = partial( + update_kronecker_factors_kl_shampoo, + eigenbasis_list=state["Q"], + eps=group["eps"], + ) + # Initialize kronecker factor matrices if "GG" not in state: state["GG"] = init_kronecker_factors( @@ -204,11 +223,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Update preconditioner matrices with gradient statistics, # do not use shampoo_beta for EMA at first step with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): - update_kronecker_factors( - kronecker_factor_list=state["GG"], - grad=grad, - shampoo_beta=0.0, - precondition_1d=group["precondition_1d"], + kronecker_factor_update_fn( + kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"] ) # Increment step counter @@ -228,7 +244,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): grad_projected = precondition( grad=grad, - eigenbasis_list=state.get("Q"), + eigenbasis_list=state["Q"], dims=[[0], [0]], ) torch.cuda.nvtx.range_pop() @@ -255,7 +271,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): norm_precond_grad = precondition( grad=adam_update, - eigenbasis_list=state.get("Q"), + eigenbasis_list=state["Q"], dims=[[0], [1]], ) torch.cuda.nvtx.range_pop() @@ -283,11 +299,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: torch.cuda.nvtx.range_push("update_kronecker_factors") with utils.fp32_matmul_precision(group["fp32_matmul_prec"]): - update_kronecker_factors( + kronecker_factor_update_fn( kronecker_factor_list=state["GG"], grad=grad, - shampoo_beta=shampoo_beta, - precondition_1d=group["precondition_1d"], + shampoo_beta=0.0, ) torch.cuda.nvtx.range_pop() @@ -453,6 +468,48 @@ def update_kronecker_factors( kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta) +@torch.no_grad() # type: ignore[misc] +def update_kronecker_factors_kl_shampoo( + kronecker_factor_list: List[torch.Tensor], + grad: torch.Tensor, + shampoo_beta: float, + eigenbasis_list: List[torch.Tensor], + eps: float, + eigval_exp: float = -1.0, +) -> None: + """Updates the kronecker factor matrices in place using KL-Shampoo correction. + + Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378 + + Args: + kronecker_factor_list: List of preconditioner matrices (L and R) to update. + grad: Gradient tensor of the parameter being optimized + shampoo_beta: Momentum coefficient for updating preconditioners. + eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices + eps: Small offset for numerical stability. + eigenval_exp: Exponent of the eigenvalues. + """ + assert grad.dim() == 2, "KL-Shampoo mathematical correction is only supported for 2D tensors" + + # Scale the gradient matrix by the approximate eigenvalues and the eigenbasis + # G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG) + updates = [] + for idx, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)): + approx_eigvals = utils.eig.conjugate(kronecker_factor, eigenbasis, diag=True) + scale_factor = 1 / grad.shape[idx] * approx_eigvals.clamp_min(eps) ** eigval_exp + + logging.debug(f"scale_factor[{idx}]: {scale_factor}") + + correction = (eigenbasis * scale_factor[None, :]) @ eigenbasis.T + + maybe_transpose_grad = grad.T if idx == 1 else grad + updates.append(utils.eig.conjugate(correction, maybe_transpose_grad)) + + # Note that updates caculated in previous loop are in reverse order of the kronecker factor list they apply to + for kronecker_factor, update in zip(kronecker_factor_list, updates[::-1], strict=True): + kronecker_factor.lerp_(update, 1 - shampoo_beta) + + @torch.no_grad() # type: ignore[misc] def update_eigenbasis_and_momentum( kronecker_factor_list: List[torch.Tensor], diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 3faf9b1..c07e1d4 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -18,9 +18,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 coverage run -p --source=emerging_optimizers tests/test_muon_utils.py || error=1 coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py || error=1 -coverage run -p --source=emerging_optimizers tests/test_soap_functions.py || error=1 coverage run -p --source=emerging_optimizers tests/test_soap_utils.py || error=1 -coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py || error=1 +coverage run -p --source=emerging_optimizers tests/test_soap.py || error=1 coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py || error=1 coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda || error=1 coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py || error=1 diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index ab73433..c07ac75 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -17,9 +17,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0 error=0 python tests/test_muon_utils.py || error=1 python tests/test_orthogonalized_optimizer.py || error=1 -python tests/test_soap_functions.py || error=1 python tests/test_soap_utils.py || error=1 -python tests/soap_smoke_test.py || error=1 +python tests/test_soap.py || error=1 python tests/test_scalar_optimizers.py --device=cuda || error=1 python tests/test_spectral_clipping_utils.py || error=1 python tests/test_triton_kernels.py || error=1 diff --git a/tests/soap_smoke_test.py b/tests/soap_smoke_test.py deleted file mode 100644 index f7cdc7e..0000000 --- a/tests/soap_smoke_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random - -import numpy as np -import torch - -from emerging_optimizers.soap.soap import SOAP - - -config = { - "lr": 0.001, - "weight_decay": 0.01, - "adam_beta1": 0.9, - "adam_beta2": 0.95, - "eps": 1e-8, - "precondition_frequency": 1, - "shampoo_beta": 0.95, - "precondition_1d": False, - "adam_warmup_steps": 1, - "fp32_matmul_prec": "highest", - "use_adaptive_criteria": False, - "trace_normalization": False, - "power_iter_steps": 1, -} - - -def main() -> None: - # seed for reproducibility - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - # Define the size of the random matrix (parameter size). - rows = 5 - cols = 3 - matrix_shape = (rows, cols) - - # Create a random matrix as a torch Parameter. - param = torch.nn.Parameter(torch.randn(matrix_shape, device="cuda")) - print(f"Param is on device: {param.device}") - # Instantiate the custom SOAP optimizer with the random parameter. - optimizer = SOAP( - [param], - lr=config["lr"], - weight_decay=config["weight_decay"], - betas=(config["adam_beta1"], config["adam_beta2"]), - eps=config["eps"], - precondition_frequency=config["precondition_frequency"], - shampoo_beta=config["shampoo_beta"], - precondition_1d=config["precondition_1d"], - adam_warmup_steps=config["adam_warmup_steps"], - trace_normalization=config["trace_normalization"], - fp32_matmul_prec=config["fp32_matmul_prec"], - use_adaptive_criteria=config["use_adaptive_criteria"], - power_iter_steps=config["power_iter_steps"], - ) - - # Number of time steps (iterations) to simulate. - time_steps = 11 - - print("Initial parameter values:") - print(param.data) - print("---------------------------") - - # Simulate a time series of random gradients. - for t in range(time_steps): - # Simulate a random gradient matrix. - random_gradient = torch.randn(matrix_shape, device=param.device) - - # In a normal training loop, backward() would populate .grad. - # Here, we manually assign the gradient. - param.grad = random_gradient - - # Run the optimizer step. - optimizer.step() - - print("After time step", t + 1) - print(param.data) - print("----------") - - -if __name__ == "__main__": - print("Starting SOAP tests....") - main() diff --git a/tests/test_soap_functions.py b/tests/test_soap.py similarity index 67% rename from tests/test_soap_functions.py rename to tests/test_soap.py index 81016c7..51cd11c 100644 --- a/tests/test_soap_functions.py +++ b/tests/test_soap.py @@ -13,29 +13,64 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any +from functools import partial +from typing import Any, List import torch from absl.testing import absltest, parameterized +from emerging_optimizers.soap import soap from emerging_optimizers.soap.soap import ( - SOAP, _clip_update_rms_in_place, _get_precondition_frequency, _is_eigenbasis_update_step, - init_kronecker_factors, - precondition, - update_kronecker_factors, ) from emerging_optimizers.utils.precondition_schedules import LinearSchedule +def kl_shampoo_update_ref( + kronecker_factor_list: List[torch.Tensor], + grad: torch.Tensor, + eigenbasis_list: List[torch.Tensor], + shampoo_beta: float, + eps: float, + eigval_exp: float = -1.0, +) -> None: + """Reference implementation of KL-Shampoo update. + + Using same functionality implemented by different people as testing reference. The chance of two + independent implementations having the same bug is very low. + + """ + if grad.dim() != 2: + raise ValueError("KL-Shampoo mathematical correction is only supported for 2D tensors") + # scale the gradient matrix by the approximate eigenvalues and the eigenbasis + # G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG) + scale_factors = [ + 1 + / grad.shape[idx] + * (torch.diag(eigenbasis_list[idx].T @ kronecker_factor_list[idx] @ eigenbasis_list[idx]) + eps) ** eigval_exp + for idx in range(len(kronecker_factor_list)) + ] + print(scale_factors) + kronecker_product_corrections = [ + (eigenbasis_list[idx] * scale_factors[idx][None, :]) @ eigenbasis_list[idx].T + for idx in range(len(kronecker_factor_list)) + ] + kronecker_product_updates = [ + grad @ kronecker_product_corrections[1] @ grad.T, + grad.T @ kronecker_product_corrections[0] @ grad, + ] + for idx in range(len(kronecker_factor_list)): + kronecker_factor_list[idx].lerp_(kronecker_product_updates[idx], 1 - shampoo_beta) + + class SoapFunctionsTest(parameterized.TestCase): def test_init_preconditioner_multidim_tensor_shapes(self) -> None: """Tests init_preconditioner with a multi-dimensional tensor.""" grad = torch.randn(3, 4, 5) state: dict[str, Any] = {} - state["GG"] = init_kronecker_factors(grad, precondition_1d=False) + state["GG"] = soap.init_kronecker_factors(grad, precondition_1d=False) self.assertEqual(len(state["GG"]), grad.dim()) self.assertEqual(state["GG"][0].shape, (3, 3)) self.assertEqual(state["GG"][1].shape, (4, 4)) @@ -49,9 +84,9 @@ def test_init_preconditioner_multidim_tensor_shapes(self) -> None: def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None: """Tests that adam_warmup_steps causes state["Q"] to be None until the specified steps are completed.""" - param = torch.randn(5, 3, requires_grad=True) + param = torch.randn(5, 3, requires_grad=True, device="cuda") - optimizer = SOAP( + optimizer = soap.SOAP( [param], lr=0.001, weight_decay=0.01, @@ -59,20 +94,21 @@ def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None: precondition_frequency=1, ) + dummy_Q = [torch.eye(shape, device=param.device) for shape in param.shape] for step in range(adam_warmup_steps - 1): param.grad = torch.randn_like(param) optimizer.step() state = optimizer.state[param] - self.assertNotIn("Q", state, f"Q should not exist at step {step}") + torch.testing.assert_close( + state["Q"], dummy_Q, atol=0, rtol=0, msg=f"Q should stay identity at step {step}" + ) for step in range(adam_warmup_steps - 1, adam_warmup_steps + 3): param.grad = torch.randn_like(param) optimizer.step() state = optimizer.state[param] - self.assertIn("Q", state, f"Q should exist at step {step}") - self.assertIsNotNone(state["Q"], f"Q should not be None at step {step}") # Verify Q has the right shape (a list with tensors for each dim) self.assertIsInstance(state["Q"], list) self.assertEqual(len(state["Q"]), param.dim()) @@ -87,11 +123,11 @@ def test_update_kronecker_factors(self) -> None: grad = torch.randn(dim0, dim1, dim2) # Initialize factors - initial_factors = init_kronecker_factors(grad, precondition_1d=False) + initial_factors = soap.init_kronecker_factors(grad, precondition_1d=False) kronecker_factors = [f.clone() for f in initial_factors] - update_kronecker_factors( + soap.update_kronecker_factors( kronecker_factor_list=kronecker_factors, grad=grad, shampoo_beta=shampoo_beta, @@ -161,12 +197,12 @@ def test_project_and_project_back(self, N: int, M: int) -> None: Q_R = torch.linalg.qr(torch.randn(N, N))[0] orthonormal_matrix_list = [Q_L, Q_R] - projected = precondition( + projected = soap.precondition( grad=grad, eigenbasis_list=orthonormal_matrix_list, dims=[[0], [0]], ) - recov = precondition( + recov = soap.precondition( grad=projected, eigenbasis_list=orthonormal_matrix_list, dims=[[0], [1]], @@ -203,14 +239,14 @@ def test_is_eigenbasis_update_step_fixed_frequency( def test_soap_optimizer_fixed_frequency(self) -> None: """Test that SOAP optimizer can be created with fixed precondition frequency (default case).""" param = torch.randn(10, 5, requires_grad=True) - optimizer = SOAP([param], lr=1e-3, precondition_frequency=10) + optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=10) self.assertEqual(optimizer.param_groups[0]["precondition_frequency"], 10) def test_soap_optimizer_class_based_schedule(self) -> None: """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" param = torch.randn(10, 5, requires_grad=True) schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) - optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule) + optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=schedule) self.assertTrue((optimizer.param_groups[0]["precondition_frequency"]) == schedule) self.assertEqual(schedule(0), 2) @@ -248,6 +284,77 @@ def test_clip_update_rms(self, max_rms: float) -> None: else: self.assertTrue(torch.linalg.norm(u_clipped) / math.sqrt(u.numel()) <= max_rms) + @parameterized.parameters( + (4, 5), + (3, 3), + (5, 4), + ) + def test_kl_shampoo_update(self, m, n): + rand_exp_fn = partial(torch.randint, low=-4, high=-1, dtype=torch.float32, device="cuda") + kronecker_factor_list = [ + 2 ** rand_exp_fn(size=(m, m)), + 2 ** rand_exp_fn(size=(n, n)), + ] + kronecker_factor_list_ref = [f.clone() for f in kronecker_factor_list] + + test_grad = 2 ** rand_exp_fn(size=(m, n)) + eigenbasis_list = [2 ** rand_exp_fn(size=(m, m)), 2 ** rand_exp_fn(size=(n, n))] + kwargs = dict( + grad=test_grad, + shampoo_beta=0.5, + eps=1e-8, + eigval_exp=-1.0, + eigenbasis_list=eigenbasis_list, + ) + kl_shampoo_update_ref(kronecker_factor_list_ref, **kwargs) + soap.update_kronecker_factors_kl_shampoo(kronecker_factor_list, **kwargs) + + torch.testing.assert_close(kronecker_factor_list[0], kronecker_factor_list_ref[0], atol=1e-6, rtol=1e-6) + torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6) + + +class SoapTest(parameterized.TestCase): + def setUp(self): + self.default_config = { + "lr": 0.001, + "weight_decay": 0.01, + "betas": (0.9, 0.95), + "eps": 1e-8, + "precondition_frequency": 1, + "shampoo_beta": 0.95, + "precondition_1d": False, + "adam_warmup_steps": 1, + "fp32_matmul_prec": "highest", + "use_adaptive_criteria": False, + "trace_normalization": False, + "power_iter_steps": 1, + } + + def test_10steps_smoke(self): + param = torch.randn(5, 3, requires_grad=True, device="cuda") + optimizer = soap.SOAP( + [param], + **self.default_config, + ) + + for _ in range(10): + param.grad = torch.randn_like(param) + optimizer.step() + param.grad = None + + def test_with_kl_shampoo_10steps_smoke(self): + param = torch.randn(5, 3, requires_grad=True, device="cuda") + optimizer = soap.SOAP( + [param], + **self.default_config, + use_kl_shampoo=True, + ) + + for _ in range(10): + param.grad = torch.randn_like(param) + optimizer.step() + param.grad = None + if __name__ == "__main__": absltest.main()