Skip to content

Commit 849a28b

Browse files
authored
Merge pull request #96 from TensorBFS/jg/fix-sampling
Fix sampling algorithm
2 parents 9986d3f + 8a250cd commit 849a28b

File tree

10 files changed

+236
-106
lines changed

10 files changed

+236
-106
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1313
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1414
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
15+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1516
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
@@ -25,6 +26,7 @@ LinearAlgebra = "1"
2526
OMEinsum = "0.8"
2627
Pkg = "1"
2728
PrecompileTools = "1"
29+
PrettyTables = "2"
2830
Requires = "1"
2931
StatsBase = "0.34"
3032
TropicalNumbers = "0.5.4, 0.6"

docs/src/api/public.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ read_td_file
6868
sample
6969
update_evidence!
7070
update_temperature
71+
random_matrix_product_state
7172
```

src/Core.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ Get the cardinalities of variables in this tensor network.
204204
"""
205205
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
206206
vars = get_vars(tn)
207-
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in eachindex(vars)]
207+
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
208+
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
208209
end
209210

210211
chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)

src/RescaledArray.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ $(TYPEDSIGNATURES)
2323
Returns a rescaled array that equivalent to the input tensor.
2424
"""
2525
function rescale_array(tensor::AbstractArray{T})::RescaledArray where {T}
26-
maxf = maximum(tensor)
26+
maxf = maximum(abs, tensor)
2727
if iszero(maxf)
2828
@warn("The maximum value of the array to rescale is 0!")
2929
return RescaledArray(zero(T), tensor)
3030
end
31-
return RescaledArray(log(maxf), OMEinsum.asarray(tensor ./ maxf, tensor))
31+
return RescaledArray(T(log(maxf)), OMEinsum.asarray(tensor ./ maxf, tensor))
3232
end
3333

3434
for CT in [:DynamicEinCode, :StaticEinCode]
@@ -46,4 +46,4 @@ end
4646
Base.size(arr::RescaledArray) = size(arr.normalized_value)
4747
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)
4848

49-
match_arraytype(::Type{<:RescaledArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = rescale_array(target)
49+
match_arraytype(::Type{<:RescaledArray{T, N, AT}}, target::AbstractArray{T, N}) where {T, N, AT} = rescale_array(match_arraytype(AT, target))

src/TensorInference.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using OMEinsum, LinearAlgebra
1111
using DocStringExtensions, TropicalNumbers
1212
# The Tropical GEMM support
1313
using StatsBase
14+
using PrettyTables
1415
import Pkg
1516

1617
# reexport OMEinsum functions
@@ -34,6 +35,9 @@ export sample
3435
# MMAP
3536
export MMAPModel
3637

38+
# utils
39+
export random_matrix_product_state
40+
3741
include("Core.jl")
3842
include("RescaledArray.jl")
3943
include("utils.jl")

src/mar.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ end
1616
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
1717
# It is a tree structure that isomorphic to the contraction tree,
1818
# `content` is the cached intermediate contraction result.
19-
# `siblings` are the siblings of current node.
20-
struct CacheTree{T}
19+
# `children` are the children of current node, e.g. tensors that are contracted to get `content`.
20+
mutable struct CacheTree{T}
2121
content::AbstractArray{T}
22-
siblings::Vector{CacheTree{T}}
22+
const children::Vector{CacheTree{T}}
2323
end
2424

2525
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
@@ -62,7 +62,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
6262
if OMEinsum.isleaf(code)
6363
return CacheTree(dy, CacheTree{T}[])
6464
else
65-
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
65+
xs = ntuple(i -> cache.children[i].content, length(cache.children))
6666
# `einsum_grad` is the back-propagation rule for einsum function.
6767
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
6868
# Then the back-propagation pass is
@@ -73,7 +73,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
7373
# ```
7474
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
7575
dxs = einsum_backward_rule(code.eins, xs, cache.content, size_dict, dy)
76-
return CacheTree(dy, generate_gradient_tree.(code.args, cache.siblings, dxs, Ref(size_dict)))
76+
return CacheTree(dy, generate_gradient_tree.(code.args, cache.children, dxs, Ref(size_dict)))
7777
end
7878
end
7979

@@ -116,7 +116,7 @@ function extract_leaves!(code, cache, res)
116116
res[code.tensorindex] = cache.content
117117
else
118118
# resurse deeper
119-
extract_leaves!.(code.args, cache.siblings, Ref(res))
119+
extract_leaves!.(code.args, cache.children, Ref(res))
120120
end
121121
return res
122122
end
@@ -145,10 +145,7 @@ The following example is taken from [`examples/asia-network/main.jl`](https://te
145145
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146146
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia-network", "model.uai"));
147147
148-
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
149-
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
150-
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
151-
contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
148+
julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0));
152149
153150
julia> marginals(tn)
154151
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
@@ -161,10 +158,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
161158
[7] => [0.145092, 0.854908]
162159
[2] => [0.05, 0.95]
163160
164-
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
165-
TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
166-
variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
167-
contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
161+
julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]);
168162
169163
julia> marginals(tn2)
170164
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:

0 commit comments

Comments
 (0)