16
16
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
17
17
# It is a tree structure that isomorphic to the contraction tree,
18
18
# `content` is the cached intermediate contraction result.
19
- # `siblings ` are the siblings of current node.
20
- struct CacheTree{T}
19
+ # `children ` are the children of current node, e.g. tensors that are contracted to get `content` .
20
+ mutable struct CacheTree{T}
21
21
content:: AbstractArray{T}
22
- siblings :: Vector{CacheTree{T}}
22
+ const children :: Vector{CacheTree{T}}
23
23
end
24
24
25
25
function cached_einsum (se:: SlicedEinsum , @nospecialize (xs), size_dict)
@@ -62,7 +62,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
62
62
if OMEinsum. isleaf (code)
63
63
return CacheTree (dy, CacheTree{T}[])
64
64
else
65
- xs = ntuple (i -> cache. siblings [i]. content, length (cache. siblings ))
65
+ xs = ntuple (i -> cache. children [i]. content, length (cache. children ))
66
66
# `einsum_grad` is the back-propagation rule for einsum function.
67
67
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
68
68
# Then the back-propagation pass is
@@ -73,7 +73,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
73
73
# ```
74
74
# Let `L` be the loss, we will have `y̅ := ∂L/∂y`, `A̅ := ∂L/∂A`...
75
75
dxs = einsum_backward_rule (code. eins, xs, cache. content, size_dict, dy)
76
- return CacheTree (dy, generate_gradient_tree .(code. args, cache. siblings , dxs, Ref (size_dict)))
76
+ return CacheTree (dy, generate_gradient_tree .(code. args, cache. children , dxs, Ref (size_dict)))
77
77
end
78
78
end
79
79
@@ -116,7 +116,7 @@ function extract_leaves!(code, cache, res)
116
116
res[code. tensorindex] = cache. content
117
117
else
118
118
# resurse deeper
119
- extract_leaves! .(code. args, cache. siblings , Ref (res))
119
+ extract_leaves! .(code. args, cache. children , Ref (res))
120
120
end
121
121
return res
122
122
end
@@ -145,10 +145,7 @@ The following example is taken from [`examples/asia-network/main.jl`](https://te
145
145
```jldoctest; setup = :(using TensorInference, Random; Random.seed!(0))
146
146
julia> model = read_model_file(pkgdir(TensorInference, "examples", "asia-network", "model.uai"));
147
147
148
- julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0))
149
- TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
150
- variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
151
- contraction time = 2^6.022, space = 2^2.0, read-write = 2^7.077
148
+ julia> tn = TensorNetworkModel(model; evidence=Dict(1=>0));
152
149
153
150
julia> marginals(tn)
154
151
Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
@@ -161,10 +158,7 @@ Dict{Vector{Int64}, Vector{Float64}} with 8 entries:
161
158
[7] => [0.145092, 0.854908]
162
159
[2] => [0.05, 0.95]
163
160
164
- julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]])
165
- TensorNetworkModel{Int64, DynamicNestedEinsum{Int64}, Array{Float64}}
166
- variables: 1 (evidence → 0), 2, 3, 4, 5, 6, 7, 8
167
- contraction time = 2^7.781, space = 2^5.0, read-write = 2^8.443
161
+ julia> tn2 = TensorNetworkModel(model; evidence=Dict(1=>0), mars=[[2, 3], [3, 4]]);
168
162
169
163
julia> marginals(tn2)
170
164
Dict{Vector{Int64}, Matrix{Float64}} with 2 entries:
0 commit comments