Skip to content

Commit 6e798c2

Browse files
Merge pull request #29 from wsmoses/er
Migrate to Enzyme easy_rule
2 parents c9a79db + 3bfd459 commit 6e798c2

File tree

2 files changed

+5
-49
lines changed

2 files changed

+5
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ FastPowerReverseDiffExt = "ReverseDiff"
2424
FastPowerTrackerExt = "Tracker"
2525

2626
[compat]
27-
Enzyme = "0.13"
27+
Enzyme = "0.13.89"
2828
ForwardDiff = "0.10, 1"
2929
Measurements = "2"
3030
MonteCarloMeasurements = "1"

ext/FastPowerEnzymeExt.jl

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,9 @@ import FastPower: fastpower
55
using Enzyme
66
using Enzyme.EnzymeRules: FwdConfig
77

8-
function Enzyme.EnzymeRules.forward(config::FwdConfig,
9-
func::Const{typeof(FastPower.fastpower)},
10-
RT::Type{<:Union{Duplicated, DuplicatedNoNeed}},
11-
_x::Union{Const, Duplicated}, _y::Union{Const, Duplicated})
12-
x = _x.val
13-
y = _y.val
14-
ret = func.val(x, y)
15-
T = typeof(ret)
16-
if !(_x isa Const)
17-
dxval = _x.dval * y * (fastpower(x, y - 1))
18-
else
19-
dxval = make_zero(_x.val)
20-
end
21-
if !(_y isa Const)
22-
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) :
23-
_y.dval*(fastpower(x, y))*log(x)
24-
else
25-
dyval = make_zero(_y.val)
26-
end
27-
if RT <: DuplicatedNoNeed
28-
return convert(T, dxval + dyval)
29-
else
30-
return Duplicated(ret, convert(T, dxval + dyval))
31-
end
32-
end
33-
34-
function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
35-
func::Const{typeof(fastpower)}, ::Union{Type{<:Active}, Type{<:Const}},
36-
x::Union{Const, Active}, y::Union{Const, Active})
37-
if EnzymeRules.needs_primal(config)
38-
primal = func.val(x.val, y.val)
39-
else
40-
primal = nothing
41-
end
42-
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
43-
end
44-
45-
function EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
46-
func::Const{typeof(fastpower)}, dret, tape, _x::Union{Const, Active}, _y::Union{
47-
Const, Active})
48-
x = _x.val
49-
y = _y.val
50-
dxval = _x isa Const ? nothing : dret.val * y * (fastpower(x, y - 1))
51-
dyval = _y isa Const ? nothing :
52-
(x isa Real && x<=0 ? Base.oftype(float(x), NaN) :
53-
dret.val * (fastpower(x, y)) * log(x))
54-
return (dxval, dyval)
55-
end
8+
Enzyme.EnzymeRules.@easy_rule(
9+
FastPower.fastpower(x, y),
10+
( y * fastpower(x, y - 1), Ω * log(x) )
11+
)
5612

5713
end

0 commit comments

Comments
 (0)