Skip to content

Commit e5ae5a2

Browse files
authored
UAI reference solution comparison tests (#41)
* Update Artifacts.toml * Add MAP UAI reference solution comparison test * Implement `read_query_file` * Make `get_instance_filepaths` return the query file path as well * Add a `queryvars` field to UAIInstance composite type * Add MMAP UAI reference solution comparison test * Add comments beside MMAP problems that fail the test * Disable Promedus MAR tests * Treat reading PR sol file as a separate case and simplify tests * Refactor `read_solution_file` function * Print problem name being tested for all "UAI Reference Sols Comp" testsets * Move `get_problems_names` function to `test/utils.jl` * Fix typo in function name * Parse PR solution as Float64 * Minor * Enable Pedigree MAR UAI reference tests * Add PR UAI reference solution comparison test * Update the UAI file format URL
1 parent 7e884de commit e5ae5a2

File tree

13 files changed

+155
-53
lines changed

13 files changed

+155
-53
lines changed

benchmark/bench_map.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Artifacts
66

77
const SUITE = BenchmarkGroup()
88

9-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
9+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
1010
problem = read_instance(model_filepath; evidence_filepath, solution_filepath)
1111

1212
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)

benchmark/bench_mar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Artifacts
88

99
const SUITE = BenchmarkGroup()
1010

11-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
11+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
1212
problem = read_instance(model_filepath; evidence_filepath, solution_filepath)
1313

1414
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)

benchmark/bench_mmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Artifacts
66

77
const SUITE = BenchmarkGroup()
88

9-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
9+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
1010
problem = read_instance(model_filepath; evidence_filepath, solution_filepath)
1111
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
1212

src/Core.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ $(TYPEDEF)
2424
2525
* `obsvars` is a vector of observed variables,
2626
* `obsvals` is a vector of observed values,
27+
* `queryvars` is a vector of query variables,
2728
* `reference_solution` is a vector with the reference solution.
2829
"""
2930
struct UAIInstance{ET, FT <: Factor{ET}}
@@ -34,7 +35,8 @@ struct UAIInstance{ET, FT <: Factor{ET}}
3435

3536
obsvars::Vector{Int}
3637
obsvals::Vector{Int}
37-
reference_solution::Union{Vector{Vector{ET}}, Vector{Int}}
38+
queryvars::Vector{Int}
39+
reference_solution::Union{Vector{Vector{ET}}, Vector{Int}, Float64}
3840
end
3941

4042
"""

src/utils.jl

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Parse the problem instance found in `model_filepath` defined in the UAI model
55
format.
66
77
The UAI file formats are defined in:
8-
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
8+
https://uaicompetition.github.io/uci-2022/file-formats/
99
"""
1010
function read_model_file(model_filepath; factor_eltype = Float64)
1111
# Read the uai file into an array of lines
@@ -69,7 +69,7 @@ Return the observed variables and values in `evidence_filepath`. If the passed
6969
file path is an empty string, return empty vectors.
7070
7171
The UAI file formats are defined in:
72-
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
72+
https://uaicompetition.github.io/uci-2022/file-formats/
7373
"""
7474
function read_evidence_file(evidence_filepath::AbstractString)
7575

@@ -93,35 +93,67 @@ function read_evidence_file(evidence_filepath::AbstractString)
9393
return obsvars, obsvals
9494
end
9595

96+
"""
97+
$(TYPEDSIGNATURES)
98+
99+
Return the query variables in `query_filepath`. If the passed file path is an
100+
empty string, return an empty vector.
101+
102+
The UAI file formats are defined in:
103+
https://uaicompetition.github.io/uci-2022/file-formats/
104+
"""
105+
function read_query_file(query_filepath::AbstractString)
106+
isempty(query_filepath) && return Int64[]
107+
108+
# Read the first line of the uai query file
109+
line = open(query_filepath) do file
110+
readlines(file)
111+
end |> first
112+
113+
# Separate the number of query vars and their indices
114+
nqueryvars, queryvars_zero_based = split(line) |> x -> parse.(Int, x) |> x -> (x[1], x[2:end])
115+
116+
# Convert to 1-based indexing
117+
queryvars = queryvars_zero_based .+ 1
118+
119+
@assert nqueryvars == length(queryvars)
120+
121+
return queryvars
122+
end
123+
96124
function read_solution_file(solution_filepath::AbstractString; factor_eltype = Float64)
125+
97126
result = Vector{factor_eltype}[]
98127
extension = splitext(solution_filepath)[2]
128+
129+
# Read the solution file into an array of lines
130+
rawlines = open(solution_filepath) do file
131+
readlines(file)
132+
end
133+
99134
if extension == ".MAR"
100-
return read_mar_solution_file(solution_filepath; factor_eltype)
101-
elseif extension == ".MAP" || extension == ".MMAP" || extension == ".PR"
102-
# Return the last line of the file as a vector of integers
103-
result = open(solution_filepath) do file
104-
readlines(file)
105-
end |> last |> split |> x -> parse.(Int, x)
135+
result = parse_mar_solution_file(rawlines; factor_eltype)
136+
elseif extension == ".MAP" || extension == ".MMAP"
137+
# Return all elements except the first in the last line as a vector of integers
138+
result = last(rawlines) |> split |> x -> x[2:end] |> x -> parse.(Int, x)
139+
elseif extension == ".PR"
140+
# Parse the number in the last line as a floating point
141+
result = last(rawlines) |> x -> parse(Float64, x)
106142
end
143+
107144
return result
108145
end
109146

