Skip to content

Commit 47ad5ff

Browse files
authored
Fix matrix factorizations of block sparse arrays with abstract block type (#154)
1 parent 33330bd commit 47ad5ff

File tree

8 files changed

+192
-66
lines changed

8 files changed

+192
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.17"
4+
version = "0.7.18"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/factorizations/eig.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ using MatrixAlgebraKit:
1818

1919
for f in [:default_eig_algorithm, :default_eigh_algorithm]
2020
@eval begin
21-
function MatrixAlgebraKit.$f(arrayt::Type{<:AbstractBlockSparseMatrix}; kwargs...)
22-
alg = $f(blocktype(arrayt); kwargs...)
23-
return BlockPermutedDiagonalAlgorithm(alg)
21+
function MatrixAlgebraKit.$f(::Type{<:AbstractBlockSparseMatrix}; kwargs...)
22+
return BlockPermutedDiagonalAlgorithm() do block
23+
return $f(block; kwargs...)
24+
end
2425
end
2526
end
2627
end
@@ -45,12 +46,23 @@ function MatrixAlgebraKit.check_input(
4546
return nothing
4647
end
4748

49+
function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T}
50+
DV = Base.promote_op(f, A)
51+
!isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}}
52+
return DV
53+
end
54+
function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T}
55+
DV = Base.promote_op(f, A)
56+
!isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}}
57+
return DV
58+
end
59+
4860
for f in [:eig_full!, :eigh_full!]
4961
@eval begin
5062
function MatrixAlgebraKit.initialize_output(
5163
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
5264
)
53-
Td, Tv = fieldtypes(Base.promote_op($f, blocktype(A), typeof(alg.alg)))
65+
Td, Tv = fieldtypes(output_type($f, blocktype(A)))
5466
D = similar(A, BlockType(Td))
5567
V = similar(A, BlockType(Tv))
5668
return (D, V)
@@ -60,7 +72,9 @@ for f in [:eig_full!, :eigh_full!]
6072
)
6173
check_input($f, A, (D, V))
6274
for I in eachstoredblockdiagindex(A)
63-
D[I], V[I] = $f(@view(A[I]), alg.alg)
75+
block = @view!(A[I])
76+
block_alg = block_algorithm(alg, block)
77+
D[I], V[I] = $f(block, block_alg)
6478
end
6579
for I in eachunstoredblockdiagindex(A)
6680
# TODO: Support setting `LinearAlgebra.I` directly, and/or
@@ -72,19 +86,31 @@ for f in [:eig_full!, :eigh_full!]
7286
end
7387
end
7488

89+
function output_type(f::typeof(eig_vals!), A::Type{<:AbstractMatrix{T}}) where {T}
90+
D = Base.promote_op(f, A)
91+
!isconcretetype(D) && return AbstractVector{complex(T)}
92+
return D
93+
end
94+
function output_type(f::typeof(eigh_vals!), A::Type{<:AbstractMatrix{T}}) where {T}
95+
D = Base.promote_op(f, A)
96+
!isconcretetype(D) && return AbstractVector{real(T)}
97+
return D
98+
end
99+
75100
for f in [:eig_vals!, :eigh_vals!]
76101
@eval begin
77102
function MatrixAlgebraKit.initialize_output(
78103
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
79104
)
80-
T = Base.promote_op($f, blocktype(A), typeof(alg.alg))
105+
T = output_type($f, blocktype(A))
81106
return similar(A, BlockType(T), axes(A, 1))
82107
end
83108
function MatrixAlgebraKit.$f(
84109
A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm
85110
)
86111
for I in eachblockstoredindex(A)
87-
D[I] = $f(@view!(A[I]), alg.alg)
112+
block = @view!(A[I])
113+
D[I] = $f(block, block_algorithm(alg, block))
88114
end
89115
return D
90116
end

src/factorizations/lq.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_
33
function MatrixAlgebraKit.default_lq_algorithm(
44
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
55
)
6-
alg = default_lq_algorithm(blocktype(A); kwargs...)
7-
return BlockPermutedDiagonalAlgorithm(alg)
6+
return BlockPermutedDiagonalAlgorithm() do block
7+
return default_lq_algorithm(block; kwargs...)
8+
end
89
end
910

