-
Notifications
You must be signed in to change notification settings - Fork 28
feat: add gradient with AutoReactant #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| module DifferentiationInterfaceReactantExt | ||
|
|
||
| using ADTypes: ADTypes, AutoReactant | ||
| import DifferentiationInterface as DI | ||
| using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray | ||
|
|
||
| DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode) | ||
| DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode) | ||
|
|
||
| include("utils.jl") | ||
| include("onearg.jl") | ||
|
|
||
| end # module |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG} | ||
| _sig::Val{SIG} | ||
| xr::XR | ||
| gr::GR | ||
| compiled_gradient::CG | ||
| compiled_gradient!::CG! | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have different prep objects for each of the compiled variants. Reason being that one may compile whereas the other may not.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what you mean here. You want to explore all 2^4 combinations of compiled/non-compiled operator variants? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean that we should have distinct perhaps they can be templated or anything else, but they should be distinct (and therefore prepare should only compile for the one that will be used)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is not compatible with DI's API. A preparation result must allow calling all four variants of an operator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, and rip that seems like a design limitation of DI. Is it something that's fixable in time?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just because you don't like it doesn't automatically mean it is a design limitation, it could also just be a design decision you disagree with 😉 In fact, it is completely consistent with the way most backends work, otherwise I would have done it differently. And it is rather convenient for users having 1 preparation for 4 variants. It is not fixable without a breaking release, which I'm fairly reluctant to do given that we have 1000 indirect dependents.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #685 to continue this design discussion (I opened this issue a year ago ^^) |
||
| compiled_value_and_gradient::CVG | ||
| compiled_value_and_gradient!::CVG! | ||
| end | ||
|
|
||
| function DI.prepare_gradient_nokwarg( | ||
| strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} | ||
| ) where {F, C} | ||
| _sig = DI.signature(f, rebackend, x; strict) | ||
| backend = rebackend.mode | ||
| xr = to_reac(x) | ||
gdalle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| gr = to_reac(similar(x)) | ||
| contextsr = map(to_reac, contexts) | ||
| compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...) | ||
| compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...) | ||
| compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...) | ||
| compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...) | ||
| return ReactantGradientPrep( | ||
| _sig, | ||
| xr, | ||
| gr, | ||
| compiled_gradient, | ||
| compiled_gradient!, | ||
| compiled_value_and_gradient, | ||
| compiled_value_and_gradient!, | ||
| ) | ||
| end | ||
|
|
||
| function DI.gradient( | ||
| f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} | ||
| ) where {F, C} | ||
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, compiled_gradient) = prep | ||
| copyto!(xr, x) | ||
gdalle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| contextsr = map(to_reac, contexts) | ||
| gr = compiled_gradient(f, backend, xr, contextsr...) | ||
| return gr | ||
| end | ||
|
|
||
| function DI.value_and_gradient( | ||
| f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} | ||
| ) where {F, C} | ||
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, compiled_value_and_gradient) = prep | ||
| copyto!(xr, x) | ||
gdalle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| contextsr = map(to_reac, contexts) | ||
| yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...) | ||
| return yr, gr | ||
| end | ||
|
|
||
| function DI.gradient!( | ||
| f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} | ||
| ) where {F, C} | ||
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, gr, compiled_gradient!) = prep | ||
| copyto!(xr, x) | ||
gdalle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| contextsr = map(to_reac, contexts) | ||
| compiled_gradient!(f, gr, backend, xr, contextsr...) | ||
| return copyto!(grad, gr) | ||
| end | ||
|
|
||
| function DI.value_and_gradient!( | ||
| f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C} | ||
| ) where {F, C} | ||
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, gr, compiled_value_and_gradient!) = prep | ||
| copyto!(xr, x) | ||
| contextsr = map(to_reac, contexts) | ||
| yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...) | ||
| return yr, copyto!(grad, gr) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| to_reac(x::AbstractArray) = to_rarray(x) | ||
| to_reac(x::ConcreteRArray) = x | ||
| to_reac(x::Number) = ConcreteRNumber(x) | ||
| to_reac(x::ConcreteRNumber) = x | ||
|
|
||
| to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c))) | ||
| to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c))) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| using Pkg | ||
| Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl") | ||
| Pkg.add("Reactant") | ||
|
|
||
| using DifferentiationInterface | ||
| using DifferentiationInterfaceTest | ||
| using Reactant | ||
| using Test | ||
|
|
||
| backend = AutoReactant() | ||
|
|
||
| @test check_available(backend) | ||
| @test check_inplace(backend) | ||
|
|
||
| test_differentiation( | ||
| backend, DifferentiationInterfaceTest.default_scenarios(; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array |
||
| include_constantified = true, include_cachified = false | ||
| ); | ||
| excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback), | ||
| logging = false | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the experimental word, but just say the incomplete implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example it's better supported here than Diffractor listed above as fully supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diffractor is listed in the README as broken, but sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another argument for the "experimental" support is protecting us if we decide that we made an API mistake, because breaking an "experimental" feature is arguably ok within SemVer