Skip to content

Demo: fix tests in complex BP #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/belief.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve
# TODO: speed up if needed!
code = star_code(length(vectors_in))
cost, gradient = cost_and_gradient(code, (t, vectors_in...))
for (o, g) in zip(vectors_out, gradient[2:end])
for (o, g) in zip(vectors_out, conj.(gradient[2:end]))
o .= g
end
return cost[]
Expand Down Expand Up @@ -115,7 +115,7 @@ Run the belief propagation algorithm, and return the final state and the informa

### Keyword Arguments
- `max_iter::Int=100`: the maximum number of iterations
- `tol::Float64=1e-6`: the tolerance for the convergence
- `tol::Float64=1e-6`: the tolerance for the convergence, the convergence is checked by infidelity of messages in consecutive iterations. For complex numbers, the converged message may be different only by a phase factor.
- `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message
"""
function belief_propagate(bp::BeliefPropgation; kwargs...)
Expand All @@ -133,14 +133,21 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In
collect_message!(bp, state; normalize = true)
process_message!(state; normalize = true, damping = damping)
# check convergence
if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
if all(iv -> all(it -> message_converged(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp))
return BPInfo(true, i)
end
pre_message_in = deepcopy(state.message_in)
end
return BPInfo(false, max_iter)
end

# check if two messages are converged by fidelity (needed for complex numbers)
function message_converged(a, b; atol)
a_norm = norm(a)
b_norm = norm(b)
return isapprox(a_norm, b_norm, atol=atol) && isapprox(sqrt(abs(a' * b)), a_norm, atol=atol)
end

# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction
function contraction_results(state::BPState{T}) where {T}
return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in]
Expand Down
1 change: 1 addition & 0 deletions src/mar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors.
function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}}
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,))
grads = conj.(grads)
@debug "cost = $cost"
ixs = OMEinsum.getixsv(tn.code)
queryvars = ixs[tn.unity_tensors_idx]
Expand Down
27 changes: 23 additions & 4 deletions test/belief.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@ end
@testset "belief propagation" begin
n = 5
chi = 3
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi)
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi)
bp = BeliefPropgation(mps_uai)
@test TensorInference.initial_state(bp) isa TensorInference.BPState
state, info = belief_propagate(bp)
state, info = belief_propagate(bp; max_iter=100, tol=1e-8)
@test info.converged
@test info.iterations < 20
mars = marginals(state)
tnet = TensorNetworkModel(mps_uai)
mars_tnet = marginals(tnet)
for v in 1:TensorInference.num_variables(bp)
@test mars[[v]] ≈ mars_tnet[[v]] atol=1e-6
@test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4
end
end

@testset "belief propagation on circle" begin
@testset "belief propagation on circle (Real)" begin
n = 10
chi = 3
mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true)
Expand All @@ -78,6 +78,25 @@ end
end
end


@testset "belief propagation on circle (Complex)" begin
n = 10
chi = 3
mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge
bp = BeliefPropgation(mps_uai)
@test TensorInference.initial_state(bp) isa TensorInference.BPState
state, info = belief_propagate(bp; max_iter=100, tol=1e-6)
@test info.converged
@test info.iterations < 100
contraction_res = TensorInference.contraction_results(state)
tnet = TensorNetworkModel(mps_uai)
mars = marginals(state)
mars_tnet = marginals(tnet)
for v in 1:TensorInference.num_variables(bp)
@test TensorInference.message_converged(mars[[v]], mars_tnet[[v]]; atol=1e-4)
end
end

@testset "marginal uai2014" begin
for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)]
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)
Expand Down
Loading