Skip to content

Implement Belief propagation as a new inference backend #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 12, 2025
Merged
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"]
version = "0.5.0"

[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProblemReductions = "899c297d-f7d2-4ebf-8815-a35996def416"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -22,15 +20,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
TensorInferenceCUDAExt = "CUDA"

[compat]
Artifacts = "1"
CUDA = "4, 5"
DocStringExtensions = "0.8.6, 0.9"
LinearAlgebra = "1"
OMEinsum = "0.8"
OMEinsum = "0.8.7"
Pkg = "1"
PrecompileTools = "1"
PrettyTables = "2"
ProblemReductions = "0.3"
StatsBase = "0.34"
TropicalNumbers = "0.5.4, 0.6"
julia = "1.9"
julia = "1.10"
4 changes: 4 additions & 0 deletions docs/src/api/public.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ RescaledArray
TensorNetworkModel
ArtifactProblemSpec
UAIModel
BeliefPropgation
```

## Functions
Expand All @@ -56,6 +57,7 @@ marginals
maximum_logp
most_probable_config
probability
belief_propagate
dataset_from_artifact
problem_from_artifact
read_model
Expand All @@ -69,4 +71,6 @@ sample
update_evidence!
update_temperature
random_matrix_product_state
random_matrix_product_uai
random_tensor_train_uai
```
19 changes: 19 additions & 0 deletions docs/src/tensor-networks.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ Some of these have been implemented in the
[OMEinsum](https://github.com/under-Peter/OMEinsum.jl) package. Please check
[Performance Tips](@ref) for more details.

## Belief propagation

Belief propagation[^Yedidia2003] is a message passing algorithm that can be used to compute the marginals of a probabilistic graphical model. It has close connections with the tensor networks. It can be viewed as a way to gauge the tensor networks[^Tindall2023], and can be combined with tensor networks to achieve better performance[^Wang2024].

Belief propagation is an approximate method, and the quality of the approximation can be improved by the loop series expansion[^Evenbly2024].


## References

[^Orus2014]:
Expand All @@ -227,3 +234,15 @@ Some of these have been implemented in the

[^Liu2023]:
Liu J G, Gao X, Cain M, et al. Computing solution space properties of combinatorial optimization problems via generic tensor networks[J]. SIAM Journal on Scientific Computing, 2023, 45(3): A1239-A1270.

[^Yedidia2003]:
Yedidia, J.S., Freeman, W.T., Weiss, Y., 2003. Understanding belief propagation and its generalizations, in: Exploring Artificial Intelligence in the New Millennium. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, pp. 239–269.

[^Wang2024]:
Wang, Y., Zhang, Y.E., Pan, F., Zhang, P., 2024. Tensor Network Message Passing. Phys. Rev. Lett. 132, 117401. https://doi.org/10.1103/PhysRevLett.132.117401

[^Tindall2023]:
Tindall, J., Fishman, M.T., 2023. Gauging tensor networks with belief propagation. SciPost Phys. 15, 222. https://doi.org/10.21468/SciPostPhys.15.6.222

[^Evenbly2024]:
Evenbly, G., Pancotti, N., Milsted, A., Gray, J., Chan, G.K.-L., 2024. Loop Series Expansions for Tensor Networks. https://doi.org/10.48550/arXiv.2409.03108
2 changes: 1 addition & 1 deletion examples/hard-core-lattice-gas/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mars = marginals(pmodel)
show_graph(SimpleGraph(graph), sites; vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph)))
# The can see the sites at the corner is more likely to be occupied.
# To obtain two-site correlations, one can set the variables to query marginal probabilities manually.
pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)])
pmodel2 = TensorNetworkModel(problem, β; unity_tensors_labels = [[e.src, e.dst] for e in edges(graph)])
mars = marginals(pmodel2);

# We show the probability that both sites on an edge are not occupied
Expand Down
5 changes: 1 addition & 4 deletions ext/TensorInferenceCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
module TensorInferenceCUDAExt
using CUDA: CuArray
import CUDA
import TensorInference: match_arraytype, keep_only!, onehot_like, togpu
import TensorInference: keep_only!, onehot_like, togpu

function onehot_like(A::CuArray, j)
mask = zero(A)
CUDA.@allowscalar mask[j] = one(eltype(mask))
return mask
end

# NOTE: this interface should be in OMEinsum
match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target)

function keep_only!(x::CuArray{T}, j) where T
CUDA.@allowscalar hotvalue = x[j]
fill!(x, zero(T))
Expand Down
101 changes: 19 additions & 82 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,18 @@ $(TYPEDEF)
Probabilistic modeling with a tensor network.

