Skip to content

Commit 5f43d01

Browse files
refactor: don't use NaNMath.pow in codegen rewriters if integral exponent
1 parent da3bd6d commit 5f43d01

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

src/code.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,13 @@ function function_to_expr(op::typeof(^), O, st)
146146
return toexpr(Term(inv, Any[ex]), st)
147147
else
148148
args = Any[Term(inv, Any[ex]), -args[2]]
149-
op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow
149+
op = get(st.rewrites, :nanmath, false) || args[2] isa Integer ? op : NaNMath.pow
150150
return toexpr(Term(op, args), st)
151151
end
152152
end
153-
get(st.rewrites, :nanmath, false) === true || return nothing
153+
if !get(st.rewrites, :nanmath, false) || args[2] isa Integer
154+
return nothing
155+
end
154156
return toexpr(Term(NaNMath.pow, args), st)
155157
end
156158

test/code.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,19 @@ nanmath_st.rewrites[:nanmath] = true
100100
@test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b))
101101

102102
@test toexpr(a^2) == :($(^)(a, 2))
103-
@test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2))
103+
@test toexpr(a^2, nanmath_st) == :($(^)(a, 2))
104104
@test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2))
105-
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2))
105+
@test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(^)(a, 2))
106106

107107
@test toexpr(a^-1) == :($(/)(1, a))
108108
@test toexpr(a^-1, nanmath_st) == :($(/)(1, a))
109109
@test toexpr(NaNMath.pow(a, -1)) == :($(inv)(a))
110110
@test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a))
111111

112112
@test toexpr(a^-2) == :($(/)(1, $(^)(a, 2)))
113-
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2)))
114-
@test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)($(inv)(a), 2))
115-
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)($(inv)(a), 2))
113+
@test toexpr(a^-2, nanmath_st) == :($(/)(1, $(^)(a, 2)))
114+
@test toexpr(NaNMath.pow(a, -2)) == :($(^)($(inv)(a), 2))
115+
@test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(^)($(inv)(a), 2))
116116

117117
f = GlobalRef(NaNMath, :sin)
118118
test_repr(toexpr(LiteralExpr(:(let x=1, y=2

0 commit comments

Comments
 (0)