Skip to content

Commit

Permalink
fix docs, upgrade CUDA and bump version (#158)
Browse files Browse the repository at this point in the history
* fix docs, upgrade CUDA and bump version

* make CUDA an extensijon

* update readme
  • Loading branch information
GiggleLiu authored Sep 23, 2023
1 parent 3f034a7 commit 96582cc
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 21 deletions.
15 changes: 9 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.7.4"
version = "0.7.5"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
BatchedRoutines = "a9ab73d0-e05c-5df1-8fde-d6a4645b8d8e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CUDAExt = "CUDA"

[compat]
AbstractTrees = "0.3, 0.4"
BatchedRoutines = "0.2"
CUDA = "4"
CUDA = "4, 5"
ChainRulesCore = "1"
Combinatorics = "1.0"
MacroTools = "0.5"
OMEinsumContractionOrders = "0.8"
Requires = "0.5, 1"
TupleTools = "1.2, 1.3"
julia = "1"

Expand All @@ -42,4 +45,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"]
test = ["Test", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials", "CUDA"]
41 changes: 39 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ julia> ein"ijk,ijk->"(s,s)[]

Using that method, it's easy to find that e.g. the peterson graph allows no 3 colouring, since
```julia
julia> ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(fill(s, 10)...)[]
julia> code = ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"
afl, bhn, cjf, dlh, enj, ago, big, cki, dmk, eom

julia> code(fill(s, 10)...)[]
0
```

Expand All @@ -178,11 +181,45 @@ embedded in a pentagon as depicted here:

![](https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Petersen_graph.svg/252px-Petersen_graph.svg.png)

`OMEinsum` does not optimie the contraction order by default, so the above contraction can be time consuming. To speed up the contraction, we can use `optimize_code` to optimize the contraction order:
```julia
julia> optcode = optimize_code(code, uniformsize(code, 3), TreeSA())
SlicedEinsum{Char, DynamicNestedEinsum{Char}}(Char[], ago, goa ->
├─ ago
└─ gcojl, cjal -> goa
├─ bgck, bojlk -> gcojl
│ ├─ big, cki -> bgck
│ │ ├─ big
│ │ └─ cki
│ └─ bhomj, lhmk -> bojlk
│ ├─ bhn, omnj -> bhomj
│ │ ├─ bhn
│ │ └─ eom, enj -> omnj
│ │
│ │
│ └─ dlh, dmk -> lhmk
│ ├─ dlh
│ └─ dmk
└─ cjf, afl -> cjal
├─ cjf
└─ afl
)

julia> contraction_complexity(optcode, uniformsize(optcode, 3))
Time complexity: 2^12.737881076857779
Space complexity: 2^7.92481250360578
Read-write complexity: 2^11.247334178028728

julia> optcode(fill(s, 10)...)[]
0
```
We can see the time complexity of the optimized code is much smaller than the original one. To know more about the contraction order optimization, please check the julia package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl).

Confronted with the above result, we can ask whether the peterson graph allows a relaxed variation of 3 colouring, having one vertex that might accept duplicate colours. The answer to that can be found using the gradient w.r.t a vertex:
```julia
julia> using Zygote: gradient

julia> gradient(x->ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum
julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum
0
```
This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph.
Expand Down
6 changes: 3 additions & 3 deletions docs/src/contractionorder.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ using OMEinsum
code = ein"ij,jk,kl,li->"
```

The time and space complexity can be obtained by calling the [`timespacereadwrite_complexity`](@ref) function.
The time and space complexity can be obtained by calling the [`contraction_complexity`](@ref) function.
```@example 3
size_dict = uniformsize(code, 10)
timespacereadwrite_complexity(code, size_dict)
contraction_complexity(code, size_dict)
```

The return values are `log2` values of the number of iterations, number of elements of the largest tensor and the number of elementwise read-write operations.
Expand All @@ -29,5 +29,5 @@ The output value is a binary contraction tree with type [`NestedEinsum`](@ref) t
The time and readwrite complexities are significantly reduced comparing to the direct contraction.

```@example 3
timespacereadwrite_complexity(optcode, size_dict)
contraction_complexity(optcode, size_dict)
```
2 changes: 1 addition & 1 deletion docs/src/docstrings.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
```@autodocs
Modules = [OMEinsum]
Modules = [OMEinsum, OMEinsum.OMEinsumContractionOrders]
```
12 changes: 11 additions & 1 deletion src/cueinsum.jl → ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
using .CUDA
module CUDAExt

import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm, asscalar
using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode
import OMEinsum
using LinearAlgebra
import LinearAlgebra: BlasFloat
using CUDA

const CuBlasFloat = Union{BlasFloat, Float16, ComplexF16}
const CUDAArrayTypes{T,N} = Union{LinearAlgebra.Transpose{T,<:CuArray{T,N}}, DenseCuArray{T,N}, LinearAlgebra.Adjoint{T,<:CuArray{T,N}}}
_unwrap(x::LinearAlgebra.Adjoint{T,<:CuArray{T}}) where T = CuArray(x)
_unwrap(x::LinearAlgebra.Transpose{T,<:CuArray{T}}) where T = CuArray(x)
Expand Down Expand Up @@ -115,3 +123,5 @@ function einsum(code::DynamicEinCode, @nospecialize(xs::NTuple{N,CUDAArrayTypes}
end

@info("OMEinsum loaded the CUDA module successfully")

end
7 changes: 0 additions & 7 deletions src/OMEinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,13 @@ export CodeOptimizer, CodeSimplifier,
writejson, readjson,
label_elimination_order

const CuBlasFloat = Union{BlasFloat, Float16, ComplexF16}

include("Core.jl")
include("loop_einsum.jl")
include("utils.jl")

include("unaryrules.jl")
include("binaryrules.jl")

using Requires
function __init__()
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cueinsum.jl")
end

include("interfaces.jl")
include("einsequence.jl")
include("slicing.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/einsequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ function parse_nested_expr(expr, tensors, allinds)
end

# the contraction tree
"""
NestedEinsum{LT} <: AbstractEinsum
The abstract type for contraction trees. It has two subtypes, [`DynamicNestedEinsum`](@ref) and [`StaticNestedEinsum`](@ref).
"""
abstract type NestedEinsum{LT} <: AbstractEinsum end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/slicing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher
res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv))
eins_sliced = drop_slicedim(se.eins, se.slicing)
for (k, slicemap) in enumerate(it)
@debug "computing slice $k/$(length(it))"
# NOTE: @debug will break Zygote
# @debug "computing slice $k/$(length(it))"
xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs))
resi = einsum(eins_sliced, xsi, it.size_dict_sliced; kwargs...)
res = fill_slice!(res, it.iyv, resi, slicemap)
Expand Down

0 comments on commit 96582cc

Please sign in to comment.