Skip to content

Commit 1ee8b12

Browse files
authored
Make block sparse SVD more generic (#147)
1 parent 311f391 commit 1ee8b12

File tree

5 files changed

+48
-27
lines changed

5 files changed

+48
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.12"
4+
version = "0.7.13"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ include("blocksparsearray/blockdiagonalarray.jl")
4545
include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
4646

4747
# factorizations
48+
include("factorizations/tensorproducts.jl")
4849
include("factorizations/svd.jl")
4950
include("factorizations/truncation.jl")
5051
include("factorizations/qr.jl")

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ struct GetUnstoredBlock{Axes}
66
end
77

88
@inline function (f::GetUnstoredBlock)(
9-
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
10-
) where {N}
11-
# TODO: Make sure this works for sparse or block sparse blocks, immutable
12-
# blocks, diagonal blocks, etc.!
13-
b_ax = ntuple(ndims(a)) do d
9+
::Type{<:AbstractArray{A,N}}, I::Vararg{Int,N}
10+
) where {A,N}
11+
ax = ntuple(N) do d
1412
return only(axes(f.axes[d][Block(I[d])]))
1513
end
16-
b = similar(eltype(a), b_ax)
17-
zero!(b)
18-
return b
14+
return zero!(similar(A, ax))
15+
end
16+
@inline function (f::GetUnstoredBlock)(
17+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
18+
) where {N}
19+
return f(typeof(a), I...)
1920
end
2021
# TODO: Use `Base.to_indices`.
2122
@inline function (f::GetUnstoredBlock)(
@@ -25,16 +26,17 @@ end
2526
end
2627

2728
# TODO: this is a hack and is also type-unstable
29+
using LinearAlgebra: Diagonal
30+
using TypeParameterAccessors: similartype
2831
function (f::GetUnstoredBlock)(
29-
a::AbstractMatrix{LinearAlgebra.Diagonal{T,V}}, I::Vararg{Int,2}
30-
) where {T,V}
31-
b_size = ntuple(ndims(a)) do d
32-
return length(f.axes[d][Block(I[d])])
32+
::Type{<:AbstractMatrix{<:Diagonal{<:Any,V}}}, I::Vararg{Int,2}
33+
) where {V}
34+
ax = ntuple(2) do d
35+
return only(axes(f.axes[d][Block(I[d])]))
3336
end
34-
if I[1] == I[2]
35-
diag = zero!(similar(V, b_size[1]))
36-
return LinearAlgebra.Diagonal{T,V}(diag)
37+
if allequal(I)
38+
return Diagonal(zero!(similar(V, first(ax))))
3739
else
38-
return zeros(T, b_size...)
40+
return zero!(similar(similartype(V, typeof(ax)), ax))
3941
end
4042
end

src/factorizations/svd.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ end
2525
function similar_output(
2626
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2727
)
28-
U = similar(A, axes(A, 1), S_axes[1])
29-
S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes)
30-
Vt = similar(A, S_axes[2], axes(A, 2))
31-
return U, S, Vt
28+
BU, BS, BVᴴ = fieldtypes(Base.promote_op(svd_compact!, blocktype(A), typeof(alg.alg)))
29+
U = similar(A, BlockType(BU), (axes(A, 1), S_axes[1]))
30+
S = similar(A, BlockType(BS), S_axes)
31+
Vᴴ = similar(A, BlockType(BVᴴ), (S_axes[2], axes(A, 2)))
32+
return U, S, Vᴴ
3233
end
3334

3435
function MatrixAlgebraKit.initialize_output(
@@ -48,19 +49,17 @@ function MatrixAlgebraKit.initialize_output(
4849
bcolIs = Int.(last.(Tuple.(bIs)))
4950
for bI in eachblockstoredindex(A)
5051
row, col = Int.(Tuple(bI))
51-
len = minimum(length, (brows[row], bcols[col]))
52-
u_axes[col] = brows[row][Base.OneTo(len)]
53-
v_axes[col] = bcols[col][Base.OneTo(len)]
52+
u_axes[col] = infimum(brows[row], bcols[col])
53+
v_axes[col] = infimum(bcols[col], brows[row])
5454
end
5555

5656
# fill in values for blocks that aren't present, pairing them in order of occurence
5757
# this is a convention, which at least gives the expected results for blockdiagonal
5858
emptyrows = setdiff(1:bm, browIs)
5959
emptycols = setdiff(1:bn, bcolIs)
6060
for (row, col) in zip(emptyrows, emptycols)
61-
len = minimum(length, (brows[row], bcols[col]))
62-
u_axes[col] = brows[row][Base.OneTo(len)]
63-
v_axes[col] = bcols[col][Base.OneTo(len)]
61+
u_axes[col] = infimum(brows[row], bcols[col])
62+
v_axes[col] = infimum(bcols[col], brows[row])
6463
end
6564

6665
u_axis = mortar_axis(u_axes)

src/factorizations/tensorproducts.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange)
2+
(isone(first(r1)) && isone(first(r2))) ||
3+
throw(ArgumentError("infimum only defined for ranges starting at 1"))
4+
if length(r1) length(r2)
5+
return r1
6+
else
7+
return r1[r2]
8+
end
9+
end
10+
11+
function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange)
12+
(isone(first(r1)) && isone(first(r2))) ||
13+
throw(ArgumentError("supremum only defined for ranges starting at 1"))
14+
if length(r1) length(r2)
15+
return r1
16+
else
17+
return r2
18+
end
19+
end

0 commit comments

Comments
 (0)