Skip to content

Commit c850508

Browse files
tests pass
1 parent 043a01c commit c850508

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ function OptimizationBase.instantiate_function(
101101
if hv == true && f.hv === nothing
102102
prep_hvp = prepare_hvp(f.f, soadtype, x, (zeros(eltype(x), size(x)),), Constant(p))
103103
function hv!(H, θ, v)
104-
hvp!(f.f, H, prep_hvp, soadtype, θ, (v,), Constant(p))
104+
hvp!(f.f, (H,), prep_hvp, soadtype, θ, (v,), Constant(p))
105105
end
106106
if p !== SciMLBase.NullParameters() && p !== nothing
107107
function hv!(H, θ, v, p)
108-
hvp!(f.f, H, prep_hvp, soadtype, θ, (v,), Constant(p))
108+
hvp!(f.f, (H,), prep_hvp, soadtype, θ, (v,), Constant(p))
109109
end
110110
end
111111
elseif hv == true
@@ -141,9 +141,9 @@ function OptimizationBase.instantiate_function(
141141
cons_jac_prototype = f.cons_jac_prototype
142142
cons_jac_colorvec = f.cons_jac_colorvec
143143
if cons !== nothing && cons_j == true && f.cons_j === nothing
144-
prep_jac = prepare_jacobian(cons_oop, adtype, x, Constant(p))
144+
prep_jac = prepare_jacobian(cons_oop, adtype, x)
145145
function cons_j!(J, θ)
146-
jacobian!(cons_oop, J, prep_jac, adtype, θ, Constant(p))
146+
jacobian!(cons_oop, J, prep_jac, adtype, θ)
147147
if size(J, 1) == 1
148148
J = vec(J)
149149
end

src/OptimizationDISparseExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,8 @@ function instantiate_function(
210210
end
211211

212212
function cons_oop(x, i)
213-
_res = zeros(eltype(x))
213+
_res = zeros(eltype(x), num_cons)
214214
f.cons(_res, x, p)
215-
@show _res
216215
return _res[i]
217216
end
218217

test/adtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ optprob.cons_h(H3, x0)
202202
optprob.cons(res, x0)
203203
@test res == [0.0]
204204
J = Array{Float64}(undef, 2)
205-
@test optprob.cons_j(J, [5.0, 3.0])
205+
optprob.cons_j(J, [5.0, 3.0])
206206
@test J == [10.0, 6.0]
207207
vJ = Array{Float64}(undef, 2)
208208
optprob.cons_vjp(vJ, [5.0, 3.0], [1.0])

0 commit comments

Comments
 (0)