Skip to content

Commit 6043b0c

Browse files
Merge pull request #121 from SciML/secondorder
Make zygote second order FD over Zygote
2 parents e10bed6 + c12dd3b commit 6043b0c

File tree

4 files changed

+29
-13
lines changed

4 files changed

+29
-13
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1616
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
1717
hvp, jacobian, Constant
1818
using ADTypes, SciMLBase
19-
import Zygote
19+
import Zygote, Zygote.ForwardDiff
2020

2121
function OptimizationBase.instantiate_function(
2222
f::OptimizationFunction{true}, x,

src/adtypes.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,11 @@ Hessian is not defined via Zygote.
220220
AutoZygote
221221

222222
function generate_adtype(adtype)
223-
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
223+
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder ||
224+
adtype isa AutoZygote)
224225
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
226+
elseif adtype isa AutoZygote
227+
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
225228
elseif adtype isa DifferentiationInterface.SecondOrder
226229
soadtype = adtype
227230
adtype = adtype.inner
@@ -234,11 +237,17 @@ end
234237

235238
function spadtype_to_spsoadtype(adtype)
236239
if !(adtype.dense_ad isa SciMLBase.NoAD ||
237-
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
240+
adtype.dense_ad isa DifferentiationInterface.SecondOrder ||
241+
adtype.dense_ad isa AutoZygote)
238242
soadtype = AutoSparse(
239243
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
240244
sparsity_detector = adtype.sparsity_detector,
241245
coloring_algorithm = adtype.coloring_algorithm)
246+
elseif adtype.dense_ad isa AutoZygote
247+
soadtype = AutoSparse(
248+
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
249+
sparsity_detector = adtype.sparsity_detector,
250+
coloring_algorithm = adtype.coloring_algorithm)
242251
else
243252
soadtype = adtype
244253
end

src/cache.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,18 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
4242

4343
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
4444

45-
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
45+
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder ||
46+
prob.f.adtype isa AutoZygote) &&
4647
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
4748
SciMLBase.requireslagh(opt))
4849
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
49-
So a `SecondOrder` with $adtype for both inner and outer will be created, this can be suboptimal and not work in some cases so
50+
So a `SecondOrder` with $(prob.f.adtype) for both inner and outer will be created, this can be suboptimal and not work in some cases so
51+
an explicit `SecondOrder` ADtype is recommended."
52+
elseif prob.f.adtype isa AutoZygote &&
53+
(SciMLBase.requiresconshess(opt) || SciMLBase.requireslagh(opt) ||
54+
SciMLBase.requireshessian(opt))
55+
@warn "The selected optimization algorithm requires second order derivatives, but `AutoZygote` ADtype was provided.
56+
So a `SecondOrder` with `AutoZygote` for inner and `AutoForwardDiff` for outer will be created, for choosing another pair
5057
an explicit `SecondOrder` ADtype is recommended."
5158
end
5259

test/adtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,9 @@ optprob.cons_h(H3, x0)
223223
H2 = Array{Float64}(undef, 2, 2)
224224

225225
optf = OptimizationFunction(
226-
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
226+
rosenbrock, AutoZygote(), cons = cons)
227227
optprob = OptimizationBase.instantiate_function(
228-
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
228+
optf, x0, AutoZygote(),
229229
nothing, 1, g = true, h = true, hv = true,
230230
cons_j = true, cons_h = true, cons_vjp = true,
231231
cons_jvp = true, lag_h = true)
@@ -456,9 +456,9 @@ end
456456
H2 = Array{Float64}(undef, 2, 2)
457457

458458
optf = OptimizationFunction(
459-
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
459+
rosenbrock, AutoZygote(), cons = con2_c)
460460
optprob = OptimizationBase.instantiate_function(
461-
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
461+
optf, x0, AutoZygote(),
462462
nothing, 2, g = true, h = true, hv = true,
463463
cons_j = true, cons_h = true, cons_vjp = true,
464464
cons_jvp = true, lag_h = true)
@@ -1080,10 +1080,10 @@ end
10801080

10811081
cons = (x, p) -> [x[1]^2 + x[2]^2]
10821082
optf = OptimizationFunction{false}(rosenbrock,
1083-
SecondOrder(AutoForwardDiff(), AutoZygote()),
1083+
AutoZygote(),
10841084
cons = cons)
10851085
optprob = OptimizationBase.instantiate_function(
1086-
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
1086+
optf, x0, AutoZygote(),
10871087
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)
10881088

10891089
@test optprob.grad(x0) == G1
@@ -1096,10 +1096,10 @@ end
10961096

10971097
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
10981098
optf = OptimizationFunction{false}(rosenbrock,
1099-
SecondOrder(AutoForwardDiff(), AutoZygote()),
1099+
AutoZygote(),
11001100
cons = cons)
11011101
optprob = OptimizationBase.instantiate_function(
1102-
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
1102+
optf, x0, AutoZygote(),
11031103
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)
11041104

11051105
@test optprob.grad(x0) == G1

0 commit comments

Comments
 (0)