Skip to content

Commit

Permalink
upgrade contraction complexity (#150)
Browse files Browse the repository at this point in the history
* update contraction order

* bump version
  • Loading branch information
GiggleLiu authored Sep 2, 2022
1 parent 3656c8f commit 20ce5fb
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 12 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*.json
.DS_Store
/docs/build/
/docs/site/
local/
*.swp
Manifest.toml
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.7.1"
version = "0.7.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
BatchedRoutines = "a9ab73d0-e05c-5df1-8fde-d6a4645b8d8e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Expand All @@ -17,13 +18,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[compat]
AbstractTrees = "0.4"
AbstractTrees = "0.3, 0.4"
BatchedRoutines = "0.2"
CUDA = "3.10"
ChainRulesCore = "1"
Combinatorics = "1.0"
MacroTools = "0.5"
OMEinsumContractionOrders = "0.7"
OMEinsumContractionOrders = "0.8"
Requires = "0.5, 1"
TupleTools = "1.2, 1.3"
julia = "1"
Expand Down
3 changes: 1 addition & 2 deletions src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export @ein_str, @ein, ein
export einsum, dynamic_einsum
export EinCode, EinIndexer, EinArray, DynamicEinCode, StaticEinCode, AbstractEinsum, NestedEinsum, SlicedEinsum, DynamicNestedEinsum, StaticNestedEinsum
export getiyv, getixsv, uniquelabels, labeltype
export timespace_complexity, timespacereadwrite_complexity
export flop
export loop_einsum, loop_einsum!, allow_loops
export asarray, asscalar
Expand All @@ -23,7 +22,7 @@ export CodeOptimizer, CodeSimplifier,
uniformsize,
optimize_code, optimize_permute,
# time space complexity
peak_memory, timespace_complexity, timespacereadwrite_complexity, flop,
peak_memory, timespace_complexity, timespacereadwrite_complexity, flop, contraction_complexity,
# file io
writejson, readjson,
label_elimination_order
Expand Down
2 changes: 1 addition & 1 deletion src/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
OMEinsumContractionOrders.optimize_permute(code::AbstractEinsum) = decorate(optimize_permute(rawcode(code)))
OMEinsumContractionOrders.peak_memory(code::AbstractEinsum, size_dict::Dict) = peak_memory(rawcode(code), size_dict)
OMEinsumContractionOrders.flop(code::AbstractEinsum, size_dict::Dict) = flop(rawcode(code), size_dict)
OMEinsumContractionOrders.timespacereadwrite_complexity(code::AbstractEinsum, size_dict) = timespacereadwrite_complexity(rawcode(code), size_dict)
OMEinsumContractionOrders.contraction_complexity(code::AbstractEinsum, size_dict) = contraction_complexity(rawcode(code), size_dict)

OMEinsumContractionOrders.uniformsize(code::AbstractEinsum, size) = Dict([l=>size for l in uniquelabels(code)])
OMEinsumContractionOrders.label_elimination_order(code::AbstractEinsum) = label_elimination_order(rawcode(code))
Expand Down
2 changes: 1 addition & 1 deletion src/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function einsum(neinsum::NestedEinsum, @nospecialize(xs::NTuple{N,CUDAArrayTypes
# do not use `setindex!` because we need to make the AD work
mxs = Vector{AbstractArray}(undef, length(siblings(neinsum)))
for (i, arg) in enumerate(siblings(neinsum))
mxs = _safe_set(mxs, i, isleaf(arg) ? xs[arg.tensorindex] : einsum(arg, xs, size_dict; active_free=active_free))
mxs = _safe_set(mxs, i, isleaf(arg) ? xs[tensorindex(arg)] : einsum(arg, xs, size_dict; active_free=active_free))
end
res = einsum(rootcode(neinsum), (mxs...,), size_dict)
active_free && for mx in mxs # free CuArray aggressively.
Expand Down
8 changes: 4 additions & 4 deletions test/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Test, Random
size_dict = Dict([c=>(1<<i) for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
Random.seed!(2)
optcode2 = optimize_code(eincode, size_dict, GreedyMethod())
tc, sc = timespace_complexity(optcode2, edge_sizes)
tc, sc = contraction_complexity(optcode2, edge_sizes)
# test flop
@test tc log2(flop(optcode2, edge_sizes))
@test flop(ein"i->", Dict('i'=>4)) == 4
Expand All @@ -20,7 +20,7 @@ using Test, Random
eincode3 = ein"(ab,acd),bcef,e,df->"
Random.seed!(2)
optcode3 = optimize_code(eincode3, size_dict, GreedyMethod())
tc, sc = timespace_complexity(optcode3, edge_sizes)
tc, sc = contraction_complexity(optcode3, edge_sizes)
@test 16 <= tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9)+1e-8)
end

Expand All @@ -46,11 +46,11 @@ end
size_dict = Dict([i=>2 for i in 1:60])
log2_edge_sizes = Dict([i=>1 for i in 1:60])
edge_sizes = Dict([i=>2 for i in 1:60])
tc, sc = timespace_complexity(code, edge_sizes)
tc, sc = contraction_complexity(code, edge_sizes)
@test tc == 60
@test sc == 0
optcode = optimize_code(code, size_dict, TreeSA(ntrials=1), MergeVectors())
tc2, sc2 = timespace_complexity(optcode, edge_sizes)
tc2, sc2 = contraction_complexity(optcode, edge_sizes)
@test sc2 == 10
xs = vcat([TropicalF64.([-1 1; 1 -1]) for i=1:90], [TropicalF64.([0, 0]) for i=1:60])
@test OMEinsum.flatten(optcode) == code
Expand Down

0 comments on commit 20ce5fb

Please sign in to comment.