Skip to content

Commit 0edc12b

Browse files
Merge pull request #112 from SciML/DIv6
Move iterator checking here and make symbolics stuff extension
2 parents 1ebca6f + a25ebfb commit 0edc12b

9 files changed

+169
-127
lines changed

Project.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.1.0"
4+
version = "2.2.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -17,24 +17,27 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1818
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1919
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
20-
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
21-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
22-
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2320

2421
[weakdeps]
2522
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2623
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2724
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
25+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
26+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2827
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
2928
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
29+
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
3030
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3131

3232
[extensions]
3333
OptimizationEnzymeExt = "Enzyme"
3434
OptimizationFiniteDiffExt = "FiniteDiff"
3535
OptimizationForwardDiffExt = "ForwardDiff"
36+
OptimizationMLDataDevicesExt = "MLDataDevices"
37+
OptimizationMLUtilsExt = "MLUtils"
3638
OptimizationMTKExt = "ModelingToolkit"
3739
OptimizationReverseDiffExt = "ReverseDiff"
40+
OptimizationSymbolicAnalysisExt = "SymbolicAnalysis"
3841
OptimizationZygoteExt = "Zygote"
3942

4043
[compat]
@@ -56,8 +59,6 @@ SciMLBase = "2"
5659
SparseConnectivityTracer = "0.6"
5760
SparseMatrixColorings = "0.4"
5861
SymbolicAnalysis = "0.3"
59-
SymbolicIndexingInterface = "0.3"
60-
Symbolics = "5.12, 6"
6162
Zygote = "0.6.67"
6263
julia = "1.10"
6364

ext/OptimizationMLDataDevicesExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationMLDataDevicesExt
2+
3+
using MLDataDevices
4+
using OptimizationBase
5+
6+
OptimizationBase.isa_dataiterator(::DeviceIterator) = true
7+
8+
end

