Skip to content

Commit 37b5183

Browse files
authored
More general block types in broadcast style (#145)
1 parent f991254 commit 37b5183

File tree

5 files changed

+70
-13
lines changed

5 files changed

+70
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.10"
4+
version = "0.7.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using BlockArrays: AbstractBlockedUnitRange, BlockSlice
2-
using Base.Broadcast: Broadcast
2+
using Base.Broadcast: Broadcast, BroadcastStyle
33

44
function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray})
5-
return BlockSparseArrayStyle{ndims(arraytype)}()
5+
return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
66
end
77

88
# Fix ambiguity error with `BlockArrays`.

src/blocksparsearray/blocksparsearray.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,24 @@ function BlockSparseArray{T,N,A}(
171171
return BlockSparseArray{T,N,A}(undef, (dim1, dim_rest...))
172172
end
173173

174+
function similartype_unchecked(
175+
A::Type{<:AbstractArray{T}}, axt::Type{<:Tuple{Vararg{Any,N}}}
176+
) where {T,N}
177+
A′ = Base.promote_op(similar, A, axt)
178+
return !isconcretetype(A′) ? Array{T,N} : A′
179+
end
180+
174181
function BlockSparseArray{T,N}(
175182
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}}
176183
) where {T,N}
177184
axt = Tuple{blockaxistype.(axes)...}
178-
A = similartype(Array{T}, axt)
185+
# Ideally we would use:
186+
# ```julia
187+
# A = similartype(Array{T}, axt)
188+
# ```
189+
# but that doesn't work when `similar` isn't defined or
190+
# isn't type stable.
191+
A = similartype_unchecked(Array{T}, axt)
179192
return BlockSparseArray{T,N,A}(undef, axes)
180193
end
181194

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,44 @@
1-
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
1+
using Base.Broadcast:
2+
Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
23
using GPUArraysCore: @allowscalar
34
using MapBroadcast: Mapped
45
using DerivableInterfaces: DerivableInterfaces, @interface
56

6-
abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
7+
abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
8+
AbstractArrayStyle{N} end
79

8-
function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle})
9-
return BlockSparseArrayInterface()
10+
blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{N}} = B()
11+
12+
function Broadcast.BroadcastStyle(
13+
style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle
14+
)
15+
style = Broadcast.result_style(blockstyle(style1), blockstyle(style2))
16+
return BlockSparseArrayStyle(style)
1017
end
1118

12-
struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end
19+
function DerivableInterfaces.interface(
20+
::Type{<:AbstractBlockSparseArrayStyle{N,B}}
21+
) where {N,B<:AbstractArrayStyle{N}}
22+
return BlockSparseArrayInterface(interface(B))
23+
end
1324

14-
# Define for new sparse array types.
15-
# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray})
16-
# return BlockSparseArrayStyle{ndims(arraytype)}()
17-
# end
25+
struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
26+
AbstractBlockSparseArrayStyle{N,B}
27+
blockstyle::B
28+
end
29+
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
30+
return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle)
31+
end
1832

33+
function BlockSparseArrayStyle{N,B}() where {N,B<:AbstractArrayStyle{N}}
34+
return BlockSparseArrayStyle{N,B}(B())
35+
end
36+
BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}())
1937
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
2038
BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}()
39+
function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N}
40+
return BlockSparseArrayStyle{N}(B(Val(N)))
41+
end
2142

2243
Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a
2344
function Broadcast.BroadcastStyle(

test/test_basics.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using BlockSparseArrays:
3333
eachblockstoredindex,
3434
eachstoredblock,
3535
eachstoredblockdiagindex,
36+
similartype_unchecked,
3637
sparsemortar,
3738
view!
3839
using GPUArraysCore: @allowscalar
@@ -44,6 +45,28 @@ using TestExtras: @constinferred
4445
using TypeParameterAccessors: TypeParameterAccessors, Position
4546
include("TestBlockSparseArraysUtils.jl")
4647

48+
@testset "similartype_unchecked" begin
49+
@test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Int})) ===
50+
Matrix{Float32}
51+
@test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Base.OneTo{Int}})) ===
52+
Matrix{Float32}
53+
if VERSION < v"1.11-"
54+
# Not type stable in Julia 1.10.
55+
@test similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int}) === Matrix{Float32}
56+
@test similartype_unchecked(JLArray{Float32}, NTuple{2,Int}) === JLMatrix{Float32}
57+
@test similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) ===
58+
JLMatrix{Float32}
59+
else
60+
@test @constinferred(similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int})) ===
61+
Matrix{Float32}
62+
@test @constinferred(similartype_unchecked(JLArray{Float32}, NTuple{2,Int})) ===
63+
JLMatrix{Float32}
64+
@test @constinferred(
65+
similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}})
66+
) === JLMatrix{Float32}
67+
end
68+
end
69+
4770
arrayts = (Array, JLArray)
4871
@testset "BlockSparseArrays (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
4972
elt in (Float32, Float64, Complex{Float32}, Complex{Float64})

0 commit comments

Comments
 (0)