Skip to content

Commit 2aadf02

Browse files
format
1 parent c850508 commit 2aadf02

File tree

5 files changed

+106
-58
lines changed

5 files changed

+106
-58
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5858
end
5959

6060
function cons_f2(x, dx, fcons, p, num_cons, i)
61-
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
61+
Enzyme.autodiff_deferred(
62+
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
6263
Const(fcons), Const(p), Const(num_cons), Const(i))
6364
return nothing
6465
end
@@ -83,7 +84,8 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8384
end
8485

8586
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
86-
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
87+
Enzyme.autodiff_deferred(
88+
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
8789
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
8890
return nothing
8991
end

ext/OptimizationZygoteExt.jl

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ function OptimizationBase.instantiate_function(
2424
g = false, h = false, hv = false, fg = false, fgh = false,
2525
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
2626
lag_h = false)
27-
2827
adtype, soadtype = OptimizationBase.generate_adtype(adtype)
2928

3029
if g == true && f.grad === nothing
@@ -83,12 +82,14 @@ function OptimizationBase.instantiate_function(
8382

8483
if fgh == true && f.fgh === nothing
8584
function fgh!(G, H, θ)
86-
(y, _, _) = value_derivative_and_second_derivative!(f.f, G, H, prep_hess, soadtype, θ, Constant(p))
85+
(y, _, _) = value_derivative_and_second_derivative!(
86+
f.f, G, H, prep_hess, soadtype, θ, Constant(p))
8787
return y
8888
end
8989
if p !== SciMLBase.NullParameters() && p !== nothing
9090
function fgh!(G, H, θ, p)
91-
(y, _, _) = value_derivative_and_second_derivative!(f.f, G, H, prep_hess, soadtype, θ, Constant(p))
91+
(y, _, _) = value_derivative_and_second_derivative!(
92+
f.f, G, H, prep_hess, soadtype, θ, Constant(p))
9293
return y
9394
end
9495
end
@@ -180,7 +181,8 @@ function OptimizationBase.instantiate_function(
180181
conshess_sparsity = f.cons_hess_prototype
181182
conshess_colors = f.cons_hess_colorvec
182183
if cons !== nothing && cons_h == true && f.cons_h === nothing
183-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
184+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
185+
for i in 1:num_cons]
184186

185187
function cons_h!(H, θ)
186188
for i in 1:num_cons
@@ -197,20 +199,23 @@ function OptimizationBase.instantiate_function(
197199

198200
if f.lag_h === nothing && cons !== nothing && lag_h == true
199201
lag_extras = prepare_hessian(
200-
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
202+
lagrangian, soadtype, x, Constant(one(eltype(x))),
203+
Constant(ones(eltype(x), num_cons)), Constant(p))
201204
lag_hess_prototype = zeros(Bool, num_cons, length(x))
202205

203206
function lag_h!(H::AbstractMatrix, θ, σ, λ)
204207
if σ == zero(eltype(θ))
205208
cons_h!(H, θ)
206209
H *= λ
207210
else
208-
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
211+
hessian!(lagrangian, H, lag_extras, soadtype, θ,
212+
Constant(σ), Constant(λ), Constant(p))
209213
end
210214
end
211215

212216
function lag_h!(h::AbstractVector, θ, σ, λ)
213-
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
217+
H = hessian(
218+
lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
214219
k = 0
215220
for i in 1:length(θ)
216221
for j in 1:i
@@ -226,12 +231,14 @@ function OptimizationBase.instantiate_function(
226231
cons_h(H, θ)
227232
H *= λ
228233
else
229-
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
234+
hessian!(lagrangian, H, lag_extras, soadtype, θ,
235+
Constant(σ), Constant(λ), Constant(p))
230236
end
231237
end
232238

233239
function lag_h!(h::AbstractVector, θ, σ, λ, p)
234-
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
240+
H = hessian(lagrangian, lag_extras, soadtype, θ,
241+
Constant(σ), Constant(λ), Constant(p))
235242
k = 0
236243
for i in 1:length(θ)
237244
for j in 1:i
@@ -303,12 +310,14 @@ function OptimizationBase.instantiate_function(
303310
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
304311
end
305312
function fg!(res, θ)
306-
(y, _) = value_and_gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
313+
(y, _) = value_and_gradient!(
314+
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
307315
return y
308316
end
309317
if p !== SciMLBase.NullParameters() && p !== nothing
310318
function fg!(res, θ, p)
311-
(y, _) = value_and_gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
319+
(y, _) = value_and_gradient!(
320+
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
312321
return y
313322
end
314323
end
@@ -341,13 +350,15 @@ function OptimizationBase.instantiate_function(
341350

342351
if fgh == true && f.fgh === nothing
343352
function fgh!(G, H, θ)
344-
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, prep_hess, soadtype, Constant(p))
353+
(y, _, _) = value_derivative_and_second_derivative!(
354+
_f, G, H, θ, prep_hess, soadtype, Constant(p))
345355
return y
346356
end
347357

348358
if p !== SciMLBase.NullParameters() && p !== nothing
349359
function fgh!(G, H, θ, p)
350-
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, prep_hess, soadtype, Constant(p))
360+
(y, _, _) = value_derivative_and_second_derivative!(
361+
_f, G, H, θ, prep_hess, soadtype, Constant(p))
351362
return y
352363
end
353364
end
@@ -419,7 +430,8 @@ function OptimizationBase.instantiate_function(
419430
extras_pullback = prepare_pullback(
420431
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),), Constant(p))
421432
function cons_vjp!(J, θ, v)
422-
pullback!(cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
433+
pullback!(
434+
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
423435
end
424436
elseif cons_vjp == true
425437
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
@@ -431,7 +443,8 @@ function OptimizationBase.instantiate_function(
431443
extras_pushforward = prepare_pushforward(
432444
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),), Constant(p))
433445
function cons_jvp!(J, θ, v)
434-
pushforward!(cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
446+
pushforward!(
447+
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
435448
end
436449
elseif cons_jvp == true
437450
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
@@ -442,7 +455,8 @@ function OptimizationBase.instantiate_function(
442455
conshess_sparsity = f.cons_hess_prototype
443456
conshess_colors = f.cons_hess_colorvec
444457
if cons !== nothing && f.cons_h === nothing && cons_h == true
445-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
458+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
459+
for i in 1:num_cons]
446460
colores = getfield.(prep_cons_hess, :coloring_result)
447461
conshess_sparsity = getfield.(colores, :S)
448462
conshess_colors = getfield.(colores, :color)
@@ -461,7 +475,8 @@ function OptimizationBase.instantiate_function(
461475
lag_hess_colors = f.lag_hess_colorvec
462476
if cons !== nothing && f.lag_h === nothing && lag_h == true
463477
lag_extras = prepare_hessian(
464-
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
478+
lagrangian, soadtype, x, Constant(one(eltype(x))),
479+
Constant(ones(eltype(x), num_cons)), Constant(p))
465480
lag_hess_prototype = lag_extras.coloring_result.S
466481
lag_hess_colors = lag_extras.coloring_result.color
467482

@@ -470,12 +485,14 @@ function OptimizationBase.instantiate_function(
470485
cons_h(H, θ)
471486
H *= λ
472487
else
473-
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
488+
hessian!(lagrangian, H, lag_extras, soadtype, θ,
489+
Constant(σ), Constant(λ), Constant(p))
474490
end
475491
end
476492

477493
function lag_h!(h, θ, σ, λ)
478-
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
494+
H = hessian(
495+
lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
479496
k = 0
480497
rows, cols, _ = findnz(H)
481498
for (i, j) in zip(rows, cols)
@@ -492,12 +509,14 @@ function OptimizationBase.instantiate_function(
492509
cons_h!(H, θ)
493510
H *= λ
494511
else
495-
hessian!(lagrangian, H, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
512+
hessian!(lagrangian, H, lag_extras, soadtype, θ,
513+
Constant(σ), Constant(λ), Constant(p))
496514
end
497515
end
498516

499517
function lag_h!(h::AbstractVector, θ, σ, λ, p)
500-
H = hessian(lagrangian, lag_extras, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
518+
H = hessian(lagrangian, lag_extras, soadtype, θ,
519+
Constant(σ), Constant(λ), Constant(p))
501520
k = 0
502521
for i in 1:length(θ)
503522
for j in 1:i

src/OptimizationDIExt.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ function instantiate_function(
3131
g = false, h = false, hv = false, fg = false, fgh = false,
3232
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
3333
lag_h = false)
34-
3534
adtype, soadtype = generate_adtype(adtype)
3635

3736
if g == true && f.grad === nothing
38-
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
37+
prep_grad = prepare_gradient(f.f, adtype, x, Constant(p))
3938
function grad(res, θ)
4039
gradient!(f.f, res, prep_grad, adtype, θ, Constant(p))
4140
end
@@ -183,7 +182,8 @@ function instantiate_function(
183182
conshess_sparsity = f.cons_hess_prototype
184183
conshess_colors = f.cons_hess_colorvec
185184
if f.cons !== nothing && f.cons_h === nothing && cons_h == true
186-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) for i in 1:num_cons]
185+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
186+
for i in 1:num_cons]
187187

188188
function cons_h!(H, θ)
189189
for i in 1:num_cons
@@ -200,20 +200,23 @@ function instantiate_function(
200200

201201
if f.cons !== nothing && lag_h == true && f.lag_h === nothing
202202
lag_prep = prepare_hessian(
203-
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
203+
lagrangian, soadtype, x, Constant(one(eltype(x))),
204+
Constant(ones(eltype(x), num_cons)), Constant(p))
204205
lag_hess_prototype = zeros(Bool, num_cons, length(x))
205206

206207
function lag_h!(H::AbstractMatrix, θ, σ, λ)
207208
if σ == zero(eltype(θ))
208209
cons_h!(H, θ)
209210
H *= λ
210211
else
211-
hessian!(lagrangian, H, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
212+
hessian!(lagrangian, H, lag_prep, soadtype, θ,
213+
Constant(σ), Constant(λ), Constant(p))
212214
end
213215
end
214216

215217
function lag_h!(h::AbstractVector, θ, σ, λ)
216-
H = hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
218+
H = hessian(
219+
lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
217220
k = 0
218221
for i in 1:length(θ)
219222
for j in 1:i
@@ -229,12 +232,14 @@ function instantiate_function(
229232
cons_h!(H, θ)
230233
H *= λ
231234
else
232-
hessian!(lagrangian, H, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
235+
hessian!(lagrangian, H, lag_prep, soadtype, θ,
236+
Constant(σ), Constant(λ), Constant(p))
233237
end
234238
end
235239

236240
function lag_h!(h::AbstractVector, θ, σ, λ, p)
237-
H = hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
241+
H = hessian(lagrangian, lag_prep, soadtype, θ,
242+
Constant(σ), Constant(λ), Constant(p))
238243
k = 0
239244
for i in 1:length(θ)
240245
for j in 1:i
@@ -341,12 +346,14 @@ function instantiate_function(
341346

342347
if fgh == true && f.fgh === nothing
343348
function fgh!(θ)
344-
(y, G, H) = value_derivative_and_second_derivative(f.f, prep_hess, adtype, θ, Constant(p))
349+
(y, G, H) = value_derivative_and_second_derivative(
350+
f.f, prep_hess, adtype, θ, Constant(p))
345351
return y, G, H
346352
end
347353
if p !== SciMLBase.NullParameters() && p !== nothing
348354
function fgh!(θ, p)
349-
(y, G, H) = value_derivative_and_second_derivative(f.f, prep_hess, adtype, θ, Constant(p))
355+
(y, G, H) = value_derivative_and_second_derivative(
356+
f.f, prep_hess, adtype, θ, Constant(p))
350357
return y, G, H
351358
end
352359
end
@@ -396,7 +403,8 @@ function instantiate_function(
396403
end
397404

398405
if f.cons_vjp === nothing && cons_vjp == true && f.cons !== nothing
399-
prep_pullback = prepare_pullback(f.cons, adtype, x, (ones(eltype(x), num_cons),), Constant(p))
406+
prep_pullback = prepare_pullback(
407+
f.cons, adtype, x, (ones(eltype(x), num_cons),), Constant(p))
400408
function cons_vjp!(θ, v)
401409
return only(pullback(f.cons, prep_pullback, adtype, θ, (v,), Constant(p)))
402410
end
@@ -424,7 +432,8 @@ function instantiate_function(
424432
function cons_i(x, i)
425433
return f.cons(x, p)[i]
426434
end
427-
prep_cons_hess = [prepare_hessian(cons_i, soadtype, x, Constant(i)) for i in 1:num_cons]
435+
prep_cons_hess = [prepare_hessian(cons_i, soadtype, x, Constant(i))
436+
for i in 1:num_cons]
428437

429438
function cons_h!(θ)
430439
H = map(1:num_cons) do i
@@ -442,14 +451,16 @@ function instantiate_function(
442451

443452
if f.cons !== nothing && lag_h == true && f.lag_h === nothing
444453
lag_prep = prepare_hessian(
445-
lagrangian, soadtype, x, Constant(one(eltype(x))), Constant(ones(eltype(x), num_cons)), Constant(p))
454+
lagrangian, soadtype, x, Constant(one(eltype(x))),
455+
Constant(ones(eltype(x), num_cons)), Constant(p))
446456
lag_hess_prototype = zeros(Bool, num_cons, length(x))
447457

448458
function lag_h!(θ, σ, λ)
449459
if σ == zero(eltype(θ))
450460
return λ .* cons_h(θ)
451461
else
452-
return hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
462+
return hessian(lagrangian, lag_prep, soadtype, θ,
463+
Constant(σ), Constant(λ), Constant(p))
453464
end
454465
end
455466

@@ -458,7 +469,8 @@ function instantiate_function(
458469
if σ == zero(eltype(θ))
459470
return λ .* cons_h(θ)
460471
else
461-
return hessian(lagrangian, lag_prep, soadtype, θ, Constant(σ), Constant(λ), Constant(p))
472+
return hessian(lagrangian, lag_prep, soadtype, θ,
473+
Constant(σ), Constant(λ), Constant(p))
462474
end
463475
end
464476
end

0 commit comments

Comments
 (0)