diff --git a/Project.toml b/Project.toml index 90c66c5..014b5e0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,10 @@ authors = ["Jin-Guo Liu", "Martin Roa Villescas"] version = "0.5.0" [deps] -Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProblemReductions = "899c297d-f7d2-4ebf-8815-a35996def416" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -22,15 +20,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" TensorInferenceCUDAExt = "CUDA" [compat] -Artifacts = "1" CUDA = "4, 5" DocStringExtensions = "0.8.6, 0.9" LinearAlgebra = "1" -OMEinsum = "0.8" +OMEinsum = "0.8.7" Pkg = "1" -PrecompileTools = "1" PrettyTables = "2" ProblemReductions = "0.3" StatsBase = "0.34" TropicalNumbers = "0.5.4, 0.6" -julia = "1.9" +julia = "1.10" diff --git a/docs/src/api/public.md b/docs/src/api/public.md index 5616b95..ca1e718 100644 --- a/docs/src/api/public.md +++ b/docs/src/api/public.md @@ -43,6 +43,7 @@ RescaledArray TensorNetworkModel ArtifactProblemSpec UAIModel +BeliefPropgation ``` ## Functions @@ -56,6 +57,7 @@ marginals maximum_logp most_probable_config probability +belief_propagate dataset_from_artifact problem_from_artifact read_model @@ -69,4 +71,6 @@ sample update_evidence! update_temperature random_matrix_product_state +random_matrix_product_uai +random_tensor_train_uai ``` diff --git a/docs/src/tensor-networks.md b/docs/src/tensor-networks.md index a86442b..f1f5836 100644 --- a/docs/src/tensor-networks.md +++ b/docs/src/tensor-networks.md @@ -205,6 +205,13 @@ Some of these have been implemented in the [OMEinsum](https://github.com/under-Peter/OMEinsum.jl) package. Please check [Performance Tips](@ref) for more details. +## Belief propagation + +Belief propagation[^Yedidia2003] is a message passing algorithm that can be used to compute the marginals of a probabilistic graphical model. It has close connections with the tensor networks. It can be viewed as a way to gauge the tensor networks[^Tindall2023], and can be combined with tensor networks to achieve better performance[^Wang2024]. + +Belief propagation is an approximate method, and the quality of the approximation can be improved by the loop series expansion[^Evenbly2024]. + + ## References [^Orus2014]: @@ -227,3 +234,15 @@ Some of these have been implemented in the [^Liu2023]: Liu J G, Gao X, Cain M, et al. Computing solution space properties of combinatorial optimization problems via generic tensor networks[J]. SIAM Journal on Scientific Computing, 2023, 45(3): A1239-A1270. + +[^Yedidia2003]: + Yedidia, J.S., Freeman, W.T., Weiss, Y., 2003. Understanding belief propagation and its generalizations, in: Exploring Artificial Intelligence in the New Millennium. Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, pp. 239–269. + +[^Wang2024]: + Wang, Y., Zhang, Y.E., Pan, F., Zhang, P., 2024. Tensor Network Message Passing. Phys. Rev. Lett. 132, 117401. https://doi.org/10.1103/PhysRevLett.132.117401 + +[^Tindall2023]: + Tindall, J., Fishman, M.T., 2023. Gauging tensor networks with belief propagation. SciPost Phys. 15, 222. https://doi.org/10.21468/SciPostPhys.15.6.222 + +[^Evenbly2024]: + Evenbly, G., Pancotti, N., Milsted, A., Gray, J., Chan, G.K.-L., 2024. Loop Series Expansions for Tensor Networks. https://doi.org/10.48550/arXiv.2409.03108 diff --git a/examples/hard-core-lattice-gas/main.jl b/examples/hard-core-lattice-gas/main.jl index 14cc289..0739442 100644 --- a/examples/hard-core-lattice-gas/main.jl +++ b/examples/hard-core-lattice-gas/main.jl @@ -62,7 +62,7 @@ mars = marginals(pmodel) show_graph(SimpleGraph(graph), sites; vertex_colors=[(b = mars[[i]][2]; (1-b, 1-b, 1-b)) for i in 1:nv(graph)], texts=fill("", nv(graph))) # The can see the sites at the corner is more likely to be occupied. # To obtain two-site correlations, one can set the variables to query marginal probabilities manually. -pmodel2 = TensorNetworkModel(problem, β; mars=[[e.src, e.dst] for e in edges(graph)]) +pmodel2 = TensorNetworkModel(problem, β; unity_tensors_labels = [[e.src, e.dst] for e in edges(graph)]) mars = marginals(pmodel2); # We show the probability that both sites on an edge are not occupied diff --git a/ext/TensorInferenceCUDAExt.jl b/ext/TensorInferenceCUDAExt.jl index 40204fc..2f99883 100644 --- a/ext/TensorInferenceCUDAExt.jl +++ b/ext/TensorInferenceCUDAExt.jl @@ -1,7 +1,7 @@ module TensorInferenceCUDAExt using CUDA: CuArray import CUDA -import TensorInference: match_arraytype, keep_only!, onehot_like, togpu +import TensorInference: keep_only!, onehot_like, togpu function onehot_like(A::CuArray, j) mask = zero(A) @@ -9,9 +9,6 @@ function onehot_like(A::CuArray, j) return mask end -# NOTE: this interface should be in OMEinsum -match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target) - function keep_only!(x::CuArray{T}, j) where T CUDA.@allowscalar hotvalue = x[j] fill!(x, zero(T)) diff --git a/src/Core.jl b/src/Core.jl index dcaf26f..2b623e6 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -45,18 +45,18 @@ $(TYPEDEF) Probabilistic modeling with a tensor network. ### Fields -* `vars` are the degrees of freedom in the tensor network. +* `nvars` are the number of variables in the tensor network. * `code` is the tensor network contraction pattern. -* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `mars`. +* `tensors` are the tensors fed into the tensor network, the leading tensors are unity tensors associated with `unity_tensors_labels`. * `evidence` is a dictionary used to specify degrees of freedom that are fixed to certain values. -* `mars` is a vector, each element is a vector of variables to compute marginal probabilities. +* `unity_tensors_idx` is a vector of indices of the unity tensors in the `tensors` array. Unity tensors are dummy tensors used to obtain the marginal probabilities. """ -struct TensorNetworkModel{LT, ET, MT <: AbstractArray} - vars::Vector{LT} +struct TensorNetworkModel{ET, MT <: AbstractArray} + nvars::Int code::ET tensors::Vector{MT} - evidence::Dict{LT, Int} - mars::Vector{Vector{LT}} + evidence::Dict{Int, Int} + unity_tensors_idx::Vector{Int} end """ @@ -78,7 +78,7 @@ end function Base.show(io::IO, tn::TensorNetworkModel) open = getiyv(tn.code) - variables = join([string_var(var, open, tn.evidence) for var in tn.vars], ", ") + variables = join([string_var(var, open, tn.evidence) for var in get_vars(tn)], ", ") tc, sc, rw = contraction_complexity(tn) println(io, "$(typeof(tn))") println(io, "variables: $variables") @@ -110,84 +110,25 @@ $(TYPEDSIGNATURES) * `evidence` is a dictionary of evidences, the values are integers start counting from 0. * `optimizer` is the tensor network contraction order optimizer, please check the package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl) for available algorithms. * `simplifier` is some strategies for speeding up the `optimizer`, please refer the same link above. -* `mars` is a list of marginal probabilities. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity. +* `unity_tensors_labels` is a list of labels for the unity tensors. It is all single variables by default, i.e. `[[1], [2], ..., [n]]`. One can also specify multi-variables, which may increase the computational complexity. """ function TensorNetworkModel( - model::UAIModel; + model::UAIModel{ET, FT}; openvars = (), evidence = Dict{Int,Int}(), optimizer = GreedyMethod(), simplifier = nothing, - mars = [[i] for i=1:model.nvars] -)::TensorNetworkModel - return TensorNetworkModel( - 1:(model.nvars), - model.cards, - model.factors; - openvars, - evidence, - optimizer, - simplifier, - mars - ) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - vars::AbstractVector{LT}, - cards::AbstractVector{Int}, - factors::Vector{<:Factor{T}}; - openvars = (), - evidence = Dict{LT, Int}(), - optimizer = GreedyMethod(), - simplifier = nothing, - mars = [[v] for v in vars] -)::TensorNetworkModel where {T, LT} - # The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors, - # The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor, - # e.g. - # `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication. - rawcode = EinCode([mars..., [[factor.vars...] for factor in factors]...], collect(LT, openvars)) # labels for vertex tensors (unity tensors) and edge tensors - tensors = Array{T}[[ones(T, [cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in factors]...] - return TensorNetworkModel(collect(LT, vars), rawcode, tensors; evidence, optimizer, simplifier, mars) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - vars::AbstractVector{LT}, - rawcode::EinCode, - tensors::Vector{<:AbstractArray}; - evidence = Dict{LT, Int}(), - optimizer = GreedyMethod(), - simplifier = nothing, - mars = [[v] for v in vars] -)::TensorNetworkModel where {LT} + unity_tensors_labels = [[i] for i=1:model.nvars] +) where {ET, FT} # `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified. # The 1st argument is the contraction pattern to be optimized (without contraction order). # The 2nd arugment is the size dictionary, which is a label-integer dictionary. # The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify. + rawcode = EinCode([unity_tensors_labels..., [[factor.vars...] for factor in model.factors]...], collect(Int, openvars)) # labels for vertex tensors (unity tensors) and edge tensors + tensors = Array{ET}[[ones(ET, [model.cards[i] for i in lb]...) for lb in unity_tensors_labels]..., [t.vals for t in model.factors]...] size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors) code = optimize_code(rawcode, size_dict, optimizer, simplifier) - TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars) -end - -""" -$(TYPEDSIGNATURES) -""" -function TensorNetworkModel( - model::UAIModel{T}, code; - evidence = Dict{Int,Int}(), - mars = [[i] for i=1:model.nvars], - vars = [1:model.nvars...] -)::TensorNetworkModel where{T} - @debug "constructing tensor network model from code" - tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...] - - return TensorNetworkModel(vars, code, tensors, evidence, mars) + return TensorNetworkModel(model.nvars, code, tensors, evidence, collect(Int, 1:length(unity_tensors_labels))) end """ @@ -195,17 +136,16 @@ $(TYPEDSIGNATURES) Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms. """ -get_vars(tn::TensorNetworkModel)::Vector = tn.vars +get_vars(tn::TensorNetworkModel)::Vector = 1:tn.nvars """ $(TYPEDSIGNATURES) -Get the cardinalities of variables in this tensor network. +Get the ardinalities of variables in this tensor network. """ function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector - vars = get_vars(tn) size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors) - [fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)] + [fixedisone && haskey(tn.evidence, k) ? 1 : size_dict[k] for k in 1:tn.nvars] end chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence) @@ -250,7 +190,4 @@ Returns the contraction complexity of a tensor newtork model. """ function OMEinsum.contraction_complexity(tn::TensorNetworkModel) return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone = true)))) -end - -# adapt array type with the target array type -match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target) +end \ No newline at end of file diff --git a/src/RescaledArray.jl b/src/RescaledArray.jl index 2e48e7a..8cbebec 100644 --- a/src/RescaledArray.jl +++ b/src/RescaledArray.jl @@ -46,4 +46,9 @@ end Base.size(arr::RescaledArray) = size(arr.normalized_value) Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i) -match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target)) +function OMEinsum.get_output_array(xs::NTuple{N, RescaledArray{T}}, size, fillzero::Bool) where {N, T} + return RescaledArray(zero(T), OMEinsum.get_output_array(getfield.(xs, :normalized_value), size, fillzero)) +end +# The following two APIs are required by OMEinsum +Base.fill!(r::RescaledArray, x) = (fill!(r.normalized_value, x ./ exp(r.log_factor)); r) +Base.conj(r::RescaledArray) = RescaledArray(conj(r.log_factor), conj(r.normalized_value)) diff --git a/src/TensorInference.jl b/src/TensorInference.jl index 19125ba..a1e7482 100644 --- a/src/TensorInference.jl +++ b/src/TensorInference.jl @@ -8,6 +8,7 @@ $(EXPORTS) module TensorInference using OMEinsum, LinearAlgebra +using OMEinsum: CacheTree, cached_einsum using DocStringExtensions, TropicalNumbers # The Tropical GEMM support using StatsBase @@ -40,8 +41,11 @@ export MMAPModel # for ProblemReductions export update_temperature +# belief propagation +export BeliefPropgation, belief_propagate + # utils -export random_matrix_product_state +export random_matrix_product_state, random_tensor_train_uai, random_matrix_product_uai include("Core.jl") include("RescaledArray.jl") @@ -51,14 +55,6 @@ include("map.jl") include("mmap.jl") include("sampling.jl") include("cspmodels.jl") - -# import PrecompileTools -# PrecompileTools.@setup_workload begin -# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the -# # precompile file and potentially make loading faster. -# PrecompileTools.@compile_workload begin -# include("../example/asia-network/main.jl") -# end -# end +include("belief.jl") end # module diff --git a/src/belief.jl b/src/belief.jl new file mode 100644 index 0000000..eede996 --- /dev/null +++ b/src/belief.jl @@ -0,0 +1,154 @@ +""" +$TYPEDEF + BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where T + +A belief propagation object. + +### Fields +- `t2v::Vector{Vector{Int}}`: a mapping from tensors to variables +- `v2t::Vector{Vector{Int}}`: a mapping from variables to tensors +- `tensors::Vector{AbstractArray{T}}`: the tensors +""" +struct BeliefPropgation{T} + t2v::Vector{Vector{Int}} # a mapping from tensors to variables + v2t::Vector{Vector{Int}} # a mapping from variables to tensors + tensors::Vector{AbstractArray{T}} # the tensors +end +num_tensors(bp::BeliefPropgation) = length(bp.t2v) +ProblemReductions.num_variables(bp::BeliefPropgation) = length(bp.v2t) + +function BeliefPropgation(nvars::Int, t2v::AbstractVector{Vector{Int}}, tensors::AbstractVector{AbstractArray{T}}) where {T} + # initialize the inverse mapping + v2t = [Int[] for _ in 1:nvars] + for (i, edge) in enumerate(t2v) + for v in edge + push!(v2t[v], i) + end + end + return BeliefPropgation(t2v, v2t, tensors) +end + +""" +$(TYPEDSIGNATURES) + +Construct a belief propagation object from a [`UAIModel`](@ref). +""" +function BeliefPropgation(uai::UAIModel{T}) where {T} + return BeliefPropgation(uai.nvars, [collect(Int, f.vars) for f in uai.factors], AbstractArray{T}[f.vals for f in uai.factors]) +end + +struct BPState{T, VT <: AbstractVector{T}} + message_in::Vector{Vector{VT}} # for each variable, we store the incoming messages + message_out::Vector{Vector{VT}} # the outgoing messages +end + +# message_in -> message_out +function process_message!(bp::BPState; normalize, damping) + for (ov, iv) in zip(bp.message_out, bp.message_in) + _process_message!(ov, iv, normalize, damping) + end +end +function _process_message!(ov::Vector, iv::Vector, normalize::Bool, damping) + # process the message, TODO: speed up if needed! + for (i, v) in enumerate(ov) + w = similar(v) + fill!(w, one(eltype(v))) # clear the output vector + for (j, u) in enumerate(iv) + j != i && (w .*= u) + end + normalize && normalize!(w, 1) + v .= v .* damping + (1 - damping) * w + end +end + +function collect_message!(bp::BeliefPropgation, state::BPState; normalize::Bool) + for it in 1:num_tensors(bp) + out = vectors_on_tensor(state.message_in, bp, it) + _collect_message!(out, bp.tensors[it], vectors_on_tensor(state.message_out, bp, it)) + normalize && normalize!.(out, 1) + end +end +# collect the vectors associated with the target tensor +function vectors_on_tensor(messages, bp::BeliefPropgation, it::Int) + return map(bp.t2v[it]) do v + # the message goes to the idx-th tensor from variable v + messages[v][findfirst(==(it), bp.v2t[v])] + end +end +function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Vector) + @assert length(vectors_out) == length(vectors_in) == ndims(t) "dimensions mismatch: $(length(vectors_out)), $(length(vectors_in)), $(ndims(t))" + # TODO: speed up if needed! + code = star_code(length(vectors_in)) + cost, gradient = cost_and_gradient(code, (t, vectors_in...)) + for (o, g) in zip(vectors_out, gradient[2:end]) + o .= g + end + return cost[] +end + +# star code: contract a tensor with multiple vectors, one for each dimension +function star_code(n::Int) + ix1, ixrest = collect(1:n), [[i] for i in 1:n] + ne = DynamicNestedEinsum([DynamicNestedEinsum{Int}(1), DynamicNestedEinsum{Int}(2)], DynamicEinCode([ix1, ixrest[1]], collect(2:n))) + for i in 2:n + ne = DynamicNestedEinsum([ne, DynamicNestedEinsum{Int}(i + 1)], DynamicEinCode([ne.eins.iy, ixrest[i]], collect((i + 1):n))) + end + return ne +end + +function initial_state(bp::BeliefPropgation{T}) where {T} + size_dict = OMEinsum.get_size_dict(bp.t2v, bp.tensors) + edges_vectors = Vector{Vector{T}}[] + for (i, tids) in enumerate(bp.v2t) + push!(edges_vectors, [ones(T, size_dict[i]) for _ in 1:length(tids)]) + end + return BPState(deepcopy(edges_vectors), edges_vectors) +end + +""" +$(TYPEDSIGNATURES) + +Run the belief propagation algorithm, and return the final state and the information about the convergence. + +### Arguments +- `bp::BeliefPropgation`: the belief propagation object + +### Keyword Arguments +- `max_iter::Int=100`: the maximum number of iterations +- `tol::Float64=1e-6`: the tolerance for the convergence +- `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message +""" +function belief_propagate(bp::BeliefPropgation; kwargs...) + state = initial_state(bp) + info = belief_propagate!(bp, state; kwargs...) + return state, info +end +struct BPInfo + converged::Bool + iterations::Int +end +function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::Int = 100, tol = 1e-6, damping = 0.2) where {T} + pre_message_in = deepcopy(state.message_in) + for i in 1:max_iter + collect_message!(bp, state; normalize = true) + process_message!(state; normalize = true, damping = damping) + # check convergence + if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) + return BPInfo(true, i) + end + pre_message_in = deepcopy(state.message_in) + end + return BPInfo(false, max_iter) +end + +# if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction +function contraction_results(state::BPState{T}) where {T} + return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] +end + +""" +$(TYPEDSIGNATURES) +""" +function marginals(state::BPState{T}) where {T} + return Dict([v] => normalize!(reduce((x, y) -> x .* y, mi), 1) for (v, mi) in enumerate(state.message_in)) +end \ No newline at end of file diff --git a/src/cspmodels.jl b/src/cspmodels.jl index 128f684..dbd7b58 100644 --- a/src/cspmodels.jl +++ b/src/cspmodels.jl @@ -28,10 +28,11 @@ Convert a constraint satisfiability problem (or energy model) to a probabilistic * `mars` is the list of variables to be marginalized. """ function TensorNetworkModel(problem::ConstraintSatisfactionProblem, β::T; evidence::Dict=Dict{Int,Int}(), - optimizer=GreedyMethod(), openvars=Int[], simplifier=nothing, mars=[[l] for l in variables(problem)]) where T <: Real + optimizer=GreedyMethod(), openvars=Int[], simplifier=nothing, unity_tensors_labels = [[l] for l in variables(problem)]) where T <: Real tensors, ixs = generate_tensors(β, problem) factors = [Factor((ix...,), t) for (ix, t) in zip(ixs, tensors)] - return TensorNetworkModel(variables(problem), fill(num_flavors(problem), num_variables(problem)), factors; openvars, evidence, optimizer, simplifier, mars) + model = UAIModel(num_variables(problem), fill(num_flavors(problem), num_variables(problem)), factors) + return TensorNetworkModel(model; openvars, evidence, optimizer, simplifier, unity_tensors_labels) end """ @@ -47,8 +48,9 @@ The program will regenerate tensors from the problem, without repeated optimizin """ function update_temperature(tnet::TensorNetworkModel, problem::ConstraintSatisfactionProblem, β::Real) tensors, ixs = generate_tensors(β, problem) - alltensors = [tnet.tensors[1:length(tnet.mars)]..., tensors...] - return TensorNetworkModel(tnet.vars, tnet.code, alltensors, tnet.evidence, tnet.mars) + @assert tnet.unity_tensors_idx == collect(1:length(tnet.unity_tensors_idx)) "The target tensor network can not be updated! Got `unity_tensors_idx = $(tnet.unity_tensors_idx)`" + alltensors = [tnet.tensors[tnet.unity_tensors_idx]..., tensors...] + return TensorNetworkModel(tnet.nvars, tnet.code, alltensors, tnet.evidence, tnet.unity_tensors_idx) end function MMAPModel(problem::ConstraintSatisfactionProblem, β::Real; diff --git a/src/map.jl b/src/map.jl index 372133f..21a7c3c 100644 --- a/src/map.jl +++ b/src/map.jl @@ -2,7 +2,7 @@ ########### Backward tropical tensor contraction ############## # This part is copied from [`GenericTensorNetworks`](https://github.com/QuEraComputing/GenericTensorNetworks.jl). -function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy) +function OMEinsum.einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Tropical}} where {M}, y, size_dict, dy) return backward_tropical!(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), y, dy, size_dict) end @@ -53,13 +53,18 @@ $(TYPEDSIGNATURES) Returns the largest log-probability and the most probable configuration. """ function most_probable_config(tn::TensorNetworkModel; usecuda = false)::Tuple{Real, Vector} - expected_mars = [[l] for l in get_vars(tn)] - @assert tn.mars[1:length(expected_mars)] == expected_mars "To get the the most probable configuration, the leading elements of `tn.vars` must be `$expected_mars`" - vars = get_vars(tn) + tensor_indices = check_queryvars(tn, [[v] for v in 1:tn.nvars]) tensors = map(t -> Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale = false)) - logp, grads = cost_and_gradient(tn.code, tensors) + logp, grads = cost_and_gradient(tn.code, (tensors...,)) # use Array to convert CuArray to CPU arrays - return content(Array(logp)[]), map(k -> haskey(tn.evidence, vars[k]) ? tn.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars)) + return content(Array(logp)[]), map(k -> haskey(tn.evidence, k) ? tn.evidence[k] : argmax(grads[tensor_indices[k]]) - 1, 1:tn.nvars) +end +# check if the queryvars are included in the unity tensors labels, if yes, return the indices of the unity tensors +function check_queryvars(tn::TensorNetworkModel, queryvars::AbstractVector{Vector{Int}}) + ixs = OMEinsum.getixsv(tn.code) + indices = [findfirst(==(l), ixs[tn.unity_tensors_idx]) for l in queryvars] + @assert !any(isnothing, indices) "To get the the most probable configuration, the unity tensors labels must include all variables. Query variables: $queryvars, Unity tensors labels: $(ixs[tn.unity_tensors_idx])" + return tn.unity_tensors_idx[indices] end """ diff --git a/src/mar.jl b/src/mar.jl index a436d3d..3e399b4 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -12,115 +12,6 @@ function adapt_tensors(code, tensors, evidence; usecuda, rescale) end end -# ######### Inference by back propagation ############ -# `CacheTree` stores intermediate `NestedEinsum` contraction results. -# It is a tree structure that isomorphic to the contraction tree, -# `content` is the cached intermediate contraction result. -# `children` are the children of current node, e.g. tensors that are contracted to get `content`. -mutable struct CacheTree{T} - content::AbstractArray{T} - const children::Vector{CacheTree{T}} -end - -function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict) - # slicing is not supported yet. - if length(se.slicing) != 0 - @warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`." - end - return cached_einsum(se.eins, xs, size_dict) -end - -# recursively contract and cache a tensor network -function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict) - if OMEinsum.isleaf(code) - # For a leaf node, cache the input tensor - y = xs[code.tensorindex] - return CacheTree(y, CacheTree{eltype(y)}[]) - else - # For a non-leaf node, compute the einsum and cache the contraction result - caches = [cached_einsum(arg, xs, size_dict) for arg in code.args] - # `einsum` evaluates the einsum contraction, - # Its 1st argument is the contraction pattern, - # Its 2nd one is a tuple of input tensors, - # Its 3rd argument is the size dictionary (label as the key, size as the value). - y = einsum(code.eins, ntuple(i -> caches[i].content, length(caches)), size_dict) - return CacheTree(y, caches) - end -end - -# computed gradient tree by back propagation -function generate_gradient_tree(se::SlicedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} - if length(se.slicing) != 0 - @warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`." - end - return generate_gradient_tree(se.eins, cache, dy, size_dict) -end - -# recursively compute the gradients and store it into a tree. -# also known as the back-propagation algorithm. -function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::AbstractArray{T}, size_dict::Dict) where {T} - if OMEinsum.isleaf(code) - return CacheTree(dy, CacheTree{T}[]) - else - xs = ntuple(i -> cache.children[i].content, length(cache.children)) - # `einsum_grad` is the back-propagation rule for einsum function. - # If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)` - # Then the back-propagation pass is - # ``` - # A̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 1) - # B̅ = einsum_grad(inputs_labels, (A, B, ...), output_labels, size_dict, y̅, 2) - # ... - # ``` - # Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`... - dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy) - return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict))) - end -end - -# a unified interface of the backward rules for real numbers and tropical numbers -function einsum_backward_rule(eins, xs::NTuple{M, AbstractArray{<:Real}} where {M}, y, size_dict, dy) - return ntuple(i -> OMEinsum.einsum_grad(OMEinsum.getixs(eins), xs, OMEinsum.getiy(eins), size_dict, dy, i), length(xs)) -end - -# the main function for generating the gradient tree. -function gradient_tree(code, xs) - # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary. - size_dict = OMEinsum.get_size_dict!(getixsv(code), xs, Dict{Int, Int}()) - # forward compute and cache intermediate results. - cache = cached_einsum(code, xs, size_dict) - # initialize `y̅` as `1`. Note we always start from `L̅ := 1`. - dy = match_arraytype(typeof(cache.content), ones(eltype(cache.content), size(cache.content))) - # back-propagate - return copy(cache.content), generate_gradient_tree(code, cache, dy, size_dict) -end - -# evaluate the cost and the gradient of leaves -function cost_and_gradient(code, xs) - cost, tree = gradient_tree(code, xs) - # extract the gradients on leaves (i.e. the input tensors). - return cost, extract_leaves(code, tree) -end - -# since slicing is not supported, we forward it to NestedEinsum. -extract_leaves(code::SlicedEinsum, cache::CacheTree) = extract_leaves(code.eins, cache) - -# extract gradients on leaf nodes. -function extract_leaves(code::NestedEinsum, cache::CacheTree) - res = Vector{Any}(undef, length(getixsv(code))) - return extract_leaves!(code, cache, res) -end - -function extract_leaves!(code, cache, res) - if OMEinsum.isleaf(code) - # extract - res[code.tensorindex] = cache.content - else - # resurse deeper - extract_leaves!.(code.args, cache.children, Ref(res)) - end - return res -end - """ $(TYPEDSIGNATURES) @@ -130,14 +21,16 @@ are their respective marginals. A marginal is a probability distribution over a subset of variables, obtained by integrating or summing over the remaining variables in the model. By default, the function returns the marginals of all individual variables. To specify which marginal variables to query, set the -`mars` field when constructing a [`TensorNetworkModel`](@ref). Note that +`unity_tensors_labels` field when constructing a [`TensorNetworkModel`](@ref). Note that the choice of marginal variables will affect the contraction order of the tensor network. ### Arguments - `tn`: The [`TensorNetworkModel`](@ref) to query. -- `usecuda`: Specifies whether to use CUDA for tensor contraction. -- `rescale`: Specifies whether to rescale the tensors during contraction. + +### Keyword Arguments +- `usecuda::Bool`: Specifies whether to use CUDA for tensor contraction. +- `rescale::Bool`: Specifies whether to rescale the tensors during contraction. ### Example The following example is taken from [`examples/asia-network/main.jl`](https://tensorbfs.github.io/TensorInference.jl/dev/generated/asia-network/main/). @@ -158,7 +51,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries: [7] => [0.145092, 0.854908] [2] => [0.05, 0.95] -julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]); +julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), unity_tensors_labels = [[2, 3], [3, 4]]); julia> marginals(tn2) Dict{Vector{Int64}, Matrix{Float64}} with 2 entries: @@ -184,11 +77,13 @@ probabilities of the queried variables, represented by tensors. """ function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}} # sometimes, the cost can overflow, then we need to rescale the tensors during contraction. - cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale)) + cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,)) @debug "cost = $cost" + ixs = OMEinsum.getixsv(tn.code) + queryvars = ixs[tn.unity_tensors_idx] if rescale - return Dict(zip(tn.mars, LinearAlgebra.normalize!.(getfield.(grads[1:length(tn.mars)], :normalized_value), 1))) + return Dict(zip(queryvars, LinearAlgebra.normalize!.(getfield.(grads[tn.unity_tensors_idx], :normalized_value), 1))) else - return Dict(zip(tn.mars, LinearAlgebra.normalize!.(grads[1:length(tn.mars)], 1))) + return Dict(zip(queryvars, LinearAlgebra.normalize!.(grads[tn.unity_tensors_idx], 1))) end -end +end \ No newline at end of file diff --git a/src/mmap.jl b/src/mmap.jl index 05d2deb..0d877e4 100644 --- a/src/mmap.jl +++ b/src/mmap.jl @@ -178,7 +178,7 @@ end function most_probable_config(mmap::MMAPModel; usecuda = false)::Tuple{Real, Vector} vars = get_vars(mmap) tensors = map(t -> OMEinsum.asarray(Tropical.(log.(t)), t), adapt_tensors(mmap; usecuda, rescale = false)) - logp, grads = cost_and_gradient(mmap.code, tensors) + logp, grads = cost_and_gradient(mmap.code, (tensors...,)) # use Array to convert CuArray to CPU arrays return content(Array(logp)[]), map(k -> haskey(mmap.evidence, vars[k]) ? mmap.evidence[vars[k]] : argmax(grads[k]) - 1, 1:length(vars)) end diff --git a/src/sampling.jl b/src/sampling.jl index 588ccc3..b94d880 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -134,9 +134,9 @@ function generate_samples!(code::DynamicNestedEinsum, cache::CacheTree{T}, iy_en @assert length(iy_env) == ndims(env) if !(OMEinsum.isleaf(code)) ixs, iy = getixsv(code.eins), getiyv(code.eins) - for (subcode, child, ix) in zip(code.args, cache.children, ixs) + for (subcode, child, ix) in zip(code.args, cache.siblings, ixs) # subenv for the current child, use it to sample and update its cache - siblings = filter(x->x !== child, cache.children) + siblings = filter(x->x !== child, cache.siblings) siblings_ixs = filter(x->x !== ix, ixs) iy_subenv = batch_label ∈ ix ? ix : [ix..., batch_label] envcode = optimize_code(EinCode([siblings_ixs..., iy_env], iy_subenv), size_dict, GreedyMethod(; nrepeat=1)) @@ -184,12 +184,12 @@ end function udpate_cache_tree!(ne::NestedEinsum, cache::CacheTree{T}, el::Pair{<:AbstractVector{L}}, batch_label::L, size_dict::Dict{L}) where {T, L} OMEinsum.isleaf(ne) && return updated = false - for (subcode, child, ix) in zip(ne.args, cache.children, getixsv(ne.eins)) + for (subcode, child, ix) in zip(ne.args, cache.siblings, getixsv(ne.eins)) if any(x->x ∈ el.first, ix) updated = true child.content = _eliminate!(child.content, ix, el, batch_label) udpate_cache_tree!(subcode, child, el, batch_label, size_dict) end end - updated && (cache.content = einsum(ne.eins, (getfield.(cache.children, :content)...,), size_dict)) + updated && (cache.content = einsum(ne.eins, (getfield.(cache.siblings, :content)...,), size_dict)) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index ce90819..11064b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -334,6 +334,18 @@ connected in a chain. - `d` is the dimension of the physical indices. """ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) where T + uai = random_matrix_product_uai(T, n, chi, d) + return TensorNetworkModel(uai; optimizer=GreedyMethod()) +end +random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) + +""" +$TYPEDSIGNATURES + +Generate a random UAIModel that represents a matrix product state (MPS). +Similar to [`random_matrix_product_state`](@ref), but returns the UAIModel directly. +""" +function random_matrix_product_uai(::Type{T}, n::Int, chi::Int, d::Int=2) where T # chi ^ (n-1) * (variance^n)^2 == 1/d^n variance = d^(-1/2) * chi^(-1/2+1/2n) tensors = Any[randn(T, d, chi) .* variance] @@ -351,12 +363,41 @@ function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) wher push!(ixs_ket, [virtual_indices_ket[n-1], physical_indices[n]]) push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]]) tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...] - return TensorNetworkModel( - collect(1:3n-2), - optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()), - tensors, - Dict{Int, Int}(), - Vector{Int}[[i] for i=1:n] + size_dict = OMEinsum.get_size_dict(ixs, tensors) + nvars = 3n-2 + return UAIModel( + nvars, + [size_dict[i] for i=1:nvars], + [Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)] ) end -random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d) + + +""" +$TYPEDSIGNATURES + +Tensor train (TT) is a tensor network model that is widely used in quantum +many-body physics. This model is different from the matrix product state (MPS) +in that it does not have an extra copy for representing the bra state. +""" +function random_tensor_train_uai(::Type{T}, n::Int, chi::Int, d::Int=2; periodic=false) where T + # chi ^ (n-1) * (variance^n)^2 == 1/d^n + variance = d^(-1/2) * chi^(-1/2+1/2n) + physical_indices = collect(1:n) + virtual_indices = collect(n+1:2n) + tensors = Any[(periodic ? rand(T, chi, d, chi) : rand(T, d, chi)) .* variance] + ixs = [periodic ? [virtual_indices[n], physical_indices[1], virtual_indices[1]] : [physical_indices[1], virtual_indices[1]]] + for i = 2:n-1 + push!(tensors, rand(T, chi, d, chi) .* variance) + push!(ixs, [virtual_indices[i-1], physical_indices[i], virtual_indices[i]]) + end + push!(tensors, (periodic ? rand(T, chi, d, chi) : rand(T, chi, d)) .* variance) + push!(ixs, periodic ? [virtual_indices[n-1], physical_indices[n], virtual_indices[n]] : [virtual_indices[n-1], physical_indices[n]]) + size_dict = OMEinsum.get_size_dict(ixs, tensors) + nvars = periodic ? 2n : 2n-1 + return UAIModel( + nvars, + [size_dict[i] for i=1:nvars], + [Factor((ixs[i]...,), tensors[i]) for i in 1:length(tensors)] + ) +end \ No newline at end of file diff --git a/test/belief.jl b/test/belief.jl new file mode 100644 index 0000000..150c302 --- /dev/null +++ b/test/belief.jl @@ -0,0 +1,103 @@ +using TensorInference, Test +using OMEinsum, LinearAlgebra + +@testset "process message" begin + mi = [[1.0, 2, 3], [2.0, 3, 4], [3.0, 4, 5]] + mo_expected = [[6.0, 12, 20], [3.0, 8, 15], [2.0, 6, 12]] + mo = similar.(mi) + TensorInference._process_message!(mo, mi, false, 0) + for i in 1:length(mo) + @test mo[i] ≈ mo_expected[i] atol=1e-8 + end + + TensorInference._process_message!(mo, mi, true, 0) + for i in 1:length(mo) + @test mo[i] ≈ normalize!(mo_expected[i], 1) atol=1e-8 + end +end + +@testset "star code" begin + code = TensorInference.star_code(3) + c1, c2, c3, c4 = [DynamicNestedEinsum{Int}(i) for i in 1:4] + ne1 = DynamicNestedEinsum([c1, c2], DynamicEinCode([[1, 2, 3], [1]], [2, 3])) + ne2 = DynamicNestedEinsum([ne1, c3], DynamicEinCode([[2, 3], [2]], [3])) + ne3 = DynamicNestedEinsum([ne2, c4], DynamicEinCode([[3], [3]], Int[])) + @test code == ne3 + t = randn(2, 2, 2) + v1 = randn(2) + v2 = randn(2) + v3 = randn(2) + vectors_out = [similar(v1), similar(v2), similar(v3)] + TensorInference._collect_message!(vectors_out, t, [v1, v2, v3]) + @test vectors_out[1] ≈ reshape(t, 2, 4) * kron(v3, v2) # NOTE: v3 is the little end + @test vectors_out[2] ≈ vec(v1' * reshape(reshape(t, 4, 2) * v3, 2, 2)) + @test vectors_out[3] ≈ vec(kron(v2, v1)' * reshape(t, 4, 2)) +end + +@testset "constructor" begin + problem = problem_from_artifact("uai2014", "MAR", "Promedus", 14) + uai = read_model(problem) + bp = BeliefPropgation(uai) + @test length(bp.v2t) == 414 + @test TensorInference.num_tensors(bp) == 414 + @test TensorInference.num_variables(bp) == length(unique(vcat([collect(Int, f.vars) for f in uai.factors]...))) +end + +@testset "belief propagation" begin + n = 5 + chi = 3 + mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi) + bp = BeliefPropgation(mps_uai) + @test TensorInference.initial_state(bp) isa TensorInference.BPState + state, info = belief_propagate(bp) + @test info.converged + @test info.iterations < 20 + mars = marginals(state) + tnet = TensorNetworkModel(mps_uai) + mars_tnet = marginals(tnet) + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-6 + end +end + +@testset "belief propagation on circle" begin + n = 10 + chi = 3 + mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true) + bp = BeliefPropgation(mps_uai) + @test TensorInference.initial_state(bp) isa TensorInference.BPState + state, info = belief_propagate(bp; max_iter=100, tol=1e-6) + @test info.converged + @test info.iterations < 100 + contraction_res = TensorInference.contraction_results(state) + tnet = TensorNetworkModel(mps_uai) + mars = marginals(state) + mars_tnet = marginals(tnet) + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4 + end +end + +@testset "marginal uai2014" begin + for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)] + optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100) + evidence = Dict{Int, Int}() + model = read_model(problem) + + tn = TensorNetworkModel(model; optimizer, evidence) + mars_tnet = marginals(tn) + + code = tn.code.eins + tensors = tn.tensors + size_dict = Dict(i => d for (i, d) in enumerate(model.cards)) + + bp = BeliefPropgation(model) + state, info = belief_propagate(bp; max_iter=300, tol=1e-6) + @test info.converged + mars = marginals(state) + + for v in 1:TensorInference.num_variables(bp) + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-2 + end + end +end \ No newline at end of file diff --git a/test/cspmodels.jl b/test/cspmodels.jl index c7b559b..cc26457 100644 --- a/test/cspmodels.jl +++ b/test/cspmodels.jl @@ -7,7 +7,7 @@ using GenericTensorNetworks β = 2.0 g = GenericTensorNetworks.Graphs.smallgraph(:petersen) problem = IndependentSet(g) - model = TensorNetworkModel(problem, β; mars=[[2, 3]]) + model = TensorNetworkModel(problem, β; unity_tensors_labels = [[2, 3]]) mars = marginals(model)[[2, 3]] problem2 = IndependentSet(g) mars2 = TensorInference.normalize!(GenericTensorNetworks.solve(GenericTensorNetwork(problem2; openvertices=[2, 3]), PartitionFunction(β)), 1) @@ -28,7 +28,7 @@ using GenericTensorNetworks β = 1.0 problem = SpinGlass(g, -ones(Int, ne(g)), zeros(Int, nv(g))) - model = TensorNetworkModel(problem, β; mars=[[2, 3]]) + model = TensorNetworkModel(problem, β; unity_tensors_labels = [[2, 3]]) samples = sample(model, 100) @test sum(energy.(Ref(problem), samples))/100 <= -14 end \ No newline at end of file diff --git a/test/map.jl b/test/map.jl index abc0c98..c5c95e0 100644 --- a/test/map.jl +++ b/test/map.jl @@ -2,16 +2,14 @@ using Test using OMEinsum using TensorInference -@testset "load from code" begin +@testset "load from model" begin model = problem_from_artifact("uai2014", "MAR", "Promedus", 14) tn1 = TensorNetworkModel(read_model(model); evidence=read_evidence(model), optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80)) - tn2 = TensorNetworkModel(read_model(model), tn1.code, evidence=read_evidence(model)) - - @test tn1.code == tn2.code + @test tn1 isa TensorNetworkModel end @testset "gradient-based tensor network solvers" begin @@ -22,8 +20,7 @@ end evidence=read_evidence(model), optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80)) @debug contraction_complexity(tn) - most_probable_config(tn) - @time logp, config = most_probable_config(tn) + logp, config = most_probable_config(tn) @test log_probability(tn, config) ≈ logp @test maximum_logp(tn)[] ≈ logp end diff --git a/test/mar.jl b/test/mar.jl index 7aa10b4..01843eb 100644 --- a/test/mar.jl +++ b/test/mar.jl @@ -1,6 +1,5 @@ using Test using OMEinsum -using KaHyPar using TensorInference @testset "composite number" begin @@ -9,6 +8,13 @@ using TensorInference op = ein"ij, j -> i" @test Array(x) ≈ exp(2.0) .* [2.0, 3.0] @test op(Array(A), Array(x)) ≈ Array(op(A, x)) + + @test OMEinsum.get_output_array((A,), (2,), true) ≈ RescaledArray(0.0, [0.0, 0.0]) + @test fill!(RescaledArray(0.0, [0.0, 0.0]), 5.0) ≈ [5.0, 5.0] + + C = RescaledArray(2.0 + 1im, [2.0im 3.0; 5.0 6.0]) + @test conj(C) isa RescaledArray + @test conj(C) ≈ RescaledArray(2.0 - 1im, [-2.0im 3.0; 5.0 6.0]) end @testset "cached, rescaled contract" begin @@ -24,12 +30,12 @@ end # cached contract xs = TensorInference.adapt_tensors(tn; usecuda = false, rescale = true) size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}()) - cache = TensorInference.cached_einsum(tn.code, xs, size_dict) + cache = OMEinsum.cached_einsum(tn.code, xs, size_dict) @test cache.content isa RescaledArray @test Array(cache.content) ≈ p1 # compute marginals - ti_sol = marginals(tn) + ti_sol = marginals(tn; rescale = true) ref_sol[collect(keys(evidence))] .= fill([1.0], length(evidence)) # imitate dummy vars @test isapprox([ti_sol[[i]] for i=1:length(ref_sol)], ref_sol; atol = 1e-5) end @@ -116,7 +122,7 @@ end 0.1 0.3 0.2 0.9 """) n = 10000 - tnet = TensorNetworkModel(model; mars=[[2, 3], [3, 4]]) + tnet = TensorNetworkModel(model; unity_tensors_labels = [[2, 3], [3, 4]]) mars = marginals(tnet) tnet23 = TensorNetworkModel(model; openvars=[2,3]) tnet34 = TensorNetworkModel(model; openvars=[3,4]) @@ -124,8 +130,8 @@ end @test mars[[3, 4]] ≈ probability(tnet34) vars = [[2, 4], [3, 5]] - tnet1 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>1)) - tnet2 = TensorNetworkModel(model; mars=vars, evidence=Dict(3=>0)) + tnet1 = TensorNetworkModel(model; unity_tensors_labels = vars, evidence=Dict(3=>1)) + tnet2 = TensorNetworkModel(model; unity_tensors_labels = vars, evidence=Dict(3=>0)) mars1 = marginals(tnet1) mars2 = marginals(tnet2) update_evidence!(tnet1, Dict(3=>0)) diff --git a/test/pr.jl b/test/pr.jl index 53c1c1e..f6c0c08 100644 --- a/test/pr.jl +++ b/test/pr.jl @@ -1,6 +1,5 @@ using Test using OMEinsum -using KaHyPar using TensorInference @testset "UAI Reference Solution Comparison" begin diff --git a/test/runtests.jl b/test/runtests.jl index 6bd7da2..c32af7a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,14 @@ end include("cspmodels.jl") end +@testset "utils" begin + include("utils.jl") +end + +@testset "belief propagation" begin + include("belief.jl") +end + using CUDA if CUDA.functional() include("cuda.jl") diff --git a/test/sampling.jl b/test/sampling.jl index 0f06320..d95b8a2 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -70,10 +70,11 @@ end Random.seed!(140) mps = random_matrix_product_state(n, chi) num_samples = 10000 + ixs = OMEinsum.getixsv(mps.code) samples = map(1:num_samples) do i - sample(mps, 1; queryvars=vcat(mps.mars...)).samples[:,1] + sample(mps, 1; queryvars=collect(1:n)).samples[:,1] end - samples = sample(mps, num_samples; queryvars=vcat(mps.mars...)) + samples = sample(mps, num_samples; queryvars=collect(1:n)) indices = map(samples) do sample sum(i->sample[i] * 2^(i-1), 1:n) + 1 end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..18958c0 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,15 @@ +using TensorInference, Test + +@testset "tensor train" begin + tt = random_tensor_train_uai(Float64, 5, 3) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) + + tt = random_tensor_train_uai(Float64, 5, 3; periodic=true) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) +end + +@testset "mps" begin + tt = random_matrix_product_uai(Float64, 5, 3) + @test tt.nvars == length(unique(vcat([collect(Int, f.vars) for f in tt.factors]...))) +end +