Skip to content

Commit f8d61fd

Browse files
Remove autmatic FoR soadtype creations
1 parent 1cb8a90 commit f8d61fd

File tree

1 file changed

+15
-57
lines changed

1 file changed

+15
-57
lines changed

src/adtypes.jl

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,14 @@ Hessian is not defined via Zygote.
220220
AutoZygote
221221

222222
function generate_adtype(adtype)
223-
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
224-
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
225-
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
226-
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
227-
else
223+
if !(adtype isa SciMLBase.NoAD && adtype isa DifferentiationInterface.SecondOrder)
224+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
225+
elseif adtype isa DifferentiationInterface.SecondOrder
226+
soadtype = adtype
227+
adtype = adtype.inner
228+
elseif adtype isa SciMLBase.NoAD
228229
soadtype = adtype
230+
adtype = adtype
229231
end
230232
return adtype, soadtype
231233
end
@@ -235,86 +237,42 @@ function generate_sparse_adtype(adtype)
235237
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
236238
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
237239
coloring_algorithm = GreedyColoringAlgorithm())
238-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
240+
if !(adtype.dense_ad isa SciMLBase.NoAD &&
241+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
239242
soadtype = AutoSparse(
240243
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
241244
sparsity_detector = TracerSparsityDetector(),
242245
coloring_algorithm = GreedyColoringAlgorithm())
243-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
244-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
245-
soadtype = AutoSparse(
246-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
247-
sparsity_detector = TracerSparsityDetector(),
248-
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
249-
elseif !(adtype isa SciMLBase.NoAD) &&
250-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
251-
soadtype = AutoSparse(
252-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
253-
sparsity_detector = TracerSparsityDetector(),
254-
coloring_algorithm = GreedyColoringAlgorithm())
255246
end
256247
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
257248
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
258249
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
259250
coloring_algorithm = adtype.coloring_algorithm)
260-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
251+
if !(adtype.dense_ad isa SciMLBase.NoAD &&
252+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
261253
soadtype = AutoSparse(
262254
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
263255
sparsity_detector = TracerSparsityDetector(),
264256
coloring_algorithm = adtype.coloring_algorithm)
265-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
266-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
267-
soadtype = AutoSparse(
268-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
269-
sparsity_detector = TracerSparsityDetector(),
270-
coloring_algorithm = adtype.coloring_algorithm)
271-
elseif !(adtype isa SciMLBase.NoAD) &&
272-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
273-
soadtype = AutoSparse(
274-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
275-
sparsity_detector = TracerSparsityDetector(),
276-
coloring_algorithm = adtype.coloring_algorithm)
277257
end
278258
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
279259
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
280260
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
281261
coloring_algorithm = GreedyColoringAlgorithm())
282-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
262+
if !(adtype.dense_ad isa SciMLBase.NoAD &&
263+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
283264
soadtype = AutoSparse(
284265
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
285266
sparsity_detector = adtype.sparsity_detector,
286267
coloring_algorithm = GreedyColoringAlgorithm())
287-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
288-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
289-
soadtype = AutoSparse(
290-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
291-
sparsity_detector = adtype.sparsity_detector,
292-
coloring_algorithm = GreedyColoringAlgorithm())
293-
elseif !(adtype isa SciMLBase.NoAD) &&
294-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
295-
soadtype = AutoSparse(
296-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
297-
sparsity_detector = adtype.sparsity_detector,
298-
coloring_algorithm = GreedyColoringAlgorithm())
299268
end
300269
else
301-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
270+
if !(adtype.dense_ad isa SciMLBase.NoAD &&
271+
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
302272
soadtype = AutoSparse(
303273
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
304274
sparsity_detector = adtype.sparsity_detector,
305275
coloring_algorithm = adtype.coloring_algorithm)
306-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
307-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
308-
soadtype = AutoSparse(
309-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
310-
sparsity_detector = adtype.sparsity_detector,
311-
coloring_algorithm = adtype.coloring_algorithm)
312-
elseif !(adtype isa SciMLBase.NoAD) &&
313-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
314-
soadtype = AutoSparse(
315-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
316-
sparsity_detector = adtype.sparsity_detector,
317-
coloring_algorithm = adtype.coloring_algorithm)
318276
end
319277
end
320278
return adtype, soadtype

0 commit comments

Comments
 (0)