Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring in improvements from modded-nanogpt repo #14

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

leloykun
Copy link
Contributor

@leloykun leloykun commented Feb 24, 2025

Adds code for:

  1. Optimizing Newton-Schulz coefficients
  2. Tighter estimate of spectral norm using Gram iteration taken from https://arxiv.org/pdf/2305.16173

Usage note:

def zeropower_via_newtonschulz5(
    G: Tensor, steps: int, enable_better_spec_norm_est: bool = False
) -> Tensor:
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
-     a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
-     for i in range(steps):
+     for i, (a, b, c) in enumerate([
+         ...[insert the coefficients here]...
+     ]):
        A = X @ X.mT
        if i == 0 and enable_better_spec_norm_est:
            # Tigher estimate of spectral norm using 1st Gram iteration.
            # Taken from https://arxiv.org/pdf/2305.16173
            S_norm_est_over_f_norm__squared = A.norm(dim=(-2, -1), keepdim=True)
            X = X / (S_norm_est_over_f_norm__squared**0.5 + 1e-7)
            A = A / (S_norm_est_over_f_norm__squared + 1e-7)
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

@KellerJordan
Copy link
Owner

have we confirmed that this option never causes any instability?

It's potentially risky, so important to confirm; so I will wait to accept it until I have evidence

@toothacher17
Copy link

Hey, @leloykun I tried your Jax scripts and get a group of new hyper coeffs:

(4.0246, -6.4224, 2.6026)
(3.9872, -6.2793, 2.5377)
(3.3260, -4.8258, 1.9451)
(2.8778, -3.6189, 1.6208)
(3.0133, -3.6424, 1.6122)

Do you have any recommendations for which one to use or I just pick a random one?

@leloykun
Copy link
Contributor Author

Hi @toothacher17,

In zeropower_via_newtonschulz5, you should replace

for i in range(steps):

with

for i, (a, b, c) in enumerate([
    ...[insert the coefficients here]...
])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants