Skip to content

Commit e10bed6

Browse files
Merge pull request #118 from SciML/secondorder
Remove automatic FoR `soadtype` creations
2 parents 1cb8a90 + 40208b6 commit e10bed6

File tree

6 files changed

+97
-116
lines changed

6 files changed

+97
-116
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.2.1"
4+
version = "2.3.0"
55

66

77
[deps]

ext/OptimizationZygoteExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ using ADTypes, SciMLBase
1919
import Zygote
2020

2121
function OptimizationBase.instantiate_function(
22-
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote,
22+
f::OptimizationFunction{true}, x,
23+
adtype::Union{ADTypes.AutoZygote,
24+
DifferentiationInterface.SecondOrder{
25+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}},
2326
p = SciMLBase.NullParameters(), num_cons = 0;
2427
g = false, h = false, hv = false, fg = false, fgh = false,
2528
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
@@ -280,7 +283,10 @@ function OptimizationBase.instantiate_function(
280283
end
281284

282285
function OptimizationBase.instantiate_function(
283-
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote},
286+
f::OptimizationFunction{true}, x,
287+
adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote,
288+
DifferentiationInterface.SecondOrder{
289+
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}}},
284290
p = SciMLBase.NullParameters(), num_cons = 0;
285291
g = false, h = false, hv = false, fg = false, fgh = false,
286292
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,

src/adtypes.jl

Lines changed: 36 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -220,102 +220,60 @@ Hessian is not defined via Zygote.
220220
AutoZygote
221221

222222
function generate_adtype(adtype)
223-
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
224-
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
225-
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
226-
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
227-
else
223+
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
224+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
225+
elseif adtype isa DifferentiationInterface.SecondOrder
226+
soadtype = adtype
227+
adtype = adtype.inner
228+
elseif adtype isa SciMLBase.NoAD
228229
soadtype = adtype
230+
adtype = adtype
229231
end
230232
return adtype, soadtype
231233
end
232234

233-
function generate_sparse_adtype(adtype)
235+
function spadtype_to_spsoadtype(adtype)
236+
if !(adtype.dense_ad isa SciMLBase.NoAD ||
237+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
238+
soadtype = AutoSparse(
239+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
240+
sparsity_detector = adtype.sparsity_detector,
241+
coloring_algorithm = adtype.coloring_algorithm)
242+
else
243+
soadtype = adtype
244+
end
245+
return soadtype
246+
end
247+
248+
function filled_spad(adtype)
234249
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
235250
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
236251
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
237252
coloring_algorithm = GreedyColoringAlgorithm())
238-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
239-
soadtype = AutoSparse(
240-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
241-
sparsity_detector = TracerSparsityDetector(),
242-
coloring_algorithm = GreedyColoringAlgorithm())
243-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
244-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
245-
soadtype = AutoSparse(
246-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
247-
sparsity_detector = TracerSparsityDetector(),
248-
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
249-
elseif !(adtype isa SciMLBase.NoAD) &&
250-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
251-
soadtype = AutoSparse(
252-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
253-
sparsity_detector = TracerSparsityDetector(),
254-
coloring_algorithm = GreedyColoringAlgorithm())
255-
end
256253
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
257254
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
258255
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
259256
coloring_algorithm = adtype.coloring_algorithm)
260-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
261-
soadtype = AutoSparse(
262-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
263-
sparsity_detector = TracerSparsityDetector(),
264-
coloring_algorithm = adtype.coloring_algorithm)
265-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
266-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
267-
soadtype = AutoSparse(
268-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
269-
sparsity_detector = TracerSparsityDetector(),
270-
coloring_algorithm = adtype.coloring_algorithm)
271-
elseif !(adtype isa SciMLBase.NoAD) &&
272-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
273-
soadtype = AutoSparse(
274-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
275-
sparsity_detector = TracerSparsityDetector(),
276-
coloring_algorithm = adtype.coloring_algorithm)
277-
end
278257
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
279258
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
280259
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
281260
coloring_algorithm = GreedyColoringAlgorithm())
282-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
283-
soadtype = AutoSparse(
284-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
285-
sparsity_detector = adtype.sparsity_detector,
286-
coloring_algorithm = GreedyColoringAlgorithm())
287-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
288-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
289-
soadtype = AutoSparse(
290-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
291-
sparsity_detector = adtype.sparsity_detector,
292-
coloring_algorithm = GreedyColoringAlgorithm())
293-
elseif !(adtype isa SciMLBase.NoAD) &&
294-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
295-
soadtype = AutoSparse(
296-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
297-
sparsity_detector = adtype.sparsity_detector,
298-
coloring_algorithm = GreedyColoringAlgorithm())
299-
end
261+
end
262+
end
263+
264+
function generate_sparse_adtype(adtype)
265+
if !(adtype.dense_ad isa DifferentiationInterface.SecondOrder)
266+
adtype = filled_spad(adtype)
267+
soadtype = spadtype_to_spsoadtype(adtype)
300268
else
301-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
302-
soadtype = AutoSparse(
303-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
304-
sparsity_detector = adtype.sparsity_detector,
305-
coloring_algorithm = adtype.coloring_algorithm)
306-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
307-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
308-
soadtype = AutoSparse(
309-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
310-
sparsity_detector = adtype.sparsity_detector,
311-
coloring_algorithm = adtype.coloring_algorithm)
312-
elseif !(adtype isa SciMLBase.NoAD) &&
313-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
314-
soadtype = AutoSparse(
315-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
316-
sparsity_detector = adtype.sparsity_detector,
317-
coloring_algorithm = adtype.coloring_algorithm)
318-
end
269+
soadtype = adtype
270+
adtype = AutoSparse(
271+
adtype.dense_ad.inner,
272+
sparsity_detector = soadtype.sparsity_detector,
273+
coloring_algorithm = soadtype.coloring_algorithm)
274+
adtype = filled_spad(adtype)
275+
soadtype = filled_spad(soadtype)
319276
end
277+
320278
return adtype, soadtype
321279
end

