Skip to content

Views of KroneckerArray #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.20"
version = "0.1.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
[compat]
Adapt = "4.3.0"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.20"
BlockSparseArrays = "0.7.21"
DerivableInterfaces = "0.5.0"
DiagonalArrays = "0.3.5"
FillArrays = "1.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ module KroneckerArraysBlockSparseArraysExt
using BlockArrays: Block
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
using KroneckerArrays: CartesianPair, CartesianProduct
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
function Base.getindex(
b::Block,
I1::Union{CartesianPair,CartesianProduct},
Irest::Union{CartesianPair,CartesianProduct}...,
)
return GenericBlockIndex(b, (I1, Irest...))
end
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)
Expand Down
1 change: 1 addition & 0 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a
_view(a::Eye, I1::Colon, I2::Colon) = a

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a
Expand Down
11 changes: 11 additions & 0 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ end
# Fix ambigiuity error.
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]

# Allow customizing for `FillArrays.Eye`.
_view(a::AbstractArray, I...) = view(a, I...)
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...)
end
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...)
end
# Fix ambigiuity error.
Base.view(a::KroneckerArray{<:Any,0}) = _view(arg1(a)) * _view(arg2(a))

function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Adapt = "4"
Aqua = "0.8"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.19"
BlockSparseArrays = "0.7.21"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3.7"
FillArrays = "1"
Expand Down
14 changes: 12 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using KroneckerArrays:
CartesianProductUnitRange,
⊗,
×,
arg1,
arg2,
cartesianproduct,
cartesianrange,
kron_nd,
Expand Down Expand Up @@ -67,8 +69,8 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test x == y

a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3))
b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c = a.a ⊗ b.b
b = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3))
c = @constinferred(a.a ⊗ b.b)
@test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
@test similar(typeof(a), (2, 3)) isa Matrix{elt}
@test size(similar(typeof(a), (2, 3))) == (2, 3)
Expand Down Expand Up @@ -101,6 +103,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test tr(a) ≈ tr(collect(a))
@test norm(a) ≈ norm(collect(a))

# Views
a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3))
b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3)))
@test arg1(b) === view(arg1(a), 1:2, 1:2)
@test arg1(b) == arg1(a)[1:2, 1:2]
@test arg2(b) === view(arg2(a), 2:3, 2:3)
@test arg2(b) == arg2(a)[2:3, 2:3]

# Broadcasting
a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b)))
Expand Down
91 changes: 90 additions & 1 deletion test/test_blocksparsearrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Adapt: adapt
using BlockArrays: Block, BlockRange
using BlockArrays: Block, BlockRange, mortar
using BlockSparseArrays:
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
using FillArrays: Eye, SquareEye
Expand Down Expand Up @@ -38,22 +38,69 @@ arrayts = (Array, JLArray)
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3))
@test a[Block(1, 2)] isa valtype(d)

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = mortar([i1, i2])
b = @view a[I, I]
b′ = copy(b)
@test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
@test_broken b[Block(1, 2)]

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = [i1, i2]
b = @view a[I, I]
b′ = copy(b)
@test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
@test_broken b[Block(1, 2)]

# Matrix multiplication
b = a * a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) * Array(a)

# Addition (mapping, broadcasting)
b = a + a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) + Array(a)

# Scaling (mapping, broadcasting)
b = 3a
@test typeof(b) === typeof(a)
@test Array(b) ≈ 3Array(a)

# Dividing (mapping, broadcasting)
b = a / 3
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) / 3

# Norm
@test norm(a) ≈ norm(Array(a))

if arrayt === Array
Expand Down Expand Up @@ -102,6 +149,48 @@ end
@test a[Block(1, 2)] == dev(Eye(2, 3) ⊗ zeros(elt, 2, 3))
@test a[Block(1, 2)] isa valtype(d)

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = mortar([i1, i2])
b = @view a[I, I]
@test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
@test_broken copy(b)
@test_broken b[Block(1, 2)]

# Slicing
r = blockrange([2 × 2, 3 × 3])
d = Dict(
Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)),
Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)),
)
a = dev(blocksparse(d, r, r))
i1 = Block(1)[(1:2) × (1:2)]
i2 = Block(2)[(2:3) × (2:3)]
I = [i1, i2]
b = @view a[I, I]
@test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
@test_broken copy(b)
@test_broken b[Block(1, 2)]

b = @constinferred a * a
@test typeof(b) === typeof(a)
@test Array(b) ≈ Array(a) * Array(a)
Expand Down
17 changes: 16 additions & 1 deletion test/test_fillarrays.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using DerivableInterfaces: zero!
using FillArrays: Eye, Zeros
using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗
using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2
using LinearAlgebra: det, norm, pinv
using StableRNGs: StableRNG
using Test: @test, @test_throws, @testset
using TestExtras: @constinferred

@testset "FillArrays.Eye" begin
MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS
Expand All @@ -18,12 +19,26 @@ using Test: @test, @test_throws, @testset
@test 2a == Eye(2) ⊗ (2a.b)
@test a * a == Eye(2) ⊗ (a.b * a.b)

# Views
a = @constinferred(Eye(2) ⊗ randn(3, 3))
b = @constinferred(view(a, (:) × (2:3), (:) × (2:3)))
@test arg1(b) === Eye(2)
@test arg2(b) === view(arg2(a), 2:3, 2:3)
@test arg2(b) == arg2(a)[2:3, 2:3]

a = randn(3, 3) ⊗ Eye(2)
@test size(a) == (6, 6)
@test a + a == (2a.a) ⊗ Eye(2)
@test 2a == (2a.a) ⊗ Eye(2)
@test a * a == (a.a * a.a) ⊗ Eye(2)

# Views
a = @constinferred(randn(3, 3) ⊗ Eye(2))
b = @constinferred(view(a, (2:3) × (:), (2:3) × (:)))
@test arg1(b) === view(arg1(a), 2:3, 2:3)
@test arg1(b) == arg1(a)[2:3, 2:3]
@test arg2(b) === Eye(2)

# similar
a = Eye(2) ⊗ randn(3, 3)
for a′ in (
Expand Down
Loading