Skip to content

feat: add LittleBigDict and use it in BasicSymbolic #725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down Expand Up @@ -51,6 +52,7 @@ LabelledArrays = "1.5"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1.1.2"
OhMyThreads = "0.7"
OrderedCollections = "1.8.0"
ReverseDiff = "1"
RuntimeGeneratedFunctions = "0.5.13"
Setfield = "0.7, 0.8, 1"
Expand Down
1 change: 1 addition & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import TermInterface: iscall, isexpr, head, children,
import ArrayInterface
import ExproniconLite as EL
import TaskLocalValues: TaskLocalValue
import OrderedCollections: LittleDict

include("cache.jl")
Base.@deprecate istree iscall
Expand Down
101 changes: 92 additions & 9 deletions src/small_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ GC'ed when removed.
defaultval(::Type{T}) where {T <: Number} = zero(T)
defaultval(::Type{Any}) = nothing

function Base.getindex(x::Backing, i::Int)
Base.@propagate_inbounds function Base.getindex(x::Backing, i::Int)
@boundscheck 1 <= i <= x.len
if i == 1
x.x1
Expand All @@ -43,7 +43,7 @@ function Base.getindex(x::Backing, i::Int)
end
end

function Base.setindex!(x::Backing, v, i::Int)
Base.@propagate_inbounds function Base.setindex!(x::Backing, v, i::Int)
@boundscheck 1 <= i <= x.len
if i == 1
setfield!(x, :x1, v)
Expand All @@ -54,20 +54,30 @@ function Base.setindex!(x::Backing, v, i::Int)
end
end

function Base.push!(x::Backing, v)
Base.@propagate_inbounds function Base.push!(x::Backing, v)
x.len < 3 || throw(ArgumentError("`Backing` is full"))
x.len += 1
x[x.len] = v
end

function Base.pop!(x::Backing{T}) where {T}
Base.@propagate_inbounds function Base.pop!(x::Backing{T}) where {T}
x.len > 0 || throw(ArgumentError("Array is empty"))
v = x[x.len]
x[x.len] = defaultval(T)
x.len -= 1
v
end

Base.@propagate_inbounds function Base.deleteat!(x::Backing{T}, i::Int) where {T}
@boundscheck 1 <= i <= x.len
x[i] = defaultval(T)
for j in i:x.len-1
x[i] = x[i + 1]
end
x.len -= 1
x
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -113,21 +123,94 @@ Base.convert(::Type{SmallVec{T, V}}, x::SmallVec{T, V}) where {T, V} = x

Base.size(x::SmallVec) = size(x.data)
Base.isempty(x::SmallVec) = isempty(x.data)
Base.getindex(x::SmallVec, i::Int) = x.data[i]
Base.setindex!(x::SmallVec, v, i::Int) = setindex!(x.data, v, i)
Base.@propagate_inbounds Base.getindex(x::SmallVec, i::Int) = x.data[i]
Base.@propagate_inbounds Base.setindex!(x::SmallVec, v, i::Int) = setindex!(x.data, v, i)
Base.@propagate_inbounds Base.deleteat!(x::SmallVec, i) = deleteat!(x.data, i)

function Base.push!(x::SmallVec{T, V}, v) where {T, V}
Base.@propagate_inbounds function Base.push!(x::SmallVec{T, V}, v) where {T, V}
buf = x.data
buf isa Backing{T} || return push!(buf::V, v)
isfull(buf) || return push!(buf::Backing{T}, v)
x.data = V(buf)
return push!(x.data::V, v)
end

Base.pop!(x::SmallVec) = pop!(x.data)
Base.@propagate_inbounds Base.pop!(x::SmallVec) = pop!(x.data)

function Base.sizehint!(x::SmallVec{T, V}, n; kwargs...) where {T, V}
x.data isa Backing && return x
if x.data isa Backing
if n > 3
x.data = V(x.data)
end
return x
end
sizehint!(x.data, n; kwargs...)
x
end

mutable struct LittleBigDict{K, V, KVec, VVec, D <: AbstractDict{K, V}} <: AbstractDict{K, V}
data::Union{LittleDict{K, V, SmallVec{K, KVec}, SmallVec{V, VVec}}, D}

function LittleBigDict{K, V, Kv, Vv, D}(keys, vals) where {K, V, Kv, Vv, D}
nk = length(keys)
nv = length(vals)
nk == nv || throw(ArgumentError("Got $nk keys for $nv values"))
if nk < 25
keys = SmallVec{K, Kv}(keys)
vals = SmallVec{V, Vv}(vals)
new{K, V, Kv, Vv, D}(LittleDict{K, V}(keys, vals))
else
new{K, V, Kv, Vv, D}(D(zip(keys, vals)))
end
end

function LittleBigDict{K, V, Kv, Vv, D}(d::D) where {K, V, Kv, Vv, D}
if length(d) < 25
return LittleBigDict{K, V, Kv, Vv, D}(collect(keys(d)), collect(values(d)))
else
return new{K, V, Kv, Vv, D}(d)
end
end

function LittleBigDict{K, V, Kv, Vv, D}(d::AbstractDict) where {K, V, Kv, Vv, D}
LittleBigDict{K, V, Kv, Vv, D}(collect(keys(d)), collect(values(d)))
end
end

function LittleBigDict{K, V, D}() where {K, V, D}
LittleBigDict{K, V, Vector{K}, Vector{V}, D}((), ())
end
LittleBigDict{K, V}() where {K, V} = LittleBigDict{K, V, Dict{K, V}}()

Base.haskey(x::LittleBigDict, k) = haskey(x.data, k)
Base.length(x::LittleBigDict) = length(x.data)
Base.getkey(x::LittleBigDict, k, d) = getkey(x.data, k, d)
Base.get(x::LittleBigDict, k, d) = get(x.data, k, d)
function Base.sizehint!(x::LittleBigDict{K, V, Kv, Vv, D}, n; kwargs...) where {K, V, Kv, Vv, D}
if x.data isa LittleDict && n >= 25
x.data = D(x.data)
end
sizehint!(x.data, n; kwargs...)
end
function Base.setindex!(x::LittleBigDict{K, V, Kv, Vv, D}, v, k) where {K, V, Kv, Vv, D}
if x.data isa LittleDict
delete!(x.data, k)
get!(Returns(v), x.data, k)
if length(x.data) > 25
x.data = D(x.data)
end
v
else
setindex!(x.data, v, k)
end
end
Base.getindex(x::LittleBigDict, k) = getindex(x.data, k)
Base.delete!(x::LittleBigDict, k) = delete!(x.data, k)
function Base.get!(f::Base.Callable, x::LittleBigDict{K, V, Kv, Vv, D}, k) where {K, V, Kv, Vv, D}
res = get!(f, x.data, k)
if x.data isa LittleDict && length(x.data) > 25
x.data = D(x.data)
end
res
end
Base.iterate(x::LittleBigDict, args...) = iterate(x.data, args...)
2 changes: 1 addition & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type Symbolic{T} end
const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}}
const NO_METADATA = nothing

sdict(kv...) = Dict{Any, Any}(kv...)
sdict() = LittleBigDict{Any, Any}()

using Base: RefValue
const SmallV{T} = SmallVec{T, Vector{T}}
Expand Down
2 changes: 1 addition & 1 deletion test/inspect_output/ex.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
1 DIV
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3 │ ├─ POW
4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
4 │ │ ├─ ADD(scalar = 1, coeffs = (x => 2, y => 3))
5 │ │ │ ├─ 1
6 │ │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
7 │ │ │ │ ├─ 2
Expand Down
Loading