Skip to content

Commit

Permalink
Faster greedy (#107)
Browse files Browse the repository at this point in the history
* boost greedy

* update

* rm asarray type restriction

* fix asarray

* fix tests
  • Loading branch information
GiggleLiu authored May 27, 2021
1 parent 3aa20d2 commit 30b2de9
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.4.1"
version = "0.4.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -17,6 +17,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[compat]
AbstractTrees = "0.3"
BatchedRoutines = "0.2"
CUDA = "2, 3.1"
ChainRulesCore = "0.9"
Expand All @@ -25,7 +26,6 @@ MacroTools = "0.5"
PkgBenchmark = "0.2"
Requires = "0.5, 1"
TupleTools = "1.2"
AbstractTrees = "0.3"
julia = "1"

[extras]
Expand Down
16 changes: 16 additions & 0 deletions benchmark/greedy_order.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using OMEinsum, LightGraphs, BenchmarkTools

function uniformsize(@nospecialize(code::EinCode{ixs,iy}), size::Int) where {ixs, iy}
Dict([c=>size for c in [Iterators.flatten(ixs)..., iy...]])
end
uniformsize(ne::OMEinsum.NestedEinsum, size::Int) = uniformsize(Iterators.flatten(ne), size)

function random_regular_eincode(n, k)
g = LightGraphs.random_regular_graph(n, k)
ixs = [minmax(e.src,e.dst) for e in LightGraphs.edges(g)]
code = EinCode((ixs..., [(i,) for i in LightGraphs.vertices(g)]...), ())
end

number_of_nodes = 200
code = random_regular_eincode(number_of_nodes, 3)
optcode = @benchmark optimize_greedy($code, $(uniformsize(code, 2)); nrepeat=10, method=OMEinsum.MinSpaceOut())
63 changes: 42 additions & 21 deletions src/contractionorder/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ function parse_tree(ein, vertices)
end
end

function optimize_greedy(code::EinCode{ixs, iy}, size_dict; method=MinSpaceOut(), nrepeat=10) where {ixs, iy}
"""
optimize_greedy(eincode, size_dict; method=MinSpaceOut(), nrepeat=10)
Greedy optimizing the contraction order. Methods are
* `MinSpaceOut`, always choose the next contraction that produces the minimum output tensor.
* `MinSpaceDiff`, always choose the next contraction that minimizes the total space.
"""
function optimize_greedy(@nospecialize(code::EinCode{ixs, iy}), size_dict; method=MinSpaceOut(), nrepeat=10) where {ixs, iy}
optimize_greedy(collect(ixs), collect(iy), size_dict; method=MinSpaceOut(), nrepeat=nrepeat)
end
function optimize_greedy(ixs::AbstractVector, iy::AbstractVector, size_dict; method=MinSpaceOut(), nrepeat=10)
if length(ixs) < 2
return code
end
Expand All @@ -62,16 +72,12 @@ function optimize_greedy(code::EinCode{ixs, iy}, size_dict; method=MinSpaceOut()
tree, _, _ = tree_greedy(incidence_list, log2_edge_sizes; method=method, nrepeat=nrepeat)
parse_eincode!(incidence_list, tree, 1:length(ixs))[2]
end

optimize_greedy(code::Int, size_dict; method=MinSpaceOut()) = code

function optimize_greedy(code::NestedEinsum, size_dict; method=MinSpaceOut())
args = optimize_greedy.(code.args, Ref(size_dict); method=method)
optimize_greedy(code::Int, size_dict; method=MinSpaceOut(), nrepeat=10) = code
function optimize_greedy(code::NestedEinsum, size_dict; method=MinSpaceOut(), nrepeat=10)
args = optimize_greedy.(code.args, Ref(size_dict); method=method, nrepeat=nrepeat)
if length(code.args) > 2
# generate coarse grained hypergraph.
#hyper_incidence_list = code.eins # TODO
#tree_greedy(hyper_incidence_list, log2_edge_sizes; method=method)
nested = optimize_greedy(code.eins, size_dict; method=method)
nested = optimize_greedy(code.eins, size_dict; method=method, nrepeat=nrepeat)
replace_args(nested, args)
else
NestedEinsum(args, code.eins)
Expand All @@ -83,22 +89,37 @@ function replace_args(nested::NestedEinsum, trueargs)
end
replace_args(nested::Int, trueargs) = trueargs[nested]

ContractionOrder.timespace_complexity(ei::Int, log2_edge_sizes) = -Inf, -Inf
function ContractionOrder.timespace_complexity(ei::NestedEinsum, log2_edge_sizes)
tcscs = timespace_complexity.(ei.args, Ref(log2_edge_sizes))
tc2, sc2 = timespace_complexity(ei.eins, log2_edge_sizes)
tc = ContractionOrder.log2sumexp2([getindex.(tcscs, 1)..., tc2])
sc = max(reduce(max, getindex.(tcscs, 2)), sc2)
"""
ContractionOrder.timespace_complexity(eincode, size_dict)
Return the time and space complexity of the einsum contraction.
The time complexity is defined as `log2(number of element multiplication)`.
The space complexity is defined as `log2(size of the maximum intermediate tensor)`.
"""
ContractionOrder.timespace_complexity(ei::Int, size_dict) = -Inf, -Inf
function ContractionOrder.timespace_complexity(ei::NestedEinsum, size_dict)
tcs = Float64[]
scs = Float64[]
for arg in ei.args
tc, sc = timespace_complexity(arg, size_dict)
push!(tcs, tc)
push!(scs, sc)
end
tc2, sc2 = timespace_complexity(collect(getixs(ei.eins)), collect(getiy(ei.eins)), size_dict)
tc = ContractionOrder.log2sumexp2([tcs..., tc2])
sc = max(reduce(max, scs), sc2)
return tc, sc
end

function ContractionOrder.timespace_complexity(ei::EinCode{ixs, iy}, log2_edge_sizes) where {ixs, iy}
function ContractionOrder.timespace_complexity(@nospecialize(ei::EinCode{ixs, iy}), size_dict) where {ixs, iy}
ContractionOrder.timespace_complexity(collect(ixs), collect(iy), size_dict)
end
function ContractionOrder.timespace_complexity(ixs::AbstractVector, iy::AbstractVector, size_dict)
# remove redundant legs
labels = vcat(collect.(ixs)..., collect(iy))
labels = vcat(collect.(ixs)..., iy)
loop_inds = unique!(filter(l->count(==(l), labels)>=2, labels))

tc = isempty(loop_inds) ? -Inf : sum(l->log2_edge_sizes[l], loop_inds)
sc = isempty(iy) ? 0.0 : sum(l->log2_edge_sizes[l], iy)
tc = isempty(loop_inds) ? -Inf : sum(l->log2(size_dict[l]), loop_inds)
sc = isempty(iy) ? 0.0 : sum(l->log2(size_dict[l]), iy)
return tc, sc
end

Expand All @@ -114,4 +135,4 @@ _flatten(i::Int, iy) = [i=>iy]
function Base.Iterators.flatten(code::NestedEinsum)
ixd = Dict(_flatten(code))
EinCode(([ixd[i] for i=1:length(ixd)]...,), OMEinsum.getiy(code.eins))
end
end
29 changes: 23 additions & 6 deletions src/contractionorder/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ function _tree_greedy(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; met
log2_scs = Float64[]

tree = Dict{VT,Any}([v=>v for v in vertices(incidence_list)])
cost_values = evaluate_costs(method, incidence_list, log2_edge_sizes)
while true
cost_values = evaluate_costs(method, incidence_list, log2_edge_sizes)
if length(cost_values) == 0
vpool = collect(vertices(incidence_list))
pair = minmax(vpool[1], vpool[2]) # to prevent empty intersect
Expand All @@ -89,6 +89,7 @@ function _tree_greedy(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; met
else
return ContractionTree(tree[pair[1]], tree[pair[2]]), log2_tcs, log2_scs
end
update_costs!(cost_values, pair..., method, incidence_list, log2_edge_sizes)
end
end

Expand Down Expand Up @@ -128,12 +129,28 @@ function evaluate_costs(method, incidence_list::IncidenceList{VT,ET}, log2_edge_
return cost_values
end

function find_best_cost(cost_values)
function update_costs!(cost_values, va, vb, method, incidence_list::IncidenceList{VT,ET}, log2_edge_sizes) where {VT,ET}
for vj in neighbors(incidence_list, va)
vx, vy = minmax(vj, va)
cost_values[(vx,vy)] = greedy_loss(method, incidence_list, log2_edge_sizes, vx, vy)
end
for k in keys(cost_values)
if vb k
delete!(cost_values, k)
end
end
end

function find_best_cost(cost_values::Dict{PT}) where PT
length(cost_values) < 1 && error("cost value information missing")
pairs = collect(keys(cost_values))
values = collect(Base.values(cost_values))
best_locs = findall(==(minimum(values)), values)
return pairs[rand(best_locs)]
minval = minimum(Base.values(cost_values))
pairs = PT[]
for (k, v) in cost_values
if v == minval
push!(pairs, k)
end
end
return rand(pairs)
end

function analyze_contraction(incidence_list::IncidenceList{VT,ET}, vi::VT, vj::VT) where {VT,ET}
Expand Down
4 changes: 2 additions & 2 deletions src/contractionorder/incidencelist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Base.copy(il::IncidenceList) = IncidenceList(deepcopy(il.v2e), deepcopy(il.e2v),
function neighbors(il::IncidenceList{VT}, v) where VT
res = VT[]
for e in il.v2e[v]
for v in il.e2v[e]
push!(res, v)
for vj in il.e2v[e]
v != vj && push!(res, vj)
end
end
return unique!(res)
Expand Down
5 changes: 2 additions & 3 deletions src/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ using .CUDA

println("OMEinsum: YOU FIND CUDA!")

#include("cudapatch.jl")

asarray(x::Number, arr::CuArray) where T = CuArray(fill(x, ()))
asarray(x, arr::CuArray) where T = CuArray(fill(x, ()))
asarray(x::AbstractArray, y::CuArray) = x

Base.Array(x::Base.ReshapedArray{T,0,<:CuArray}) where T = Array(x.parent)

Expand Down
11 changes: 6 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
export asarray

"""
asarray(x::Number[, parent::AbstractArray]) -> AbstactArray
asarray(x[, parent::AbstractArray]) -> AbstactArray
Return a 0-dimensional array with item `x`, otherwise, do nothing.
If a `parent` is supplied, it will try to match the parent array type.
"""
asarray(x::Number) = fill(x, ())
asarray(x::Number, arr::Array) = fill(x, ())
asarray(x::AbstractArray, args...) = x
asarray(x) = fill(x, ())
asarray(x::AbstractArray) = x
asarray(x, arr::Array) = fill(x, ())
asarray(x::AbstractArray, y::Array) = x

"""
nopermute(ix,iy)
Expand Down Expand Up @@ -98,4 +99,4 @@ function _batched_gemm(C1::Char, C2::Char, A::AbstractArray{T,3}, B::AbstractArr
mul!(view(C,:,:,l), a, b)
end
return C
end
end
10 changes: 6 additions & 4 deletions test/contractionorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ end
Random.seed!(2)
incidence_list = IncidenceList(Dict('A' => ['a', 'b'], 'B'=>['a', 'c', 'd'], 'C'=>['b', 'c', 'e', 'f'], 'D'=>['e'], 'E'=>['d', 'f']))
log2_edge_sizes = Dict([c=>i for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
edge_sizes = Dict([c=>(1<<i) for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
il = copy(incidence_list)
contract_pair!(il, 'A', 'B', log2_edge_sizes)
target = IncidenceList(Dict('A' => ['b', 'c', 'd'], 'C'=>['b', 'c', 'e', 'f'], 'D'=>['e'], 'E'=>['d', 'f']))
Expand All @@ -41,14 +42,14 @@ end
size_dict = Dict([c=>(1<<i) for (i,c) in enumerate(['a', 'b', 'c', 'd', 'e', 'f'])]...)
Random.seed!(2)
optcode2 = optimize_greedy(eincode, size_dict)
tc, sc = timespace_complexity(optcode2, log2_edge_sizes)
tc, sc = timespace_complexity(optcode2, edge_sizes)
@test 16 <= tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9))
@test sc == 11
@test optcode1 == optcode2
eincode3 = ein"(ab,acd),bcef,e,df->"
Random.seed!(2)
optcode3 = optimize_greedy(eincode3, size_dict)
tc, sc = timespace_complexity(optcode3, log2_edge_sizes)
tc, sc = timespace_complexity(optcode3, edge_sizes)
@test 16 <= tc <= log2(exp2(10)+exp2(16)+exp2(15)+exp2(9)+1e-8)
end

Expand All @@ -73,11 +74,12 @@ end
code = EinCode((c60_edges..., [(i,) for i=1:60]...), ())
size_dict = Dict([i=>2 for i in 1:60])
log2_edge_sizes = Dict([i=>1 for i in 1:60])
tc, sc = timespace_complexity(code, log2_edge_sizes)
edge_sizes = Dict([i=>2 for i in 1:60])
tc, sc = timespace_complexity(code, edge_sizes)
@test tc == 60
@test sc == 0
optcode = optimize_greedy(code, size_dict)
tc2, sc2 = timespace_complexity(optcode, log2_edge_sizes)
tc2, sc2 = timespace_complexity(optcode, edge_sizes)
@test sc2 == 10
xs = vcat([TropicalF64.([-1 1; 1 -1]) for i=1:90], [TropicalF64.([0, 0]) for i=1:60])
@test Base.Iterators.flatten(optcode) == code
Expand Down

0 comments on commit 30b2de9

Please sign in to comment.