Skip to content
Merged
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Logging = "1.10"
LoggingExtras = "0.4, 1"
Lux = "1.12.4"
MLUtils = "0.4"
ModelingToolkit = "10.23"
ModelingToolkit = "11"
Mooncake = "0.4.138"
Optim = ">= 1.4.1"
Optimisers = ">= 0.2.5"
Expand All @@ -64,7 +64,7 @@ SafeTestsets = "0.1"
SciMLBase = "2.122.1"
SciMLSensitivity = "7"
SparseArrays = "1.10"
Symbolics = "6"
Symbolics = "6, 7"
TerminalLoggers = "0.1"
Test = "1.10"
Tracker = "0.2"
Expand Down
5 changes: 1 addition & 4 deletions lib/OptimizationBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -34,15 +33,14 @@ OptimizationFiniteDiffExt = "FiniteDiff"
OptimizationForwardDiffExt = "ForwardDiff"
OptimizationMLDataDevicesExt = "MLDataDevices"
OptimizationMLUtilsExt = "MLUtils"
OptimizationMTKExt = "ModelingToolkit"
OptimizationReverseDiffExt = "ReverseDiff"
OptimizationSymbolicAnalysisExt = "SymbolicAnalysis"
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "1.14"
ArrayInterface = "7.6"
DifferentiationInterface = "0.7"
DifferentiationInterface = "0.7.13"
DocStringExtensions = "0.9"
Enzyme = "0.13.2"
FastClosures = "0.3"
Expand All @@ -51,7 +49,6 @@ ForwardDiff = "0.10.26, 1"
LinearAlgebra = "1.9, 1.10"
MLDataDevices = "1"
MLUtils = "0.4"
ModelingToolkit = "10.23"
PDMats = "0.11"
Reexport = "1.2"
ReverseDiff = "1.14"
Expand Down
213 changes: 213 additions & 0 deletions lib/OptimizationBase/ext/OptimizationMTKBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
module OptimizationMTKBaseExt

import OptimizationBase, OptimizationBase.ArrayInterface
import SciMLBase
import SciMLBase: OptimizationFunction
import OptimizationBase.ADTypes: AutoSymbolics, AutoSparse
using ModelingToolkitBase

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics}, p,
num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
p = isnothing(p) ? SciMLBase.NullParameters() : p

sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, x, p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
# don't need to pass `x` or `p` since they're defaults now
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
sparse = true, cons_j = cons_j, cons_h = cons_h,
cons_sparse = true)
f = mtkprob.f

grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)

hv = function (H, θ, v, args...)
res = similar(f.hess_prototype, eltype(θ))
hess(res, θ, args...)
H .= res * v
end

if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p

sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, cache.u0,
cache.p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
# don't need to pass `x` or `p` since they're defaults now
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
sparse = true, cons_j = cons_j, cons_h = cons_h,
cons_sparse = true)
f = mtkprob.f

grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)

hv = function (H, θ, v, args...)
res = similar(f.hess_prototype, eltype(θ))
hess(res, θ, args...)
H .= res * v
end
if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p,
num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
p = isnothing(p) ? SciMLBase.NullParameters() : p

sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, x, p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
# don't need to pass `x` or `p` since they're defaults now
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
sparse = false, cons_j = cons_j, cons_h = cons_h,
cons_sparse = false)
f = mtkprob.f

grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)

hv = function (H, θ, v, args...)
res = ArrayInterface.zeromatrix(θ)
hess(res, θ, args...)
H .= res * v
end

if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
adtype::AutoSymbolics, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
lag_h = false)
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p

sys = complete(ModelingToolkitBase.modelingtoolkitize(OptimizationProblem(f, cache.u0,
cache.p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
# don't need to pass `x` or `p` since they're defaults now
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
sparse = false, cons_j = cons_j, cons_h = cons_h,
cons_sparse = false)
f = mtkprob.f

grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)

hv = function (H, θ, v, args...)
res = ArrayInterface.zeromatrix(θ)
hess(res, θ, args...)
H .= res * v
end

if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

end
Loading
Loading