1011
function similar_output(
@@ -58,8 +59,10 @@ function MatrixAlgebraKit.initialize_output(
5859
# allocate output
5960
for bI in eachblockstoredindex(A)
6061
brow, bcol = Tuple(bI)
62+
block = @view!(A[bI])
63+
block_alg = block_algorithm(alg, block)
6164
L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output(
62-
lq_compact!, @view!(A[bI]), alg.alg
65+
lq_compact!, block, block_alg
6366
)
6467
end
6568

@@ -105,8 +108,10 @@ function MatrixAlgebraKit.initialize_output(
105108
# allocate output
106109
for bI in eachblockstoredindex(A)
107110
brow, bcol = Tuple(bI)
111+
block = @view!(A[bI])
112+
block_alg = block_algorithm(alg, block)
108113
L[brow, brow], Q[brow, bcol] = MatrixAlgebraKit.initialize_output(
109-
lq_full!, @view!(A[bI]), alg.alg
114+
lq_full!, block, block_alg
110115
)
111116
end
112117

@@ -154,7 +159,9 @@ function MatrixAlgebraKit.lq_compact!(
154159
for bI in eachblockstoredindex(A)
155160
brow, bcol = Tuple(bI)
156161
lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol]))
157-
lq′ = lq_compact!(@view!(A[bI]), lq, alg.alg)
162+
block = @view!(A[bI])
163+
block_alg = block_algorithm(alg, block)
164+
lq′ = lq_compact!(block, lq, block_alg)
158165
@assert lq === lq′ "lq_compact! might not be in-place"
159166
end
160167

@@ -183,7 +190,9 @@ function MatrixAlgebraKit.lq_full!(
183190
for bI in eachblockstoredindex(A)
184191
brow, bcol = Tuple(bI)
185192
lq = (@view!(L[brow, brow]), @view!(Q[brow, bcol]))
186-
lq′ = lq_full!(@view!(A[bI]), lq, alg.alg)
193+
block = @view!(A[bI])
194+
block_alg = block_algorithm(alg, block)
195+
lq′ = lq_full!(block, lq, block_alg)
187196
@assert lq === lq′ "lq_full! might not be in-place"
188197
end
189198

src/factorizations/qr.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ using MatrixAlgebraKit:
22
MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full!
33

44
function MatrixAlgebraKit.default_qr_algorithm(
5-
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
5+
::Type{<:AbstractBlockSparseMatrix}; kwargs...
66
)
7-
alg = default_qr_algorithm(blocktype(A); kwargs...)
8-
return BlockPermutedDiagonalAlgorithm(alg)
7+
return BlockPermutedDiagonalAlgorithm() do block
8+
return default_qr_algorithm(block; kwargs...)
9+
end
910
end
1011

1112
function similar_output(
@@ -59,8 +60,10 @@ function MatrixAlgebraKit.initialize_output(
5960
# allocate output
6061
for bI in eachblockstoredindex(A)
6162
brow, bcol = Tuple(bI)
63+
block = @view!(A[bI])
64+
block_alg = block_algorithm(alg, block)
6265
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
63-
qr_compact!, @view!(A[bI]), alg.alg
66+
qr_compact!, block, block_alg
6467
)
6568
end
6669

@@ -106,8 +109,10 @@ function MatrixAlgebraKit.initialize_output(
106109
# allocate output
107110
for bI in eachblockstoredindex(A)
108111
brow, bcol = Tuple(bI)
112+
block = @view!(A[bI])
113+
block_alg = block_algorithm(alg, block)
109114
Q[brow, bcol], R[bcol, bcol] = MatrixAlgebraKit.initialize_output(
110-
qr_full!, @view!(A[bI]), alg.alg
115+
qr_full!, block, block_alg
111116
)
112117
end
113118

@@ -155,7 +160,9 @@ function MatrixAlgebraKit.qr_compact!(
155160
for bI in eachblockstoredindex(A)
156161
brow, bcol = Tuple(bI)
157162
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
158-
qr′ = qr_compact!(@view!(A[bI]), qr, alg.alg)
163+
block = @view!(A[bI])
164+
block_alg = block_algorithm(alg, block)
165+
qr′ = qr_compact!(block, qr, block_alg)
159166
@assert qr === qr′ "qr_compact! might not be in-place"
160167
end
161168

@@ -184,7 +191,9 @@ function MatrixAlgebraKit.qr_full!(
184191
for bI in eachblockstoredindex(A)
185192
brow, bcol = Tuple(bI)
186193
qr = (@view!(Q[brow, bcol]), @view!(R[bcol, bcol]))
187-
qr′ = qr_full!(@view!(A[bI]), qr, alg.alg)
194+
block = @view!(A[bI])
195+
block_alg = block_algorithm(alg, block)
196+
qr′ = qr_full!(block, qr, block_alg)
188197
@assert qr === qr′ "qr_full! might not be in-place"
189198
end
190199

src/factorizations/svd.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,26 @@ A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped a
1010
a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
1111
a block permuted block-diagonal matrix.
1212
"""
13-
struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <:
14-
MatrixAlgebraKit.AbstractAlgorithm
15-
alg::A
13+
struct BlockPermutedDiagonalAlgorithm{F} <: MatrixAlgebraKit.AbstractAlgorithm
14+
falg::F
15+
end
16+
function block_algorithm(alg::BlockPermutedDiagonalAlgorithm, a::AbstractMatrix)
17+
return block_algorithm(alg, typeof(a))
18+
end
19+
function block_algorithm(alg::BlockPermutedDiagonalAlgorithm, A::Type{<:AbstractMatrix})
20+
return alg.falg(A)
1621
end
1722

1823
function MatrixAlgebraKit.default_svd_algorithm(
19-
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
24+
::Type{<:AbstractBlockSparseMatrix}; kwargs...
2025
)
21-
alg = default_svd_algorithm(blocktype(A); kwargs...)
22-
return BlockPermutedDiagonalAlgorithm(alg)
26+
return BlockPermutedDiagonalAlgorithm() do block
27+
return default_svd_algorithm(block; kwargs...)
28+
end
2329
end
2430

25-
function output_type(
26-
::typeof(svd_compact!),
27-
A::Type{<:AbstractMatrix{T}},
28-
Alg::Type{<:MatrixAlgebraKit.AbstractAlgorithm},
29-
) where {T}
30-
USVᴴ = Base.promote_op(svd_compact!, A, Alg)
31+
function output_type(::typeof(svd_compact!), A::Type{<:AbstractMatrix{T}}) where {T}
32+
USVᴴ = Base.promote_op(svd_compact!, A)
3133
!isconcretetype(USVᴴ) &&
3234
return Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
3335
return USVᴴ
@@ -36,7 +38,7 @@ end
3638
function similar_output(
3739
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
3840
)
39-
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A), typeof(alg.alg)))
41+
BU, BS, BVᴴ = fieldtypes(output_type(svd_compact!, blocktype(A)))
4042
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
4143
S = similar(A, BlockType(BS), S_axes)
4244
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
@@ -81,8 +83,10 @@ function MatrixAlgebraKit.initialize_output(
8183
# allocate output
8284
for bI in eachblockstoredindex(A)
8385
brow, bcol = Tuple(bI)
86+
block = @view!(A[bI])
87+
block_alg = block_algorithm(alg, block)
8488
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
85-
svd_compact!, @view!(A[bI]), alg.alg
89+
svd_compact!, block, block_alg
8690
)
8791
end
8892

@@ -140,8 +144,10 @@ function MatrixAlgebraKit.initialize_output(
140144
# allocate output
141145
for bI in eachblockstoredindex(A)
142146
brow, bcol = Tuple(bI)
147+
block = @view!(A[bI])
148+
block_alg = block_algorithm(alg, block)
143149
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
144-
svd_full!, @view!(A[bI]), alg.alg
150+
svd_full!, block, block_alg
145151
)
146152
end
147153

@@ -196,7 +202,9 @@ function MatrixAlgebraKit.svd_compact!(
196202
for bI in eachblockstoredindex(A)
197203
brow, bcol = Tuple(bI)
198204
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
199-
usvᴴ′ = svd_compact!(@view!(A[bI]), usvᴴ, alg.alg)
205+
block = @view!(A[bI])
206+
block_alg = block_algorithm(alg, block)
207+
usvᴴ′ = svd_compact!(block, usvᴴ, block_alg)
200208
@assert usvᴴ === usvᴴ′ "svd_compact! might not be in-place"
201209
end
202210

@@ -226,7 +234,9 @@ function MatrixAlgebraKit.svd_full!(
226234
for bI in eachblockstoredindex(A)
227235
brow, bcol = Tuple(bI)
228236
usvᴴ = (@view!(U[brow, bcol]), @view!(S[bcol, bcol]), @view!(Vᴴ[bcol, bcol]))
229-
usvᴴ′ = svd_full!(@view!(A[bI]), usvᴴ, alg.alg)
237+
block = @view!(A[bI])
238+
block_alg = block_algorithm(alg, block)
239+
usvᴴ′ = svd_full!(block, usvᴴ, block_alg)
230240
@assert usvᴴ === usvᴴ′ "svd_full! might not be in-place"
231241
end
232242

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ DiagonalArrays = "0.3"
3131
GPUArraysCore = "0.2"
3232
JLArrays = "0.2"
3333
LinearAlgebra = "1"
34-
MatrixAlgebraKit = "0.2"
34+
MatrixAlgebraKit = "0.2.5"
3535
Random = "1"
3636
SafeTestsets = "0.1"
3737
SparseArraysBase = "0.5.11"

0 commit comments

Comments
 (0)