diff --git a/src/types.jl b/src/types.jl index 683f58d44..97d97b5a7 100644 --- a/src/types.jl +++ b/src/types.jl @@ -102,7 +102,7 @@ end """ $(SIGNATURES) -Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) +Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) of `x`. By default this is just `typeof(x)`. Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules specific to numbers (such as commutativity of multiplication). Or such @@ -561,9 +561,9 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) st = symtype(T) pst = _promote_symtype(head, args) # Use promoted symtype only if not a subtype of the existing symtype of T. - # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` - # Where the result would have a symtype of Bool. - # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 + # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # Where the result would have a symtype of Bool. + # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 # TODO this should be optimized. new_st = if st <: AbstractArray st @@ -816,27 +816,37 @@ function show_ref(io, f, args) print(io, "]") end +import Base.nameof +# To fall through the `nameof` in the `show_call` below +Base.nameof(f, arg, args...) = nameof(f) + +""" + show_call(io, f, args) +Displays the function call with given args. There are different outputs if `f` +is unary, binary or otherwise. `f`'s output can also be decorated using +`Base.nameof` provided with the function as well as with the `symtype` +of `f`'s arguments. +""" function show_call(io, f, args) - fname = iscall(f) ? Symbol(repr(f)) : nameof(f) + fname = nameof(f, symtype.(args)...) + frep = Symbol(repr(f)) + len_args = length(args) - if Base.isunaryoperator(fname) && len_args == 1 + + if Base.isunaryoperator(frep) && len_args == 1 print(io, "$fname") print_arg(io, first(args), paren=true) - elseif Base.isbinaryoperator(fname) && len_args > 1 + elseif Base.isbinaryoperator(frep) && len_args > 1 for (i, t) in enumerate(args) i != 1 && print(io, " $fname ") print_arg(io, t, paren=true) end else - if issym(f) - Base.show_unquoted(io, nameof(f)) - else - Base.show(io, f) - end + print(io, "$fname") print(io, "(") - for i=1:length(args) + for i=1:len_args print(io, args[i]) - i != length(args) && print(io, ", ") + i != len_args && print(io, ", ") end print(io, ")") end diff --git a/test/basics.jl b/test/basics.jl index 1402f9aca..23945115a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -3,12 +3,13 @@ using SymbolicUtils using IfElse: ifelse using Setfield using Test, ReferenceTests +import Base.nameof include("utils.jl") @testset "@syms" begin let - @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int + @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int @test issym(a) && symtype(a) == Number @test a.name === :a @@ -235,6 +236,38 @@ end @test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14))) end +let + +sq(x) = return SymbolicUtils.Term{Number}(sq, [x]) + +function Base.nameof(::typeof(sq), arg) + if arg <: Real + return :sqrt_R + elseif arg <: Complex + return :sqrt_C + else + return :sqrt + end +end + +@testset "call printing" begin + get_print(sym) = begin b = IOBuffer(); print(b, sym); String(take!(b)); end + + x,y,z = @syms x::Real y::Complex z + @syms e() f(x) g(x,y) h(x,y,z) + + @test get_print(e()) == "e()" + @test get_print(f(x)) == "f(x)" + @test get_print(g(x,y)) == "g(x, y)" + @test get_print(h(x,y,z)) == "h(x, y, z)" + + @test get_print(sq(x)) == "sqrt_R(x)" + @test get_print(sq(y)) == "sqrt_C(y)" + @test get_print(sq(z)) == "sqrt(z)" +end + +end + @testset "maketerm" begin @syms a b c @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) @@ -249,7 +282,7 @@ end # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") metadata2 = Base.ImmutableDict{DataType, Any}(Ctx2, "meta_2") - + d = b * c @set! d.metadata = metadata2 @@ -277,12 +310,12 @@ end @test symtype(new_expr) == Bool # Doesn't know return type, promoted symtype is Any - foo(x,y) = x^2 + x + foo(x,y) = x^2 + x new_expr = SymbolicUtils.maketerm(typeof(ref_expr), foo, [a, b], nothing) @test symtype(new_expr) == Number # Promoted symtype is a subtype of referred - @syms x::Int y::Int + @syms x::Int y::Int new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (+), [x, y], nothing) @test symtype(new_expr) == Int64 @@ -384,5 +417,5 @@ end ax = adjoint(x) @test isequal(ax, x) @test ax === x - @test isequal(adjoint(y), conj(y)) + @test isequal(adjoint(y), conj(y)) end