### Fields
* `vars` are the degrees of freedom in the tensor network.
* `nvars` are the number of variables in the tensor network.
* `code` is the tensor network contraction pattern.
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`.
* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`.
* `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values.
* `mars` is a vector, each element is a vector of variables to compute marginal probabilities.
* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities.
"""
struct TensorNetworkModel{LT, ET, MT <: AbstractArray}
vars::Vector{LT}
struct TensorNetworkModel{ET, MT <: AbstractArray}
nvars::Int
code::ET
tensors::Vector{MT}
evidence::Dict{LT, Int}
mars::Vector{Vector{LT}}
evidence::Dict{Int, Int}
unity_tensors_idx::Vector{Int}
end

"""
Expand All @@ -78,7 +78,7 @@ end

function Base.show(io::IO, tn::TensorNetworkModel)
open = getiyv(tn.code)
variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ")
variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ")
tc, sc, rw = contraction_complexity(tn)
println(io, "$(typeof(tn))")
println(io, "variables: $variables")
Expand Down Expand Up @@ -110,102 +110,42 @@ $(TYPEDSIGNATURES)
* `evidence` is a dictionary of evidences, the values are integers start counting from 0.
* `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms.
* `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above.
* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
* `unity_tensors_labels` is a list of labels for the unity tensors. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity.
"""
function TensorNetworkModel(
model::UAIModel;
model::UAIModel{ET, FT};
openvars = (),
evidence = Dict{Int,Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[i] for i=1:model.nvars]
)::TensorNetworkModel
return TensorNetworkModel(
1:(model.nvars),
model.cards,
model.factors;
openvars,
evidence,
optimizer,
simplifier,
mars
)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
vars::AbstractVector{LT},
cards::AbstractVector{Int},
factors::Vector{<:Factor{T}};
openvars = (),
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {T, LT}
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
# e.g.
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...]
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
vars::AbstractVector{LT},
rawcode::EinCode,
tensors::Vector{<:AbstractArray};
evidence = Dict{LT, Int}(),
optimizer = GreedyMethod(),
simplifier = nothing,
mars = [[v] for v in vars]
)::TensorNetworkModel where {LT}
unity_tensors_labels = [[i] for i=1:model.nvars]
) where {ET, FT}
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
# The 1st argument is the contraction pattern to be optimized (without contraction order).
# The 2nd arugment is the size dictionary, which is a label-integer dictionary.
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...]
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModel(
model::UAIModel{T}, code;
evidence = Dict{Int,Int}(),
mars = [[i] for i=1:model.nvars],
vars = [1:model.nvars...]
)::TensorNetworkModel where{T}
@debug "constructing tensor network model from code"
tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...]

return TensorNetworkModel(vars, code, tensors, evidence, mars)
return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels)))
end

"""
$(TYPEDSIGNATURES)

Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms.
"""
get_vars(tn::TensorNetworkModel)::Vector = tn.vars
get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars

"""
$(TYPEDSIGNATURES)

Get the cardinalities of variables in this tensor network.
Get the ardinalities of variables in this tensor network.
"""
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
vars = get_vars(tn)
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
[fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars]
end

chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)
Expand Down Expand Up @@ -250,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model.
"""
function OMEinsum.contraction_complexity(tn::TensorNetworkModel)
return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true))))
end

# adapt array type with the target array type
match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target)
end
7 changes: 6 additions & 1 deletion src/RescaledArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ end
Base.size(arr::RescaledArray) = size(arr.normalized_value)
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)

match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))
function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T}
return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero))
end
# The following two APIs are required by OMEinsum
Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r)
Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value))
16 changes: 6 additions & 10 deletions src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ $(EXPORTS)
module TensorInference

using OMEinsum, LinearAlgebra
using OMEinsum: CacheTree, cached_einsum
using DocStringExtensions, TropicalNumbers
# The Tropical GEMM support
using StatsBase
Expand Down Expand Up @@ -40,8 +41,11 @@ export MMAPModel
# for ProblemReductions
export update_temperature

# belief propagation
export BeliefPropgation, belief_propagate

# utils
export random_matrix_product_state
export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai

include("Core.jl")
include("RescaledArray.jl")
Expand All @@ -51,14 +55,6 @@ include("map.jl")
include("mmap.jl")
include("sampling.jl")
include("cspmodels.jl")

# import PrecompileTools
# PrecompileTools.@setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# PrecompileTools.@compile_workload begin
# include("../example/asia-network/main.jl")
# end
# end
include("belief.jl")

end # module
Loading
Loading