Skip to content

Commit 4edcd8b

Browse files
committed
replace marginalized with queryvars
1 parent 010b415 commit 4edcd8b

File tree

4 files changed

+23
-30
lines changed

4 files changed

+23
-30
lines changed

examples/asia/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ logp, cfg = most_probable_config(tnet)
3131

3232
# Get the maximum log-probabilities (MMAP)
3333
# To get the probability of lung cancer, we need to marginalize out other variables.
34-
mmap = MMAPModel(instance; marginalized=[1,2,3,5,6,8])
34+
mmap = MMAPModel(instance; queryvars=[4,7])
3535
# We get the most probable configurations on [4, 7]
3636
most_probable_config(mmap)
3737
# The total probability of having lung cancer is roughly half.

src/mmap.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ Computing the most likely assignment to the query variables, Xₘ ⊆ X after m
1717
```
1818
1919
### Fields
20-
* `vars` is the remaining (or not marginalized) degree of freedoms in the tensor network.
20+
* `vars` is the query variables in the tensor network.
2121
* `code` is the tropical tensor network contraction pattern.
2222
* `tensors` is the tensors fed into the tensor network.
2323
* `clusters` is the clusters, each element of this cluster is a [`TensorNetworkModel`](@ref) instance for marginalizing certain variables.
24-
* `fixedvertices` is a dictionary to specifiy degree of freedoms fixed to certain values, which should not have overlap with the marginalized variables.
24+
* `fixedvertices` is a dictionary to specifiy degree of freedoms fixed to certain values, which should not have overlap with the query variables.
2525
"""
2626
struct MMAPModel{LT, AT <: AbstractArray}
2727
vars::Vector{LT}
@@ -37,7 +37,7 @@ function Base.show(io::IO, mmap::MMAPModel)
3737
tc, sc, rw = contraction_complexity(mmap)
3838
println(io, "$(typeof(mmap))")
3939
println(io, "variables: $variables")
40-
println(io, "marginalized variables: $(map(x->x.eliminated_vars, mmap.clusters))")
40+
println(io, "query variables: $(map(x->x.eliminated_vars, mmap.clusters))")
4141
print_tcscrw(io, tc, sc, rw)
4242
end
4343
Base.show(io::IO, ::MIME"text/plain", mmap::MMAPModel) = Base.show(io, mmap)
@@ -58,24 +58,26 @@ end
5858
"""
5959
$(TYPEDSIGNATURES)
6060
"""
61-
function MMAPModel(instance::UAIInstance; marginalized, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
61+
function MMAPModel(instance::UAIInstance; queryvars, openvertices = (), optimizer = GreedyMethod(), simplifier = nothing)::MMAPModel
6262
return MMAPModel(
63-
1:(instance.nvars), instance.cards, instance.factors; marginalized, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
63+
1:(instance.nvars), instance.cards, instance.factors; queryvars, fixedvertices = Dict(zip(instance.obsvars, instance.obsvals)), optimizer, simplifier, openvertices
6464
)
6565
end
6666

6767
"""
6868
$(TYPEDSIGNATURES)
6969
"""
70-
function MMAPModel(vars::AbstractVector{LT}, cards::AbstractVector{Int}, factors::Vector{<:Factor{T}}; marginalized, openvertices = (),
71-
fixedvertices = Dict{LT, Int}(),
72-
optimizer = GreedyMethod(), simplifier = nothing,
73-
marginalize_optimizer = GreedyMethod(), marginalize_simplifier = nothing
74-
)::MMAPModel where {T, LT}
70+
function MMAPModel(vars::AbstractVector{LT}, cards::AbstractVector{Int}, factors::Vector{<:Factor{T}}; queryvars, openvertices = (),
71+
fixedvertices = Dict{LT, Int}(),
72+
optimizer = GreedyMethod(), simplifier = nothing,
73+
marginalize_optimizer = GreedyMethod(), marginalize_simplifier = nothing
74+
)::MMAPModel where {T, LT}
7575
all_ixs = [[[var] for var in vars]..., [[factor.vars...] for factor in factors]...] # labels for vertex tensors (unity tensors) and edge tensors
7676
iy = collect(LT, openvertices)
77-
if !isempty(setdiff(iy, vars))
78-
error("Marginalized variables should not contain any output variable.")
77+
evidencevars = collect(LT, keys(fixedvertices))
78+
marginalized = setdiff(vars, iy queryvars evidencevars)
79+
if !isempty(setdiff(iy, marginalized))
80+
error("Marginalized variables should not contain any output variable, got $(marginalized) and $iy.")
7981
end
8082
all_tensors = [[ones(T, cards[i]) for i in 1:length(vars)]..., getfield.(factors, :vals)...]
8183
size_dict = OMEinsum.get_size_dict(all_ixs, all_tensors)

test/cuda.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,22 @@ end
4242
optimizer = TreeSA(ntrials = 1, niters = 2, βs = 1:0.1:40)
4343
tn_ref = TensorNetworkModel(instance; optimizer)
4444
# does not marginalize any var
45-
tn = MMAPModel(instance; marginalized = Int[], optimizer)
45+
tn = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
4646
r1, r2 = maximum_logp(tn_ref; usecuda = true), maximum_logp(tn; usecuda = true)
4747
@test r1 isa CuArray
4848
@test r2 isa CuArray
4949
@test r1 r2
5050

5151
# marginalize all vars
52-
tn2 = MMAPModel(instance; marginalized = collect(1:(instance.nvars)), optimizer)
52+
tn2 = MMAPModel(instance; queryvars = Int[], optimizer)
5353
cup = probability(tn_ref; usecuda = true)
5454
culogp = maximum_logp(tn2; usecuda = true)
5555
@test cup isa RescaledArray{T, N, <:CuArray} where {T, N}
5656
@test culogp isa CuArray
5757
@test Array(cup)[] exp(Array(culogp)[])
5858

5959
# does not optimize over open vertices
60-
tn3 = MMAPModel(instance; marginalized = [2, 4, 6], optimizer)
60+
tn3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
6161
logp, config = most_probable_config(tn3; usecuda = true)
6262
@test log_probability(tn3, config) logp
6363
end

test/mmap.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ end
1515
tn_ref = TensorNetworkModel(instance; optimizer)
1616

1717
# Does not marginalize any var
18-
mmap = MMAPModel(instance; marginalized = Int[], optimizer)
18+
mmap = MMAPModel(instance; queryvars = collect(1:instance.nvars), optimizer)
1919
@debug(mmap)
2020
@test maximum_logp(tn_ref) maximum_logp(mmap)
2121

2222
# Marginalize all vars
23-
mmap2 = MMAPModel(instance; marginalized = collect(1:(instance.nvars)), optimizer)
23+
mmap2 = MMAPModel(instance; queryvars = Int[], optimizer)
2424
@debug(mmap2)
2525
@test Array(probability(tn_ref))[] exp(maximum_logp(mmap2)[])
2626

2727
# Does not optimize over open vertices
28-
mmap3 = MMAPModel(instance; marginalized = [2, 4, 6], optimizer)
28+
mmap3 = MMAPModel(instance; queryvars = setdiff(1:instance.nvars, [2, 4, 6]), optimizer)
2929
@debug(mmap3)
3030
logp, config = most_probable_config(mmap3)
3131
@test log_probability(mmap3, config) logp
@@ -42,17 +42,8 @@ end
4242
@info "Testing: $problem_name"
4343
model_filepath, evidence_filepath, query_filepath, solution_filepath = get_instance_filepaths(problem_name, "MMAP")
4444
instance = read_instance(model_filepath; evidence_filepath, query_filepath, solution_filepath)
45-
model = MMAPModel(instance; marginalized = setdiff(1:(instance.nvars), instance.queryvars), optimizer)
45+
model = MMAPModel(instance; queryvars = instance.queryvars, optimizer)
4646
_, solution = most_probable_config(model)
4747
@test solution == instance.reference_solution
4848
end
49-
end
50-
51-
using Artifacts
52-
include("utils.jl")
53-
model_filepath, evidence_filepath, query_filepath, solution_filepath = get_instance_filepaths("Segmentation_11", "MMAP")
54-
instance = read_instance(model_filepath; evidence_filepath, query_filepath, solution_filepath)
55-
ref_sol = read_solution_file(solution_filepath)[2:end]
56-
57-
optimizer = TreeSA(ntrials=1, niters=2, βs=1:0.1:40)
58-
mmap = MMAPModel(instance; marginalized=setdiff(1:instance.nvars, instance.queryvars), optimizer)
49+
end

0 commit comments

Comments
 (0)