Skip to content

Commit 75be96e

Browse files
Actually test sparse zygote
1 parent 2aadf02 commit 75be96e

File tree

3 files changed

+26
-26
lines changed

3 files changed

+26
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.0.4"
4+
version = "2.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/OptimizationZygoteExt.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,13 @@ function OptimizationBase.instantiate_function(
290290
adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype)
291291

292292
if g == true && f.grad === nothing
293-
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
293+
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
294294
function grad(res, θ)
295-
gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
295+
gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
296296
end
297297
if p !== SciMLBase.NullParameters() && p !== nothing
298298
function grad(res, θ, p)
299-
gradient!(_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
299+
gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
300300
end
301301
end
302302
elseif g == true
@@ -307,17 +307,17 @@ function OptimizationBase.instantiate_function(
307307

308308
if fg == true && f.fg === nothing
309309
if g == false
310-
extras_grad = prepare_gradient(_f, adtype.dense_ad, x, Constant(p))
310+
extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p))
311311
end
312312
function fg!(res, θ)
313313
(y, _) = value_and_gradient!(
314-
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
314+
f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
315315
return y
316316
end
317317
if p !== SciMLBase.NullParameters() && p !== nothing
318318
function fg!(res, θ, p)
319319
(y, _) = value_and_gradient!(
320-
_f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
320+
f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p))
321321
return y
322322
end
323323
end
@@ -330,16 +330,16 @@ function OptimizationBase.instantiate_function(
330330
hess_sparsity = f.hess_prototype
331331
hess_colors = f.hess_colorvec
332332
if h == true && f.hess === nothing
333-
prep_hess = prepare_hessian(_f, soadtype, x, Constant(p))
333+
prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p))
334334
function hess(res, θ)
335-
hessian!(_f, res, prep_hess, soadtype, θ, Constant(p))
335+
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
336336
end
337-
hess_sparsity = extras_hess.coloring_result.S
338-
hess_colors = extras_hess.coloring_result.color
337+
hess_sparsity = prep_hess.coloring_result.S
338+
hess_colors = prep_hess.coloring_result.color
339339

340340
if p !== SciMLBase.NullParameters() && p !== nothing
341341
function hess(res, θ, p)
342-
hessian!(_f, res, prep_hess, soadtype, θ, Constant(p))
342+
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
343343
end
344344
end
345345
elseif h == true
@@ -351,14 +351,14 @@ function OptimizationBase.instantiate_function(
351351
if fgh == true && f.fgh === nothing
352352
function fgh!(G, H, θ)
353353
(y, _, _) = value_derivative_and_second_derivative!(
354-
_f, G, H, θ, prep_hess, soadtype, Constant(p))
354+
f.f, G, H, θ, prep_hess, soadtype, Constant(p))
355355
return y
356356
end
357357

358358
if p !== SciMLBase.NullParameters() && p !== nothing
359359
function fgh!(G, H, θ, p)
360360
(y, _, _) = value_derivative_and_second_derivative!(
361-
_f, G, H, θ, prep_hess, soadtype, Constant(p))
361+
f.f, G, H, θ, prep_hess, soadtype, Constant(p))
362362
return y
363363
end
364364
end
@@ -371,11 +371,11 @@ function OptimizationBase.instantiate_function(
371371
if hv == true && f.hv === nothing
372372
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
373373
function hv!(H, θ, v)
374-
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
374+
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
375375
end
376376
if p !== SciMLBase.NullParameters() && p !== nothing
377377
function hv!(H, θ, v, p)
378-
hvp!(_f, H, prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
378+
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
379379
end
380380
end
381381
elseif hv == true
@@ -411,15 +411,15 @@ function OptimizationBase.instantiate_function(
411411
cons_jac_prototype = f.cons_jac_prototype
412412
cons_jac_colorvec = f.cons_jac_colorvec
413413
if cons !== nothing && cons_j == true && f.cons_j === nothing
414-
prep_jac = prepare_jacobian(cons_oop, adtype, x, Constant(p))
414+
prep_jac = prepare_jacobian(cons_oop, adtype, x)
415415
function cons_j!(J, θ)
416-
jacobian!(cons_oop, J, prep_jac, adtype, θ, Constant(p))
416+
jacobian!(cons_oop, J, prep_jac, adtype, θ)
417417
if size(J, 1) == 1
418418
J = vec(J)
419419
end
420420
end
421-
cons_jac_prototype = extras_jac.coloring_result.S
422-
cons_jac_colorvec = extras_jac.coloring_result.color
421+
cons_jac_prototype = prep_jac.coloring_result.S
422+
cons_jac_colorvec = prep_jac.coloring_result.color
423423
elseif cons !== nothing && cons_j == true
424424
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
425425
else
@@ -428,10 +428,10 @@ function OptimizationBase.instantiate_function(
428428

429429
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
430430
extras_pullback = prepare_pullback(
431-
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),), Constant(p))
431+
cons_oop, adtype.dense_ad, x, (ones(eltype(x), num_cons),))
432432
function cons_vjp!(J, θ, v)
433433
pullback!(
434-
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,), Constant(p))
434+
cons_oop, (J,), extras_pullback, adtype.dense_ad, θ, (v,))
435435
end
436436
elseif cons_vjp == true
437437
cons_vjp! = (J, θ, v) -> f.cons_vjp(J, θ, v, p)
@@ -441,10 +441,10 @@ function OptimizationBase.instantiate_function(
441441

442442
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
443443
extras_pushforward = prepare_pushforward(
444-
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),), Constant(p))
444+
cons_oop, adtype.dense_ad, x, (ones(eltype(x), length(x)),))
445445
function cons_jvp!(J, θ, v)
446446
pushforward!(
447-
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,), Constant(p))
447+
cons_oop, (J,), extras_pushforward, adtype.dense_ad, θ, (v,))
448448
end
449449
elseif cons_jvp == true
450450
cons_jvp! = (J, θ, v) -> f.cons_jvp(J, θ, v, p)
@@ -482,7 +482,7 @@ function OptimizationBase.instantiate_function(
482482

483483
function lag_h!(H::AbstractMatrix, θ, σ, λ)
484484
if σ == zero(eltype(θ))
485-
cons_h(H, θ)
485+
cons_h!(H, θ)
486486
H *= λ
487487
else
488488
hessian!(lagrangian, H, lag_extras, soadtype, θ,

test/adtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ end
739739

740740
# Instantiate the optimization problem
741741
optprob = OptimizationBase.instantiate_function(optf, x0,
742-
OptimizationBase.AutoSparseForwardDiff(),
742+
OptimizationBase.AutoSparseZygote(),
743743
nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true)
744744
# Test gradient
745745
G = zeros(3)

0 commit comments

Comments
 (0)