@@ -50,6 +50,7 @@ function apply!((; opt, eta, mu, lambda, fallback)::Muon, state, x::AbstractArra
5050 # Nesterov update fed to NS5: U ← β m + (1-β) g
5151 U = @. μ * state + (1 - μ) * dx
5252 # orthogonalize
53+ @. . U = U / (norm (U) + T (1e-6 ))
5354 Ot = newtonschulz5 (U)
5455 # post shape factor √max(1, r/c)
5556 r, c... = size (x)
@@ -59,46 +60,50 @@ function apply!((; opt, eta, mu, lambda, fallback)::Muon, state, x::AbstractArra
5960 end
6061end
6162
62- const NS5_COEFFS = (; a = 3.4445 , b = - 4.7750 , c = 2.0315 )
63-
64- function _newtonschulz5 ! (X:: AbstractMatrix{T} , :: Val{true } ) where T
65- n = size (X, 1 )
63+ # Applies `X = a*X + b*X*X'*X + c*X*X'*X*X'*X` five times,
64+ # with two methods based on X*X' and X'*X respectively
65+ function newtonschulz5 ! (X:: AbstractMatrix{T} ) where T
66+ n = minimum ( size (X) )
6667 A = similar (X, n, n)
6768 B = similar (X, n, n)
6869 x = similar (X) # mirror of X
69- (; a, b, c) = NS5_COEFFS
70- for _ in 1 : 5
71- mul! (A, X, X' )
72- B .= A; mul! (B, A, A, c, b)
73- x .= X; mul! (X, B, x, true , a)
70+ a, b, c = 3.4445 , - 4.7750 , 2.0315
71+ if size (X, 1 ) <= size (X, 2 )
72+ for _ in 1 : 5
73+ mul! (A, X, X' )
74+ B .= A; mul! (B, A, A, c, b)
75+ x .= X; mul! (X, B, x, true , a)
76+ end
77+ else
78+ for _ in 1 : 5
79+ mul! (A, X' , X)
80+ B .= A; mul! (B, A, A, c, b)
81+ x .= X; mul! (X, x, B, true , a)
82+ end
7483 end
7584 return X
7685end
7786
78- function _newtonschulz5! (X:: AbstractMatrix{T} , :: Val{false} ) where T
79- n = size (X, 2 )
80- A = similar (X, n, n)
81- B = similar (X, n, n)
82- x = similar (X) # mirror of X
83- (; a, b, c) = NS5_COEFFS
84- for _ in 1 : 5
85- mul! (A, X' , X)
86- B .= A; mul! (B, A, A, c, b)
87- x .= X; mul! (X, x, B, true , a)
87+ function newtonschulz5 (X:: AbstractMatrix{T} ) where T
88+ a, b, c = T (3.4445 ), T (- 4.7750 ), T (2.0315 )
89+ if size (X, 1 ) <= size (X, 2 )
90+ for _ in 1 : 5
91+ A = X * X'
92+ B = c * A + b * A * A
93+ X = a * X + B * X
94+ end
95+ else
96+ for _ in 1 : 5
97+ A = X' * X
98+ B = c * A + b * A * A
99+ X = a * X + X * B
100+ end
88101 end
89102 return X
90103end
91104
92- _newtonschulz5! (G, c:: Bool ) = _newtonschulz5! (G, Val (c):: Union{Val{true},Val{false}} )
93-
94- function newtonschulz5 (G:: AbstractMatrix{T} ) where T
95- X = G / (norm (G) + T (1e-7 ))
96- return _newtonschulz5! (X, size (G, 1 ) <= size (G, 2 ))
97- end
98-
99- # Applies `a*X + b*X*X'*X + c*X*X'*X*X'*X` five times,
100- # with two methods based on X*X' and X'*X respectively
101- newtonschulz5 (G:: AbstractArray ) = reshape (newtonschulz5 (reshape (G, size (G,1 ), :)), size (G))
105+ newtonschulz5! (X:: AbstractArray ) = reshape (newtonschulz5! (reshape (X, size (X,1 ), :)), size (X))
106+ newtonschulz5 (X:: AbstractArray ) = reshape (newtonschulz5 (reshape (X, size (X,1 ), :)), size (X))
102107
103108adjust (r:: Muon , η:: Real ) = adjust (r, eta = η, opt = adjust (r. opt, eta = (r. opt. eta / r. eta) * η))
104109
0 commit comments