Skip to content

Commit 6f21d7e

Browse files
committed
Fix gemm operation
Fixes #7
1 parent 0053ea3 commit 6f21d7e

File tree

4 files changed

+222
-26
lines changed

4 files changed

+222
-26
lines changed

blas/src/main/java/dev/ludovic/netlib/blas/Java11BLAS.java

+12-12
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
14071407
sum02 = Math.fma(a00, b02, sum02);
14081408
sum10 = Math.fma(a10, b00, sum10);
14091409
sum11 = Math.fma(a10, b01, sum11);
1410-
sum11 = Math.fma(a10, b02, sum12);
1410+
sum12 = Math.fma(a10, b02, sum12);
14111411
sum20 = Math.fma(a20, b00, sum20);
14121412
sum21 = Math.fma(a20, b01, sum21);
14131413
sum22 = Math.fma(a20, b02, sum22);
@@ -1422,7 +1422,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
14221422
sum02 = Math.fma(a01, b12, sum02);
14231423
sum10 = Math.fma(a11, b10, sum10);
14241424
sum11 = Math.fma(a11, b11, sum11);
1425-
sum11 = Math.fma(a11, b12, sum12);
1425+
sum12 = Math.fma(a11, b12, sum12);
14261426
sum20 = Math.fma(a21, b10, sum20);
14271427
sum21 = Math.fma(a21, b11, sum21);
14281428
sum22 = Math.fma(a21, b12, sum22);
@@ -1439,7 +1439,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
14391439
sum02 = Math.fma(a00, b02, sum02);
14401440
sum10 = Math.fma(a10, b00, sum10);
14411441
sum11 = Math.fma(a10, b01, sum11);
1442-
sum11 = Math.fma(a10, b02, sum12);
1442+
sum12 = Math.fma(a10, b02, sum12);
14431443
sum20 = Math.fma(a20, b00, sum20);
14441444
sum21 = Math.fma(a20, b01, sum21);
14451445
sum22 = Math.fma(a20, b02, sum22);
@@ -1602,7 +1602,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
16021602
sum02 = Math.fma(a00, b02, sum02);
16031603
sum10 = Math.fma(a10, b00, sum10);
16041604
sum11 = Math.fma(a10, b01, sum11);
1605-
sum11 = Math.fma(a10, b02, sum12);
1605+
sum12 = Math.fma(a10, b02, sum12);
16061606
sum20 = Math.fma(a20, b00, sum20);
16071607
sum21 = Math.fma(a20, b01, sum21);
16081608
sum22 = Math.fma(a20, b02, sum22);
@@ -1617,7 +1617,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
16171617
sum02 = Math.fma(a01, b12, sum02);
16181618
sum10 = Math.fma(a11, b10, sum10);
16191619
sum11 = Math.fma(a11, b11, sum11);
1620-
sum11 = Math.fma(a11, b12, sum12);
1620+
sum12 = Math.fma(a11, b12, sum12);
16211621
sum20 = Math.fma(a21, b10, sum20);
16221622
sum21 = Math.fma(a21, b11, sum21);
16231623
sum22 = Math.fma(a21, b12, sum22);
@@ -1634,7 +1634,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
16341634
sum02 = Math.fma(a00, b02, sum02);
16351635
sum10 = Math.fma(a10, b00, sum10);
16361636
sum11 = Math.fma(a10, b01, sum11);
1637-
sum11 = Math.fma(a10, b02, sum12);
1637+
sum12 = Math.fma(a10, b02, sum12);
16381638
sum20 = Math.fma(a20, b00, sum20);
16391639
sum21 = Math.fma(a20, b01, sum21);
16401640
sum22 = Math.fma(a20, b02, sum22);
@@ -1798,7 +1798,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
17981798
sum02 = Math.fma(a00, b02, sum02);
17991799
sum10 = Math.fma(a10, b00, sum10);
18001800
sum11 = Math.fma(a10, b01, sum11);
1801-
sum11 = Math.fma(a10, b02, sum12);
1801+
sum12 = Math.fma(a10, b02, sum12);
18021802
sum20 = Math.fma(a20, b00, sum20);
18031803
sum21 = Math.fma(a20, b01, sum21);
18041804
sum22 = Math.fma(a20, b02, sum22);
@@ -1813,7 +1813,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
18131813
sum02 = Math.fma(a01, b12, sum02);
18141814
sum10 = Math.fma(a11, b10, sum10);
18151815
sum11 = Math.fma(a11, b11, sum11);
1816-
sum11 = Math.fma(a11, b12, sum12);
1816+
sum12 = Math.fma(a11, b12, sum12);
18171817
sum20 = Math.fma(a21, b10, sum20);
18181818
sum21 = Math.fma(a21, b11, sum21);
18191819
sum22 = Math.fma(a21, b12, sum22);
@@ -1830,7 +1830,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
18301830
sum02 = Math.fma(a00, b02, sum02);
18311831
sum10 = Math.fma(a10, b00, sum10);
18321832
sum11 = Math.fma(a10, b01, sum11);
1833-
sum11 = Math.fma(a10, b02, sum12);
1833+
sum12 = Math.fma(a10, b02, sum12);
18341834
sum20 = Math.fma(a20, b00, sum20);
18351835
sum21 = Math.fma(a20, b01, sum21);
18361836
sum22 = Math.fma(a20, b02, sum22);
@@ -1994,7 +1994,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
19941994
sum02 = Math.fma(a00, b02, sum02);
19951995
sum10 = Math.fma(a10, b00, sum10);
19961996
sum11 = Math.fma(a10, b01, sum11);
1997-
sum11 = Math.fma(a10, b02, sum12);
1997+
sum12 = Math.fma(a10, b02, sum12);
19981998
sum20 = Math.fma(a20, b00, sum20);
19991999
sum21 = Math.fma(a20, b01, sum21);
20002000
sum22 = Math.fma(a20, b02, sum22);
@@ -2009,7 +2009,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
20092009
sum02 = Math.fma(a01, b12, sum02);
20102010
sum10 = Math.fma(a11, b10, sum10);
20112011
sum11 = Math.fma(a11, b11, sum11);
2012-
sum11 = Math.fma(a11, b12, sum12);
2012+
sum12 = Math.fma(a11, b12, sum12);
20132013
sum20 = Math.fma(a21, b10, sum20);
20142014
sum21 = Math.fma(a21, b11, sum21);
20152015
sum22 = Math.fma(a21, b12, sum22);
@@ -2026,7 +2026,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
20262026
sum02 = Math.fma(a00, b02, sum02);
20272027
sum10 = Math.fma(a10, b00, sum10);
20282028
sum11 = Math.fma(a10, b01, sum11);
2029-
sum11 = Math.fma(a10, b02, sum12);
2029+
sum12 = Math.fma(a10, b02, sum12);
20302030
sum20 = Math.fma(a20, b00, sum20);
20312031
sum21 = Math.fma(a20, b01, sum21);
20322032
sum22 = Math.fma(a20, b02, sum22);

