@@ -97,7 +97,7 @@ NormGrowthCap(τ = 1.01; ϵ = 1e-8, lb = 1e-7, throw = true, scale = true) = Nor
9797init (o:: NormGrowthCap , x:: AbstractArray{T} ) where T = T (0 )
9898
9999function apply! (o:: NormGrowthCap , state, x:: AbstractArray{T} , dx) where T
100- current_norm = _norm (dx , 2 )
100+ current_norm = _norm (Optimisers . unthunk (dx) , 2 )
101101 if o. throw && ! isfinite (current_norm)
102102 throw (DomainError (" gradient has L2-norm $current_norm , for array $(summary (x)) " ))
103103 end
@@ -106,7 +106,7 @@ function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
106106 else
107107 # If you're below the hard min, then don't scale
108108 if o. scale
109- minthresh = o. lb * sqrt (length (dx ))
109+ minthresh = o. lb * sqrt (length (Optimisers . unthunk (dx) ))
110110 else
111111 minthresh = o. lb
112112 end
@@ -123,6 +123,71 @@ function apply!(o::NormGrowthCap, state, x::AbstractArray{T}, dx) where T
123123 end
124124end
125125
126+
127+ """
128+ GradNormControl(accumulator, τ = 1.1; epsilon = 1e-8, lb = 0.1, throw = true, scale = true, clipreportthresh = Inf)
129+
130+ NormGrowthCap with additional control, accumulation, and reporting options.
131+ `accumulator` must be an array of `Float64` with two elements, which is where the unscaled and scaled gradient norms are added into, allowing you to monitor the sum of the norms. It is your job to print/reset this.
132+ """
133+ struct GradNormControl <: Optimisers.AbstractRule
134+ tau:: Float64
135+ epsilon:: Float64
136+ lb:: Float64 # Min grad norm, to stop a tensor getting stuck near zero
137+ throw:: Bool
138+ scale:: Bool
139+ heavyclipthresh:: Real
140+ accumulator:: AbstractVector{<:Float64}
141+ end
142+
143+ function GradNormControl (accumulator, τ = 1.1 ; epsilon = 1e-8 , lb = 0.1 , throw = true , scale = true , clipreportthresh = Inf )
144+ if length (accumulator) != 2
145+ throw (ArgumentError (" accumulator must be an array with two elements, initialized to 0" ))
146+ end
147+ GradNormControl (τ, epsilon, lb, throw, scale, clipreportthresh, accumulator)
148+ end
149+
150+ function init (o:: GradNormControl , x:: AbstractArray{T} ) where T
151+ if o. scale
152+ minthresh = o. lb * sqrt (length (x))
153+ else
154+ minthresh = o. lb
155+ end
156+ return T (0 ), minthresh
157+ end
158+
159+ function apply! (o:: GradNormControl , state, x:: AbstractArray{T} , dx) where T
160+ prevnorm, minthresh = state
161+ utdx = Optimisers. unthunk (dx)
162+ current_norm = Optimisers. _norm (utdx, 2 )
163+ o. accumulator[1 ] += current_norm
164+ if o. throw && ! isfinite (current_norm)
165+ throw (DomainError (" gradient has L2-norm $current_norm , for array $(summary (x)) " ))
166+ end
167+ if prevnorm == 0
168+ o. accumulator[2 ] += current_norm
169+ return (current_norm, minthresh), dx
170+ else
171+ if current_norm < minthresh
172+ o. accumulator[2 ] += current_norm
173+ return (current_norm, minthresh), dx
174+ end
175+ ratio = current_norm / (prevnorm + o. epsilon)
176+ if ratio > o. tau
177+ lambda = T ((o. tau * prevnorm) / (current_norm + o. epsilon))
178+ if ratio > o. tau * o. heavyclipthresh
179+ println (" Heavy clipping on $(size (utdx)) :" , current_norm, " ->" , current_norm * lambda)
180+ end
181+ o. accumulator[2 ] += current_norm * lambda
182+ return (current_norm * lambda, minthresh), dx * lambda
183+ else
184+ o. accumulator[2 ] += current_norm
185+ return (current_norm, minthresh), dx
186+ end
187+ end
188+ end
189+
190+
126191"""
127192 Apollo(opt::AdamW = AdamW(), r::Function = dim -> ceil(Int, sqrt(dim)); u = 100, sort_dims = true)
128193 Apollo(η::Real, args...; kw...)
0 commit comments