Skip to content

refactor: disable internal mutation of BasicSymbolic #718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down Expand Up @@ -52,6 +54,8 @@ LabelledArrays = "1.5"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1.1.2"
OhMyThreads = "0.7"
ReadOnlyArrays = "0.2.0"
ReadOnlyDicts = "1.0.0"
ReverseDiff = "1"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need
@rule(+(~~xs) => ~~xs)(x + y + z)

# output
3-element view(::Vector{Any}, 1:3) with eltype Any:
3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, Vector{Any}}, 1:3) with eltype Any:
z
y
x
Expand Down
2 changes: 2 additions & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import ArrayInterface
using WeakValueDicts: WeakValueDict
import ExproniconLite as EL
import TaskLocalValues: TaskLocalValue
using ReadOnlyArrays
using ReadOnlyDicts

include("cache.jl")
Base.@deprecate istree iscall
Expand Down
6 changes: 3 additions & 3 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ end
# ismul(x)
function quick_mul(x, y)
if haskey(x.dict, y) && x.dict[y] >= 1
d = copy(x.dict)
d = copy(parent(x.dict))
if d[y] > 1
d[y] -= 1
elseif d[y] == 1
Expand All @@ -490,7 +490,7 @@ end
function quick_mulpow(x, y)
y.exp isa Number || return (x, y)
if haskey(x.dict, y.base)
d = copy(x.dict)
d = copy(parent(x.dict))
if x.dict[y.base] > y.exp
d[y.base] -= y.exp
den = 1
Expand All @@ -509,7 +509,7 @@ end

# Double mul case
function quick_mulmul(x, y)
num_dict, den_dict = _merge_div(x.dict, y.dict)
num_dict, den_dict = _merge_div(parent(x.dict), parent(y.dict))
Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict)
end

Expand Down
32 changes: 12 additions & 20 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sdict(kv...) = Dict{Any, Any}(kv...)
using Base: RefValue
const EMPTY_ARGS = []
const EMPTY_HASH = RefValue(UInt(0))
const EMPTY_DICT = sdict()
const EMPTY_DICT = ReadOnlyDict(sdict())
const EMPTY_DICT_T = typeof(EMPTY_DICT)
const ENABLE_HASHCONSING = Ref(true)

Expand Down Expand Up @@ -62,6 +62,10 @@ const ENABLE_HASHCONSING = Ref(true)
end
end

function Base.setproperty!(x::BasicSymbolic, sym::Symbol, v)
error("Mutating `BasicSymbolic` is not allowed")
end

function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
ScalarSymbolic()
end
Expand Down Expand Up @@ -147,8 +151,8 @@ end

@inline head(x::BasicSymbolic) = operation(x)

@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::Vector{Any}
args = copy(arguments(x))
@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::ReadOnlyVector{Any}
args = copy(parent(arguments(x)))
@compactified x::BasicSymbolic begin
Add => @goto ADD
Mul => @goto MUL
Expand All @@ -167,7 +171,7 @@ end

TermInterface.children(x::BasicSymbolic) = arguments(x)
TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x)
function TermInterface.arguments(x::BasicSymbolic)
function TermInterface.arguments(x::BasicSymbolic)::ReadOnlyVector{Any}
@compactified x::BasicSymbolic begin
Term => return x.arguments
Add => @goto ADDMUL
Expand Down Expand Up @@ -554,17 +558,10 @@ function Sym{T}(name::Symbol; kw...) where {T}
BasicSymbolic(s)
end

function unwrap_arr!(arr)
for i in eachindex(arr)
arr[i] = unwrap(arr[i])
end
end

function Term{T}(f, args; kw...) where T
if eltype(args) !== Any
args = convert(Vector{Any}, args)
end
unwrap_arr!(args)

s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
BasicSymbolic(s)
Expand All @@ -574,16 +571,8 @@ function Term(f, args; metadata=NO_METADATA)
Term{_promote_symtype(f, args)}(f, args, metadata=metadata)
end

function unwrap_dict(dict)
if any(k -> unwrap(k) !== k, keys(dict))
return typeof(dict)(unwrap(k) => v for (k, v) in dict)
end
return dict
end

function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
coeff = unwrap(coeff)
dict = unwrap_dict(dict)
if isempty(dict)
return coeff
elseif _iszero(coeff) && length(dict) == 1
Expand All @@ -602,7 +591,6 @@ end

function Mul(T, a, b; metadata=NO_METADATA, kw...)
a = unwrap(a)
b = unwrap_dict(b)
isempty(b) && return a
if _isone(a) && length(b) == 1
pair = first(b)
Expand Down Expand Up @@ -1342,6 +1330,10 @@ function _merge!(f::F, d, others...; filter=x->false) where F
acc
end

function mapvalues(f, d1::ReadOnlyDict)
mapvalues(f, parent(d1))
end

function mapvalues(f, d1::AbstractDict)
d = copy(d1)
for (k, v) in d
Expand Down
Loading