Skip to content

Start handling abstract block types #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.7.14"
version = "0.7.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
20 changes: 19 additions & 1 deletion src/abstractblocksparsearray/abstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value, ::Block{0})
return a
end

# Custom `_convert` works around the issue that
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
# https://github.com/JuliaLang/julia/pull/52487).
# TODO: Delete once we drop support for Julia v1.10.
_convert(::Type{T}, a::AbstractArray) where {T} = convert(T, a)
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
_construct(T::Type{<:Diagonal}, a::AbstractMatrix) = T(diag(a))
function _convert(T::Type{<:Diagonal}, a::AbstractMatrix)
LinearAlgebra.checksquare(a)
return isdiag(a) ? _construct(T, a) : throw(InexactError(:convert, T, a))
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
Expand All @@ -74,7 +87,12 @@ function Base.setindex!(
),
)
end
blocks(a)[Int.(I)...] = value
# Custom `_convert` works around the issue that
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
# https://github.com/JuliaLang/julia/pull/52487).
# TODO: Delete once we drop support for Julia v1.10.
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
return a
end

Expand Down
34 changes: 29 additions & 5 deletions src/abstractblocksparsearray/arraylayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,37 @@ function ArrayLayouts.MemoryLayout(
end

function Base.similar(
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
mul::MulAdd{
<:BlockLayout{<:SparseLayout,BlockLayoutA},
<:BlockLayout{<:SparseLayout,BlockLayoutB},
LayoutC,
T,
A,
B,
C,
},
elt::Type,
axes,
) where {A,B}
# TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type.
output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...})
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
) where {BlockLayoutA,BlockLayoutB,LayoutC,T,A,B,C}

# TODO: Consider using this instead:
# ```julia
# blockmultype = MulAdd{BlockLayoutA,BlockLayoutB,LayoutC,T,blocktype(A),blocktype(B),C}
# output_blocktype = Base.promote_op(
# similar, blockmultype, Type{elt}, Tuple{eltype.(eachblockaxis.(axes))...}
# )
# ```
# The issue is that it in some cases it seems to lose some information about the block types.

# TODO: Maybe this should be:
# output_blocktype = Base.promote_op(
# mul!, blocktype(mul.A), blocktype(mul.B), blocktype(mul.C), typeof(mul.α), typeof(mul.β)
# )

output_blocktype = Base.promote_op(*, blocktype(mul.A), blocktype(mul.B))
output_blocktype′ =
!isconcretetype(output_blocktype) ? AbstractMatrix{elt} : output_blocktype
return similar(BlockSparseArray{elt,length(axes),output_blocktype′}, axes)
end

# Materialize a SubArray view.
Expand Down
22 changes: 14 additions & 8 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,11 @@ end