src/cache.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
4242

4343
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
4444

45+
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
46+
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
47+
SciMLBase.requireslagh(opt))
48+
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
49+
So a `SecondOrder` with $adtype for both inner and outer will be created, this can be suboptimal and not work in some cases so
50+
an explicit `SecondOrder` ADtype is recommended."
51+
end
52+
4553
f = OptimizationBase.instantiate_function(
4654
prob.f, reinit_cache, prob.f.adtype, num_cons;
4755
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),

test/adtests.jl

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ optprob.cons_h(H3, x0)
105105
μ = randn(1)
106106
σ = rand()
107107
optprob.lag_h(H4, x0, σ, μ)
108-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
108+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
109109

110110
G2 = Array{Float64}(undef, 2)
111111
H2 = Array{Float64}(undef, 2, 2)
@@ -142,7 +142,7 @@ optprob.cons_h(H3, x0)
142142
μ = randn(1)
143143
σ = rand()
144144
optprob.lag_h(H4, x0, σ, μ)
145-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
145+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
146146

147147
G2 = Array{Float64}(undef, 2)
148148
H2 = Array{Float64}(undef, 2, 2)
@@ -179,7 +179,7 @@ optprob.cons_h(H3, x0)
179179
μ = randn(1)
180180
σ = rand()
181181
optprob.lag_h(H4, x0, σ, μ)
182-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
182+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
183183

184184
G2 = Array{Float64}(undef, 2)
185185
H2 = Array{Float64}(undef, 2, 2)
@@ -217,14 +217,15 @@ optprob.cons_h(H3, x0)
217217
μ = randn(1)
218218
σ = rand()
219219
optprob.lag_h(H4, x0, σ, μ)
220-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
220+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
221221

222222
G2 = Array{Float64}(undef, 2)
223223
H2 = Array{Float64}(undef, 2, 2)
224224

