Skip to content

Commit 5cdad75

Browse files
authored
Implement fast coloring without decompression prep (#200)
1 parent e5ce73c commit 5cdad75

File tree

5 files changed

+180
-71
lines changed

5 files changed

+180
-71
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ SparseMatrixColorings
1515

1616
```@docs
1717
coloring
18+
fast_coloring
1819
ColoringProblem
1920
GreedyColoringAlgorithm
2021
ConstantColoringAlgorithm

src/SparseMatrixColorings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export NaturalOrder, RandomOrder, LargestFirst
5959
export DynamicDegreeBasedOrder, SmallestLast, IncidenceDegree, DynamicLargestFirst
6060
export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
6161
export ConstantColoringAlgorithm
62-
export coloring
62+
export coloring, fast_coloring
6363
export column_colors, row_colors, ncolors
6464
export column_groups, row_groups
6565
export sparsity_pattern

src/interface.jl

Lines changed: 117 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ function GreedyColoringAlgorithm(
116116
return GreedyColoringAlgorithm{decompression,typeof(order)}(order, postprocessing)
117117
end
118118

119+
## Coloring
120+
121+
abstract type WithOrWithoutResult end
122+
struct WithResult <: WithOrWithoutResult end
123+
struct WithoutResult <: WithOrWithoutResult end
124+
119125
"""
120126
coloring(
121127
S::AbstractMatrix,
@@ -175,99 +181,166 @@ julia> collect.(column_groups(result))
175181
- [`compress`](@ref)
176182
- [`decompress`](@ref)
177183
"""
178-
function coloring end
179-
180184
function coloring(
181185
A::AbstractMatrix,
182-
::ColoringProblem{:nonsymmetric,:column},
186+
problem::ColoringProblem,
187+
algo::GreedyColoringAlgorithm;
188+
decompression_eltype::Type{R}=Float64,
189+
symmetric_pattern::Bool=false,
190+
) where {R}
191+
return _coloring(WithResult(), A, problem, algo, R, symmetric_pattern)
192+
end
193+
194+
"""
195+
fast_coloring(
196+
S::AbstractMatrix,
197+
problem::ColoringProblem,
198+
algo::GreedyColoringAlgorithm;
199+
[symmetric_pattern=false]
200+
)
201+
202+
Solve a [`ColoringProblem`](@ref) on the matrix `S` with a [`GreedyColoringAlgorithm`](@ref) and return
203+
204+
- a single color vector for `:column` and `:row` problems
205+
- a tuple of color vectors for `:bidirectional` problems
206+
207+
This function is very similar to [`coloring`](@ref), but it skips the computation of an [`AbstractColoringResult`](@ref) to speed things up.
208+
209+
# See also
210+
211+
- [`coloring`](@ref)
212+
"""
213+
function fast_coloring(
214+
A::AbstractMatrix,
215+
problem::ColoringProblem,
183216
algo::GreedyColoringAlgorithm;
184-
decompression_eltype::Type=Float64,
185217
symmetric_pattern::Bool=false,
218+
)
219+
return _coloring(WithoutResult(), A, problem, algo, Float64, symmetric_pattern)
220+
end
221+
222+
function _coloring(
223+
speed_setting::WithOrWithoutResult,
224+
A::AbstractMatrix,
225+
::ColoringProblem{:nonsymmetric,:column},
226+
algo::GreedyColoringAlgorithm,
227+
decompression_eltype::Type,
228+
symmetric_pattern::Bool,
186229
)
187230
bg = BipartiteGraph(
188231
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
189232
)
190233
color = partial_distance2_coloring(bg, Val(2), algo.order)
191-
return ColumnColoringResult(A, bg, color)
234+
if speed_setting isa WithResult
235+
return ColumnColoringResult(A, bg, color)
236+
else
237+
return color
238+
end
192239
end
193240

194-
function coloring(
241+
function _coloring(
242+
speed_setting::WithOrWithoutResult,
195243
A::AbstractMatrix,
196244
::ColoringProblem{:nonsymmetric,:row},
197-
algo::GreedyColoringAlgorithm;
198-
decompression_eltype::Type=Float64,
199-
symmetric_pattern::Bool=false,
245+
algo::GreedyColoringAlgorithm,
246+
decompression_eltype::Type,
247+
symmetric_pattern::Bool,
200248
)
201249
bg = BipartiteGraph(
202250
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
203251
)
204252
color = partial_distance2_coloring(bg, Val(1), algo.order)
205-
return RowColoringResult(A, bg, color)
253+
if speed_setting isa WithResult
254+
return RowColoringResult(A, bg, color)
255+
else
256+
return color
257+
end
206258
end
207259

208-
function coloring(
260+
function _coloring(
261+
speed_setting::WithOrWithoutResult,
209262
A::AbstractMatrix,
210263
::ColoringProblem{:symmetric,:column},
211-
algo::GreedyColoringAlgorithm{:direct};
212-
decompression_eltype::Type=Float64,
264+
algo::GreedyColoringAlgorithm{:direct},
265+
decompression_eltype::Type,
266+
symmetric_pattern::Bool,
213267
)
214268
ag = AdjacencyGraph(A)
215269
color, star_set = star_coloring(ag, algo.order; postprocessing=algo.postprocessing)
216-
return StarSetColoringResult(A, ag, color, star_set)
270+
if speed_setting isa WithResult
271+
return StarSetColoringResult(A, ag, color, star_set)
272+
else
273+
return color
274+
end
217275
end
218276

219-
function coloring(
277+
function _coloring(
278+
speed_setting::WithOrWithoutResult,
220279
A::AbstractMatrix,
221280
::ColoringProblem{:symmetric,:column},
222-
algo::GreedyColoringAlgorithm{:substitution};
223-
decompression_eltype::Type=Float64,
224-
)
281+
algo::GreedyColoringAlgorithm{:substitution},
282+
decompression_eltype::Type{R},
283+
symmetric_pattern::Bool,
284+
) where {R}
225285
ag = AdjacencyGraph(A)
226286
color, tree_set = acyclic_coloring(ag, algo.order; postprocessing=algo.postprocessing)
227-
return TreeSetColoringResult(A, ag, color, tree_set, decompression_eltype)
287+
if speed_setting isa WithResult
288+
return TreeSetColoringResult(A, ag, color, tree_set, R)
289+
else
290+
return color
291+
end
228292
end
229293

230-
function coloring(
294+
function _coloring(
295+
speed_setting::WithOrWithoutResult,
231296
A::AbstractMatrix,
232297
::ColoringProblem{:nonsymmetric,:bidirectional},
233-
algo::GreedyColoringAlgorithm{decompression};
234-
decompression_eltype::Type{R}=Float64,
235-
symmetric_pattern::Bool=false,
236-
) where {decompression,R}
298+
algo::GreedyColoringAlgorithm{:direct},
299+
decompression_eltype::Type{R},
300+
symmetric_pattern::Bool,
301+
) where {R}
237302
A_and_Aᵀ = bidirectional_pattern(A; symmetric_pattern)
238303
ag = AdjacencyGraph(A_and_Aᵀ; has_diagonal=false)
239-
240-
if decompression == :direct
241-
color, star_set = star_coloring(ag, algo.order; postprocessing=algo.postprocessing)
304+
color, star_set = star_coloring(ag, algo.order; postprocessing=algo.postprocessing)
305+
if speed_setting isa WithResult
242306
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
307+
return BicoloringResult(A, ag, symmetric_result, R)
243308
else
244-
color, tree_set = acyclic_coloring(
245-
ag, algo.order; postprocessing=algo.postprocessing
246-
)
247-
symmetric_result = TreeSetColoringResult(
248-
A_and_Aᵀ, ag, color, tree_set, decompression_eltype
249-
)
309+
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
310+
return row_color, column_color
311+
end
312+
end
313+
314+
function _coloring(
315+
speed_setting::WithOrWithoutResult,
316+
A::AbstractMatrix,
317+
::ColoringProblem{:nonsymmetric,:bidirectional},
318+
algo::GreedyColoringAlgorithm{:substitution},
319+
decompression_eltype::Type{R},
320+
symmetric_pattern::Bool,
321+
) where {R}
322+
A_and_Aᵀ = bidirectional_pattern(A; symmetric_pattern)
323+
ag = AdjacencyGraph(A_and_Aᵀ; has_diagonal=false)
324+
color, tree_set = acyclic_coloring(ag, algo.order; postprocessing=algo.postprocessing)
325+
if speed_setting isa WithResult
326+
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
327+
return BicoloringResult(A, ag, symmetric_result, R)
328+
else
329+
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
330+
return row_color, column_color
250331
end
251-
return BicoloringResult(A, ag, symmetric_result, decompression_eltype)
252332
end
253333

254334
## ADTypes interface
255335

256336
function ADTypes.column_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
257-
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
258-
color = partial_distance2_coloring(bg, Val(2), algo.order)
259-
return color
337+
return fast_coloring(A, ColoringProblem{:nonsymmetric,:column}(), algo)
260338
end
261339

262340
function ADTypes.row_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
263-
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
264-
color = partial_distance2_coloring(bg, Val(1), algo.order)
265-
return color
341+
return fast_coloring(A, ColoringProblem{:nonsymmetric,:row}(), algo)
266342
end
267343

268344
function ADTypes.symmetric_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
269-
ag = AdjacencyGraph(A)
270-
# never postprocess because end users do not expect zeros
271-
color, star_set = star_coloring(ag, algo.order; postprocessing=false)
272-
return color
345+
return fast_coloring(A, ColoringProblem{:symmetric,:column}(), algo)
273346
end

test/small.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using Test
2727
color0 = [1, 1, 2]
2828
@test structurally_orthogonal_columns(A0, color0)
2929
@test directly_recoverable_columns(A0, color0)
30-
test_coloring_decompression(A0, problem, algo; B0, color0)
30+
test_coloring_decompression(A0, problem, algo; B0, color0, test_fast=true)
3131
end;
3232

3333
@testset "Row coloring & decompression" begin
@@ -45,7 +45,7 @@ end;
4545
color0 = [1, 1, 2]
4646
@test structurally_orthogonal_columns(transpose(A0), color0)
4747
@test directly_recoverable_columns(transpose(A0), color0)
48-
test_coloring_decompression(A0, problem, algo; B0, color0)
48+
test_coloring_decompression(A0, problem, algo; B0, color0, test_fast=true)
4949
end;
5050

5151
@testset "Symmetric coloring & direct decompression" begin
@@ -57,15 +57,15 @@ end;
5757
A0, B0, color0 = example.A, example.B, example.color
5858
@test symmetrically_orthogonal_columns(A0, color0)
5959
@test directly_recoverable_columns(A0, color0)
60-
test_coloring_decompression(A0, problem, algo; B0, color0)
60+
test_coloring_decompression(A0, problem, algo; B0, color0, test_fast=true)
6161
end
6262

6363
@testset "Fig 1 from 'Efficient computation of sparse hessians using coloring and AD'" begin
6464
example = efficient_fig_1()
6565
A0, B0, color0 = example.A, example.B, example.color
6666
@test symmetrically_orthogonal_columns(A0, color0)
6767
@test directly_recoverable_columns(A0, color0)
68-
test_coloring_decompression(A0, problem, algo; B0, color0)
68+
test_coloring_decompression(A0, problem, algo; B0, color0, test_fast=true)
6969
end
7070
end;
7171

@@ -77,13 +77,13 @@ end;
7777
example = what_fig_61()
7878
A0, B0, color0 = example.A, example.B, example.color
7979
# our coloring doesn't give the color0 from the example, but that's okay
80-
test_coloring_decompression(A0, problem, algo)
80+
test_coloring_decompression(A0, problem, algo; test_fast=true)
8181
end
8282

8383
@testset "Fig 4 from 'Efficient computation of sparse hessians using coloring and AD'" begin
8484
example = efficient_fig_4()
8585
A0, B0, color0 = example.A, example.B, example.color
86-
test_coloring_decompression(A0, problem, algo; B0, color0)
86+
test_coloring_decompression(A0, problem, algo; B0, color0, test_fast=true)
8787
end
8888
end;
8989

@@ -180,5 +180,19 @@ end;
180180
A, problem, GreedyColoringAlgorithm{:direct}(order; postprocessing=true)
181181
),
182182
)
183+
184+
test_bicoloring_decompression(
185+
A,
186+
problem,
187+
GreedyColoringAlgorithm{:direct}(order; postprocessing=true);
188+
test_fast=true,
189+
)
190+
191+
test_bicoloring_decompression(
192+
A,
193+
problem,
194+
GreedyColoringAlgorithm{:substitution}(order; postprocessing=true);
195+
test_fast=true,
196+
)
183197
end
184-
end
198+
end;

0 commit comments

Comments
 (0)