diff --git a/Project.toml b/Project.toml index 357437b..7f585fd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.5.10" +version = "0.5.11" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" @@ -15,6 +16,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] Accessors = "0.1.41" +Adapt = "4.3.0" Aqua = "0.8.9" ArrayLayouts = "1.11.0" DerivableInterfaces = "0.5" diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index 3e0422f..9dfe012 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -35,6 +35,26 @@ function Base._cat(dims, a::AnyAbstractSparseArray...) return concatenate(dims, a...) end +function map_stored(f, a::AnyAbstractSparseArray) + kvs = storedpairs(a) + # `collect` to convert to `Vector`, since otherwise + # if it stays as `Dictionary` we might hit issues like + # https://github.com/andyferris/Dictionaries.jl/issues/163. + ks = collect(first.(kvs)) + vs = collect(last.(kvs)) + vs′ = map(f, vs) + a′ = zero!(similar(a, eltype(vs′))) + for (k, v′) in zip(ks, vs′) + a′[k] = v′ + end + return a′ +end + +using Adapt: adapt +function Base.print_array(io::IO, a::AnyAbstractSparseArray) + a′ = map_stored(adapt(Array), a) + return @invoke Base.print_array(io::typeof(io), a′::AbstractArray{<:Any,ndims(a)}) +end function Base.replace_in_print_matrix( a::AnyAbstractSparseVecOrMat, i::Integer, j::Integer, s::AbstractString ) diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index f31ca8a..6af80e6 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -256,7 +256,11 @@ function sparse_mul!( for I2 in eachstoredindex(a2) I_dest = mul_indices(I1, I2) if !isnothing(I_dest) - a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β′) + if isstored(a_dest, I_dest) + a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β′) + else + a_dest[I_dest] = a1[I1] * a2[I2] * α + end end end end diff --git a/test/Project.toml b/test/Project.toml index 0d9f3a5..9176f07 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -17,7 +16,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Adapt = "4.2.0" Aqua = "0.8.11" ArrayLayouts = "1.11.1" -DerivableInterfaces = "0.5" Dictionaries = "0.4.4" JLArrays = "0.2.0" LinearAlgebra = "<0.0.1, 1"