Skip to content

Commit 3a592dc

Browse files
Merge pull request #108 from gdalle/gd/di_v06
Update to DifferentiationInterface v0.6
2 parents e30f933 + e4e34d5 commit 3a592dc

File tree

5 files changed

+2014
-131
lines changed

5 files changed

+2014
-131
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ OptimizationReverseDiffExt = "ReverseDiff"
3838
OptimizationZygoteExt = "Zygote"
3939

4040
[compat]
41-
ADTypes = "1.5"
41+
ADTypes = "1.9"
4242
ArrayInterface = "7.6"
43-
DifferentiationInterface = "0.5"
43+
DifferentiationInterface = "0.6.1"
4444
DocStringExtensions = "0.9"
45-
Enzyme = "0.12.12"
45+
Enzyme = "0.13.2"
4646
FastClosures = "0.3"
4747
FiniteDiff = "2.12"
4848
ForwardDiff = "0.10.26"

src/OptimizationDIExt.jl

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ function instantiate_function(
3939
adtype, soadtype = generate_adtype(adtype)
4040

4141
if g == true && f.grad === nothing
42-
extras_grad = prepare_gradient(_f, adtype, x)
42+
prep_grad = prepare_gradient(_f, adtype, x)
4343
function grad(res, θ)
44-
gradient!(_f, res, adtype, θ, extras_grad)
44+
gradient!(_f, res, prep_grad, adtype, θ)
4545
end
4646
if p !== SciMLBase.NullParameters() && p !== nothing
4747
function grad(res, θ, p)
@@ -57,10 +57,10 @@ function instantiate_function(
5757

5858
if fg == true && f.fg === nothing
5959
if g == false
60-
extras_grad = prepare_gradient(_f, adtype, x)
60+
prep_grad = prepare_gradient(_f, adtype, x)
6161
end
6262
function fg!(res, θ)
63-
(y, _) = value_and_gradient!(_f, res, adtype, θ, extras_grad)
63+
(y, _) = value_and_gradient!(_f, res, prep_grad, adtype, θ)
6464
return y
6565
end
6666
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -79,9 +79,9 @@ function instantiate_function(
7979
hess_sparsity = f.hess_prototype
8080
hess_colors = f.hess_colorvec
8181
if h == true && f.hess === nothing
82-
extras_hess = prepare_hessian(_f, soadtype, x)
82+
prep_hess = prepare_hessian(_f, soadtype, x)
8383
function hess(res, θ)
84-
hessian!(_f, res, soadtype, θ, extras_hess)
84+
hessian!(_f, res, prep_hess, soadtype, θ)
8585
end
8686
if p !== SciMLBase.NullParameters() && p !== nothing
8787
function hess(res, θ, p)
@@ -98,7 +98,7 @@ function instantiate_function(
9898
if fgh == true && f.fgh === nothing
9999
function fgh!(G, H, θ)
100100
(y, _, _) = value_derivative_and_second_derivative!(
101-
_f, G, H, soadtype, θ, extras_hess)
101+
_f, G, H, prep_hess, soadtype, θ)
102102
return y
103103
end
104104
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -116,14 +116,14 @@ function instantiate_function(
116116
end
117117

118118
if hv == true && f.hv === nothing
119-
extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
119+
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
120120
function hv!(H, θ, v)
121-
hvp!(_f, H, soadtype, θ, v, extras_hvp)
121+
only(hvp!(_f, (H,), prep_hvp, soadtype, θ, (v,)))
122122
end
123123
if p !== SciMLBase.NullParameters() && p !== nothing
124124
function hv!(H, θ, v, p)
125125
global _p = p
126-
hvp!(_f, H, soadtype, θ, v)
126+
only(hvp!(_f, (H,), soadtype, θ, (v,)))
127127
end
128128
end
129129
elseif hv == true
@@ -156,9 +156,9 @@ function instantiate_function(
156156
cons_jac_prototype = f.cons_jac_prototype
157157
cons_jac_colorvec = f.cons_jac_colorvec
158158
if cons !== nothing && cons_j == true && f.cons_j === nothing
159-
extras_jac = prepare_jacobian(cons_oop, adtype, x)
159+
prep_jac = prepare_jacobian(cons_oop, adtype, x)
160160
function cons_j!(J, θ)
161-
jacobian!(cons_oop, J, adtype, θ, extras_jac)
161+
jacobian!(cons_oop, J, prep_jac, adtype, θ)
162162
if size(J, 1) == 1
163163
J = vec(J)
164164
end
@@ -170,9 +170,9 @@ function instantiate_function(
170170
end
171171

172172
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
173-
extras_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons))
173+
prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),))
174174
function cons_vjp!(J, θ, v)
175-
pullback!(cons_oop, J, adtype, θ, v, extras_pullback)
175+
only(pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,)))
176176
end
177177
elseif cons_vjp == true && cons !== nothing
178178
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
@@ -181,10 +181,10 @@ function instantiate_function(
181181
end
182182

183183
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
184-
extras_pushforward = prepare_pushforward(
185-
cons_oop, adtype, x, ones(eltype(x), length(x)))
184+
prep_pushforward = prepare_pushforward(
185+
cons_oop, adtype, x, (ones(eltype(x), length(x)),))
186186
function cons_jvp!(J, θ, v)
187-
pushforward!(cons_oop, J, adtype, θ, v, extras_pushforward)
187+
only(pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,)))
188188
end
189189
elseif cons_jvp == true && cons !== nothing
190190
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
@@ -196,11 +196,11 @@ function instantiate_function(
196196
conshess_colors = f.cons_hess_colorvec
197197
if cons !== nothing && f.cons_h === nothing && cons_h == true
198198
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
199-
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
199+
prep_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
200200

201201
function cons_h!(H, θ)
202202
for i in 1:num_cons
203-
hessian!(fncs[i], H[i], soadtype, θ, extras_cons_hess[i])
203+
hessian!(fncs[i], H[i], prep_cons_hess[i], soadtype, θ)
204204
end
205205
end
206206
elseif cons_h == true && cons !== nothing
@@ -212,7 +212,7 @@ function instantiate_function(
212212
lag_hess_prototype = f.lag_hess_prototype
213213

214214
if cons !== nothing && lag_h == true && f.lag_h === nothing
215-
lag_extras = prepare_hessian(
215+
lag_prep = prepare_hessian(
216216
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
217217
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
218218

@@ -221,13 +221,13 @@ function instantiate_function(
221221
cons_h(H, θ)
222222
H *= λ
223223
else
224-
H .= @view(hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
224+
H .= @view(hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
225225
1:length(θ), 1:length(θ)])
226226
end
227227
end
228228

229229
function lag_h!(h::AbstractVector, θ, σ, λ)
230-
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
230+
H = hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))
231231
k = 0
232232
for i in 1:length(θ)
233233
for j in 1:i
@@ -244,14 +244,14 @@ function instantiate_function(
244244
H *= λ
245245
else
246246
global _p = p
247-
H .= @view(hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
247+
H .= @view(hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
248248
1:length(θ), 1:length(θ)])
249249
end
250250
end
251251

252252
function lag_h!(h::AbstractVector, θ, σ, λ, p)
253253
global _p = p
254-
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
254+
H = hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))
255255
k = 0
256256
for i in 1:length(θ)
257257
for j in 1:i
@@ -308,9 +308,9 @@ function instantiate_function(
308308
adtype, soadtype = generate_adtype(adtype)
309309

310310
if g == true && f.grad === nothing
311-
extras_grad = prepare_gradient(_f, adtype, x)
311+
prep_grad = prepare_gradient(_f, adtype, x)
312312
function grad(θ)
313-
gradient(_f, adtype, θ, extras_grad)
313+
gradient(_f, prep_grad, adtype, θ)
314314
end
315315
if p !== SciMLBase.NullParameters() && p !== nothing
316316
function grad(θ, p)
@@ -326,10 +326,10 @@ function instantiate_function(
326326

327327
if fg == true && f.fg === nothing
328328
if g == false
329-
extras_grad = prepare_gradient(_f, adtype, x)
329+
prep_grad = prepare_gradient(_f, adtype, x)
330330
end
331331
function fg!(θ)
332-
(y, res) = value_and_gradient(_f, adtype, θ, extras_grad)
332+
(y, res) = value_and_gradient(_f, prep_grad, adtype, θ)
333333
return y, res
334334
end
335335
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -348,9 +348,9 @@ function instantiate_function(
348348
hess_sparsity = f.hess_prototype
349349
hess_colors = f.hess_colorvec
350350
if h == true && f.hess === nothing
351-
extras_hess = prepare_hessian(_f, soadtype, x)
351+
prep_hess = prepare_hessian(_f, soadtype, x)
352352
function hess(θ)
353-
hessian(_f, soadtype, θ, extras_hess)
353+
hessian(_f, prep_hess, soadtype, θ)
354354
end
355355
if p !== SciMLBase.NullParameters() && p !== nothing
356356
function hess(θ, p)
@@ -366,7 +366,7 @@ function instantiate_function(
366366

367367
if fgh == true && f.fgh === nothing
368368
function fgh!(θ)
369-
(y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess)
369+
(y, G, H) = value_derivative_and_second_derivative(_f, prep_hess, adtype, θ)
370370
return y, G, H
371371
end
372372
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -383,14 +383,14 @@ function instantiate_function(
383383
end
384384

385385
if hv == true && f.hv === nothing
386-
extras_hvp = prepare_hvp(_f, soadtype, x, zeros(eltype(x), size(x)))
386+
prep_hvp = prepare_hvp(_f, soadtype, x, (zeros(eltype(x), size(x)),))
387387
function hv!(θ, v)
388-
hvp(_f, soadtype, θ, v, extras_hvp)
388+
only(hvp(_f, prep_hvp, soadtype, θ, (v)))
389389
end
390390
if p !== SciMLBase.NullParameters() && p !== nothing
391391
function hv!(θ, v, p)
392392
global _p = p
393-
hvp(_f, soadtype, θ, v, extras_hvp)
393+
only(vp(_f, prep_hvp, soadtype, θ, (v,)))
394394
end
395395
end
396396
elseif hv == true
@@ -417,9 +417,9 @@ function instantiate_function(
417417
cons_jac_prototype = f.cons_jac_prototype
418418
cons_jac_colorvec = f.cons_jac_colorvec
419419
if cons !== nothing && cons_j == true && f.cons_j === nothing
420-
extras_jac = prepare_jacobian(cons, adtype, x)
420+
prep_jac = prepare_jacobian(cons, adtype, x)
421421
function cons_j!(θ)
422-
J = jacobian(cons, adtype, θ, extras_jac)
422+
J = jacobian(cons, prep_jac, adtype, θ)
423423
if size(J, 1) == 1
424424
J = vec(J)
425425
end
@@ -432,9 +432,9 @@ function instantiate_function(
432432
end
433433

434434
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
435-
extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
435+
prep_pullback = prepare_pullback(cons, adtype, x, (ones(eltype(x), num_cons),))
436436
function cons_vjp!(θ, v)
437-
return pullback(cons, adtype, θ, v, extras_pullback)
437+
return only(pullback(cons, prep_pullback, adtype, θ, (v,)))
438438
end
439439
elseif cons_vjp == true && cons !== nothing
440440
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
@@ -443,10 +443,10 @@ function instantiate_function(
443443
end
444444

445445
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
446-
extras_pushforward = prepare_pushforward(
447-
cons, adtype, x, ones(eltype(x), length(x)))
446+
prep_pushforward = prepare_pushforward(
447+
cons, adtype, x, (ones(eltype(x), length(x)),))
448448
function cons_jvp!(θ, v)
449-
return pushforward(cons, adtype, θ, v, extras_pushforward)
449+
return only(pushforward(cons, prep_pushforward, adtype, θ, (v,)))
450450
end
451451
elseif cons_jvp == true && cons !== nothing
452452
cons_jvp! = (θ, v) -> f.cons_jvp(θ, v, p)
@@ -458,11 +458,11 @@ function instantiate_function(
458458
conshess_colors = f.cons_hess_colorvec
459459
if cons !== nothing && cons_h == true && f.cons_h === nothing
460460
fncs = [(x) -> cons(x)[i] for i in 1:num_cons]
461-
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
461+
prep_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
462462

463463
function cons_h!(θ)
464464
H = map(1:num_cons) do i
465-
hessian(fncs[i], soadtype, θ, extras_cons_hess[i])
465+
hessian(fncs[i], prep_cons_hess[i], soadtype, θ)
466466
end
467467
return H
468468
end
@@ -475,15 +475,15 @@ function instantiate_function(
475475
lag_hess_prototype = f.lag_hess_prototype
476476

477477
if cons !== nothing && lag_h == true && f.lag_h === nothing
478-
lag_extras = prepare_hessian(
478+
lag_prep = prepare_hessian(
479479
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
480480
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
481481

482482
function lag_h!(θ, σ, λ)
483483
if σ == zero(eltype(θ))
484484
return λ .* cons_h(θ)
485485
else
486-
return hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
486+
return hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
487487
1:length(θ), 1:length(θ)]
488488
end
489489
end
@@ -494,7 +494,7 @@ function instantiate_function(
494494
return λ .* cons_h(θ)
495495
else
496496
global _p = p
497-
return hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)[
497+
return hessian(lagrangian, lag_prep, soadtype, vcat(θ, [σ], λ))[
498498
1:length(θ), 1:length(θ)]
499499
end
500500
end

0 commit comments

Comments
 (0)