diff --git a/Project.toml b/Project.toml index 8e7206c1..8e23fc0d 100644 --- a/Project.toml +++ b/Project.toml @@ -8,20 +8,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +AbstractFFTsEnzymeCoreExt = "EnzymeCore" [compat] ChainRulesCore = "1" +EnzymeCore = "0.3" julia = "^1.0" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "Enzyme", "Random", "Test", "Unitful"] diff --git a/ext/AbstractFFTsEnzymeCoreExt.jl b/ext/AbstractFFTsEnzymeCoreExt.jl new file mode 100644 index 00000000..75a08f78 --- /dev/null +++ b/ext/AbstractFFTsEnzymeCoreExt.jl @@ -0,0 +1,58 @@ +module AbstractFFTsEnzymeCoreExt + +using AbstractFFTs +using AbstractFFTs.LinearAlgebra +using EnzymeCore +using EnzymeCore.EnzymeRules + +###################### +# Forward-mode rules # +###################### + +const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}} + +# since FFTs are linear, implement all forward-model rules generically at a low-level + +function EnzymeRules.forward( + func::Const{typeof(mul!)}, + RT::Type{<:Const}, + y::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, + p::Const{<:AbstractFFTs.Plan{T}}, + x::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, +) where {T} + val = func.val(y.val, p.val, x.val) + if x isa Duplicated && y isa Duplicated + dval = func.val(y.dval, p.val, x.dval) + elseif x isa Duplicated && y isa Duplicated + dval = map(y.dval, x.dval) do dy, dx + return func.val(dy, p.val, dx) + end + end + return nothing +end + +function EnzymeRules.forward( + func::Const{typeof(*)}, + RT::Type{ + <:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed} + }, + p::Const{<:AbstractFFTs.Plan}, + x::DuplicatedOrBatchDuplicated{<:StridedArray}, +) + RT <: Const && return func.val(p.val, x.val) + if x isa Duplicated + dval = func.val(p.val, x.dval) + RT <: DuplicatedNoNeed && return dval + val = func.val(p.val, x.val) + RT <: Duplicated && return Duplicated(val, dval) + else # x isa BatchDuplicated + dval = map(x.dval) do dx + return func.val(p.val, dx) + end + RT <: BatchDuplicatedNoNeed && return dval + val = func.val(p.val, x.val) + RT <: BatchDuplicated && return BatchDuplicated(val, dval) + end +end + +end # module