Skip to content

Commit 6057823

Browse files
committed
fix an issue for F-contiguous arrays
1 parent 8a39e55 commit 6057823

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

dpnp/backend/extensions/blas/syrk.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,20 @@ static sycl::event syrk_impl(sycl::queue &exec_q,
152152
std::int64_t i = idx[0];
153153
std::int64_t j = idx[1];
154154
if (j > i) {
155-
res[j * ldc + i] = res[i * ldc + j];
155+
// result form row_major::syrk is row major and result form
156+
// column_major::syrk is column major, so copying upper
157+
// triangle to lower triangle is different for each case
158+
if (is_row_major) {
159+
// row-major: res[i][j] = res[i * ldc + j]
160+
res[j * ldc + i] = res[i * ldc + j];
161+
}
162+
else {
163+
// column-major: res[i][j] = res[i + j * ldc]
164+
res[i * ldc + j] = res[j * ldc + i];
165+
}
156166
}
157167
});
158168
});
159-
160169
return copy_event;
161170
}
162171

dpnp/tests/test_product.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,12 +1198,15 @@ def test_syrk(self, dt):
11981198
assert result is iout
11991199
assert_dtype_allclose(result, expected)
12001200

1201+
result = ia.mT @ ia
1202+
expected = a.T @ a
1203+
assert_dtype_allclose(result, expected)
1204+
12011205
@pytest.mark.parametrize(
12021206
"order, out_order",
12031207
[("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")],
12041208
)
12051209
def test_syrk_out_order(self, order, out_order):
1206-
# test syrk with out keyword
12071210
a = generate_random_numpy_array((5, 4), order=order, low=-5, high=5)
12081211
out = numpy.empty((5, 5), dtype=a.dtype, order=out_order)
12091212
ia, iout = dpnp.array(a), dpnp.array(out)
@@ -1215,6 +1218,14 @@ def test_syrk_out_order(self, order, out_order):
12151218
assert result.flags.f_contiguous == expected.flags.f_contiguous
12161219
assert_dtype_allclose(result, expected)
12171220

1221+
@pytest.mark.parametrize("order", ["F", "C"])
1222+
def test_syrk_order(self, order):
1223+
a = generate_random_numpy_array((4, 6), order=order, low=-5, high=5)
1224+
ia = dpnp.array(a)
1225+
expected = numpy.matmul(a, a.T)
1226+
result = dpnp.matmul(ia, ia.mT)
1227+
assert_dtype_allclose(result, expected)
1228+
12181229
def test_bool(self):
12191230
a = generate_random_numpy_array((3, 4), dtype=dpnp.bool)
12201231
b = generate_random_numpy_array((4, 5), dtype=dpnp.bool)

0 commit comments

Comments
 (0)