110147
"""
111148
$(TYPEDSIGNATURES)
112149
113-
Return the marginals of all variables. The order of the variables is the same
114-
as in the model
150+
Parse the solution marginals of all variables from the UAI MAR solution file.
151+
The order of the variables is the same as in the model definition.
115152
116153
The UAI file formats are defined in:
117-
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
154+
https://uaicompetition.github.io/uci-2022/file-formats/
118155
"""
119-
function read_mar_solution_file(solution_filepath::AbstractString; factor_eltype = Float64)
120-
121-
# Read the uai mar file into an array of lines
122-
rawlines = open(solution_filepath) do file
123-
readlines(file)
124-
end
156+
function parse_mar_solution_file(rawlines::Vector{String}; factor_eltype = Float64)
125157

126158
parsed_margs = split(rawlines[2]) |> x -> x[2:end] |> x -> parse.(factor_eltype, x)
127159

@@ -192,13 +224,15 @@ Read a UAI problem instance from a file.
192224
function read_instance(
193225
model_filepath::AbstractString;
194226
evidence_filepath::AbstractString = "",
227+
query_filepath::AbstractString = "",
195228
solution_filepath::AbstractString = "",
196229
eltype = Float64
197230
)::UAIInstance
198231
nvars, cards, ncliques, factors = read_model_file(model_filepath; factor_eltype = eltype)
199232
obsvars, obsvals = read_evidence_file(evidence_filepath)
233+
queryvars = read_query_file(query_filepath)
200234
reference_solution = isempty(solution_filepath) ? Vector{eltype}[] : read_solution_file(solution_filepath)
201-
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_solution)
235+
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, queryvars, reference_solution)
202236
end
203237

204238
function read_instance_from_string(uai::AbstractString; eltype = Float64)::UAIInstance

test/Artifacts.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[uai2014]
2-
git-tree-sha1 = "d05c5e541cc06f1cb4d8e24c1067e07472e36c24"
2+
git-tree-sha1 = "199ed43697fe22447c6c64a939b222fd4073f2d0"
33

44
[[uai2014.download]]
5-
sha256 = "73c91cd68931aec562499ab66ed2326771b829aa715e790609c6a1b86c9a9ad8"
5+
sha256 = "5d93ced227cff3eb2ae7feb77dcb6c780212b47a0c0355dda8439de6f5b9d369"
66
url = "https://github.com/mroavi/uai-2014-inference-competition/raw/main/uai2014.tar.gz"

test/cuda.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using TensorInference, CUDA
44
CUDA.allowscalar(false)
55

66
@testset "gradient-based tensor network solvers" begin
7-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
7+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
88
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
99

1010
# does not optimize over open vertices
@@ -21,7 +21,7 @@ CUDA.allowscalar(false)
2121
end
2222

2323
@testset "map" begin
24-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
24+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
2525
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
2626

2727
# does not optimize over open vertices
@@ -36,7 +36,7 @@ end
3636
end
3737

3838
@testset "mmap" begin
39-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
39+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
4040
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
4141

4242
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)

test/map.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using OMEinsum
33
using TensorInference
44

55
@testset "gradient-based tensor network solvers" begin
6-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
6+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
77
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
88

99
# does not optimize over open vertices
@@ -14,3 +14,13 @@ using TensorInference
1414
@test log_probability(tn, config) logp
1515
@test maximum_logp(tn)[] logp
1616
end
17+
18+
@testset "UAI Reference Solution Comparison" begin
19+
problem_name = "Promedas_70"
20+
@info "Testing: $problem_name"
21+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths(problem_name, "MAP")
22+
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
23+
tn = TensorNetworkModel(instance; optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100))
24+
_, solution = most_probable_config(tn)
25+
@test solution == instance.reference_solution
26+
end

