Skip to content

Commit fb0d57a

Browse files
committed
.
1 parent 2fdaaa7 commit fb0d57a

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

src/rules.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ function apply!((; opt, eta, mu, lambda, fallback)::Muon, state, x::AbstractArra
5050
# Nesterov update fed to NS5: U ← β m + (1-β) g
5151
U = @. μ * state + (1 - μ) * dx
5252
# orthogonalize
53+
@.. U = U / (norm(U) + T(1e-6))
5354
Ot = newtonschulz5(U)
5455
# post shape factor √max(1, r/c)
5556
r, c... = size(x)
@@ -59,46 +60,50 @@ function apply!((; opt, eta, mu, lambda, fallback)::Muon, state, x::AbstractArra
5960
end
6061
end
6162

62-
const NS5_COEFFS = (; a = 3.4445, b = -4.7750, c = 2.0315)
63-
64-
function _newtonschulz5!(X::AbstractMatrix{T}, ::Val{true}) where T
65-
n = size(X, 1)
63+
# Applies `X = a*X + b*X*X'*X + c*X*X'*X*X'*X` five times,
64+
# with two methods based on X*X' and X'*X respectively
65+
function newtonschulz5!(X::AbstractMatrix{T}) where T
66+
n = minimum(size(X))
6667
A = similar(X, n, n)
6768
B = similar(X, n, n)
6869
x = similar(X) # mirror of X
69-
(; a, b, c) = NS5_COEFFS
70-
for _ in 1:5
71-
mul!(A, X, X')
72-
B .= A; mul!(B, A, A, c, b)
73-
x .= X; mul!(X, B, x, true, a)
70+
a, b, c = 3.4445, -4.7750, 2.0315
71+
if size(X, 1) <= size(X, 2)
72+
for _ in 1:5
73+
mul!(A, X, X')
74+
B .= A; mul!(B, A, A, c, b)
75+
x .= X; mul!(X, B, x, true, a)
76+
end
77+
else
78+
for _ in 1:5
79+
mul!(A, X', X)
80+
B .= A; mul!(B, A, A, c, b)
81+
x .= X; mul!(X, x, B, true, a)
82+
end
7483
end
7584
return X
7685
end
7786

78-
function _newtonschulz5!(X::AbstractMatrix{T}, ::Val{false}) where T
79-
n = size(X, 2)
80-
A = similar(X, n, n)
81-
B = similar(X, n, n)
82-
x = similar(X) # mirror of X
83-
(; a, b, c) = NS5_COEFFS
84-
for _ in 1:5
85-
mul!(A, X', X)
86-
B .= A; mul!(B, A, A, c, b)
87-
x .= X; mul!(X, x, B, true, a)
87+
function newtonschulz5(X::AbstractMatrix{T}) where T
88+
a, b, c = T(3.4445), T(-4.7750), T(2.0315)
89+
if size(X, 1) <= size(X, 2)
90+
for _ in 1:5
91+
A = X * X'
92+
B = c * A + b * A * A
93+
X = a * X + B * X
94+
end
95+
else
96+
for _ in 1:5
97+
A = X' * X
98+
B = c * A + b * A * A
99+
X = a * X + X * B
100+
end
88101
end
89102
return X
90103
end
91104

92-
_newtonschulz5!(G, c::Bool) = _newtonschulz5!(G, Val(c)::Union{Val{true},Val{false}})
93-
94-
function newtonschulz5(G::AbstractMatrix{T}) where T
95-
X = G / (norm(G) + T(1e-7))
96-
return _newtonschulz5!(X, size(G, 1) <= size(G, 2))
97-
end
98-
99-
# Applies `a*X + b*X*X'*X + c*X*X'*X*X'*X` five times,
100-
# with two methods based on X*X' and X'*X respectively
101-
newtonschulz5(G::AbstractArray) = reshape(newtonschulz5(reshape(G, size(G,1), :)), size(G))
105+
newtonschulz5!(X::AbstractArray) = reshape(newtonschulz5!(reshape(X, size(X,1), :)), size(X))
106+
newtonschulz5(X::AbstractArray) = reshape(newtonschulz5(reshape(X, size(X,1), :)), size(X))
102107

103108
adjust(r::Muon, η::Real) = adjust(r, eta = η, opt = adjust(r.opt, eta = (r.opt.eta / r.eta) * η))
104109

0 commit comments

Comments
 (0)