@@ -188,6 +188,82 @@ function apply!(o::GradNormControl, state, x::AbstractArray{T}, dx) where T
188188end
189189
190190
191+ """
192+ AdaptiveGradNormControl(accumulator, τ = 1.0; epsilon = 1e-8, lb = 0.1,
193+ momentum = 0.90, throw = true, clipreportthresh = Inf)
194+
195+ Gradient norm control using exponential moving statistics. Clips gradients when the
196+ current norm exceeds mean + τ * std.
197+ """
198+ struct AdaptiveGradNormControl <: Optimisers.AbstractRule
199+ tau:: Float64
200+ epsilon:: Float64
201+ lb:: Float64
202+ throw:: Bool
203+ momentum:: Float64
204+ heavyclipthresh:: Real
205+ accumulator:: AbstractVector{<:Float64}
206+ end
207+
208+ function AdaptiveGradNormControl (accumulator, τ = 1.0 ; epsilon = 1e-8 , lb = 0.1 ,
209+ momentum = 0.9 , throw = true , clipreportthresh = Inf )
210+ if length (accumulator) != 2
211+ throw (ArgumentError (" accumulator must be an array with two elements" ))
212+ end
213+ AdaptiveGradNormControl (τ, epsilon, lb, throw, momentum, clipreportthresh, accumulator)
214+ end
215+
216+ # Helper function to update running statistics
217+ function update_running_stats (curr_norm, prev_mean, prev_std, momentum)
218+ new_mean = momentum * prev_mean + (1 - momentum) * curr_norm
219+ # Variance update formula: var = E[(x - μ)²] = E[x²] - μ²
220+ new_var = momentum * (prev_std^ 2 + prev_mean^ 2 ) +
221+ (1 - momentum) * curr_norm^ 2 - new_mean^ 2
222+ new_std = sqrt (max (new_var, 1e-8 ))
223+ return new_mean, new_std
224+ end
225+
226+ function Optimisers. init (o:: AdaptiveGradNormControl , x:: AbstractArray{T} ) where T
227+ minthresh = o. lb * sqrt (length (x))
228+ return (T (0 ), T (0 ), minthresh) # mean, std, minthresh
229+ end
230+
231+ function Optimisers. apply! (o:: AdaptiveGradNormControl , state, x:: AbstractArray{T} , dx) where T
232+ mu, std, minthresh = state
233+ utdx = Optimisers. unthunk (dx)
234+ current_norm = Optimisers. _norm (utdx, 2 )
235+ o. accumulator[1 ] += current_norm
236+ if o. throw && ! isfinite (current_norm)
237+ throw (DomainError (" gradient has L2-norm $current_norm " ))
238+ end
239+ if current_norm < minthresh
240+ o. accumulator[2 ] += current_norm
241+ new_mean, new_std = update_running_stats (current_norm, mu, std, o. momentum) # Unsure if we should adjust the mean if they fall below the threshold?
242+ return (new_mean, new_std, minthresh), dx
243+ end
244+ if mu == 0
245+ o. accumulator[2 ] += current_norm
246+ return (current_norm, current_norm, minthresh), dx
247+ end
248+ threshold = mu + o. tau * std
249+ if current_norm > threshold
250+ lambda = T (threshold / (current_norm + o. epsilon))
251+ clipped_norm = current_norm * lambda
252+ if current_norm > threshold * o. heavyclipthresh
253+ println (" Heavy clipping on $(size (utdx)) : " , current_norm, " ->" , clipped_norm, " with mu " , mu, " and std " , std)
254+ end
255+ new_mean, new_std = update_running_stats (clipped_norm, mu, std, o. momentum)
256+ o. accumulator[2 ] += clipped_norm
257+ return (new_mean, new_std, minthresh), dx * lambda
258+ end
259+ o. accumulator[2 ] += current_norm
260+ new_mean, new_std = update_running_stats (current_norm, mu, std, o. momentum)
261+ return (new_mean, new_std, minthresh), dx
262+ end
263+
264+
265+
266+
191267"""
192268 Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
193269 Apollo(η::Real, args...; kw...)
0 commit comments