blas/src/main/java/dev/ludovic/netlib/blas/Java8BLAS.java

+12-12
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
18901890
sum02 = a00 * b02 + sum02;
18911891
sum10 = a10 * b00 + sum10;
18921892
sum11 = a10 * b01 + sum11;
1893-
sum11 = a10 * b02 + sum12;
1893+
sum12 = a10 * b02 + sum12;
18941894
sum20 = a20 * b00 + sum20;
18951895
sum21 = a20 * b01 + sum21;
18961896
sum22 = a20 * b02 + sum22;
@@ -1905,7 +1905,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
19051905
sum02 = a01 * b12 + sum02;
19061906
sum10 = a11 * b10 + sum10;
19071907
sum11 = a11 * b11 + sum11;
1908-
sum11 = a11 * b12 + sum12;
1908+
sum12 = a11 * b12 + sum12;
19091909
sum20 = a21 * b10 + sum20;
19101910
sum21 = a21 * b11 + sum21;
19111911
sum22 = a21 * b12 + sum22;
@@ -1922,7 +1922,7 @@ protected void sgemmNN(int m, int n, int k, float alpha, float[] a, int offseta,
19221922
sum02 = a00 * b02 + sum02;
19231923
sum10 = a10 * b00 + sum10;
19241924
sum11 = a10 * b01 + sum11;
1925-
sum11 = a10 * b02 + sum12;
1925+
sum12 = a10 * b02 + sum12;
19261926
sum20 = a20 * b00 + sum20;
19271927
sum21 = a20 * b01 + sum21;
19281928
sum22 = a20 * b02 + sum22;
@@ -2085,7 +2085,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
20852085
sum02 = a00 * b02 + sum02;
20862086
sum10 = a10 * b00 + sum10;
20872087
sum11 = a10 * b01 + sum11;
2088-
sum11 = a10 * b02 + sum12;
2088+
sum12 = a10 * b02 + sum12;
20892089
sum20 = a20 * b00 + sum20;
20902090
sum21 = a20 * b01 + sum21;
20912091
sum22 = a20 * b02 + sum22;
@@ -2100,7 +2100,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
21002100
sum02 = a01 * b12 + sum02;
21012101
sum10 = a11 * b10 + sum10;
21022102
sum11 = a11 * b11 + sum11;
2103-
sum11 = a11 * b12 + sum12;
2103+
sum12 = a11 * b12 + sum12;
21042104
sum20 = a21 * b10 + sum20;
21052105
sum21 = a21 * b11 + sum21;
21062106
sum22 = a21 * b12 + sum22;
@@ -2117,7 +2117,7 @@ protected void sgemmNT(int m, int n, int k, float alpha, float[] a, int offseta,
21172117
sum02 = a00 * b02 + sum02;
21182118
sum10 = a10 * b00 + sum10;
21192119
sum11 = a10 * b01 + sum11;
2120-
sum11 = a10 * b02 + sum12;
2120+
sum12 = a10 * b02 + sum12;
21212121
sum20 = a20 * b00 + sum20;
21222122
sum21 = a20 * b01 + sum21;
21232123
sum22 = a20 * b02 + sum22;
@@ -2281,7 +2281,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
22812281
sum02 = a00 * b02 + sum02;
22822282
sum10 = a10 * b00 + sum10;
22832283
sum11 = a10 * b01 + sum11;
2284-
sum11 = a10 * b02 + sum12;
2284+
sum12 = a10 * b02 + sum12;
22852285
sum20 = a20 * b00 + sum20;
22862286
sum21 = a20 * b01 + sum21;
22872287
sum22 = a20 * b02 + sum22;
@@ -2296,7 +2296,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
22962296
sum02 = a01 * b12 + sum02;
22972297
sum10 = a11 * b10 + sum10;
22982298
sum11 = a11 * b11 + sum11;
2299-
sum11 = a11 * b12 + sum12;
2299+
sum12 = a11 * b12 + sum12;
23002300
sum20 = a21 * b10 + sum20;
23012301
sum21 = a21 * b11 + sum21;
23022302
sum22 = a21 * b12 + sum22;
@@ -2313,7 +2313,7 @@ protected void sgemmTN(int m, int n, int k, float alpha, float[] a, int offseta,
23132313
sum02 = a00 * b02 + sum02;
23142314
sum10 = a10 * b00 + sum10;
23152315
sum11 = a10 * b01 + sum11;
2316-
sum11 = a10 * b02 + sum12;
2316+
sum12 = a10 * b02 + sum12;
23172317
sum20 = a20 * b00 + sum20;
23182318
sum21 = a20 * b01 + sum21;
23192319
sum22 = a20 * b02 + sum22;
@@ -2477,7 +2477,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
24772477
sum02 = a00 * b02 + sum02;
24782478
sum10 = a10 * b00 + sum10;
24792479
sum11 = a10 * b01 + sum11;
2480-
sum11 = a10 * b02 + sum12;
2480+
sum12 = a10 * b02 + sum12;
24812481
sum20 = a20 * b00 + sum20;
24822482
sum21 = a20 * b01 + sum21;
24832483
sum22 = a20 * b02 + sum22;
@@ -2492,7 +2492,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
24922492
sum02 = a01 * b12 + sum02;
24932493
sum10 = a11 * b10 + sum10;
24942494
sum11 = a11 * b11 + sum11;
2495-
sum11 = a11 * b12 + sum12;
2495+
sum12 = a11 * b12 + sum12;
24962496
sum20 = a21 * b10 + sum20;
24972497
sum21 = a21 * b11 + sum21;
24982498
sum22 = a21 * b12 + sum22;
@@ -2509,7 +2509,7 @@ protected void sgemmTT(int m, int n, int k, float alpha, float[] a, int offseta,
25092509
sum02 = a00 * b02 + sum02;
25102510
sum10 = a10 * b00 + sum10;
25112511
sum11 = a10 * b01 + sum11;
2512-
sum11 = a10 * b02 + sum12;
2512+
sum12 = a10 * b02 + sum12;
25132513
sum20 = a20 * b00 + sum20;
25142514
sum21 = a20 * b01 + sum21;
25152515
sum22 = a20 * b02 + sum22;

blas/src/test/java/dev/ludovic/netlib/blas/DgemmTest.java

+102-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,108 @@ void testSanity(BLAS blas) {
8181
blas.dgemm("T", "N", M, N, K, 0.0, dgeAT, K, dgeB, K, 1.0, dgeCcopy = dgeC.clone(), M);
8282
assertArrayEquals(expected, dgeCcopy, depsilon);
8383

84-
f2j.dgemm("T", "T", M, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, expected = dgeC.clone(), M);
85-
blas.dgemm("T", "T", M, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, dgeCcopy = dgeC.clone(), M);
84+
f2j.dgemm("T", "T", M/2, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, expected = dgeC.clone(), M/2);
85+
blas.dgemm("T", "T", M/2, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, dgeCcopy = dgeC.clone(), M/2);
86+
assertArrayEquals(expected, dgeCcopy, depsilon);
87+
88+
f2j.dgemm("N", "N", M/2, N, K, 1.0, dgeA, M/2, dgeB, K, 2.0, expected = dgeC.clone(), M/2);
89+
blas.dgemm("N", "N", M/2, N, K, 1.0, dgeA, M/2, dgeB, K, 2.0, dgeCcopy = dgeC.clone(), M/2);
90+
assertArrayEquals(expected, dgeCcopy, depsilon);
91+
92+
f2j.dgemm("N", "T", M/2, N, K, 1.0, dgeA, M/2, dgeBT, N, 2.0, expected = dgeC.clone(), M/2);
93+
blas.dgemm("N", "T", M/2, N, K, 1.0, dgeA, M/2, dgeBT, N, 2.0, dgeCcopy = dgeC.clone(), M/2);
94+
assertArrayEquals(expected, dgeCcopy, depsilon);
95+
96+
f2j.dgemm("T", "N", M/2, N, K, 1.0, dgeAT, K, dgeB, K, 2.0, expected = dgeC.clone(), M/2);
97+
blas.dgemm("T", "N", M/2, N, K, 1.0, dgeAT, K, dgeB, K, 2.0, dgeCcopy = dgeC.clone(), M/2);
98+
assertArrayEquals(expected, dgeCcopy, depsilon);
99+
100+
f2j.dgemm("T", "T", M/2, N, K, 1.0, dgeAT, K, dgeBT, N, 2.0, expected = dgeC.clone(), M/2);
101+
blas.dgemm("T", "T", M/2, N, K, 1.0, dgeAT, K, dgeBT, N, 2.0, dgeCcopy = dgeC.clone(), M/2);
102+
assertArrayEquals(expected, dgeCcopy, depsilon);
103+
104+
f2j.dgemm("N", "N", M/2, N, K, 1.0, dgeA, M/2, dgeB, K, 0.0, expected = dgeC.clone(), M/2);
105+
blas.dgemm("N", "N", M/2, N, K, 1.0, dgeA, M/2, dgeB, K, 0.0, dgeCcopy = dgeC.clone(), M/2);
106+
assertArrayEquals(expected, dgeCcopy, depsilon);
107+
108+
f2j.dgemm("N", "T", M/2, N, K, 1.0, dgeA, M/2, dgeBT, N, 0.0, expected = dgeC.clone(), M/2);
109+
blas.dgemm("N", "T", M/2, N, K, 1.0, dgeA, M/2, dgeBT, N, 0.0, dgeCcopy = dgeC.clone(), M/2);
110+
assertArrayEquals(expected, dgeCcopy, depsilon);
111+
112+
f2j.dgemm("T", "N", M/2, N, K, 1.0, dgeAT, K, dgeB, K, 0.0, expected = dgeC.clone(), M/2);
113+
blas.dgemm("T", "N", M/2, N, K, 1.0, dgeAT, K, dgeB, K, 0.0, dgeCcopy = dgeC.clone(), M/2);
114+
assertArrayEquals(expected, dgeCcopy, depsilon);
115+
116+
f2j.dgemm("T", "T", M/2, N, K, 1.0, dgeAT, K, dgeBT, N, 0.0, expected = dgeC.clone(), M/2);
117+
blas.dgemm("T", "T", M/2, N, K, 1.0, dgeAT, K, dgeBT, N, 0.0, dgeCcopy = dgeC.clone(), M/2);
118+
assertArrayEquals(expected, dgeCcopy, depsilon);
119+
120+
f2j.dgemm("N", "N", M/2, N, K, 0.0, dgeA, M/2, dgeB, K, 1.0, expected = dgeC.clone(), M/2);
121+
blas.dgemm("N", "N", M/2, N, K, 0.0, dgeA, M/2, dgeB, K, 1.0, dgeCcopy = dgeC.clone(), M/2);
122+
assertArrayEquals(expected, dgeCcopy, depsilon);
123+
124+
f2j.dgemm("N", "T", M/2, N, K, 0.0, dgeA, M/2, dgeBT, N, 1.0, expected = dgeC.clone(), M/2);
125+
blas.dgemm("N", "T", M/2, N, K, 0.0, dgeA, M/2, dgeBT, N, 1.0, dgeCcopy = dgeC.clone(), M/2);
126+
assertArrayEquals(expected, dgeCcopy, depsilon);
127+
128+
f2j.dgemm("T", "N", M/2, N, K, 0.0, dgeAT, K, dgeB, K, 1.0, expected = dgeC.clone(), M/2);
129+
blas.dgemm("T", "N", M/2, N, K, 0.0, dgeAT, K, dgeB, K, 1.0, dgeCcopy = dgeC.clone(), M/2);
130+
assertArrayEquals(expected, dgeCcopy, depsilon);
131+
132+
f2j.dgemm("T", "T", M/2, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, expected = dgeC.clone(), M/2);
133+
blas.dgemm("T", "T", M/2, N, K, 0.0, dgeAT, K, dgeBT, N, 1.0, dgeCcopy = dgeC.clone(), M/2);
134+
assertArrayEquals(expected, dgeCcopy, depsilon);
135+
136+
f2j.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, expected = dgeC.clone(), M);
137+
blas.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, dgeCcopy = dgeC.clone(), M);
138+
assertArrayEquals(expected, dgeCcopy, depsilon);
139+
140+
f2j.dgemm("N", "N", M, N/2, K, 1.0, dgeA, M, dgeB, K, 2.0, expected = dgeC.clone(), M);
141+
blas.dgemm("N", "N", M, N/2, K, 1.0, dgeA, M, dgeB, K, 2.0, dgeCcopy = dgeC.clone(), M);
142+
assertArrayEquals(expected, dgeCcopy, depsilon);
143+
144+
f2j.dgemm("N", "T", M, N/2, K, 1.0, dgeA, M, dgeBT, N/2, 2.0, expected = dgeC.clone(), M);
145+
blas.dgemm("N", "T", M, N/2, K, 1.0, dgeA, M, dgeBT, N/2, 2.0, dgeCcopy = dgeC.clone(), M);
146+
assertArrayEquals(expected, dgeCcopy, depsilon);
147+
148+
f2j.dgemm("T", "N", M, N/2, K, 1.0, dgeAT, K, dgeB, K, 2.0, expected = dgeC.clone(), M);
149+
blas.dgemm("T", "N", M, N/2, K, 1.0, dgeAT, K, dgeB, K, 2.0, dgeCcopy = dgeC.clone(), M);
150+
assertArrayEquals(expected, dgeCcopy, depsilon);
151+
152+
f2j.dgemm("T", "T", M, N/2, K, 1.0, dgeAT, K, dgeBT, N/2, 2.0, expected = dgeC.clone(), M);
153+
blas.dgemm("T", "T", M, N/2, K, 1.0, dgeAT, K, dgeBT, N/2, 2.0, dgeCcopy = dgeC.clone(), M);
154+
assertArrayEquals(expected, dgeCcopy, depsilon);
155+
156+
f2j.dgemm("N", "N", M, N/2, K, 1.0, dgeA, M, dgeB, K, 0.0, expected = dgeC.clone(), M);
157+
blas.dgemm("N", "N", M, N/2, K, 1.0, dgeA, M, dgeB, K, 0.0, dgeCcopy = dgeC.clone(), M);
158+
assertArrayEquals(expected, dgeCcopy, depsilon);
159+
160+
f2j.dgemm("N", "T", M, N/2, K, 1.0, dgeA, M, dgeBT, N/2, 0.0, expected = dgeC.clone(), M);
161+
blas.dgemm("N", "T", M, N/2, K, 1.0, dgeA, M, dgeBT, N/2, 0.0, dgeCcopy = dgeC.clone(), M);
162+
assertArrayEquals(expected, dgeCcopy, depsilon);
163+
164+
f2j.dgemm("T", "N", M, N/2, K, 1.0, dgeAT, K, dgeB, K, 0.0, expected = dgeC.clone(), M);
165+
blas.dgemm("T", "N", M, N/2, K, 1.0, dgeAT, K, dgeB, K, 0.0, dgeCcopy = dgeC.clone(), M);
166+
assertArrayEquals(expected, dgeCcopy, depsilon);
167+
168+
f2j.dgemm("T", "T", M, N/2, K, 1.0, dgeAT, K, dgeBT, N/2, 0.0, expected = dgeC.clone(), M);
169+
blas.dgemm("T", "T", M, N/2, K, 1.0, dgeAT, K, dgeBT, N/2, 0.0, dgeCcopy = dgeC.clone(), M);
170+
assertArrayEquals(expected, dgeCcopy, depsilon);
171+
172+
f2j.dgemm("N", "N", M, N/2, K, 0.0, dgeA, M, dgeB, K, 1.0, expected = dgeC.clone(), M);
173+
blas.dgemm("N", "N", M, N/2, K, 0.0, dgeA, M, dgeB, K, 1.0, dgeCcopy = dgeC.clone(), M);
174+
assertArrayEquals(expected, dgeCcopy, depsilon);
175+
176+
f2j.dgemm("N", "T", M, N/2, K, 0.0, dgeA, M, dgeBT, N/2, 1.0, expected = dgeC.clone(), M);
177+
blas.dgemm("N", "T", M, N/2, K, 0.0, dgeA, M, dgeBT, N/2, 1.0, dgeCcopy = dgeC.clone(), M);
178+
assertArrayEquals(expected, dgeCcopy, depsilon);
179+
180+
f2j.dgemm("T", "N", M, N/2, K, 0.0, dgeAT, K, dgeB, K, 1.0, expected = dgeC.clone(), M);
181+
blas.dgemm("T", "N", M, N/2, K, 0.0, dgeAT, K, dgeB, K, 1.0, dgeCcopy = dgeC.clone(), M);
182+
assertArrayEquals(expected, dgeCcopy, depsilon);
183+
184+
f2j.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, expected = dgeC.clone(), M);
185+
blas.dgemm("T", "T", M, N/2, K, 0.0, dgeAT, K, dgeBT, N/2, 1.0, dgeCcopy = dgeC.clone(), M);
86186
assertArrayEquals(expected, dgeCcopy, depsilon);
87187
}
88188
}

0 commit comments

Comments
 (0)