Skip to content

Commit d55caaa

Browse files
authored
Start handling abstract block types (#150)
1 parent 89faa55 commit d55caaa

15 files changed

+1031
-826
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.14"
4+
version = "0.7.15"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value, ::Block{0})
6363
return a
6464
end
6565

66+
# Custom `_convert` works around the issue that
67+
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
68+
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
69+
# https://github.com/JuliaLang/julia/pull/52487).
70+
# TODO: Delete once we drop support for Julia v1.10.
71+
_convert(::Type{T}, a::AbstractArray) where {T} = convert(T, a)
72+
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
73+
_construct(T::Type{<:Diagonal}, a::AbstractMatrix) = T(diag(a))
74+
function _convert(T::Type{<:Diagonal}, a::AbstractMatrix)
75+
LinearAlgebra.checksquare(a)
76+
return isdiag(a) ? _construct(T, a) : throw(InexactError(:convert, T, a))
77+
end
78+
6679
function Base.setindex!(
6780
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
6881
) where {N}
@@ -74,7 +87,12 @@ function Base.setindex!(
7487
),
7588
)
7689
end
77-
blocks(a)[Int.(I)...] = value
90+
# Custom `_convert` works around the issue that
91+
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
92+
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
93+
# https://github.com/JuliaLang/julia/pull/52487).
94+
# TODO: Delete once we drop support for Julia v1.10.
95+
blocks(a)[Int.(I)...] = _convert(blocktype(a), value)
7896
return a
7997
end
8098

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,37 @@ function ArrayLayouts.MemoryLayout(
2323
end
2424

2525
function Base.similar(
26-
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
26+
mul::MulAdd{
27+
<:BlockLayout{<:SparseLayout,BlockLayoutA},
28+
<:BlockLayout{<:SparseLayout,BlockLayoutB},
29+
LayoutC,
30+
T,
31+
A,
32+
B,
33+
C,
34+
},
2735
elt::Type,
2836
axes,
29-
) where {A,B}
30-
# TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type.
31-
output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...})
32-
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
37+
) where {BlockLayoutA,BlockLayoutB,LayoutC,T,A,B,C}
38+
39+
# TODO: Consider using this instead:
40+
# ```julia
41+
# blockmultype = MulAdd{BlockLayoutA,BlockLayoutB,LayoutC,T,blocktype(A),blocktype(B),C}
42+
# output_blocktype = Base.promote_op(
43+
# similar, blockmultype, Type{elt}, Tuple{eltype.(eachblockaxis.(axes))...}
44+
# )
45+
# ```
46+
# The issue is that it in some cases it seems to lose some information about the block types.
47+
48+
# TODO: Maybe this should be:
49+
# output_blocktype = Base.promote_op(
50+
# mul!, blocktype(mul.A), blocktype(mul.B), blocktype(mul.C), typeof(mul.α), typeof(mul.β)
51+
# )
52+
53+
output_blocktype = Base.promote_op(*, blocktype(mul.A), blocktype(mul.B))
54+
output_blocktype′ =
55+
!isconcretetype(output_blocktype) ? AbstractMatrix{elt} : output_blocktype
56+
return similar(BlockSparseArray{elt,length(axes),output_blocktype′}, axes)
3357
end
3458

3559
# Materialize a SubArray view.

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,11 @@ end
231231

232232
function blocksparse_similar(a, elt::Type, axes::Tuple)
233233
ndims = length(axes)
234-
blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
235-
return BlockSparseArray{elt,ndims,blockt}(undef, axes)
234+
# TODO: Define a version of `similartype` that catches the case
235+
# where the output isn't concrete and returns an `AbstractArray`.
236+
blockt = Base.promote_op(similar, blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
237+
blockt′ = !isconcretetype(blockt) ? AbstractArray{elt,ndims} : blockt
238+
return BlockSparseArray{elt,ndims,blockt′}(undef, axes)
236239
end
237240
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
238241
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}
@@ -283,13 +286,11 @@ function Base.similar(
283286
elt::Type,
284287
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
285288
)
286-
# TODO: Use `@interface interface(a) similar(...)`.
287289
return @interface interface(a) similar(a, elt, axes)
288290
end
289291

290292
# Fixes ambiguity error.
291293
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{})
292-
# TODO: Use `@interface interface(a) similar(...)`.
293294
return @interface interface(a) similar(a, elt, axes)
294295
end
295296

