Skip to content

Commit 08a4d38

Browse files
authored
Matrix functions with abstract block type (#155)
1 parent 47ad5ff commit 08a4d38

File tree

3 files changed

+84
-66
lines changed

3 files changed

+84
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.18"
4+
version = "0.7.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,14 @@ const MATRIX_FUNCTIONS_UNSTABLE = [
9393
]
9494

9595
function initialize_output_blocksparse(f::F, a::AbstractMatrix) where {F}
96-
B = Base.promote_op(f, blocktype(a))
97-
return similar(a, BlockType(B))
96+
blockt = Base.promote_op(f, blocktype(a))
97+
elt′ = Base.promote_op(f, eltype(a))
98+
blockt′ = if !(blockt <: AbstractMatrix{elt′}) || blockt === Union{}
99+
AbstractMatrix{elt′}
100+
else
101+
blockt
102+
end
103+
return similar(a, BlockType(blockt′))
98104
end
99105

100106
function matrix_function_blocksparse(f::F, a::AbstractMatrix; kwargs...) where {F}
@@ -117,8 +123,14 @@ end
117123
for f in MATRIX_FUNCTIONS_UNSTABLE
118124
@eval begin
119125
function initialize_output_blocksparse(::typeof($f), a::AbstractMatrix)
120-
B = similartype(blocktype(a), complex(eltype(a)))
121-
return similar(a, BlockType(B))
126+
elt′ = complex(eltype(a))
127+
blockt = Base.promote_op(similar, blocktype(a), elt′)
128+
blockt′ = if !(blockt <: AbstractMatrix{elt′}) || blockt === Union{}
129+
AbstractMatrix{elt′}
130+
else
131+
blockt
132+
end
133+
return similar(a, BlockType(blockt′))
122134
end
123135
end
124136
end

test/test_factorizations.jl

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -34,79 +34,85 @@ using StableRNGs: StableRNG
3434
using Test: @inferred, @test, @test_broken, @test_throws, @testset
3535

3636
@testset "Matrix functions (T=$elt)" for elt in (Float32, Float64, ComplexF64)
37-
rng = StableRNG(123)
38-
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
39-
a[Block(1, 1)] = randn(rng, elt, 2, 2)
40-
a[Block(2, 2)] = randn(rng, elt, 3, 3)
41-
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
42-
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
43-
# Only works when real, also isn't defined in Julia 1.10.
44-
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
45-
MATRIX_FUNCTIONS_LOW_ACCURACY = [:acoth]
46-
for f in setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_LOW_ACCURACY)
47-
@eval begin
48-
fa = $f($a)
49-
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
50-
@test fa isa BlockSparseMatrix
51-
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
37+
for matrixt in (Matrix, AbstractMatrix)
38+
a = BlockSparseMatrix{elt,matrixt{elt}}(undef, [2, 3], [2, 3])
39+
rng = StableRNG(123)
40+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
41+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
42+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
43+
MATRIX_FUNCTIONS = [MATRIX_FUNCTIONS; [:inv, :pinv]]
44+
# Only works when real, also isn't defined in Julia 1.10.
45+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
46+
MATRIX_FUNCTIONS_LOW_ACCURACY = [:acoth]
47+
for f in setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_LOW_ACCURACY)
48+
@eval begin
49+
fa = $f($a)
50+
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
51+
@test fa isa BlockSparseMatrix
52+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
53+
end
5254
end
53-
end
54-
for f in MATRIX_FUNCTIONS_LOW_ACCURACY
55-
@eval begin
56-
fa = $f($a)
57-
if !Sys.isapple() && ($elt <: Real)
58-
# `acoth` appears to be broken on this matrix on Windows and Ubuntu
59-
# for real matrices.
60-
@test_broken Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
61-
else
62-
@test Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
55+
for f in MATRIX_FUNCTIONS_LOW_ACCURACY
56+
@eval begin
57+
fa = $f($a)
58+
if !Sys.isapple() && ($elt <: Real)
59+
# `acoth` appears to be broken on this matrix on Windows and Ubuntu
60+
# for real matrices.
61+
@test_broken Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
62+
else
63+
@test Matrix(fa) $f(Matrix($a)) rtol = eps(real($elt))
64+
end
65+
@test fa isa BlockSparseMatrix
66+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
6367
end
64-
@test fa isa BlockSparseMatrix
65-
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
6668
end
6769
end
6870

