Skip to content

Commit

Permalink
Fix 1.5 compatibility (#116)
Browse files Browse the repository at this point in the history
* fix 1.5 compatibility

* update compatibility

* update project.toml

* rm 1.5 ci
  • Loading branch information
GiggleLiu authored Sep 1, 2021
1 parent d8248f5 commit c949af5
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 17 deletions.
10 changes: 4 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.4.7"
version = "0.4.8"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -11,21 +11,19 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[compat]
AbstractTrees = "0.3"
BatchedRoutines = "0.2"
CUDA = "2, 3.1"
ChainRulesCore = "1.0"
CUDA = "3.4"
ChainRulesCore = "1.3"
Combinatorics = "1.0"
MacroTools = "0.5"
PkgBenchmark = "0.2"
Requires = "0.5, 1"
TupleTools = "1.2"
TupleTools = "1.2, 1.3"
julia = "1"

[extras]
Expand Down
6 changes: 3 additions & 3 deletions src/contractionorder/greedy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function _tree_greedy(incidence_list::IncidenceList{VT,ET}, log2_edge_sizes; met
end

function contract_pair!(incidence_list, vi, vj, log2_edge_sizes)
log2dim(legs) = sum(l->log2_edge_sizes[l], legs, init=0)
log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed.
@assert vj > vi
# compute time complexity and output tensor
legsets = analyze_contraction(incidence_list, vi, vj)
Expand Down Expand Up @@ -186,13 +186,13 @@ function analyze_contraction(incidence_list::IncidenceList{VT,ET}, vi::VT, vj::V
end

function greedy_loss(::MinSpaceOut, incidence_list, log2_edge_sizes, vi, vj)
log2dim(legs) = sum(l->log2_edge_sizes[l], legs, init=0)
log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed.
legs = analyze_contraction(incidence_list, vi, vj)
log2dim(legs.l01)+log2dim(legs.l02)+log2dim(legs.l012)
end

function greedy_loss(::MinSpaceDiff, incidence_list, log2_edge_sizes, vi, vj)
log2dim(legs) = sum(l->log2_edge_sizes[l], legs, init=0)
log2dim(legs) = isempty(legs) ? 0 : sum(l->log2_edge_sizes[l], legs) # for 1.5, you need this patch because `init` kw is not allowed.
legs = analyze_contraction(incidence_list, vi, vj)
D1,D2,D12,D01,D02,D012 = log2dim.(getfield.(Ref(legs), 1:6))
exp2(D01+D02+D012) - exp2(D1+D01+D12) - exp2(D2+D02+D12) # out - in
Expand Down
14 changes: 10 additions & 4 deletions src/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ asscalar(x::DenseCuArray) = Array(x)[]
Base.Array(x::Base.ReshapedArray{T,0,<:CuArray}) where T = Array(x.parent)

function get_output_array(xs::NTuple{N, DenseCuArray{<:Any,M} where M}, size; has_repeated_indices=true) where N
out = CUDA.zeros(promote_type(map(eltype,xs)...), size)
CUDA.zeros(promote_type(map(eltype,xs)...), size)
end

CUDA.cudaconvert(A::EinArray{T}) where T = EinArray{T}(cudaconvert.(A.xs), A.x_indexers, A.y_indexer, A.size, A.ICIS, A.OCIS)
Expand All @@ -31,13 +31,19 @@ function loop_einsum!(code::EinCode{ixs, iy},
A = einarray(code, xs, size_dict)
if NO == length(iy)
y = reshape(y, fill(1, ndims(A)-NO)...,size(y)...)
dropdims(Base.mapreducedim!(x->x, +, y, A), dims=(1:ndims(A)-NO...,))
raw = Base.mapreducedim!(x->x, +, y, A)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw, dims=(1:ndims(A)-NO...,))
end
return raw
else
y_ = CUDA.zeros(T, size(A)[end-NO+1:end]...)
y_ = reshape(y_, fill(1, ndims(A)-NO)...,size(y_)...)
raw = Base.mapreducedim!(x->x, +, y_, A)
out = dropdims(raw, dims=(1:ndims(A)-NO...,))
expanddims!(EinCode{((iy_...,),), iy}(), out, y)
if ndims(A)-NO > 0 # fix 1.7 compatibility
raw = dropdims(raw, dims=(1:ndims(A)-NO...,))
end
return expanddims!(EinCode{((iy_...,),), iy}(), raw, y)
end
end

Expand Down
8 changes: 5 additions & 3 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,14 @@ function dynamic_einsum(ixs, xs, iy, size_dict)
if length(ixs) == 1
einsum(rule, ixs[1], iy, xs[1], size_dict)
else
einsum(rule, ixs, iy, xs, size_dict)
einsum(rule, ixs, iy, (xs...,), size_dict)
end
end

dynamic_einsum(::EinCode{ixs, iy}, xs; kwargs...) where {ixs, iy} = dynamic_einsum(ixs, xs, iy; kwargs...)

# the fallback
function einsum(::DefaultRule, ixs, iy, xs, size_dict)
@debug "DefaultRule loop_einsum" ixs => iy size.(xs)
loop_einsum(EinCode{ixs, iy}(), xs, size_dict)
end
loop_einsum(EinCode{ixs, iy}(), (xs...,), size_dict)
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ If a `parent` is supplied, it will try to match the parent array type.
"""
asarray(x) = fill(x, ())
asarray(x::AbstractArray) = x
asarray(x, arr::Array) = fill(x, ())
asarray(x, arr::AbstractArray) = fill(x, ())
asarray(x::AbstractArray, y::Array) = x
asscalar(x) = x
asscalar(x::AbstractArray) = x[]
Expand Down

0 comments on commit c949af5

Please sign in to comment.