Skip to content

Commit 1d3be8a

Browse files
authored
Adding AdaptiveGradNormControl
1 parent 7ad2fce commit 1d3be8a

File tree

3 files changed

+79
-2
lines changed

3 files changed

+79
-2
lines changed

src/CannotWaitForTheseOptimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ import Optimisers: OptimiserChain, AbstractRule, Leaf, adjust, adjust!, _adjust,
66
include("rules.jl")
77
include("adjust.jl")
88

9-
export Muon, Apollo, NormGrowthCap, GradNormControl
9+
export Muon, Apollo, NormGrowthCap, GradNormControl, AdaptiveGradNormControl
1010

1111
end

src/rules.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,82 @@ function apply!(o::GradNormControl, state, x::AbstractArray{T}, dx) where T
188188
end
189189

190190

191+
"""
192+
AdaptiveGradNormControl(accumulator, τ = 1.0; epsilon = 1e-8, lb = 0.1,
193+
momentum = 0.90, throw = true, clipreportthresh = Inf)
194+
195+
Gradient norm control using exponential moving statistics. Clips gradients when the
196+
current norm exceeds mean + τ * std.
197+
"""
198+
struct AdaptiveGradNormControl <: Optimisers.AbstractRule
199+
tau::Float64
200+
epsilon::Float64
201+
lb::Float64
202+
throw::Bool
203+
momentum::Float64
204+
heavyclipthresh::Real
205+
accumulator::AbstractVector{<:Float64}
206+
end
207+
208+
function AdaptiveGradNormControl(accumulator, τ = 1.0; epsilon = 1e-8, lb = 0.1,
209+
momentum = 0.9, throw = true, clipreportthresh = Inf)
210+
if length(accumulator) != 2
211+
throw(ArgumentError("accumulator must be an array with two elements"))
212+
end
213+
AdaptiveGradNormControl(τ, epsilon, lb, throw, momentum, clipreportthresh, accumulator)
214+
end
215+
216+
# Helper function to update running statistics
217+
function update_running_stats(curr_norm, prev_mean, prev_std, momentum)
218+
new_mean = momentum * prev_mean + (1 - momentum) * curr_norm
219+
# Variance update formula: var = E[(x - μ)²] = E[x²] - μ²
220+
new_var = momentum * (prev_std^2 + prev_mean^2) +
221+
(1 - momentum) * curr_norm^2 - new_mean^2
222+
new_std = sqrt(max(new_var, 1e-8))
223+
return new_mean, new_std
224+
end
225+
226+
function Optimisers.init(o::AdaptiveGradNormControl, x::AbstractArray{T}) where T
227+
minthresh = o.lb * sqrt(length(x))
228+
return (T(0), T(0), minthresh) # mean, std, minthresh
229+
end
230+
231+
function Optimisers.apply!(o::AdaptiveGradNormControl, state, x::AbstractArray{T}, dx) where T
232+
mu, std, minthresh = state
233+
utdx = Optimisers.unthunk(dx)
234+
current_norm = Optimisers._norm(utdx, 2)
235+
o.accumulator[1] += current_norm
236+
if o.throw && !isfinite(current_norm)
237+
throw(DomainError("gradient has L2-norm $current_norm"))
238+
end
239+
if current_norm < minthresh
240+
o.accumulator[2] += current_norm
241+
new_mean, new_std = update_running_stats(current_norm, mu, std, o.momentum) #Unsure if we should adjust the mean if they fall below the threshold?
242+
return (new_mean, new_std, minthresh), dx
243+
end
244+
if mu == 0
245+
o.accumulator[2] += current_norm
246+
return (current_norm, current_norm, minthresh), dx
247+
end
248+
threshold = mu + o.tau * std
249+
if current_norm > threshold
250+
lambda = T(threshold / (current_norm + o.epsilon))
251+
clipped_norm = current_norm * lambda
252+
if current_norm > threshold * o.heavyclipthresh
253+
println("Heavy clipping on $(size(utdx)): ", current_norm, "->", clipped_norm, " with mu ", mu, " and std ", std)
254+
end
255+
new_mean, new_std = update_running_stats(clipped_norm, mu, std, o.momentum)
256+
o.accumulator[2] += clipped_norm
257+
return (new_mean, new_std, minthresh), dx * lambda
258+
end
259+
o.accumulator[2] += current_norm
260+
new_mean, new_std = update_running_stats(current_norm, mu, std, o.momentum)
261+
return (new_mean, new_std, minthresh), dx
262+
end
263+
264+
265+
266+
191267
"""
192268
Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
193269
Apollo(η::Real, args...; kw...)

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ Random.seed!(1)
88
RULES = [
99
Muon(), Apollo(),
1010
OptimiserChain(NormGrowthCap(), Apollo()),
11-
OptimiserChain(GradNormControl([0.0,0.0]), Muon())
11+
OptimiserChain(GradNormControl([0.0,0.0]), Muon()),
12+
OptimiserChain(AdaptiveGradNormControl([0.0,0.0]), Muon())
1213
]
1314

1415
name(o) = typeof(o).name.name # just for printing testset headings

0 commit comments

Comments
 (0)