Skip to content

Commit 6ff9f93

Browse files
authored
Views of KroneckerArray (#27)
1 parent 9ff937e commit 6ff9f93

File tree

8 files changed

+138
-8
lines changed

8 files changed

+138
-8
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.20"
4+
version = "0.1.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3.0"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.7.20"
26+
BlockSparseArrays = "0.7.21"
2727
DerivableInterfaces = "0.5.0"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ module KroneckerArraysBlockSparseArraysExt
33
using BlockArrays: Block
44
using BlockSparseArrays: BlockIndexVector, GenericBlockIndex
55
using KroneckerArrays: CartesianPair, CartesianProduct
6-
function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...)
6+
function Base.getindex(
7+
b::Block,
8+
I1::Union{CartesianPair,CartesianProduct},
9+
Irest::Union{CartesianPair,CartesianProduct}...,
10+
)
711
return GenericBlockIndex(b, (I1, Irest...))
812
end
913
function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...)

src/fillarrays/kroneckerarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr
2222
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2323

2424
_getindex(a::Eye, I1::Colon, I2::Colon) = a
25+
_view(a::Eye, I1::Colon, I2::Colon) = a
2526

2627
# Like `adapt` but preserves `Eye`.
2728
_adapt(to, a::Eye) = a

src/kroneckerarray.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ end
178178
# Fix ambigiuity error.
179179
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]
180180

181+
# Allow customizing for `FillArrays.Eye`.
182+
_view(a::AbstractArray, I...) = view(a, I...)
183+
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
184+
return _view(arg1(a), arg1.(I)...) _view(arg2(a), arg2.(I)...)
185+
end
186+
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
187+
return _view(arg1(a), arg1.(I)...) _view(arg2(a), arg2.(I)...)
188+
end
189+
# Fix ambigiuity error.
190+
Base.view(a::KroneckerArray{<:Any,0}) = _view(arg1(a)) * _view(arg2(a))
191+
181192
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
182193
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
183194
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
2121
Adapt = "4"
2222
Aqua = "0.8"
2323
BlockArrays = "1.6"
24-
BlockSparseArrays = "0.7.19"
24+
BlockSparseArrays = "0.7.21"
2525
DerivableInterfaces = "0.5"
2626
DiagonalArrays = "0.3.7"
2727
FillArrays = "1"

test/test_basics.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ using KroneckerArrays:
1111
CartesianProductUnitRange,
1212
,
1313
×,
14+
arg1,
15+
arg2,
1416
cartesianproduct,
1517
cartesianrange,
1618
kron_nd,
@@ -67,8 +69,8 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
6769
@test x == y
6870

6971
a = @constinferred(randn(elt, 2, 2) randn(elt, 3, 3))
70-
b = randn(elt, 2, 2) randn(elt, 3, 3)
71-
c = a.a b.b
72+
b = @constinferred(randn(elt, 2, 2) randn(elt, 3, 3))
73+
c = @constinferred(a.a b.b)
7274
@test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
7375
@test similar(typeof(a), (2, 3)) isa Matrix{elt}
7476
@test size(similar(typeof(a), (2, 3))) == (2, 3)
@@ -101,6 +103,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
101103
@test tr(a) tr(collect(a))
102104
@test norm(a) norm(collect(a))
103105

106+
# Views
107+
a = @constinferred(randn(elt, 2, 2) randn(elt, 3, 3))
108+
b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3)))
109+
@test arg1(b) === view(arg1(a), 1:2, 1:2)
110+
@test arg1(b) == arg1(a)[1:2, 1:2]
111+
@test arg2(b) === view(arg2(a), 2:3, 2:3)
112+
@test arg2(b) == arg2(a)[2:3, 2:3]
113+
104114
# Broadcasting
105115
a = randn(elt, 2, 2) randn(elt, 3, 3)
106116
style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b)))

test/test_blocksparsearrays.jl

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Adapt: adapt
2-
using BlockArrays: Block, BlockRange
2+
using BlockArrays: Block, BlockRange, mortar
33
using BlockSparseArrays:
44
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
55
using FillArrays: Eye, SquareEye
@@ -38,22 +38,69 @@ arrayts = (Array, JLArray)
3838
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) zeros(elt, 2, 3))
3939
@test a[Block(1, 2)] isa valtype(d)
4040