6971
# Catch case of off-diagonal blocks.
70-
rng = StableRNG(123)
71-
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
72-
a[Block(1, 1)] = randn(rng, elt, 2, 2)
73-
a[Block(1, 2)] = randn(rng, elt, 2, 3)
74-
for f in MATRIX_FUNCTIONS
75-
@eval begin
76-
@test_throws ArgumentError $f($a)
72+
for matrixt in (Matrix, AbstractMatrix)
73+
a = BlockSparseMatrix{elt,matrixt{elt}}(undef, [2, 3], [2, 3])
74+
rng = StableRNG(123)
75+
a[Block(1, 1)] = randn(rng, elt, 2, 2)
76+
a[Block(1, 2)] = randn(rng, elt, 2, 3)
77+
for f in BlockSparseArrays.MATRIX_FUNCTIONS
78+
@eval begin
79+
@test_throws ArgumentError $f($a)
80+
end
7781
end
7882
end
7983

8084
# Missing diagonal blocks.
81-
rng = StableRNG(123)
82-
a = BlockSparseMatrix{elt}(undef, [2, 3], [2, 3])
83-
a[Block(2, 2)] = randn(rng, elt, 3, 3)
84-
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
85-
# These functions involve inverses so they break when there are zeros on the diagonal.
86-
MATRIX_FUNCTIONS_SINGULAR = [
87-
:log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth
88-
]
89-
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
90-
# Dense version is broken for some reason, investigate.
91-
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
92-
for f in MATRIX_FUNCTIONS
93-
@eval begin
94-
fa = $f($a)
95-
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
96-
@test fa isa BlockSparseMatrix
97-
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
85+
for matrixt in (Matrix, AbstractMatrix)
86+
a = BlockSparseMatrix{elt,matrixt{elt}}(undef, [2, 3], [2, 3])
87+
rng = StableRNG(123)
88+
a[Block(2, 2)] = randn(rng, elt, 3, 3)
89+
MATRIX_FUNCTIONS = BlockSparseArrays.MATRIX_FUNCTIONS
90+
# These functions involve inverses so they break when there are zeros on the diagonal.
91+
MATRIX_FUNCTIONS_SINGULAR = [
92+
:log, :acsc, :asec, :acot, :acsch, :asech, :acoth, :csc, :cot, :csch, :coth
93+
]
94+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, MATRIX_FUNCTIONS_SINGULAR)
95+
# Dense version is broken for some reason, investigate.
96+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
97+
for f in MATRIX_FUNCTIONS
98+
@eval begin
99+
fa = $f($a)
100+
@test Matrix(fa) $f(Matrix($a)) rtol = (eps(real($elt)))
101+
@test fa isa BlockSparseMatrix
102+
@test issetequal(eachblockstoredindex(fa), [Block(1, 1), Block(2, 2)])
103+
end
98104
end
99-
end
100105

101-
SINGULAR_EXCEPTION = if VERSION < v"1.11-"
102-
# A different exception is thrown in older versions of Julia.
103-
LinearAlgebra.LAPACKException
104-
else
105-
LinearAlgebra.SingularException
106-
end
107-
for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log])
108-
@eval begin
109-
@test_throws $SINGULAR_EXCEPTION $f($a)
106+
SINGULAR_EXCEPTION = if VERSION < v"1.11-"
107+
# A different exception is thrown in older versions of Julia.
108+
LinearAlgebra.LAPACKException
109+
else
110+
LinearAlgebra.SingularException
111+
end
112+
for f in setdiff(MATRIX_FUNCTIONS_SINGULAR, [:log])
113+
@eval begin
114+
@test_throws $SINGULAR_EXCEPTION $f($a)
115+
end
110116
end
111117
end
112118
end

0 commit comments

Comments
 (0)