@@ -189,6 +189,23 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
189189 C
190190end
191191
192+ function default_blasmul! (α, A:: AbstractVector , B:: AbstractMatrix , β, C:: AbstractMatrix )
193+ mA, = size (A)
194+ mB, nB = size (B)
195+ 1 == mB || throw (DimensionMismatch (" Dimensions must match" ))
196+ size (C) == (mA, nB) || throw (DimensionMismatch (" Dimensions must match" ))
197+
198+ lmul! (β, C)
199+
200+ (iszero (mA) || iszero (nB)) && return C
201+
202+ for k in colsupport (A), j in rowsupport (B)
203+ _default_blasmul_loop! (α, A, B, β, C, k, j)
204+ end
205+ C
206+ end
207+
208+
192209function _default_blasmul! (:: IndexLinear , α, A:: AbstractMatrix , B:: AbstractVector , β, C:: AbstractVector )
193210 mA, nA = size (A)
194211 mB = length (B)
@@ -266,6 +283,11 @@ function materialize!(M::MatMulVecAdd)
266283 default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
267284end
268285
286+ function materialize! (M:: VecMulMatAdd )
287+ α, A, B, β, C = M. α, M. A, M. B, M. β, M. C
288+ default_blasmul! (α, unalias (C,A), unalias (C,B), iszero (β) ? false : β, C)
289+ end
290+
269291@inline _gemv! (tA, α, A, x, β, y) = BLAS. gemv! (tA, α, unalias (y,A), unalias (y,x), β, y)
270292@inline _gemm! (tA, tB, α, A, B, β, C) = BLAS. gemm! (tA, tB, α, unalias (C,A), unalias (C,B), β, C)
271293
@@ -424,8 +446,7 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh
424446 trans (similar (trans (M. A), T, y))
425447end
426448
427- function similar (M:: MulAdd{<:Any,<:DualLayout,ZerosLayout} , :: Type{T} , (x,y)) where T
428- @assert length (x) == 1
449+ function similar (M:: MulAdd{ScalarLayout,<:DualLayout,ZerosLayout} , :: Type{T} , (x,y)) where T
429450 trans = transtype (M. B)
430451 trans (similar (trans (M. B), T, y))
431452end
@@ -434,3 +455,4 @@ const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}}
434455copy (M:: MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts} ) = M. C
435456copy (M:: MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts} ) = M. C
436457copy (M:: MulAdd{<:Any, <:ZerosLayouts, <:ZerosLayouts} ) = M. C
458+
0 commit comments