Skip to content

Commit 024d97d

Browse files
Make secondorder handling explicit
1 parent f8d61fd commit 024d97d

File tree

5 files changed

+77
-68
lines changed

5 files changed

+77
-68
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ using ADTypes, SciMLBase
1919
import Zygote
2020

2121
function OptimizationBase.instantiate_function(
22-
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote,
22+
f::OptimizationFunction{true}, x,
23+
adtype::Union{ADTypes.AutoZygote,
24+
DifferentiationInterface.SecondOrder{
25+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}},
2326
p = SciMLBase.NullParameters(), num_cons = 0;
2427
g = false, h = false, hv = false, fg = false, fgh = false,
2528
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
@@ -280,7 +283,10 @@ function OptimizationBase.instantiate_function(
280283
end
281284

282285
function OptimizationBase.instantiate_function(
283-
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote},
286+
f::OptimizationFunction{true}, x,
287+
adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote,
288+
DifferentiationInterface.SecondOrder{
289+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}}},
284290
p = SciMLBase.NullParameters(), num_cons = 0;
285291
g = false, h = false, hv = false, fg = false, fgh = false,
286292
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,

src/adtypes.jl

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ Hessian is not defined via Zygote.
220220
AutoZygote
221221

222222
function generate_adtype(adtype)
223-
if !(adtype isa SciMLBase.NoAD && adtype isa DifferentiationInterface.SecondOrder)
223+
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
224224
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
225225
elseif adtype isa DifferentiationInterface.SecondOrder
226226
soadtype = adtype
@@ -232,48 +232,49 @@ function generate_adtype(adtype)
232232
return adtype, soadtype
233233
end
234234

