diff --git a/src/code.jl b/src/code.jl index 8b1953b20..774c26931 100644 --- a/src/code.jl +++ b/src/code.jl @@ -140,18 +140,18 @@ end function function_to_expr(op::typeof(^), O, st) args = arguments(O) - if length(args) == 2 && args[2] isa Real && args[2] < 0 - ex = args[1] - if args[2] == -1 - return toexpr(Term(inv, Any[ex]), st) - else - args = Any[Term(inv, Any[ex]), -args[2]] - op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow - return toexpr(Term(op, args), st) - end + if args[2] isa Real && args[2] < 0 + args[1] = Term(inv, Any[args[1]]) + args[2] = -args[2] + end + if isequal(args[2], 1) + return toexpr(args[1], st) + end + if get(st.rewrites, :nanmath, false) === true && !(args[2] isa Integer) + op = NaNMath.pow + return toexpr(Term(op, args), st) end - get(st.rewrites, :nanmath, false) === true || return nothing - return toexpr(Term(NaNMath.pow, args), st) + return nothing end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) diff --git a/test/code.jl b/test/code.jl index 0e25437a1..918ef9dcc 100644 --- a/test/code.jl +++ b/test/code.jl @@ -100,19 +100,19 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(NaNMath.pow(a, b), nanmath_st) == :($(NaNMath.pow)(a, b)) @test toexpr(a^2) == :($(^)(a, 2)) - @test toexpr(a^2, nanmath_st) == :($(NaNMath.pow)(a, 2)) - @test toexpr(NaNMath.pow(a, 2)) == :($(NaNMath.pow)(a, 2)) - @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(NaNMath.pow)(a, 2)) + @test toexpr(a^2, nanmath_st) == :($(^)(a, 2)) + @test toexpr(NaNMath.pow(a, 2)) == :($(^)(a, 2)) + @test toexpr(NaNMath.pow(a, 2), nanmath_st) == :($(^)(a, 2)) @test toexpr(a^-1) == :($(/)(1, a)) @test toexpr(a^-1, nanmath_st) == :($(/)(1, a)) - @test toexpr(NaNMath.pow(a, -1)) == :($(NaNMath.pow)(a, -1)) - @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(NaNMath.pow)(a, -1)) + @test toexpr(NaNMath.pow(a, -1)) == :($(inv)(a)) + @test toexpr(NaNMath.pow(a, -1), nanmath_st) == :($(inv)(a)) @test toexpr(a^-2) == :($(/)(1, $(^)(a, 2))) - @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(NaNMath.pow)(a, 2))) - @test toexpr(NaNMath.pow(a, -2)) == :($(NaNMath.pow)(a, -2)) - @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(NaNMath.pow)(a, -2)) + @test toexpr(a^-2, nanmath_st) == :($(/)(1, $(^)(a, 2))) + @test toexpr(NaNMath.pow(a, -2)) == :($(^)($(inv)(a), 2)) + @test toexpr(NaNMath.pow(a, -2), nanmath_st) == :($(^)($(inv)(a), 2)) f = GlobalRef(NaNMath, :sin) test_repr(toexpr(LiteralExpr(:(let x=1, y=2