@@ -220,12 +220,14 @@ Hessian is not defined via Zygote.
220
220
AutoZygote
221
221
222
222
function generate_adtype (adtype)
223
- if ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ForwardMode
224
- soadtype = DifferentiationInterface. SecondOrder (adtype, AutoReverseDiff ()) # make zygote?
225
- elseif ! (adtype isa SciMLBase. NoAD) && ADTypes. mode (adtype) isa ADTypes. ReverseMode
226
- soadtype = DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype)
227
- else
223
+ if ! (adtype isa SciMLBase. NoAD && adtype isa DifferentiationInterface. SecondOrder)
224
+ soadtype = DifferentiationInterface. SecondOrder (adtype, adtype)
225
+ elseif adtype isa DifferentiationInterface. SecondOrder
226
+ soadtype = adtype
227
+ adtype = adtype. inner
228
+ elseif adtype isa SciMLBase. NoAD
228
229
soadtype = adtype
230
+ adtype = adtype
229
231
end
230
232
return adtype, soadtype
231
233
end
@@ -235,86 +237,42 @@ function generate_sparse_adtype(adtype)
235
237
adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
236
238
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
237
239
coloring_algorithm = GreedyColoringAlgorithm ())
238
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
240
+ if ! (adtype. dense_ad isa SciMLBase. NoAD &&
241
+ adtype. dense_ad isa DifferentiationInterface. SecondOrder)
239
242
soadtype = AutoSparse (
240
243
DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
241
244
sparsity_detector = TracerSparsityDetector (),
242
245
coloring_algorithm = GreedyColoringAlgorithm ())
243
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
244
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
245
- soadtype = AutoSparse (
246
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
247
- sparsity_detector = TracerSparsityDetector (),
248
- coloring_algorithm = GreedyColoringAlgorithm ()) # make zygote?
249
- elseif ! (adtype isa SciMLBase. NoAD) &&
250
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
251
- soadtype = AutoSparse (
252
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
253
- sparsity_detector = TracerSparsityDetector (),
254
- coloring_algorithm = GreedyColoringAlgorithm ())
255
246
end
256
247
elseif adtype. sparsity_detector isa ADTypes. NoSparsityDetector &&
257
248
! (adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm)
258
249
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
259
250
coloring_algorithm = adtype. coloring_algorithm)
260
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
251
+ if ! (adtype. dense_ad isa SciMLBase. NoAD &&
252
+ adtype. dense_ad isa DifferentiationInterface. SecondOrder)
261
253
soadtype = AutoSparse (
262
254
DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
263
255
sparsity_detector = TracerSparsityDetector (),
264
256
coloring_algorithm = adtype. coloring_algorithm)
265
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
266
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
267
- soadtype = AutoSparse (
268
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
269
- sparsity_detector = TracerSparsityDetector (),
270
- coloring_algorithm = adtype. coloring_algorithm)
271
- elseif ! (adtype isa SciMLBase. NoAD) &&
272
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
273
- soadtype = AutoSparse (
274
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
275
- sparsity_detector = TracerSparsityDetector (),
276
- coloring_algorithm = adtype. coloring_algorithm)
277
257
end
278
258
elseif ! (adtype. sparsity_detector isa ADTypes. NoSparsityDetector) &&
279
259
adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
280
260
adtype = AutoSparse (adtype. dense_ad; sparsity_detector = adtype. sparsity_detector,
281
261
coloring_algorithm = GreedyColoringAlgorithm ())
282
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
262
+ if ! (adtype. dense_ad isa SciMLBase. NoAD &&
263
+ adtype. dense_ad isa DifferentiationInterface. SecondOrder)
283
264
soadtype = AutoSparse (
284
265
DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
285
266
sparsity_detector = adtype. sparsity_detector,
286
267
coloring_algorithm = GreedyColoringAlgorithm ())
287
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
288
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
289
- soadtype = AutoSparse (
290
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
291
- sparsity_detector = adtype. sparsity_detector,
292
- coloring_algorithm = GreedyColoringAlgorithm ())
293
- elseif ! (adtype isa SciMLBase. NoAD) &&
294
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
295
- soadtype = AutoSparse (
296
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
297
- sparsity_detector = adtype. sparsity_detector,
298
- coloring_algorithm = GreedyColoringAlgorithm ())
299
268
end
300
269
else
301
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
270
+ if ! (adtype. dense_ad isa SciMLBase. NoAD &&
271
+ adtype. dense_ad isa DifferentiationInterface. SecondOrder)
302
272
soadtype = AutoSparse (
303
273
DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
304
274
sparsity_detector = adtype. sparsity_detector,
305
275
coloring_algorithm = adtype. coloring_algorithm)
306
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
307
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
308
- soadtype = AutoSparse (
309
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
310
- sparsity_detector = adtype. sparsity_detector,
311
- coloring_algorithm = adtype. coloring_algorithm)
312
- elseif ! (adtype isa SciMLBase. NoAD) &&
313
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
314
- soadtype = AutoSparse (
315
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
316
- sparsity_detector = adtype. sparsity_detector,
317
- coloring_algorithm = adtype. coloring_algorithm)
318
276
end
319
277
end
320
278
return adtype, soadtype
0 commit comments