@@ -301,7 +302,6 @@ function Base.similar(
301302
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
302303
},
303304
)
304-
# TODO: Use `@interface interface(a) similar(...)`.
305305
return @interface interface(a) similar(a, elt, axes)
306306
end
307307

@@ -311,7 +311,6 @@ function Base.similar(
311311
elt::Type,
312312
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
313313
)
314-
# TODO: Use `@interface interface(a) similar(...)`.
315314
return @interface interface(a) similar(a, elt, axes)
316315
end
317316

@@ -321,9 +320,17 @@ function Base.similar(
321320
elt::Type,
322321
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
323322
)
324-
# TODO: Use `@interface interface(a) similar(...)`.
325323
return @interface interface(a) similar(a, elt, axes)
326324
end
325+
function Base.similar(a::AnyAbstractBlockSparseArray, elt::Type)
326+
return @interface interface(a) similar(a, elt, axes(a))
327+
end
328+
function Base.similar(
329+
a::AnyAbstractBlockSparseArray,
330+
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
331+
)
332+
return @interface interface(a) similar(a, eltype(a), axes)
333+
end
327334

328335
# Fixes ambiguity errors with BlockArrays.
329336
function Base.similar(
@@ -343,7 +350,6 @@ end
343350
function Base.similar(
344351
a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
345352
)
346-
# TODO: Use `@interface interface(a) similar(...)`.
347353
return @interface interface(a) similar(a, elt, axes)
348354
end
349355

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ function eachblockstoredindex(a::AbstractArray)
4242
return Block.(Tuple.(eachstoredindex(blocks(a))))
4343
end
4444

45+
function SparseArraysBase.isstored(a::AbstractArray, I1::Block{1}, Irest::Block{1}...)
46+
I = (I1, Irest...)
47+
return isstored(blocks(a), Int.(I)...)
48+
end
49+
function SparseArraysBase.isstored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
50+
return isstored(a, Tuple(I)...)
51+
end
52+
4553
using DiagonalArrays: diagindices
4654
# Block version of `DiagonalArrays.diagindices`.
4755
function blockdiagindices(a::AbstractArray)

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ end
1111
ax = ntuple(N) do d
1212
return only(axes(f.axes[d][Block(I[d])]))
1313
end
14+
!isconcretetype(A) && return zero!(similar(Array{eltype(A),N}, ax))
1415
return zero!(similar(A, ax))
1516
end
1617
@inline function (f::GetUnstoredBlock)(

src/blocksparsearrayinterface/map.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,38 @@ function union_eachblockstoredindex(as::AbstractArray...)
1818
return (map(eachblockstoredindex, as)...)
1919
end
2020

21+
# Get a view of a block assuming it is stored.
22+
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
23+
return blocks(a)[Int.(I)...]
24+
end
25+
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
26+
return viewblock_stored(a, Tuple(I)...)
27+
end
28+
29+
using FillArrays: Zeros
30+
# Get a view of a block if it is stored, otherwise return a lazy zeros.
31+
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
32+
if isstored(a, I...)
33+
return viewblock_stored(a, I...)
34+
else
35+
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
36+
return Zeros{eltype(a)}(block_ax)
37+
end
38+
end
39+
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
40+
return viewblock_or_zeros(a, Tuple(I)...)
41+
end
42+
43+
function map_block!(f, a_dest::AbstractArray, I::Block, a_srcs::AbstractArray...)
44+
a_srcs_I = map(a_src -> viewblock_or_zeros(a_src, I), a_srcs)
45+
if isstored(a_dest, I)
46+
a_dest[I] .= f.(a_srcs_I...)
47+
else
48+
a_dest[I] = Broadcast.broadcast_preserving_zero_d(f, a_srcs_I...)
49+
end
50+
return a_dest
51+
end
52+
2153
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
2254
# TODO: This assumes element types are numbers, generalize this logic.
2355
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
@@ -27,22 +59,7 @@ function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
2759
BlockRange(a_dest)
2860
end
2961
for I in Is
30-
# TODO: Use:
31-
# block_dest = @view a_dest[I]
32-
# or:
33-
# block_dest = @view! a_dest[I]
34-
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
35-
# TODO: Use:
36-
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37-
block_srcs = map(a_srcs) do a_src
38-
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
39-
end
40-
# TODO: Use `map!!` to handle immutable blocks.
41-
map!(f, block_dest, block_srcs...)
42-
# Replace the entire block, handles initializing new blocks
43-
# or if blocks are immutable.
44-
# TODO: Use `a_dest[I] = block_dest`.
45-
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
62+
map_block!(f, a_dest, I, a_srcs...)
4663
end
4764
return a_dest
4865
end
@@ -151,8 +168,12 @@ end
151168
function map_stored_blocks(f, a::AbstractArray)
152169
block_stored_indices = collect(eachblockstoredindex(a))
153170
if isempty(block_stored_indices)
171+
eltype_a′ = Base.promote_op(f, eltype(a))
154172
blocktype_a′ = Base.promote_op(f, blocktype(a))
155-
return BlockSparseArray{eltype(blocktype_a′),ndims(a),blocktype_a′}(undef, axes(a))
173+
eltype_a′′ = !isconcretetype(eltype_a′) ? Any : eltype_a′
174+
blocktype_a′′ =
175+
!isconcretetype(blocktype_a′) ? AbstractArray{eltype_a′′,ndims(a)} : blocktype_a′
176+
return BlockSparseArray{eltype_a′′,ndims(a),blocktype_a′′}(undef, axes(a))
156177
end
157178
stored_blocks = map(B -> f(@view!(a[B])), block_stored_indices)
158179
blocktype_a′ = eltype(stored_blocks)

src/factorizations/svd.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,21 @@ function MatrixAlgebraKit.default_svd_algorithm(
2222
return BlockPermutedDiagonalAlgorithm(alg)
2323
end
2424

25+
function output_type(
26+
::typeof(svd_compact!),
27+
A::Type{<:AbstractMatrix{T}},
28+
Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm},
29+
) where {T}
30+
USVᴴ = Base.promote_op(svd_compact!, A, Alg)
31+
!isconcretetype(USVᴴ) &&
32+
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
33+
return USVᴴ
34+
end
35+
2536
function similar_output(
2637
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2738
)
28-
BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg)))
39+
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg)))
2940
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
3041
S = similar(A, BlockType(BS), S_axes)
3142
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
44
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
55
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
66
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
7+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
78
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
89
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -25,6 +26,7 @@ Aqua = "0.8"
2526
ArrayLayouts = "1"
2627
BlockArrays = "1"
2728
BlockSparseArrays = "0.7"
29+
DerivableInterfaces = "0.5"
2830
DiagonalArrays = "0.3"
2931
GPUArraysCore = "0.2"
3032
JLArrays = "0.2"

