From 2ab70b35aee531573ec5338e311f7a6d9d8ae323 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 00:48:16 -0500 Subject: [PATCH 01/26] Create `BasicSymbolicImpl` struct to separate metadata from hash consing [skip ci] --- src/types.jl | 88 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/src/types.jl b/src/types.jl index e66425097..6c7b610df 100644 --- a/src/types.jl +++ b/src/types.jl @@ -24,19 +24,17 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT) const ENABLE_HASHCONSING = Ref(true) @compactify show_methods=false begin - @abstract mutable struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA - end - mutable struct Sym{T} <: BasicSymbolic{T} + @abstract mutable struct BasicSymbolicImpl{T} end + mutable struct Sym{T} <: BasicSymbolicImpl{T} name::Symbol = :OOF end - mutable struct Term{T} <: BasicSymbolic{T} + mutable struct Term{T} <: BasicSymbolicImpl{T} f::Any = identity # base/num if Pow; issorted if Add/Dict arguments::Vector{Any} = EMPTY_ARGS hash::RefValue{UInt} = EMPTY_HASH hash2::RefValue{UInt} = EMPTY_HASH end - mutable struct Mul{T} <: BasicSymbolic{T} + mutable struct Mul{T} <: BasicSymbolicImpl{T} coeff::Any = 0 # exp/den if Pow dict::EMPTY_DICT_T = EMPTY_DICT hash::RefValue{UInt} = EMPTY_HASH @@ -44,7 +42,7 @@ const ENABLE_HASHCONSING = Ref(true) arguments::Vector{Any} = EMPTY_ARGS issorted::RefValue{Bool} = NOT_SORTED end - mutable struct Add{T} <: BasicSymbolic{T} + mutable struct Add{T} <: BasicSymbolicImpl{T} coeff::Any = 0 # exp/den if Pow dict::EMPTY_DICT_T = EMPTY_DICT hash::RefValue{UInt} = EMPTY_HASH @@ -52,25 +50,33 @@ const ENABLE_HASHCONSING = Ref(true) arguments::Vector{Any} = EMPTY_ARGS issorted::RefValue{Bool} = NOT_SORTED end - mutable struct Div{T} <: BasicSymbolic{T} + mutable struct Div{T} <: BasicSymbolicImpl{T} num::Any = 1 den::Any = 1 simplified::Bool = false arguments::Vector{Any} = EMPTY_ARGS end - mutable struct Pow{T} <: BasicSymbolic{T} + mutable struct Pow{T} <: BasicSymbolicImpl{T} base::Any = 1 exp::Any = 1 arguments::Vector{Any} = EMPTY_ARGS end end +@kwdef struct BasicSymbolic{T} <: Symbolic{T} + impl::BasicSymbolicImpl{T} + metadata::Metadata = NO_METADATA +end + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + exprtype(x.impl) +end +function exprtype(impl::BasicSymbolicImpl) + @compactified impl::BasicSymbolicImpl begin Term => TERM Add => ADD Mul => MUL @@ -81,7 +87,15 @@ function exprtype(x::BasicSymbolic) end end -const wvd = WeakValueDict{UInt, BasicSymbolic}() +function Base.getproperty(x::BasicSymbolic, sym::Symbol) + if sym === :metadata || sym === :impl + return getfield(x, sym) + else + return getproperty(x.impl, sym) + end +end + +const wvd = WeakValueDict{UInt, BasicSymbolicImpl}() # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @@ -99,7 +113,7 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple nt = getproperties(obj) nt_new = merge(nt, patch) # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj::BasicSymbolic begin + @compactified obj.impl::BasicSymbolicImpl begin Sym => Sym{T}(nt_new.name; nt_new...) Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) @@ -128,9 +142,12 @@ symtype(x) = typeof(x) @inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer -@inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f +function operation(x::BasicSymbolic) + operation(x.impl) +end +@inline function operation(impl::BasicSymbolicImpl) + @compactified impl::BasicSymbolicImpl begin + Term => impl.f Add => (+) Mul => (*) Div => (/) @@ -144,7 +161,7 @@ end function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) - @compactified x::BasicSymbolic begin + @compactified x.impl::BasicSymbolicImpl begin Add => @goto ADD Mul => @goto MUL _ => return args @@ -169,7 +186,10 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + arguments(x.impl) +end +function TermInterface.arguments(x::BasicSymbolicImpl) + @compactified x::BasicSymbolicImpl begin Term => return x.arguments Add => @goto ADDMUL Mul => @goto ADDMUL @@ -219,7 +239,15 @@ end isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +@inline function isa_SymType(T::Val{S}, x) where {S} + if x isa BasicSymbolic + Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.impl) + elseif x isa BasicSymbolicImpl + Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x) + else + false + end +end """ issym(x) @@ -395,7 +423,14 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +Base.nameof(s::BasicSymbolic) = nameof(s.impl) +function Base.nameof(s::BasicSymbolicImpl) + if issym(s) + s.name + else + error("None Sym BasicSymbolic doesn't have a name") + end +end ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) @@ -458,7 +493,8 @@ function hash2(n::T, salt::UInt) where {T <: Number} hash(T, hash(n, salt)) end hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) -function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} +hash2(s::BasicSymbolicImpl) = hash2(s, zero(UInt)) +function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T} E = exprtype(s) h::UInt = 0 if E === SYM @@ -520,7 +556,7 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the original behavior of those functions. """ -function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic +function BasicSymbolicImpl(s::BasicSymbolicImpl)::BasicSymbolicImpl if !ENABLE_HASHCONSING[] return s end @@ -533,18 +569,20 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic end end -function Sym{T}(name::Symbol; kw...) where {T} +function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T} s = Sym{T}(; name, kw...) - BasicSymbolic(s) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end -function Term{T}(f, args; kw...) where T +function Term{T}(f, args; metadata = NO_METADATA, kw...) where T if eltype(args) !== Any args = convert(Vector{Any}, args) end s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) - BasicSymbolic(s) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end function Term(f, args; metadata=NO_METADATA) From 17408e8925ee0117f91897ba1e4f9c9982a61941 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 12:29:03 -0500 Subject: [PATCH 02/26] Rename `BasicSymbolicImpl` field from `impl` to `expr` [skip ci] --- src/types.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/types.jl b/src/types.jl index 6c7b610df..82dbcd95b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -64,7 +64,7 @@ const ENABLE_HASHCONSING = Ref(true) end @kwdef struct BasicSymbolic{T} <: Symbolic{T} - impl::BasicSymbolicImpl{T} + expr::BasicSymbolicImpl{T} metadata::Metadata = NO_METADATA end @@ -73,10 +73,10 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) end function exprtype(x::BasicSymbolic) - exprtype(x.impl) + exprtype(x.expr) end -function exprtype(impl::BasicSymbolicImpl) - @compactified impl::BasicSymbolicImpl begin +function exprtype(expr::BasicSymbolicImpl) + @compactified expr::BasicSymbolicImpl begin Term => TERM Add => ADD Mul => MUL @@ -88,10 +88,10 @@ function exprtype(impl::BasicSymbolicImpl) end function Base.getproperty(x::BasicSymbolic, sym::Symbol) - if sym === :metadata || sym === :impl + if sym === :metadata || sym === :expr return getfield(x, sym) else - return getproperty(x.impl, sym) + return getproperty(x.expr, sym) end end @@ -113,7 +113,7 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple nt = getproperties(obj) nt_new = merge(nt, patch) # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj.impl::BasicSymbolicImpl begin + @compactified obj.expr::BasicSymbolicImpl begin Sym => Sym{T}(nt_new.name; nt_new...) Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) @@ -143,11 +143,11 @@ symtype(x) = typeof(x) # We're returning a function pointer function operation(x::BasicSymbolic) - operation(x.impl) + operation(x.expr) end -@inline function operation(impl::BasicSymbolicImpl) - @compactified impl::BasicSymbolicImpl begin - Term => impl.f +@inline function operation(expr::BasicSymbolicImpl) + @compactified expr::BasicSymbolicImpl begin + Term => expr.f Add => (+) Mul => (*) Div => (/) @@ -161,7 +161,7 @@ end function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) - @compactified x.impl::BasicSymbolicImpl begin + @compactified x.expr::BasicSymbolicImpl begin Add => @goto ADD Mul => @goto MUL _ => return args @@ -186,7 +186,7 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - arguments(x.impl) + arguments(x.expr) end function TermInterface.arguments(x::BasicSymbolicImpl) @compactified x::BasicSymbolicImpl begin @@ -241,7 +241,7 @@ iscall(s::BasicSymbolic) = isexpr(s) @inline function isa_SymType(T::Val{S}, x) where {S} if x isa BasicSymbolic - Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.impl) + Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.expr) elseif x isa BasicSymbolicImpl Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x) else @@ -423,7 +423,7 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = nameof(s.impl) +Base.nameof(s::BasicSymbolic) = nameof(s.expr) function Base.nameof(s::BasicSymbolicImpl) if issym(s) s.name From 295a1dfa70bf94923402e214aa3fb51f6e783d46 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 15:42:34 -0500 Subject: [PATCH 03/26] Adapt `BasicSymbolic` hash consing constructor --- src/types.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/types.jl b/src/types.jl index 82dbcd95b..2ed771435 100644 --- a/src/types.jl +++ b/src/types.jl @@ -602,8 +602,9 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T end end - s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) - BasicSymbolic(s) + s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end function Mul(T, a, b; metadata=NO_METADATA, kw...) @@ -618,8 +619,9 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) - BasicSymbolic(s) + s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end end @@ -650,7 +652,7 @@ function maybe_intcoeff(x) end end -function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} +function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where {T} if T<:Number && !(T<:SafeReal) n, d = quick_cancel(n, d) end @@ -684,8 +686,9 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} end end - s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata) - BasicSymbolic(s) + s = Div{T}(; num=n, den=d, simplified, arguments=[]) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end function Div(n,d, simplified=false; kw...) @@ -702,8 +705,9 @@ end function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} _iszero(b) && return 1 _isone(b) && return a - s = Pow{T}(; base=a, exp=b, arguments=[], metadata) - BasicSymbolic(s) + s = Pow{T}(; base=a, exp=b, arguments=[]) + bsi = BasicSymbolicImpl(s) + BasicSymbolic(bsi, metadata) end function Pow(a, b; metadata = NO_METADATA, kwargs...) From 1fe09e5ca0752a279b9624be018560e91f9495c8 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 15:59:40 -0500 Subject: [PATCH 04/26] Adapt `setproperties` with new `BasicSymbolic` struct --- src/types.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 2ed771435..6eb0f7b4c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -110,10 +110,11 @@ const SIMPLIFIED = 0x01 << 0 #@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T - nt = getproperties(obj) - nt_new = merge(nt, patch) + expr = obj.expr + nt = getproperties(expr) + nt_new = merge(nt, (metadata = obj.metadata,), patch) # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj.expr::BasicSymbolicImpl begin + @compactified expr::BasicSymbolicImpl begin Sym => Sym{T}(nt_new.name; nt_new...) Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0))) From b4932e65ec13dc71e45c6b7e02beec86135fe3ad Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 16:00:14 -0500 Subject: [PATCH 05/26] Call custom constructor of `Mul` in `maybe_intcoeff` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 6eb0f7b4c..a742b373b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -642,7 +642,7 @@ ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) if x.coeff isa Rational && isone(x.coeff.den) - Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false)) + Mul(symtype(x), x.coeff.num, x.dict; x.metadata) else x end From 9da1a9e4bb6d864d686653fef6287a9b406331b2 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 16:25:57 -0500 Subject: [PATCH 06/26] Fix: `hash2` of `BasicSymbolic` equals its `BasicSymbolicImpl` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index a742b373b..900dc3e7f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -493,7 +493,7 @@ hash2(s, salt::UInt) = hash(s, salt) function hash2(n::T, salt::UInt) where {T <: Number} hash(T, hash(n, salt)) end -hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) +hash2(s::BasicSymbolic) = hash2(s.expr, zero(UInt)) hash2(s::BasicSymbolicImpl) = hash2(s, zero(UInt)) function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T} E = exprtype(s) From 56c990cc519e51b71de4ac4533b462919a26528f Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 17:40:11 -0500 Subject: [PATCH 07/26] Operate CSE `topological_sort` `dfs` on `BasicSymbolicImpl` --- src/code.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/code.jl b/src/code.jl index 38167c1d0..5c4053e71 100644 --- a/src/code.jl +++ b/src/code.jl @@ -717,6 +717,9 @@ function topological_sort(graph) visited = IdDict() function dfs(node) + if node isa BasicSymbolic + node = node.expr + end if haskey(visited, node) return visited[node] end From 4fadeab4c8a5e708a11f76e22436fb8a30114d69 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 17:40:38 -0500 Subject: [PATCH 08/26] Add `isexpr` & `iscall` methods for `BasicSymbolicImpl` --- src/types.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index 900dc3e7f..3a844ed5a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -237,8 +237,10 @@ function TermInterface.arguments(x::BasicSymbolicImpl) return args end -isexpr(s::BasicSymbolic) = !issym(s) -iscall(s::BasicSymbolic) = isexpr(s) +isexpr(s::BasicSymbolic) = isexpr(s.expr) +isexpr(expr::BasicSymbolicImpl) = !issym(expr) +iscall(s::BasicSymbolic) = iscall(s.expr) +iscall(expr::BasicSymbolicImpl) = isexpr(expr) @inline function isa_SymType(T::Val{S}, x) where {S} if x isa BasicSymbolic From d7737dd819217942a619fc9105881448411a2ddc Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 13 Feb 2025 17:42:51 -0500 Subject: [PATCH 09/26] Remove `metadata` from `hash2` [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 3a844ed5a..50c2ffd8d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -527,7 +527,7 @@ function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T} else error_on_type() end - h = hash(metadata(s), hash(T, h)) + h = hash(T, h) if hasproperty(s, :hash2) s.hash2[] = h end From 65cc88f6ec035ba0607708b8057175e7a28271b3 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 14 Feb 2025 13:25:47 -0500 Subject: [PATCH 10/26] Change `Base.isequal` for `BasicSymbolicImpl` --- src/types.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 50c2ffd8d..2a936d2aa 100644 --- a/src/types.jl +++ b/src/types.jl @@ -284,7 +284,10 @@ function _allarequal(xs, ys; comparator = isequal)::Bool return true end -function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} +function Base.isequal(a::BasicSymbolic, b::BasicSymbolic) + isequal(a.expr, b.expr) +end +function Base.isequal(a::BasicSymbolicImpl{T}, b::BasicSymbolicImpl{S}) where {T,S} a === b && return true E = exprtype(a) From df49641860ba4d25154d6edf32f790ced6691547 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 14 Feb 2025 13:32:07 -0500 Subject: [PATCH 11/26] Modify flyweight factory for `BasicSymbolicImpl` [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 2a936d2aa..71d48b2c3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -568,7 +568,7 @@ function BasicSymbolicImpl(s::BasicSymbolicImpl)::BasicSymbolicImpl end h = hash2(s) t = get!(wvd, h, s) - if t === s || isequal_with_metadata(t, s) + if t === s || isequal(t, s) return t else return s From db30dc8f0f28a2c4891dcaa693f5e9bfd4ce3561 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:36:33 -0500 Subject: [PATCH 12/26] Create `MetadataImpl` struct to keep track of metadata tree Co-authored-by: Aayush Sabharwal --- src/types.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 71d48b2c3..9b94c1c0f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -63,9 +63,14 @@ const ENABLE_HASHCONSING = Ref(true) end end +struct MetadataImpl + this::Metadata + children::Vector{Any} +end + @kwdef struct BasicSymbolic{T} <: Symbolic{T} expr::BasicSymbolicImpl{T} - metadata::Metadata = NO_METADATA + meta::MetadataImpl end function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) From 580ad334b04b1efcb66c6d5b707f8c46ece3a186 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:37:53 -0500 Subject: [PATCH 13/26] Modify `getproperty(::BasicSymbolic)` for metadata --- src/types.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 9b94c1c0f..4251318be 100644 --- a/src/types.jl +++ b/src/types.jl @@ -93,7 +93,9 @@ function exprtype(expr::BasicSymbolicImpl) end function Base.getproperty(x::BasicSymbolic, sym::Symbol) - if sym === :metadata || sym === :expr + if sym === :metadata + return getfield(x, :meta).this + elseif sym === :expr || sym === :meta return getfield(x, sym) else return getproperty(x.expr, sym) From 831f8fc1943d6a4c91133d50a56d6009e9c78dbb Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:38:18 -0500 Subject: [PATCH 14/26] Add `isequal` method for `MetadataImpl` --- src/types.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/types.jl b/src/types.jl index 4251318be..050a8ba90 100644 --- a/src/types.jl +++ b/src/types.jl @@ -303,6 +303,12 @@ function Base.isequal(a::BasicSymbolicImpl{T}, b::BasicSymbolicImpl{S}) where {T T === S || return false return _isequal(a, b, E)::Bool end +function Base.isequal(a::MetadataImpl, b::MetadataImpl) + (a === b) || + (isequal_with_metadata(a.this, b.this) && + isequal_with_metadata(a.children, b.children)) +end + function _isequal(a, b, E; comparator = isequal) if E === SYM nameof(a) === nameof(b) From 29a6eb4c6485e1c023fc5fef853c15edcb2266ff Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:40:03 -0500 Subject: [PATCH 15/26] Modify `isequal_with_metadata` with new `BasicSymbolic` structure --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 050a8ba90..0daaa96aa 100644 --- a/src/types.jl +++ b/src/types.jl @@ -350,7 +350,7 @@ function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool w E === exprtype(b) || return false T === S || return false - _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false + _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal(a.meta, b.meta) || return false end """ From f8d0e78fa751d9c628a562e13a156614fcfc2fd0 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:43:35 -0500 Subject: [PATCH 16/26] Add `getmetadata` methods bc `metadata` kwarg takes outer-scope function in `BasicSymbolic` constructors --- src/types.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 0daaa96aa..288d3dcff 100644 --- a/src/types.jl +++ b/src/types.jl @@ -935,6 +935,7 @@ _issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) issafecanon(f, ss...) = all(x->issafecanon(f, x), ss) +getmetadata(s) = metadata(s) function getmetadata(s::Symbolic, ctx) md = metadata(s) if md isa AbstractDict @@ -943,11 +944,13 @@ function getmetadata(s::Symbolic, ctx) throw(ArgumentError("$s does not have metadata for $ctx")) end end - function getmetadata(s::Symbolic, ctx, default) md = metadata(s) md isa AbstractDict ? get(md, ctx, default) : default end +function getmetadata(d::AbstractDict, ctx) + d[ctx] +end # pirated for Setfield purposes: using Base: ImmutableDict From e9f4ff8f2ede359c014b613f363fc5a0c791e775 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:44:15 -0500 Subject: [PATCH 17/26] Modify `BasicSymbolic` constructors with new struct structure --- src/types.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/types.jl b/src/types.jl index 288d3dcff..fd93cc33a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -591,7 +591,8 @@ end function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T} s = Sym{T}(; name, kw...) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, Vector()) + BasicSymbolic(bsi, mdi) end function Term{T}(f, args; metadata = NO_METADATA, kw...) where T @@ -601,7 +602,8 @@ function Term{T}(f, args; metadata = NO_METADATA, kw...) where T s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, getmetadata.(args)) + BasicSymbolic(bsi, mdi) end function Term(f, args; metadata=NO_METADATA) @@ -623,7 +625,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + BasicSymbolic(bsi, mdi) end function Mul(T, a, b; metadata=NO_METADATA, kw...) @@ -640,7 +643,8 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) dict = b s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + BasicSymbolic(bsi, mdi) end end @@ -707,7 +711,8 @@ function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where { s = Div{T}(; num=n, den=d, simplified, arguments=[]) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + BasicSymbolic(bsi, mdi) end function Div(n,d, simplified=false; kw...) @@ -726,7 +731,8 @@ function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} _isone(b) && return a s = Pow{T}(; base=a, exp=b, arguments=[]) bsi = BasicSymbolicImpl(s) - BasicSymbolic(bsi, metadata) + mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + BasicSymbolic(bsi, mdi) end function Pow(a, b; metadata = NO_METADATA, kwargs...) From 5293e9813334111640b078b9348982205d32b30a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:44:43 -0500 Subject: [PATCH 18/26] Add `metadata_children` function for accessing metadata tree --- src/types.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/types.jl b/src/types.jl index fd93cc33a..4036282c9 100644 --- a/src/types.jl +++ b/src/types.jl @@ -923,6 +923,10 @@ end metadata(s::Symbolic) = s.metadata metadata(s::Symbolic, meta) = Setfield.@set! s.metadata = meta +function metadata_children(s::BasicSymbolic) + s.meta.children +end + function hasmetadata(s::Symbolic, ctx) metadata(s) isa AbstractDict && haskey(metadata(s), ctx) end From 4373146aa7791a25cd974bddf0d48ad3eeb20a31 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:45:19 -0500 Subject: [PATCH 19/26] Modify hash consing tests with new `BasicSymbolic` struct --- test/hash_consing.jl | 66 ++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/test/hash_consing.jl b/test/hash_consing.jl index 02b678b27..46a32bebb 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -7,23 +7,23 @@ struct Ctx2 end @testset "Sym" begin x1 = only(@syms x) x2 = only(@syms x) - @test x1 === x2 + @test x1.expr === x2.expr x3 = only(@syms x::Float64) - @test x1 !== x3 + @test x1.expr !== x3.expr x4 = only(@syms x::Float64) - @test x1 !== x4 - @test x3 === x4 + @test x1.expr !== x4.expr + @test x3.expr === x4.expr x5 = only(@syms x::Int) x6 = only(@syms x::Int) - @test x1 !== x5 - @test x3 !== x5 - @test x5 === x6 + @test x1.expr !== x5.expr + @test x3.expr !== x5.expr + @test x5.expr === x6.expr xm1 = setmetadata(x1, Ctx1, "meta_1") xm2 = setmetadata(x1, Ctx1, "meta_1") - @test xm1 === xm2 + @test xm1.expr === xm2.expr xm3 = setmetadata(x1, Ctx2, "meta_2") - @test xm1 !== xm3 + @test xm1.expr === xm3.expr end @syms a b c @@ -31,73 +31,73 @@ end @testset "Term" begin t1 = sin(a) t2 = sin(a) - @test t1 === t2 + @test t1.expr === t2.expr t3 = Term(identity,[a]) t4 = Term(identity,[a]) - @test t3 === t4 + @test t3.expr === t4.expr t5 = Term{Int}(identity,[a]) - @test t3 !== t5 + @test t3.expr !== t5.expr tm1 = setmetadata(t1, Ctx1, "meta_1") - @test t1 !== tm1 + @test t1.expr === tm1.expr end @testset "Add" begin d1 = a + b d2 = b + a - @test d1 === d2 + @test d1.expr === d2.expr d3 = b - 2 + a d4 = a + b - 2 - @test d3 === d4 + @test d3.expr === d4.expr d5 = Add(Int, 0, Dict(a => 1, b => 1)) - @test d5 !== d1 + @test d5.expr !== d1.expr dm1 = setmetadata(d1,Ctx1,"meta_1") - @test d1 !== dm1 + @test d1.expr === dm1.expr end @testset "Mul" begin m1 = a*b m2 = b*a - @test m1 === m2 + @test m1.expr === m2.expr m3 = 6*a*b m4 = 3*a*2*b - @test m3 === m4 + @test m3.expr === m4.expr m5 = Mul(Int, 1, Dict(a => 1, b => 1)) - @test m5 !== m1 + @test m5.expr !== m1.expr mm1 = setmetadata(m1, Ctx1, "meta_1") - @test m1 !== mm1 + @test m1.expr === mm1.expr end @testset "Div" begin v1 = a/b v2 = a/b - @test v1 === v2 + @test v1.expr === v2.expr v3 = -1/a v4 = -1/a - @test v3 === v4 + @test v3.expr === v4.expr v5 = 3a/6 v6 = 2a/4 - @test v5 === v6 + @test v5.expr === v6.expr v7 = Div{Float64}(-1,a) - @test v7 !== v3 + @test v7.expr !== v3.expr vm1 = setmetadata(v1,Ctx1, "meta_1") - @test vm1 !== v1 + @test vm1.expr === v1.expr end @testset "Pow" begin p1 = a^b p2 = a^b - @test p1 === p2 + @test p1.expr === p2.expr p3 = a^(2^-b) p4 = a^(2^-b) - @test p3 === p4 + @test p3.expr === p4.expr p5 = Pow{Float64}(a,b) - @test p1 !== p5 + @test p1.expr !== p5.expr pm1 = setmetadata(p1,Ctx1, "meta_1") - @test pm1 !== p1 + @test pm1.expr === p1.expr end @testset "Equivalent numbers" begin @@ -114,7 +114,7 @@ end a1 = setmetadata(a, Int, b) b1 = setmetadata(b, Int, 3) a2 = setmetadata(a, Int, b1) - @test a1 !== a2 + @test a1.expr === a2.expr @test !SymbolicUtils.isequal_with_metadata(a1, a2) @test metadata(metadata(a1)[Int]) === nothing @test metadata(metadata(a2)[Int])[Int] == 3 @@ -123,7 +123,7 @@ end @testset "Compare metadata of expression tree" begin @syms a b aa = setmetadata(a, Int, b) - @test aa !== a + @test aa.expr === a.expr @test isequal(a, aa) @test !SymbolicUtils.isequal_with_metadata(a, aa) @test !SymbolicUtils.isequal_with_metadata(2a, 2aa) @@ -144,6 +144,6 @@ end h = SymbolicUtils.hash2(ex) @test h == ex.hash2[] ex2 = setmetadata(ex, Int, 3) - @test ex2.hash2[] != h + @test ex2.hash2[] == h end end From 4a996190e624ed937a3f4dfa70f1b7b555d8b4de Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 16 Feb 2025 23:46:22 -0500 Subject: [PATCH 20/26] Modify rewrite metadata tests --- test/rewrite.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index c2e920f9b..4def1569e 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,4 +1,5 @@ using SymbolicUtils +using SymbolicUtils: metadata_children include("utils.jl") @@ -88,24 +89,24 @@ end ex1 = ex + c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex + b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata ex = a * b ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata end \ No newline at end of file From 440c17bbd614700728e41607de35eeae115e6777 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 01:55:06 -0500 Subject: [PATCH 21/26] Refactor `-(::SN, ::SN)` for easier debugging --- src/types.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/types.jl b/src/types.jl index 4036282c9..e61f9fd2a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1427,12 +1427,16 @@ function -(a::SN) end function -(a::SN, b::SN) - (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) - isadd(a) && isadd(b) ? Add(sub_t(a,b), - a.coeff - b.coeff, - _merge(-, a.dict, - b.dict, - filter=_iszero)) : a + (-b) + if !issafecanon(+, a) || !issafecanon(*, b) + return term(-, a, b) + elseif isadd(a) && isadd(b) + t = sub_t(a, b) + c = a.coeff - b.coeff + d = _merge(-, a.dict, b.dict, filter = _iszero) + return Add(t, c, d) + else + return a + (-b) + end end -(a::Number, b::SN) = a + (-b) From 4d0072603dd7ba0bdaa64e96a9e173b2ffb82769 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 01:57:51 -0500 Subject: [PATCH 22/26] Make `BasicSymbolicImpl` children `BasicSymbolicImpl` --- src/types.jl | 80 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/src/types.jl b/src/types.jl index e61f9fd2a..184e26ce1 100644 --- a/src/types.jl +++ b/src/types.jl @@ -244,6 +244,45 @@ function TermInterface.arguments(x::BasicSymbolicImpl) return args end +""" +$(TYPEDSIGNATURES) + +For given `coeff` and `dict`, return arguments of type `BasicSymbolicImpl`, children's +metadata and a new dictionary with `BasicSymbolicImpl` as key's type for preparation of the +construction of either `Add` or `Mul`. +""" +function get_arguments_metadata(coeff, dict::AbstractDict, type::ExprType) + siz = length(dict) + idcoeff = type === ADD ? iszero(coeff) : isone(coeff) + args = Vector() + sizehint!(args, idcoeff ? siz : siz + 1) + idcoeff || push!(args, coeff) + if type === ADD + for (k, v) in dict + if k isa BasicSymbolicImpl + k = BasicSymbolic(k, MetadataImpl()) + end + push!(args, applicable(*, k, v) ? k * v : maketerm(k, *, [k, v], nothing)) + end + else # MUL + for (k, v) in dict + if k isa BasicSymbolicImpl + k = BasicSymbolic(k, MetadataImpl()) + end + push!(args, unstable_pow(k, v)) + end + end + metadata_children = map(getmetaimpl, args) + for i in 1:length(args) + if args[i] isa BasicSymbolic + args[i] = args[i].expr + end + end + keys = idcoeff ? args : @view args[2:end] + bsi_dict = Dict(zip(keys, values(dict))) + return args, metadata_children, bsi_dict +end + isexpr(s::BasicSymbolic) = isexpr(s.expr) isexpr(expr::BasicSymbolicImpl) = !issym(expr) iscall(s::BasicSymbolic) = iscall(s.expr) @@ -599,10 +638,15 @@ function Term{T}(f, args; metadata = NO_METADATA, kw...) where T if eltype(args) !== Any args = convert(Vector{Any}, args) end - + metadata_children = map(getmetaimpl, args) + for i in 1:length(args) + if args[i] isa BasicSymbolic + args[i] = args[i].expr + end + end s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) bsi = BasicSymbolicImpl(s) - mdi = MetadataImpl(metadata, getmetadata.(args)) + mdi = MetadataImpl(metadata, metadata_children) BasicSymbolic(bsi, mdi) end @@ -622,10 +666,10 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T return Mul(T, coeff, dict) end end - - s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) + arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, ADD) + s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...) bsi = BasicSymbolicImpl(s) - mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + mdi = MetadataImpl(metadata, metadata_children) BasicSymbolic(bsi, mdi) end @@ -641,9 +685,10 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...) + arguments, metadata_children, dict = get_arguments_metadata(coeff, dict, MUL) + s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments, issorted=RefValue(false), kw...) bsi = BasicSymbolicImpl(s) - mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + mdi = MetadataImpl(metadata, metadata_children) BasicSymbolic(bsi, mdi) end end @@ -708,10 +753,16 @@ function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where { end end end - + metadata_children = [getmetaimpl(n), getmetaimpl(d)] + if n isa BasicSymbolic + n = n.expr + end + if d isa BasicSymbolic + d = d.expr + end s = Div{T}(; num=n, den=d, simplified, arguments=[]) bsi = BasicSymbolicImpl(s) - mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + mdi = MetadataImpl(metadata, metadata_children) BasicSymbolic(bsi, mdi) end @@ -728,10 +779,17 @@ end function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} _iszero(b) && return 1 - _isone(b) && return a + _isone(b) && return a + metadata_children = [getmetaimpl(a), getmetaimpl(b)] + if a isa BasicSymbolic + a = a.expr + end + if b isa BasicSymbolic + b = b.expr + end s = Pow{T}(; base=a, exp=b, arguments=[]) bsi = BasicSymbolicImpl(s) - mdi = MetadataImpl(metadata, getmetadata.(arguments(s))) + mdi = MetadataImpl(metadata, metadata_children) BasicSymbolic(bsi, mdi) end From 2fcac041991bb8f431b73f1f843887242ac8cc3c Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 01:59:04 -0500 Subject: [PATCH 23/26] Make `getproperty` return `BasicSymbolic` if applicable --- src/types.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 184e26ce1..4a868733e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -97,8 +97,24 @@ function Base.getproperty(x::BasicSymbolic, sym::Symbol) return getfield(x, :meta).this elseif sym === :expr || sym === :meta return getfield(x, sym) + elseif sym === :base || sym === :num + bsi = getproperty(getfield(x, :expr), sym) + if bsi isa BasicSymbolicImpl + mdi = getfield(x, :meta).children[1] + return BasicSymbolic(bsi, mdi) + else + return bsi + end + elseif sym === :exp || sym === :den + bsi = getproperty(getfield(x, :expr), sym) + if bsi isa BasicSymbolicImpl + mdi = getfield(x, :meta).children[2] + return BasicSymbolic(bsi, mdi) + else + return bsi + end else - return getproperty(x.expr, sym) + return getproperty(getfield(x, :expr), sym) end end From d03e0efb8e624c46207cbaeb2b1cd3156e00f6ba Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 02:02:07 -0500 Subject: [PATCH 24/26] `arguments` wraps `BasicSymbolicImpl` and `MetadataImpl` Co-authored-by: Aayush Sabharwal --- src/types.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 4a868733e..5193ca22e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -68,11 +68,18 @@ struct MetadataImpl children::Vector{Any} end +function MetadataImpl() + MetadataImpl(nothing, Vector()) +end + @kwdef struct BasicSymbolic{T} <: Symbolic{T} expr::BasicSymbolicImpl{T} meta::MetadataImpl end +getmetaimpl(x::BasicSymbolic) = x.meta +getmetaimpl(::Any) = nothing + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end @@ -210,7 +217,19 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - arguments(x.expr) + args = arguments(x.expr) + args_metadata = x.meta.children + res = Vector() + for (arg, meta) in zip(args, args_metadata) + if arg isa BasicSymbolicImpl + if isnothing(meta) + meta = MetadataImpl() + end + arg = BasicSymbolic(arg, meta) + end + push!(res, arg) + end + res end function TermInterface.arguments(x::BasicSymbolicImpl) @compactified x::BasicSymbolicImpl begin From 2eae8514b34221b6c3f401b1eb4676ee6e204aa3 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 02:02:55 -0500 Subject: [PATCH 25/26] Revert "Add `metadata_children` function for accessing metadata tree" Co-authored-by: Aayush Sabharwal This reverts commit 5293e9813334111640b078b9348982205d32b30a. --- src/types.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index 5193ca22e..b0e59ca2b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1016,10 +1016,6 @@ end metadata(s::Symbolic) = s.metadata metadata(s::Symbolic, meta) = Setfield.@set! s.metadata = meta -function metadata_children(s::BasicSymbolic) - s.meta.children -end - function hasmetadata(s::Symbolic, ctx) metadata(s) isa AbstractDict && haskey(metadata(s), ctx) end From cc5e2739162c783f0c50e919bef3d2b668ad5dfd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Feb 2025 02:03:54 -0500 Subject: [PATCH 26/26] Revert "Modify rewrite metadata tests" Co-authored-by: Aayush Sabharwal This reverts commit 4a996190e624ed937a3f4dfa70f1b7b555d8b4de. --- test/rewrite.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index 4def1569e..c2e920f9b 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,4 @@ using SymbolicUtils -using SymbolicUtils: metadata_children include("utils.jl") @@ -89,24 +88,24 @@ end ex1 = ex + c @test SymbolicUtils.isterm(ex1) - @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex + b - @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata ex = a * b ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * c @test SymbolicUtils.isterm(ex1) - @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * b - @test getmetadata(metadata_children(ex1)[1], MetaData) == :metadata + @test getmetadata(arguments(ex1)[1], MetaData) == :metadata end \ No newline at end of file