Skip to content

Commit

Permalink
gpu binary rule (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Jun 12, 2021
1 parent ded4a27 commit 69d5520
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.4.3"
version = "0.4.4"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
26 changes: 4 additions & 22 deletions src/binaryrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ end
# S = 1
# T = 1
function einsum(::SimpleBinaryRule{(),(), ()}, xs::NTuple{2, Any})
asarray(xs[1][] * xs[2], xs[1])
asarray(asscalar(xs[1]) * xs[2], xs[1])
end

# i,->i : 100
# S = N
# T = N
function einsum(::SimpleBinaryRule{('i',),(), ('i',)}, xs::NTuple{2, Any})
xs[1] .* xs[2][]
xs[1] .* asscalar(xs[2])
end

# j,j-> : 010
Expand Down Expand Up @@ -167,31 +167,13 @@ end
# 010
function einsum(::SimpleBinaryRule{('j','l'), ('j','l'), ('l',)}, xs::NTuple{2, Any})
a, b = xs
T = promote_type(eltype(xs[1]), eltype(xs[2]))
out = similar(a, T, size(a, 2))
@inbounds for k=1:size(a, 2)
elem = zero(T)
for i=1:size(a, 1)
elem += a[i,k] * b[i,k]
end
out[k] = elem
end
return out
dropdims(mapreduce(*, +, a, b; dims=1); dims=1)
end

# 101
function einsum(::SimpleBinaryRule{('i','l'), ('k','l'), ('i','k','l')}, xs::NTuple{2, Any})
a, b = xs
T = promote_type(eltype(xs[1]), eltype(xs[2]))
out = similar(a, T, size(a, 1), size(b, 1), size(a, 2))
@inbounds for k=1:size(a, 2)
for j=1:size(b, 1)
for i=1:size(a, 1)
out[i,j,k] = a[i,k] * b[j,k]
end
end
end
return out
_batched_gemm('N', 'N', reshape(a, size(a, 1), 1, size(a, 2)), reshape(b, 1, size(b, 1), size(b, 2)))
end
@inline function einsum(::SimpleBinaryRule{('i','l'), ('k','l'), ('k','i','l')}, xs::NTuple{2, Any})
einsum(SimpleBinaryRule{('i','l'),('k','l'), ('i','k','l')}(), (xs[2], xs[1]))
Expand Down
5 changes: 5 additions & 0 deletions src/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ println("OMEinsum: YOU FIND CUDA!")

asarray(x, arr::CuArray) where T = CuArray(fill(x, ()))
asarray(x::AbstractArray, y::CuArray) = x
asscalar(x::DenseCuArray) = Array(x)[]

Base.Array(x::Base.ReshapedArray{T,0,<:CuArray}) where T = Array(x.parent)

Expand All @@ -20,6 +21,10 @@ for TP in [:Diag, :Repeat, :Duplicate, :DefaultRule]
end
end

function einsum(::SimpleBinaryRule{('j',), ('j',), ()}, xs::NTuple{2, DenseCuArray})
dropdims(reshape(xs[1],1,:) * xs[2]; dims=1)
end

function loop_einsum!(code::EinCode{ixs, iy},
xs::NTuple{N, DenseCuArray{<:Any,M} where M},
y::DenseCuArray{T,L}, size_dict::Dict{LT}) where {N,L,T, ixs, iy, LT}
Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export asarray
export asarray, asscalar

"""
asarray(x[, parent::AbstractArray]) -> AbstactArray
Expand All @@ -10,6 +10,8 @@ asarray(x) = fill(x, ())
asarray(x::AbstractArray) = x
asarray(x, arr::Array) = fill(x, ())
asarray(x::AbstractArray, y::Array) = x
asscalar(x) = x
asscalar(x::AbstractArray) = x[]

"""
nopermute(ix,iy)
Expand Down
14 changes: 14 additions & 0 deletions test/cueinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ end
@test Array(res) Array(loop_einsum(code, (xs...,), size_dict))
end
end

@testset "binary rules" begin
for (code, a, b) in [
(ein"j,j->", randn(10), randn(10)),
(ein"i,->i", randn(10), fill(2.0, ())),
(ein",->", fill(2.0,()), fill(2.0, ())),
(ein"il,kl->ikl", randn(10, 10), randn(10, 10)),
]
res0 = code(a, b)
res1 = code(CuArray(a), CuArray(b))
@test res1 isa CuArray
@test res0 Array(res1)
end
end

0 comments on commit 69d5520

Please sign in to comment.