diff --git a/src/debug.jl b/src/debug.jl index 3593779..9daddd9 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -1,5 +1,15 @@ using .BipartiteGraphs: Label, BipartiteAdjacencyList + +struct SSAUses{T} + eqs::Vector{T} # equation uses + vars::Vector{T} # variable uses +end +Base.eltype(info::SSAUses{T}) where {T} = T +Base.copy(uses::SSAUses) = SSAUses(copy(uses.eqs), copy(uses.vars)) + +ssa_uses(::SystemStructure) = nothing + struct SystemStructurePrintMatrix <: AbstractMatrix{Union{Label, BipartiteAdjacencyList}} bpg::BipartiteGraph @@ -7,6 +17,7 @@ struct SystemStructurePrintMatrix <: var_to_diff::DiffGraph eq_to_diff::DiffGraph var_eq_matching::Union{Matching, Nothing} + ssa_uses::Union{Nothing, SSAUses} end """ @@ -19,31 +30,36 @@ function SystemStructurePrintMatrix(s::SystemStructure) complete(s.solvable_graph), complete(s.var_to_diff), complete(s.eq_to_diff), - nothing) + nothing, + ssa_uses(s)) end -Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 9) +Base.size(bgpm::SystemStructurePrintMatrix) = (max(nsrcs(bgpm.bpg), ndsts(bgpm.bpg)) + 1, 9 + 2(bgpm.ssa_uses !== nothing)) function compute_diff_label(diff_graph, i, symbol) di = i - 1 <= length(diff_graph) ? diff_graph[i - 1] : nothing return di === nothing ? Label("") : Label(string(di, symbol)) end function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer) checkbounds(bgpm, i, j) + if bgpm.ssa_uses === nothing + # Skip SSAUse-related columns. + j += j ≥ 5 + j ≥ 11 + end if i <= 1 - return (Label.(("# eq", "∂ₜ", " ", " ", "", "# v", "∂ₜ", " ", " ")))[j] - elseif j == 5 + return (Label.(("# eq", "∂ₜ", " ", " ", "%", "", "# v", "∂ₜ", " ", " ", "%")))[j] + elseif j == 6 colors = Base.text_colors return Label("|", :light_black) elseif j == 2 return compute_diff_label(bgpm.eq_to_diff, i, '↑') elseif j == 3 return compute_diff_label(invview(bgpm.eq_to_diff), i, '↓') - elseif j == 7 - return compute_diff_label(bgpm.var_to_diff, i, '↑') elseif j == 8 + return compute_diff_label(bgpm.var_to_diff, i, '↑') + elseif j == 9 return compute_diff_label(invview(bgpm.var_to_diff), i, '↓') elseif j == 1 return Label((i - 1 <= length(bgpm.eq_to_diff)) ? string(i - 1) : "") - elseif j == 6 + elseif j == 7 return Label((i - 1 <= length(bgpm.var_to_diff)) ? string(i - 1) : "") elseif j == 4 return BipartiteAdjacencyList( @@ -56,7 +72,11 @@ function Base.getindex(bgpm::SystemStructurePrintMatrix, i::Integer, j::Integer) bgpm.var_eq_matching !== nothing && (i - 1 <= length(invview(bgpm.var_eq_matching))) ? invview(bgpm.var_eq_matching)[i - 1] : unassigned) - elseif j == 9 + elseif j == 5 + return get(bgpm.ssa_uses.eqs, i - 1, Label("")) + elseif j == 11 + return get(bgpm.ssa_uses.vars, i - 1, Label("")) + elseif j == 10 match = unassigned if bgpm.var_eq_matching !== nothing && i - 1 <= length(bgpm.var_eq_matching) match = bgpm.var_eq_matching[i - 1] @@ -95,7 +115,11 @@ end struct MatchedSystemStructure structure::SystemStructure var_eq_matching::Matching + ssa_uses::Union{Nothing, SSAUses} end +MatchedSystemStructure(structure, var_eq_matching) = MatchedSystemStructure(structure, var_eq_matching, nothing) + +ssa_uses(ms::MatchedSystemStructure) = ms.ssa_uses """ Create a SystemStructurePrintMatrix to display the contents @@ -106,12 +130,12 @@ function SystemStructurePrintMatrix(ms::MatchedSystemStructure) complete(ms.structure.solvable_graph), complete(ms.structure.var_to_diff), complete(ms.structure.eq_to_diff), - complete(ms.var_eq_matching, - nsrcs(ms.structure.graph))) + complete(ms.var_eq_matching, nsrcs(ms.structure.graph)), + ms.ssa_uses) end function Base.copy(ms::MatchedSystemStructure) - MatchedSystemStructure(Base.copy(ms.structure), Base.copy(ms.var_eq_matching)) + MatchedSystemStructure(copy(ms.structure), copy(ms.var_eq_matching), copy(ms.ssa_uses)) end function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) @@ -136,4 +160,11 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) printstyled(io, string(" ", symbol); color) printstyled(io, string(" ", label)) end + if ms.ssa_uses !== nothing + T = eltype(ms.ssa_uses) + (symbol, label, color) = BipartiteGraphs.overview_label(T) + print(io, " | ") + printstyled(io, string(" ", symbol); color) + printstyled(io, string(" ", label)) + end end diff --git a/src/interface.jl b/src/interface.jl index 1cd7523..a572fca 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -5,7 +5,7 @@ abstract type TransformationState{T} end abstract type AbstractTearingState{T} <: TransformationState{T} end struct SelectedState end -BipartiteGraphs.overview_label(::Type{SelectedState}) = ('∫', " Seleced State", :cyan) +BipartiteGraphs.overview_label(::Type{SelectedState}) = ('∫', " Selected State", :cyan) function linear_subsys_adjmat! end function eq_derivative! end