test/mar.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using TensorInference
1212
end
1313

1414
@testset "cached, rescaled contract" begin
15-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
15+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
1616
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
1717
ref_sol = instance.reference_solution
1818
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)
@@ -34,39 +34,27 @@ end
3434
@test isapprox(ti_sol, ref_sol; atol = 1e-5)
3535
end
3636

37-
function get_problems_names(problem_set::String)
38-
# Capture the problem names that belong to the current problem_set
39-
regex = Regex("($(problem_set)_\\d*)(\\.uai)\$")
40-
return readdir(joinpath(artifact"uai2014", "MAR"); sort = false) |>
41-
x -> map(y -> match(regex, y), x) |> # apply regex
42-
x -> filter(!isnothing, x) |> # filter out `nothing` values
43-
x -> map(first, x) # get the first capture of each element
44-
end
45-
46-
@testset "gradient-based tensor network solvers" begin
37+
@testset "UAI Reference Solution Comparison" begin
4738
problem_sets = [
4839
#("Alchemy", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
4940
#("CSP", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
5041
#("DBN", KaHyParBipartite(sc_target = 25)),
5142
#("Grids", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
5243
#("linkage", TreeSA(ntrials = 3, niters = 20, βs = 0.1:0.1:40)), # linkage_15 fails
5344
#("ObjectDetection", TreeSA(ntrials = 1, niters = 5, βs = 1:0.1:100)),
54-
#("Pedigree", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
55-
("Promedus", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
45+
("Pedigree", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
46+
#("Promedus", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # greedy also works
5647
#("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)),
5748
("Segmentation", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)) # greedy also works
5849
]
59-
6050
for (problem_set, optimizer) in problem_sets
61-
@testset "$(problem_set) problem_set" begin
62-
51+
@testset "$(problem_set) problem set" begin
6352
# Capture the problem names that belong to the current problem set
64-
problem_names = get_problems_names(problem_set)
65-
53+
problem_names = get_problem_names(problem_set, "MAR")
6654
for problem_name in problem_names
6755
@info "Testing: $problem_name"
6856
@testset "$(problem_name)" begin
69-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths(problem_name, "MAR")
57+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths(problem_name, "MAR")
7058
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
7159
ref_sol = instance.reference_solution
7260
obsvars = instance.obsvars

test/mmap.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,42 @@ using TensorInference
88
end
99

1010
@testset "gradient-based tensor network solvers" begin
11-
model_filepath, evidence_filepath, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
11+
model_filepath, evidence_filepath, _, solution_filepath = get_instance_filepaths("Promedus_14", "MAR")
1212
instance = read_instance(model_filepath; evidence_filepath, solution_filepath)
1313

1414
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
1515
tn_ref = TensorNetworkModel(instance; optimizer)
16-
# does not marginalize any var
16+
17+
# Does not marginalize any var
1718
mmap = MMAPModel(instance; marginalized = Int[], optimizer)
1819
@debug(mmap)
1920
@test maximum_logp(tn_ref) maximum_logp(mmap)
2021

21-
# marginalize all vars
22+
# Marginalize all vars
2223
mmap2 = MMAPModel(instance; marginalized = collect(1:(instance.nvars)), optimizer)
2324
@debug(mmap2)
2425
@test Array(probability(tn_ref))[] exp(maximum_logp(mmap2)[])
2526

26-
# does not optimize over open vertices
27+
# Does not optimize over open vertices
2728
mmap3 = MMAPModel(instance; marginalized = [2, 4, 6], optimizer)
2829
@debug(mmap3)
2930
logp, config = most_probable_config(mmap3)
3031
@test log_probability(mmap3, config) logp
31-
end
32+
33+
end
34+
35+
@testset "UAI Reference Solution Comparison" begin
36+
problems = [
37+
("Segmentation_12", TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)),
38+
# ("Segmentation_13", TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)), # fails!
39+
# ("Segmentation_14", TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)) # fails!
40+
]
41+
for (problem_name, optimizer) in problems
42+
@info "Testing: $problem_name"
43+
model_filepath, evidence_filepath, query_filepath, solution_filepath = get_instance_filepaths(problem_name, "MMAP")
44+
instance = read_instance(model_filepath; evidence_filepath, query_filepath, solution_filepath)
45+
model = MMAPModel(instance; marginalized = setdiff(1:(instance.nvars), instance.queryvars), optimizer)
46+
_, solution = most_probable_config(model)
47+
@test solution == instance.reference_solution
48+
end
49+
end

0 commit comments

Comments
 (0)