Skip to content

Commit 356c7c2

Browse files
authored
Merge pull request #78 from TensorBFS/jg/fix-overflow
Fix the overflow issue in probability
2 parents 83ee4e7 + a1100ad commit 356c7c2

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ makedocs(;
6767
"Contributing" => "contributing.md",
6868
],
6969
doctest = false,
70+
warnonly = :missing_docs,
7071
)
7172

7273
deploydocs(;

docs/src/api/public.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,6 @@ read_model_file
6666
read_evidence_file
6767
read_td_file
6868
sample
69+
update_evidence!
70+
update_temperature
6971
```

src/Core.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ function log_probability(tn::TensorNetworkModel, config::Union{Dict, AbstractVec
203203
assign = config isa AbstractVector ? Dict(zip(get_vars(tn), config)) : config
204204
return sum(x -> log(x[2][(getindex.(Ref(assign), x[1]) .+ 1)...]), zip(getixsv(tn.code), tn.tensors))
205205
end
206+
"""
207+
$(TYPEDSIGNATURES)
208+
209+
Evaluate the log probability (or partition function).
210+
It is the logged version of [`probability`](@ref), which is less likely to overflow.
211+
"""
212+
function log_probability(tn::TensorNetworkModel; usecuda = false)::AbstractArray
213+
res = probability(tn; usecuda, rescale=true)
214+
return asarray(res.log_factor .+ log.(res.normalized_value), res.normalized_value)
215+
end
206216

207217
"""
208218
$(TYPEDSIGNATURES)

src/mar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ tensor network.
140140
- `rescale`: Specifies whether to rescale the tensors during contraction.
141141
142142
### Example
143-
The following example is taken from [`examples/asia-network/main.jl`](@ref).
143+
The following example is taken from [`examples/asia-network/main.jl`](https://tensorbfs.github.io/TensorInference.jl/dev/generated/asia-network/main/).
144144
145145
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146146
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia-network", "model.uai"));

test/pr.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,35 @@ using TensorInference
66
@testset "UAI Reference Solution Comparison" begin
77
problems = dataset_from_artifact("uai2014")["PR"]
88
problem_sets = [
9-
#("Alchemy", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # fails
9+
#("Alchemy", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
1010
#("CSP", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
1111
#("DBN", KaHyParBipartite(sc_target = 25)),
12-
#("Grids", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)), # fails
13-
#("linkage", TreeSA(ntrials = 3, niters = 20, βs = 0.1:0.1:40)), # fails
12+
#("Grids", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
13+
#("linkage", TreeSA(ntrials = 3, niters = 20, βs = 0.1:0.1:40)),
1414
#("ObjectDetection", TreeSA(ntrials = 1, niters = 5, βs = 1:0.1:100)),
1515
("Pedigree", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
1616
#("Promedus", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)),
17-
#("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), # fails
17+
#("relational", TreeSA(ntrials=1, niters=5, βs=0.1:0.1:100)), tw too large
1818
("Segmentation", TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100))
1919
]
2020
for (problem_set_name, optimizer) in problem_sets
2121
@testset "$(problem_set_name) problem set" begin
2222
for (id, problem) in problems[problem_set_name]
2323
@info "Testing: $(problem_set_name)_$id"
2424
tn = TensorNetworkModel(read_model(problem); optimizer, evidence=read_evidence(problem))
25-
solution = probability(tn) |> first |> log10
26-
@test isapprox(solution, read_solution(problem); atol = 1e-3)
25+
solution = log_probability(tn) / log(10) |> first
26+
@test isapprox(solution, read_solution(problem); atol = 1e-3, rtol=1e-3)
2727
end
2828
end
2929
end
3030
end
31+
32+
@testset "issue 77" begin
33+
problems = dataset_from_artifact("uai2014")["PR"]
34+
problem_set_name = "Alchemy"
35+
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)
36+
id, problem = problems[problem_set_name] |> first
37+
tn = TensorNetworkModel(read_model(problem); optimizer, evidence=read_evidence(problem))
38+
solution = log_probability(tn) / log(10) |> first
39+
@test isapprox(solution, read_solution(problem); atol=1e-3)
40+
end

0 commit comments

Comments
 (0)