Skip to content

Commit 94061d8

Browse files
Add constant parameters and try to update Enzyme
1 parent 4e2968f commit 94061d8

File tree

5 files changed

+273
-380
lines changed

5 files changed

+273
-380
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ using Core: Vararg
1818
end
1919

2020
function inner_grad(θ, bθ, f, p)
21-
Enzyme.autodiff_deferred(Enzyme.Reverse,
22-
Const(firstapply),
21+
Enzyme.autodiff(Enzyme.Reverse,
22+
firstapply,
2323
Active,
2424
Const(f),
2525
Enzyme.Duplicated(θ, bθ),
@@ -29,9 +29,8 @@ function inner_grad(θ, bθ, f, p)
2929
end
3030

3131
function inner_grad_primal(θ, bθ, f, p)
32-
Enzyme.autodiff_deferred(Enzyme.ReverseWithPrimal,
33-
Const(firstapply),
34-
Active,
32+
Enzyme.autodiff(Enzyme.ReverseWithPrimal,
33+
firstapply,
3534
Const(f),
3635
Enzyme.Duplicated(θ, bθ),
3736
Const(p)
@@ -40,9 +39,8 @@ end
4039

4140
function hv_f2_alloc(x, f, p)
4241
dx = Enzyme.make_zero(x)
43-
Enzyme.autodiff_deferred(Enzyme.Reverse,
42+
Enzyme.autodiff(Enzyme.Reverse,
4443
firstapply,
45-
Active,
4644
Const(f),
4745
Enzyme.Duplicated(x, dx),
4846
Const(p)
@@ -58,7 +56,7 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5856
end
5957

6058
function cons_f2(x, dx, fcons, p, num_cons, i)
61-
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx),
59+
Enzyme.autodiff(Enzyme.Reverse, inner_cons, Enzyme.Duplicated(x, dx),
6260
Const(fcons), Const(p), Const(num_cons), Const(i))
6361
return nothing
6462
end
@@ -70,8 +68,8 @@ function inner_cons_oop(
7068
end
7169

7270
function cons_f2_oop(x, dx, fcons, p, i)
73-
Enzyme.autodiff_deferred(
74-
Enzyme.Reverse, inner_cons_oop, Active, Enzyme.Duplicated(x, dx),
71+
Enzyme.autodiff(
72+
Enzyme.Reverse, inner_cons_oop, Enzyme.Duplicated(x, dx),
7573
Const(fcons), Const(p), Const(i))
7674
return nothing
7775
end
@@ -83,7 +81,7 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8381
end
8482

8583
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
86-
Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
84+
Enzyme.autodiff(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
8785
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
8886
return nothing
8987
end

0 commit comments

Comments
 (0)