Skip to content

Commit 0835e86

Browse files
authored
Adding GradNormControl
1 parent 9ae6539 commit 0835e86

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CannotWaitForTheseOptimisers"
22
uuid = "16124dda-d9fe-413b-a880-e3f4df3aa341"
33
authors = ["murrellb <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

src/CannotWaitForTheseOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ import Optimisers: OptimiserChain, AbstractRule, Leaf, adjust, adjust!, _adjust,
66
include("rules.jl")
77
include("adjust.jl")
88

9-
export Muon, Apollo, NormGrowthCap
9+
export Muon, Apollo, NormGrowthCap, GradNormControl
1010

1111
end

src/rules.jl

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true) = Nor
9797
init(o::NormGrowthCap, x::AbstractArray{T}) where T = T(0)
9898

9999
function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
100-
current_norm = _norm(dx, 2)
100+
current_norm = _norm(Optimisers.unthunk(dx), 2)
101101
if o.throw && !isfinite(current_norm)
102102
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))"))
103103
end
@@ -106,7 +106,7 @@ function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
106106
else
107107
#If you're below the hard min, then don't scale
108108
if o.scale
109-
minthresh = o.lb * sqrt(length(dx))
109+
minthresh = o.lb * sqrt(length(Optimisers.unthunk(dx)))
110110
else
111111
minthresh = o.lb
112112
end
@@ -123,6 +123,71 @@ function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
123123
end
124124
end
125125

126+
127+
"""
128+
GradNormControl(accumulator, τ = 1.1; epsilon = 1e-8, lb = 0.1, throw = true, scale = true, clipreportthresh = Inf)
129+
130+
NormGrowthCap with additional control, accumulation, and reporting options.
131+
`accumulator` must be an array of `Float64` with two elements, which is where the unscaled and scaled gradient norms are added into, allowing you to monitor the sum of the norms. It is your job to print/reset this.
132+
"""
133+
struct GradNormControl <: Optimisers.AbstractRule
134+
tau::Float64
135+
epsilon::Float64
136+
lb::Float64 #Min grad norm, to stop a tensor getting stuck near zero
137+
throw::Bool
138+
scale::Bool
139+
heavyclipthresh::Real
140+
accumulator::AbstractVector{<:Float64}
141+
end
142+
143+
function GradNormControl(accumulator, τ = 1.1; epsilon = 1e-8, lb = 0.1, throw = true, scale = true, clipreportthresh = Inf)
144+
if length(accumulator) != 2
145+
throw(ArgumentError("accumulator must be an array with two elements, initialized to 0"))
146+
end
147+
GradNormControl(τ, epsilon, lb, throw, scale, clipreportthresh, accumulator)
148+
end
149+
150+
function init(o::GradNormControl, x::AbstractArray{T}) where T
151+
if o.scale
152+
minthresh = o.lb * sqrt(length(x))
153+
else
154+
minthresh = o.lb
155+
end
156+
return T(0), minthresh
157+
end
158+
159+
function apply!(o::GradNormControl, state, x::AbstractArray{T}, dx) where T
160+
prevnorm, minthresh = state
161+
utdx = Optimisers.unthunk(dx)
162+
current_norm = Optimisers._norm(utdx, 2)
163+
o.accumulator[1] += current_norm
164+
if o.throw && !isfinite(current_norm)
165+
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))"))
166+
end
167+
if prevnorm == 0
168+
o.accumulator[2] += current_norm
169+
return (current_norm, minthresh), dx
170+
else
171+
if current_norm < minthresh
172+
o.accumulator[2] += current_norm
173+
return (current_norm, minthresh), dx
174+
end
175+
ratio = current_norm / (prevnorm + o.epsilon)
176+
if ratio > o.tau
177+
lambda = T((o.tau * prevnorm) / (current_norm + o.epsilon))
178+
if ratio > o.tau * o.heavyclipthresh
179+
println("Heavy clipping on $(size(utdx)):", current_norm, "->", current_norm * lambda)
180+
end
181+
o.accumulator[2] += current_norm * lambda
182+
return (current_norm * lambda, minthresh), dx * lambda
183+
else
184+
o.accumulator[2] += current_norm
185+
return (current_norm, minthresh), dx
186+
end
187+
end
188+
end
189+
190+
126191
"""
127192
Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
128193
Apollo(η::Real, args...; kw...)

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ Random.seed!(1)
77

88
RULES = [
99
Muon(), Apollo(),
10-
OptimiserChain(NormGrowthCap(), Apollo())
10+
OptimiserChain(NormGrowthCap(), Apollo()),
11+
OptimiserChain(GradNormControl(), Muon())
1112
]
1213

1314
name(o) = typeof(o).name.name # just for printing testset headings

0 commit comments

Comments
 (0)