Skip to content

Commit 80ae44f

Browse files
Demo: fix tests in complex BP (#100)
* Add conj with cost_and_gradient Signed-off-by: 周唤海 <[email protected]> * fix tests --------- Signed-off-by: 周唤海 <[email protected]> Co-authored-by: 周唤海 <[email protected]>
1 parent d914583 commit 80ae44f

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

src/belief.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
8080
# TODO: speed up if needed!
8181
code = star_code(length(vectors_in))
8282
cost, gradient = cost_and_gradient(code, (t, vectors_in...))
83-
for (o, g) in zip(vectors_out, gradient[2:end])
83+
for (o, g) in zip(vectors_out, conj.(gradient[2:end]))
8484
o .= g
8585
end
8686
return cost[]
@@ -115,7 +115,7 @@ Run the belief propagation algorithm, and return the final state and the informa
115115
116116
### Keyword Arguments
117117
- `max_iter::Int=100`: the maximum number of iterations
118-
- `tol::Float64=1e-6`: the tolerance for the convergence
118+
- `tol::Float64=1e-6`: the tolerance for the convergence, the convergence is checked by infidelity of messages in consecutive iterations. For complex numbers, the converged message may be different only by a phase factor.
119119
- `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message
120120
"""
121121
function belief_propagate(bp::BeliefPropgation; kwargs...)
@@ -133,14 +133,21 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In
133133
collect_message!(bp, state; normalize = true)
134134
process_message!(state; normalize = true, damping = damping)
135135
# check convergence
136-
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
136+
if all(iv -> all(it -> message_converged(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
137137
return BPInfo(true, i)
138138
end
139139
pre_message_in = deepcopy(state.message_in)
140140
end
141141
return BPInfo(false, max_iter)
142142
end
143143

144+
# check if two messages are converged by fidelity (needed for complex numbers)
145+
function message_converged(a, b; atol)
146+
a_norm = norm(a)
147+
b_norm = norm(b)
148+
return isapprox(a_norm, b_norm, atol=atol) && isapprox(sqrt(abs(a' * b)), a_norm, atol=atol)
149+
end
150+
144151
# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
145152
function contraction_results(state::BPState{T}) where {T}
146153
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]

src/mar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors.
7878
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
7979
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
8080
cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,))
81+
grads = conj.(grads)
8182
@debug "cost = $cost"
8283
ixs = OMEinsum.getixsv(tn.code)
8384
queryvars = ixs[tn.unity_tensors_idx]

test/belief.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ end
4646
@testset "belief propagation" begin
4747
n = 5
4848
chi = 3
49-
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi)
49+
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi)
5050
bp = BeliefPropgation(mps_uai)
5151
@test TensorInference.initial_state(bp) isa TensorInference.BPState
52-
state, info = belief_propagate(bp)
52+
state, info = belief_propagate(bp; max_iter=100, tol=1e-8)
5353
@test info.converged
5454
@test info.iterations < 20
5555
mars = marginals(state)
5656
tnet = TensorNetworkModel(mps_uai)
5757
mars_tnet = marginals(tnet)
5858
for v in 1:TensorInference.num_variables(bp)
59-
@test mars[[v]] mars_tnet[[v]] atol=1e-6
59+
@test mars[[v]] mars_tnet[[v]] atol=1e-4
6060
end
6161
end
6262

63-
@testset "belief propagation on circle" begin
63+
@testset "belief propagation on circle (Real)" begin
6464
n = 10
6565
chi = 3
6666
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
@@ -78,6 +78,25 @@ end
7878
end
7979
end
8080

81+
82+
@testset "belief propagation on circle (Complex)" begin
83+
n = 10
84+
chi = 3
85+
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
86+
bp = BeliefPropgation(mps_uai)
87+
@test TensorInference.initial_state(bp) isa TensorInference.BPState
88+
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
89+
@test info.converged
90+
@test info.iterations < 100
91+
contraction_res = TensorInference.contraction_results(state)
92+
tnet = TensorNetworkModel(mps_uai)
93+
mars = marginals(state)
94+
mars_tnet = marginals(tnet)
95+
for v in 1:TensorInference.num_variables(bp)
96+
@test TensorInference.message_converged(mars[[v]], mars_tnet[[v]]; atol=1e-4)
97+
end
98+
end
99+
81100
@testset "marginal uai2014" begin
82101
for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)]
83102
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)

0 commit comments

Comments
 (0)