Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from itertools import chain
from typing import Callable, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -81,6 +82,7 @@ class SOAP(optim.Optimizer):
power_iter_steps: Number of power iteration steps to perform before QR decomposition.
More steps can lead to better convergence but increased computation time.
max_update_rms: Clip the update RMS to this value (0 means no clipping).
use_kl_shampoo: Whether to use KL-Shampoo correction.
"""

def __init__(
Expand All @@ -107,6 +109,7 @@ def __init__(
adaptive_update_tolerance: Optional[float] = None,
power_iter_steps: int = 1,
max_update_rms: float = 0.0,
use_kl_shampoo: bool = False,
) -> None:
# Check for betas.
if betas is None:
Expand Down Expand Up @@ -159,6 +162,7 @@ def __init__(
"adaptive_update_tolerance": adaptive_update_tolerance,
"power_iter_steps": power_iter_steps,
"max_update_rms": max_update_rms,
"use_kl_shampoo": use_kl_shampoo,
}
super().__init__(params, defaults)

Expand Down Expand Up @@ -194,6 +198,21 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)

if "Q" not in state:
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]

# Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
# because it needs access to state and group
kronecker_factor_update_fn = partial(
update_kronecker_factors, precondition_1d=group["precondition_1d"]
)
if group["use_kl_shampoo"]:
kronecker_factor_update_fn = partial(
update_kronecker_factors_kl_shampoo,
eigenbasis_list=state["Q"],
eps=group["eps"],
)

# Initialize kronecker factor matrices
if "GG" not in state:
state["GG"] = init_kronecker_factors(
Expand All @@ -204,11 +223,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
# Update preconditioner matrices with gradient statistics,
# do not use shampoo_beta for EMA at first step
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
update_kronecker_factors(
kronecker_factor_list=state["GG"],
grad=grad,
shampoo_beta=0.0,
precondition_1d=group["precondition_1d"],
kronecker_factor_update_fn(
kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"]
)

# Increment step counter
Expand All @@ -228,7 +244,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
grad_projected = precondition(
grad=grad,
eigenbasis_list=state.get("Q"),
eigenbasis_list=state["Q"],
dims=[[0], [0]],
)
torch.cuda.nvtx.range_pop()
Expand All @@ -255,7 +271,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
norm_precond_grad = precondition(
grad=adam_update,
eigenbasis_list=state.get("Q"),
eigenbasis_list=state["Q"],
dims=[[0], [1]],
)
torch.cuda.nvtx.range_pop()
Expand Down Expand Up @@ -283,11 +299,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:

torch.cuda.nvtx.range_push("update_kronecker_factors")
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
update_kronecker_factors(
kronecker_factor_update_fn(
kronecker_factor_list=state["GG"],
grad=grad,
shampoo_beta=shampoo_beta,
precondition_1d=group["precondition_1d"],
shampoo_beta=0.0,
)
torch.cuda.nvtx.range_pop()

Expand Down Expand Up @@ -453,6 +468,48 @@ def update_kronecker_factors(
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)


@torch.no_grad() # type: ignore[misc]
def update_kronecker_factors_kl_shampoo(
kronecker_factor_list: List[torch.Tensor],
grad: torch.Tensor,
shampoo_beta: float,
eigenbasis_list: List[torch.Tensor],
eps: float,
eigval_exp: float = -1.0,
) -> None:
"""Updates the kronecker factor matrices in place using KL-Shampoo correction.

Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378

Args:
kronecker_factor_list: List of preconditioner matrices (L and R) to update.
grad: Gradient tensor of the parameter being optimized
shampoo_beta: Momentum coefficient for updating preconditioners.
eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices
eps: Small offset for numerical stability.
eigenval_exp: Exponent of the eigenvalues.
"""
assert grad.dim() == 2, "KL-Shampoo mathematical correction is only supported for 2D tensors"

# Scale the gradient matrix by the approximate eigenvalues and the eigenbasis
# G@Q_R@λ_R^(−1)@[email protected]/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG)
updates = []
for idx, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)):
approx_eigvals = utils.eig.conjugate(kronecker_factor, eigenbasis, diag=True)
scale_factor = 1 / grad.shape[idx] * approx_eigvals.clamp_min(eps) ** eigval_exp

logging.debug(f"scale_factor[{idx}]: {scale_factor}")

correction = (eigenbasis * scale_factor[None, :]) @ eigenbasis.T

maybe_transpose_grad = grad.T if idx == 1 else grad
updates.append(utils.eig.conjugate(correction, maybe_transpose_grad))

# Note that updates caculated in previous loop are in reverse order of the kronecker factor list they apply to
for kronecker_factor, update in zip(kronecker_factor_list, updates[::-1], strict=True):
kronecker_factor.lerp_(update, 1 - shampoo_beta)


@torch.no_grad() # type: ignore[misc]
def update_eigenbasis_and_momentum(
kronecker_factor_list: List[torch.Tensor],
Expand Down
3 changes: 1 addition & 2 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
error=0
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py || error=1
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_functions.py || error=1
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py || error=1
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py || error=1
coverage run -p --source=emerging_optimizers tests/test_soap.py || error=1
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py || error=1
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda || error=1
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py || error=1
Expand Down
3 changes: 1 addition & 2 deletions tests/ci/L1_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
error=0
python tests/test_muon_utils.py || error=1
python tests/test_orthogonalized_optimizer.py || error=1
python tests/test_soap_functions.py || error=1
python tests/test_soap_utils.py || error=1
python tests/soap_smoke_test.py || error=1
python tests/test_soap.py || error=1
python tests/test_scalar_optimizers.py --device=cuda || error=1
python tests/test_spectral_clipping_utils.py || error=1
python tests/test_triton_kernels.py || error=1
Expand Down
97 changes: 0 additions & 97 deletions tests/soap_smoke_test.py

This file was deleted.

Loading