ext/OptimizationMLUtilsExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OptimizationMLUtilsExt
2+
3+
using MLUtils
4+
using OptimizationBase
5+
6+
OptimizationBase.isa_dataiterator(::MLUtils.DataLoader) = true
7+
8+
end
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
module OptimizationSymbolicAnalysisExt
2+
3+
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
4+
using SymbolicAnalysis: AnalysisResult
5+
import Symbolics: variable, Equation, Inequality, unwrap, @variables
6+
7+
function OptimizationBase.symify_cache(
8+
f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP,
9+
CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV},
10+
prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
11+
EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
12+
try
13+
vars = if prob.u0 isa Matrix
14+
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
15+
else
16+
ArrayInterface.restructure(
17+
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
18+
end
19+
params = if prob.p isa SciMLBase.NullParameters
20+
[]
21+
elseif prob.p isa MTK.MTKParameters
22+
[variable(, i) for i in eachindex(vcat(p...))]
23+
else
24+
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
25+
end
26+
27+
if prob.u0 isa Matrix
28+
vars = vars[1]
29+
end
30+
31+
obj_expr = f.f(vars, params)
32+
33+
if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
34+
lhs = Array{Symbolics.Num}(undef, num_cons)
35+
f.cons(lhs, vars)
36+
cons = Union{Equation, Inequality}[]
37+
38+
if !isnothing(prob.lcons)
39+
for i in 1:num_cons
40+
if !isinf(prob.lcons[i])
41+
if prob.lcons[i] != prob.ucons[i]
42+
push!(cons, prob.lcons[i] lhs[i])
43+
else
44+
push!(cons, lhs[i] ~ prob.ucons[i])
45+
end
46+
end
47+
end
48+
end
49+
50+
if !isnothing(prob.ucons)
51+
for i in 1:num_cons
52+
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
53+
push!(cons, lhs[i] prob.ucons[i])
54+
end
55+
end
56+
end
57+
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
58+
(isnothing(prob.ucons) || all(isinf, prob.ucons))
59+
throw(ArgumentError("Constraints passed have no proper bounds defined.
60+
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
61+
or pass the lower and upper bounds for inequality constraints."))
62+
end
63+
cons_expr = lhs
64+
elseif !isnothing(prob.f.cons)
65+
cons_expr = f.cons(vars, params)
66+
else
67+
cons_expr = nothing
68+
end
69+
catch err
70+
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
71+
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
72+
end
73+
return obj_expr, cons_expr
74+
end
75+
76+
function analysis(obj_expr, cons_expr)
77+
if obj_expr !== nothing
78+
obj_expr = obj_expr |> Symbolics.unwrap
79+
if manifold === nothing
80+
obj_res = analyze(obj_expr)
81+
else
82+
obj_res = analyze(obj_expr, manifold)
83+
end
84+
@info "Objective Euclidean curvature: $(obj_res.curvature)"
85+
if obj_res.gcurvature !== nothing
86+
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
87+
end
88+
end
89+
90+
if cons_expr !== nothing
91+
cons_expr = cons_expr .|> Symbolics.unwrap
92+
if manifold === nothing
93+
cons_res = analyze.(cons_expr)
94+
else
95+
cons_res = analyze.(cons_expr, Ref(manifold))
96+
end
97+
for i in 1:num_cons
98+
@info "Constraints Euclidean curvature: $(cons_res[i].curvature)"
99+
100+
if cons_res[i].gcurvature !== nothing
101+
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
102+
end
103+
end
104+
end
105+
106+
return obj_res, cons_res
107+
end
108+
109+
end

ext/OptimizationZygoteExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ function OptimizationBase.instantiate_function(
332332
function hess(res, θ)
333333
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
334334
end
335-
hess_sparsity = prep_hess.coloring_result.S
335+
hess_sparsity = prep_hess.coloring_result.A
336336
hess_colors = prep_hess.coloring_result.color
337337

338338
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -415,7 +415,7 @@ function OptimizationBase.instantiate_function(
415415
J = vec(J)
416416
end
417417
end
418-
cons_jac_prototype = prep_jac.coloring_result.S
418+
cons_jac_prototype = prep_jac.coloring_result.A
419419
cons_jac_colorvec = prep_jac.coloring_result.color
420420
elseif cons !== nothing && cons_j == true
421421
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
@@ -455,7 +455,7 @@ function OptimizationBase.instantiate_function(
455455
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
456456
for i in 1:num_cons]
457457
colores = getfield.(prep_cons_hess, :coloring_result)
458-
conshess_sparsity = getfield.(colores, :S)
458+
conshess_sparsity = getfield.(colores, :A)
459459
conshess_colors = getfield.(colores, :color)
460460
function cons_h!(H, θ)
461461
for i in 1:num_cons
@@ -474,7 +474,7 @@ function OptimizationBase.instantiate_function(
474474
lag_extras = prepare_hessian(
475475
lagrangian, soadtype, x, Constant(one(eltype(x))),
476476
Constant(ones(eltype(x), num_cons)), Constant(p))
477-
lag_hess_prototype = lag_extras.coloring_result.S
477+
lag_hess_prototype = lag_extras.coloring_result.A
478478
lag_hess_colors = lag_extras.coloring_result.color
479479

480480
function lag_h!(H::AbstractMatrix, θ, σ, λ)

src/OptimizationBase.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@ if !isdefined(Base, :get_extension)
99
end
1010

1111
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
12-
using SymbolicIndexingInterface
13-
using SymbolicAnalysis
14-
using SymbolicAnalysis: AnalysisResult
15-
import Symbolics
16-
import Symbolics: variable, Equation, Inequality, unwrap, @variables
1712
import SciMLBase: OptimizationProblem,
1813
OptimizationFunction, ObjSense,
1914
MaxSense, MinSense, OptimizationStats
@@ -31,6 +26,7 @@ Base.iterate(::NullData, i = 1) = nothing
3126
Base.length(::NullData) = 0
3227

3328
include("adtypes.jl")
29+
include("symify.jl")
3430
include("cache.jl")
3531
include("OptimizationDIExt.jl")
3632
include("OptimizationDISparseExt.jl")

src/OptimizationDISparseExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function instantiate_function(
6565
function hess(res, θ)
6666
hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p))
6767
end
68-
hess_sparsity = prep_hess.coloring_result.S
68+
hess_sparsity = prep_hess.coloring_result.A
6969
hess_colors = prep_hess.coloring_result.color
7070

7171
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -147,7 +147,7 @@ function instantiate_function(
147147
J = vec(J)
148148
end
149149
end
150-
cons_jac_prototype = prep_jac.coloring_result.S
150+
cons_jac_prototype = prep_jac.coloring_result.A
151151
cons_jac_colorvec = prep_jac.coloring_result.color
152152
elseif cons_j === true && f.cons !== nothing
153153
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
@@ -185,7 +185,7 @@ function instantiate_function(
185185
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
186186
for i in 1:num_cons]
187187
colores = getfield.(prep_cons_hess, :coloring_result)
188-
conshess_sparsity = getfield.(colores, :S)
188+
conshess_sparsity = getfield.(colores, :A)
189189
conshess_colors = getfield.(colores, :color)
190190
function cons_h!(H, θ)
191191
for i in 1:num_cons
@@ -204,7 +204,7 @@ function instantiate_function(
204204
lag_prep = prepare_hessian(
205205
lagrangian, soadtype, x, Constant(one(eltype(x))),
206206
Constant(ones(eltype(x), num_cons)), Constant(p))
207-
lag_hess_prototype = lag_prep.coloring_result.S
207+
lag_hess_prototype = lag_prep.coloring_result.A
208208
lag_hess_colors = lag_prep.coloring_result.color
209209

210210
function lag_h!(H::AbstractMatrix, θ, σ, λ)
@@ -357,7 +357,7 @@ function instantiate_function(
357357
function hess(θ)
358358
hessian(f.f, prep_hess, soadtype, θ, Constant(p))
359359
end
360-
hess_sparsity = prep_hess.coloring_result.S
360+
hess_sparsity = prep_hess.coloring_result.A
361361
hess_colors = prep_hess.coloring_result.color
362362

363363
if p !== SciMLBase.NullParameters() && p !== nothing
@@ -410,7 +410,7 @@ function instantiate_function(
410410
end
411411
return J
412412
end
413-
cons_jac_prototype = prep_jac.coloring_result.S
413+
cons_jac_prototype = prep_jac.coloring_result.A
414414
cons_jac_colorvec = prep_jac.coloring_result.color
415415
elseif cons_j === true && f.cons !== nothing
416416
cons_j! = (θ) -> f.cons_j(θ, p)
@@ -459,7 +459,7 @@ function instantiate_function(
459459
return H
460460
end
461461
colores = getfield.(prep_cons_hess, :coloring_result)
462-
conshess_sparsity = getfield.(colores, :S)
462+
conshess_sparsity = getfield.(colores, :A)
463463
conshess_colors = getfield.(colores, :color)
464464
elseif cons_h == true && f.cons !== nothing
465465
cons_h! = (res, θ) -> f.cons_h(res, θ, p)
@@ -482,7 +482,7 @@ function instantiate_function(
482482
return hess
483483
end
484484
end
485-
lag_hess_prototype = lag_prep.coloring_result.S
485+
lag_hess_prototype = lag_prep.coloring_result.A
486486
lag_hess_colors = lag_prep.coloring_result.color
487487

488488
if p !== SciMLBase.NullParameters() && p !== nothing

0 commit comments

Comments
 (0)