225-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = cons)
225+
optf = OptimizationFunction(
226+
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
226227
optprob = OptimizationBase.instantiate_function(
227-
optf, x0, OptimizationBase.AutoZygote(),
228+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
228229
nothing, 1, g = true, h = true, hv = true,
229230
cons_j = true, cons_h = true, cons_vjp = true,
230231
cons_jvp = true, lag_h = true)
@@ -254,14 +255,19 @@ optprob.cons_h(H3, x0)
254255
μ = randn(1)
255256
σ = rand()
256257
optprob.lag_h(H4, x0, σ, μ)
257-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
258+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
258259

259260
G2 = Array{Float64}(undef, 2)
260261
H2 = Array{Float64}(undef, 2, 2)
261262

262-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = cons)
263+
optf = OptimizationFunction(rosenbrock,
264+
DifferentiationInterface.SecondOrder(
265+
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
266+
cons = cons)
263267
optprob = OptimizationBase.instantiate_function(
264-
optf, x0, OptimizationBase.AutoFiniteDiff(),
268+
optf, x0,
269+
DifferentiationInterface.SecondOrder(
270+
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
265271
nothing, 1, g = true, h = true, hv = true,
266272
cons_j = true, cons_h = true, cons_vjp = true,
267273
cons_jvp = true, lag_h = true)
@@ -287,11 +293,12 @@ optprob.cons_h(H3, x0)
287293
H3 = [Array{Float64}(undef, 2, 2)]
288294
optprob.cons_h(H3, x0)
289295
@test H3[[2.0 0.0; 0.0 2.0]] rtol=1e-5
296+
Random.seed!(123)
290297
H4 = Array{Float64}(undef, 2, 2)
291298
μ = randn(1)
292299
σ = rand()
293300
optprob.lag_h(H4, x0, σ, μ)
294-
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
301+
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
295302
end
296303

297304
@testset "two constraints tests" begin
@@ -448,9 +455,10 @@ end
448455
G2 = Array{Float64}(undef, 2)
449456
H2 = Array{Float64}(undef, 2, 2)
450457

451-
optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c)
458+
optf = OptimizationFunction(
459+
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
452460
optprob = OptimizationBase.instantiate_function(
453-
optf, x0, OptimizationBase.AutoZygote(),
461+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
454462
nothing, 2, g = true, h = true, hv = true,
455463
cons_j = true, cons_h = true, cons_vjp = true,
456464
cons_jvp = true, lag_h = true)
@@ -486,9 +494,13 @@ end
486494
H2 = Array{Float64}(undef, 2, 2)
487495

488496
optf = OptimizationFunction(
489-
rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = con2_c)
497+
rosenbrock, DifferentiationInterface.SecondOrder(
498+
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
499+
cons = con2_c)
490500
optprob = OptimizationBase.instantiate_function(
491-
optf, x0, OptimizationBase.AutoFiniteDiff(),
501+
optf, x0,
502+
DifferentiationInterface.SecondOrder(
503+
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
492504
nothing, 2, g = true, h = true, hv = true,
493505
cons_j = true, cons_h = true, cons_vjp = true,
494506
cons_jvp = true, lag_h = true)
@@ -734,12 +746,15 @@ end
734746
@test lag_H lag_H_expected
735747
@test nnz(lag_H) == 5
736748

737-
optf = OptimizationFunction(sparse_objective, OptimizationBase.AutoSparseZygote(),
749+
optf = OptimizationFunction(sparse_objective,
750+
AutoSparse(DifferentiationInterface.SecondOrder(
751+
ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
738752
cons = sparse_constraints)
739753

740754
# Instantiate the optimization problem
741755
optprob = OptimizationBase.instantiate_function(optf, x0,
742-
OptimizationBase.AutoSparseZygote(),
756+
AutoSparse(DifferentiationInterface.SecondOrder(
757+
ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
743758
nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true)
744759
# Test gradient
745760
G = zeros(3)
@@ -1065,10 +1080,10 @@ end
10651080

10661081
cons = (x, p) -> [x[1]^2 + x[2]^2]
10671082
optf = OptimizationFunction{false}(rosenbrock,
1068-
OptimizationBase.AutoZygote(),
1083+
SecondOrder(AutoForwardDiff(), AutoZygote()),
10691084
cons = cons)
10701085
optprob = OptimizationBase.instantiate_function(
1071-
optf, x0, OptimizationBase.AutoZygote(),
1086+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
10721087
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)
10731088

10741089
@test optprob.grad(x0) == G1
@@ -1081,10 +1096,10 @@ end
10811096

10821097
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
10831098
optf = OptimizationFunction{false}(rosenbrock,
1084-
OptimizationBase.AutoZygote(),
1099+
SecondOrder(AutoForwardDiff(), AutoZygote()),
10851100
cons = cons)
10861101
optprob = OptimizationBase.instantiate_function(
1087-
optf, x0, OptimizationBase.AutoZygote(),
1102+
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
10881103
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)
10891104

10901105
@test optprob.grad(x0) == G1

test/matrixvalued.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ using OptimizationBase, LinearAlgebra, ForwardDiff, Zygote, FiniteDiff,
33
using Test, ReverseDiff
44

55
@testset "Matrix Valued" begin
6-
for adtype in [AutoForwardDiff(), AutoZygote(), AutoFiniteDiff(),
6+
for adtype in [AutoForwardDiff(), SecondOrder(AutoForwardDiff(), AutoZygote()),
7+
SecondOrder(AutoForwardDiff(), AutoFiniteDiff()),
78
AutoSparse(AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()),
8-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()),
9-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())]
9+
AutoSparse(SecondOrder(AutoForwardDiff(), AutoZygote()),
10+
sparsity_detector = TracerLocalSparsityDetector()),
11+
AutoSparse(SecondOrder(AutoForwardDiff(), AutoFiniteDiff()),
12+
sparsity_detector = TracerLocalSparsityDetector())]
1013
# 1. Matrix Factorization
14+
@show adtype
1115
function matrix_factorization_objective(X, A)
1216
U, V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]),
1317
@view(X[1:size(A, 1), (Int(size(A, 2) / 2) + 1):size(A, 2)])
@@ -28,12 +32,7 @@ using Test, ReverseDiff
2832
cons_j = true, cons_h = true)
2933
optf.grad(hcat(U_mf, V_mf))
3034
optf.hess(hcat(U_mf, V_mf))
31-
if adtype != AutoSparse(
32-
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
33-
adtype !=
34-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
35-
adtype !=
36-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
35+
if !(adtype isa ADTypes.AutoSparse)
3736
optf.cons_j(hcat(U_mf, V_mf))
3837
optf.cons_h(hcat(U_mf, V_mf))
3938
end
@@ -55,12 +54,7 @@ using Test, ReverseDiff
5554
cons_j = true, cons_h = true)
5655
optf.grad(X_pca)
5756
optf.hess(X_pca)
58-
if adtype != AutoSparse(
59-
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
60-
adtype !=
61-
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
62-
adtype !=
63-
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
57+
if !(adtype isa ADTypes.AutoSparse)
6458
optf.cons_j(X_pca)
6559
optf.cons_h(X_pca)
6660
end

0 commit comments

Comments
 (0)