Skip to content

Commit

Permalink
Polish AD and unify some interfaces (#129)
Browse files Browse the repository at this point in the history
* matrix type does not match

* use ProjectTo

* update OMEinsum

* update docstring
  • Loading branch information
GiggleLiu authored Dec 4, 2021
1 parent 55cbe75 commit 105fade
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OMEinsum"
uuid = "ebe7aa44-baf0-506c-a96f-8464559b3922"
authors = ["Andreas Peter <[email protected]>"]
version = "0.6.3"
version = "0.6.4"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
28 changes: 27 additions & 1 deletion src/Core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export EinCode, EinIndexer, EinArray
export einarray
export einarray, getiyv, getixsv

"""
EinCode
Expand All @@ -21,7 +21,33 @@ struct StaticEinCode{ixs, iy} <: EinCode end
getixs(::StaticEinCode{ixs}) where ixs = ixs
getiy(::StaticEinCode{ixs, iy}) where {ixs, iy} = iy
labeltype(::StaticEinCode{ixs,iy}) where {ixs, iy} = promote_type(eltype.(ixs)..., eltype(iy))
"""
getixsv(code)
Get labels of input tensors for `EinCode`, `NestedEinsum` and some other einsum like objects.
Returns a vector of vector.
```jldoctest; setup = :(using OMEinsum)
julia> getixsv(ein"(ij,jk),k->i")
3-element Vector{Vector{Char}}:
['i', 'j']
['j', 'k']
['k']
```
"""
getixsv(code::StaticEinCode) = [collect(labeltype(code), ix) for ix in getixs(code)]
"""
getiy(code)
Get labels of the output tensor for `EinCode`, `NestedEinsum` and some other einsum like objects.
Returns a vector.
```jldoctest; setup = :(using OMEinsum)
julia> getiyv(ein"(ij,jk),k->i")
1-element Vector{Char}:
'i': ASCII/Unicode U+0069 (category Ll: Letter, lowercase)
```
"""
getiyv(code::StaticEinCode) = collect(labeltype(code), getiy(code))

"""
Expand Down
5 changes: 1 addition & 4 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ function einsum_grad(ixs, @nospecialize(xs), iy, size_dict, cdy, i)
nxs = _insertat( xs, i, cdy)
niy = ixs[i]
y = einsum(DynamicEinCode(nixs, niy), nxs, size_dict)
y = conj(y) # do not use `conj!` to help computing Hessians.
typeof(y) == typeof(xs[i]) && return y
xs[i] isa Array{<:Real} && return convert(typeof(xs[i]), real(y))
convert(typeof(xs[i]), y)
return ChainRulesCore.ProjectTo(xs[i])(conj(y)) # do not use `conj!` because we want to support Hessians.
end

function ChainRulesCore.rrule(::typeof(einsum), code::EinCode, @nospecialize(xs), size_dict)
Expand Down
2 changes: 2 additions & 0 deletions src/deprecation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ end
@deprecate dynamic_einsum(ixs, xs, iy; size_info=nothing) einsum(DynamicEinCode(ixs, iy), xs; size_info=size_info)
@deprecate dynamic_einsum(code::EinCode, xs; size_info=nothing) code(xs...; size_info=size_info)
@deprecate dynamic_einsum(code::NestedEinsum, xs; size_info=nothing) code(xs...; size_info=size_info)

@deprecate collect_ixs getixsv
6 changes: 4 additions & 2 deletions src/einsequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ function get_size_dict!(ne::NestedEinsum, @nospecialize(xs), size_info::Dict{LT}
return get_size_dict_!(ixs, [collect(Int, size(xs[i])) for i in ks], size_info)
end

collect_ixs(ne::EinCode) = [_collect(ix) for ix in getixs(ne)]
function collect_ixs(ne::NestedEinsum)
function getixsv(ne::NestedEinsum)
d = OMEinsum.collect_ixs!(ne, Dict{Int,Vector{OMEinsum.labeltype(ne.eins)}}())
ks = sort!(collect(keys(d)))
return @inbounds [d[i] for i in ks]
Expand Down Expand Up @@ -291,3 +290,6 @@ function flatten(code::NestedEinsum)
StaticEinCode{ntuple(i->(ixd[i]...,), length(ixd)), OMEinsum.getiy(code.eins)}()
end
end

labeltype(ne::NestedEinsum) = labeltype(ne.eins)
getiyv(ne::NestedEinsum) = getiyv(ne.eins)
5 changes: 3 additions & 2 deletions test/einsequence.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test, OMEinsum
using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf, collect_ixs
using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf, getixsv, getiyv
@testset "einsequence" begin
@test push!(IndexGroup([],1), 'c').inds == IndexGroup(['c'], 1).inds
@test isempty(IndexGroup([],1))
Expand All @@ -16,7 +16,8 @@ using OMEinsum: IndexGroup, NestedEinsum, parse_nested, DynamicEinCode, isleaf,
size_info = Dict('k'=>2)
a, b, c, d = randn(2), randn(2,2), randn(2), randn(2)
@test ein"((i,ij),i),j->ik"(a, b, c, d; size_info=size_info) ein"i,ij,i,j->ik"(a, b, c, d; size_info=size_info)
@test collect_ixs(ein"((i,ij),i),j->ik") == collect_ixs(ein"i,ij,i,j->ik") == collect_ixs(DynamicEinCode(ein"i,ij,i,j->ik")) == [['i'], ['i','j'], ['i'], ['j']]
@test getixsv(ein"((i,ij),i),j->ik") == getixsv(ein"i,ij,i,j->ik") == getixsv(DynamicEinCode(ein"i,ij,i,j->ik")) == [['i'], ['i','j'], ['i'], ['j']]
@test getiyv(ein"((i,ij),i),j->ik") == getiyv(ein"i,ij,i,j->ik") == getiyv(DynamicEinCode(ein"i,ij,i,j->ik")) == ['i','k']
end

@testset "macro" begin
Expand Down

0 comments on commit 105fade

Please sign in to comment.