Skip to content

Commit 1cb8a90

Browse files
Merge pull request #116 from SciML/DIv6
fix symbolic analysis dispatches
2 parents 0edc12b + 8acf17c commit 1cb8a90

File tree

8 files changed

+83
-70
lines changed

8 files changed

+83
-70
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.2.0"
4+
version = "2.2.1"
5+
56

67
[deps]
78
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/OptimizationSymbolicAnalysisExt.jl

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,84 @@
11
module OptimizationSymbolicAnalysisExt
22

3-
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
3+
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics,
4+
OptimizationBase.ArrayInterface
45
using SymbolicAnalysis: AnalysisResult
5-
import Symbolics: variable, Equation, Inequality, unwrap, @variables
6+
import SymbolicAnalysis.Symbolics: variable, Equation, Inequality, unwrap, @variables
67

78
function OptimizationBase.symify_cache(
89
f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP,
910
CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV},
10-
prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
11+
prob, num_cons,
12+
manifold) where {
13+
iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
1114
EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
12-
try
13-
vars = if prob.u0 isa Matrix
14-
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
15-
else
16-
ArrayInterface.restructure(
17-
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
18-
end
19-
params = if prob.p isa SciMLBase.NullParameters
20-
[]
21-
elseif prob.p isa MTK.MTKParameters
22-
[variable(, i) for i in eachindex(vcat(p...))]
23-
else
24-
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
25-
end
15+
obj_expr = f.expr
16+
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
2617

27-
if prob.u0 isa Matrix
28-
vars = vars[1]
29-
end
18+
if obj_expr === nothing || cons_expr === nothing
19+
try
20+
vars = if prob.u0 isa Matrix
21+
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
22+
else
23+
ArrayInterface.restructure(
24+
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
25+
end
26+
params = if prob.p isa SciMLBase.NullParameters
27+
[]
28+
elseif prob.p isa MTK.MTKParameters
29+
[variable(, i) for i in eachindex(vcat(p...))]
30+
else
31+
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
32+
end
33+
34+
if prob.u0 isa Matrix
35+
vars = vars[1]
36+
end
3037

31-
obj_expr = f.f(vars, params)
38+
if obj_expr === nothing
39+
obj_expr = f.f(vars, params)
40+
end
3241

33-
if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
34-
lhs = Array{Symbolics.Num}(undef, num_cons)
35-
f.cons(lhs, vars)
36-
cons = Union{Equation, Inequality}[]
42+
if cons_expr === nothing && SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
43+
lhs = Array{Symbolics.Num}(undef, num_cons)
44+
f.cons(lhs, vars)
45+
cons = Union{Equation, Inequality}[]
3746

38-
if !isnothing(prob.lcons)
39-
for i in 1:num_cons
40-
if !isinf(prob.lcons[i])
41-
if prob.lcons[i] != prob.ucons[i]
42-
push!(cons, prob.lcons[i] lhs[i])
43-
else
44-
push!(cons, lhs[i] ~ prob.ucons[i])
47+
if !isnothing(prob.lcons)
48+
for i in 1:num_cons
49+
if !isinf(prob.lcons[i])
50+
if prob.lcons[i] != prob.ucons[i]
51+
push!(cons, prob.lcons[i] lhs[i])
52+
else
53+
push!(cons, lhs[i] ~ prob.ucons[i])
54+
end
4555
end
4656
end
4757
end
48-
end
4958

50-
if !isnothing(prob.ucons)
51-
for i in 1:num_cons
52-
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
53-
push!(cons, lhs[i] prob.ucons[i])
59+
if !isnothing(prob.ucons)
60+
for i in 1:num_cons
61+
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
62+
push!(cons, lhs[i] prob.ucons[i])
63+
end
5464
end
5565
end
66+
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
67+
(isnothing(prob.ucons) || all(isinf, prob.ucons))
68+
throw(ArgumentError("Constraints passed have no proper bounds defined.
69+
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
70+
or pass the lower and upper bounds for inequality constraints."))
71+
end
72+
cons_expr = lhs
73+
elseif cons_expr === nothing && !isnothing(prob.f.cons)
74+
cons_expr = f.cons(vars, params)
5675
end
57-
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
58-
(isnothing(prob.ucons) || all(isinf, prob.ucons))
59-
throw(ArgumentError("Constraints passed have no proper bounds defined.
60-
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
61-
or pass the lower and upper bounds for inequality constraints."))
62-
end
63-
cons_expr = lhs
64-
elseif !isnothing(prob.f.cons)
65-
cons_expr = f.cons(vars, params)
66-
else
67-
cons_expr = nothing
76+
catch err
77+
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
78+
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
6879
end
69-
catch err
70-
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
71-
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
7280
end
73-
return obj_expr, cons_expr
74-
end
7581

76-
function analysis(obj_expr, cons_expr)
7782
if obj_expr !== nothing
7883
obj_expr = obj_expr |> Symbolics.unwrap
7984
if manifold === nothing
@@ -85,6 +90,8 @@ function analysis(obj_expr, cons_expr)
8590
if obj_res.gcurvature !== nothing
8691
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
8792
end
93+
else
94+
obj_res = nothing
8895
end
8996

9097
if cons_expr !== nothing
@@ -101,6 +108,8 @@ function analysis(obj_expr, cons_expr)
101108
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
102109
end
103110
end
111+
else
112+
cons_res = nothing
104113
end
105114

106115
return obj_res, cons_res

src/cache.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
5050
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))
5151

5252
if structural_analysis
53-
obj_expr, cons_expr = symify_cache(f, prob)
54-
try
55-
obj_res, cons_res = analysis(obj_expr, cons_expr)
56-
catch err
57-
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
58-
end
53+
obj_res, cons_res = symify_cache(f, prob, num_cons, manifold)
5954
else
6055
obj_res = nothing
6156
cons_res = nothing

src/symify.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
function symify_cache(f::OptimizationFunction, prob)
2-
obj_expr = f.expr
3-
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))
4-
5-
return obj_expr, cons_expr
1+
function symify_cache(f::OptimizationFunction, prob, num_cons, manifold)
2+
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
63
end

test/Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,22 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1314
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1415
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
1516
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1617
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1718
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
19+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
20+
OptimizationManopt = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6"
1821
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1922
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2023
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2124
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2225
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
26+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
27+
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
2328
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2429
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2530
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -35,4 +40,8 @@ Lux = ">= 0.4.50"
3540
Manifolds = "0.9"
3641
Optim = ">= 1.4.1"
3742
Optimisers = ">= 0.2.5"
43+
Optimization = "4"
44+
OptimizationManopt = "0.0.4"
45+
SparseConnectivityTracer = "0.6"
46+
SymbolicAnalysis = "0.3.0"
3847
SafeTestsets = ">= 0.0.1"

test/cvxtest.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ optf = OptimizationFunction(rosenbrock, AutoZygote(), cons = con2_c)
2929
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ucons = [1.0, 0.0],
3030
lb = [-1.0, -1.0], ub = [1.0, 1.0], structural_analysis = true)
3131
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
32-
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.Convex
32+
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
3333
@test res.cache.analysis_results.constraints[1].curvature == SymbolicAnalysis.Convex
3434
@test res.cache.analysis_results.constraints[2].curvature ==
3535
SymbolicAnalysis.UnknownCurvature
@@ -46,7 +46,7 @@ optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
4646
prob = OptimizationProblem(optf, data2[1]; manifold = M, structural_analysis = true)
4747

4848
opt = OptimizationManopt.GradientDescentOptimizer()
49-
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 100)
50-
@test sol.minimizer < 1e-1
49+
@time sol = solve(prob, opt, maxiters = 100)
50+
@test sol.minimum < 1e-1
5151
@test sol.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
5252
@test sol.cache.analysis_results.objective.gcurvature == SymbolicAnalysis.GConvex

test/matrixvalued.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ using Test, ReverseDiff
7878
Omega_mc = rand(4, 4) .> 0.5 # Mask for observed entries (boolean matrix)
7979
X_mc = rand(4, 4) # Matrix to be completed
8080
optf = OptimizationFunction{false}(
81-
matrix_completion_objective, adtype, cons = rank_constraint)
81+
matrix_completion_objective, adtype)
8282
optf = OptimizationBase.instantiate_function(
8383
optf, X_mc, adtype, (A_mc, Omega_mc), g = true, h = true)
8484
optf.grad(X_mc)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ using Test
33

44
@testset "OptimizationBase.jl" begin
55
include("adtests.jl")
6+
include("cvxtest.jl")
7+
include("matrixvalued.jl")
68
end

0 commit comments

Comments
 (0)