235-
function generate_sparse_adtype(adtype)
235+
function spadtype_to_spsoadtype(adtype)
236+
if !(adtype.dense_ad isa SciMLBase.NoAD ||
237+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
238+
soadtype = AutoSparse(
239+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
240+
sparsity_detector = adtype.sparsity_detector,
241+
coloring_algorithm = adtype.coloring_algorithm)
242+
else
243+
soadtype = adtype
244+
end
245+
return soadtype
246+
end
247+
248+
function filled_spad(adtype)
236249
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
237-
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
250+
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
238251
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
239252
coloring_algorithm = GreedyColoringAlgorithm())
240-
if !(adtype.dense_ad isa SciMLBase.NoAD &&
241-
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
242-
soadtype = AutoSparse(
243-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
244-
sparsity_detector = TracerSparsityDetector(),
245-
coloring_algorithm = GreedyColoringAlgorithm())
246-
end
247253
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
248254
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
249255
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
250256
coloring_algorithm = adtype.coloring_algorithm)
251-
if !(adtype.dense_ad isa SciMLBase.NoAD &&
252-
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
253-
soadtype = AutoSparse(
254-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
255-
sparsity_detector = TracerSparsityDetector(),
256-
coloring_algorithm = adtype.coloring_algorithm)
257-
end
258257
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
259258
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
260259
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
261260
coloring_algorithm = GreedyColoringAlgorithm())
262-
if !(adtype.dense_ad isa SciMLBase.NoAD &&
263-
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
264-
soadtype = AutoSparse(
265-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
266-
sparsity_detector = adtype.sparsity_detector,
267-
coloring_algorithm = GreedyColoringAlgorithm())
268-
end
261+
end
262+
end
263+
264+
function generate_sparse_adtype(adtype)
265+
266+
if !(adtype.dense_ad isa DifferentiationInterface.SecondOrder)
267+
adtype = filled_spad(adtype)
268+
soadtype = spadtype_to_spsoadtype(adtype)
269269
else
270-
if !(adtype.dense_ad isa SciMLBase.NoAD &&
271-
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
272-
soadtype = AutoSparse(
273-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
274-
sparsity_detector = adtype.sparsity_detector,
275-
coloring_algorithm = adtype.coloring_algorithm)
276-
end
270+
soadtype = adtype
271+
adtype = AutoSparse(
272+
adtype.dense_ad.inner,
273+
sparsity_detector = soadtype.sparsity_detector,
274+
coloring_algorithm = soadtype.coloring_algorithm)
275+
adtype = filled_spad(adtype)
276+
soadtype = filled_spad(soadtype)
277277
end
278+
278279
return adtype, soadtype
279280
end

src/cache.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
4242

4343
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
4444

45+
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
46+
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
47+
SciMLBase.requireslagh(opt))
48+
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
49+
So a `SecondOrder` with $adtype for both inner and outer will be creates, this can be suboptimal and not work in all cases so
50+
an explicit `SecondOrder` ADtype is recommended."
51+
end
52+
4553
f = OptimizationBase.instantiate_function(
4654
prob.f, reinit_cache, prob.f.adtype, num_cons;
4755
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),

test/adtests.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ optprob.cons_h(H3, x0)
105105
μ = randn(1)
106106
σ = rand()
107107
optprob.lag_h(H4, x0, σ, μ)
108-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
108+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
109109

110110
G2 = Array{Float64}(undef, 2)
111111
H2 = Array{Float64}(undef, 2, 2)
@@ -142,7 +142,7 @@ optprob.cons_h(H3, x0)
142142
μ = randn(1)
143143
σ = rand()
144144
optprob.lag_h(H4, x0, σ, μ)
145-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
145+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
146146

147147
G2 = Array{Float64}(undef, 2)
148148
H2 = Array{Float64}(undef, 2, 2)
@@ -179,7 +179,7 @@ optprob.cons_h(H3, x0)
179179
μ = randn(1)
180180
σ = rand()
181181
optprob.lag_h(H4, x0, σ, μ)
182-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
182+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
183183

184184
G2 = Array{Float64}(undef, 2)
185185
H2 = Array{Float64}(undef, 2, 2)
@@ -217,14 +217,15 @@ optprob.cons_h(H3, x0)
217217
μ = randn(1)
218218
σ = rand()
219219
optprob.lag_h(H4, x0, σ, μ)
220-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
220+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
221221

222222
G2 = Array{Float64}(undef, 2)
223223
H2 = Array{Float64}(undef, 2, 2)
224224

225-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = cons)
225+
optf = OptimizationFunction(
226+
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
226227
optprob = OptimizationBase.instantiate_function(
227-
optf, x0, OptimizationBase.AutoZygote(),
228+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
228229
nothing, 1, g = true, h = true, hv = true,
229230
cons_j = true, cons_h = true, cons_vjp = true,
230231
cons_jvp = true, lag_h = true)
@@ -254,14 +255,14 @@ optprob.cons_h(H3, x0)
254255
μ = randn(1)
255256
σ = rand()
256257
optprob.lag_h(H4, x0, σ, μ)
257-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
258+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
258259

259260
G2 = Array{Float64}(undef, 2)
260261
H2 = Array{Float64}(undef, 2, 2)
261262

262-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = cons)
263+
optf = OptimizationFunction(rosenbrock, DifferentiationInterface.SecondOrder(ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()), cons = cons)
263264
optprob = OptimizationBase.instantiate_function(
264-
optf, x0, OptimizationBase.AutoFiniteDiff(),
265+
optf, x0, DifferentiationInterface.SecondOrder(ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
265266
nothing, 1, g = true, h = true, hv = true,
266267
cons_j = true, cons_h = true, cons_vjp = true,
267268
cons_jvp = true, lag_h = true)
@@ -287,11 +288,12 @@ optprob.cons_h(H3, x0)
287288
H3 = [Array{Float64}(undef, 2, 2)]
288289
optprob.cons_h(H3, x0)
289290
@test H3[[2.0 0.0; 0.0 2.0]] rtol=1e-5
291+
Random.seed!(123)
290292
H4 = Array{Float64}(undef, 2, 2)
291293
μ = randn(1)
292294
σ = rand()
293295
optprob.lag_h(H4, x0, σ, μ)
294-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
296+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
295297
end
296298

297299
@testset "two constraints tests" begin
@@ -448,9 +450,10 @@ end
448450
G2 = Array{Float64}(undef, 2)
449451
H2 = Array{Float64}(undef, 2, 2)
450452

451-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c)
453+
optf = OptimizationFunction(
454+
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
452455
optprob = OptimizationBase.instantiate_function(
453-
optf, x0, OptimizationBase.AutoZygote(),
456+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
454457
nothing, 2, g = true, h = true, hv = true,
455458
cons_j = true, cons_h = true, cons_vjp = true,
456459
cons_jvp = true, lag_h = true)
@@ -486,9 +489,9 @@ end
486489
H2 = Array{Float64}(undef, 2, 2)
487490

488491
optf = OptimizationFunction(
489-
rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = con2_c)
492+
rosenbrock, DifferentiationInterface.SecondOrder(ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()), cons = con2_c)
490493
optprob = OptimizationBase.instantiate_function(
491-
optf, x0, OptimizationBase.AutoFiniteDiff(),
494+
optf, x0, DifferentiationInterface.SecondOrder(ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
492495
nothing, 2, g = true, h = true, hv = true,
493496
cons_j = true, cons_h = true, cons_vjp = true,
494497
cons_jvp = true, lag_h = true)
@@ -734,12 +737,12 @@ end
734737
@test lag_H lag_H_expected
735738
@test nnz(lag_H) == 5
736739

737-
optf = OptimizationFunction(sparse_objective, OptimizationBase.AutoSparseZygote(),
740+
optf = OptimizationFunction(sparse_objective, AutoSparse(DifferentiationInterface.SecondOrder(ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
738741
cons = sparse_constraints)
739742

740743
# Instantiate the optimization problem
741744
optprob = OptimizationBase.instantiate_function(optf, x0,
742-
OptimizationBase.AutoSparseZygote(),
745+
AutoSparse(DifferentiationInterface.SecondOrder(ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
743746
nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true)
744747
# Test gradient
745748
G = zeros(3)
@@ -1065,10 +1068,10 @@ end
10651068

10661069
cons = (x, p) -> [x[1]^2 + x[2]^2]
10671070
optf = OptimizationFunction{false}(rosenbrock,
1068-
OptimizationBase.AutoZygote(),
1071+
SecondOrder(AutoForwardDiff(), AutoZygote()),
10691072
cons = cons)
10701073
optprob = OptimizationBase.instantiate_function(
1071-
optf, x0, OptimizationBase.AutoZygote(),
1074+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
10721075
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)
10731076

10741077
@test optprob.grad(x0) == G1
@@ -1081,10 +1084,10 @@ end
10811084

10821085
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
10831086
optf = OptimizationFunction{false}(rosenbrock,
1084-
OptimizationBase.AutoZygote(),
1087+
SecondOrder(AutoForwardDiff(), AutoZygote()),
10851088
cons = cons)
10861089
optprob = OptimizationBase.instantiate_function(
1087-
optf, x0, OptimizationBase.AutoZygote(),
1090+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
10881091
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)
10891092

10901093
@test optprob.grad(x0) == G1

test/matrixvalued.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ using OptimizationBase, LinearAlgebra, ForwardDiff, Zygote, FiniteDiff,
33
using Test, ReverseDiff
44

55
@testset "Matrix Valued" begin
6-
for adtype in [AutoForwardDiff(), AutoZygote(), AutoFiniteDiff(),
6+
for adtype in [AutoForwardDiff(), SecondOrder(AutoForwardDiff(), AutoZygote()), SecondOrder(AutoForwardDiff(), AutoFiniteDiff()),
77
AutoSparse(AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()),
8-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()),
9-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())]
8+
AutoSparse(SecondOrder(AutoForwardDiff(), AutoZygote()), sparsity_detector = TracerLocalSparsityDetector()),
9+
AutoSparse(SecondOrder(AutoForwardDiff(), AutoFiniteDiff()), sparsity_detector = TracerLocalSparsityDetector())]
1010
# 1. Matrix Factorization
11+
@show adtype
1112
function matrix_factorization_objective(X, A)
1213
U, V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]),
1314
@view(X[1:size(A, 1), (Int(size(A, 2) / 2) + 1):size(A, 2)])
@@ -28,12 +29,7 @@ using Test, ReverseDiff
2829
cons_j = true, cons_h = true)
2930
optf.grad(hcat(U_mf, V_mf))
3031
optf.hess(hcat(U_mf, V_mf))
31-
if adtype != AutoSparse(
32-
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
33-
adtype !=
34-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
35-
adtype !=
36-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
32+
if !(adtype isa ADTypes.AutoSparse)
3733
optf.cons_j(hcat(U_mf, V_mf))
3834
optf.cons_h(hcat(U_mf, V_mf))
3935
end
@@ -55,12 +51,7 @@ using Test, ReverseDiff
5551
cons_j = true, cons_h = true)
5652
optf.grad(X_pca)
5753
optf.hess(X_pca)
58-
if adtype != AutoSparse(
59-
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
60-
adtype !=
61-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
62-
adtype !=
63-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
54+
if !(adtype isa ADTypes.AutoSparse)
6455
optf.cons_j(X_pca)
6556
optf.cons_h(X_pca)
6657
end

0 commit comments

Comments
 (0)