diff --git a/src/code.jl b/src/code.jl index 38167c1d0..5c4053e71 100644 --- a/src/code.jl +++ b/src/code.jl @@ -717,6 +717,9 @@ function topological_sort(graph) visited = IdDict() function dfs(node) + if node isa BasicSymbolic + node = node.expr + end if haskey(visited, node) return visited[node] end diff --git a/src/types.jl b/src/types.jl index e66425097..b0e59ca2b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -24,19 +24,17 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT) const ENABLE_HASHCONSING = Ref(true) @compactify show_methods=false begin - @abstract mutable struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA - end - mutable struct Sym{T} <: BasicSymbolic{T} + @abstract mutable struct BasicSymbolicImpl{T} end + mutable struct Sym{T} <: BasicSymbolicImpl{T} name::Symbol = :OOF end - mutable struct Term{T} <: BasicSymbolic{T} + mutable struct Term{T} <: BasicSymbolicImpl{T} f::Any = identity # base/num if Pow; issorted if Add/Dict arguments::Vector{Any} = EMPTY_ARGS hash::RefValue{UInt} = EMPTY_HASH hash2::RefValue{UInt} = EMPTY_HASH end - mutable struct Mul{T} <: BasicSymbolic{T} + mutable struct Mul{T} <: BasicSymbolicImpl{T} coeff::Any = 0 # exp/den if Pow dict::EMPTY_DICT_T = EMPTY_DICT hash::RefValue{UInt} = EMPTY_HASH @@ -44,7 +42,7 @@ const ENABLE_HASHCONSING = Ref(true) arguments::Vector{Any} = EMPTY_ARGS issorted::RefValue{Bool} = NOT_SORTED end - mutable struct Add{T} <: BasicSymbolic{T} + mutable struct Add{T} <: BasicSymbolicImpl{T} coeff::Any = 0 # exp/den if Pow dict::EMPTY_DICT_T = EMPTY_DICT hash::RefValue{UInt} = EMPTY_HASH @@ -52,25 +50,45 @@ const ENABLE_HASHCONSING = Ref(true) arguments::Vector{Any} = EMPTY_ARGS issorted::RefValue{Bool} = NOT_SORTED end - mutable struct Div{T} <: BasicSymbolic{T} + mutable struct Div{T} <: BasicSymbolicImpl{T} num::Any = 1 den::Any = 1 simplified::Bool = false arguments::Vector{Any} = EMPTY_ARGS end - mutable struct Pow{T} <: BasicSymbolic{T} + mutable struct Pow{T} <: BasicSymbolicImpl{T} base::Any = 1 exp::Any = 1 arguments::Vector{Any} = EMPTY_ARGS end end +struct MetadataImpl + this::Metadata + children::Vector{Any} +end + +function MetadataImpl() + MetadataImpl(nothing, Vector()) +end + +@kwdef struct BasicSymbolic{T} <: Symbolic{T} + expr::BasicSymbolicImpl{T} + meta::MetadataImpl +end + +getmetaimpl(x::BasicSymbolic) = x.meta +getmetaimpl(::Any) = nothing + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + exprtype(x.expr) +end +function exprtype(expr::BasicSymbolicImpl) + @compactified expr::BasicSymbolicImpl begin Term => TERM Add => ADD Mul => MUL @@ -81,7 +99,33 @@ function exprtype(x::BasicSymbolic) end end -const wvd = WeakValueDict{UInt, BasicSymbolic}() +function Base.getproperty(x::BasicSymbolic, sym::Symbol) + if sym === :metadata + return getfield(x, :meta).this + elseif sym === :expr || sym === :meta + return getfield(x, sym) + elseif sym === :base || sym === :num + bsi = getproperty(getfield(x, :expr), sym) + if bsi isa BasicSymbolicImpl + mdi = getfield(x, :meta).children[1] + return BasicSymbolic(bsi, mdi) + else + return bsi + end + elseif sym === :exp || sym === :den + bsi = getproperty(getfield(x, :expr), sym) + if bsi isa BasicSymbolicImpl + mdi = getfield(x, :meta).children[2] + return BasicSymbolic(bsi, mdi) + else + return bsi + end + else + return getproperty(getfield(x, :expr), sym) + end +end + +const wvd = WeakValueDict{UInt, BasicSymbolicImpl}() # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @@ -96,10 +140,11 @@ const SIMPLIFIED = 0x01 << 0 #@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T - nt = getproperties(obj) - nt_new = merge(nt, patch) + expr = obj.expr + nt = getproperties(expr) + nt_new = merge(nt, (metadata = obj.metadata,), patch) # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj::BasicSymbolic begin + @compactified expr::BasicSymbolicImpl begin Sym => Sym{T}(nt_new.name; nt_new...) Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) @@ -128,9 +173,12 @@ symtype(x) = typeof(x) @inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer -@inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f +function operation(x::BasicSymbolic) + operation(x.expr) +end +@inline function operation(expr::BasicSymbolicImpl) + @compactified expr::BasicSymbolicImpl begin + Term => expr.f Add => (+) Mul => (*) Div => (/) @@ -144,7 +192,7 @@ end function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) - @compactified x::BasicSymbolic begin + @compactified x.expr::BasicSymbolicImpl begin Add => @goto ADD Mul => @goto MUL _ => return args @@ -169,7 +217,22 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + args = arguments(x.expr) + args_metadata = x.meta.children + res = Vector() + for (arg, meta) in zip(args, args_metadata) + if arg isa BasicSymbolicImpl + if isnothing(meta) + meta = MetadataImpl() + end + arg = BasicSymbolic(arg, meta) + end + push!(res, arg) + end + res +end +function TermInterface.arguments(x::BasicSymbolicImpl) + @compactified x::BasicSymbolicImpl begin Term => return x.arguments Add => @goto ADDMUL Mul => @goto ADDMUL @@ -216,10 +279,59 @@ function TermInterface.arguments(x::BasicSymbolic) return args end -isexpr(s::BasicSymbolic) = !issym(s) -iscall(s::BasicSymbolic) = isexpr(s) +""" +$(TYPEDSIGNATURES) + +For given `coeff` and `dict`, return arguments of type `BasicSymbolicImpl`, children's +metadata and a new dictionary with `BasicSymbolicImpl` as key's type for preparation of the +construction of either `Add` or `Mul`. +""" +function get_arguments_metadata(coeff, dict::AbstractDict, type::ExprType) + siz = length(dict) + idcoeff = type === ADD ? iszero(coeff) : isone(coeff) + args = Vector() + sizehint!(args, idcoeff ? siz : siz + 1) + idcoeff || push!(args, coeff) + if type === ADD + for (k, v) in dict + if k isa BasicSymbolicImpl + k = BasicSymbolic(k, MetadataImpl()) + end + push!(args, applicable(*, k, v) ? k * v : maketerm(k, *, [k, v], nothing)) + end + else # MUL + for (k, v) in dict + if k isa BasicSymbolicImpl + k = BasicSymbolic(k, MetadataImpl()) + end + push!(args, unstable_pow(k, v)) + end + end + metadata_children = map(getmetaimpl, args) + for i in 1:length(args) + if args[i] isa BasicSymbolic + args[i] = args[i].expr + end + end + keys = idcoeff ? args : @view args[2:end] + bsi_dict = Dict(zip(keys, values(dict))) + return args, metadata_children, bsi_dict +end + +isexpr(s::BasicSymbolic) = isexpr(s.expr) +isexpr(expr::BasicSymbolicImpl) = !issym(expr) +iscall(s::BasicSymbolic) = iscall(s.expr) +iscall(expr::BasicSymbolicImpl) = isexpr(expr) -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +@inline function isa_SymType(T::Val{S}, x) where {S} + if x isa BasicSymbolic + Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.expr) + elseif x isa BasicSymbolicImpl + Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x) + else + false + end +end """ issym(x) @@ -253,7 +365,10 @@ function _allarequal(xs, ys; comparator = isequal)::Bool return true end -function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} +function Base.isequal(a::BasicSymbolic, b::BasicSymbolic) + isequal(a.expr, b.expr) +end +function Base.isequal(a::BasicSymbolicImpl{T}, b::BasicSymbolicImpl{S}) where {T,S} a === b && return true E = exprtype(a) @@ -262,6 +377,12 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} T === S || return false return _isequal(a, b, E)::Bool end +function Base.isequal(a::MetadataImpl, b::MetadataImpl) + (a === b) || + (isequal_with_metadata(a.this, b.this) && + isequal_with_metadata(a.children, b.children)) +end + function _isequal(a, b, E; comparator = isequal) if E === SYM nameof(a) === nameof(b) @@ -303,7 +424,7 @@ function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool w E === exprtype(b) || return false T === S || return false - _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false + _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal(a.meta, b.meta) || return false end """ @@ -395,7 +516,14 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +Base.nameof(s::BasicSymbolic) = nameof(s.expr) +function Base.nameof(s::BasicSymbolicImpl) + if issym(s) + s.name + else + error("None Sym BasicSymbolic doesn't have a name") + end +end ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) @@ -457,8 +585,9 @@ hash2(s, salt::UInt) = hash(s, salt) function hash2(n::T, salt::UInt) where {T <: Number} hash(T, hash(n, salt)) end -hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) -function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} +hash2(s::BasicSymbolic) = hash2(s.expr, zero(UInt)) +hash2(s::BasicSymbolicImpl) = hash2(s, zero(UInt)) +function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T} E = exprtype(s) h::UInt = 0 if E === SYM @@ -488,7 +617,7 @@ function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} else error_on_type() end - h = hash(metadata(s), hash(T, h)) + h = hash(T, h) if hasproperty(s, :hash2) s.hash2[] = h end @@ -520,31 +649,40 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the original behavior of those functions. """ -function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic +function BasicSymbolicImpl(s::BasicSymbolicImpl)::BasicSymbolicImpl if !ENABLE_HASHCONSING[] return s end h = hash2(s) t = get!(wvd, h, s) - if t === s || isequal_with_metadata(t, s) + if t === s || isequal(t, s) return t else return s end end -function Sym{T}(name::Symbol; kw...) where {T} +function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T} s = Sym{T}(; name, kw...) - BasicSymbolic(s) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, Vector()) + BasicSymbolic(bsi, mdi) end -function Term{T}(f, args; kw...) where T +function Term{T}(f, args; metadata = NO_METADATA, kw...) where T if eltype(args) !== Any args = convert(Vector{Any}, args) end - + metadata_children = map(getmetaimpl, args) + for i in 1:length(args) + if args[i] isa BasicSymbolic + args[i] = args[i].expr + end + end s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) - BasicSymbolic(s) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, metadata_children) + BasicSymbolic(bsi, mdi) end function Term(f, args; metadata=NO_METADATA) @@ -563,9 +701,11 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T return Mul(T, coeff, dict) end end - - s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) - BasicSymbolic(s) + arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, ADD) + s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, metadata_children) + BasicSymbolic(bsi, mdi) end function Mul(T, a, b; metadata=NO_METADATA, kw...) @@ -580,8 +720,11 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) - BasicSymbolic(s) + arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, MUL) + s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, metadata_children) + BasicSymbolic(bsi, mdi) end end @@ -601,7 +744,7 @@ ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) if x.coeff isa Rational && isone(x.coeff.den) - Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false)) + Mul(symtype(x), x.coeff.num, x.dict; x.metadata) else x end @@ -612,7 +755,7 @@ function maybe_intcoeff(x) end end -function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} +function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where {T} if T<:Number && !(T<:SafeReal) n, d = quick_cancel(n, d) end @@ -645,9 +788,17 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} end end end - - s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata) - BasicSymbolic(s) + metadata_children = [getmetaimpl(n), getmetaimpl(d)] + if n isa BasicSymbolic + n = n.expr + end + if d isa BasicSymbolic + d = d.expr + end + s = Div{T}(; num=n, den=d, simplified, arguments=[]) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, metadata_children) + BasicSymbolic(bsi, mdi) end function Div(n,d, simplified=false; kw...) @@ -663,9 +814,18 @@ end function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} _iszero(b) && return 1 - _isone(b) && return a - s = Pow{T}(; base=a, exp=b, arguments=[], metadata) - BasicSymbolic(s) + _isone(b) && return a + metadata_children = [getmetaimpl(a), getmetaimpl(b)] + if a isa BasicSymbolic + a = a.expr + end + if b isa BasicSymbolic + b = b.expr + end + s = Pow{T}(; base=a, exp=b, arguments=[]) + bsi = BasicSymbolicImpl(s) + mdi = MetadataImpl(metadata, metadata_children) + BasicSymbolic(bsi, mdi) end function Pow(a, b; metadata = NO_METADATA, kwargs...) @@ -874,6 +1034,7 @@ _issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) issafecanon(f, ss...) = all(x->issafecanon(f, x), ss) +getmetadata(s) = metadata(s) function getmetadata(s::Symbolic, ctx) md = metadata(s) if md isa AbstractDict @@ -882,11 +1043,13 @@ function getmetadata(s::Symbolic, ctx) throw(ArgumentError("$s does not have metadata for $ctx")) end end - function getmetadata(s::Symbolic, ctx, default) md = metadata(s) md isa AbstractDict ? get(md, ctx, default) : default end +function getmetadata(d::AbstractDict, ctx) + d[ctx] +end # pirated for Setfield purposes: using Base: ImmutableDict @@ -1353,12 +1516,16 @@ function -(a::SN) end function -(a::SN, b::SN) - (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) - isadd(a) && isadd(b) ? Add(sub_t(a,b), - a.coeff - b.coeff, - _merge(-, a.dict, - b.dict, - filter=_iszero)) : a + (-b) + if !issafecanon(+, a) || !issafecanon(*, b) + return term(-, a, b) + elseif isadd(a) && isadd(b) + t = sub_t(a, b) + c = a.coeff - b.coeff + d = _merge(-, a.dict, b.dict, filter = _iszero) + return Add(t, c, d) + else + return a + (-b) + end end -(a::Number, b::SN) = a + (-b) diff --git a/test/hash_consing.jl b/test/hash_consing.jl index 02b678b27..46a32bebb 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -7,23 +7,23 @@ struct Ctx2 end @testset "Sym" begin x1 = only(@syms x) x2 = only(@syms x) - @test x1 === x2 + @test x1.expr === x2.expr x3 = only(@syms x::Float64) - @test x1 !== x3 + @test x1.expr !== x3.expr x4 = only(@syms x::Float64) - @test x1 !== x4 - @test x3 === x4 + @test x1.expr !== x4.expr + @test x3.expr === x4.expr x5 = only(@syms x::Int) x6 = only(@syms x::Int) - @test x1 !== x5 - @test x3 !== x5 - @test x5 === x6 + @test x1.expr !== x5.expr + @test x3.expr !== x5.expr + @test x5.expr === x6.expr xm1 = setmetadata(x1, Ctx1, "meta_1") xm2 = setmetadata(x1, Ctx1, "meta_1") - @test xm1 === xm2 + @test xm1.expr === xm2.expr xm3 = setmetadata(x1, Ctx2, "meta_2") - @test xm1 !== xm3 + @test xm1.expr === xm3.expr end @syms a b c @@ -31,73 +31,73 @@ end @testset "Term" begin t1 = sin(a) t2 = sin(a) - @test t1 === t2 + @test t1.expr === t2.expr t3 = Term(identity,[a]) t4 = Term(identity,[a]) - @test t3 === t4 + @test t3.expr === t4.expr t5 = Term{Int}(identity,[a]) - @test t3 !== t5 + @test t3.expr !== t5.expr tm1 = setmetadata(t1, Ctx1, "meta_1") - @test t1 !== tm1 + @test t1.expr === tm1.expr end @testset "Add" begin d1 = a + b d2 = b + a - @test d1 === d2 + @test d1.expr === d2.expr d3 = b - 2 + a d4 = a + b - 2 - @test d3 === d4 + @test d3.expr === d4.expr d5 = Add(Int, 0, Dict(a => 1, b => 1)) - @test d5 !== d1 + @test d5.expr !== d1.expr dm1 = setmetadata(d1,Ctx1,"meta_1") - @test d1 !== dm1 + @test d1.expr === dm1.expr end @testset "Mul" begin m1 = a*b m2 = b*a - @test m1 === m2 + @test m1.expr === m2.expr m3 = 6*a*b m4 = 3*a*2*b - @test m3 === m4 + @test m3.expr === m4.expr m5 = Mul(Int, 1, Dict(a => 1, b => 1)) - @test m5 !== m1 + @test m5.expr !== m1.expr mm1 = setmetadata(m1, Ctx1, "meta_1") - @test m1 !== mm1 + @test m1.expr === mm1.expr end @testset "Div" begin v1 = a/b v2 = a/b - @test v1 === v2 + @test v1.expr === v2.expr v3 = -1/a v4 = -1/a - @test v3 === v4 + @test v3.expr === v4.expr v5 = 3a/6 v6 = 2a/4 - @test v5 === v6 + @test v5.expr === v6.expr v7 = Div{Float64}(-1,a) - @test v7 !== v3 + @test v7.expr !== v3.expr vm1 = setmetadata(v1,Ctx1, "meta_1") - @test vm1 !== v1 + @test vm1.expr === v1.expr end @testset "Pow" begin p1 = a^b p2 = a^b - @test p1 === p2 + @test p1.expr === p2.expr p3 = a^(2^-b) p4 = a^(2^-b) - @test p3 === p4 + @test p3.expr === p4.expr p5 = Pow{Float64}(a,b) - @test p1 !== p5 + @test p1.expr !== p5.expr pm1 = setmetadata(p1,Ctx1, "meta_1") - @test pm1 !== p1 + @test pm1.expr === p1.expr end @testset "Equivalent numbers" begin @@ -114,7 +114,7 @@ end a1 = setmetadata(a, Int, b) b1 = setmetadata(b, Int, 3) a2 = setmetadata(a, Int, b1) - @test a1 !== a2 + @test a1.expr === a2.expr @test !SymbolicUtils.isequal_with_metadata(a1, a2) @test metadata(metadata(a1)[Int]) === nothing @test metadata(metadata(a2)[Int])[Int] == 3 @@ -123,7 +123,7 @@ end @testset "Compare metadata of expression tree" begin @syms a b aa = setmetadata(a, Int, b) - @test aa !== a + @test aa.expr === a.expr @test isequal(a, aa) @test !SymbolicUtils.isequal_with_metadata(a, aa) @test !SymbolicUtils.isequal_with_metadata(2a, 2aa) @@ -144,6 +144,6 @@ end h = SymbolicUtils.hash2(ex) @test h == ex.hash2[] ex2 = setmetadata(ex, Int, 3) - @test ex2.hash2[] != h + @test ex2.hash2[] == h end end