-
Notifications
You must be signed in to change notification settings - Fork 4
Add BandedPlusSemiseparable and QR #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #53 +/- ##
==========================================
+ Coverage 89.70% 95.41% +5.70%
==========================================
Files 4 5 +1
Lines 272 523 +251
==========================================
+ Hits 244 499 +255
+ Misses 28 24 -4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@TaoChenImperial I added a function level to separate out the pre-computation. Unfortunately there's still allocation: julia> U,V = randn(n,r), randn(n,r); W,S = randn(n,p), randn(n,p); A = BandedPlusSemiseparableQRPerturbedFactors(U,V,W,S,B); storage = zeros(T,A.n), UᵀU_lookup_table(A), ūw̄_sum_lookup_table(A), d_extra_lookup_table(A); @time bandedplussemi_qr!(A, storage...);
0.000354 seconds (12.13 k allocations: 725.844 KiB)I'll put some comments in the review so you can see where some of the issues are. |
| """ | ||
|
|
||
| struct BandedPlusSemiseparableQRPerturbedFactors{T} <: LayoutMatrix{T} | ||
| n::Int # matrix size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to store this as it can be inferred from size(B,n)
|
|
||
| struct BandedPlusSemiseparableQRPerturbedFactors{T} <: LayoutMatrix{T} | ||
| n::Int # matrix size | ||
| r::Int # lower rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be inferred from size(U,2)
| struct BandedPlusSemiseparableQRPerturbedFactors{T} <: LayoutMatrix{T} | ||
| n::Int # matrix size | ||
| r::Int # lower rank | ||
| p::Int # upper rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be inferred from size(W,2)
| n::Int # matrix size | ||
| r::Int # lower rank | ||
| p::Int # upper rank | ||
| l::Int # lower bandwidth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l and m can be inferred from bandwidths(B)
| j = A.j[] | ||
| UᵀA_1 = (A.U[j+1:min(A.l+j+1,A.n),:])'*A.B[j+1:min(A.l+j+1,A.n),j+1] + UᵀU[j+2,:,:]*A.V[j+1,:] #the j+1th column of UᵀA | ||
|
|
||
| k̄ = A.V[j+1,:] + A.Q * A.S[j+1,1:A.p] + A.K * UᵀA_1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is allocating. You should use views.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in all the code below.
| end | ||
|
|
||
| function update_next_submatrix!(A, k̄, b, τ, w̄₁, ū₁, d₁, f, d̄, c₁, c₂, c₃, c₄, c₅, c₆) | ||
| Q_prev = A.Q[:,:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all allocating. Why do we need them?
| ūw̄_sum = zeros(A.n, A.r, A.p) | ||
| ūw̄_sum_current = zeros(A.r, A.p) | ||
| for t in 1:A.n | ||
| ūw̄_sum[t,:,:] = ūw̄_sum_current[:,:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the [:,:]???
| ūw̄_sum_current = zeros(A.r, A.p) | ||
| for t in 1:A.n | ||
| ūw̄_sum[t,:,:] = ūw̄_sum_current[:,:] | ||
| ūw̄_sum_current[:,:] += A.U[t,:] * (A.W[t,1:A.p])' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the [:,:]?
| function fast_UᵀA(U, V, W, S, B, j) | ||
| # compute U[j,end]ᵀ*A[j:end,j:end] where A = tril(UV',-1) + B + triu(WS',1) in O(n) | ||
| n = size(U,1) | ||
| r = size(U,2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| r = size(U,2) | |
| (n,r) = size(U) |
| end | ||
|
|
||
| function bandedplussemi_qr!(A, τ, tables...) | ||
| n = A.n | ||
| n = size(A.B, 1) | ||
| for i in 1 : n-1 | ||
| onestep_qr!(A, τ, tables...) | ||
| end | ||
|
|
||
| A.B[n,n] = A[n,n] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment explaining what this is doing
| Eₛ_prev = copy(A.Eₛ) | ||
| Xₛ_prev = copy(A.Xₛ) | ||
| Yₛ_prev = copy(A.Yₛ) | ||
| Zₛ_prev = copy(A.Zₛ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to make copies?
| b[1:min(A.l-1,length(b))] += (A.Xₛ * A.S[j+1,1:A.p] + A.Yₛ * UᵀA_1 + A.Zₛ[:,1])[2:min(A.l,length(b)+1)] | ||
| b = A.B[j+2:min(l+j+1, n), j+1] | ||
| if l > 0 || m > 0 | ||
| b[1:min(l-1,length(b))] .+= view((A.Xₛ * view(A.S, j+1, 1:p) + A.Yₛ * UᵀA_1 + view(A.Zₛ, :, 1)), 2:min(l,length(b)+1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be rewritten as muladd! calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| b[1:min(l-1,length(b))] .+= view((A.Xₛ * view(A.S, j+1, 1:p) + A.Yₛ * UᵀA_1 + view(A.Zₛ, :, 1)), 2:min(l,length(b)+1)) | |
| kr = 2:min(l,length(b)+1) | |
| v = view(b, kr .- 1) | |
| T = eltype(A) | |
| mul!(v, view(A.Xₛ, kr,:), view(A.S, j+1, 1:p), one(T), one(T)) | |
| mul!(v, view(A.Yₛ,kr,:), UᵀA_1, one(T), one(T)) | |
| v .+= view(A.Zₛ, kr, 1) |
| ūw̄_sum_current = zeros(eltype(A), r, p) | ||
| for t in 1:n | ||
| ūw̄_sum[t,:,:] .= ūw̄_sum_current | ||
| ūw̄_sum_current .+= view(A.U, t, :) * (view(A.W, t, 1:p))' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ūw̄_sum_current .+= view(A.U, t, :) * (view(A.W, t, 1:p))' | |
| mul!(ūw̄_sum_current, view(A.U, t, :), view(A.W, t, 1:p)', one(eltype(A)), one(eltype(A))) |
| b[1:min(A.l-1,length(b))] += (A.Xₛ * A.S[j+1,1:A.p] + A.Yₛ * UᵀA_1 + A.Zₛ[:,1])[2:min(A.l,length(b)+1)] | ||
| b = A.B[j+2:min(l+j+1, n), j+1] | ||
| if l > 0 || m > 0 | ||
| b[1:min(l-1,length(b))] .+= view((A.Xₛ * view(A.S, j+1, 1:p) + A.Yₛ * UᵀA_1 + view(A.Zₛ, :, 1)), 2:min(l,length(b)+1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| b[1:min(l-1,length(b))] .+= view((A.Xₛ * view(A.S, j+1, 1:p) + A.Yₛ * UᵀA_1 + view(A.Zₛ, :, 1)), 2:min(l,length(b)+1)) | |
| kr = 2:min(l,length(b)+1) | |
| v = view(b, kr .- 1) | |
| T = eltype(A) | |
| mul!(v, view(A.Xₛ, kr,:), view(A.S, j+1, 1:p), one(T), one(T)) | |
| mul!(v, view(A.Yₛ,kr,:), UᵀA_1, one(T), one(T)) | |
| v .+= view(A.Zₛ, kr, 1) |
No description provided.