diff --git a/Project.toml b/Project.toml index 1177da1..d1c03f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.20" +version = "0.1.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" BlockArrays = "1.6" -BlockSparseArrays = "0.7.20" +BlockSparseArrays = "0.7.21" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" FillArrays = "1.13.0" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 114674e..9fa1018 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -3,7 +3,11 @@ module KroneckerArraysBlockSparseArraysExt using BlockArrays: Block using BlockSparseArrays: BlockIndexVector, GenericBlockIndex using KroneckerArrays: CartesianPair, CartesianProduct -function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...) +function Base.getindex( + b::Block, + I1::Union{CartesianPair,CartesianProduct}, + Irest::Union{CartesianPair,CartesianProduct}..., +) return GenericBlockIndex(b, (I1, Irest...)) end function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index db51fc9..fc40ef7 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -22,6 +22,7 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} _getindex(a::Eye, I1::Colon, I2::Colon) = a +_view(a::Eye, I1::Colon, I2::Colon) = a # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 636c1cd..94d601a 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -178,6 +178,17 @@ end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] +# Allow customizing for `FillArrays.Eye`. +_view(a::AbstractArray, I...) = view(a, I...) +function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} + return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...) +end +function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} + return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...) +end +# Fix ambigiuity error. +Base.view(a::KroneckerArray{<:Any,0}) = _view(arg1(a)) * _view(arg2(a)) + function Base.:(==)(a::KroneckerArray, b::KroneckerArray) return arg1(a) == arg1(b) && arg2(a) == arg2(b) end diff --git a/test/Project.toml b/test/Project.toml index 9f9c5a9..16b4265 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Adapt = "4" Aqua = "0.8" BlockArrays = "1.6" -BlockSparseArrays = "0.7.19" +BlockSparseArrays = "0.7.21" DerivableInterfaces = "0.5" DiagonalArrays = "0.3.7" FillArrays = "1" diff --git a/test/test_basics.jl b/test/test_basics.jl index c8fa835..0fe2ef3 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -11,6 +11,8 @@ using KroneckerArrays: CartesianProductUnitRange, ⊗, ×, + arg1, + arg2, cartesianproduct, cartesianrange, kron_nd, @@ -67,8 +69,8 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test x == y a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) - b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) - c = a.a ⊗ b.b + b = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) + c = @constinferred(a.a ⊗ b.b) @test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)} @test similar(typeof(a), (2, 3)) isa Matrix{elt} @test size(similar(typeof(a), (2, 3))) == (2, 3) @@ -101,6 +103,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test tr(a) ≈ tr(collect(a)) @test norm(a) ≈ norm(collect(a)) + # Views + a = @constinferred(randn(elt, 2, 2) ⊗ randn(elt, 3, 3)) + b = @constinferred(view(a, (1:2) × (2:3), (1:2) × (2:3))) + @test arg1(b) === view(arg1(a), 1:2, 1:2) + @test arg1(b) == arg1(a)[1:2, 1:2] + @test arg2(b) === view(arg2(a), 2:3, 2:3) + @test arg2(b) == arg2(a)[2:3, 2:3] + # Broadcasting a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b))) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 1d37e15..bd88d9d 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,5 +1,5 @@ using Adapt: adapt -using BlockArrays: Block, BlockRange +using BlockArrays: Block, BlockRange, mortar using BlockSparseArrays: BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype using FillArrays: Eye, SquareEye @@ -38,22 +38,69 @@ arrayts = (Array, JLArray) @test a[Block(1, 2)] == dev(zeros(elt, 2, 3) ⊗ zeros(elt, 2, 3)) @test a[Block(1, 2)] isa valtype(d) + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == + a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] + @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] + @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = mortar([i1, i2]) + b = @view a[I, I] + b′ = copy(b) + @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken b[Block(1, 2)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = [i1, i2] + b = @view a[I, I] + b′ = copy(b) + @test b[Block(2, 2)] == b′[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken b[Block(1, 2)] + + # Matrix multiplication b = a * a @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) * Array(a) + # Addition (mapping, broadcasting) b = a + a @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) + Array(a) + # Scaling (mapping, broadcasting) b = 3a @test typeof(b) === typeof(a) @test Array(b) ≈ 3Array(a) + # Dividing (mapping, broadcasting) b = a / 3 @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) / 3 + # Norm @test norm(a) ≈ norm(Array(a)) if arrayt === Array @@ -102,6 +149,48 @@ end @test a[Block(1, 2)] == dev(Eye(2, 3) ⊗ zeros(elt, 2, 3)) @test a[Block(1, 2)] isa valtype(d) + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == + a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] + @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] + @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = mortar([i1, i2]) + b = @view a[I, I] + @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken copy(b) + @test_broken b[Block(1, 2)] + + # Slicing + r = blockrange([2 × 2, 3 × 3]) + d = Dict( + Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), + Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), + ) + a = dev(blocksparse(d, r, r)) + i1 = Block(1)[(1:2) × (1:2)] + i2 = Block(2)[(2:3) × (2:3)] + I = [i1, i2] + b = @view a[I, I] + @test b[Block(2, 2)] == a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] + @test_broken copy(b) + @test_broken b[Block(1, 2)] + b = @constinferred a * a @test typeof(b) === typeof(a) @test Array(b) ≈ Array(a) * Array(a) diff --git a/test/test_fillarrays.jl b/test/test_fillarrays.jl index f852ef9..62e0234 100644 --- a/test/test_fillarrays.jl +++ b/test/test_fillarrays.jl @@ -1,9 +1,10 @@ using DerivableInterfaces: zero! using FillArrays: Eye, Zeros -using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗ +using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, ×, arg1, arg2 using LinearAlgebra: det, norm, pinv using StableRNGs: StableRNG using Test: @test, @test_throws, @testset +using TestExtras: @constinferred @testset "FillArrays.Eye" begin MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS @@ -18,12 +19,26 @@ using Test: @test, @test_throws, @testset @test 2a == Eye(2) ⊗ (2a.b) @test a * a == Eye(2) ⊗ (a.b * a.b) + # Views + a = @constinferred(Eye(2) ⊗ randn(3, 3)) + b = @constinferred(view(a, (:) × (2:3), (:) × (2:3))) + @test arg1(b) === Eye(2) + @test arg2(b) === view(arg2(a), 2:3, 2:3) + @test arg2(b) == arg2(a)[2:3, 2:3] + a = randn(3, 3) ⊗ Eye(2) @test size(a) == (6, 6) @test a + a == (2a.a) ⊗ Eye(2) @test 2a == (2a.a) ⊗ Eye(2) @test a * a == (a.a * a.a) ⊗ Eye(2) + # Views + a = @constinferred(randn(3, 3) ⊗ Eye(2)) + b = @constinferred(view(a, (2:3) × (:), (2:3) × (:))) + @test arg1(b) === view(arg1(a), 2:3, 2:3) + @test arg1(b) == arg1(a)[2:3, 2:3] + @test arg2(b) === Eye(2) + # similar a = Eye(2) ⊗ randn(3, 3) for a′ in (