1- using ArrayLayouts, FillArrays, Random, StableRNGs, LinearAlgebra, Test
1+ using ArrayLayouts, FillArrays, Random, StableRNGs, LinearAlgebra, Test, Quaternions
22using ArrayLayouts: DenseColumnMajor, AbstractStridedLayout, AbstractColumnMajor, DiagonalLayout, mul, Mul, zero!
33
44Random. seed! (0 )
@@ -89,6 +89,23 @@ Random.seed!(0)
8989 @test mul (A,X) == A* X
9090 @test mul (X,A) == X* A
9191 end
92+
93+ @testset " Diagonal Fill" begin
94+ for (A, B) in (([1 : 4 ;], [3 : 6 ;]), (reshape ([1 : 16 ;],4 ,4 ), reshape (2 .* [1 : 16 ;],4 ,4 )))
95+ D = Diagonal (Fill (3 , 4 ))
96+ M = MulAdd (2 , D, A, 3 , B)
97+ @test copy (M) == mul! (B, D, A, 2 , 3 )
98+ M = MulAdd (1 , D, A, 0 , B)
99+ @test copy (M) == mul! (B, D, A)
100+ end
101+
102+ A, B = [1 : 4 ;], reshape ([3 : 6 ;], 4 , 1 )
103+ D = Diagonal (Fill (3 , 1 ))
104+ M = MulAdd (2 , A, D, 3 , B)
105+ @test copy (M) == (VERSION >= v " 1.9" ? mul! (B, A, D, 2 , 3 ) : 2 * A * D + 3 * B)
106+ M = MulAdd (1 , A, D, 0 , B)
107+ @test copy (M) == (VERSION >= v " 1.9" ? mul! (B, A, D) : A * D)
108+ end
92109 end
93110
94111 @testset " Matrix * Matrix" begin
@@ -98,17 +115,28 @@ Random.seed!(0)
98115 B in (randn (5 ,5 ), view (randn (5 ,5 ),:,:), view (randn (5 ,5 ),1 : 5 ,:),
99116 view (randn (5 ,5 ),1 : 5 ,1 : 5 ), view (randn (5 ,5 ),:,1 : 5 ))
100117 C = similar (B);
118+ D = similar (C);
101119
102120 C .= MulAdd (1.0 ,A,B,0.0 ,C)
103- @test C == BLAS. gemm! (' N' , ' N' , 1.0 , A, B, 0.0 , similar (C) )
121+ @test C == BLAS. gemm! (' N' , ' N' , 1.0 , A, B, 0.0 , D )
104122
105123 C .= MulAdd (2.0 ,A,B,0.0 ,C)
106- @test C == BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 0.0 , similar (C) )
124+ @test C == BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 0.0 , D )
107125
108126 C = copy (B)
109127 C .= MulAdd (2.0 ,A,B,1.0 ,C)
110128 @test C == BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 1.0 , copy (B))
111129 end
130+
131+ A, B = ones (100 , 100 ), ones (100 , 100 )
132+ C = ones (100 , 100 )
133+ C .= MulAdd (2 ,A,B,1 ,C)
134+ @test C ≈ BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 1.0 , copy (B))
135+
136+ A, B = Float64[i+ j for i in 1 : 100 , j in 1 : 100 ], Float64[i+ j for i in 1 : 100 , j in 1 : 100 ]
137+ C = ones (100 , 100 )
138+ C .= MulAdd (2 ,A,B,1 ,C)
139+ @test_broken C ≈ BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 1.0 , copy (B))
112140 end
113141
114142 @testset " gemm Complex" begin
@@ -276,7 +304,8 @@ Random.seed!(0)
276304 vx = view (x,1 : 2 )
277305 vy = view (y,:)
278306 muladd! (2.0 , VA, vx, 3.0 , vy)
279- @test @allocated (muladd! (2.0 , VA, vx, 3.0 , vy)) == 0
307+ # spurious allocations in tests
308+ @test @allocated (muladd! (2.0 , VA, vx, 3.0 , vy)) < 100
280309 end
281310
282311 @testset " BigFloat" begin
@@ -680,7 +709,7 @@ Random.seed!(0)
680709 b = randn (5 )
681710 c = randn (5 ) + im* randn (5 )
682711 d = randn (5 ) + im* randn (5 )
683-
712+
684713 @test ArrayLayouts. dot (a,b) ≈ ArrayLayouts. dotu (a,b) ≈ mul (a' ,b)
685714 @test ArrayLayouts. dot (a,b) ≈ dot (a,b)
686715 @test eltype (Dot (a,1 : 5 )) == Float64
@@ -693,7 +722,7 @@ Random.seed!(0)
693722 @test ArrayLayouts. dot (c,b) == mul (c' ,b)
694723 @test ArrayLayouts. dotu (c,b) == mul (transpose (c),b)
695724 @test ArrayLayouts. dot (c,b) ≈ dot (c,b)
696-
725+
697726 @test ArrayLayouts. dot (a,d) == mul (a' ,d)
698727 @test ArrayLayouts. dotu (a,d) == mul (transpose (a),d)
699728 @test ArrayLayouts. dot (a,d) ≈ dot (a,d)
@@ -730,9 +759,88 @@ Random.seed!(0)
730759 X = randn (rng, ComplexF64, 8 , 4 )
731760 Y = randn (rng, 8 , 2 )
732761 @test mul (Y' ,X) ≈ Y' X
762+
763+ for A in (randn (5 ,5 ), view (randn (5 ,5 ),:,:), view (randn (5 ,5 ),1 : 5 ,:),
764+ view (randn (5 ,5 ),1 : 5 ,1 : 5 ), view (randn (5 ,5 ),:,1 : 5 )),
765+ B in (randn (5 ,5 ), view (randn (5 ,5 ),:,:), view (randn (5 ,5 ),1 : 5 ,:),
766+ view (randn (5 ,5 ),1 : 5 ,1 : 5 ), view (randn (5 ,5 ),:,1 : 5 ))
767+ C = similar (B);
768+ D = similar (C);
769+
770+ C .= MulAdd (1 ,A,B,0 ,C)
771+ @test C ≈ BLAS. gemm! (' N' , ' N' , 1.0 , A, B, 0.0 , D)
772+
773+ C = copy (B)
774+ C .= MulAdd (2 ,A,B,1 ,C)
775+ @test C ≈ BLAS. gemm! (' N' , ' N' , 2.0 , A, B, 1.0 , copy (B))
776+ end
733777 end
734778
735779 @testset " Vec * Adj" begin
736780 @test ArrayLayouts. mul (1 : 5 , (1 : 4 )' ) == (1 : 5 ) * (1 : 4 )'
737781 end
782+
783+ @testset " Fill" begin
784+ mutable struct MFillMat{T} <: FillArrays.AbstractFill{T,2,NTuple{2,Base.OneTo{Int}}}
785+ x :: T
786+ sz :: NTuple{2,Int}
787+ end
788+ MFillMat (x:: T , sz:: NTuple{2,Int} ) where {T} = MFillMat {T} (x, sz)
789+ MFillMat (x:: T , sz:: Vararg{Int,2} ) where {T} = MFillMat {T} (x, sz)
790+ Base. size (M:: MFillMat ) = M. sz
791+ FillArrays. getindex_value (M:: MFillMat ) = M. x
792+ Base. copyto! (M:: MFillMat , A:: Broadcast.Broadcasted ) = (M. x = only (unique (A)); M)
793+ Base. copyto! (M:: MFillMat , A:: Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}} ) = (M. x = only (unique (A)); M)
794+
795+ M = MulAdd (1 , Fill (2 ,4 ,4 ), Fill (3 ,4 ,4 ), 2 , MFillMat (2 ,4 ,4 ))
796+ X = copy (M)
797+ @test X == Fill (28 ,4 ,4 )
798+
799+ M = MulAdd (1 , Fill (2 ,4 ,4 ), Fill (3 ,4 ,4 ), 0 , MFillMat (2 ,4 ,4 ))
800+ X = copy (M)
801+ @test X == Fill (24 ,4 ,4 )
802+ end
803+
804+ @testset " non-commutative" begin
805+ A = [quat (rand (4 )... ) for i in 1 : 4 , j in 1 : 4 ]
806+ B = [quat (rand (4 )... ) for i in 1 : 4 , j in 1 : 4 ]
807+ C = [quat (rand (4 )... ) for i in 1 : 4 , j in 1 : 4 ]
808+ α, β = quat (0 ,0 ,0 ,1 ), quat (0 ,1 ,0 ,0 )
809+ M = MulAdd (α, A, B, β, C)
810+ @test copy (M) ≈ mul! (copy (C), A, B, α, β) ≈ A * B * α + C * β
811+
812+ SA = Symmetric (A)
813+ M = MulAdd (α, SA, B, β, C)
814+ @test copy (M) ≈ mul! (copy (C), SA, B, α, β) ≈ SA * B * α + C * β
815+
816+ B = [quat (rand (4 )... ) for i in 1 : 4 ]
817+ C = [quat (rand (4 )... ) for i in 1 : 4 ]
818+ M = MulAdd (α, A, B, β, C)
819+ @test copy (M) ≈ mul! (copy (C), A, B, α, β) ≈ A * B * α + C * β
820+
821+ M = MulAdd (α, SA, B, β, C)
822+ @test copy (M) ≈ mul! (copy (C), SA, B, α, β) ≈ SA * B * α + C * β
823+
824+ A = [quat (rand (4 )... ) for i in 1 : 4 ]
825+ B = [quat (rand (4 )... ) for i in 1 : 1 , j in 1 : 1 ]
826+ C = [quat (rand (4 )... ) for i in 1 : 4 , j in 1 : 1 ]
827+ M = MulAdd (α, A, B, β, C)
828+ @test copy (M) ≈ mul! (copy (C), A, B, α, β) ≈ A * B * α + C * β
829+
830+ D = Diagonal (Fill (quat (rand (4 )... ), 4 ))
831+ b = [quat (rand (4 )... ) for i in 1 : 4 ]
832+ c = [quat (rand (4 )... ) for i in 1 : 4 ]
833+ M = MulAdd (α, D, b, β, c)
834+ @test copy (M) ≈ mul! (copy (c), D, b, α, β) ≈ D * b * α + c * β
835+
836+ D = Diagonal (Fill (quat (rand (4 )... ), 1 ))
837+ b = [quat (rand (4 )... ) for i in 1 : 4 ]
838+ c = [quat (rand (4 )... ) for i in 1 : 4 , j in 1 : 1 ]
839+ M = MulAdd (α, b, D, β, c)
840+ if VERSION >= v " 1.9"
841+ @test copy (M) ≈ mul! (copy (c), b, D, α, β) ≈ b * D * α + c * β
842+ else
843+ @test copy (M) ≈ b * D * α + c * β
844+ end
845+ end
738846end
0 commit comments