test/test_abstract_blocktype.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Adapt: adapt
2+
using BlockArrays: Block
3+
using BlockSparseArrays: BlockSparseMatrix, blockstoredlength
4+
using JLArrays: JLArray
5+
using SparseArraysBase: storedlength
6+
using Test: @test, @test_broken, @testset
7+
8+
elts = (Float32, Float64, ComplexF32)
9+
arrayts = (Array, JLArray)
10+
@testset "Abstract block type (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
11+
elt in elts
12+
13+
dev = adapt(arrayt)
14+
a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
15+
@test sprint(show, MIME"text/plain"(), a) isa String
16+
@test iszero(storedlength(a))
17+
@test iszero(blockstoredlength(a))
18+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
19+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
20+
@test !iszero(a[Block(1, 1)])
21+
@test a[Block(1, 1)] isa arrayt{elt,2}
22+
@test !iszero(a[Block(2, 2)])
23+
@test a[Block(2, 2)] isa arrayt{elt,2}
24+
@test iszero(a[Block(2, 1)])
25+
@test a[Block(2, 1)] isa Matrix{elt}
26+
@test iszero(a[Block(1, 2)])
27+
@test a[Block(1, 2)] isa Matrix{elt}
28+
29+
b = copy(a)
30+
@test Array(b) Array(a)
31+
32+
b = a + a
33+
@test Array(b) Array(a) + Array(a)
34+
35+
b = 3a
36+
@test Array(b) 3Array(a)
37+
38+
if arrayt === Array
39+
b = a * a
40+
@test Array(b) Array(a) * Array(a)
41+
else
42+
@test_broken a * a
43+
end
44+
end

0 commit comments

Comments
 (0)