diff --git a/Project.toml b/Project.toml index 8f42c60..bd500e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,30 +1,33 @@ name = "OMEinsum" uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922" authors = ["Andreas Peter "] -version = "0.7.4" +version = "0.7.5" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" BatchedRoutines = "a9ab73d0-e05c-5df1-8fde-d6a4645b8d8e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + +[extensions] +CUDAExt = "CUDA" + [compat] AbstractTrees = "0.3, 0.4" BatchedRoutines = "0.2" -CUDA = "4" +CUDA = "4, 5" ChainRulesCore = "1" Combinatorics = "1.0" MacroTools = "0.5" OMEinsumContractionOrders = "0.8" -Requires = "0.5, 1" TupleTools = "1.2, 1.3" julia = "1" @@ -42,4 +45,4 @@ TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials"] +test = ["Test", "Documenter", "LinearAlgebra", "ProgressMeter", "SymEngine", "Random", "Zygote", "DoubleFloats", "TropicalNumbers", "ForwardDiff", "Polynomials", "CUDA"] diff --git a/README.md b/README.md index 464f358..c0f748c 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,10 @@ julia> ein"ijk,ijk->"(s,s)[] Using that method, it's easy to find that e.g. the peterson graph allows no 3 colouring, since ```julia -julia> ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(fill(s, 10)...)[] +julia> code = ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->" +afl, bhn, cjf, dlh, enj, ago, big, cki, dmk, eom + +julia> code(fill(s, 10)...)[] 0 ``` @@ -178,11 +181,45 @@ embedded in a pentagon as depicted here: ![](https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Petersen_graph.svg/252px-Petersen_graph.svg.png) +`OMEinsum` does not optimie the contraction order by default, so the above contraction can be time consuming. To speed up the contraction, we can use `optimize_code` to optimize the contraction order: +```julia +julia> optcode = optimize_code(code, uniformsize(code, 3), TreeSA()) +SlicedEinsum{Char, DynamicNestedEinsum{Char}}(Char[], ago, goa -> +├─ ago +└─ gcojl, cjal -> goa + ├─ bgck, bojlk -> gcojl + │ ├─ big, cki -> bgck + │ │ ├─ big + │ │ └─ cki + │ └─ bhomj, lhmk -> bojlk + │ ├─ bhn, omnj -> bhomj + │ │ ├─ bhn + │ │ └─ eom, enj -> omnj + │ │ ⋮ + │ │ + │ └─ dlh, dmk -> lhmk + │ ├─ dlh + │ └─ dmk + └─ cjf, afl -> cjal + ├─ cjf + └─ afl +) + +julia> contraction_complexity(optcode, uniformsize(optcode, 3)) +Time complexity: 2^12.737881076857779 +Space complexity: 2^7.92481250360578 +Read-write complexity: 2^11.247334178028728 + +julia> optcode(fill(s, 10)...)[] +0 +``` +We can see the time complexity of the optimized code is much smaller than the original one. To know more about the contraction order optimization, please check the julia package [`OMEinsumContractionOrders.jl`](https://github.com/TensorBFS/OMEinsumContractionOrders.jl). + Confronted with the above result, we can ask whether the peterson graph allows a relaxed variation of 3 colouring, having one vertex that might accept duplicate colours. The answer to that can be found using the gradient w.r.t a vertex: ```julia julia> using Zygote: gradient -julia> gradient(x->ein"afl,bhn,cjf,dlh,enj,ago,big,cki,dmk,eom->"(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum +julia> gradient(x->optcode(x,s,s,s,s,s,s,s,s,s)[], s)[1] |> sum 0 ``` This tells us that even if we allow duplicates on one vertex, there are no 3-colourings for the peterson graph. diff --git a/docs/src/contractionorder.md b/docs/src/contractionorder.md index d19990c..428fcdb 100644 --- a/docs/src/contractionorder.md +++ b/docs/src/contractionorder.md @@ -12,11 +12,11 @@ using OMEinsum code = ein"ij,jk,kl,li->" ``` -The time and space complexity can be obtained by calling the [`timespacereadwrite_complexity`](@ref) function. +The time and space complexity can be obtained by calling the [`contraction_complexity`](@ref) function. ```@example 3 size_dict = uniformsize(code, 10) -timespacereadwrite_complexity(code, size_dict) +contraction_complexity(code, size_dict) ``` The return values are `log2` values of the number of iterations, number of elements of the largest tensor and the number of elementwise read-write operations. @@ -29,5 +29,5 @@ The output value is a binary contraction tree with type [`NestedEinsum`](@ref) t The time and readwrite complexities are significantly reduced comparing to the direct contraction. ```@example 3 -timespacereadwrite_complexity(optcode, size_dict) +contraction_complexity(optcode, size_dict) ``` \ No newline at end of file diff --git a/docs/src/docstrings.md b/docs/src/docstrings.md index 0e7ceec..95d3d07 100644 --- a/docs/src/docstrings.md +++ b/docs/src/docstrings.md @@ -1,3 +1,3 @@ ```@autodocs -Modules = [OMEinsum] +Modules = [OMEinsum, OMEinsum.OMEinsumContractionOrders] ``` diff --git a/src/cueinsum.jl b/ext/CUDAExt.jl similarity index 90% rename from src/cueinsum.jl rename to ext/CUDAExt.jl index dcf0208..54de708 100644 --- a/src/cueinsum.jl +++ b/ext/CUDAExt.jl @@ -1,5 +1,13 @@ -using .CUDA +module CUDAExt +import OMEinsum: asarray, get_output_array, einsum, loop_einsum!, _batched_gemm, asscalar +using OMEinsum: EinArray, Diag, Repeat, Duplicate, DefaultRule, EinCode, DynamicEinCode, StaticEinCode, NestedEinsum, SimpleBinaryRule, match_rule, loop_einsum, getiy, getixs, _unique, einarray, align_eltypes, siblings, isleaf, tensorindex, _safe_set, rootcode +import OMEinsum +using LinearAlgebra +import LinearAlgebra: BlasFloat +using CUDA + +const CuBlasFloat = Union{BlasFloat, Float16, ComplexF16} const CUDAArrayTypes{T,N} = Union{LinearAlgebra.Transpose{T,<:CuArray{T,N}}, DenseCuArray{T,N}, LinearAlgebra.Adjoint{T,<:CuArray{T,N}}} _unwrap(x::LinearAlgebra.Adjoint{T,<:CuArray{T}}) where T = CuArray(x) _unwrap(x::LinearAlgebra.Transpose{T,<:CuArray{T}}) where T = CuArray(x) @@ -115,3 +123,5 @@ function einsum(code::DynamicEinCode, @nospecialize(xs::NTuple{N,CUDAArrayTypes} end @info("OMEinsum loaded the CUDA module successfully") + +end \ No newline at end of file diff --git a/src/OMEinsum.jl b/src/OMEinsum.jl index e9014d2..24bf6e8 100644 --- a/src/OMEinsum.jl +++ b/src/OMEinsum.jl @@ -27,8 +27,6 @@ export CodeOptimizer, CodeSimplifier, writejson, readjson, label_elimination_order -const CuBlasFloat = Union{BlasFloat, Float16, ComplexF16} - include("Core.jl") include("loop_einsum.jl") include("utils.jl") @@ -36,11 +34,6 @@ include("utils.jl") include("unaryrules.jl") include("binaryrules.jl") -using Requires -function __init__() - @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cueinsum.jl") -end - include("interfaces.jl") include("einsequence.jl") include("slicing.jl") diff --git a/src/einsequence.jl b/src/einsequence.jl index 951800e..3eae41a 100644 --- a/src/einsequence.jl +++ b/src/einsequence.jl @@ -153,6 +153,11 @@ function parse_nested_expr(expr, tensors, allinds) end # the contraction tree +""" + NestedEinsum{LT} <: AbstractEinsum + +The abstract type for contraction trees. It has two subtypes, [`DynamicNestedEinsum`](@ref) and [`StaticNestedEinsum`](@ref). +""" abstract type NestedEinsum{LT} <: AbstractEinsum end """ diff --git a/src/slicing.jl b/src/slicing.jl index 1d463c0..c2804eb 100644 --- a/src/slicing.jl +++ b/src/slicing.jl @@ -95,7 +95,8 @@ function einsum(se::SlicedEinsum, @nospecialize(xs::NTuple{N,AbstractArray} wher res = get_output_array(xs, getindex.(Ref(size_dict), it.iyv)) eins_sliced = drop_slicedim(se.eins, se.slicing) for (k, slicemap) in enumerate(it) - @debug "computing slice $k/$(length(it))" + # NOTE: @debug will break Zygote + # @debug "computing slice $k/$(length(it))" xsi = ntuple(i->take_slice(xs[i], it.ixsv[i], slicemap), length(xs)) resi = einsum(eins_sliced, xsi, it.size_dict_sliced; kwargs...) res = fill_slice!(res, it.iyv, resi, slicemap)