Skip to content

Commit f991254

Browse files
authored
More generalizations for generic block types (#144)
1 parent 51fceec commit f991254

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.9"
4+
version = "0.7.10"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -44,7 +44,7 @@ SparseArraysBase = "0.5"
4444
SplitApplyCombine = "1.2.3"
4545
TensorAlgebra = "0.3.2"
4646
Test = "1.10"
47-
TypeParameterAccessors = "0.4"
47+
TypeParameterAccessors = "0.4.1"
4848
julia = "1.10"
4949

5050
[extras]

src/blocksparsearray/blocksparsearray.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using BlockArrays:
99
using DerivableInterfaces: @interface
1010
using Dictionaries: Dictionary
1111
using SparseArraysBase: SparseArrayDOK
12+
using TypeParameterAccessors: similartype
1213

1314
"""
1415
SparseArrayDOK{T}(undef_blocks, axes)
@@ -173,7 +174,9 @@ end
173174
function BlockSparseArray{T,N}(
174175
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}}
175176
) where {T,N}
176-
return BlockSparseArray{T,N,Array{T,N}}(undef, axes)
177+
axt = Tuple{blockaxistype.(axes)...}
178+
A = similartype(Array{T}, axt)
179+
return BlockSparseArray{T,N,A}(undef, axes)
177180
end
178181

179182
function BlockSparseArray{T,N}(

src/blocksparsearrayinterface/arraylayouts.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ using LinearAlgebra: LinearAlgebra, dot, mul!
1111
return a_dest
1212
end
1313

14+
function DerivableInterfaces.interface(m::MulAdd)
15+
return interface(m.A, m.B, m.C)
16+
end
17+
1418
function ArrayLayouts.materialize!(
1519
m::MatMulMatAdd{
1620
<:BlockLayout{<:SparseLayout},
1721
<:BlockLayout{<:SparseLayout},
1822
<:BlockLayout{<:SparseLayout},
1923
},
2024
)
21-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
22-
@interface BlockSparseArrayInterface() muladd!(m.α, m.A, m.B, m.β, m.C)
25+
@interface interface(m) muladd!(m.α, m.A, m.B, m.β, m.C)
2326
return m.C
2427
end
2528
function ArrayLayouts.materialize!(
@@ -29,7 +32,7 @@ function ArrayLayouts.materialize!(
2932
<:BlockLayout{<:SparseLayout},
3033
},
3134
)
32-
@interface BlockSparseArrayInterface() matmul!(m)
35+
@interface interface(m) matmul!(m)
3336
return m.C
3437
end
3538

@@ -42,5 +45,5 @@ end
4245
end
4346

4447
function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}})
45-
return @interface BlockSparseArrayInterface() dot(d.A, d.B)
48+
return @interface interface(d.A, d.B) dot(d.A, d.B)
4649
end

src/factorizations/svd.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using DiagonalArrays: diagonaltype
12
using MatrixAlgebraKit:
23
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
4+
using TypeParameterAccessors: realtype
35

46
"""
57
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -24,10 +26,7 @@ function similar_output(
2426
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2527
)
2628
U = similar(A, axes(A, 1), S_axes[1])
27-
T = real(eltype(A))
28-
# TODO: this should be replaced with a more general similar function that can handle setting
29-
# the blocktype and element type - something like S = similar(A, BlockType(...))
30-
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes)
29+
S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes)
3130
Vt = similar(A, S_axes[2], axes(A, 2))
3231
return U, S, Vt
3332
end

0 commit comments

Comments
 (0)