Skip to content

Commit 9ca6c3d

Browse files
authored
Merge pull request #8 from JuliaLinearAlgebra/copymethods
Add copy, lmul! and rmul! methods
2 parents 040521b + 7545ae3 commit 9ca6c3d

File tree

7 files changed

+66
-9
lines changed

7 files changed

+66
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "RectangularFullPacked"
22
uuid = "27983f2f-6524-42ba-a408-2b5a31c238e4"
3-
version = "0.1.0"
3+
version = "0.2.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/RectangularFullPacked.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using LinearAlgebra
66
using LinearAlgebra: BlasFloat, checksquare
77

88
import Base: \
9+
import LinearAlgebra.BLAS: syrk!
10+
import LinearAlgebra: Hermitian
911

1012
abstract type AbstractRFP{T} <: AbstractMatrix{T} end
1113

src/cholesky.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@ struct CholeskyRFP{T<:BlasFloat} <: Factorization{T}
44
uplo::Char
55
end
66

7-
LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat} =
8-
CholeskyRFP(LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data), A.transr, A.uplo)
7+
function LinearAlgebra.cholesky!(A::HermitianRFP{T}) where {T<:BlasFloat}
8+
return CholeskyRFP(
9+
LAPACK_RFP.pftrf!(A.transr, A.uplo, A.data),
10+
A.transr,
11+
A.uplo,
12+
)
13+
end
914
LinearAlgebra.cholesky(A::HermitianRFP{T}) where {T<:BlasFloat} = cholesky!(copy(A))
1015
LinearAlgebra.factorize(A::HermitianRFP) = cholesky(A)
1116

src/hermitian.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ end
99

1010
#HermitianRFP(A::TriangularRFP) = HermitianRFP(A.data, A.transr, A.uplo)
1111

12+
function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal}, uplo::Symbol)
13+
Symbol(A.uplo) == uplo ||
14+
throw(ArgumentError("A.uplo = $(A.uplo) conflicts with argument uplo = $uplo"))
15+
return Hermitian(A)
16+
end
17+
18+
function Hermitian(A::TriangularRFP{<:LinearAlgebra.BlasReal})
19+
return HermitianRFP(A.data, A.transr, A.uplo)
20+
end
21+
22+
Base.copy(A::HermitianRFP{T}) where {T} = HermitianRFP{T}(copy(A.data), A.transr, A.uplo)
23+
1224
function Base.getindex(A::HermitianRFP, i::Integer, j::Integer)
1325
(A.uplo == 'L' ? i < j : i > j) && return conj(getindex(A, j, i))
1426
n, k, l = checkbounds(A, i, j)
@@ -28,4 +40,16 @@ function Ac_mul_A_RFP(A::Matrix{T}, uplo = :U) where {T<:BlasFloat}
2840
return HermitianRFP(LAPACK_RFP.sfrk!('N', ul, tr, 1.0, A, 0.0, par), 'N', ul)
2941
end
3042

31-
Base.copy(A::HermitianRFP) = HermitianRFP(copy(A.data), A.transr, A.uplo)
43+
function syrk!(
44+
trans::AbstractChar,
45+
α::Real,
46+
A::StridedMatrix{T},
47+
β::Real,
48+
C::HermitianRFP{T},
49+
) where {T}
50+
return HermitianRFP(
51+
LAPACK_RFP.sfrk!(C.transr, C.uplo, Char(trans), α, A, β, C.data),
52+
C.transr,
53+
C.uplo,
54+
)
55+
end

src/triangular.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct TriangularRFP{T<:BlasFloat} <: AbstractRFP{T}
44
uplo::Char
55
end
66

7-
function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where {T}
7+
function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol = :N) where {T}
88
n = checksquare(A)
99
ul = first(string(uplo))
1010
if ul "UL"
@@ -21,7 +21,7 @@ function TriangularRFP(A::Matrix{T}, uplo::Symbol = :U; transr::Symbol=:N) where
2121
ul,
2222
)
2323
end
24-
24+
2525
function Base.Array(A::TriangularRFP{T}) where {T}
2626
n, k, l = _rfpsize(A)
2727
C = Array{T}(undef, (n, n))
@@ -36,7 +36,7 @@ function Base.getindex(A::TriangularRFP{T}, i::Integer, j::Integer) where {T}
3636
(A.uplo == 'L' ? i < j : i > j) && return zero(T)
3737
rs, doconj = _packedinds(A, Int(i), Int(j), iseven(n), l)
3838
val = A.data[first(rs), last(rs)]
39-
return doconj ? conj(val) : val
39+
return doconj ? conj(val) : val
4040
end
4141

4242
function Base.setindex!(A::TriangularRFP{T}, x::T, i::Integer, j::Integer) where {T}

src/utilities.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323
function _packedinds(i::Int, j::Int, lower::Bool, neven::Bool, tr::Bool, l::Int)
2424
if lower
2525
conj = l < j
26-
inds = conj ? (j - l, i + !neven - l) : (i + neven, j)
26+
inds = conj ? (j - l, i + !neven - l) : (i + neven, j)
2727
else
2828
conj = (j + !neven) l
2929
inds = conj ? (l + neven + j, i) : (i, j + !neven - l)
@@ -55,7 +55,8 @@ function _rfpsize(A::AbstractRFP)
5555
dsz = size(A.data)
5656
k, l = A.transr == 'N' ? dsz : reverse(dsz)
5757
L = 2l
58-
isone(abs(k - L)) || throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP"))
58+
isone(abs(k - L)) ||
59+
throw(ArgumentError("size(A.data) = $dsz is not consistent with RFP"))
5960
return k - (L < k), k, l
6061
end
6162

@@ -73,3 +74,13 @@ function Base.size(A::AbstractRFP)
7374
n, k, l = _rfpsize(A)
7475
return (n, n)
7576
end
77+
78+
function LinearAlgebra.rmul!(A::AbstractRFP, B::Number)
79+
rmul!(A.data, B)
80+
return A
81+
end
82+
83+
function LinearAlgebra.lmul!(A::Number, B::AbstractRFP)
84+
lmul!(A, B.data)
85+
return B
86+
end

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP
6969
A = rand(elty, 10, n)
7070
AcA = A'A
7171
AcA_RFP = Ac_mul_A_RFP(A, uplo)
72+
@test AcA_RFP BLAS.syrk!(elty <: Complex ? 'C' : 'T', 1.0, A, 0.0, copy(AcA_RFP))
7273
o = ones(elty, n)
7374

7475
@test AcA AcA_RFP
@@ -97,4 +98,18 @@ import RectangularFullPacked: Ac_mul_A_RFP, TriangularRFP
9798
@test A \ o A_RFP \ o
9899
@test inv(A) Array(inv(A_RFP))
99100
end
101+
102+
@testset "In-place scalar multiplication" begin
103+
U = lu(rand(7, 7)).U
104+
B = sqrt(π)
105+
@test rmul!(copy(U), B) rmul!(TriangularRFP(U, :U), B)
106+
@test lmul!(B, copy(U)) lmul!(B, TriangularRFP(U, :U; transr=:T))
107+
end
108+
109+
@testset "Hermitian from Triangular" begin
110+
U = lu(rand(7,7)).U
111+
@test Hermitian(TriangularRFP(U, :U)) Hermitian(U, :U)
112+
@test Hermitian(TriangularRFP(U, :U), :U) Hermitian(U, :U)
113+
@test_throws ArgumentError Hermitian(TriangularRFP(U, :U), :L)
114+
end
100115
end

0 commit comments

Comments
 (0)