|
1 | 1 | ### This support BLAS style multiplication |
2 | | -# α * A * B + β C |
| 2 | +# A * B * α + C * β |
3 | 3 | # but avoids the broadcast machinery |
4 | 4 |
|
5 | | -# Lazy representation of α*A*B + β*C |
| 5 | +# Lazy representation of A*B*α + C*β |
6 | 6 | struct MulAdd{StyleA, StyleB, StyleC, T, AA, BB, CC} |
7 | 7 | α::T |
8 | 8 | A::AA |
9 | 9 | B::BB |
10 | 10 | β::T |
11 | 11 | C::CC |
| 12 | + Czero::Bool # this flag indicates whether C isa Zeros, or a copy of one |
| 13 | + # the idea is that if Czero == true, then downstream packages don't need to |
| 14 | + # fill C with zero before performing the muladd |
12 | 15 | end |
13 | 16 |
|
14 | | -@inline MulAdd{StyleA,StyleB,StyleC}(α::T, A::AA, B::BB, β::T, C::CC) where {StyleA,StyleB,StyleC,T,AA,BB,CC} = |
15 | | - MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C) |
| 17 | +@inline function MulAdd{StyleA,StyleB,StyleC}(α::T, A::AA, B::BB, β::T, C::CC; |
| 18 | + Czero = C isa Zeros) where {StyleA,StyleB,StyleC,T,AA,BB,CC} |
| 19 | + MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C,Czero) |
| 20 | +end |
16 | 21 |
|
17 | | -@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C) where {StyleA,StyleB,StyleC} |
| 22 | +@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C; kw...) where {StyleA,StyleB,StyleC} |
18 | 23 | α,β = promote(αT,βV) |
19 | | - MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C) |
| 24 | + MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C; kw...) |
20 | 25 | end |
21 | 26 |
|
22 | | -@inline MulAdd(α, A::AA, B::BB, β, C::CC) where {AA,BB,CC} = |
23 | | - MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C) |
| 27 | +@inline MulAdd(α, A::AA, B::BB, β, C::CC; kw...) where {AA,BB,CC} = |
| 28 | + MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C; kw...) |
24 | 29 |
|
25 | 30 | MulAdd(A, B) = MulAdd(Mul(A, B)) |
26 | 31 | function MulAdd(M::Mul) |
@@ -67,15 +72,15 @@ const BlasMatMulVecAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB |
67 | 72 | const BlasMatMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractMatrix{T},<:AbstractMatrix{T},<:AbstractMatrix{T}} |
68 | 73 | const BlasVecMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractVector{T},<:AbstractMatrix{T},<:AbstractMatrix{T}} |
69 | 74 |
|
70 | | -muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C)) |
| 75 | +muladd!(α, A, B, β, C; kw...) = materialize!(MulAdd(α, A, B, β, C; kw...)) |
71 | 76 | materialize(M::MulAdd) = copy(instantiate(M)) |
72 | 77 | copy(M::MulAdd) = copyto!(similar(M), M) |
73 | 78 |
|
74 | 79 | _fill_copyto!(dest, C) = copyto!(dest, C) |
75 | 80 | _fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload |
76 | 81 |
|
77 | 82 | @inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T = |
78 | | - muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C)) |
| 83 | + muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C); Czero = M.Czero) |
79 | 84 |
|
80 | 85 | # Modified from LinearAlgebra._generic_matmatmul! |
81 | 86 | const tilebufsize = 10800 # Approximately 32k/3 |
|
0 commit comments