Skip to content

Commit cbc8e97

Browse files
Merge pull request #910 from SciML/ChrisRackauckas-patch-1
Simplify calls to Optim.jl
2 parents 93eb623 + 2f3366e commit cbc8e97

File tree

1 file changed

+39
-16
lines changed

1 file changed

+39
-16
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,22 @@ function SciMLBase.__solve(cache::OptimizationCache{
393393
end
394394
end
395395
u0_type = eltype(cache.u0)
396-
optim_f = Optim.TwiceDifferentiable(_loss, gg, fg!, hh, cache.u0,
397-
real(zero(u0_type)),
398-
Optim.NLSolversBase.alloc_DF(cache.u0,
399-
real(zero(u0_type))),
400-
isnothing(cache.f.hess_prototype) ?
401-
Optim.NLSolversBase.alloc_H(cache.u0,
402-
real(zero(u0_type))) :
403-
convert.(u0_type, cache.f.hess_prototype))
396+
397+
optim_f = if SciMLBase.requireshessian(cache.opt)
398+
Optim.TwiceDifferentiable(_loss, gg, fg!, hh, cache.u0,
399+
real(zero(u0_type)),
400+
Optim.NLSolversBase.alloc_DF(cache.u0,
401+
real(zero(u0_type))),
402+
isnothing(cache.f.hess_prototype) ?
403+
Optim.NLSolversBase.alloc_H(cache.u0,
404+
real(zero(u0_type))) :
405+
convert.(u0_type, cache.f.hess_prototype))
406+
else
407+
Optim.OnceDifferentiable(_loss, gg, fg!, cache.u0,
408+
real(zero(u0_type)),
409+
Optim.NLSolversBase.alloc_DF(cache.u0,
410+
real(zero(u0_type))))
411+
end
404412

405413
cons_hl! = function (h, θ, λ)
406414
res = [similar(h) for i in 1:length(λ)]
@@ -412,15 +420,26 @@ function SciMLBase.__solve(cache::OptimizationCache{
412420

413421
lb = cache.lb === nothing ? [] : cache.lb
414422
ub = cache.ub === nothing ? [] : cache.ub
415-
if cache.f.cons !== nothing
416-
optim_fc = Optim.TwiceDifferentiableConstraints(cache.f.cons, cache.f.cons_j,
417-
cons_hl!,
418-
lb, ub,
419-
cache.lcons, cache.ucons)
423+
424+
optim_fc = if SciMLBase.requireshessian(cache.opt)
425+
if cache.f.cons !== nothing
426+
Optim.TwiceDifferentiableConstraints(cache.f.cons, cache.f.cons_j,
427+
cons_hl!,
428+
lb, ub,
429+
cache.lcons, cache.ucons)
430+
else
431+
Optim.TwiceDifferentiableConstraints(lb, ub)
432+
end
420433
else
421-
optim_fc = Optim.TwiceDifferentiableConstraints(lb, ub)
434+
if cache.f.cons !== nothing
435+
Optim.OnceDifferentiableConstraints(cache.f.cons, cache.f.cons_j,
436+
lb, ub,
437+
cache.lcons, cache.ucons)
438+
else
439+
Optim.OnceDifferentiableConstraints(lb, ub)
440+
end
422441
end
423-
442+
424443
opt_args = __map_optimizer_args(cache, cache.opt, callback = _cb,
425444
maxiters = cache.solver_args.maxiters,
426445
maxtime = cache.solver_args.maxtime,
@@ -429,7 +448,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
429448
cache.solver_args...)
430449

431450
t0 = time()
432-
opt_res = Optim.optimize(optim_f, optim_fc, cache.u0, cache.opt, opt_args)
451+
if lb === nothing && ub === nothing && cache.f.cons === nothing
452+
opt_res = Optim.optimize(optim_f, cache.u0, cache.opt, opt_args)
453+
else
454+
opt_res = Optim.optimize(optim_f, optim_fc, cache.u0, cache.opt, opt_args)
455+
end
433456
t1 = time()
434457
opt_ret = Symbol(Optim.converged(opt_res))
435458
stats = Optimization.OptimizationStats(; iterations = opt_res.iterations,

0 commit comments

Comments
 (0)