Skip to content

Commit 4ee09d2

Browse files
authored
Generalize BlockIndexVector (#156)
1 parent 08a4d38 commit 4ee09d2

File tree

3 files changed

+100
-11
lines changed

3 files changed

+100
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.19"
4+
version = "0.7.20"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,111 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:
216216
)
217217
end
218218

219-
struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <:
220-
AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}}
221-
block::Block{1,TB}
219+
struct GenericBlockIndex{N,TI<:Tuple{Vararg{Integer,N}},Tα<:Tuple{Vararg{Any,N}}}
220+
I::TI
221+
α::Tα
222+
end
223+
@inline function GenericBlockIndex(a::NTuple{N,Block{1}}, b::Tuple) where {N}
224+
return GenericBlockIndex(Int.(a), b)
225+
end
226+
@inline function GenericBlockIndex(::Tuple{}, b::Tuple{})
227+
return GenericBlockIndex{0,Tuple{},Tuple{}}((), ())
228+
end
229+
@inline GenericBlockIndex(a::Integer, b) = GenericBlockIndex((a,), (b,))
230+
@inline GenericBlockIndex(a::Tuple, b) = GenericBlockIndex(a, (b,))
231+
@inline GenericBlockIndex(a::Integer, b::Tuple) = GenericBlockIndex((a,), b)
232+
@inline GenericBlockIndex() = GenericBlockIndex((), ())
233+
@inline GenericBlockIndex(a::Block, b::Tuple) = GenericBlockIndex(a.n, b)
234+
@inline GenericBlockIndex(a::Block, b) = GenericBlockIndex(a, (b,))
235+
@inline function GenericBlockIndex(
236+
I::Tuple{Vararg{Integer,N}}, α::Tuple{Vararg{Any,M}}
237+
) where {M,N}
238+
M <= N || throw(ArgumentError("number of indices must not exceed the number of blocks"))
239+
α2 = ntuple(k -> k <= M ? α[k] : 1, N)
240+
GenericBlockIndex(I, α2)
241+
end
242+
BlockArrays.block(b::GenericBlockIndex) = Block(b.I...)
243+
BlockArrays.blockindex(b::GenericBlockIndex{1}) = b.α[1]
244+
function GenericBlockIndex(indcs::Tuple{Vararg{GenericBlockIndex{1},N}}) where {N}
245+
GenericBlockIndex(block.(indcs), blockindex.(indcs))
246+
end
247+
function print_tuple_elements(io::IO, @nospecialize(t))
248+
if !isempty(t)
249+
print(io, t[1])
250+
for n in t[2:end]
251+
print(io, ", ", n)
252+
end
253+
end
254+
return nothing
255+
end
256+
function Base.show(io::IO, B::GenericBlockIndex)
257+
show(io, Block(B.I...))
258+
print(io, "[")
259+
print_tuple_elements(io, B.α)
260+
print(io, "]")
261+
return nothing
262+
end
263+
264+
using Base: @propagate_inbounds
265+
@propagate_inbounds function Base.getindex(b::AbstractVector, K::GenericBlockIndex{1})
266+
return b[Block(K.I[1])][K.α[1]]
267+
end
268+
@propagate_inbounds function Base.getindex(
269+
b::AbstractArray{T,N}, K::GenericBlockIndex{N}
270+
) where {T,N}
271+
return b[block(K)][K.α...]
272+
end
273+
@propagate_inbounds function Base.getindex(
274+
b::AbstractArray, K::GenericBlockIndex{1}, J::GenericBlockIndex{1}...
275+
)
276+
return b[GenericBlockIndex(tuple(K, J...))]
277+
end
278+
279+
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
280+
return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
281+
end
282+
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
283+
return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
284+
end
285+
286+
struct BlockIndexVector{N,I<:NTuple{N,AbstractVector},TB<:Integer,BT} <: AbstractArray{BT,N}
287+
block::Block{N,TB}
222288
indices::I
289+
function BlockIndexVector(
290+
block::Block{N,TB}, indices::I
291+
) where {N,I<:NTuple{N,AbstractVector},TB<:Integer}
292+
BT = blockindextype(TB, eltype.(indices)...)
293+
return new{N,I,TB,BT}(block, indices)
294+
end
295+
end
296+
function BlockIndexVector(block::Block{1}, indices::AbstractVector)
297+
return BlockIndexVector(block, (indices,))
298+
end
299+
Base.size(a::BlockIndexVector) = length.(a.indices)
300+
function Base.getindex(a::BlockIndexVector{N}, I::Vararg{Integer,N}) where {N}
301+
return a.block[map((r, i) -> r[i], a.indices, I)...]
223302
end
224-
Base.length(a::BlockIndexVector) = length(a.indices)
225-
Base.size(a::BlockIndexVector) = (length(a),)
226-
BlockArrays.Block(a::BlockIndexVector) = a.block
227-
Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]]
228-
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices))
303+
BlockArrays.block(b::BlockIndexVector) = b.block
304+
BlockArrays.Block(b::BlockIndexVector) = b.block
305+
306+
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))
307+
308+
using ArrayLayouts: LayoutArray
309+
@propagate_inbounds Base.getindex(b::AbstractArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block(
310+
K
311+
)][K.indices...]
312+
@propagate_inbounds Base.getindex(b::LayoutArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block(
313+
K
314+
)][K.indices...]
315+
@propagate_inbounds Base.getindex(b::LayoutArray{T,1}, K::BlockIndexVector{1}) where {T} = b[block(
316+
K
317+
)][K.indices...]
229318

230319
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
231320
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
232321
I′_blocks = map(eachindex(I_blocks)) do b
233322
I_b = findall(I_blocks[b])
234-
BlockIndexVector(Block(b), I_b)
323+
return BlockIndexVector(Block(b), I_b)
235324
end
236325
return mortar(filter(!isempty, I′_blocks))
237326
end

src/abstractblocksparsearray/views.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ to_block(I::BlockIndexRange{1}) = Block(I)
9595
to_block(I::BlockIndexVector) = Block(I)
9696
to_block_indices(I::Block{1}) = Colon()
9797
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98-
to_block_indices(I::BlockIndexVector) = I.indices
98+
to_block_indices(I::BlockIndexVector) = only(I.indices)
9999

100100
function Base.view(
101101
a::AbstractBlockSparseArray{<:Any,N},

0 commit comments

Comments
 (0)