41+
# Slicing
42+
r = blockrange([2 × 2, 3 × 3])
43+
d = Dict(
44+
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
45+
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
46+
)
47+
a = dev(blocksparse(d, r, r))
48+
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
49+
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
50+
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
51+
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
52+
53+
# Slicing
54+
r = blockrange([2 × 2, 3 × 3])
55+
d = Dict(
56+
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
57+
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
58+
)
59+
a = dev(blocksparse(d, r, r))
60+
i1 = Block(1)[(1:2) × (1:2)]
61+
i2 = Block(2)[(2:3) × (2:3)]
62+
I = mortar([i1, i2])
63+
b = @view a[I, I]
64+
b′ = copy(b)
65+
@test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
66+
@test_broken b[Block(1, 2)]
67+
68+
# Slicing
69+
r = blockrange([2 × 2, 3 × 3])
70+
d = Dict(
71+
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
72+
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
73+
)
74+
a = dev(blocksparse(d, r, r))
75+
i1 = Block(1)[(1:2) × (1:2)]
76+
i2 = Block(2)[(2:3) × (2:3)]
77+
I = [i1, i2]
78+
b = @view a[I, I]
79+
b′ = copy(b)
80+
@test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
81+
@test_broken b[Block(1, 2)]
82+
83+
# Matrix multiplication
4184
b = a * a
4285
@test typeof(b) === typeof(a)
4386
@test Array(b) Array(a) * Array(a)
4487

88+
# Addition (mapping, broadcasting)
4589
b = a + a
4690
@test typeof(b) === typeof(a)
4791
@test Array(b) Array(a) + Array(a)
4892

93+
# Scaling (mapping, broadcasting)
4994
b = 3a
5095
@test typeof(b) === typeof(a)
5196
@test Array(b) 3Array(a)
5297

98+
# Dividing (mapping, broadcasting)
5399
b = a / 3
54100
@test typeof(b) === typeof(a)
55101
@test Array(b) Array(a) / 3
56102

103+
# Norm
57104
@test norm(a) norm(Array(a))
58105

59106
if arrayt === Array
@@ -102,6 +149,48 @@ end
102149
@test a[Block(1, 2)] == dev(Eye(2, 3) zeros(elt, 2, 3))
103150
@test a[Block(1, 2)] isa valtype(d)
104151

152+
# Slicing
153+
r = blockrange([2 × 2, 3 × 3])
154+
d = Dict(
155+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
156+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
157+
)
158+
a = dev(blocksparse(d, r, r))
159+
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
160+
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
161+
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
162+
@test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
163+
164+
# Slicing
165+
r = blockrange([2 × 2, 3 × 3])
166+
d = Dict(
167+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
168+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
169+
)
170+
a = dev(blocksparse(d, r, r))
171+
i1 = Block(1)[(1:2) × (1:2)]
172+
i2 = Block(2)[(2:3) × (2:3)]
173+
I = mortar([i1, i2])
174+
b = @view a[I, I]
175+
@test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
176+
@test_broken copy(b)
177+
@test_broken b[Block(1, 2)]
178+
179+
# Slicing
180+
r = blockrange([2 × 2, 3 × 3])
181+
d = Dict(
182+
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
183+
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
184+
)
185+
a = dev(blocksparse(d, r, r))
186+
i1 = Block(1)[(1:2) × (1:2)]
187+
i2 = Block(2)[(2:3) × (2:3)]
188+
I = [i1, i2]
189+
b = @view a[I, I]
190+
@test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]]
191+
@test_broken copy(b)
192+
@test_broken b[Block(1, 2)]
193+
105194
b = @constinferred a * a
106195
@test typeof(b) === typeof(a)
107196
@test Array(b) Array(a) * Array(a)

test/test_fillarrays.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
using DerivableInterfaces: zero!
22
using FillArrays: Eye, Zeros
3-
using KroneckerArrays: KroneckerArrays, KroneckerArray,
3+
using KroneckerArrays: KroneckerArrays, KroneckerArray, , ×, arg1, arg2
44
using LinearAlgebra: det, norm, pinv
55
using StableRNGs: StableRNG
66
using Test: @test, @test_throws, @testset
7+
using TestExtras: @constinferred
78

89
@testset "FillArrays.Eye" begin
910
MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS
@@ -18,12 +19,26 @@ using Test: @test, @test_throws, @testset
1819
@test 2a == Eye(2) (2a.b)
1920
@test a * a == Eye(2) (a.b * a.b)
2021

22+
# Views
23+
a = @constinferred(Eye(2) randn(3, 3))
24+
b = @constinferred(view(a, (:) × (2:3), (:) × (2:3)))
25+
@test arg1(b) === Eye(2)
26+
@test arg2(b) === view(arg2(a), 2:3, 2:3)
27+
@test arg2(b) == arg2(a)[2:3, 2:3]
28+
2129
a = randn(3, 3) Eye(2)
2230
@test size(a) == (6, 6)
2331
@test a + a == (2a.a) Eye(2)
2432
@test 2a == (2a.a) Eye(2)
2533
@test a * a == (a.a * a.a) Eye(2)
2634

35+
# Views
36+
a = @constinferred(randn(3, 3) Eye(2))
37+
b = @constinferred(view(a, (2:3) × (:), (2:3) × (:)))
38+
@test arg1(b) === view(arg1(a), 2:3, 2:3)
39+
@test arg1(b) == arg1(a)[2:3, 2:3]
40+
@test arg2(b) === Eye(2)
41+
2742
# similar
2843
a = Eye(2) randn(3, 3)
2944
for a′ in (

0 commit comments

Comments
 (0)