function blocksparse_similar(a, elt::Type, axes::Tuple)
ndims = length(axes)
blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
return BlockSparseArray{elt,ndims,blockt}(undef, axes)
# TODO: Define a version of `similartype` that catches the case
# where the output isn't concrete and returns an `AbstractArray`.
blockt = Base.promote_op(similar, blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
blockt′ = !isconcretetype(blockt) ? AbstractArray{elt,ndims} : blockt
return BlockSparseArray{elt,ndims,blockt′}(undef, axes)
end
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}
Expand Down Expand Up @@ -283,13 +286,11 @@ function Base.similar(
elt::Type,
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

# Fixes ambiguity error.
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{})
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -301,7 +302,6 @@ function Base.similar(
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -311,7 +311,6 @@ function Base.similar(
elt::Type,
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand All @@ -321,9 +320,17 @@ function Base.similar(
elt::Type,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type)
return @interface interface(a) similar(a, elt, axes(a))
end
function Base.similar(
a::AnyAbstractBlockSparseArray,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
return @interface interface(a) similar(a, eltype(a), axes)
end

# Fixes ambiguity errors with BlockArrays.
function Base.similar(
Expand All @@ -343,7 +350,6 @@ end
function Base.similar(
a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
# TODO: Use `@interface interface(a) similar(...)`.
return @interface interface(a) similar(a, elt, axes)
end

Expand Down
8 changes: 8 additions & 0 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ function eachblockstoredindex(a::AbstractArray)
return Block.(Tuple.(eachstoredindex(blocks(a))))
end

function SparseArraysBase.isstored(a::AbstractArray, I1::Block{1}, Irest::Block{1}...)
I = (I1, Irest...)
return isstored(blocks(a), Int.(I)...)
end
function SparseArraysBase.isstored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return isstored(a, Tuple(I)...)
end

using DiagonalArrays: diagindices
# Block version of `DiagonalArrays.diagindices`.
function blockdiagindices(a::AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions src/blocksparsearrayinterface/getunstoredblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ end
ax = ntuple(N) do d
return only(axes(f.axes[d][Block(I[d])]))
end
!isconcretetype(A) && return zero!(similar(Array{eltype(A),N}, ax))
return zero!(similar(A, ax))
end
@inline function (f::GetUnstoredBlock)(
Expand Down
55 changes: 38 additions & 17 deletions src/blocksparsearrayinterface/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@ function union_eachblockstoredindex(as::AbstractArray...)
return ∪(map(eachblockstoredindex, as)...)
end

# Get a view of a block assuming it is stored.
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
return blocks(a)[Int.(I)...]
end
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return viewblock_stored(a, Tuple(I)...)
end

using FillArrays: Zeros
# Get a view of a block if it is stored, otherwise return a lazy zeros.
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
if isstored(a, I...)
return viewblock_stored(a, I...)
else
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
return Zeros{eltype(a)}(block_ax)
end
end
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
return viewblock_or_zeros(a, Tuple(I)...)
end

function map_block!(f, a_dest::AbstractArray, I::Block, a_srcs::AbstractArray...)
a_srcs_I = map(a_src -> viewblock_or_zeros(a_src, I), a_srcs)
if isstored(a_dest, I)
a_dest[I] .= f.(a_srcs_I...)
else
a_dest[I] = Broadcast.broadcast_preserving_zero_d(f, a_srcs_I...)
end
return a_dest
end

function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
# TODO: This assumes element types are numbers, generalize this logic.
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
Expand All @@ -27,22 +59,7 @@ function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
BlockRange(a_dest)
end
for I in Is
# TODO: Use:
# block_dest = @view a_dest[I]
# or:
# block_dest = @view! a_dest[I]
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
# TODO: Use:
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
block_srcs = map(a_srcs) do a_src
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
end
# TODO: Use `map!!` to handle immutable blocks.
map!(f, block_dest, block_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
# TODO: Use `a_dest[I] = block_dest`.
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
map_block!(f, a_dest, I, a_srcs...)
end
return a_dest
end
Expand Down Expand Up @@ -151,8 +168,12 @@ end
function map_stored_blocks(f, a::AbstractArray)
block_stored_indices = collect(eachblockstoredindex(a))
if isempty(block_stored_indices)
eltype_a′ = Base.promote_op(f, eltype(a))
blocktype_a′ = Base.promote_op(f, blocktype(a))
return BlockSparseArray{eltype(blocktype_a′),ndims(a),blocktype_a′}(undef, axes(a))
eltype_a′′ = !isconcretetype(eltype_a′) ? Any : eltype_a′
blocktype_a′′ =
!isconcretetype(blocktype_a′) ? AbstractArray{eltype_a′′,ndims(a)} : blocktype_a′
return BlockSparseArray{eltype_a′′,ndims(a),blocktype_a′′}(undef, axes(a))
end
stored_blocks = map(B -> f(@view!(a[B])), block_stored_indices)
blocktype_a′ = eltype(stored_blocks)
Expand Down
13 changes: 12 additions & 1 deletion src/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,21 @@ function MatrixAlgebraKit.default_svd_algorithm(
return BlockPermutedDiagonalAlgorithm(alg)
end

function output_type(
::typeof(svd_compact!),
A::Type{<:AbstractMatrix{T}},
Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm},
) where {T}
USVᴴ = Base.promote_op(svd_compact!, A, Alg)
!isconcretetype(USVᴴ) &&
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
return USVᴴ
end

function similar_output(
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
)
BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg)))
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg)))
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
S = similar(A, BlockType(BS), S_axes)
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -25,6 +26,7 @@ Aqua = "0.8"
ArrayLayouts = "1"
BlockArrays = "1"
BlockSparseArrays = "0.7"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3"
GPUArraysCore = "0.2"
JLArrays = "0.2"
Expand Down
44 changes: 44 additions & 0 deletions test/test_abstract_blocktype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Adapt: adapt
using BlockArrays: Block
using BlockSparseArrays: BlockSparseMatrix, blockstoredlength
using JLArrays: JLArray
using SparseArraysBase: storedlength
using Test: @test, @test_broken, @testset

elts = (Float32, Float64, ComplexF32)
arrayts = (Array, JLArray)
@testset "Abstract block type (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
elt in elts

dev = adapt(arrayt)
a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
@test sprint(show, MIME"text/plain"(), a) isa String
@test iszero(storedlength(a))
@test iszero(blockstoredlength(a))
a[Block(1, 1)] = dev(randn(elt, 2, 2))
a[Block(2, 2)] = dev(randn(elt, 3, 3))
@test !iszero(a[Block(1, 1)])
@test a[Block(1, 1)] isa arrayt{elt,2}
@test !iszero(a[Block(2, 2)])
@test a[Block(2, 2)] isa arrayt{elt,2}
@test iszero(a[Block(2, 1)])
@test a[Block(2, 1)] isa Matrix{elt}
@test iszero(a[Block(1, 2)])
@test a[Block(1, 2)] isa Matrix{elt}

b = copy(a)
@test Array(b) ≈ Array(a)

b = a + a
@test Array(b) ≈ Array(a) + Array(a)

b = 3a
@test Array(b) ≈ 3Array(a)

if arrayt === Array
b = a * a
@test Array(b) ≈ Array(a) * Array(a)
else
@test_broken a * a
end
end
Loading
Loading