Skip to content

Commit 8d98411

Browse files
authored
Add diagtrav function with dims keyword (#151)
* Add diagtrav function with dims keyword * add tests * Update test_blockkron.jl * remove moved code
1 parent 68deb82 commit 8d98411

File tree

5 files changed

+29
-30
lines changed

5 files changed

+29
-30
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LazyBandedMatrices"
22
uuid = "d7e5e226-e90b-4449-9968-0f923699bf6f"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.11.6"
4+
version = "0.11.7"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -28,7 +28,7 @@ BlockArrays = "1.0"
2828
BlockBandedMatrices = "0.13"
2929
FillArrays = "1.0"
3030
InfiniteArrays = "0.15"
31-
LazyArrays = "2.2.3"
31+
LazyArrays = "2.8"
3232
MatrixFactorizations = "3.0"
3333
StaticArrays = "1.0"
3434
julia = "1.10"

src/LazyBandedMatrices.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ const LazyBlockBandedLayouts = LazyArraysBlockBandedMatricesExt.LazyBlockBandedL
3838

3939
export DiagTrav, KronTrav, blockkron, BlockKron, BlockBroadcastArray, BlockVcat, BlockHcat, BlockHvcat, unitblocks
4040

41+
## TODO: export diagtrav, invdiagtrav
42+
4143
include("tridiag.jl")
4244
include("bidiag.jl")
4345
include("special.jl")

src/blockconcat.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -492,30 +492,7 @@ resize!(c::BlockBroadcastVector{T,typeof(vcat)}, N::Block{1}) where T = BlockBro
492492
# BlockVec
493493
####
494494

495-
# support LazyArrays v2.8 where BlockVec is moved
496-
if isdefined(LazyBandedMatrices.LazyArraysBlockArraysExt, :BlockVec)
497-
const BlockVec = LazyBandedMatrices.LazyArraysBlockArraysExt.BlockVec
498-
else
499-
const BlockVec{T, M<:AbstractMatrix{T}} = ApplyVector{T, typeof(blockvec), <:Tuple{M}}
500-
501-
BlockVec{T}(M::AbstractMatrix{T}) where T = ApplyVector{T}(blockvec, M)
502-
BlockVec(M::AbstractMatrix{T}) where T = BlockVec{T}(M)
503-
axes(b::BlockVec) = (blockedrange(Fill(size(b.args[1])...)),)
504-
size(b::BlockVec) = (length(b.args[1]),)
505-
506-
view(b::BlockVec, K::Block{1}) = view(b.args[1], :, Int(K))
507-
Base.@propagate_inbounds getindex(b::BlockVec, k::Int) = b.args[1][k]
508-
Base.@propagate_inbounds setindex!(b::BlockVec, v, k::Int) = setindex!(b.args[1], v, k)
509-
510-
_resize!(A::AbstractMatrix, m, n) = A[1:m, 1:n]
511-
_resize!(At::Transpose, m, n) = transpose(transpose(At)[1:n, 1:m])
512-
_resize!(Ac::Adjoint, m, n) = (Ac')[1:n, 1:m]'
513-
resize!(b::BlockVec, K::Block{1}) = BlockVec(_resize!(b.args[1], size(b.args[1],1), Int(K)))
514-
515-
applylayout(::Type{typeof(blockvec)}, ::AbstractPaddedLayout) = PaddedColumns{ApplyLayout{typeof(blockvec)}}()
516-
paddeddata(b::BlockVec) = BlockVec(paddeddata(b.args[1]))
517-
end
518-
495+
const BlockVec = LazyBandedMatrices.LazyArraysBlockArraysExt.BlockVec
519496

520497
####
521498
# summary

src/blockkron.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,22 @@ function getindex(A::InvDiagTrav{T}, k::Int, j::Int) where T
237237
end
238238
end
239239

240-
invdiagtrav(a) = InvDiagTrav(a)
240+
diagtrav(a::AbstractMatrix) = DiagTrav(a)
241+
function diagtrav(a::AbstractArray{T,3}; dims=1:3) where T
242+
if dims == 1:3
243+
DiagTrav(a)
244+
else
245+
@assert dims == 1:2
246+
ret = BlockedMatrix{T}(undef, (_krontrav_axes(axes(a,1), axes(a,2)), axes(a,3)))
247+
forin axes(a,3)
248+
ret[:,ℓ] = DiagTrav(view(a,:,:,ℓ))
249+
end
250+
ret
251+
end
252+
end
253+
254+
diagtrav(a::InvDiagTrav) = a.vector
255+
invdiagtrav(a::AbstractVector) = InvDiagTrav(a)
241256
invdiagtrav(a::DiagTrav) = a.array
242257

243258
-(A::DiagTrav) = DiagTrav(-A.array)

test/test_blockkron.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LazyBandedMatrices, FillArrays, BandedMatrices, BlockBandedMatrices, Block
44

55
using LinearAlgebra
66
import BlockBandedMatrices: isbandedblockbanded, isbanded, BandedBlockBandedStyle, BandedLayout, _BandedBlockBandedMatrix
7-
import LazyBandedMatrices: KronTravBandedBlockBandedLayout, BroadcastBandedLayout, BroadcastBandedBlockBandedLayout, arguments, call, blockcolsupport, InvDiagTrav, invdiagtrav, pad, krontrav
7+
import LazyBandedMatrices: KronTravBandedBlockBandedLayout, BroadcastBandedLayout, BroadcastBandedBlockBandedLayout, arguments, call, blockcolsupport, InvDiagTrav, invdiagtrav, pad, krontrav, diagtrav
88
import ArrayLayouts: FillLayout, OnesLayout
99
import LazyArrays: resizedata!, FillLayout, arguments, colsupport, call, LazyArrayStyle
1010
import BandedMatrices: BandedColumns
@@ -24,7 +24,7 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
2424
@testset "Kron" begin
2525
@testset "DiagTrav" begin
2626
A = [1 2 3; 4 5 6; 7 8 9]
27-
@test DiagTrav(A) == Vector(DiagTrav(A)) == [1, 4, 2, 7, 5, 3]
27+
@test DiagTrav(A) == Vector(DiagTrav(A)) == diagtrav(A) == [1, 4, 2, 7, 5, 3]
2828
@test resize!(DiagTrav(A), Block(2)) == [1, 4,2]
2929
@test maximum(abs, DiagTrav(A)) == 7
3030
@test copy(DiagTrav(A)) == DiagTrav(A)
@@ -51,14 +51,18 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
5151
@test resize!(DiagTrav(A), Block(2)) == [1, 3,2]
5252

5353
A = DiagTrav(randn(3,3,3))
54+
@test A == diagtrav(A.array)
5455
@test A[Block(1)] == A[1:1,1,1]
5556
@test A[Block(2)] == [A.array[2,1,1], A.array[1,2,1], A.array[1,1,2]]
5657
@test A[Block(3)] == [A.array[3,1,1], A.array[2,2,1], A.array[2,1,2],
5758
A.array[1,3,1], A.array[1,2,2], A.array[1,1,3]]
5859
@test A == [A[Block(1)]; A[Block(2)]; A[Block(3)]]
5960

61+
A = randn(3,3,3)
62+
@test diagtrav(A; dims=1:2) == [diagtrav(A[:,:,1]) diagtrav(A[:,:,2]) diagtrav(A[:,:,3])]
63+
6064
A = reshape(1:9,3,3)'
61-
@test DiagTrav(A) == Vector(DiagTrav(A)) == [1, 4, 2, 7, 5, 3]
65+
@test DiagTrav(A) == Vector(DiagTrav(A)) == diagtrav(A) == [1, 4, 2, 7, 5, 3]
6266
A = reshape(1:12,3,4)'
6367
@test DiagTrav(A) == [1, 4, 2, 7, 5, 3, 10, 8, 6]
6468
A = reshape(1:12,3,4)
@@ -84,6 +88,7 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
8488
@testset "InvDiagTrav" begin
8589
A = [1 2 3; 4 5 6; 7 8 9]
8690
@test invdiagtrav(BlockedVector(DiagTrav(A))) == invdiagtrav(DiagTrav(A)) == [1 2 3; 4 5 0; 7 0 0]
91+
@test diagtrav(invdiagtrav(diagtrav(A))) == diagtrav(invdiagtrav(BlockedVector(diagtrav(A)))) == diagtrav(A)
8792
end
8893

8994
@testset "BlockKron" begin

0 commit comments

Comments
 (0)