diff --git a/Project.toml b/Project.toml index cbc04c758..3ec83ad17 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,8 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" +ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -52,6 +54,8 @@ LabelledArrays = "1.5" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1.1.2" OhMyThreads = "0.7" +ReadOnlyArrays = "0.2.0" +ReadOnlyDicts = "1.0.0" ReverseDiff = "1" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 047bef71a..05d18e82a 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need @rule(+(~~xs) => ~~xs)(x + y + z) # output -3-element view(::Vector{Any}, 1:3) with eltype Any: +3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, Vector{Any}}, 1:3) with eltype Any: z y x diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index f965be206..0bfc30a8c 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -23,6 +23,8 @@ import ArrayInterface using WeakValueDicts: WeakValueDict import ExproniconLite as EL import TaskLocalValues: TaskLocalValue +using ReadOnlyArrays +using ReadOnlyDicts include("cache.jl") Base.@deprecate istree iscall diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..d41fa2070 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -471,7 +471,7 @@ end # ismul(x) function quick_mul(x, y) if haskey(x.dict, y) && x.dict[y] >= 1 - d = copy(x.dict) + d = copy(parent(x.dict)) if d[y] > 1 d[y] -= 1 elseif d[y] == 1 @@ -490,7 +490,7 @@ end function quick_mulpow(x, y) y.exp isa Number || return (x, y) if haskey(x.dict, y.base) - d = copy(x.dict) + d = copy(parent(x.dict)) if x.dict[y.base] > y.exp d[y.base] -= y.exp den = 1 @@ -509,7 +509,7 @@ end # Double mul case function quick_mulmul(x, y) - num_dict, den_dict = _merge_div(x.dict, y.dict) + num_dict, den_dict = _merge_div(parent(x.dict), parent(y.dict)) Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict) end diff --git a/src/types.jl b/src/types.jl index 152a97290..3b1d23276 100644 --- a/src/types.jl +++ b/src/types.jl @@ -18,7 +18,7 @@ sdict(kv...) = Dict{Any, Any}(kv...) using Base: RefValue const EMPTY_ARGS = [] const EMPTY_HASH = RefValue(UInt(0)) -const EMPTY_DICT = sdict() +const EMPTY_DICT = ReadOnlyDict(sdict()) const EMPTY_DICT_T = typeof(EMPTY_DICT) const ENABLE_HASHCONSING = Ref(true) @@ -62,6 +62,10 @@ const ENABLE_HASHCONSING = Ref(true) end end +function Base.setproperty!(x::BasicSymbolic, sym::Symbol, v) + error("Mutating `BasicSymbolic` is not allowed") +end + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end @@ -147,8 +151,8 @@ end @inline head(x::BasicSymbolic) = operation(x) -@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::Vector{Any} - args = copy(arguments(x)) +@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::ReadOnlyVector{Any} + args = copy(parent(arguments(x))) @compactified x::BasicSymbolic begin Add => @goto ADD Mul => @goto MUL @@ -167,7 +171,7 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) -function TermInterface.arguments(x::BasicSymbolic) +function TermInterface.arguments(x::BasicSymbolic)::ReadOnlyVector{Any} @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL @@ -554,17 +558,10 @@ function Sym{T}(name::Symbol; kw...) where {T} BasicSymbolic(s) end -function unwrap_arr!(arr) - for i in eachindex(arr) - arr[i] = unwrap(arr[i]) - end -end - function Term{T}(f, args; kw...) where T if eltype(args) !== Any args = convert(Vector{Any}, args) end - unwrap_arr!(args) s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) BasicSymbolic(s) @@ -574,16 +571,8 @@ function Term(f, args; metadata=NO_METADATA) Term{_promote_symtype(f, args)}(f, args, metadata=metadata) end -function unwrap_dict(dict) - if any(k -> unwrap(k) !== k, keys(dict)) - return typeof(dict)(unwrap(k) => v for (k, v) in dict) - end - return dict -end - function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T coeff = unwrap(coeff) - dict = unwrap_dict(dict) if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 @@ -602,7 +591,6 @@ end function Mul(T, a, b; metadata=NO_METADATA, kw...) a = unwrap(a) - b = unwrap_dict(b) isempty(b) && return a if _isone(a) && length(b) == 1 pair = first(b) @@ -1342,6 +1330,10 @@ function _merge!(f::F, d, others...; filter=x->false) where F acc end +function mapvalues(f, d1::ReadOnlyDict) + mapvalues(f, parent(d1)) +end + function mapvalues(f, d1::AbstractDict) d = copy(d1) for (k, v) in d